In [1]:
import pdb
import traceback

import jax
from jax import numpy as jnp
import optax
import wandb

from common_dl_utils.config_creation import Config
import common_jax_utils as cju

wandb.login()

key = jax.random.PRNGKey(12398)
key_gen = cju.key_generator(key)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mabdtab[0m ([33mabdtab-tue[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
# Step 1: Data Preparation for Audio Modality
# Import necessary libraries
import numpy as np
import jax
import jax.numpy as jnp
from scipy.io import wavfile

file_path = "C:/Users/abdel/Downloads/Salem_Arabiano.wav"

# Function to load and preprocess audio data
def load_audio(file_path, sample_rate=16000):
    """
    Load audio file and resample to the desired sample rate.

    Args:
        file_path (str): Path to the audio file.
        sample_rate (int): Desired sample rate.

    Returns:
        jnp.ndarray: Normalized audio signal.
        float: Duration of the audio in seconds.
    """
    original_rate, audio = wavfile.read(file_path)

    if len(audio.shape) > 1:
        # Convert to mono if stereo
        audio = np.mean(audio, axis=1)

    # Resample if necessary
    if original_rate != sample_rate:
        duration = len(audio) / original_rate
        time_old = np.linspace(0, duration, len(audio))
        time_new = np.linspace(0, duration, int(sample_rate * duration))
        audio = np.interp(time_new, time_old, audio)

    # Normalize audio
    audio = audio / np.max(np.abs(audio))

    return jnp.array(audio), duration

# Function to sample audio data at specific time points
def sample_audio(audio, duration, num_samples=1000):
    """
    Sample the audio signal at specific time points.

    Args:
        audio (jnp.ndarray): Normalized audio signal.
        duration (float): Duration of the audio in seconds.
        num_samples (int): Number of samples.

    Returns:
        jnp.ndarray: Time points.
        jnp.ndarray: Corresponding audio values.
    """
    time_points = jnp.linspace(0, duration, num_samples)
    indices = (time_points * len(audio) / duration).astype(int)
    indices = jnp.clip(indices, 0, len(audio) - 1)  # Ensure indices are within bounds

    return time_points, audio[indices]

In [None]:
# Load the audio file
audio_data, fragment_length = load_audio_file(config.target_function_config['audio_file'])

# Create the sampler
window_size = 1024  # Adjust this based on your needs
batch_size = 32     # Adjust based on your memory constraints
sampler = SoundSampler(
    sound_fragment=audio_data,
    fragment_length=fragment_length,
    window_size=window_size,
    batch_size=batch_size
)

# Create the loss evaluator
loss_evaluator = SoundLossEvaluator(
    time_domain_weight=1.0,
    frequency_domain_weight=0.1  # Adjust these weights as needed
)

# Initialize the model (using your existing config)
model = cju.construct_model(config)

# Create optimizer
learning_rate = 1e-4
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(model)

# Training step function
@jax.jit
def train_step(model, opt_state, key):
    # Get batch of samples
    time_points, pressure_values = sampler(key)
    
    # Calculate loss and gradients
    def loss_fn(model):
        loss, _ = loss_evaluator(model, (time_points, pressure_values))
        return loss
    
    loss, grads = jax.value_and_grad(loss_fn)(model)
    
    # Update model
    updates, opt_state = optimizer.update(grads, opt_state)
    model = optax.apply_updates(model, updates)
    
    return model, opt_state, loss

# Training loop
num_epochs = 1000
key = jax.random.PRNGKey(0)

# Initialize wandb
wandb.init(
    project="inr-audio",
    config={
        "window_size": window_size,
        "batch_size": batch_size,
        "learning_rate": learning_rate,
        "num_epochs": num_epochs
    }
)

for epoch in range(num_epochs):
    key, subkey = jax.random.split(key)
    model, opt_state, loss = train_step(model, opt_state, subkey)
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss}")
        wandb.log({"loss": loss})

wandb.finish()

In [3]:
config = Config()

# Define the audio input (1D) modality
config.architecture = './model_components'  
config.model_type = 'inr_modules.CombinedINR'

config.model_config = Config()
config.model_config.in_size = 1  # Since audio is 1D
config.model_config.out_size = 1  # For output signals

config.model_config.terms = [
    ('inr_modules.MLPINR.from_config',{
        'hidden_size': 256,
        'num_layers': 5,
        'layer_type': 'inr_layers.FinerLayer',  
        'num_splits': 1,
        'use_complex': False,
        'activation_kwargs': {'w0': 30}, 
        'initialization_scheme':'initialization_schemes.finer_scheme',
        'initialization_scheme_kwargs':{'bias_k' : 10} 
    })
]

# Training setup for audio modality
config.trainer_module = './inr_utils/' 
config.trainer_type = 'training.train_inr'

# Set up for audio data input
config.target_function = 'audio.AudioSignal'  # Assuming you have an audio data handler
config.target_function_config = {
    'audio_file': 'C:/Users/abdel/Downloads/Salem_Arabiano.wav',  # Set the path to the audio file you want to use
    'scale_to_01': True,
}

config.loss_function = 'losses.scaled_mse_loss'
config.sampler = ('sampling.GridSubsetSampler',{
    'size': [audio_length],  # Define the length of the audio signal
    'batch_size': 2000,
    'allow_duplicates': False,
})

config.optimizer = 'adam'
config.optimizer_config = {
    'learning_rate': 1.5e-4
}
config.steps = 40000
config.use_wandb = False  # You can enable this if you're logging with Weights & Biases

config.after_step_callback = 'callbacks.ComposedCallback'
config.after_step_callback_config = {
    'callbacks': [
        ('callbacks.print_loss', {'after_every':400}),
        'callbacks.report_loss',
        ('callbacks.MetricCollectingCallback', {'metric_collector':'metrics.MetricCollector'}),
        'callbacks.raise_error_on_nan'
    ],
    'show_logs': False
}

config.metric_collector_config = {
    'metrics': [
        ('metrics.MSEOnFixedGrid', {'grid': [audio_length], 'batch_size': audio_length, 'frequency': 'every_n_batches'})
    ],
    'batch_frequency': 400,
    'epoch_frequency': 1
}


NameError: name 'audio_length' is not defined

In [None]:
try:
    inr = cju.run_utils.get_model_from_config_and_key(
        prng_key=next(key_gen),
        config=config,
        model_sub_config_name_base='model',
        add_model_module_to_architecture_default_module=False, # since the model is already in the default module specified by 'architecture',
    )
except Exception as e:
    traceback.print_exc()
    print(e)
    print('\n')
    pdb.post_mortem()

In [None]:
inr

In [None]:
inr(jnp.zeros(2))

In [None]:
# next we get the experiment from the config using common_jax_utils.run_utils.get_experiment_from_config_and_key
experiment = cju.run_utils.get_experiment_from_config_and_key(
    prng_key=next(key_gen),
    config=config,
    model_kwarg_in_trainer='inr',
    model_sub_config_name_base='model',  # so it looks for "model_config" in config
    trainer_default_module_key='trainer_module',  # so it knows to get the module specified by config.trainer_module
    additional_trainer_default_modules=[optax],  # remember the don't forget to add optax to the default modules? This is that 
    add_model_module_to_architecture_default_module=False,
    initialize=False  # don't run the experiment yet, we want to use wandb
)

In [None]:
# and we run the experiment while logging things to wandb
with wandb.init(
    project='inr_edu_24',
    notes='test',
    tags=['test']
) as run:
    results = experiment.initialize()

In [None]:
# Create a configuration object
config = Config()
config.architecture = './model_components'

# We'll still use the CombinedINR (summing outputs of multiple sub-MLPs)
config.model_type = 'inr_modules.CombinedINR'

# Now define how many input/output dimensions your INR expects
# For 1D audio (mono), in_size=1 and out_size=1
# (If stereo, you might do in_size=1, out_size=2, etc.)
config.model_config = Config()
config.model_config.in_size = 1
config.model_config.out_size = 1

config.model_config.terms = [  # CombinedINR uses multiple MLPs and returns the sum of their outputs. These 'terms' are the MLPs
    # ('inr_modules.MLPINR.new_from_config',{
    #     'hidden_size': 256,
    #     'num_layers': 5,
    #     'layer_type': 'inr_layers.SirenLayer',
    #     'num_splits': 1,
    #     'use_complex': False,
    #     'activation_kwargs': {'w0': 30.},
    #     'initialization_scheme':'initialization_schemes.siren_scheme',
    #     'positional_encoding_layer': ('inr_layers.ClassicalPositionalEncoding.from_config', {'num_frequencies': 10}),
    # }),
    # ('inr_modules.MLPINR.from_config',{
    #     'hidden_size': 1024,
    #     'num_layers': 2,
    #     'num_splits': 1,
    #     'layer_type': 'inr_layers.GaussianINRLayer',
    #     'use_complex': False,
    #     'activation_kwargs': {'inverse_scale': 1},
    # })
    ('inr_modules.MLPINR.new_from_config',{
        'hidden_size': 256,
        'num_layers': 5,
        'layer_type': 'inr_layers.FinerLayer',
        'num_splits': 1,
        'use_complex': False,
        'activation_kwargs': {'w0': 30},
        'initialization_scheme':'initialization_schemes.finer_scheme',
        'initialization_scheme_kwargs':{'bias_k' : 10}
        # 'initialization_scheme_k' : {'k': 20}
        #'positional_encoding_layer': ('inr_layers.ClassicalPositionalEncoding.from_config', {'num_frequencies': 10}),
    })
]

# Tell the system where to look for trainer objects
config.trainer_module = './inr_utils/'
config.trainer_type = 'training.train_inr'


config.target_function = 'audio.ContinuousAudio'
config.target_function_config = {
    'audio_file': './example_data/example.wav',
    # If your audio class has a built-in normalization, you can set it here
    'scale_to_01': True,
    # If you have a custom interpolation method or want nearest-sample access,
    # define it or remove if not needed
    'interpolation_method': 'audio.make_piecewise_constant_interpolation',
    # Additional parameters your audio class might need
    'sample_rate': 16000,
}


config.loss_function = 'losses.scaled_mse_loss_with_scale_factor'
config.loss_function_config = {
    'scale_factor': 10
}     
                  

                

