# Basic usage for ReaxNet
Example of how to use ReaxNet to predict energy and forces for a given structure.

## Import libraries
Make sure you have the required libraries installed.

In [1]:
import yaml
import pickle
import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
from ase.io import read
from jax_md import space, partition
from reaxnet.egnn.nequip import NequIPEnergyModel
from reaxnet.egnn.data import AtomicNumberTable
from reaxnet.egnn.nn_util import neighbor_list_featurizer
from reaxnet.jax_nb.parameters import pqeq_parameters
from reaxnet.jax_nb.jax_nb import pqeq_fori_loop, nonbond_potential, LAMBDA

## Read the atomic structure
We use float64 for the calculations.

In [2]:
jax.config.update("jax_enable_x64", True)
atoms = read('Li2PO2N.cif')

## Load the pretrained model
The pretrained model is trained on the [MpTrj](https://figshare.com/articles/dataset/Materials_Project_Trjectory_MPtrj_Dataset/23713842) dataset. The pretrained model can be downloaded [here](https://figshare.com/s/182754086804163dd29a).

In [3]:
MODELPATH = '../pretrained/'
with open(MODELPATH+'model_config.yaml', 'r') as f:
    model_dict = yaml.safe_load(f)
with open(MODELPATH+'params.pickle', 'rb') as f:
    params = pickle.load(f)
ztable = AtomicNumberTable.from_dict(MODELPATH+'mapping.yaml')
model = NequIPEnergyModel(**model_dict)

## Prepare to predict
### 1. Convert numpy array to jax.numpy array

In [4]:
positions = atoms.get_scaled_positions()
box = jnp.asarray(atoms.get_cell().array.transpose())
atomic_numbers = jnp.asarray(atoms.numbers)
chemical_symbols = atoms.get_chemical_symbols()
nn_atomic_numbers = ztable.mapping(atomic_numbers) 
nn_atomic_numbers = jax.nn.one_hot(jnp.array(nn_atomic_numbers), len(ztable) + 1)

### 2. Define the displacement function and neighbor list function

In [5]:
displacement_fn, _ = space.periodic_general(box, fractional_coordinates=True)
nn_neighbor_fn = partition.neighbor_list(
                            displacement_fn,
                            box,
                            model_dict['r_max'],
                            format=partition.Sparse,
                            fractional_coordinates=True,
                            )
nb_neighbor_fn = partition.neighbor_list(
                            displacement_fn,
                            box,
                            12.5,
                            format=partition.Sparse,
                            fractional_coordinates=True,
)
featurizer = neighbor_list_featurizer(displacement_fn)

### 3. Define the machine learning potential energy function

In [6]:
def energy_nn(embedded_numbers, model, params, position, neighbor, **kwargs):
    graph = featurizer(embedded_numbers, position, neighbor, **kwargs)
    atomic_output = model.apply(params, graph.edges, graph.nodes, graph.senders, graph.receivers) 
    return jnp.sum(atomic_output[:-1]) 
energy_fn_nn = partial(energy_nn, nn_atomic_numbers, model, params)

### 4. Define the polarizable long-range interactions function

In [7]:
rad = jnp.array([pqeq_parameters[s]['rad'] for s in chemical_symbols])
alpha = 0.5 * LAMBDA / rad / rad
alpha = jnp.sqrt(alpha.reshape(-1,1) * alpha.reshape(1,-1) / (alpha.reshape(-1,1) + alpha.reshape(1,-1)))
chi0 = jnp.array([pqeq_parameters[s]['chi0'] for s in chemical_symbols])
eta0 = jnp.array([pqeq_parameters[s]['eta0'] for s in chemical_symbols])
z = jnp.array([pqeq_parameters[s]['Z'] for s in chemical_symbols])
Ks = jnp.array([pqeq_parameters[s]['Ks'] for s in chemical_symbols])

charges_fn = partial(pqeq_fori_loop, displacement_fn,
                                     alpha=alpha, cutoff=12.5, iterations=2, net_charge=0,
                                     eta0=eta0, chi0=chi0, z=z, Ks=Ks)
energy_fn_nb = partial(nonbond_potential, displacement_fn,
                       alpha=alpha, cutoff=12.5, eta0=eta0, chi0=chi0, z=z, Ks=Ks,
                       compute_d3=False, atomic_numbers=atomic_numbers, 
                       d3_params={'s6': 1.0, 'rs6': 1.217, 's18': 0.722, 'rs18': 1.0, 'alp': 14.0},
                       damping='zero', smooth_fn=None)

### 5. Define the total energy function

In [8]:
def energy_fn(positions,
              nn_nbr,
              nb_nbr,
             ):
    """
    Total potential energy function

    Args:
    ----
    positions: jnp.array
        Atomic positions in Angstrom
    """

    # Update the neighbor lists first
    nn_nbr = nn_nbr.update(positions)
    nb_nbr = nb_nbr.update(positions)
    pe_nn = energy_fn_nn(positions, nn_nbr)

    charges, r_shell = charges_fn(jax.lax.stop_gradient(positions), 
                                  nb_nbr)
    pe_nb = energy_fn_nb(positions, nb_nbr, r_shell, charges)
    
    return pe_nn + pe_nb, (charges, r_shell)

### 6. Allocate the neighbor list

In [9]:
nn_nbr = nn_neighbor_fn.allocate(positions)
nb_nbr = nb_neighbor_fn.allocate(positions)

### 7. Define the value and gradient functions
The forces are calculated as the negative gradient of the energy with respect to the atomic positions. The charges are the auxiliary output when calculating the polarizable long-range interactions.

In [10]:
value_and_grad_fn = jax.jit(jax.value_and_grad(partial(energy_fn, 
                                nn_nbr=nn_nbr, nb_nbr=nb_nbr), argnums=0, has_aux=True))

## Get results from the model
- Energy: results[0][0], in eV
- Forces: results[1], in eV/Angstrom
- Charges: results[0][1][0], in e


In [11]:
results = value_and_grad_fn(positions)

In [12]:
energy = results[0][0]
print('Energy:', energy, 'eV')

Energy: -14602.55344216379 eV


In [13]:
forces = -results[1]
print('Forces: \n', forces)

Forces: 
 [[ 0.05382896 -0.04230339  0.0018265 ]
 [ 0.05382896  0.04230339  0.0018265 ]
 [-0.05382896  0.04230339  0.0018265 ]
 ...
 [ 0.08315825 -0.09146194 -0.05896824]
 [-0.08315825 -0.09146194 -0.05896824]
 [-0.08315825  0.09146194 -0.05896824]]


In [14]:
charges = results[0][1][0]
print('Charges: \n', charges)

Charges: 
 [ 0.37500883  0.37500883  0.37500883 ... -0.2685427  -0.2685427
 -0.2685427 ]
