# Example of training an INR locally
This notebook provides an example of how to create an INR and train it locally using the tools in this repo.

In [2]:
import pdb
import traceback

import jax
from jax import numpy as jnp
import optax
import wandb

from common_dl_utils.config_creation import Config
import common_jax_utils as cju

wandb.login()

key = jax.random.PRNGKey(12398)
key_gen = cju.key_generator(key)

We want to train a single INR on `example_data/parrot.png`. We'll use the `CombinedINR` clas from `model_components.inr_modules` together with the `SirenLayer` and `GaussianINRLayer` from `model_components.inr_layers` for the model, and we'll train it using the tools from `inr_utils`.

To do all of this, basically we only need to create a config. We'll use the `common_dl_utils.config_creation.Config` class for this, but this is basically just a dictionary that allows for attribute access-like acces of its elements (so we can do `config.model_type = "CombinedINR"` instead of `config["model_type"] = "CombinedINR"`). You can also just use a dictionary instead.

Then we'll use the tools from `common_jax_utils` to first get a model from this config so we can inspect it, and then just run the experiment specified by the config.

Doing this in a config instead of hard coded might seem like extra work, but consider this:
1. you can serialize this config as a json file or a yaml file to later get the same model and experimental settings back 
   so when you are experimenting with different architectures, if you just store the configs you've used, you can easily recreate previous results
2. when we get to running hyper parameter sweeps, you can easily get these configs (with a pick for the varying hyper parameters) from wandb
   and then run an experiment specified by that config on any machine you want, e.g. on Snellius

In [3]:
config = Config()

# first we specify what the model should look like
config.architecture = './model_components'  # module containing all relevant classes for architectures
# NB if the classes relevant for creating the model are spread over multiple modules, this is no problem
# let config.architecture be the module that contains the "main" model class, and for all other components just specify the module
# or specify the other modules as default modules to the tools in common_jax_utils.run_utils
config.model_type = 'inr_modules.CombinedINR'

config.model_config = Config()
config.model_config.in_size = 2
config.model_config.out_size = 3
config.model_config.terms = [  # CombinedINR uses multiple MLPs and returns the sum of their outputs. These 'terms' are the MLPs
    # ('inr_modules.MLPINR.from_config',{
    #     'hidden_size': 256,
    #     'num_layers': 5,
    #     'layer_type': 'inr_layers.RealWire',
    #     'num_splits': 1,
    #     'use_complex': False,
    #     'activation_kwargs': {'w0': 15., 's0':1.}
    # }),
    # ('inr_modules.MLPINR.from_config',{
    #     'hidden_size': 1024,
    #     'num_layers': 2,
    #     'num_splits': 1,
    #     'layer_type': 'inr_layers.GaussianINRLayer',
    #     'use_complex': False,
    #     'activation_kwargs': {'inverse_scale': 1},
    # })
    # ('inr_modules.MLPINR.from_config',{
    #     'hidden_size': 256,
    #     'num_layers': 5,
    #     'layer_type': 'inr_layers.FinerLayer',  # Include the FinerLayer here
    #     'num_splits': 1,
    #     'use_complex': False,
    #     'activation_kwargs': {'w0': 20}  # Set the w0 hyperparameter for the FINER activation
    # }),
    ('inr_layers.FinerLayer', {
        'in_features': config.model_config.in_size,
        'out_features': config.model_config.out_size,
        'hidden_layers': 3,  # Number of hidden layers for FINER
        'hidden_size': 256,  # Size of each hidden layer
        'omega': 30  # Frequency parameter for FINER activation
    })
]

# next, we set up the training loop, including the 'target_function' that we want to mimic
config.trainer_module = './inr_utils/'  # similarly to config.architecture above, here we just specify in what module to look for objects by default
config.trainer_type = 'training.train_inr'
config.target_function = 'images.ContinuousImage'
config.target_function_config = {
    'image': './example_data/parrot.png',
    'scale_to_01': True,
    'interpolation_method': 'images.make_piece_wise_constant_interpolation'
}
config.loss_function = 'losses.scaled_mse_loss'
config.sampler = ('sampling.GridSubsetSampler',{  # samples coordinates in a fixed grid, that should in this case coincide with the pixel locations in the image
    'size': [2040, 1356],
    'batch_size': 2000,
    'allow_duplicates': False,
})

config.optimizer = 'adam'  # we'll have to add optax to the additional default modules later
config.optimizer_config = {
    'learning_rate': 1.5e-4
}
config.steps = 40000
config.use_wandb = True

