# MLIP example: Running inference and simulations with MLIP

In this notebook, we present general considerations on how to **run MD with pre-trained MLIP models** and what to expect from the performance of models. It is designed to help non-Machine Learning users to quickly learn to use MLIP models and integrate them into their workflow.

For more advanced use of our library, we have prepared two subsequent tutorials: (1) How to train MLIP models from scratch, and (2) how to build new models and tools to extend the *mlip* library.

**This notebook aims at showcasing:**
- **How to load a pre-trained model** from [Hugging Face](https://huggingface.co/collections/InstaDeepAI/ml-interatomic-potentials-68134208c01a954ede6dae42) or from a locally saved zip file
- **How to run inference** on a batch of systems using an MLIP model
- **How to create a simulation engine**, using either Jax-MD or ASE, with a force field based on a pre-trained MLIP model
- **How to run the simulation**, save, and access the results
- **How to select the appropriate model** for your simulation based on our existing pre-trained models.
- **What to expect from MLIP models**, in terms of runtime, and accuracy, based on our experiments.

**Install and logging setup**

In order to run this notebook, we first need to set up the appropriate environment. This requires (1) InstaDeep's *mlip* library, (2) the appropriate Jax CUDA backend to run the models on GPU, and (3) Jax-MD. ASE is installed as part of the *mlip* library. Jax-MD installation is kept separate because it has to be installed from GitHub directly.

We also set up logging for displaying information about the runs, and download to the content folder all the files required for this tutorial from InstaDeep's [HuggingFace collection](https://huggingface.co/collections/InstaDeepAI/ml-interatomic-potentials-68134208c01a954ede6dae42).

In [None]:
%pip install mlip "jax[cuda12]==0.4.33" huggingface_hub git+https://github.com/jax-md/jax-md.git

# Use this instead for installation without GPU:
# %pip install mlip huggingface_hub git+https://github.com/jax-md/jax-md.git

In [None]:
import logging

logging.basicConfig(level=logging.INFO, force=True, format='%(levelname)s - %(message)s')

In [None]:
from huggingface_hub import snapshot_download

snapshot_download(repo_id="InstaDeepAI/MLIP-tutorials", allow_patterns="simulation/*", local_dir="")

Let's also check what device we are using:

In [None]:
import jax

print(jax.devices())

## 1. Loading a pre-trained MLIP model

In this notebook, we will show the illustrative example of a simple MACE model pre-trained on aspirin. It is trained for 100 epochs on over ~1000 conformation of aspirin with energy and forces computed at DFT level.

Our library is designed to store all information about a pre-trained model into a single zip file. Once the model has been downloaded (or trained on your device), we need to (1) specify the model family, (2) specify the path, and (3) load the model directly into a [`ForceField`](https://instadeepai.github.io/mlip/api_reference/models/force_field.html) object.

In [None]:
from mlip.models.model_io import load_model_from_zip
from mlip.models import Mace # (1) we use MACE for this example.

model_path = "simulation/example_model.zip" # (2) path to the zip file

force_field = load_model_from_zip(
    Mace, model_path
) #(3) initialisation of the ForceField instance

## 2. Run batched inference with an MLIP model

Using the loaded MLIP model, we can run batched inference on a number of structures at once. In the example below we run inference on 16 conformations of aspirin using batches of size 8 (it will appear as three batches as the underlying *jraph* library relies on padding which creates an empty graph at the end of each batch).

We first load the 16 structures:

In [None]:
from ase.io import read as ase_read

batched_aspirin_files= "simulation/aspirin_batched_example.xyz"
structures = ase_read(batched_aspirin_files, index=":")

We can now run inference with a single pre-built function, note Jax starts by compiling all the required functions. It may appear slow at the beginning but this provides significant acceleration at scales (compilation is saved in the notebook kernel, so if you want an illustration of the speed gains, you can run the cell twice):

In [None]:
from mlip.inference import run_batched_inference

predictions = run_batched_inference(structures, force_field, batch_size=8)

Energy and forces for each stucture can be obtained directly from the predictions computed:

In [None]:
# Example: Get energy and forces for 7-th structure (indexing starts at 0)
print(f"Energy for structure 7:", predictions[7].energy)
print(f"Forces for structure 7:\n", predictions[7].forces)

## 3. Configure a simulation engine

There are two options for setting up a simulation engine in the *mlip* library: ASE and Jax-MD. While ASE may be more familiar to most users, we recommend Jax-MD for performance as it allows running the entire simulation on GPU. As such we have written below the code for setting up a JaxMD simulation config, but have left commented out the script for ASE.

For advanced options (e.g. step time, friction, temperature scheduling), you can find the relevant documention at the following links for [Jax-MD](https://instadeepai.github.io/mlip/user_guide/simulations.html#simulations-with-jax-md) and [ASE](https://instadeepai.github.io/mlip/user_guide/simulations.html#simulations-with-ase), respectively.

For users with little experience using Jax, here are a couple of points to note regarding Jax-MD:
- Jax compiles the required functions before running them. This allows Jax code to run very efficiently, but it also requires inputs to Jax functions to have static shapes which implies some compromises when running MD.
- The approach used in JaxMD is to set a number of *episodes* in between which the code checks whether the functions need to be recompiled.
- In the config for the `JaxMDSimulationEngine` in the *mlip* library, users can set a total number of steps for their simulations and a number of episodes which **must divide the total number of steps**.
- In general, we find that having episode length of ~1000 steps is a good compromise for most simulations.

In [None]:
from mlip.simulation.jax_md import JaxMDSimulationEngine

config = JaxMDSimulationEngine.Config(
    num_steps=500,
    num_episodes=10, # In Jax-MD, MD results are logged to the console at each episode
    snapshot_interval=1, # The number of steps between each saved MD for the simulation
)

In [None]:
#Example ASE script:

#from mlip.simulation.ase import ASESimulationEngine

#config = ASESimulationEngine.Config(
#    num_steps=500,
#    log_interval=50, #Because ASE does not need episodes, we can specify the log frequency
#    snapshot_interval=1,
#)

## 4. Running a simulation

With the force field loaded and the MD config set-up, we are ready to run a quick simulation, in vacuum. We begin with loading a file and creating and ASE molecule object (which can be used as input by Jax-MD!). As a result, the library currently supports any file format compatible with ASE read.

In [None]:
from ase.io import read as ase_read

example_aspirin_file= "simulation/aspirin_md_example.xyz"
atoms = ase_read(example_aspirin_file)

One final step before running the simulation: we need to initialise the engine, with the force field, config and input system previously created. Initialisation is nearly identical for Jax-MD and ASE:

In [None]:
md_engine = JaxMDSimulationEngine(atoms, force_field, config)

# md_engine = ASESimulationEngine(atoms, force_field, config)

We can now run the simulation with one simple line. When using Jax-MD it may appear that the simulation is slow to start, that is because, as mentioned before Jax optimally compiles all the required functions. That time however is rapidly caught up as the full-GPU based simulations runs significantly faster than with ASE.

In [None]:
md_engine.run()

## 5. Visualise the simulation

There are no visualisation tools included as part of the *mlip* library, however we provided a simple function below to run the visualisation in a notebook.  

In [None]:
%%capture log

!pip install py3Dmol rdkit rdkit2ase

In [None]:
import py3Dmol
from rdkit import Chem
from rdkit2ase import ase2rdkit
from rdkit.Geometry import Point3D
import numpy as np

def update_rdkitpositions(mol,xyz) :
    conf = mol.GetConformer()
    for i in range(mol.GetNumAtoms()):

        x,y,z = np.array(xyz[i] ).astype(np.double)
        conf.SetAtomPosition(i,Point3D(x,y,z))

    return mol


def MolTo3DView(mol, positions , size=(300, 300), style="stick", surface=False, opacity=0.5):
    """Draw molecule in 3D

    Args:
    ----
        mol: rdMol, molecule to show
        size: tuple(int, int), canvas size
        style: str, type of drawing molecule
               style can be 'line', 'stick', 'sphere', 'carton'
        surface, bool, display SAS
        opacity, float, opacity of surface, range 0.0-1.0
    Return:
    ----
        viewer: py3Dmol.view, a class for constructing embedded 3Dmol.js views in ipython notebooks.
    """
    assert style in ('line', 'stick', 'sphere', 'carton')
    mblock = Chem.MolToMolBlock(mol)
    viewer = py3Dmol.view(width=size[0], height=size[1])

    models = ""
    for xyz in  positions :
        mol = update_rdkitpositions(mol , xyz)
        mblock = Chem.MolToPDBBlock(mol)
        models += mblock
    viewer.addModelsAsFrames(models, 'pdb')

    viewer.setStyle({style:{}})
    if surface:
        viewer.addSurface(py3Dmol.SAS, {'opacity': opacity})

    viewer.zoomTo()
    viewer.setStyle({'stick': {}})
    viewer.zoomTo()
    viewer.animate({'loop': 'forward'})


    return viewer

The MD frames are stored in the *state* property of the MD engine used for the simulation. It can be accessed as follows and includes atomic positions for each frames:

In [None]:
md_state_aspirin = md_engine.state

In [None]:
rdkit_object =ase2rdkit( atoms)
viewer = MolTo3DView(rdkit_object, positions =  md_state_aspirin.positions)
viewer.show()

## 6. Points to consider when running MD with MLIP models

MLIP models thrive at approaching DFT accuracy in simulations orders of magnitude faster than DFT. Due to the fact that they are machine learned functions, different MLIP models will have different speed / accuracy trade-offs. Therefore when chosing a model, users should be mindful of their objectives: smaller models will be faster, but oftentimes less accurate.

Another key considerations with MLIP models is that they can be specialised. The issue of course being that models trained to be specialised may have worse performance on general tasks. Related to this point: MLIP models are trained on a subset of the periodic table and therefore can only be run on systems that have the atomic species included in the training set.

As a simplistic illustration of using the wrong MLIP model the following example uses the MLIP model trained on aspirin to run a simulation on 3BPA - which has the same atomic species, but different structures. We can see that the model trained on aspirin only is not able to run a suitable simulation on 3BPA.

In [None]:
example_3bpa_file = "simulation/3bpa_md_example.xyz"
atoms_3bpa = ase_read(example_3bpa_file)
md_engine = JaxMDSimulationEngine(atoms_3bpa, force_field, config)
# md_engine = ASESimulationEngine(atoms_3bpa, force_field, config)
md_engine.run()

In [None]:
rdkit_3bpa =ase2rdkit(atoms_3bpa)
md_state = md_engine.state
viewer = MolTo3DView(rdkit_3bpa, positions =  md_state.positions)
viewer.show()

## 7. Using pre-trained models for simulation (requires GPU set-up)

The pre-trained models provided with the *mlip* library are designed to be generally applied to bio-chemical systems and inlcude the following 15 atomic species: ($\mathrm{B}$, $\mathrm{Br}$, $\mathrm{C}$, $\mathrm{Cl}$, $\mathrm{F}$, $\mathrm{H}$, $\mathrm{I}$, $\mathrm{K}$, $\mathrm{Li}$, $\mathrm{N}$, $\mathrm{Na}$, $\mathrm{O}$, $\mathrm{P}$, $\mathrm{S}$, $\mathrm{Si}$).

For an example of expected runtimes on the different pre-trained models, on various hardware and on various system sizes, we recommend that you read the relevant sections of the associated white paper.

Below we run a cells which loads a ViSNet fondation model (though users can easily load either MACE or NequIP using the code that is commented out) and runs it for 1000 steps of Chignolin. We strongly recommend using a GPU / Jax-MD backend for this part of the tutorial.

As previously described, **the process is**: 
- Download the pre-trained model from HuggingFace
- Initialise the force fields
- Configure the MD engine
- Initialise the engine with an input structure
- Run the simulation

#### 1. Downloading the model

In [None]:
from huggingface_hub import hf_hub_download

hf_hub_download(repo_id="InstaDeepAI/visnet-organics", filename="visnet_organics_01.zip", local_dir="pretrained_models/")
# hf_hub_download(repo_id="InstaDeepAI/mace-organics", filename="mace_organics_01.zip", local_dir="pretrained_models/")
# hf_hub_download(repo_id="InstaDeepAI/nequip-organics", filename="nequip_organics_01.zip", local_dir="pretrained_models/")

In [None]:
from mlip.models import Visnet

organics_model_path = "pretrained_models/visnet_organics_01.zip"
# organics_model_path = "pretrained_models/mace_organics_01.zip"
# organics_model_path = "pretrained_models/nequip_organics_01.zip"

force_field = load_model_from_zip(
    Visnet, organics_model_path
) 

In [None]:
config = JaxMDSimulationEngine.Config(
    num_steps=1000,
    num_episodes=1, 
    snapshot_interval=1,
)

In [None]:
chignolin_file = "simulation/chignolin_protonated.xyz"
atoms = ase_read(chignolin_file)

In [None]:
md_engine = JaxMDSimulationEngine(atoms, force_field, config)

Before running the next cell, beware that it may take a while to run depending on the hardware used. Following compilation, the expected runtime for a 1000 steps episode on chignolin should take ~5s on a H100, ~8s on a A100.  

In [None]:
md_engine.run() 

We can visualise the results here as well, due to the size of the system however this make take a few minutes to run in jupyter. 

In [None]:
md_state_chignolin = md_engine.state
rdkit_object = ase2rdkit(atoms)
viewer = MolTo3DView(rdkit_object, positions =  md_state_chignolin.positions)
viewer.show()