# MLIP examples: Model Training

In this notebook, we present general guidelines on **how to run trainings with the *mlip* library**. It is targeted to users with some knowledge of machine learning, though the library is designed to allow both beginners and advanced users to train and deploy MLIP models on targetted datasets.

Note we keep this tutorial minimal but try to cover all important aspects that you'll need to build your own end-to-end MLIP pipelines! Additional tools we leave to the user's discretion include connection of loggers to visualisation tools (e.g. Weights & Biases), and cloud connections for checkpointing / saving models. 

**This notebook aims at showcasing:**

- **How to prepare a dataset** for training, validation and testing
- **How to set up a training loop** including initialing all the necessary tools and configs
- **How to run a training** and save trained models and evaluate them on test sets
- **How to load models** for later use (e.g in MD)
- ***[Advanced topic]* How to create custom loggers** for bespoke applications
- ***[Advanced topic]* How to load models from checkpoints** for restarting a training loop

**Install, required imports, and logging setup**


As a first step, we will run the installation of the *mlip* library directly from pip. We also install the appropriate Jax CUDA backend to run on GPU (comment it out to run on CPU). In this notebook, we will not run any simulation and therefore do not install Jax-MD, for details on how to do so, please refer to our *simulation* tutorial. Note that if you have ran another tutorial in the same environment, this installation is not required. Please refer to [our installation page](https://instadeepai.github.io/mlip/installation/index.html) for more information.


In [None]:
%pip install mlip "jax[cuda12]==0.4.33" huggingface_hub

# Use this instead for installation without GPU:
# %pip install mlip huggingface_hub

For convenience, and as users are expected to have already some knowledge of machine learning we run all the required imports for this tutorial upfront (with comments on where these will be used):

In [None]:
# For dataset loading
from mlip.data import GraphDatasetBuilder, ExtxyzReader

# For model
from mlip.models import Mace, Nequip, Visnet, ForceField

# For optimizer
from mlip.training import get_default_mlip_optimizer, OptimizerConfig

# For loss function
from mlip.models.loss import MSELoss

# For training
from mlip.training import TrainingLoop
from mlip.models.model_io import save_model_to_zip, load_model_from_zip
from mlip.models.params_loading import load_parameters_from_checkpoint

# Other
import logging
import os
import matplotlib.pyplot as plt

We also set up logging for displaying information about the runs, and download to the content folder all the files required for this tutorial.

In [None]:
logging.basicConfig(level=logging.INFO, force=True, format='%(levelname)s - %(message)s')

One can also set a dedicated logging level for *mlip* specifically. Feel free to set this to `logging.DEBUG` for more verbose command line output.

In [None]:
logging.getLogger("mlip").setLevel(logging.INFO)

Let's also check what device we are using:

In [None]:
import jax

print(jax.devices())

## 1. Preparing a dataset

For this example, we train on configurations of the aspirin molecule only. The dataset is a subset of the [Revised MD17 dataset](https://figshare.com/articles/dataset/Revised_MD17_dataset_rMD17_/12672038). The training set will be downloaded at "training/rmd17_aspirin_train.xyz" and the validation set will be download at "training/rmd17_aspirin_val.xyz". All data is transfered from InstaDeep's [HuggingFace collection](https://huggingface.co/collections/InstaDeepAI/ml-interatomic-potentials-68134208c01a954ede6dae42). 

In [None]:
from huggingface_hub import snapshot_download

snapshot_download(repo_id="InstaDeepAI/MLIP-tutorials", allow_patterns="training/*", local_dir="")

The data processing **is a two step process**:

1. **We read the data from disk into [`ChemicalSystem`](https://instadeepai.github.io/mlip/api_reference/data/chemical_system.html) objects**. This is done by a "reader", and since the dataset is stored in extended xyz formatit can be read with the [`ExtxyzReader`](https://instadeepai.github.io/mlip/api_reference/data/chemical_systems_readers/extxyz_reader.html). The *mlip* library also includes a HDF5 format reader: [`Hdf5Reader`](https://instadeepai.github.io/mlip/api_reference/data/chemical_systems_readers/hdf5_reader.html). We expect that users may want to implement their own readers in the future to deal with custom data formats.


In [None]:
reader = ExtxyzReader(
    ExtxyzReader.Config(
        train_dataset_paths="training/rmd17_aspirin_train.xyz",
        valid_dataset_paths="training/rmd17_aspirin_val.xyz",
        test_dataset_paths="training/rmd17_aspirin_test.xyz",
    )
)

2. **We process these [`ChemicalSystem`](https://instadeepai.github.io/mlip/api_reference/data/chemical_system.html) objects into graphs.** This process uses the class [`GraphDatasetBuilder`](https://instadeepai.github.io/mlip/api_reference/data/graph_dataset_builder.html) which offers some degree of customisation through its [config class](https://instadeepai.github.io/mlip/api_reference/data/dataset_configs.html#mlip.data.configs.GraphDatasetBuilderConfig). 



In [None]:
builder_config = GraphDatasetBuilder.Config(
    graph_cutoff_angstrom=5.0,
    batch_size=16,
)

builder = GraphDatasetBuilder(reader, builder_config)
builder.prepare_datasets() # This step is required to compute all dataset information (used later on by most MLIP model)

train_set, validation_set, test_set = builder.get_splits()

More information can be found in the [deep-dive on data processing](https://instadeepai.github.io/mlip/user_guide/data_processing.html)  in our documentation for more details.

We can now **print some statistics about our dataset** along with the [`DatasetInfo`](https://instadeepai.github.io/mlip/api_reference/data/dataset_info.html) object that will be required for downstream tasks. The **dataset info** holds all the hyperparameters of the models that are directly derived from the dataset or its processing, e.g., the cutoff distance to determine the graph edges.

In [None]:
print("Dataset info:", builder.dataset_info)
print("Number of batches in train set:", len(train_set))
print("Number of batches in validation set:", len(validation_set))
print("Number of batches in test set:", len(test_set))

## 2. Preparing a training loop

To start training, we first need to prepare some prerequisites. These are, as for all ML models: 
- A **model architecture**,
- An **optimizer**, and
- A **loss function**

We start with the **model architecture**: 

For this tutorial, we provide the initialization code for MACE, NequIP and ViSnet, but commented out two of them. For all the hyperparameters available, see the documentations of the [MACE config](https://instadeepai.github.io/mlip/api_reference/models/mace.html#mlip.models.mace.config.MaceConfig), the [NequIP config](https://instadeepai.github.io/mlip/api_reference/models/nequip.html#mlip.models.nequip.config.NequipConfig), and the [ViSNet config](https://instadeepai.github.io/mlip/api_reference/models/visnet.html#mlip.models.visnet.config.VisnetConfig).

The model creation process includes two steps, (1) the creation of the MLIP network and (2) of the force field. See our [deep-dive on models](https://instadeepai.github.io/mlip/user_guide/models.html) for a detailed explanation of this pattern. The force field will be the essential object required for the training below, as well as for running MD simulations.

In [None]:
# We override some of the default hyperparameters 
# of the model to make it smaller such that this training example becomes more minimal
mlip_network = Mace(
    Mace.Config(num_channels=16, correlation=2),
    builder.dataset_info,
)

# mlip_network = Nequip(
#     Nequip.Config(
#         node_irreps="4x0e + 4x0o + 4x1o + 4x1e + 4x2e + 4x2o",
#         num_layers=2,
#     ),
#     builder.dataset_info,
# )

# mlip_network = Visnet(
#     Visnet.Config(num_channels=16, num_layers=2),
#     builder.dataset_info,
# )

force_field = ForceField.from_mlip_network(mlip_network)

Next, we **create an optimizer**: 

The *mlip* library is set up so that you can use any [`optax`](https://github.com/google-deepmind/optax) optimizer you like, e.g., [`optax.adam`](https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.adam). We nonetheless provided a default optimizer specialized for MLIP models (see [this](https://instadeepai.github.io/mlip/api_reference/training/optimizer.html#mlip.training.optimizer.get_default_mlip_optimizer) part of the documentation for more details).

In [None]:
optimizer = get_default_mlip_optimizer()

For the **loss**:

We use a Mean-Squared-Error (MSE) loss, that by default uses a weighting factor of 25.0 for MSE of forces, 1.0 for MSE of energies, and zero for MSE of stress (which is not available in this dataset). See [this](https://instadeepai.github.io/mlip/user_guide/training.html#loss) part of the documentation for more information on further options such as alternative loss functions and how to create a weight flip schedule between energy and forces. 

In [None]:
loss = MSELoss()

Finally, we can **create our training loop**:

At a minimum, it needs as input:
- a training dataset
- a validation dataset
- a force field
- a loss
- an optimizer
- a config (which specifies for instace the number of epochs)

Note that the config is documented [here](https://instadeepai.github.io/mlip/api_reference/training/training_loop_config.html). Its only argument that lacks a default value is the number of epochs to train for.

In [None]:
training_config = TrainingLoop.Config(num_epochs=10)

training_loop = TrainingLoop(
    train_dataset=train_set,
    validation_dataset=validation_set,
    force_field=force_field,
    loss=loss,
    optimizer=optimizer,
    config=training_config,
)

## 3. Running a training loop

**Running the loop**:

The following box runs the prepared training loop. Note that training will be a **lot more efficient for GPU users** (depending on the GPU, one should expect ~1s to ~12s per epoch, once the code is compiled) - for CPU users our measures ranged from ~12s to ~100s per epoch.

In [None]:
training_loop.run()

In order to **evaluate on test set**, user can simply run the following line:

In [None]:
training_loop.test(test_set)

This test function is documented [here](https://instadeepai.github.io/mlip/api_reference/training/training_loop.html#mlip.training.training_loop.TrainingLoop.test).

**Recovering the best validation model**:

After training has completed, the [`TrainingLoop`](https://instadeepai.github.io/mlip/api_reference/training/training_loop.html) holds all the relevant information about the run. We can obtain the force field with the best validation parameters as follows:

In [None]:
optimized_force_field = training_loop.best_model

This force field object can now be applied in, for example, MD simulations or energy minimizations. 

**Saving the model to a zip file**:

We can also save the trained model in zip format. This is also the format that we provide our pre-trained models in.

In [None]:
save_model_to_zip("training/final_model.zip", optimized_force_field)

## 4. Loading a pre-trained model

We can load a previously trained model via the zip format and, for instance, print the dataset info stored within it:

In [None]:
loaded_force_field = load_model_from_zip(Mace, "training/final_model.zip")

print("Dataset info:", loaded_force_field.dataset_info)

See [this](https://instadeepai.github.io/mlip/api_reference/models/force_field.html) part of the documentation for more details on the `ForceField` class.

## 5. Advanced topics: Custom logging and Checkpointing

For more **advanced logging and checkpointing**, we import a few more objects from the library:

In [None]:
from mlip.training import TrainingIOHandler, log_metrics_to_line
from mlip.training.training_io_handler import LogCategory

We can now set up a **custom I/O handler** with checkpointing for training:

The I/O handler class is documented [here](https://instadeepai.github.io/mlip/api_reference/training/training_io_handling.html). Also check out [this](https://instadeepai.github.io/mlip/user_guide/training.html#io-handling-and-logging) part in our deep-dive on logging for more information. The code below adds a local directory for checkpointing to the I/O handler, which activates checkpointing during training.

In [None]:
io_handler = TrainingIOHandler(
    TrainingIOHandler.Config(
        local_model_output_dir="training/model_training"
    )
)

Next, we can **attach logging functions to the I/O handler**: 

Users can attached as many logging functions as required to the I/O handler. In this example, we attach two. 
1. The [`log_metrics_to_line`](https://instadeepai.github.io/mlip/api_reference/training/training_io_handling.html#mlip.training.training_loggers.log_metrics_to_line) function that is also included in the default I/O handler that we used in the previous example.
2. A custom function that just keeps track of the validation set losses, so we can later easily create a curve from it.

In [None]:
# The following logger is also attached in the default I/O handler
# that was used in the training above
io_handler.attach_logger(log_metrics_to_line)

# Define a custom logging function that keeps track of validation loss
validation_losses = []
def _custom_logger(category, to_log, epoch_number):
  if category == LogCategory.EVAL_METRICS:
    validation_losses.append(to_log["loss"])

# Attach our custom logging function to the I/O handler
io_handler.attach_logger(_custom_logger)

The custom logging function is called several times during the training loop with the argument `category` (e.g. `TRAIN_METRICS`, `EVAL_METRICS`) telling the function what is currently being logged, for example, train or evaluation metrics. It is of enum type [`LogCategory`](https://instadeepai.github.io/mlip/api_reference/training/training_io_handling.html#mlip.training.training_io_handler.LogCategory). See the documentation of the built-in function [`log_metrics_to_line`](https://instadeepai.github.io/mlip/api_reference/training/training_io_handling.html#mlip.training.training_loggers.log_metrics_to_line) for what we expect the logging function's signature to be.

To illustrate this use of custom logging, we **start a new training run**. It will use the original force field object we created, hence, it will again start with random parameters:

In [None]:
# Only run 5 epochs for this example
training_config.num_epochs = 5

training_loop = TrainingLoop(
    train_dataset=train_set,
    validation_dataset=validation_set,
    force_field=force_field,
    loss=loss,
    optimizer=optimizer,
    config=training_config,
    io_handler=io_handler,
)

training_loop.run()

We can now **access the information stored** by the custom logger saved into our validation loss list:

In [None]:
print(validation_losses)

Let's create a training curve from these values.

In [None]:
epoch_nums = list(range(len(validation_losses)))
plt.plot(epoch_nums, validation_losses)
plt.xlabel("Epoch")
plt.ylabel("Validation loss")
plt.xticks(epoch_nums)
plt.show()

Finally, as mentioned earlier, **creating the custom I/O handler triggered checkpointing**. We can now load a force field from a given model checkpoint, for example, the most recently saved model:

In [None]:
from pathlib import Path

# Find out what is the most recent epoch that was saved
checkpoints = os.listdir("training/model_training/model")
max_epoch_num = max(int(num) for num in checkpoints)

# Load the parameters from the checkpoint
loaded_params_via_ckpt = load_parameters_from_checkpoint(
    local_checkpoint_dir=Path("training/model_training/model").resolve(),
    initial_params=force_field.params,
    epoch_to_load=max_epoch_num,
    load_ema_params=False,
)

# Create a new force field with those parameters
loaded_force_field = ForceField(force_field.predictor, loaded_params_via_ckpt)

See our other example notebooks for how to use loaded models in downstream tasks like batched inference or simulations (MD or energy minimizations).

Furthermore, these checkpoints can of course also be **used to restart a training from a given checkpoint**. We refer to the [documentation of the I/O handler's config](https://instadeepai.github.io/mlip/api_reference/training/training_io_handling.html#mlip.training.training_io_handler.TrainingIOHandlerConfig) for more information on this.