# Trying to get multiple INRs to train in parallel on a single GPU


In [1]:
import pdb
import traceback

import jax
from jax import numpy as jnp
import optax
# import wandb

from common_dl_utils.config_creation import Config
import common_jax_utils as cju

# wandb.login()

key = jax.random.PRNGKey(12398)
key_gen = cju.key_generator(key)

2025-01-30 19:24:53.411474: W external/xla/xla/service/gpu/nvptx_compiler.cc:893] The NVIDIA driver's CUDA version is 12.4 which is older than the PTX compiler version 12.6.68. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


We want to train a single INR on `example_data/parrot.png`. We'll use the `CombinedINR` clas from `model_components.inr_modules` together with the `SirenLayer` and `GaussianINRLayer` from `model_components.inr_layers` for the model, and we'll train it using the tools from `inr_utils`.

To do all of this, basically we only need to create a config. We'll use the `common_dl_utils.config_creation.Config` class for this, but this is basically just a dictionary that allows for attribute access-like acces of its elements (so we can do `config.model_type = "CombinedINR"` instead of `config["model_type"] = "CombinedINR"`). You can also just use a dictionary instead.

Then we'll use the tools from `common_jax_utils` to first get a model from this config so we can inspect it, and then just run the experiment specified by the config.

Doing this in a config instead of hard coded might seem like extra work, but consider this:
1. you can serialize this config as a json file or a yaml file to later get the same model and experimental settings back 
   so when you are experimenting with different architectures, if you just store the configs you've used, you can easily recreate previous results
2. when we get to running hyper parameter sweeps, you can easily get these configs (with a pick for the varying hyper parameters) from wandb
   and then run an experiment specified by that config on any machine you want, e.g. on Snellius

In [2]:
config = Config()

# first we specify what the model should look like
config.architecture = './model_components'  # module containing all relevant classes for architectures
# NB if the classes relevant for creating the model are spread over multiple modules, this is no problem
# let config.architecture be the module that contains the "main" model class, and for all other components just specify the module
# or specify the other modules as default modules to the tools in common_jax_utils.run_utils
config.model_type = 'inr_modules.CombinedINR'

config.model_config = Config()
config.model_config.in_size = 2
config.model_config.out_size = 3
config.model_config.terms = [  # CombinedINR uses multiple MLPs and returns the sum of their outputs. These 'terms' are the MLPs
    ('inr_modules.MLPINR.from_config',{
        'hidden_size': 256,
        'num_layers': 5,
        'layer_type': 'inr_layers.SirenLayer',
        'num_splits': 1, #3,
        'activation_kwargs': {'w0':12.},#{'inverse_scale': 5.},
        'initialization_scheme':'initialization_schemes.siren_scheme',
        'initialization_scheme_kwargs': {'w0': 12.},
        'positional_encoding_layer': ('state_test_objects.py', 'CountingIdentity'),
    }),
    # ('inr_modules.MLPINR.from_config',{
    #     'hidden_size': 1024,
    #     'num_layers': 2,
    #     'num_splits': 1,
    #     'layer_type': 'inr_layers.GaussianINRLayer',
    #     'use_complex': False,
    #     'activation_kwargs': {'inverse_scale': 1},
    # })
]

# next, we set up the training loop, including the 'target_function' that we want to mimic
config.trainer_module = './inr_utils/'  # similarly to config.architecture above, here we just specify in what module to look for objects by default
config.trainer_type = 'training.train_inr_scan'
config.loss_evaluator = 'losses.PointWiseLossEvaluator'
config.target_function = 'images.ContinuousImage'
config.target_function_config = {
    'image': './example_data/parrot.png',
    'scale_to_01': True,
    'interpolation_method': 'images.make_piece_wise_constant_interpolation',
    'data_index': None,
}
config.loss_function = 'losses.scaled_mse_loss'
config.state_update_function = ('state_test_objects.py', 'counter_updater')
config.sampler = ('sampling.GridSubsetSampler',{  # samples coordinates in a fixed grid, that should in this case coincide with the pixel locations in the image
    'size': [2040, 1356],
    'batch_size': 2000,
    'allow_duplicates': False,
})

config.optimizer = 'adam'  # we'll have to add optax to the additional default modules later
config.optimizer_config = {
    'learning_rate': 1.5e-4
}
config.steps = 40000 #changed from 40000
# config.use_wandb = True

