In [None]:

import pdb
import traceback
import numpy as np
import jax
from jax import numpy as jnp
import optax
import wandb
import equinox as eqx
from typing import Optional, Callable
import librosa

from common_dl_utils.config_creation import Config
import common_jax_utils as cju

wandb.login()

key = jax.random.PRNGKey(12398)
key_gen = cju.key_generator(key)

In [None]:
# Cell [2] - Audio loading function
def load_audio_file(file_path, sr=16000):
    """
    Load an audio file and return it as a normalized numpy array.
    
    Args:
        file_path: Path to the audio file
        sr: Target sampling rate (default: 16000)
        
    Returns:
        tuple: (audio_array, fragment_length)
    """
    # Load the audio file
    audio, _ = librosa.load(file_path, sr=sr)
    
    # Convert to numpy array and normalize to [-1, 1] range
    audio = np.array(audio, dtype=np.float32)
    audio = audio / np.max(np.abs(audio))
    
    return audio, len(audio)


In [None]:
# Cell [3] - Configuration setup
config = Config()

# Model architecture configuration
config.architecture = './model_components'
config.model_type = 'inr_modules.CombinedINR'

# Model configuration
config.model_config = Config()
config.model_config.in_size = 1  # Time dimension input
config.model_config.out_size = 1  # Audio amplitude output
config.model_config.terms = [
    # ('inr_modules.MLPINR.new_from_config',{
    #     'hidden_size': 256,
    #     'num_layers': 5,
    #     'layer_type': 'inr_layers.SirenLayer',
    #     'num_splits': 1,
    #     'use_complex': False,
    #     'activation_kwargs': {'w0': 30.},
    #     'initialization_scheme':'initialization_schemes.siren_scheme',
    #     'positional_encoding_layer': ('inr_layers.ClassicalPositionalEncoding.from_config', {'num_frequencies': 10}),
    # }),
    # ('inr_modules.MLPINR.from_config',{
    #     'hidden_size': 1024,
    #     'num_layers': 2,
    #     'num_splits': 1,
    #     'layer_type': 'inr_layers.GaussianINRLayer',
    #     'use_complex': False,
    #     'activation_kwargs': {'inverse_scale': 1},
    # })
    ('inr_modules.MLPINR.new_from_config',{
        'hidden_size': 256,
        'num_layers': 5,
        'layer_type': 'inr_layers.FinerLayer',
        'num_splits': 1,
        'use_complex': False,
        'activation_kwargs': {'w0': 30},
        'initialization_scheme':'initialization_schemes.finer_scheme',
        'initialization_scheme_kwargs':{'bias_k' : 10}
        # 'initialization_scheme_k' : {'k': 20}
        #'positional_encoding_layer': ('inr_layers.ClassicalPositionalEncoding.from_config', {'num_frequencies': 10}),
    })
]

# Training configuration
config.trainer_module = './inr_utils/'
config.trainer_type = 'training.train_inr'

# # Target function configuration  # don't need this because neither sampler nor loss evaluator uses it
# config.target_function = 'audio.ContinuousAudio'
# config.target_function_config = {
#     'audio_file': './example_data/example.wav',
#     'scale_to_01': True,
#     'interpolation_method': 'audio.make_piecewise_constant_interpolation',
#     'sample_rate': 16000,
# }

# Loss function configuration
config.loss_function = 'losses.SoundLossEvaluator'
config.loss_function_config = {
    'time_domain_weight': 1.0,
    'frequency_domain_weight': 0.1
}

# Sampler configuration
config.sampler = ('sampling.SoundSampler', {
    'window_size': 1024,
    'batch_size': 32,
    'fragment_length': None,  # Will be set after loading audio
    'sound_fragment': "path_to_audio_file.npy",  # TODO store audio file as npy and put path here
})


In [None]:
# Cell [4] - Load audio and update sampler config
audio_data, fragment_length = load_audio_file(config.target_function_config['audio_file'])
config.sampler[1]['fragment_length'] = fragment_length

In [None]:
# Cell [7] - Training loop
num_epochs = 1000
key = jax.random.PRNGKey(0)

with wandb.init(
    project="inr-audio",
    config={
        "window_size": config.sampler[1]['window_size'],
        "batch_size": config.sampler[1]['batch_size'],
        "learning_rate": learning_rate,
        "num_epochs": num_epochs
    }
) as run:
    for epoch in range(num_epochs):
        key, subkey = jax.random.split(key)
        batch = config.sampler.sample(subkey)
        model, opt_state, loss = train_step(model, opt_state, batch)
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {loss}")
            wandb.log({
                "loss": loss,
                "epoch": epoch,
            })
wandb.finish()