In [1]:

import pdb
import os
import traceback
import numpy as np
import jax
from jax import numpy as jnp
import optax
import wandb
import equinox as eqx
from typing import Optional, Callable
import librosa

from common_dl_utils.config_creation import Config
import common_jax_utils as cju

wandb.login()

key = jax.random.PRNGKey(12398)
key_gen = cju.key_generator(key)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mabdtab[0m ([33mabdtab-tue[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
def load_audio_file(file_path, sr=16000, save_npy=True):
    """
    Load an audio file and return it as a normalized numpy array.
    Optionally save as .npy file.
    
    Args:
        file_path: Path to the audio file
        sr: Target sampling rate (default: 16000)
        save_npy: Whether to save the audio as .npy file (default: True)
        
    Returns:
        tuple: (audio_array, fragment_length, npy_path)
    """
    file_path = './data_gt_bach.wav'
    # Load the audio file
    audio, _ = librosa.load(file_path, sr=sr)
    
    # Convert to numpy array and normalize to [-1, 1] range
    audio = np.array(audio, dtype=np.float32)
    audio = audio / np.max(np.abs(audio))

    if save_npy:
        # Create npy filename from original audio filename
        npy_path = os.path.splitext(file_path)[0] + '.npy'
        np.save(npy_path, audio)
        return audio, len(audio), npy_path
    
    return audio, len(audio), None


In [3]:
# Cell [3] - Configuration setup
config = Config()
# audio_path = './data_gt_bach.wav'
# audio_data, fragment_length = load_audio_file(audio_path)

# Model architecture configuration
config.architecture = './model_components'
config.model_type = 'inr_modules.CombinedINR'

# Model configuration
config.model_config = Config()
config.model_config.in_size = 1  # Time dimension input
config.model_config.out_size = 1  # Audio amplitude output
config.model_config.terms = [
    # ('inr_modules.MLPINR.new_from_config',{
    #     'hidden_size': 256,
    #     'num_layers': 5,
    #     'layer_type': 'inr_layers.SirenLayer',
    #     'num_splits': 1,
    #     'use_complex': False,
    #     'activation_kwargs': {'w0': 30.},
    #     'initialization_scheme':'initialization_schemes.siren_scheme',
    #     'positional_encoding_layer': ('inr_layers.ClassicalPositionalEncoding.from_config', {'num_frequencies': 10}),
    # }),
    # ('inr_modules.MLPINR.from_config',{
    #     'hidden_size': 1024,
    #     'num_layers': 2,
    #     'num_splits': 1,
    #     'layer_type': 'inr_layers.GaussianINRLayer',
    #     'use_complex': False,
    #     'activation_kwargs': {'inverse_scale': 1},
    # })
    ('inr_modules.MLPINR.from_config',{
        'hidden_size': 256,
        'num_layers': 5,
        'layer_type': 'inr_layers.FinerLayer',
        'num_splits': 1,
        'use_complex': False,
        'activation_kwargs': {'w0': 10},
        'initialization_scheme':'initialization_schemes.finer_scheme',
        'initialization_scheme_kwargs':{'bias_k' : 5}
        # 'initialization_scheme_k' : {'k': 20}
        #'positional_encoding_layer': ('inr_layers.ClassicalPositionalEncoding.from_config', {'num_frequencies': 10}),
    })
]

# Training configuration
config.trainer_module = './inr_utils/'
config.trainer_type = 'training.train_inr'

# # Target function configuration  # don't need this because neither sampler nor loss evaluator uses it
# config.target_function = 'audio.ContinuousAudio'
# config.target_function_config = {
#     'audio_file': './example_data/example.wav',
#     'scale_to_01': True,
#     'interpolation_method': 'audio.make_piecewise_constant_interpolation',
#     'sample_rate': 16000,
# }

# Loss function configuration
config.loss_function = 'losses.SoundLossEvaluator'
config.loss_function_config = {
    'time_domain_weight': 1.0,
    'frequency_domain_weight': 0.1
}

config.optimizer = 'optax.adam'
config.optimizer_config = {
    'learning_rate': 1e-3
}

# Sampler configuration
config.sampler = ('sampling.SoundSampler', {
    'window_size': 1024,
    'batch_size': 32,
    'fragment_length': None, # Will be set after loading audio
    'sound_fragment': './data_gt_bach.npy',  # TODO store audio file as npy and put path here
    #'sound_fragment': "path_to_audio_file.npy",  # TODO store audio file as npy and put path here
})
config.optimizer_state = None 


## the two code blocks below are where the erros are coming from, it is in the training loop the code below is way more than it has to be and ivejust been trying to fix the errors, thank you agian for the help :)

In [4]:
# Cell [7] - Training loop
import sys
sys.path.append('/home/abdtab/INR_BEP')  # Add the parent directory containing inr_utils

import inr_utils.sampling as sampling

from inr_utils.training import make_inr_train_step_function

train_step = make_inr_train_step_function(
    loss_evaluator=config.loss_function,
    sampler=config.sampler,
    optimizer=config.optimizer
)

num_epochs = 1000
key = jax.random.PRNGKey(0)


# First, instantiate the sampler
sampler_class_name, sampler_config = config.sampler
sampler = sampling.SoundSampler(**sampler_config)

# First, instantiate the sampler
sampler_class_name, sampler_config = config.sampler
sampler = sampling.SoundSampler(**sampler_config)

with wandb.init(
    project="inr-audio",
    config={
        "window_size": config.sampler[1]['window_size'],
        "batch_size": config.sampler[1]['batch_size'],
        "learning_rate": config.optimizer_config['learning_rate'],
        "num_epochs": num_epochs,
        "audio_path": './data_gt_bach.npy'
    }
) as run:
    for epoch in range(num_epochs):
        key, subkey = jax.random.split(key)
        # The train_step function handles sampling internally, so we don't need to call sampler
        model, opt_state, loss, state = train_step(config.model_config, config.optimizer_state, subkey)
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {loss}")
            wandb.log({
                "loss": loss,
                "epoch": epoch,
            })
wandb.finish()

Traceback (most recent call last):
  File "/tmp/ipykernel_2514/742912583.py", line 40, in <module>
    model, opt_state, loss, state = train_step(config.model_config, config.optimizer_state, subkey)
  File "/home/abdtab/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/equinox/_jit.py", line 242, in __call__
    return self._call(False, args, kwargs)
  File "/home/abdtab/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/equinox/_module.py", line 1078, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
  File "/home/abdtab/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/equinox/_jit.py", line 215, in _call
    out = self._cached(dynamic_donate, dynamic_nodonate, static)
ValueError: Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class 'tuple'>, (((<function make_inr_train_step_function.<locals>.train_step at 0x7f93d310e3b0>,), PyTreeDef(*)), (({'in_size': 1, 'out_size': 1, 'terms': [('inr_mo

ValueError: Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class 'tuple'>, (((<function make_inr_train_step_function.<locals>.train_step at 0x7f93d310e3b0>,), PyTreeDef(*)), (({'in_size': 1, 'out_size': 1, 'terms': [('inr_modules.MLPINR.from_config', {'hidden_size': 256, 'num_layers': 5, 'layer_type': 'inr_layers.FinerLayer', 'num_splits': 1, 'use_complex': False, 'activation_kwargs': {'w0': 10}, 'initialization_scheme': 'initialization_schemes.finer_scheme', 'initialization_scheme_kwargs': {'bias_k': 5}})]},), PyTreeDef(*)), ((None,), PyTreeDef(((None, *, None), {})))). The error was:
TypeError: unhashable type: 'Config'


In [2]:
# this is the code i wrote to try and get the model working and solving the errors 
# and i know for a fact that this is way more complicated than it needs to be
# but i was solving the errors with gpt and cursor at that point cuz i gave up on manually fixing it 

import sys
sys.path.append('/home/abdtab/INR_BEP')  # Add the parent directory containing inr_utils

import inr_utils.sampling as sampling
from inr_utils.training import make_inr_train_step_function, initialize_state
from model_components.inr_modules import CombinedINR  # Import your model class

# Create train_step with config components
train_step = make_inr_train_step_function(
    loss_evaluator=config.loss_function,
    sampler=config.sampler,
    optimizer=config.optimizer
)

num_epochs = 1000
key = jax.random.PRNGKey(0)

# First, instantiate the sampler
sampler_class_name, sampler_config = config.sampler
sampler = sampling.SoundSampler(**sampler_config)

# Create the model instance using CombinedINR and your config
model = CombinedINR.from_config(config.model_config)
model, state = initialize_state(model)

# Initialize optimizer state
opt_state = config.optimizer.init(model.parameters())

with wandb.init(
    project="inr-audio",
    config={
        "window_size": config.sampler[1]['window_size'],
        "batch_size": config.sampler[1]['batch_size'],
        "learning_rate": config.optimizer_config['learning_rate'],
        "num_epochs": num_epochs,
        "audio_path": './data_gt_bach.npy'
    }
) as run:
    for epoch in range(num_epochs):
        key, subkey = jax.random.split(key)
        model, opt_state, loss, state = train_step(model, opt_state, subkey, state)
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {loss}")
            wandb.log({
                "loss": loss,
                "epoch": epoch,
            })
wandb.finish()

NameError: name 'config' is not defined

In [None]:
# this i just took from inr_example.ipynb to try and get the model working
try:
    inr = cju.run_utils.get_model_from_config_and_key(
        prng_key=next(key_gen),
        config=config,
        model_sub_config_name_base='model',
        add_model_module_to_architecture_default_module=False,
    )
except Exception as e:
    traceback.print_exc()
    print(e)
    print('\n')
    pdb.post_mortem()

Traceback (most recent call last):
  File "/tmp/ipykernel_57731/1907487116.py", line 3, in <module>
    inr = cju.run_utils.get_model_from_config_and_key(
  File "/home/abdtab/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/common_jax_utils/run_utils.py", line 95, in get_model_from_config_and_key
    un_initialized_model = get_model_from_config(
  File "/home/abdtab/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/common_dl_utils/config_realization.py", line 1185, in get_model_from_config
    default_module = load_from_path(name="architecture", path=config[default_module_key]) if default_module_key is not None else None
  File "/home/abdtab/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/common_dl_utils/module_loading.py", line 48, in load_from_path
    raise ModuleNotFoundError(f"Could not find {path=}")
ModuleNotFoundError: Could not find path='./model_components'


Could not find path='./model_components'


> [0;32m/home/abdtab/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/common_dl_utils/module_loading.py[0m(48)[0;36mload_from_path[0;34m()[0m
[0;32m     46 [0;31m    [0;31m# first check if path exists[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     47 [0;31m    [0;32mif[0m [0;32mnot[0m [0mos[0m[0;34m.[0m[0mpath[0m[0;34m.[0m[0mexists[0m[0;34m([0m[0mpath[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 48 [0;31m        [0;32mraise[0m [0mModuleNotFoundError[0m[0;34m([0m[0;34mf"Could not find {path=}"[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     49 [0;31m[0;34m[0m[0m
[0m[0;32m     50 [0;31m    [0;31m# if path exists but is not a .py or .pyc file, see if it is a directory with a __init__.py file[0m[0;34m[0m[0;34m[0m[0m
[0m
> [0;32m/home/abdtab/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/common_dl_utils/config_realization.py[0m(1185)[0;36mget_model