# # now we want some extra things, like logging, to happen during training
# # the inr_utils.training.train_inr function allows for this through callbacks.
# # The callbacks we want to use can be found in inr_utils.callbacks
# config.after_step_callback = 'callbacks.ComposedCallback'
# config.after_step_callback_config = {
#     'callbacks':[
#         ('callbacks.print_loss', {'after_every':400}),  # only print the loss every 400th step
#         'callbacks.report_loss',  # but log the loss to wandb after every step
#         ('callbacks.MetricCollectingCallback', # this thing will help us collect metrics and log images to wandb
#              {'metric_collector':'metrics.MetricCollector'}
#         ),
#         'callbacks.raise_error_on_nan'  # stop training if the loss becomes NaN
#     ],
#     'show_logs': False
# }

# config.after_training_callback = ('state_test_objects.py', 'after_training_callback')

# config.metric_collector_config = {  # the metrics for MetricCollectingCallback / metrics.MetricCollector
#     'metrics':[
#         ('metrics.PlotOnGrid2D', {'grid': 256, 'batch_size':8*256, 'frequency':'every_n_batches'}),  
#         # ^ plots the image on this fixed grid so we can visually inspect the inr on wandb
#         ('metrics.MSEOnFixedGrid', {'grid': [2040, 1356], 'batch_size':2040, 'frequency': 'every_n_batches'})
#         # ^ compute the MSE with the actual image pixels
#     ],
#     'batch_frequency': 400,  # compute all of these metrics every 400 batches
#     'epoch_frequency': 1  # not actually used
# }

# #config.after_training_callback = None  # don't care for one now, but you could have this e.g. store some nice loss plots if you're not using wandb 
# config.optimizer_state = None  # we're starting from scratch

In [3]:
# let's first see if we get the correct model
try:
    inr = cju.run_utils.get_model_from_config_and_key(
        prng_key=next(key_gen),
        config=config,
        model_sub_config_name_base='model',
        add_model_module_to_architecture_default_module=False, # since the model is already in the default module specified by 'architecture',
    )
except Exception as e:
    traceback.print_exc()
    print(e)
    print('\n')
    pdb.post_mortem()

In [4]:
inr

CombinedINR(
  terms=(
    MLPINR(
      layers=(
        CountingIdentity(
          _embedding_matrix=f32[3],
          state_index=StateIndex(
            marker=<object object at 0x7fdc605aae20>,
            init=i32[]
          )
        ),
        SirenLayer(
          weights=f32[256,2],
          biases=f32[256],
          activation_kwargs={'w0': 12.0}
        ),
        SirenLayer(
          weights=f32[256,256],
          biases=f32[256],
          activation_kwargs={'w0': 12.0}
        ),
        SirenLayer(
          weights=f32[256,256],
          biases=f32[256],
          activation_kwargs={'w0': 12.0}
        ),
        Linear(weights=f32[3,256], biases=f32[3], activation_kwargs={})
      )
    ),
  ),
  post_processor=<function real_part>
)

In [5]:
# check that it works properly
try:
    inr(jnp.zeros(2))
except Exception as e:
    traceback.print_exc()
    print(e)
    print('\n')
    pdb.post_mortem()

In [6]:
# next we get the experiment from the config using common_jax_utils.run_utils.get_experiment_from_config_and_key
experiment = cju.run_utils.get_experiment_from_config_and_key(
    prng_key=next(key_gen),
    config=config,
    model_kwarg_in_trainer='inr',
    model_sub_config_name_base='model',  # so it looks for "model_config" in config
    trainer_default_module_key='trainer_module',  # so it knows to get the module specified by config.trainer_module
    additional_trainer_default_modules=[optax],  # remember the don't forget to add optax to the default modules? This is that 
    add_model_module_to_architecture_default_module=False,
    initialize=False  # don't run the experiment yet, we want to use wandb
)

In [7]:
experiment

