In [7]:

import pdb
import os
import traceback
import numpy as np
import jax
import optax
import wandb
import librosa

from common_dl_utils.config_creation import Config
import common_jax_utils as cju

wandb.login()

key = jax.random.PRNGKey(12398)
key_gen = cju.key_generator(key)

In [8]:
def load_audio_file(file_path, sr=16000, save_npy=True):
    """
    Load an audio file and return it as a normalized numpy array.
    Optionally save as .npy file.
    """
    # Load the audio file
    audio, _ = librosa.load(file_path, sr=sr)
    
    # Convert to numpy array and normalize to [-1, 1] range
    audio = np.array(audio, dtype=np.float32)
    audio = audio / np.max(np.abs(audio))

    if save_npy:
        # Create npy filename from original audio filename
        npy_path = os.path.splitext(file_path)[0] + '.npy'
        np.save(npy_path, audio)
        return audio, len(audio), npy_path
   
    return audio, len(audio), None

In [9]:
# First set up the config properly
config = Config()

# Model architecture configuration
config.architecture = './model_components'  # module containing model classes
config.model_type = 'inr_modules.CombinedINR'

# Model configuration
config.model_config = Config()
config.model_config.in_size = 1 # Time dimension input
config.model_config.out_size = 1  # Audio amplitude output
config.model_config.terms = [
    ('inr_modules.MLPINR.from_config',{
        'hidden_size': 256,
        'num_layers': 5,
        'layer_type': 'inr_layers.SirenLayer',
        'num_splits': 1,
        'use_complex': False,
        'activation_kwargs': {'w0': 30.},
        'initialization_scheme':'initialization_schemes.siren_scheme',
        'positional_encoding_layer': ('inr_layers.ClassicalPositionalEncoding.from_config', {'num_frequencies': 10}),
    }),
    # ('inr_modules.MLPINR.from_config',{
    #     'hidden_size': 1024,
    #     'num_layers': 2,
    #     'num_splits': 1,
    #     'layer_type': 'inr_layers.GaussianINRLayer',
    #     'use_complex': False,
    #     'activation_kwargs': {'inverse_scale': 1},
    # })
    # ('inr_modules.MLPINR.from_config',{
    #     'hidden_size': 256,
    #     'num_layers': 5,
    #     'layer_type': 'inr_layers.FinerLayer',
    #     'num_splits': 1,
    #     'use_complex': False,
    #     'activation_kwargs': {'w0': 30},
    #     'initialization_scheme':'initialization_schemes.finer_scheme',
    #     'initialization_scheme_kwargs': {'bias_k': 10,'scale_factor': 10}
    #     # 'initialization_scheme_k' : {'k': 20}
    #     #'positional_encoding_layer': ('inr_layers.ClassicalPositionalEncoding.from_config', {'num_frequencies': 10}),
    # })
]

# Training configuration
config.trainer_module = './inr_utils'  # module containing training code
config.trainer_type = 'training.train_inr'

# Loss function configuration
config.loss_evaluator = config.loss_function = 'losses.SoundLossEvaluator'
config.loss_evaluator_config = config.loss_function_config = {
    'time_domain_weight': 1.0,
    'frequency_domain_weight': 0.1
}

# Optimizer configuration
config.optimizer = 'adam'  # will use optax.adam
config.optimizer_config = {
    'learning_rate': 1e-3
}

# Load and save the audio file
audio_path = './example_data/data_gt_bach.wav'  # Make sure this path exists
audio_data, fragment_length, npy_path = load_audio_file(audio_path)

# Sampler configuration
config.sampler = ('sampling.SoundSampler', {
    'window_size': 1024,
    'batch_size': 32,
    'fragment_length': fragment_length,  # Will be set after loading audio
    'sound_fragment': npy_path
})


# Metric collector configuration
config.metric_collector_config = {
    'metrics': [
        ('metrics.AudioMetricsOnGrid', {
            'target_audio': audio_data,
            'grid_size': fragment_length,
            'batch_size': 1024,  # This will be automatically adjusted if needed
            'sr': 16000,
            'frequency': 'every_n_batches'
        })
    ],
    'batch_frequency': 100,
    'epoch_frequency': 1
}
# Callback configuration for logging
config.after_step_callback = 'callbacks.ComposedCallback'  # This line was missing
config.after_step_callback_config = {
    'callbacks': [
        ('callbacks.print_loss', {'after_every': 10}),
        'callbacks.report_loss',
        'callbacks.raise_error_on_nan',
        ('callbacks.AudioMetricsCallback', {
            'metric_collector': ('metrics.MetricCollector', config.metric_collector_config),
            'print_metrics': True,
            'print_frequency': 100
        })
    ],
    'use_wandb': True,  # This was missing
    'show_logs': False
}

