-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Run molecular dynamics with Jax-MD #17
Comments
Hey, we are currently using the I do not have much experience with Import stuff: import numpy as onp
from jax.config import config ; config.update('jax_enable_x64', True)
import jax.numpy as np
import jax
from jax import random
from jax import jit
from jax import lax
import time
from jax_md import space, smap, energy, minimize, quantity, simulate, partition
from jax_md.space import periodic_general, free, DisplacementOrMetricFn, Box
from mlff.mdx.potential import MLFFPotential
from mlff.utils.structures import Graph Define the def to_jax_md(
potential, # the mlff potential
displacement_or_metric: DisplacementOrMetricFn,
box_size: Box, # box if it exists, check md jax documentation for conventions
species: np.ndarray = None, # the atomic species, np.ndarray of shape n_atoms
dr_threshold: float = 0., # currently dr_threshold > 0 is experimental
fractional_coordinates: bool = False,
format: partition.NeighborListFormat = partition.NeighborListFormat(1), # only sparse is supported in mlff
**neighbor_kwargs
):
# create the neighbor_fn
neighbor_fn = partition.neighbor_list(
displacement_or_metric,
box_size,
potential.cutoff, # load the cutoff of the model from the MLFFPotential
dr_threshold,
fractional_coordinates=fractional_coordinates,
format=partition.NeighborListFormat(1),
**neighbor_kwargs)
# create an energy_fn that is compatible with jax_md
def energy_fn(R, neighbor, **energy_fn_kwargs):
idx_i = neighbor.idx[1] # shape: P
idx_j = neighbor.idx[0] # shape: P
R_ij = jax.vmap(displacement_or_metric)(R[idx_j], R[idx_i])
mask = idx_i < len(species)
graph = Graph(nodes=species, edges=R_ij, centers=idx_i, others=idx_j, mask=mask)
return potential(graph).sum()
return neighbor_fn, energy_fn Test that everything is working # test that the energy_fn is working
# load some example data
ethanol_data = np.load('/path/to/mlff/repository/examples/example_data/ethanol.npz')
# create the mlff potential from the ckpt_dir
mlff_potential = MLFFPotential.create_from_ckpt_dir('path/to/ckpt_dir')
# create displacement and shift_fn (no peridic boundary conditions present)
displacement, shift = space.free()
# create the neighbor_fn and energy_fn from
neighbor_fn, energy_fn = to_jax_md(potential=mlff_potential,
displacement_or_metric=displacement,
box_size=None,
species=ethanol_data['z']
)
# jit energy_fn
energy_fn = jit(energy_fn)
# test
R0 = ethanol_data['R'][0]
nbrs = neighbor_fn.allocate(R0)
print('E = {}'.format(energy_fn(R0, neighbor=nbrs))) Do a little NVE simulation with neighborhood list updates # to an nve simulation with jax md, closely following
# https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/nve_neighbor_list.ipynb#scrollTo=c5CMTjhGVyl4
ethanol_data = np.load('/path/to/mlff/repository/examples/example_data/ethanol.npz')
mlff_potential = MLFFPotential.create_from_ckpt_dir('path/to/ckpt_dir')
displacement, shift = space.free()
neighbor_fn, energy_fn = to_jax_md(potential=mlff_potential,
displacement_or_metric=displacement,
box_size=None,
species=ethanol_data['z']
)
init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3, )
R0 = ethanol_data['R'][0]
nbrs = neighbor_fn.allocate(R0)
state = init_fn(random.PRNGKey(0), R0, kT=1e-3, neighbor=nbrs)
def body_fn(i, state):
state, nbrs = state
nbrs = nbrs.update(state.position)
state = apply_fn(state, neighbor=nbrs)
return state, nbrs
step = 0
while step < 40:
new_state, nbrs = lax.fori_loop(0, 100, body_fn, (state, nbrs))
if nbrs.did_buffer_overflow:
print('Neighbor list overflowed, reallocating.')
nbrs = neighbor_fn.allocate(state.position)
else:
state = new_state
step += 1
# print the diff between initial and final positions
print(R0 - state.position) Do not hesitate if you have further questions. |
Thank you so much for your help!!! It indeed helps me to try out the So3krates which turns out to be very competitive! I would like to continue trying something else. May I ask if it is possible to use mlff for training with datasets that contain multiple different molecules. I read your code and it seems the input dataset is padded into an array for training, which makes it difficult to train NNPs with, e.g., different atomtypes? |
Oh I think I found how to do it. I can use the node_mask property_key, right? Basically, set the atom_type to 0 and node_type with true and false? |
It works exactly as you said. In the This way of padding is, however, suboptimal if you have a dataset with high variance in the number of atoms. An important note: If you check the Alternatively, you can set The corresponding part in the CLI code is here mlff/mlff/cAPI/mlff_train_so3krates.py Line 311 in 4a0f330
Not applying a shift at all often leads to bad results since the energy of molecules with different size and composition have a large variance which makes training difficult. Let me know if something is unclear :-) Best |
Thank the developers for the great MLFF!
May I ask if there is any plan about using jax-md as MD simulator?
The text was updated successfully, but these errors were encountered: