# Adding your own models:

This notebook is intended for presenting the general workflow of adding further models.

In theory, the framwork allows for training of any parameterized model on the data set, as long as its function call is differentiable w.r.t. to its parameters (once more a lot of heavy lifting is done by the equinox python package (https://docs.kidger.site/equinox/).

**However**: Everything that is defined as a `jax.Array` will be interpreted as a trainable parameter! For instance, one could define the normalization values as a `jax.Array`. These will then also be trained. If intentional it might be okay, but it can also lead to some confusing bugs..

The model needs to provide an API as specified in `mc2/model_interfaces/model_interface.py`.
The necessary functions are `__call__` and `normalized_call`.

In [None]:
# optional setup
%load_ext autoreload
%autoreload 2

import traceback
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # choose cuda-device
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"  # disable preallocation of memory


import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
jax.config.update("jax_platform_name", "cpu")  # optionally run on cpu

In [None]:
from mc2.model_interfaces.model_interface import ModelInterface

In [None]:
ModelInterface.__call__

In [None]:
ModelInterface.normalized_call

You cannot instantiate the default class because it is abstract and you need to inherit from it first:

In [None]:
try:
    ModelInterface()
except TypeError:
    traceback.print_exc()

This is a very, very simple model with one trainable parameter. It simply always predicts a constant value.
Training it would yield the best constant value to approximate the training data set (likely close to the mean of the data, depending on the exact shape of the training loss):

The model:
```
from typing import Callable

import jax
import jax.numpy as jnp
import equinox as eqx


class DummyModel(eqx.Module):
    theta: jax.Array

    def __init__(self, key: jax.random.PRNGKey):
        self.theta = jax.random.normal(key, shape=())

    def __call__(self, x):
        return self.theta * jnp.ones(x.shape)

```

And model interface:
```
from typing import Callable

import jax
import jax.numpy as jnp
import equinox as eqx

from mc2.data_management import Normalizer
from mc2.model_interfaces.model_interface import ModelInterface
from mc2.models.dummy_model import DummyModel

class DummyModelInterface(ModelInterface):
    model: DummyModel
    normalizer: Normalizer
    featurize: Callable = eqx.field(static=True)

    def __call__(
        self,
        B_past: jax.Array,
        H_past: jax.Array,
        B_future: jax.Array,
        T: jax.Array,
    ) -> jax.Array:

        # concatenating and normalizing the data
        B_all = jnp.concatenate([B_past, B_future], axis=1)
        B_all_norm, H_past_norm, T_norm = self.normalizer.normalize(B_all, H_past, T)

        B_past_norm = B_all_norm[:, : B_past.shape[1]]
        B_future_norm = B_all_norm[:, B_past.shape[1] :]

        # performing prediction
        batch_H_pred = self.normalized_call(B_past_norm, H_past_norm, B_future_norm, T_norm)

        # denormalizing predicted value
        batch_H_pred_denorm = jax.vmap(jax.vmap(self.normalizer.denormalize_H))(batch_H_pred)

        return batch_H_pred_denorm

    def normalized_call(
        self,
        B_past_norm: jax.Array,
        H_past_norm: jax.Array,
        B_future_norm: jax.Array,
        T_norm: jax.Array,
        warmup: bool = True,
    ) -> jax.Array:
        batch_H_pred = jax.vmap(self.model)(B_future_norm)
        return batch_H_pred
```

To be able to use this model, you would need to make it importable from within the package. That is, it would need to be added to a python file within the source code.

Additionally, the model needs to be added as an option to the `setup_model`-function, so that it may be chosen as a `model_type` in the training script.
The function can be found in `mc2/model_setup.py`.

An entry like:
```
case "DummyModel":
    model_params_d = dict(key=model_key)
    model = DummyModel(key=model_key)
    mdl_interface_cls = DummyModelInterface
```
needs to be added.

For this specific example, this all has been done. 
You can find the `DummyModel` at `mc2/models/dummy_model.py`, the `DummyModelInterface` at `mc2/model_interfaces/dummy_model_interface.py`, and the added case for the `DummyModel` in `mc2/model_setup.py`.


As a result, we can create such a model using the `setup_model` function:

In [None]:
from mc2.model_setup import setup_model, setup_featurize
from mc2.data_management import Normalizer

In [None]:
identity_function = lambda x : x  # just returns its input, i.e., no featurization
normalizer = Normalizer(B_max=1.0, H_max=1.0, T_max=1.0, norm_fe_max=[], H_transform=identity_function, H_inverse_transform=identity_function)  # placeholder normalizer that does nothing
model, model_parameter_dict = setup_model(
    model_label="DummyModel",
    model_key=jax.random.PRNGKey(0),
    normalizer=normalizer,
    featurize=setup_featurize("reduce", 0, 0)
)
display("model parameters:", model_parameter_dict)
print()
display("model:", model)

And we can also train versions of it using the main training function:

In [None]:
from mc2.runners.rnn_training_jax import train_model_jax

In [None]:
train_model_jax(
    material_name="B",
    model_types=["DummyModel"],
    seeds=[155],
    epochs=100,
    loss_type="MSE",
    disable_f64=True,
)

Checking its performance:

In [None]:
from mc2.utils.model_evaluation import reconstruct_model_from_file, get_exp_ids
from mc2.utils.model_evaluation import plot_model_frequency_sweep, plot_first_predictions

In [None]:
dummy_model_exp_ids = get_exp_ids(material_name="B", model_type="DummyModel")
dummy_model_exp_ids

In [None]:
exp_id = dummy_model_exp_ids[0] # or choose whichever model you want to test..
model = reconstruct_model_from_file(exp_id) 

In [None]:
from mc2.utils.model_evaluation import (
    load_gt_and_pred, plot_worst_predictions, plot_first_predictions, plot_loss_trends
)

In [None]:
seed=exp_id.split("seed")[-1]
gt, pred = load_gt_and_pred(
    exp_id=exp_id,
    seed=seed,
    freq_idx=0
)

In [None]:
plot_worst_predictions(gt, pred);
plt.show()

plot_first_predictions(gt, pred);
plt.show()

plot_loss_trends(exp_id, seed);
plt.show()

In [None]:
from mc2.data_management import DataSet, MaterialSet

In [None]:
material_set = MaterialSet.from_material_name("B")
train_set, eval_set, test_set = material_set.split_into_train_val_test()

In [None]:
plot_model_frequency_sweep(model, test_set, loader_key=jax.random.PRNGKey(21), past_size=1)

sanity checking model output:

In [None]:
print("model_output:", model.normalizer.denormalize_H(model.model.theta))

In [None]:
average_field = []
for frequency_set in train_set:
    average_field.append(jnp.mean(frequency_set.H))
average_H = jnp.mean(jnp.array(average_field))
print("average field value for the data set:", average_H)

The model outputs roughly $0$, while not exact, it is close to the average of the training data for material `B`.
For training a one-parameter model with stochastic gradient descent, the result is acceptable.