Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run molecular dynamics with Jax-MD #17

Closed
Chenghao-Wu opened this issue Oct 8, 2023 · 4 comments
Closed

Run molecular dynamics with Jax-MD #17

Chenghao-Wu opened this issue Oct 8, 2023 · 4 comments

Comments

@Chenghao-Wu
Copy link

Thank the developers for the great MLFF!

May I ask if there is any plan about using jax-md as MD simulator?

@thorben-frank
Copy link
Owner

thorben-frank commented Oct 9, 2023

Hey,

we are currently using the mlff internal md library mdx. However, if you want to use jax_md with a trained NN from mlff you can do so by defininig a helper function that maps an MLFFPotential to a jax_md potential. All you need is some molecular data and the pre-trained mlff model (the path to ckpt_dir).

I do not have much experience with jax_md so I basically copy pasted there NVE example from here https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/nve_neighbor_list.ipynb#scrollTo=c5CMTjhGVyl4 and defined the helper function to_jax_md. Code (tested with jax_md version 0.2.0 since the newest pip version throws an e3nn_jax import error) see below:

Import stuff:

import numpy as onp

from jax.config import config ; config.update('jax_enable_x64', True)
import jax.numpy as np
import jax
from jax import random
from jax import jit
from jax import lax

import time

from jax_md import space, smap, energy, minimize, quantity, simulate, partition
from jax_md.space import periodic_general, free, DisplacementOrMetricFn, Box

from mlff.mdx.potential import MLFFPotential
from mlff.utils.structures import Graph

Define the to_jax_md function.

def to_jax_md(
    potential,  # the mlff potential
    displacement_or_metric: DisplacementOrMetricFn,
    box_size: Box, # box if it exists, check md jax documentation for conventions
    species: np.ndarray = None,  # the atomic species, np.ndarray of shape n_atoms
    dr_threshold: float = 0.,  # currently dr_threshold > 0 is experimental
    fractional_coordinates: bool = False,
    format: partition.NeighborListFormat = partition.NeighborListFormat(1), # only sparse is supported in mlff
    **neighbor_kwargs
    ):
    
    # create the neighbor_fn
    neighbor_fn = partition.neighbor_list(
        displacement_or_metric,
        box_size,
        potential.cutoff,  # load the cutoff of the model from the MLFFPotential
        dr_threshold,
        fractional_coordinates=fractional_coordinates,
        format=partition.NeighborListFormat(1),  
        **neighbor_kwargs)
    
    # create an energy_fn that is compatible with jax_md
    def energy_fn(R, neighbor, **energy_fn_kwargs):
        idx_i = neighbor.idx[1]  # shape: P
        idx_j = neighbor.idx[0]  # shape: P
        R_ij = jax.vmap(displacement_or_metric)(R[idx_j], R[idx_i])
        mask = idx_i < len(species)
        graph = Graph(nodes=species, edges=R_ij, centers=idx_i, others=idx_j, mask=mask)
        return potential(graph).sum()

    return neighbor_fn, energy_fn

Test that everything is working

# test that the energy_fn is working

# load some example data
ethanol_data = np.load('/path/to/mlff/repository/examples/example_data/ethanol.npz')

# create the mlff potential from the ckpt_dir
mlff_potential = MLFFPotential.create_from_ckpt_dir('path/to/ckpt_dir')

# create displacement and shift_fn (no peridic boundary conditions present)
displacement, shift = space.free()

# create the neighbor_fn and energy_fn from 
neighbor_fn, energy_fn = to_jax_md(potential=mlff_potential,
                                   displacement_or_metric=displacement, 
                                   box_size=None,
                                   species=ethanol_data['z']
                                  )
# jit energy_fn
energy_fn = jit(energy_fn)

# test
R0 = ethanol_data['R'][0]
nbrs = neighbor_fn.allocate(R0)
print('E = {}'.format(energy_fn(R0, neighbor=nbrs)))

Do a little NVE simulation with neighborhood list updates

# to an nve simulation with jax md, closely following 
# https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/nve_neighbor_list.ipynb#scrollTo=c5CMTjhGVyl4

ethanol_data = np.load('/path/to/mlff/repository/examples/example_data/ethanol.npz')

mlff_potential = MLFFPotential.create_from_ckpt_dir('path/to/ckpt_dir')

displacement, shift = space.free()

neighbor_fn, energy_fn = to_jax_md(potential=mlff_potential,
                                   displacement_or_metric=displacement, 
                                   box_size=None,
                                   species=ethanol_data['z']
                                  )

init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3, )

R0 = ethanol_data['R'][0]
nbrs = neighbor_fn.allocate(R0)

state = init_fn(random.PRNGKey(0), R0, kT=1e-3, neighbor=nbrs)

def body_fn(i, state):
    state, nbrs = state
    nbrs = nbrs.update(state.position)
    state = apply_fn(state, neighbor=nbrs)
    return state, nbrs

step = 0
while step < 40:
    new_state, nbrs = lax.fori_loop(0, 100, body_fn, (state, nbrs))
    if nbrs.did_buffer_overflow:
        print('Neighbor list overflowed, reallocating.')
        nbrs = neighbor_fn.allocate(state.position)
    else:
        state = new_state
        step += 1

# print the diff between initial and final positions
print(R0 - state.position)

Do not hesitate if you have further questions.

@Chenghao-Wu
Copy link
Author

Thank you so much for your help!!! It indeed helps me to try out the So3krates which turns out to be very competitive!

I would like to continue trying something else. May I ask if it is possible to use mlff for training with datasets that contain multiple different molecules. I read your code and it seems the input dataset is padded into an array for training, which makes it difficult to train NNPs with, e.g., different atomtypes?

@Chenghao-Wu
Copy link
Author

Oh I think I found how to do it. I can use the node_mask property_key, right? Basically, set the atom_type to 0 and node_type with true and false?

@thorben-frank
Copy link
Owner

thorben-frank commented Oct 21, 2023

It works exactly as you said. In the .npz file pad the atomic types and atomic positions with zeros and set the 'node_mask' to True/False. If you are using ASE digestible inputs, you do not need to do anything and it will create the corresponding data structure automatically internally.

This way of padding is, however, suboptimal if you have a dataset with high variance in the number of atoms.

An important note: If you check the train_so3krates command there is an option called --shift_by which is set to mean by default. Note that this should not be used for datasets with molecules of different size (I actually need to add a warning/error at some point). Instead, you should either use --shift_by no if you have shifted the energy in your .npz file already yourself. Alternatively you can use --shift_by type which will shift the energy by an atom type specific shift which must be parsed to the train_so3krates command using the --shifts keyword. See here https://github.com/thorben-frank/mlff#energy-shifts.

Alternatively, you can set --shift_by lse which will determine energy shifts from the data by solving an LSE.

The corresponding part in the CLI code is here

if shift_by == 'mean':
.

Not applying a shift at all often leads to bad results since the energy of molecules with different size and composition have a large variance which makes training difficult.

Let me know if something is unclear :-)

Best
Thorben in

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants