In [7]:

import pdb
import os
import traceback
import numpy as np
import jax
from jax import numpy as jnp
import optax
import wandb
import equinox as eqx
from typing import Optional, Callable
import librosa

from common_dl_utils.config_creation import Config
import common_jax_utils as cju

wandb.login()

key = jax.random.PRNGKey(12398)
key_gen = cju.key_generator(key)

In [8]:
def load_audio_file(file_path, sr=16000, save_npy=True):
    """
    Load an audio file and return it as a normalized numpy array.
    Optionally save as .npy file.
    """
    # Load the audio file
    audio, _ = librosa.load(file_path, sr=sr)
    
    # Convert to numpy array and normalize to [-1, 1] range
    audio = np.array(audio, dtype=np.float32)
    audio = audio / np.max(np.abs(audio))

    if save_npy:
        # Create npy filename from original audio filename
        npy_path = os.path.splitext(file_path)[0] + '.npy'
        np.save(npy_path, audio)
        return audio, len(audio), npy_path
   
    return audio, len(audio), None

In [9]:
# First set up the config properly
config = Config()

# Model architecture configuration
config.architecture = './model_components'  # module containing model classes
config.model_type = 'inr_modules.CombinedINR'

# Model configuration
config.model_config = Config()
config.model_config.in_size = 1 # Time dimension input
config.model_config.out_size = 1  # Audio amplitude output
config.model_config.terms = [
    ('inr_modules.MLPINR.from_config',{
        'hidden_size': 256,
        'num_layers': 5,
        'layer_type': 'inr_layers.SirenLayer',
        'num_splits': 1,
        'use_complex': False,
        'activation_kwargs': {'w0': 30.},
        'initialization_scheme':'initialization_schemes.siren_scheme',
        'positional_encoding_layer': ('inr_layers.ClassicalPositionalEncoding.from_config', {'num_frequencies': 10}),
    }),
    # ('inr_modules.MLPINR.from_config',{
    #     'hidden_size': 1024,
    #     'num_layers': 2,
    #     'num_splits': 1,
    #     'layer_type': 'inr_layers.GaussianINRLayer',
    #     'use_complex': False,
    #     'activation_kwargs': {'inverse_scale': 1},
    # })
    # ('inr_modules.MLPINR.from_config',{
    #     'hidden_size': 256,
    #     'num_layers': 3,
    #     'layer_type': 'inr_layers.FinerLayer',
    #     'num_splits': 1,
    #     'use_complex': False,
    #     'activation_kwargs': {'w0': 20},
    #     'initialization_scheme':'initialization_schemes.finer_scheme',
    #     'initialization_scheme_kwargs': {'bias_k': 5,'scale_factor': 1}
    #     # 'initialization_scheme_k' : {'k': 20}
    #     #'positional_encoding_layer': ('inr_layers.ClassicalPositionalEncoding.from_config', {'num_frequencies': 10}),
    # })
]

# Training configuration
config.trainer_module = './inr_utils'  # module containing training code
config.trainer_type = 'training.train_inr'

# Loss function configuration
config.loss_evaluator = config.loss_function = 'losses.SoundLossEvaluator'
config.loss_evaluator_config = config.loss_function_config = {
    'time_domain_weight': 1.0,
    'frequency_domain_weight': 0.0001
}
# # Loss function configuration
# config.loss_evaluator = config.loss_function = 'losses.SoundLossEvaluator'
# config.loss_evaluator_config = config.loss_function_config = {
#     'time_domain_weight': 1.0,
#     'frequency_domain_weight': 0.001,
#     'initial_freq_weight': 0.0001,
#     'final_freq_weight': 0.001,
#     'transition_steps': 5000
# }


# Optimizer configuration
config.optimizer = 'adam'  # will use optax.adam
config.optimizer_config = {
    'learning_rate': 1e-4,  # Lower initial learning rate
    'b1': 0.8,
    'b2': 0.999999
}