# now we want some extra things, like logging, to happen during training
# the inr_utils.training.train_inr function allows for this through callbacks.
# The callbacks we want to use can be found in inr_utils.callbacks
config.after_step_callback = 'callbacks.ComposedCallback'
config.after_step_callback_config = {
    'callbacks':[
        ('callbacks.print_loss', {'after_every':400}),  # only print the loss every 400th step
        'callbacks.report_loss',  # but log the loss to wandb after every step
        ('callbacks.MetricCollectingCallback', # this thing will help us collect metrics and log images to wandb
            {'metric_collector':'metrics.MetricCollector'}
        ),
        'callbacks.raise_error_on_nan'  # stop training if the loss becomes NaN
    ],
    'show_logs': False
}

config.metric_collector_config = {  # the metrics for MetricCollectingCallback / metrics.MetricCollector
    'metrics':[
        ('metrics.PlotOnGrid2D', {'grid': 256, 'batch_size':8*256, 'frequency':'every_n_batches'}),  
        # ^ plots the image on this fixed grid so we can visually inspect the inr on wandb
        ('metrics.MSEOnFixedGrid', {'grid': [2040, 1356], 'batch_size':2040, 'frequency': 'every_n_batches'})
        # ^ compute the MSE with the actual image pixels
    ],
    'batch_frequency': 400,  # compute all of these metrics every 400 batches
    'epoch_frequency': 1  # not actually used
}

config.after_training_callback = None  # don't care for one now, but you could have this e.g. store some nice loss plots if you're not using wandb 
config.optimizer_state = None  # we're starting from scratch

In [4]:
# let's first see if we get the correct model
try:
    inr = cju.run_utils.get_model_from_config_and_key(
        prng_key=next(key_gen),
        config=config,
        model_sub_config_name_base='model',
        add_model_module_to_architecture_default_module=False, # since the model is already in the default module specified by 'architecture',
    )
except Exception as e:
    traceback.print_exc()
    print(e)
    print('\n')
    pdb.post_mortem()

