# Example of training an INR locally
This notebook provides an example of how to create an INR and train it locally using the tools in this repo.

In [2]:
import pdb
import traceback

import jax
from jax import numpy as jnp
import optax
import wandb

from common_dl_utils.config_creation import Config
import common_jax_utils as cju

wandb.login()

key = jax.random.PRNGKey(12398)
key_gen = cju.key_generator(key)

[34m[1mwandb[0m: Currently logged in as: [33mmaxwell_litsios[0m ([33mbep-circle[0m). Use [1m`wandb login --relogin`[0m to force relogin


We want to train a single INR on `example_data/parrot.png`. We'll use the `CombinedINR` clas from `model_components.inr_modules` together with the `SirenLayer` and `GaussianINRLayer` from `model_components.inr_layers` for the model, and we'll train it using the tools from `inr_utils`.

To do all of this, basically we only need to create a config. We'll use the `common_dl_utils.config_creation.Config` class for this, but this is basically just a dictionary that allows for attribute access-like acces of its elements (so we can do `config.model_type = "CombinedINR"` instead of `config["model_type"] = "CombinedINR"`). You can also just use a dictionary instead.

Then we'll use the tools from `common_jax_utils` to first get a model from this config so we can inspect it, and then just run the experiment specified by the config.

Doing this in a config instead of hard coded might seem like extra work, but consider this:
1. you can serialize this config as a json file or a yaml file to later get the same model and experimental settings back 
   so when you are experimenting with different architectures, if you just store the configs you've used, you can easily recreate previous results
2. when we get to running hyper parameter sweeps, you can easily get these configs (with a pick for the varying hyper parameters) from wandb
   and then run an experiment specified by that config on any machine you want, e.g. on Snellius

In [2]:
config = Config()

# first we specify what the model should look like
config.architecture = './model_components'  # module containing all relevant classes for architectures
# NB if the classes relevant for creating the model are spread over multiple modules, this is no problem
# let config.architecture be the module that contains the "main" model class, and for all other components just specify the module
# or specify the other modules as default modules to the tools in common_jax_utils.run_utils
config.model_type = 'inr_modules.CombinedINR'

config.model_config = Config()
config.model_config.in_size = 3
config.model_config.out_size = 1
config.model_config.terms = [  # CombinedINR uses multiple MLPs and returns the sum of their outputs. These 'terms' are the MLPs
    ('inr_modules.MLPINR.from_config',{
        'hidden_size': 256,
        'num_layers': 5,
        'layer_type': 'inr_layers.SirenLayer',
        'num_splits': 3,
        'activation_kwargs': {'w0':12.},#{'inverse_scale': 5.},
        # 'initialization_scheme':'initialization_schemes.siren_scheme',
        # 'initialization_scheme_kwargs': {'w0': 25.},
        'positional_encoding_layer': ('state_test_objects.py', 'CountingIdentity'),
        # 'post_processor': 'model_components.auxiliary.squeeze_array',
    }),
    # ('inr_modules.MLPINR.from_config',{
    #     'hidden_size': 1024,
    #     'num_layers': 2,
    #     'num_splits': 1,
    #     'layer_type': 'inr_layers.GaussianINRLayer',
    #     'use_complex': False,
    #     'activation_kwargs': {'inverse_scale': 1},
    # })
]

# next, we set up the training loop, including the 'target_function' that we want to mimic
config.trainer_module = './inr_utils/'  # similarly to config.architecture above, here we just specify in what module to look for objects by default
config.trainer_type = 'training.train_with_dataloader_scan'
# config.trainer_type = 'training.train_inr_with_dataloader'


config.dataloader = 'sdf.SDFDataLoader'

config.dataloader_config = {
    "sdf_name": "Armadillo",
    "batch_size": 20000,
    "keep_aspect_ratio":True

}

config.num_cycles = 40
config.steps_per_cycle = 200


config.loss_evaluator = "losses.SDFLossEvaluator"


config.target_function = 'sdf.SDFDataLoader' #see when config. losseval
config.target_function_config = {
    "sdf_name": "Armadillo",
    "batch_size": 10000,
    "keep_aspect_ratio":True,

}

config.state_update_function = ('state_test_objects.py', 'counter_updater')

config.optimizer = 'training.OptimizerFactory.single_optimizer'#'adamw'  # we'll have to add optax to the additional default modules later
config.optimizer_type = 'adamw'
config.optimizer_config = {
    'learning_rate': 1e-5,#1.5e-4
    'weight_decay': 1e-4,
}
config.optimizer_mask = 'masking.array_mask'
config.steps = 1200 #changed from 40000
config.use_wandb = True

# now we want some extra things, like logging, to happen during training
# the inr_utils.training.train_inr function allows for this through callbacks.
# The callbacks we want to use can be found in inr_utils.callbacks
config.after_cycle_callback = 'callbacks.ComposedCallback'
config.after_cycle_callback_config = {
    'callbacks':[
        ('callbacks.print_loss', {'after_every':1}),  # only print the loss every 400th step
        'callbacks.report_loss',  # but log the loss to wandb after every step
        ('callbacks.MetricCollectingCallback', # this thing will help us collect metrics and log images to wandb
             {'metric_collector':'metrics.MetricCollector'}
        ),
        'callbacks.raise_error_on_nan'  # stop training if the loss becomes NaN
    ],
    'show_logs': False
}
# config.after_step_callback = 'callbacks.ComposedCallback'
# config.after_step_callback_config = {
#     'callbacks':[
#         ('callbacks.print_loss', {'after_every':400}),  # only print the loss every 400th step
#         'callbacks.report_loss',  # but log the loss to wandb after every step
#         ('callbacks.MetricCollectingCallback', # this thing will help us collect metrics and log images to wandb
#              {'metric_collector':'metrics.MetricCollector'}
#         ),
#         'callbacks.raise_error_on_nan'  # stop training if the loss becomes NaN
#     ],
#     'show_logs': False
# }

config.after_training_callback = ('state_test_objects.py', 'after_training_callback')

config.metric_collector_config = {  # the metrics for MetricCollectingCallback / metrics.MetricCollector
    'metrics':[
        ('metrics.JaccardIndexSDF', {
            'frequency':'every_n_batches',
            'grid_resolution': 100,
            'num_dims': 3,
            'batch_size': 10000
        }),
        ("metrics.SDFReconstructor",
         {
            'frequency':'every_n_batches',
            'grid_resolution': 100,
            'batch_size': 10000,
        }),
        # todo add view rendering here

    ],
    'batch_frequency': 4,  # compute all of these metrics every 400 batches
    'epoch_frequency': 1  # not actually used
}

#config.after_training_callback = None  # don't care for one now, but you could have this e.g. store some nice loss plots if you're not using wandb 
config.optimizer_state = None  # we're starting from scratch

In [3]:
# let's first see if we get the correct model
try:
    inr = cju.run_utils.get_model_from_config_and_key(
        prng_key=next(key_gen),
        config=config,
        model_sub_config_name_base='model',
        add_model_module_to_architecture_default_module=False, # since the model is already in the default module specified by 'architecture',
    )
except Exception as e:
    traceback.print_exc()
    print(e)
    print('\n')
    pdb.post_mortem()

In [4]:
inr

CombinedINR(
  terms=(
    MLPINR(
      layers=(
        CountingIdentity(
          _embedding_matrix=f32[3],
          state_index=StateIndex(
            marker=<object object at 0x7d935391bdf0>,
            init=i32[]
          )
        ),
        SirenLayer(
          weights=f32[256,3],
          biases=f32[256],
          activation_kwargs={'w0': 12.0}
        ),
        SirenLayer(
          weights=f32[256,256],
          biases=f32[256],
          activation_kwargs={'w0': 12.0}
        ),
        SirenLayer(
          weights=f32[256,256],
          biases=f32[256],
          activation_kwargs={'w0': 12.0}
        ),
        Linear(weights=f32[1,256], biases=f32[1], activation_kwargs={})
      )
    ),
  ),
  post_processor=<function real_part>
)

In [5]:
# check that it works properly
try:
    inr(jnp.zeros(3))
except Exception as e:
    traceback.print_exc()
    print(e)
    print('\n')
    pdb.post_mortem()

In [6]:
inr(jnp.zeros(3))


(Array([0.99375975], dtype=float32), None)

In [7]:
# next we get the experiment from the config using common_jax_utils.run_utils.get_experiment_from_config_and_key
experiment = cju.run_utils.get_experiment_from_config_and_key(
    prng_key=next(key_gen),
    config=config,
    model_kwarg_in_trainer='inr',
    model_sub_config_name_base='model',  # so it looks for "model_config" in config
    trainer_default_module_key='trainer_module',  # so it knows to get the module specified by config.trainer_module
    additional_trainer_default_modules=[optax],  # remember the don't forget to add optax to the default modules? This is that 
    add_model_module_to_architecture_default_module=False,
    initialize=False  # don't run the experiment yet, we want to use wandb
)

In [8]:
experiment

PostponedInitialization(cls=train_with_dataloader_scan, kwargs={'loss_evaluator': PostponedInitialization(cls=SDFLossEvaluator, kwargs={'state_update_function': <function counter_updater at 0x7d92f1761510>}, missing_args=[]), 'dataloader': PostponedInitialization(cls=SDFDataLoader, kwargs={'sdf_name': 'Armadillo', 'batch_size': 20000, 'keep_aspect_ratio': True, 'key': Array([1786414058, 1458264990], dtype=uint32)}, missing_args=[]), 'optimizer': PostponedInitialization(cls=single_optimizer, kwargs={'optimizer_type': <function adamw at 0x7d935d13e200>, 'optimizer_config': {'learning_rate': 1e-05, 'weight_decay': 0.0001}, 'optimizer_mask': <function array_mask at 0x7d9346b228c0>, 'learning_rate_schedule': None, 'schedule_boundaries': None}, missing_args=[]), 'steps_per_cycle': 200, 'num_cycles': 40, 'use_wandb': True, 'after_cycle_callback': PostponedInitialization(cls=ComposedCallback, kwargs={'callbacks': [functools.partial(<function print_loss at 0x7d92f171b1c0>, after_every=1), <func

In [9]:
# and we run the experiment while logging things to wandb
try:
    with wandb.init(
        project='inr_edu_24',
        notes='test',
        tags=['test']
    ) as run:
        results = experiment.initialize()
except Exception as e:
    traceback.print_exc()
    print(e)
    print('\n')
    pdb.post_mortem()


Loss at step 1 is 503.4261474609375.
Loss at step 2 is 151.2399139404297.
Loss at step 3 is 158.7724609375.
Loss at step 4 is 153.7262725830078.
Loss at step 5 is 148.79888916015625.
Loss at step 6 is 154.47230529785156.
Loss at step 7 is 140.12860107421875.
Loss at step 8 is 144.83226013183594.
Loss at step 9 is 142.53469848632812.
Loss at step 10 is 142.410400390625.
Loss at step 11 is 144.37510681152344.
Loss at step 12 is 144.6142120361328.
Loss at step 13 is 142.5659942626953.
Loss at step 14 is 144.45936584472656.
Loss at step 15 is 148.80271911621094.
Loss at step 16 is 142.94775390625.
Loss at step 17 is 141.3892364501953.
Loss at step 18 is 142.63316345214844.
Loss at step 19 is 142.49746704101562.
Loss at step 20 is 139.62646484375.
Loss at step 21 is 136.5791015625.
Loss at step 22 is 142.61839294433594.
Loss at step 23 is 144.36602783203125.
Loss at step 24 is 147.67250061035156.
Loss at step 25 is 139.03599548339844.
Loss at step 26 is 142.9278106689453.
Loss at step 27 is

VBox(children=(Label(value='1507.252 MB of 1576.874 MB uploaded\r'), FloatProgress(value=0.9558483567113274, m…

0,1
batch_within_epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
jaccard_index,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
batch_within_epoch,40.0
epoch,1.0
jaccard_index,0.0
loss,140.49887


In [10]:
import equinox as eqx
from inr_utils.masking import array_mask
array_mask(inr)

CombinedINR(
  terms=(
    MLPINR(
      layers=(
        CountingIdentity(
          _embedding_matrix=True,
          state_index=StateIndex(
            marker=<object object at 0x7d935391bdf0>,
            init=True
          )
        ),
        SirenLayer(weights=True, biases=True, activation_kwargs={'w0': False}),
        SirenLayer(weights=True, biases=True, activation_kwargs={'w0': False}),
        SirenLayer(weights=True, biases=True, activation_kwargs={'w0': False}),
        Linear(weights=True, biases=True, activation_kwargs={})
      )
    ),
  ),
  post_processor=False
)

In [11]:
def is_array(leaf):
    print(f"{leaf=}")
    return eqx.is_array(leaf)
jax.tree_util.tree_map(is_array, [None, jnp.zeros(3), lambda x: x], is_leaf=lambda x: x is None or eqx.is_array(x))

leaf=None
leaf=Array([0., 0., 0.], dtype=float32)
leaf=<function <lambda> at 0x7d938d057130>


[False, True, False]

In [12]:
jax.tree_flatten([None, jnp.zeros(3)])


jax.tree_flatten is deprecated: use jax.tree.flatten (jax v0.4.25 or newer) or jax.tree_util.tree_flatten (any JAX version).



([Array([0., 0., 0.], dtype=float32)], PyTreeDef([None, *]))

In [13]:
inr, losses, optimizer_state, state, loss_evaluator, additional_output = results
inr

CombinedINR(
  terms=(
    MLPINR(
      layers=(
        CountingIdentity(
          _embedding_matrix=f32[3],
          state_index=StateIndex(marker=0, init=_Sentinel())
        ),
        SirenLayer(
          weights=f32[256,3],
          biases=f32[256],
          activation_kwargs={'w0': 12.0}
        ),
        SirenLayer(
          weights=f32[256,256],
          biases=f32[256],
          activation_kwargs={'w0': 12.0}
        ),
        SirenLayer(
          weights=f32[256,256],
          biases=f32[256],
          activation_kwargs={'w0': 12.0}
        ),
        Linear(weights=f32[1,256], biases=f32[1], activation_kwargs={})
      )
    ),
  ),
  post_processor=<function real_part>
)

In [14]:
from state_test_objects import after_training_callback, CountingIdentity
after_training_callback(losses, inr, state)

Checking model and state for CountingIdentity layers
Found a CountingIdentity layer with counter value 8000 in final state after training.


In [15]:
from inr_utils.sdf import SDFDataLoader
data_loader = SDFDataLoader(**config.dataloader_config, key=jax.random.PRNGKey(0))

batch = next(iter(data_loader))
batch[0]

Array([[-0.9475113 ,  0.51643044, -0.732251  ],
       [-0.03327261, -0.5643355 ,  0.07955802],
       [-0.03956697,  0.89269453, -0.5453213 ],
       ...,
       [ 0.40899348, -0.77576375,  0.21623659],
       [ 0.35770416,  0.64567757, -0.31344652],
       [-0.03441858, -0.5913744 ,  0.3327012 ]], dtype=float32)

In [16]:
jax.numpy.linalg.norm(batch[1], axis=-1).min()

Array(0.99999994, dtype=float32)

In [17]:
jax.numpy.linalg.norm(batch[1], axis=-1).max()

Array(1.7320508, dtype=float32)

In [18]:
jax.tree.map(lambda x: x.device, batch)

(CudaDevice(id=0), CudaDevice(id=0), CudaDevice(id=0))

In [19]:
batch[1].shape

(20000, 3)

In [20]:
def get_default_device():
  return jax.config.jax_default_device or jax.local_devices()[0]
get_default_device()

CudaDevice(id=0)

In [21]:
experiment

PostponedInitialization(cls=train_with_dataloader_scan, kwargs={'loss_evaluator': PostponedInitialization(cls=SDFLossEvaluator, kwargs={'state_update_function': <function counter_updater at 0x7d92f1761510>}, missing_args=[]), 'dataloader': PostponedInitialization(cls=SDFDataLoader, kwargs={'sdf_name': 'Armadillo', 'batch_size': 20000, 'keep_aspect_ratio': True, 'key': Array([1786414058, 1458264990], dtype=uint32)}, missing_args=[]), 'optimizer': PostponedInitialization(cls=single_optimizer, kwargs={'optimizer_type': <function adamw at 0x7d935d13e200>, 'optimizer_config': {'learning_rate': 1e-05, 'weight_decay': 0.0001}, 'optimizer_mask': <function array_mask at 0x7d9346b228c0>, 'learning_rate_schedule': None, 'schedule_boundaries': None}, missing_args=[]), 'steps_per_cycle': 200, 'num_cycles': 40, 'use_wandb': True, 'after_cycle_callback': PostponedInitialization(cls=ComposedCallback, kwargs={'callbacks': [functools.partial(<function print_loss at 0x7d92f171b1c0>, after_every=1), <func

In [22]:
loss_evaluator = experiment.kwargs['loss_evaluator'].initialize()
data_loader = experiment.kwargs['dataloader'].initialize()
optimizer = experiment.kwargs['optimizer'].initialize()

In [23]:
import inr_utils
import equinox as eqx

train_step = inr_utils.training.make_sampler_free_train_step(loss_evaluator=loss_evaluator, optimizer=optimizer)

set(leaf.device for leaf in jax.tree.leaves(inr) if eqx.is_array(leaf))

{CudaDevice(id=0)}

In [24]:
print(set(leaf.device for leaf in jax.tree.leaves(state) if eqx.is_array(leaf)))

{CudaDevice(id=0)}


In [25]:
print(set(leaf.device for leaf in jax.tree.leaves(optimizer_state) if eqx.is_array(leaf)))

{CudaDevice(id=0)}


In [26]:
batch = next(iter(data_loader))

In [27]:
print(set(leaf.device for leaf in jax.tree.leaves(batch) if eqx.is_array(leaf)))

{CudaDevice(id=0)}


In [28]:
out = train_step(inr, batch, optimizer_state, state)

E0130 18:11:39.454519 2601660 buffer_comparator.cc:157] Difference at 2007: -0.0535316, expected -0.618668
E0130 18:11:39.454545 2601660 buffer_comparator.cc:157] Difference at 3466: 3.75414, expected 4.34357
E0130 18:11:39.454552 2601660 buffer_comparator.cc:157] Difference at 5808: 1.5179, expected 2.07677
E0130 18:11:39.454554 2601660 buffer_comparator.cc:157] Difference at 5875: 0.961105, expected 1.66135
E0130 18:11:39.454557 2601660 buffer_comparator.cc:157] Difference at 6446: 5.06743, expected 4.17685
E0130 18:11:39.454559 2601660 buffer_comparator.cc:157] Difference at 6663: -0.518723, expected -0.749084
E0130 18:11:39.454561 2601660 buffer_comparator.cc:157] Difference at 6794: 1.70668, expected 2.22054
E0130 18:11:39.454566 2601660 buffer_comparator.cc:157] Difference at 8168: 1.41833, expected 0.894318
E0130 18:11:39.454577 2601660 buffer_comparator.cc:157] Difference at 12705: 1.84048, expected 2.29286
E0130 18:11:39.454578 2601660 buffer_comparator.cc:157] Difference at 1

In [29]:
print(set(leaf.device for leaf in jax.tree.leaves(out) if eqx.is_array(leaf)))

{CudaDevice(id=0)}


In [30]:
%%timeit
# this does run on the gpu
out = train_step(inr, batch, optimizer_state, state)
jax.block_until_ready(out)

13.7 ms ± 115 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [31]:
data_iter = iter(data_loader)

In [32]:
%%timeit
next(data_iter)

750 μs ± 7.65 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