PostponedInitialization(cls=train_inr_scan, kwargs={'steps': 40000, 'loss_evaluator': PostponedInitialization(cls=PointWiseLossEvaluator, kwargs={'target_function': PostponedInitialization(cls=ContinuousImage, kwargs={'image': './example_data/parrot.png', 'scale_to_01': True, 'interpolation_method': <function make_piece_wise_constant_interpolation at 0x7fdc28121bd0>, 'data_index': None}, missing_args=[]), 'loss_function': <function scaled_mse_loss at 0x7fdbf0d34ca0>, 'state_update_function': <function counter_updater at 0x7fdbd8765ab0>}, missing_args=[]), 'sampler': PostponedInitialization(cls=GridSubsetSampler, kwargs={'size': [2040, 1356], 'batch_size': 2000, 'allow_duplicates': False, 'min': 0.0, 'max': 1.0, 'num_dimensions': None, 'indexing': 'ij'}, missing_args=[]), 'optimizer': PostponedInitialization(cls=adam, kwargs={'learning_rate': 0.00015, 'b1': 0.9, 'b2': 0.999, 'eps': 1e-08, 'eps_root': 0.0, 'mu_dtype': None, 'nesterov': False}, missing_args=[]), 'state_initialization_func

In [8]:
# run it
try:
    results = experiment.initialize()
    print(results)
except Exception as e:
    print(e)
    print()
    traceback.print_exc()
    pdb.post_mortem()

E0130 19:25:01.106727 2669192 buffer_comparator.cc:157] Difference at 2774: 0.854811, expected 0.651917
E0130 19:25:01.106762 2669192 buffer_comparator.cc:157] Difference at 5692: -0.490153, expected -0.328079
E0130 19:25:01.106769 2669192 buffer_comparator.cc:157] Difference at 7625: -0.0861647, expected -0.251582
E0130 19:25:01.106776 2669192 buffer_comparator.cc:157] Difference at 10116: 1.21569, expected 0.980453
E0130 19:25:01.106780 2669192 buffer_comparator.cc:157] Difference at 11249: 0.598515, expected 0.847711
E0130 19:25:01.106784 2669192 buffer_comparator.cc:157] Difference at 12328: 0.00544977, expected -0.148094
E0130 19:25:01.106792 2669192 buffer_comparator.cc:157] Difference at 15248: 0.391176, expected 0.24692
E0130 19:25:01.106800 2669192 buffer_comparator.cc:157] Difference at 18309: 0.0632911, expected 0.211475
E0130 19:25:01.106807 2669192 buffer_comparator.cc:157] Difference at 20861: 0.046452, expected -0.0853729
E0130 19:25:01.106816 2669192 buffer_comparator.c

