# MLIP examples: Adding a new model

This is an *advanced tutorial* for users with machine learning expertise to build new models into the *mlip* library, in order to benefit from the full set of tools integrated. The library is conceived to allow flexible addition of models, and of different components (e.g. loggers, loss functions) into a single ecosystem. Users are therefore encouraged to re-use the existing code and where appropriate contribute with additional tools that can be used by the community. 

**This noteboow aims at showcasing:**
- **The MLIP model hierarchy** detailing the relevant layers used to create a model from scratch
- **A simple example of new model integration** using a constant MLIP network as illustration
- **A more advanced example** using a simple message passing network

**Install and required imports**

As a first step, we will run the installation of the *mlip* library directly from pip. We also install the appropriate Jax CUDA backend to run on GPU (comment it out to run on CPU). In this notebook, we will not run any simulation and therefore do not install Jax-MD, for details on how to do so, please refer to our *simulation* tutorial. Note that if you have ran another tutorial in the same environment, this installation is not required. Please refer to [our installation page](https://instadeepai.github.io/mlip/installation/index.html) for more information.

In [None]:
%pip install mlip "jax[cuda12]==0.4.33"

# Use this instead for installation without GPU:
# %pip install mlip

In [None]:
import jax
import jax.numpy as jnp
from jax import Array

import flax.linen as nn
import pydantic

from mlip.models.atomic_energies import get_atomic_energies

# Required classes for create a novel model 
from mlip.models.mlip_network import MLIPNetwork
from mlip.models import ForceFieldPredictor, ForceField

## 1. The *mlip* model hierarchy

The *mlip* library currently **relies on the two following layers for model definitions**:
- [`MLIPNetwork`][MLIPNetwork] is a base class for GNNs that **computes node-wise energy** summands from edge vectors, node species, and graph edges passed as `senders` and `receivers` index arrays.
- [`ForceFieldPredictor`][ForceFieldPredictor] is a generic wrapper around any [`MLIPNetwork`][MLIPNetwork].

  It gathers **total energy, forces (and, if required, stress)** in the [`Prediction`](https://instadeepai.github.io/mlip/api_reference/models/prediction.html) dataclass, by summing the node energies obtained from [`MLIPNetwork`][MLIPNetwork] on a [`jraph.GraphsTuple`](https://jraph.readthedocs.io/en/latest/api.html) object, and differentiating with respect to positions (and unit cell).


For convenience, our training loop and simulation engines finally work with [`ForceField`](https://instadeepai.github.io/mlip/api_reference/models/force_field.html) objects that **wrap a force field predictor and its learnable parameters within a frozen dataclass object**.

For illustration, in this notebook we will

2. Define a very simple model that returns constant energies,
3. Define a more involved GNN model without equivariance constraints.

[MLIPNetwork]: https://instadeepai.github.io/mlip/api_reference/models/mlip_network.html
[ForceFieldPredictor]: https://instadeepai.github.io/mlip/api_reference/models/predictor.html
[ForceField]: https://instadeepai.github.io/mlip/api_reference/models/force_field.html
[Prediction]: https://instadeepai.github.io/mlip/api_reference/models/prediction.html

---

## 2. Constant MLIPNetwork (atomic energies)

### a. *Config and DatasetInfo*

To facilitate model loading and saving, our [`MLIPNetwork`](https://instadeepai.github.io/mlip/api_reference/models/mlip_network.html) class **gathers (almost) all of their hyperparameters within a `pydantic.BaseModel` subclass**. Their class attribute `.Config` points to this configuration class. Only exceptions consist of hyperparameters that are data dependent, and might
conflict with the data processing pipeline.

This is why [`MLIPNetwork`](https://instadeepai.github.io/mlip/api_reference/models/mlip_network.html) **also accept a [`DatasetInfo`](https://instadeepai.github.io/mlip/api_reference/data/dataset_info.html) object** upon initialization, that notably stores:
- `cutoff_distance_angstrom : float`
- `atomic_energies_map : dict[int, float]`
- `avg_num_neighbours : float`
- and some other data computed when processing the dataset.

This way, we are sure that our models can only be used in the context they were trained for, and will not be evaluated e.g. on atomic numbers they have never seen. We create a dummy [`DatasetInfo`](https://instadeepai.github.io/mlip/api_reference/data/dataset_info.html) for the purpose of this example:

In [None]:
from mlip.data import DatasetInfo

# Dummy `DatasetInfo` for H, C, N, O 
# which have atomic numbers 1, 6, 7, 8 respectively
dataset_info = DatasetInfo(
    atomic_energies_map={
        1: -100.0,
        6: -600.0,
        7: -700.0,
        8: -800.0,
    },
    cutoff_distance_angstrom = 5.0,
)

During default data preprocessing, the `atomic_energies_map` dictionary is computed by least squares regression. This dictionary contains the average contribution of each atomic specie, which may be large to account for its full electronic cloud.

### b. *Constant node energies*

To illustrate the MLIP model hierarchy, we construct **a very simple `ForceField` that only returns the sum of atomic contributions**:
- The error on energies should be much smaller than the total energy of the structure
- However the forces will only return 0, because atoms are treated as isolated and the energy does not depend on the positions.

In [None]:
class ConstantMLIPConfig(pydantic.BaseModel):
    learnable: bool

class ConstantMLIP(MLIPNetwork):
    # arguments to `ConstantMLIP.__init__`
    config: ConstantMLIPConfig
    dataset_info : DatasetInfo
    # reference to `ConstantMLIP.Config` sister class
    Config = ConstantMLIPConfig

    @nn.compact
    def __call__(self, edge_vectors, node_species, senders, receivers):
        num_species = len(self.dataset_info.atomic_energies_map)
        atomic_energies = get_atomic_energies(self.dataset_info)
        if self.config.learnable:
            num_species = len(self.dataset_info.atomic_energies_map)
            atomic_energies = self.param(
                "atomic_energies",
                lambda _ : atomic_energies,
            )
        node_energies = atomic_energies[node_species]
        return node_energies

### c. *Constant force field*

Now that we have defined this simple `ConstantMLIP` subclass, we can already define a state-holding [`ForceField`](https://instadeepai.github.io/mlip/api_reference/models/force_field.html) object. The quickest (but slightly opaque) way is to use the helper classmethod `ForceField.from_mlip_network()`:

In [None]:
# constant_mlip : (vectors, species, senders, receivers) -> node_energies
constant_mlip = ConstantMLIP(
    config=ConstantMLIP.Config(learnable=True),
    dataset_info=dataset_info,
)

# force_field : graph -> predictions
force_field = ForceField.from_mlip_network(
    constant_mlip,
    predict_stress=False,
    seed=123,
)

# N.B. force_field is not a flax module! it wraps predictor + params
print(force_field.predictor, force_field.params, sep="\n")

For the sake of transparency, let us detail what is actually being done here.

First, a [`ForceFieldPredictor`](https://instadeepai.github.io/mlip/api_reference/models/predictor.html) instance is created on top of the `constant_mlip` model.

Then, random parameters are initialized by calling the predictor's `.init()` method on a random seed and a dummy graph. These two objects (the predictor and its parameter dict) are wrapped for convenience inside the [`ForceField`](https://instadeepai.github.io/mlip/api_reference/models/force_field.html) dataclass. The following is thus equivalent:

In [None]:
# constant_predictor: graph -> predictions
constant_predictor = ForceFieldPredictor(
    constant_mlip,
    predict_stress=False,
)

force_field = ForceField.init(constant_predictor, seed=123)

We'll see below how to manually initialize parameters, and call the [`ForceField`](https://instadeepai.github.io/mlip/api_reference/models/force_field.html) default constructor : this only requires an input graph.

**N.B.** The [`ForceField`](https://instadeepai.github.io/mlip/api_reference/models/force_field.html) dataclass is frozen: this is to prevent any stateful operations to be performed on the parameters, which would be incompatible with JAX compilation and tracing mechanisms. You can think of [`ForceField`](https://instadeepai.github.io/mlip/api_reference/models/force_field.html) as holding the _state_ of a learnable [`ForceFieldPredictor`](https://instadeepai.github.io/mlip/api_reference/models/predictor.html), although _it remains immutable_.

### d. *Evaluating the force field*

In order to illustrate the signatures and outputs of the models,
we'll need an example [`jraph.GraphsTuple`](https://jraph.readthedocs.io/en/latest/api.html) input.

In [None]:
from jraph import GraphsTuple
from mlip.data import ChemicalSystem
import numpy as np
from mlip.data.helpers import create_graph_from_chemical_system

# Example H2O molecule:
#   - H (Z=1) has specie index 0
#   - O (Z=8) has specie index 3 (H, C, N come first)
system = ChemicalSystem(
    atomic_numbers = np.array([1, 8, 1]),
    atomic_species = np.array([0, 3, 0]),
    positions = np.array(
        [[-.5, .0, .0], [.0, .2, .0], [.5, .0, .0]]
    ),
)

graph = create_graph_from_chemical_system(
    chemical_system = system,
    distance_cutoff_angstrom = 5.,
    # GOTCHA: need >= 1 dummy graph to sum node_energies correctly
    batch_it_with_minimal_dummy = True
)

With this graph at hand, we can now apply the flax `nn.Module` predictor to return energy and forces.

Recall that [flax.linen modules](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html) have the following methods:
- `.init()` returns initial parameters from a random number generator (RNG) key and inputs,
- `.apply()` returns outputs from learnable parameters and inputs

You might be surprised that 2 energy values are actually returned: this is because [jraph](https://jraph.readthedocs.io/en/latest/api.html) assumes that batches of graphs **always contain at least one dummy graph**.

In [None]:
# Initialize parameters
key = jax.random.key(123)
params = constant_predictor.init(key, graph)
print("Parameters:\n", params, "\n")


# Evaluate force field predictor on H2O graph
prediction = constant_predictor.apply(params, graph)
print("Prediction:\n", prediction)

### e. *Wrapping the model state in ForceField*

In order to hide the `flax` logic for downstream applications, our `TrainingLoop` class takes in and returns a [`ForceField`](https://instadeepai.github.io/mlip/api_reference/models/force_field.html) object that simply wraps the predictor with its initial and final parameters respectively.

This frozen dataclass can then be easily passed to the [`SimulationEngine`](https://instadeepai.github.io/mlip/api_reference/simulation/simulation_engine.html), or just saved for later (by JSON-serializing the MLIPNetwork's `.config` and `.dataset_info`, and dumping the flattened parameter dict as `.npz`).

In [None]:
from mlip.models.model_io import save_model_to_zip, load_model_from_zip

force_field_0 = ForceField(
    predictor = constant_predictor,
    params = params,
)

# We recommend to keep the MLIPNetwork class name in zip
save_model_to_zip("ConstantMLIP-ff.zip", force_field_0)

# As loading requires the MLIPNetwork class (higher layers being agnostic)
force_field_1 = load_model_from_zip(ConstantMLIP, "ConstantMLIP-ff.zip")

Note that [`ForceField`](https://instadeepai.github.io/mlip/api_reference/models/force_field.html) instances are also callable, and morally equivalent to `functools.partial(predictor.apply, params)`.

This means they can be directly evaluated on a graph by forgetting about the (frozen) learnable parameters, as done during simulation.

In [None]:
prediction = jax.jit(force_field_0)(graph)
prediction

In theory, the [`ForceField`](https://instadeepai.github.io/mlip/api_reference/models/force_field.html) class is duck-typed for the [`SimulationEngine`](https://instadeepai.github.io/mlip/api_reference/simulation/simulation_engine.html), and you could provide any other object with the following methods and properties (e.g. to wrap models defined in another JAX framework):
- `.__call__(graph: GraphsTuple) -> Prediction`
- `.cutoff_distance: float`
- `.allowed_atomic_numbers: set[int]`

However this kind of general model extension is not thoroughly supported for now. You can provide feedback if you would like to use the library in this way but encounter issues.

---

## 3. Message-passing MLIPNetwork

Now that we went through the *mlip* model hierarchy, let us **define a more meaningful model that is actually able to predict forces**.
In this tutorial, we'll simply demonstrate how to implement a very simple message-passing neural network (MPNN) which can be used with all the other components of the library.

### Hyperparameters

To define our model, we first **create a pydantic BaseModel config object that will encapsulate the attributes of our model**, as before. This allows to seamlessly validate the attributes that are passed to our model (see [the pydantic docs](https://docs.pydantic.dev/latest/) for more information)
and makes it straightforward to store and save increasingly complex configurations.


In [None]:
class MPNNConfig(pydantic.BaseModel):
    """
    Configuration class for our custom MLIP model.
    """
    # Define the configuration parameters
    n_layers: int
    num_features: int
    num_species: int
    mlp_hidden_dims: tuple[int, ...] = (64,)


Having defined our config, we can now create our MLIP model class. Our custom model must inherits the [`MLIPNetwork`](https://instadeepai.github.io/mlip/api_reference/models/mlip_network.html) class, which is itself a `flax.linen.Module` object. As such, we can easily define our network using flax `@nn.compact` decorator, see [the flax docs](https://flax-linen.readthedocs.io/en/latest/quick_start.html) for more information.

Our model must also have a dataset_info attribute of type [`DatasetInfo`](https://instadeepai.github.io/mlip/api_reference/data/dataset_info.html). This object encapsulates the relevant informations about the dataset at hand that can be used to create the model. For instance, this attribute contains the average number of neighbors per atom in the dataset, which is used in models like [MACE](https://arxiv.org/pdf/2206.07697) to normalize the messages passed to each nodes.

We provide a very simple example of MPNN below, which computes messages through an `MLP` encoding of sender and receiver features with edge distances.

In [None]:
from mlip.utils.safe_norm import safe_norm

class MLP(nn.Module):
    """Multi-layer 'perceptron' with silu activation.

    Attributes:
        layers: Dimension of each layer, including the input dimension.
    """
    layers: tuple[int, ...]

    @nn.compact
    def __call__(self, x: Array) -> Array:
        assert x.shape[-1] == self.layers[0]
        for dim in self.layers[1:]:
            x = nn.Dense(dim)(x)
            x = nn.silu(x)
        return x


class MPNN(MLIPNetwork):
    """Our custom MLIP model. It is a flax Module that inherits from MLIPNetwork.

    Attributes:
        config: Configuration object containing model parameters.
        dataset_info: DatasetInfo object containing information about the dataset.
    """
    Config = MPNNConfig

    config: MPNNConfig
    dataset_info: DatasetInfo

    @nn.compact
    def __call__(
        self,
        edge_vectors: jnp.ndarray,
        node_species: jnp.ndarray,
        senders: jnp.ndarray,
        receivers: jnp.ndarray,
    ) -> jnp.ndarray:
        """Compute node-wise energy summands.

        Args:
            edge_vectors: Edge vectors, jnp.array of shape [n_edges, 3].
            node_species: Node species, jnp.array of shape [n_nodes].
            senders: Sender indices, jnp.array of shape [n_edges].
            receivers: Receiver indices, jnp.array of shape [n_edges].
        Returns:
            node_energies: Node energies, jnp.array of shape [n_nodes].
        """

        avg_num_neighbors = self.dataset_info.avg_num_neighbors
        num_species = self.config.num_species
        num_features = self.config.num_features

        num_nodes = node_species.shape[0]
        node_energies = jnp.zeros((num_nodes,))

        # TODO: reuse RadialEmbedding block
        edge_distances = safe_norm(edge_vectors, axis=-1)[:,None]

        # Encode atomic numbers to node features
        node_feats = nn.one_hot(node_species, num_species)
        node_feats = nn.Dense(self.config.num_features)(node_feats)

        # Message-passing steps
        for _ in range(self.config.n_layers):

          edge_feats = jnp.concatenate(
              [edge_distances, node_feats[senders], node_feats[receivers]],
              axis=-1,
          )
          # compute messages
          mlp_in = 1 + 2 * num_features
          mlp_hidden = self.config.mlp_hidden_dims
          messages = MLP([mlp_in, *mlp_hidden, num_features])(edge_feats)
          # propagate messages
          node_feats = node_feats.at[receivers].add(messages / avg_num_neighbors)

        # Project node features to scalar node energies
        node_energies = nn.Dense(1)(node_feats)[...,0] # [n_nodes, ]

        # Add non-interacting atomic energies
        atomic_energies = ConstantMLIP(
            ConstantMLIP.Config(learnable=True),
            self.dataset_info,
        )(edge_vectors, node_species, senders, receivers)

        return node_energies + atomic_energies

Having defined both our model and its associated config classes, we can now instantiate our model and turn it into a [`ForceField`](https://instadeepai.github.io/mlip/api_reference/models/force_field.html) object that can be used for training and simulations.

In [None]:
config = MPNN.Config(
    n_layers = 1,
    num_features = 4,
    num_species = 3,
    mlp_hidden_dims = (4,)
)

mlip_net = MPNN(
    config=config,
    dataset_info=dataset_info
)

force_field = ForceField.from_mlip_network(
    mlip_net,
    predict_stress=False,
    seed=42,
)

# ForceField object can now be fed to the training loop or used for predictions.

In contrast with our previous constant energy predictor, let's evaluate our randomly initialized force field on the same H2O graph to check whether it outputs non-zero forces:


In [None]:
prediction = force_field(graph)
energy = prediction.energy[0]  # second value would be for dummy
forces = jnp.delete(prediction.forces, -1, axis=0)  # last row would be for dummy

print("Energy:", energy)
print("Forces:\n", forces)

Lastly, we show you how you can print the structure of the parameters of the model:

In [None]:
def print_params_structure(params):
    def print_structure(subtree, indent=0):
        for key, value in subtree.items():
            if isinstance(value, dict) or isinstance(value, tuple):
                print(' ' * indent + f"{key}:")
                print_structure(value, indent+2)
            else:
                print(' ' * indent + f"{key}: {value.shape}")
    print_structure(params)

print_params_structure(force_field.params)