# Load and save the audio file
audio_path = './example_data/data_gt_bach.wav'  # Make sure this path exists
audio_data, fragment_length, npy_path = load_audio_file(audio_path)

# Sampler configuration
config.sampler = ('sampling.SoundSampler', {
    'window_size': 256,
    'batch_size': 64,
    'fragment_length': fragment_length,  # Will be set after loading audio
    'sound_fragment': npy_path
})


# Metric collector configuration
config.metric_collector_config = {
    'metrics': [
        ('metrics.AudioMetricsOnGrid', {
            'target_audio': audio_data,
            'grid_size': fragment_length,
            'batch_size': 1024,
            'sr': 16000,
            'frequency': 'every_n_batches'
        })
    ],
    'batch_frequency': 100,
    'epoch_frequency': 1
}
# # Callback configuration for logging
# config.after_step_callback = 'callbacks.ComposedCallback'  # This line was missing
# config.after_step_callback_config = {
#     'callbacks': [
#         ('callbacks.print_loss', {'after_every': 10}),
#         'callbacks.report_loss',
#         'callbacks.raise_error_on_nan',
#         ('callbacks.AudioMetricsCallback', {
#             'metric_collector': ('metrics.MetricCollector', config.metric_collector_config),
#             'print_metrics': True,
#             'print_frequency': 100
#         })
#     ],
#     'use_wandb': True,  # This was missing
#     'show_logs': False
# }

# Callback configuration
config.after_step_callback = 'callbacks.ComposedCallback'  # This line is crucial
config.after_step_callback_config = {
    'callbacks': [
        ('callbacks.print_loss', {'after_every': 10}),
        'callbacks.report_loss',
        'callbacks.raise_error_on_nan',
        ('callbacks.AudioMetricsWithEarlyStoppingCallback', {
            'metric_collector': ('metrics.MetricCollector', config.metric_collector_config),
            'print_metrics': True,
            'print_frequency': 100,
            'patience': 100,
            'min_delta': 0.001,
            'monitor': 'audio_mse'
        })
    ],
    'use_wandb': True,
    'show_logs': False
}




config.after_training_callback = None
config.optimizer_state = None
config.steps = 20000
config.use_wandb = True


In [10]:
try:
    inr = cju.run_utils.get_model_from_config_and_key(
        prng_key=next(key_gen),
        config=config,
        model_sub_config_name_base='model',
        add_model_module_to_architecture_default_module=False,
    )
except Exception as e:
    traceback.print_exc()
    print(e)
    print('\n')
    pdb.post_mortem()


In [11]:
# Then get the experiment
experiment = cju.run_utils.get_experiment_from_config_and_key(
    prng_key=next(key_gen),
    config=config,
    model_kwarg_in_trainer='inr',
    model_sub_config_name_base='model',
    trainer_default_module_key='trainer_module',
    additional_trainer_default_modules=[optax],
    add_model_module_to_architecture_default_module=False,
    initialize=False
)

In [12]:
# Run the experiment with wandb logging
try: 
        
    with wandb.init(
        project='inr-audio',
        config={
            'window_size': config.sampler[1]['window_size'],
            'batch_size': config.sampler[1]['batch_size'],
            'learning_rate': config.optimizer_config['learning_rate'],
            'steps': config.steps,
            'audio_path': config.sampler[1]['sound_fragment']
        }
    ) as run:
        results = experiment.initialize()
        
except Exception as e:
    traceback.print_exc()
    print(e)
    print('\n')
    pdb.post_mortem()


Audio Metrics at step 0:
Loss at step 10 is 17.838665008544922.
Loss at step 20 is 11.21815299987793.
Loss at step 30 is 7.28759241104126.
Loss at step 40 is 8.382088661193848.
Loss at step 50 is 6.444572925567627.
Loss at step 60 is 7.390853404998779.
Loss at step 70 is 6.734798908233643.
Loss at step 80 is 5.958539009094238.
Loss at step 90 is 5.995924472808838.
Loss at step 100 is 6.307222366333008.