Traceback (most recent call last):
  File "/tmp/ipykernel_17244/1141651026.py", line 3, in <module>
    inr = cju.run_utils.get_model_from_config_and_key(
  File "/home/abdtab/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/common_jax_utils/run_utils.py", line 115, in get_model_from_config_and_key
    return un_initialized_model.initialize()
  File "/home/abdtab/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/common_dl_utils/config_realization.py", line 182, in initialize
    processed_self_kwargs = tree_map(
  File "/home/abdtab/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/common_dl_utils/trees.py", line 126, in tree_map
    {
  File "/home/abdtab/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/common_dl_utils/trees.py", line 127, in <dictcomp>
    key: tree_map(sub_tree, func, is_leaf=is_leaf)
  File "/home/abdtab/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/common_dl_utils/trees.py", line 132, in tree_map
    return type(tree)(tree_map(s

Can't instantiate abstract class FinerLayer with abstract class attributes {'_activation_function'}


> [0;32m/home/abdtab/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/equinox/_better_abstract.py[0m(223)[0;36m__call__[0;34m()[0m
[0;32m    221 [0;31m        [0;32mif[0m [0mlen[0m[0;34m([0m[0mcls[0m[0;34m.[0m[0m__abstractclassvars__[0m[0;34m)[0m [0;34m>[0m [0;36m0[0m[0;34m:[0m  [0;31m# pyright: ignore[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    222 [0;31m            [0mabstract_class_vars[0m [0;34m=[0m [0mset[0m[0;34m([0m[0mcls[0m[0;34m.[0m[0m__abstractclassvars__[0m[0;34m)[0m  [0;31m# pyright: ignore[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 223 [0;31m            raise TypeError(
[0m[0;32m    224 [0;31m                [0;34mf"Can't instantiate abstract class {cls.__name__} with abstract class "[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    225 [0;31m                [0;34mf"attributes {abstract_class_vars}"[0m[0;34m[0m

In [None]:
inr

NameError: name 'inr' is not defined

In [None]:
# check that it works properly
inr(jnp.zeros(2))

TypeError: mul got incompatible shapes for broadcasting: (256,), (2,).

In [None]:
# next we get the experiment from the config using common_jax_utils.run_utils.get_experiment_from_config_and_key
experiment = cju.run_utils.get_experiment_from_config_and_key(
    prng_key=next(key_gen),
    config=config,
    model_kwarg_in_trainer='inr',
    model_sub_config_name_base='model',  # so it looks for "model_config" in config
    trainer_default_module_key='trainer_module',  # so it knows to get the module specified by config.trainer_module
    additional_trainer_default_modules=[optax],  # remember the don't forget to add optax to the default modules? This is that 
    add_model_module_to_architecture_default_module=False,
    initialize=False  # don't run the experiment yet, we want to use wandb
)

In [None]:
# and we run the experiment while logging things to wandb
with wandb.init(
    project='inr_edu_24',
    notes='test',
    tags=['test']
) as run:
    results = experiment.initialize()

Traceback (most recent call last):
  File "/tmp/ipykernel_40872/4160894630.py", line 7, in <module>
    results = experiment.initialize()
  File "/home/abdtab/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/common_dl_utils/config_realization.py", line 195, in initialize
    return cls(**processed_self_kwargs)
  File "/home/abdtab/INR_BEP/inr_utils/training.py", line 152, in train_inr
    inr, optimizer_state, loss = train_step(inr, optimizer_state, next(key_gen))
  File "/home/abdtab/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/equinox/_jit.py", line 242, in __call__
    return self._call(False, args, kwargs)
  File "/home/abdtab/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/equinox/_module.py", line 1078, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
  File "/home/abdtab/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/equinox/_jit.py", line 215, in _call
    out = self._cached(dynamic_donate, dynamic_nodonate, static)
  File 

VBox(children=(Label(value='0.015 MB of 0.015 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

TypeError: mul got incompatible shapes for broadcasting: (256,), (2,).

In [2]:
# Import the necessary modules and functions
from model_components.inr_layers import FinerLayer
import model_components.activation_functions as act
import optax
from PIL import Image
import jax
from jax import numpy as jnp
import equinox as eqx

from common_jax_utils import key_generator
import model_components.auxiliary as aux
import numpy as np

# Utility functions to load and prepare the image
def load_image(image_path):
    img = Image.open(image_path).convert("RGB")
    img = np.array(img) / 255.0  # Normalize pixel values to [0, 1]
    return img

def prepare_data(img):
    height, width, _ = img.shape
    x_coords = np.linspace(-1, 1, width)
    y_coords = np.linspace(-1, 1, height)
    coords = np.array(np.meshgrid(x_coords, y_coords)).reshape(2, -1).T  # Shape: (width*height, 2)
    rgb_values = img.reshape(-1, 3)  # Flatten RGB values to match coords
    return jnp.array(coords), jnp.array(rgb_values)

# Initialize the FINER model
class FINERModel(eqx.Module):
    layers: list

    def __init__(self, in_size, out_size, hidden_layers, hidden_features, omega, key):
        keys = jax.random.split(key, hidden_layers + 2)
        
        # Define layers
        self.layers = []
        
        # Input layer with first_omega
        self.layers.append(FinerLayer(in_size, hidden_features, omega=omega, key=keys[0], is_first=True, fbs=0.1))
        
        # Hidden layers with hidden_omega
        for i in range(1, hidden_layers + 1):
            self.layers.append(FinerLayer(hidden_features, hidden_features, omega=omega, key=keys[i]))
        
        # Output layer (no activation)
        self.layers.append(FinerLayer(hidden_features, out_size, omega=omega, key=keys[-1], is_last=True))

    def __call__(self, x):
        # Pass input through each layer
        for layer in self.layers:
            x = layer(x)
        return x

# Image path and load image data
image_path = "./example_data/parrot.png"  # Replace with your actual image path
img = load_image(image_path)
coords, rgb_values = prepare_data(img)

# Initialize model
key = jax.random.PRNGKey(0)
finer_model = FINERModel(in_size=2, out_size=3, hidden_layers=3, hidden_features=256, omega=30, key=key)

# Define loss function
def loss_fn(params, coords, rgb_values):
    pred_rgb = finer_model(coords)  # Predict RGB values
    return jnp.mean((pred_rgb - rgb_values) ** 2)

# Initialize optimizer
optimizer = optax.adam(learning_rate=1e-4)
opt_state = optimizer.init(finer_model)

# Training loop
num_steps = 5000

@jax.jit
def train_step(finer_model, opt_state, coords, rgb_values):
    # Compute loss and gradients
    loss, grads = jax.value_and_grad(loss_fn)(finer_model, coords, rgb_values)
    updates, opt_state = optimizer.update(grads, opt_state)
    finer_model = eqx.apply_updates(finer_model, updates)
    return finer_model, opt_state, loss

# Train the model
for step in range(num_steps):
    finer_model, opt_state, loss = train_step(finer_model, opt_state, coords, rgb_values)
    if step % 1000 == 0:
        print(f"Step {step}, Loss: {loss}")

# Generate reconstructed image
reconstructed_rgb = finer_model(coords).reshape(img.shape)
reconstructed_img = (np.clip(reconstructed_rgb, 0, 1) * 255).astype(np.uint8)
Image.fromarray(reconstructed_img).show()


TypeError: Can't instantiate abstract class FinerLayer with abstract class attributes {'_activation_function'}