(CombinedINR(
  terms=(
    MLPINR(
      layers=(
        CountingIdentity(
          _embedding_matrix=f32[3],
          state_index=StateIndex(marker=0, init=_Sentinel())
        ),
        SirenLayer(
          weights=f32[256,2],
          biases=f32[256],
          activation_kwargs={'w0': 12.0}
        ),
        SirenLayer(
          weights=f32[256,256],
          biases=f32[256],
          activation_kwargs={'w0': 12.0}
        ),
        SirenLayer(
          weights=f32[256,256],
          biases=f32[256],
          activation_kwargs={'w0': 12.0}
        ),
        Linear(weights=f32[3,256], biases=f32[3], activation_kwargs={})
      )
    ),
  ),
  post_processor=<function real_part>
), (ScaleByAdamState(count=Array(40000, dtype=int32), mu=CombinedINR(
  terms=(
    MLPINR(
      layers=(
        CountingIdentity(
          _embedding_matrix=f32[3],
          state_index=StateIndex(marker=0, init=_Sentinel())
        ),
        SirenLayer(
          weights=f32[256,2],
   

In [9]:
from collections.abc import Sequence, Mapping
from common_dl_utils.config_realization import PostponedInitialization
def complete_postponed_initialization(postponed_init:PostponedInitialization, completion: dict):
    postponed_init.resolve_missing_args(completion)
    for value in postponed_init.kwargs.values():
        if isinstance(value, PostponedInitialization):
            complete_postponed_initialization(value, completion)
        elif isinstance(value, Sequence):
            for v in value:
                if isinstance(v, PostponedInitialization):
                    complete_postponed_initialization(v, completion)
        elif isinstance(value, Mapping):
            for v in value.values():
                if isinstance(v, PostponedInitialization):
                    complete_postponed_initialization(v, completion)


def run_experiment(missing_kwargs: dict, config:dict, key:jax.Array):
    experiment = cju.run_utils.get_experiment_from_config_and_key(
        prng_key=key,
        config=config,
        model_kwarg_in_trainer='inr',
        model_sub_config_name_base='model',  # so it looks for "model_config" in config
        trainer_default_module_key='trainer_module',  # so it knows to get the module specified by config.trainer_module
        additional_trainer_default_modules=[optax],  # remember the don't forget to add optax to the default modules? This is that 
        add_model_module_to_architecture_default_module=False,
        initialize=False  # don't run the experiment yet, we want to add the missing kwargs
    )
    complete_postponed_initialization(experiment, missing_kwargs)
    return experiment.initialize()
    #return experiment

In [10]:
incomplete_config = Config()

# first we specify what the model should look like
incomplete_config.architecture = './model_components'  # module containing all relevant classes for architectures
# NB if the classes relevant for creating the model are spread over multiple modules, this is no problem
# let config.architecture be the module that contains the "main" model class, and for all other components just specify the module
# or specify the other modules as default modules to the tools in common_jax_utils.run_utils
incomplete_config.model_type = 'inr_modules.CombinedINR'

incomplete_config.model_config = Config()
incomplete_config.model_config.in_size = 2
incomplete_config.model_config.out_size = 3
incomplete_config.model_config.terms = [  # CombinedINR uses multiple MLPs and returns the sum of their outputs. These 'terms' are the MLPs
    ('inr_modules.MLPINR.from_config',{
        'hidden_size': 256,
        'num_layers': 5,
        'layer_type': 'inr_layers.SirenLayer',
        'num_splits': 3,
        #'activation_kwargs': {'w0':12.}, #                        <-------------------------------------------------------------- this is the missin one
        'initialization_scheme':'initialization_schemes.siren_scheme',
        #'initialization_scheme_kwargs': {'w0': 12.},
        'positional_encoding_layer': ('state_test_objects.py', 'CountingIdentity'),
    }),
    # ('inr_modules.MLPINR.from_config',{
    #     'hidden_size': 1024,
    #     'num_layers': 2,
    #     'num_splits': 1,
    #     'layer_type': 'inr_layers.GaussianINRLayer',
    #     'use_complex': False,
    #     'activation_kwargs': {'inverse_scale': 1},
    # })
]

# next, we set up the training loop, including the 'target_function' that we want to mimic
incomplete_config.trainer_module = './inr_utils/'  # similarly to config.architecture above, here we just specify in what module to look for objects by default
incomplete_config.trainer_type = 'training.train_inr_scan'
incomplete_config.loss_evaluator = 'losses.PointWiseLossEvaluator'
incomplete_config.target_function = 'images.ContinuousImage'
incomplete_config.target_function_config = {
    'image': './example_data/parrot.png',
    'scale_to_01': True,
    'interpolation_method': 'images.make_piece_wise_constant_interpolation',
    'data_index': None
}
incomplete_config.loss_function = 'losses.scaled_mse_loss'
incomplete_config.state_update_function = ('state_test_objects.py', 'counter_updater')
incomplete_config.sampler = ('sampling.GridSubsetSampler',{  # samples coordinates in a fixed grid, that should in this case coincide with the pixel locations in the image
    'size': [2040, 1356],
    'batch_size': 2000,
    'allow_duplicates': False,
})

incomplete_config.optimizer = 'adam'  # we'll have to add optax to the additional default modules later
incomplete_config.optimizer_config = {
    'learning_rate': 1.5e-4
}
incomplete_config.steps = 40000 #changed from 40000

In [11]:
run_experiment(
    missing_kwargs={"activation_kwargs": {"w0": 12.}},
    config=incomplete_config,
    key=next(key_gen)
)

(CombinedINR(
   terms=(
     MLPINR(
       layers=(
         CountingIdentity(
           _embedding_matrix=f32[3],
           state_index=StateIndex(marker=0, init=_Sentinel())
         ),
         SirenLayer(
           weights=f32[256,2],
           biases=f32[256],
           activation_kwargs={'w0': 12.0}
         ),
         SirenLayer(
           weights=f32[256,256],
           biases=f32[256],
           activation_kwargs={'w0': 12.0}
         ),
         SirenLayer(
           weights=f32[256,256],
           biases=f32[256],
           activation_kwargs={'w0': 12.0}
         ),
         Linear(weights=f32[3,256], biases=f32[3], activation_kwargs={})
       )
     ),
   ),
   post_processor=<function real_part>
 ),
 (ScaleByAdamState(count=Array(40000, dtype=int32), mu=CombinedINR(
    terms=(
      MLPINR(
        layers=(
          CountingIdentity(
            _embedding_matrix=f32[3],
            state_index=StateIndex(marker=0, init=_Sentinel())
          ),
          

In [12]:
import equinox as eqx
num_parallel = 10

def v_mappable_runner(w0, key):
    return_value = run_experiment(
        missing_kwargs={"activation_kwargs": {"w0": w0}},
        config=incomplete_config,
        key=key
    )
    #return eqx.filter(return_value, eqx.is_array_like)
    return return_value

keys = jax.random.split(next(key_gen), num_parallel)
w0s = jnp.linspace(10., 30., num=num_parallel)

#results = jax.vmap(v_mappable_runner)(w0s, keys)  # no idea why this results in a user warning while the single one doesn't... but it seems to work
results = eqx.filter_vmap(v_mappable_runner)(w0s, keys)

In [13]:
eqx.filter_vmap(run_experiment, in_axes=(0, None, 0))({"activation_kwargs": {"w0": w0s}}, incomplete_config, keys)

(CombinedINR(
   terms=(
     MLPINR(
       layers=(
         CountingIdentity(
           _embedding_matrix=f32[10,3],
           state_index=StateIndex(marker=0, init=_Sentinel())
         ),
         SirenLayer(
           weights=f32[10,256,2],
           biases=f32[10,256],
           activation_kwargs={'w0': f32[10]}
         ),
         SirenLayer(
           weights=f32[10,256,256],
           biases=f32[10,256],
           activation_kwargs={'w0': f32[10]}
         ),
         SirenLayer(
           weights=f32[10,256,256],
           biases=f32[10,256],
           activation_kwargs={'w0': f32[10]}
         ),
         Linear(weights=f32[10,3,256], biases=f32[10,3], activation_kwargs={})
       )
     ),
   ),
   post_processor=<function real_part>
 ),
 (ScaleByAdamState(count=Array([40000, 40000, 40000, 40000, 40000, 40000, 40000, 40000, 40000,
         40000], dtype=int32), mu=CombinedINR(
    terms=(
      MLPINR(
        layers=(
          CountingIdentity(
            _e

In [14]:
from state_test_objects import after_training_callback, CountingIdentity
from inr_utils.parallel_training import tree_unstack

inr, optimizer_state, state, losses = results

for _inr, _optimizer_state, _state, _losses in tree_unstack(results):
    after_training_callback(_losses, _inr, _state)

Checking model and state for CountingIdentity layers
Found a CountingIdentity layer with counter value 40000 in final state after training.
Checking model and state for CountingIdentity layers
Found a CountingIdentity layer with counter value 40000 in final state after training.
Checking model and state for CountingIdentity layers
Found a CountingIdentity layer with counter value 40000 in final state after training.
Checking model and state for CountingIdentity layers
Found a CountingIdentity layer with counter value 40000 in final state after training.
Checking model and state for CountingIdentity layers
Found a CountingIdentity layer with counter value 40000 in final state after training.
Checking model and state for CountingIdentity layers
Found a CountingIdentity layer with counter value 40000 in final state after training.
Checking model and state for CountingIdentity layers
Found a CountingIdentity layer with counter value 40000 in final state after training.
Checking model and s

In [15]:
def tree_unstack(tree, axis=0):
    leaves, tree_def = jax.tree.flatten(tree)
    array_leaf = next(filter(eqx.is_array, leaves))
    num_out = array_leaf.shape[axis]
    def _safe_unstack(maybe_array):
        if eqx.is_array(maybe_array):
            return jnp.unstack(maybe_array, axis=axis)
        else:
            return num_out*[maybe_array]
    unstacked_leaves = [_safe_unstack(leaf) for leaf in leaves]
    del leaves
    return [tree_def.unflatten(leaves) for leaves in zip(*unstacked_leaves)]

try:
    tree_unstack(inr)
except Exception as e:
    print(e)
    traceback.print_exc()
    pdb.post_mortem()