Audio Metrics at step 100:
Loss at step 110 is 7.3549933433532715.
Loss at step 120 is 6.299817085266113.
Loss at step 130 is 6.207917213439941.
Loss at step 140 is 7.054150104522705.
Loss at step 150 is 7.3905792236328125.
Loss at step 160 is 6.159975528717041.
Loss at step 170 is 6.5834269523620605.
Loss at step 180 is 5.606884002685547.
Loss at step 190 is 6.279910564422607.
Loss at step 200 is 8.416749000549316.

Audio Metrics at step 200:
Loss at step 210 is 10.155515670776367.
Loss at step 220 is 9.116251945495605.
Loss at step 230 is 8.2527437210083.
Loss at step 240 is 7.248920

Traceback (most recent call last):
  File "/tmp/ipykernel_43379/2668227620.py", line 14, in <module>
    results = experiment.initialize()
  File "/home/abdtab/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/common_dl_utils/config_realization.py", line 195, in initialize
    return cls(**processed_self_kwargs)
  File "/home/abdtab/INR_BEP/inr_utils/training.py", line 194, in train_inr
    after_step_callback(step, loss, inr, state, optimizer_state)
  File "/home/abdtab/INR_BEP/inr_utils/callbacks.py", line 69, in __call__
    log = callback(step, loss, inr, state, optimizer_state)
  File "/home/abdtab/INR_BEP/inr_utils/callbacks.py", line 163, in __call__
    raise StopIteration("Early stopping triggered")
StopIteration: Early stopping triggered


0,1
audio_magnitude_error,█▇▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▂▁▂▂▂▂▂▁▂▂▁▂▁▁▁▁▁▁▁▁▁▁
audio_mse,██▇▇▇▇▄▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
audio_psnr,▁▁▁▂▃▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇█▇██▇██████████
audio_snr,▁▂▂▂▂▅▆▆▇▆▆▇▇▇▇▇▇▇▇▇███▇▇██▇█▇██████████
audio_spectral_convergence,█▇▇▇▆▄▂▂▂▂▂▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
batch_within_epoch,▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇████
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,█▇▆▄▄▃▂▃▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
audio_magnitude_error,0.3466
audio_mse,0.00171
audio_psnr,27.66872
audio_snr,11.49891
audio_spectral_convergence,0.2185
batch_within_epoch,17899.0
epoch,1.0
loss,0.2812


Traceback (most recent call last):
  File "/tmp/ipykernel_43379/2668227620.py", line 14, in <module>
    results = experiment.initialize()
  File "/home/abdtab/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/common_dl_utils/config_realization.py", line 195, in initialize
    return cls(**processed_self_kwargs)
  File "/home/abdtab/INR_BEP/inr_utils/training.py", line 194, in train_inr
    after_step_callback(step, loss, inr, state, optimizer_state)
  File "/home/abdtab/INR_BEP/inr_utils/callbacks.py", line 69, in __call__
    log = callback(step, loss, inr, state, optimizer_state)
  File "/home/abdtab/INR_BEP/inr_utils/callbacks.py", line 163, in __call__
    raise StopIteration("Early stopping triggered")
StopIteration: Early stopping triggered


Early stopping triggered


> [0;32m/home/abdtab/INR_BEP/inr_utils/callbacks.py[0m(163)[0;36m__call__[0;34m()[0m
[0;32m    161 [0;31m                [0mself[0m[0;34m.[0m[0mwait[0m [0;34m+=[0m [0;36m1[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    162 [0;31m                [0;32mif[0m [0mself[0m[0;34m.[0m[0mwait[0m [0;34m>=[0m [0mself[0m[0;34m.[0m[0mpatience[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 163 [0;31m                    [0;32mraise[0m [0mStopIteration[0m[0;34m([0m[0;34m"Early stopping triggered"[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    164 [0;31m[0;34m[0m[0m
[0m[0;32m    165 [0;31m        [0;32mreturn[0m [0mmetrics[0m[0;34m[0m[0;34m[0m[0m
[0m
