# Learning a GNN potential for diamond

In this notebook, we will learn a graph neural network (GNN) potential
from experimental stiffness tensor data. We will directly apply the DiffTRe
library. If you are interested in the working mechanisms of DiffTRe,
please have a look at the double well notebook.

We will define all components necessary to initialize DiffTRe:
experimental data, simulation box, GNN potential with prior, simulator,
loss function and optimizer.


In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = str(1)
from chemtrain.jax_md_mod import io, custom_space, custom_energy, \
    custom_quantity, custom_simulator
from chemtrain import difftre
from jax_md import partition, simulate, space
import optax
from jax import random, checkpoint
import jax.numpy as jnp
from functools import partial
import matplotlib.pyplot as plt
import pickle
import warnings
warnings.filterwarnings('ignore')  # disable warnings about float64 usage

optimization_pickle_file_path = 'saved_models/Diamond_GNN.pkl'
long_traj_pickle_file_path = 'saved_optimization_results/Diamond_long_traj.pkl'

### Experimental observables

It is often useful to start with the experimental target data. In this example,
these are elastic stiffness tensor values from an experiment conducted in the
paper _The Elastic Stiffness Moduli of Diamond_
[(McSkimin et. al., 1972)](https://doi.org/10.1063/1.1661318).
Due to symmetries in cubic crystals, the stiffness tensor consists of three
distinct stiffness moduli $c_{11}$, $c_{12}$ and $c_{44}$
 (in Voigt notation).</br> In the un-strained state, we assume the crystal to
 be in a stress-free state. The second target observable is therefore an
 initial stress tensor of $\mathbf{0}$.

In [None]:
c_11_target = 1079.  # in GPa
c_12_target = 124.
c_44_target = 578.

convert_from_GPa_to_kJ_mol_nm_3 = 10**3 / 1.66054
stiffness_targets = jnp.array([c_11_target, c_12_target, c_44_target]) * \
                    convert_from_GPa_to_kJ_mol_nm_3

target_dict = {'stress': jnp.zeros((3, 3)), 'stiffness': stiffness_targets}

### Simulation setup

Next, we need to define the parameters of the simulation.
Some of these parameters are already defined by the experiment, e.g. temperature and density of the material during
the experiment.

In [None]:
system_temperature = 298.15  # Kelvin = 25 deg. celsius
Boltzmann_constant = 0.0083145107  # in kJ / mol K
kbT = system_temperature * Boltzmann_constant

experimental_density = 3512  # kg / m^3
mass = 12.011  # mass of carbon atoms in u

file_loc = '../examples/data/confs/Diamond.pdb'
R_init, box = io.load_box(file_loc)  # initial configuration of diamond crystal
N = R_init.shape[0]
density = mass * N * 1.66054 / jnp.prod(box)
# adjust box size to experiment
box *= (density / experimental_density) ** (1 / 3)
density = mass * N * 1.66054 / jnp.prod(box)
print('Model Density:', density, 'kg/m^3. Experimental density:', experimental_density, 'kg/m^3')

# simulation times in ps
dt = 0.5e-3
total_time = 70.
t_equilib = 10.  # discard all states within the first 5 ps as equilibration
print_every = 0.025  # save states every 0.1 ps for use in averaging
timing_struct = custom_simulator.process_printouts(dt, total_time,
                                                   t_equilib, print_every)

# we need to take derivatives wrt. the simulation box to compute the
# stiffness tensor we therefore use a differentiable box in fractal
# coordinates: all atom positions are mapped onto a unit cube
box_tensor, scale_fn = custom_space.init_fractional_coordinates(box)
R_init = scale_fn(R_init)
displacement, shift = space.periodic_general(box_tensor)

### GNN Potential

Now we define the GNN as well as the prior potential functioning as an
advantageous starting point for the optimization. As prior potential,
we select the Stillinger-Weber potential originally developed for silicon
and only roughly adjust the length scale $\sigma$ and energy scale $\epsilon$
such that it is compatible with diamond. After defining both potentials,
we can combine them by simply adding computed potential energy values.
This is done inside the `energy_fn_template` defining a combined `energy`
function from the current GNN weights. This scheme enables updating GNN
weights during the optimization without compromising "jitability".
You can define a custom template by creating a function that takes current
learnable weights and returns a corresponding `energy_fn`.

In [None]:
# define random seed for initialization of model and simulation
key = random.PRNGKey(0)
model_init_key, simuation_init_key = random.split(key, 2)

# prior parameters
sigma = 0.14  # nm
eps = 200  # kJ / mol
r_cut_SW = sigma * 1.8  # a = 1.8 in original SW potential
r_cut_NN = 0.2

# we use the same neighbor list for prior and GNN to avoid keeping track
# of 2 separate lists
r_cut_nbrs = max(r_cut_SW, r_cut_NN)
neighbor_fn = partition.neighbor_list(displacement, box[0], r_cut_nbrs,
                                      dr_threshold=0.05,
                                      capacity_multiplier=1.5,
                                      fractional_coordinates=True)
nbrs_init = neighbor_fn(R_init)

init_fn, GNN_energy = custom_energy.DimeNetPP_neighborlist(displacement,
                                                           R_init,
                                                           nbrs_init,
                                                           r_cut_NN)
init_params = init_fn(model_init_key, R_init, neighbor=nbrs_init)

prior_fn = custom_energy.\
    stillinger_weber_neighborlist(displacement,
                                  cutoff=r_cut_SW,
                                  sigma=sigma,
                                  epsilon=eps,
                                  initialize_neighbor_list=False)

def energy_fn_template(energy_params):
    gnn_energy = partial(GNN_energy, energy_params)
    def energy(R, neighbor, **dynamic_kwargs):
        return gnn_energy(R, neighbor=neighbor, **dynamic_kwargs) + \
               prior_fn(R, neighbor=neighbor, **dynamic_kwargs)
    return energy

### Loss function
After defining the `energy_fn_template`, we can define the loss function.
We compute the stiffness tensor via the 'stress-fluctuation method',
where multiple instantaneous quantities need to be combined:
$C_{ijkl} = \langle C^B_{ijkl} \rangle - \frac{V}{k_B T}\left(\langle
\sigma^B_{ij} \sigma^B_{kl} \rangle - \langle \sigma^B_{ij} \rangle \langle
\sigma^B_{kl} \rangle\right) + \frac{N k_B T}{V}\left( \delta_{ik} \delta_{jl} +
\delta_{il} \delta_{jk} \right)$
$\mathrm{with} \quad \sigma^B_{ij} = \frac{1}{V}\frac{\partial U}{\partial
\epsilon_{ij}} \ ; \quad  C^B_{ijkl} = \frac{1}{V}\frac{\partial^2 U}{\partial
 \epsilon_{ij} \partial \epsilon_{kl}}$

We therefore need to compute instantaneous values of stresses $\sigma_{ij}$,
$\sigma^B_{ij}$ and $C^B_{ijkl}$ and combine them to observables $\langle
\sigma_{ij} \rangle$ and $C_{ijkl}$. Therefore, we cannot use the default
loss function in `DiffTRe_init` because it is only applicable to
observables that are directly computed as averages of instantaneous quantities.

We can compute instantaneous quantities from a trajectory via a `quantity_dict`.
It requires for each instantaneous observable $O_k(S_i, U_{\theta})$ a function
to compute the respective observable for each state $S_i$.
In the end, we combine loss contributions from stress and stiffness tensors by
multiplication of each component with a scale $\gamma$, which is useful to
counteract the effect of different units / orders of magnitude of observables.

When constructing a custom loss function, we need to adhere to the function
signature required by the implementation of DiffTRe:
`Loss_fn` takes as input the trajectory of computed instantaneous quantities
saved in a dict under its respective key of the `quantitiy_dict`.
Additionally, it receives corresponding weights $w_i$ to compute averages
under the reweighting scheme. With these components, target observables can
be computed. The output of the function is a tuple `(loss value, predicted
ensemble averages)`. The latter is only necessary for post-processing of the
optimization process. `loss_fn` below provides an example implementation.

In [None]:
# initialize observable functions
stress_fn = custom_quantity.init_virial_stress_tensor(energy_fn_template,
                                                      box_tensor)

born_stiffness_fn, sigma_born, sigma_tensor_prod, stiffness_tensor_fn = \
    custom_quantity.init_stiffness_tensor_stress_fluctuation(energy_fn_template,
                                                             box_tensor, kbT, N)

# add all functions that compute instantaneous values into quantity_dict
# DiffTRe computes correspondig quantities for each state in the (sparse) trajectory
# checkpoint is needed to avoid out-of-memory errors for expensive models, e.g.  GNNs
quantity_dict = {}
quantity_dict['stress'] = {'compute_fn': checkpoint(stress_fn)}
quantity_dict['born_stiffness'] = {'compute_fn': checkpoint(born_stiffness_fn)}
quantity_dict['born_stress'] = {'compute_fn': checkpoint(sigma_born)}

gamma_stress = 5.e-8
gamma_stiffness = 1.e-10

def loss_fn(quantity_trajs, weights):
    
    def reweighting_average(quantity_snapshots):
        weighted_snapshots = (quantity_snapshots.T * weights).T
        return jnp.sum(weighted_snapshots, axis=0)
    
    # compute all contributions to stiffness tensor
    born_stress_prod_snapshots = sigma_tensor_prod(quantity_trajs['born_stress'])
    stress_product_born_mean = reweighting_average(born_stress_prod_snapshots)  # <sigma^B_ij sigma^B_kl>
    born_stress_tensor_mean = reweighting_average(quantity_trajs['born_stress'])
    born_stiffness_mean = reweighting_average(quantity_trajs['born_stiffness'])
    stiffness_tensor = stiffness_tensor_fn(born_stiffness_mean,
                                           born_stress_tensor_mean,
                                           stress_product_born_mean)
    
    stiffness_pred = custom_quantity.\
        stiffness_tensor_components_cubic_crystal(stiffness_tensor)
    stress_tensor_mean = reweighting_average(quantity_trajs['stress'])
    
    loss = gamma_stress * difftre.mse_loss(stress_tensor_mean, target_dict['stress']) + \
           gamma_stiffness * difftre.mse_loss(stiffness_pred, target_dict['stiffness'])

    predictions = {'stress': stress_tensor_mean, 'stiffness': stiffness_pred}
    
    return loss, predictions


### Simulator

Let's define a `simulator_template` taking the current potential energy
function from the `energy_fn_template` as input and returning a NVT
simulator for the current energy function. This is achieved by simply
"baking in" all constant simulation parameters via `partial`. For this
example, we select a Langevin thermostat. With a given simulator,
we can generate the initial simulation state from which the first trajectory
is sampled.

In [None]:
energy_fn_init = energy_fn_template(init_params)
simulator_template = partial(simulate.nvt_langevin, shift=shift, dt=dt, kT=kbT,
                             gamma=4.)
init, _ = simulator_template(energy_fn_init)
state = init(simuation_init_key, R_init, mass=mass, neighbor=nbrs_init)
# store neighbor list together with current simulation state
init_sim_state = (state, nbrs_init)

### Optimizer

The last missing component for DiffTRe is the optimizer. Here, we select an
Adam optimizer with exponential learning rate decay.

In [None]:
num_updates = 700
initial_step_size = 0.002
lr_schedule = optax.exponential_decay(-initial_step_size, 500, 0.01)
optimizer = optax.chain(
    optax.scale_by_adam(0.1, 0.4),
    optax.scale_by_schedule(lr_schedule)
)

### Initialize DiffTRe

With all individual components defined, we can initialize the DiffTRe trainer.

In [None]:
trainer = difftre.Trainer(init_params, quantity_dict, simulator_template,
                          energy_fn_template, neighbor_fn, init_sim_state,
                          timing_struct, optimizer, kbT, loss_fn=loss_fn)

### Learning the GNN potential

This optimization takes some time. You can skip it and load our trained model
in the next cell.

In [None]:
trainer.train(num_updates)

trainer.save_trainer(optimization_pickle_file_path)

In [None]:
# load results, if necessary
trainer = difftre.Trainer.load_trainer(optimization_pickle_file_path)

## Visualize training process

Let's start post-processing by plotting the loss and compute-time per update.
In this application, most of the update steps do not require a re-computation
of the trajectory, reducing computation time significantly.
Computing the stiffness tensor is rather expensive. The speed-up that results
from re-using reference trajectories is therefore reduced compared to
cheap observables, e.g. RDFs.

In [None]:
fig, ax1 = plt.subplots()
ax1.set_xlabel('Update Step')
ax1.set_ylabel('MSE Loss')
ax1.semilogy(trainer.losses, color='#3c5488ff', label='Loss')
ax2 = ax1.twinx()
ax2.plot(trainer.update_times, label='Time per update', color='#4dbbd5ff')
ax2.set_ylabel('$t$ in s')
fig.legend(loc="upper right", bbox_to_anchor=(1,1),
           bbox_transform=ax1.transAxes)
plt.savefig('Train_history_diamond.png')
plt.show()


Next, we visualize convergence of our observables, the stress and stiffness
tensor values.

In [None]:
predicted_quantities = trainer.predictions

fig, ax = plt.subplots()
ax.set_xlabel('Update Step')
ax.set_ylabel('Hydrostatic stress in GPa')

if 'stress' in predicted_quantities[0]:
    pressure_series = [jnp.trace(prediction_dict['stress']) /
                       convert_from_GPa_to_kJ_mol_nm_3 / 3. for
                       prediction_dict in predicted_quantities]
    ax.plot(pressure_series, label='$\sigma_h$', color='#3c5488ff')
    ax.axhline(y=0., linestyle='--', color='k')
    print('Predicted stress tensor:', predicted_quantities[-1]['stress'] /
          convert_from_GPa_to_kJ_mol_nm_3)

if 'stiffness' in predicted_quantities[0]:
    stiffness_list = [prediction_dict['stiffness'] for prediction_dict in
                      predicted_quantities]
    stiffness_array = jnp.stack(stiffness_list) / convert_from_GPa_to_kJ_mol_nm_3  # back to GPa
    ax1 = ax.twinx()
    ax1.set_ylabel('$C_{ij}$ in GPa')
    ax1.set_ylim([0, 1200])
    ax1.plot(stiffness_array[:, 0], label='$c_{11}$', color='#00a087ff')
    ax1.plot(stiffness_array[:, 1], label='$c_{12}$', color='#91d1c2ff')
    ax1.axhline(y=c_12_target, linestyle='--', color='k')
    ax1.plot(stiffness_array[:, 2], label='$c_{44}$', color='#8491b4ff')
    ax1.axhline(y=c_44_target, linestyle='--', color='k')
    ax1.axhline(y=c_11_target, linestyle='--', label='targets', color='k')


fig.legend(loc="lower right", bbox_to_anchor=(1.4, 0.),
           bbox_transform=ax1.transAxes)
plt.savefig('Stress_stiffness_history.png')
plt.show()

## Validate simulation results

Given that we've trained the model on rather short trajectories, let's validate
the model on a longer trajectory. This allows to detect if the model overfitted
to initial conditions or drifts away from the targets.

In [None]:
total_time_long = 11000.  # 11 ns
t_equilib_long = 1000.  # 1 ns


long_trajectory_struct = custom_simulator.process_printouts(dt, total_time_long,
                                                            t_equilib_long, print_every)
trajectory_generator = custom_simulator.trajectory_generator_init(simulator_template,
                                                                  energy_fn_template,
                                                                  neighbor_fn,
                                                                  long_trajectory_struct)
long_traj_state = trajectory_generator(trainer.params, init_sim_state)
quantity_traj = difftre.quantity_traj(long_traj_state, quantity_dict,
                                      neighbor_fn, trainer.params)

stress_tensor = jnp.mean(quantity_traj['stress'], axis=0)
print('Predicted stress tensor:', stress_tensor /
      convert_from_GPa_to_kJ_mol_nm_3)

born_stress_prod_snapshots = sigma_tensor_prod(quantity_traj['born_stress'])
stress_product_born_mean = jnp.mean(born_stress_prod_snapshots, axis=0)
born_stress_tensor_mean = jnp.mean(quantity_traj['born_stress'], axis=0)
born_stiffness_mean = jnp.mean(quantity_traj['born_stiffness'], axis=0)
stiffness_tensor = stiffness_tensor_fn(born_stiffness_mean,
                                       born_stress_tensor_mean,
                                       stress_product_born_mean)

stiffness_components = \
    custom_quantity.stiffness_tensor_components_cubic_crystal(stiffness_tensor)
print('Predicted c_11:', stiffness_components[0] /
      convert_from_GPa_to_kJ_mol_nm_3, 'GPa; Target:', c_11_target, 'GPa')
print('Predicted c_12:', stiffness_components[1] /
      convert_from_GPa_to_kJ_mol_nm_3, 'GPa; Target:', c_12_target, 'GPa')
print('Predicted c_44:', stiffness_components[2] /
      convert_from_GPa_to_kJ_mol_nm_3, 'GPa; Target:', c_44_target, 'GPa')


with open(long_traj_pickle_file_path, 'wb') as f:
    pickle.dump([stress_tensor, stiffness_components], f)

In [None]:
# you can skip the long run above and just load results
stress_tensor, stiffness_components = \
    pickle.load(open(long_traj_pickle_file_path, "rb" ))

In [None]:
final_epoch = num_updates - 1
x_position = [final_epoch for i in range(3)]
stiffnesses_GPa = stiffness_components / convert_from_GPa_to_kJ_mol_nm_3

ax.scatter(final_epoch, jnp.trace(stress_tensor) / convert_from_GPa_to_kJ_mol_nm_3 / 3.,
           marker='x', color='k', label='10 ns')
ax1.scatter(x_position, stiffnesses_GPa, marker='x', color='k')

fig.legend(loc="lower right", bbox_to_anchor=(1.4, 0.),
           bbox_transform=ax1.transAxes)
fig.savefig('Stress_stiffness_history.png')
fig