# Example of training a hypernetwork locally
This notebook provides an example of how to create a hypernetwork and train it locally using the tools in this repo. We'll train the hypernetwork on MNIST, and we'll use the architecture created in hypernetwork_examples/hypernetwork_ae.py

In [1]:
import pdb
import traceback
import pprint

import jax
from jax import numpy as jnp
import optax
import wandb

from common_dl_utils.config_creation import Config
import common_jax_utils as cju

wandb.login()

key = jax.random.PRNGKey(12398)
key_gen = cju.key_generator(key)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33msimon-martinus-koop[0m ([33mnld[0m). Use [1m`wandb login --relogin`[0m to force relogin
2024-09-14 14:11:55.214586: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.2 which is older than the PTX compiler version (12.6.20). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [2]:
config = Config()

config.architecture = './model_components'
# this time however, part of the network architecture lives somewhere else:
config.model_type = ('./hypernetwork_example/hypernetwork_ae.py', 'Hypernetwork')  # we specify the module and the class within that module

config.model_config = dict(
    in_features = 1,
    conv_features = 64,
    shared_features = 2048,
    mlp_hidden_features = 2048,
    mlp_depth = 2,
    inr_hidden_size = 64,
    inr_depth = 3,
    low_rank = 10,
    kernel_size = 3,
    groups = 1,
    layer_type = "inr_layers.SirenLayer",
    layer_kwargs = {'w0': 12},
)

# next, we set up  the training loop
config.trainer_module = './hypernetwork_utils/'
config.trainer_type = 'training.Trainer'
config.train_loader = ('./hypernetwork_example/mnist.py', 'get_train_loader')
config.validation_loader = ('./hypernetwork_example/mnist.py', 'get_validation_loader')
config.batch_size = 16  # batch size for data loaders
config.shuffle = True
config.loss_function = 'inr_utils.losses.scaled_mse_loss'  # inr_utils is imported by hypernetwork_utils
config.location_sampler = ('inr_utils.sampling.GridSubsetSampler',{  # NB when doing this (str, dict) thing,
    # where the dict determines the config options for the thing str points to,
    # default values of the thing str points to will take presedence over values 
    # specified in config (but not in dict)
    'size': [28, 28],
    'batch_size': 400,
    'allow_duplicates': False,
})
config.target = 'inr_utils.images.ArrayInterpolator'
config.target_config = {
    'interpolation_method': 'inr_utils.images.make_piece_wise_constant_interpolation',
    'scale_to_01': False,  # this is already handled by the dataloader
    'channels_first': True,  # because the dataloader puts channels first
}

config.optimizer = "training.OptimizerFactory.single_optimizer"
config.optimizer_type = 'adamw'  # don't forget to add optax to additional default modules
config.optimizer_config = {
    #'learning_rate': 0.000015,  # is handled by the learning_rate_schedule
    'weight_decay': 0.,
}
config.learning_rate_schedule = ('exponential_decay', { 
    'init_value': 0.000015,
    'transition_steps': 12_000,
    'decay_rate': .9,
    'transition_begin': 12_000
})

config.sub_steps_per_datapoint = 2  # for every batch of images, take two gradient update steps with different coordinates
config.epochs = 20

config.metric_collector = 'metrics.MetricCollector'
config.metric_collector_config = Config()
config.metric_collector_config.batch_frequency = 600
config.metric_collector_config.epoch_frequency = 2  # compute some metrics after every 2 epochs, e.g. because PlotOnGrid2D is slow (it probably isn't, but PlotOnGrid3D definitely is)
config.metric_collector_config.metrics = [
    'training.ValidationLoop',
    ('metrics.PlotOnGrid2D', {'grid': 28, 'batch_size': 0, 'frequency': 'every_n_epochs', 'requires_scaling': True}),
    ('metrics.MSEOnFixedGrid', {'grid': 28, 'num_dims': 2, 'batch_size': 0, 'frequency': 'every_n_batches'}),
    ('metrics.LossStandardDeviation', {'window_size': 100, 'frequency': 'every_batch'})
]

config.after_step_callback = 'callbacks.ComposedCallback'
config.after_step_callback_config = {
    'callbacks': [
        ('callbacks.print_loss', {'after_every': 400}),
        'callbacks.raise_error_on_nan'
    ]
}
config.after_epoch_callback = None
config.use_wandb = True

In [3]:
# let's first see if we get the correct model
try:
    model = cju.run_utils.get_model_from_config_and_key(
        prng_key=next(key_gen),
        config=config,
        model_sub_config_name_base='model',
        add_model_module_to_architecture_default_module=True, # because this time much of the model is not in the module specified by 'architecture'
    )
except Exception as e:
    traceback.print_exc()
    print(e)
    print('\n')
    pdb.post_mortem()

In [4]:
model(jnp.zeros((1, 28, 28)))

MLPINR(
  input_layer=SirenLayer(
    weights=f32[64,2],
    biases=f32[64],
    activation_kwargs={'w0': 12}
  ),
  hidden_layers=[
    SirenLayer(weights=f32[64,64], biases=f32[64], activation_kwargs={'w0': 12}),
    SirenLayer(weights=f32[64,64], biases=f32[64], activation_kwargs={'w0': 12})
  ],
  output_layer=Linear(weights=f32[1,64], biases=f32[1], activation_kwargs={}),
  post_processor=<function real_part>
)

In [5]:
model(jnp.zeros((1, 28, 28)))(jnp.zeros((2,)))

Array([-0.10312119], dtype=float32)

Now, let's set up the experiment

In [6]:
try:
    trainer = cju.run_utils.get_experiment_from_config_and_key(
        prng_key=next(key_gen),
        config=config,
        model_kwarg_in_trainer='hypernetwork',  # Trainer.__init__ takes a `hypernetwork` parameter to which the model that is to be trained should be passed
        trainer_default_module_key='trainer_module',
        additional_trainer_default_modules=[optax],
        add_model_module_to_architecture_default_module=True,
        model_sub_config_name_base='model',
    )
except Exception as e:
    traceback.print_exc()
    print(e)
    print('\n')
    pdb.post_mortem()

And let's run it while logging to wandb

In [7]:
with wandb.init(
    project='inr_edu_24',
    notes='hypernetwork test',
    tags=('test',)
) as run:
    run.log({'used_config': pprint.pformat(config)})
    results = trainer.train(next(key_gen))

Start training for 20 epochs with 3750 batches per epoch and 2 gradient steps per batch.
Start epoch 1.
    Loss at step 400 is 0.5202123522758484.
    Loss at step 800 is 0.3814256191253662.
    Loss at step 1200 is 0.310028076171875.
    Loss at step 1600 is 0.2519527077674866.
    Loss at step 2000 is 0.25903552770614624.
    Loss at step 2400 is 0.19799944758415222.
    Loss at step 2800 is 0.2446938008069992.
    Loss at step 3200 is 0.1674477756023407.
    Loss at step 3600 is 0.20979449152946472.
    Finished epoch 1 with average loss: 0.3153599798679352.
    Start Validation Loop
    validation loss: 0.1773703545331955 +/- 0.08626064658164978
Start epoch 2.
    Loss at step 400 is 0.15975871682167053.
    Loss at step 800 is 0.18665388226509094.
    Loss at step 1200 is 0.14192450046539307.
    Loss at step 1600 is 0.17154403030872345.
    Loss at step 2000 is 0.1467149555683136.
    Loss at step 2400 is 0.13949400186538696.
    Loss at step 2800 is 0.1282154619693756.
    Loss

VBox(children=(Label(value='69.657 MB of 69.657 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
batch/MSE_on_fixed_grid,█▄▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
batch/MSE_on_fixed_grid_std,█▇▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁
batch/global_step,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▆▆▆▆▆▇▇▇▇▇█████
batch/loss,█▆▄▄▄▄▃▃▃▄▄▂▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
batch/loss_std_over_100_steps,█▇▆▅▆▃▃▃▂▂▂▂▃▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁
batch/step_within_epoch,▂▅▆▆▇▂▃▄▅▇▇▂█▆▂▅█▄▄▇▄▁▇▂▂▆▇▂▃▃▆▇▆▇▂▂▆▁▆█
batch_within_epoch,▂▂▃▃▃▁▂▃▆▅▅▇▃▂▇█▁▁▂▃▆▆▄▆▆▂▂▂▅▅█▁▂▇▄▇▇▅▅█
epoch,▁▁▁▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇███████
epoch/loss,█▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
epoch/validation/loss,█▆▄▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁

0,1
batch/MSE_on_fixed_grid,0.00496
batch/MSE_on_fixed_grid_std,0.00288
batch/global_step,75000
batch/loss,0.03992
batch/loss_std_over_100_steps,0.00625
batch/step_within_epoch,3749
batch_within_epoch,3750
epoch,20
epoch/loss,0.04102
epoch/validation/loss,0.04101