# # Callback configuration for logging
# config.after_step_callback = 'callbacks.ComposedCallback'
# config.after_step_callback_config = {
#     'callbacks': [
#         ('callbacks.print_loss', {'after_every': 10}),
#         'callbacks.report_loss',
#         'callbacks.raise_error_on_nan'
#     ],
#     'show_logs': False
# }



config.after_training_callback = None
config.optimizer_state = None
config.steps = 40000
config.use_wandb = True


In [10]:
try:
    inr = cju.run_utils.get_model_from_config_and_key(
        prng_key=next(key_gen),
        config=config,
        model_sub_config_name_base='model',
        add_model_module_to_architecture_default_module=False,
    )
except Exception as e:
    traceback.print_exc()
    print(e)
    print('\n')
    pdb.post_mortem()


In [11]:
# Then get the experiment
experiment = cju.run_utils.get_experiment_from_config_and_key(
    prng_key=next(key_gen),
    config=config,
    model_kwarg_in_trainer='inr',
    model_sub_config_name_base='model',
    trainer_default_module_key='trainer_module',
    additional_trainer_default_modules=[optax],
    add_model_module_to_architecture_default_module=False,
    initialize=False
)

In [12]:
# Run the experiment with wandb logging
try: 
        
    with wandb.init(
        project='inr-audio',
        config={
            'window_size': config.sampler[1]['window_size'],
            'batch_size': config.sampler[1]['batch_size'],
            'learning_rate': config.optimizer_config['learning_rate'],
            'steps': config.steps,
            'audio_path': config.sampler[1]['sound_fragment']
        }
    ) as run:
        results = experiment.initialize()
        
except Exception as e:
    traceback.print_exc()
    print(e)
    print('\n')
    pdb.post_mortem()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.



Audio Metrics at step 0:
Loss at step 10 is 49110.19921875.
Loss at step 20 is 30321.6953125.
Loss at step 30 is 25111.845703125.
Loss at step 40 is 14980.5439453125.
Loss at step 50 is 8722.50390625.
Loss at step 60 is 5391.98779296875.
Loss at step 70 is 4698.75244140625.
Loss at step 80 is 2995.6171875.
Loss at step 90 is 4098.02490234375.
Loss at step 100 is 3059.082763671875.

Audio Metrics at step 100:
Loss at step 110 is 2348.799072265625.
Loss at step 120 is 1561.2381591796875.
Loss at step 130 is 1853.829833984375.
Loss at step 140 is 1735.5042724609375.
Loss at step 150 is 1279.7247314453125.
Loss at step 160 is 2193.74609375.
Loss at step 170 is 2152.640625.
Loss at step 180 is 1535.1761474609375.
Loss at step 190 is 1446.7889404296875.
Loss at step 200 is 1754.449951171875.

Audio Metrics at step 200:
Loss at step 210 is 2057.173828125.
Loss at step 220 is 1702.3385009765625.
Loss at step 230 is 1441.05419921875.
Loss at step 240 is 1375.185546875.
Loss at step 250 is 1154

Traceback (most recent call last):
  File "/tmp/ipykernel_8368/2655119114.py", line 14, in <module>
    results = experiment.initialize()
  File "/home/ovindar/PycharmProjects/INR_BEP/.venv/lib/python3.10/site-packages/common_dl_utils/config_realization.py", line 195, in initialize
    return cls(**processed_self_kwargs)
  File "/home/ovindar/PycharmProjects/INR_BEP/inr_utils/training.py", line 212, in train_inr
    after_step_callback(step, loss, inr, state, optimizer_state)
  File "/home/ovindar/PycharmProjects/INR_BEP/inr_utils/callbacks.py", line 69, in __call__
    log = callback(step, loss, inr, state, optimizer_state)
  File "/home/ovindar/PycharmProjects/INR_BEP/inr_utils/callbacks.py", line 105, in __call__
    metrics = self.metric_collector.on_batch_end(
  File "/home/ovindar/PycharmProjects/INR_BEP/.venv/lib/python3.10/site-packages/common_dl_utils/metrics.py", line 110, in on_batch_end
    results.update(metric.compute(**kwargs))
  File "/home/ovindar/PycharmProjects/INR_B

0,1
audio_magnitude_error,█▃▃▁
audio_mse,█▄▄▁
audio_psnr,▁▅▅█
audio_snr,▁▅▅█
audio_spectral_convergence,█▃▂▁
batch_within_epoch,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇██████
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,█▇▆▅▅▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
audio_magnitude_error,0.91537
audio_mse,0.03591
audio_psnr,14.44835
audio_snr,-1.72146
audio_spectral_convergence,0.77428
batch_within_epoch,499.0
epoch,1.0
loss,850.29407


KeyboardInterrupt: 