In [None]:
from mlff.src.md.calculator import mlffCalculator
from mlff.src.io import read_json
from mlff.src.nn.stacknet import init_stack_net
from flax.training import checkpoints
import os
from ase import units
from ase import Atoms
from ase.optimize import QuasiNewton
from ase.md.langevin import Langevin
from ase.md.verlet import VelocityVerlet
from ase.md.velocitydistribution import (MaxwellBoltzmannDistribution, Stationary, ZeroRotation)
from ase.io import read
from ase.io.trajectory import Trajectory
from ase.visualize import view
from tqdm import tqdm

# Build an ASE Calculator

We load our model that we trained in the Learning_Force_Fields example and construct an ASE calculator from the mlff interface. One can pass an `n_interactions_max` value to the calculator, which sets the maximal number of 
pairwise interactions before the energy and force functions are recompiled. If `n_interactions_max=None`, this can results in recompiling any time the total number of interactions changes. This actually happens quite often, which
is why we suggest to pass a number here. A good estimate is usually given by calculating the average neighborhood size given the cutoff radius, add 1 or 2 to this result and mutliply it with the total number of atoms in the 
system.

In [None]:
load_path = 'example_model/module/'
h = read_json(os.path.join(load_path, 'hyperparameters.json'))

stack_net = init_stack_net(h)
params = checkpoints.restore_checkpoint(load_path, target=None, prefix='checkpoint_loss_')['params']
calc = mlffCalculator(params=params, stack_net=stack_net, n_interactions_max=None)

# Run MD in ASE

In [None]:
mol = read('example_data/md_start_point_ethanol.xyz')
mol.set_calculator(calc)

# do a quick geometry relaxation
qn = QuasiNewton(mol)
qn.run(1e-4, 100)

# # set the momenta corresponding to T=300K
# MaxwellBoltzmannDistribution(mol, temperature_K=300)
# Stationary(mol)  # zero linear momentum
# ZeroRotation(mol)  # zero angular momentum

dyn = Langevin(mol, 0.2 * units.fs, temperature_K=300, friction=0.002, trajectory='example_data/md_ethanol_langevin.traj')
def printenergy(a):
        # function to print the potential, kinetic and total energy
        epot = a.get_potential_energy() / len(a)
        ekin = a.get_kinetic_energy() / len(a)
        print('Energy per atom: Epot = %.3feV  Ekin = %.3feV (T=%3.0fK)  '
                'Etot = %.3feV' % (epot, ekin, ekin / (1.5 * units.kB), epot + ekin))

# now run the dynamics
n_steps = 1000
printenergy(mol)
for i in range(n_steps):
    if i % 10 == 0:
        print('{} / {}'.format(i, n_steps))
    dyn.run(10)
    printenergy(mol)

In [None]:
# install nglview to execute the cell from https://github.com/nglviewer/nglview#installation
# TODO: visualization currently breaks ... 
traj = Trajectory('example_data/md_ethanol_langevin.traj')
pos = []
sym = []
atoms = []
atoms = [a for n, a in enumerate(traj) if n % 5 == 0]
v = view(atoms, viewer='ngl')
v.view.add_ball_and_stick()
v