# Non-bond interactions usage
Example calculations of potential energy and forces of LiPON electrolyte.

In [1]:
from ase.io import read
import numpy as np
import jax
import jax.numpy as jnp
from reaxnet.jax_nb.jax_nb import pqeq_fori_loop, nonbond_potential, LAMBDA
from jax_md import partition, space
from reaxnet.jax_nb.parameters import pqeq_parameters
from functools import partial

## Read data
We should convert numpy array to jax.numpy array. See difference between original numpy array and jax.numpy array in this [site](https://jax.readthedocs.io/en/latest/jax.numpy.html).

In [2]:
atoms = read('Li2PO2N.cif')
positions = jnp.array(atoms.get_scaled_positions())
symbols = atoms.get_chemical_symbols()
cell = jnp.array(atoms.cell.array.T)

## Define the displacement function

In [3]:
displacement_fn, _ = space.periodic_general(box=cell, fractional_coordinates=True)

## Define the neighbor list function

In [4]:
neighbor_fn = partition.neighbor_list(displacement_fn,
                                      box=cell,
                                      r_cutoff=12.5,
                                      format=partition.Sparse,
                                      fractional_coordinates=True,
                                      capacity_multiplier=2.0)


## Load pqeq parameters

In [5]:
rad = jnp.array([pqeq_parameters[s]['rad'] for s in symbols])
alpha = 0.5 * LAMBDA / rad / rad
alpha = jnp.sqrt(alpha.reshape(-1,1) * alpha.reshape(1,-1) / (alpha.reshape(-1,1) + alpha.reshape(1,-1)))
chi0 = jnp.array([pqeq_parameters[s]['chi0'] for s in symbols])
eta0 = jnp.array([pqeq_parameters[s]['eta0'] for s in symbols])
z = jnp.array([pqeq_parameters[s]['Z'] for s in symbols])
Ks = jnp.array([pqeq_parameters[s]['Ks'] for s in symbols])

## Define a loop function to solve partial charges

In [6]:
charges_fn = partial(pqeq_fori_loop, displacement_fn,
                                     alpha=alpha, cutoff=12.5, iterations=2, net_charge=0.0,
                                     eta0=eta0, chi0=chi0, z=z, Ks=Ks)


## Define energy function

In [7]:
energy_fn_nb = partial(nonbond_potential, displacement_fn,
                               # key args pqeq
                      alpha=alpha, cutoff=12.5, eta0=eta0, chi0=chi0, z=z, Ks=Ks,
                               # key args d3
                      atomic_numbers=jnp.array(atoms.numbers), compute_d3=False, 
                      # PBE zero damping parameters
                      d3_params={'s6': 1.0, 'rs6': 1.217, 's18': 0.722, 'rs18': 1.0, 'alp': 14.0},
                      damping='zero', smooth_fn=None)


In [8]:
def energy_fn(positions, nbr,**displ_kwargs):
    nbr = nbr.update(positions, **displ_kwargs)
    charges, r_shell = charges_fn(jax.lax.stop_gradient(positions), nbr)
    pe_nb = energy_fn_nb(positions, nbr, r_shell, charges, **displ_kwargs)
    return pe_nb, (charges, r_shell)


## Jit for acceleration

In [9]:
nbr = neighbor_fn.allocate(positions)
value_and_grad_fn = jax.jit(jax.value_and_grad(partial(energy_fn, nbr=nbr), argnums=0, has_aux=True))


## Get results

In [10]:
t = time.time()
results = value_and_grad_fn(positions, box=cell)
t = time.time() - t
print(f"Time to compute energy and gradient: {t:.2f} s")

Time to compute energy and gradient: 2.71 s


In [11]:
pe = np.asarray(results[0][0])
forces = np.asarray(-results[1])
charges = np.asarray(results[0][1][0])

In [12]:
print(pe)

-592.375


In [13]:
for i in forces:
    print(i)

[ 0.03590737 -0.04427568  0.00150718]
[0.03591576 0.04428577 0.00147697]
[-0.03597849  0.04434611  0.00154626]
[-0.03582811 -0.04434761  0.00143227]
[ 0.0358956  -0.04431414  0.00149903]
[0.03589303 0.04434717 0.00153489]
[-0.03597565  0.04429619  0.00149233]
[-0.03587262 -0.04432746  0.0015046 ]
[-2.4467870e-03  5.5903336e-05  4.2212196e-06]
[ 2.4619722e-03  1.9973610e-05 -5.9171580e-06]
[-2.5286418e-03 -3.0279974e-05 -7.0404261e-05]
[ 2.5465142e-03 -2.0967907e-05 -1.6514510e-05]
[-5.7206739e-02  4.3120235e-06  1.2250954e-02]
[ 5.7045266e-02 -1.4290214e-05  1.2232358e-02]
[-5.70769459e-02  6.71017915e-07  1.21901315e-02]
[5.7132743e-02 5.5289827e-05 1.2276527e-02]
[ 0.05011837  0.06925315 -0.00758929]
[ 0.0501386  -0.06907848 -0.00753374]
[-0.05014071 -0.06919979 -0.00756007]
[-0.05013699  0.06918366 -0.00758684]
[ 0.05012479  0.06914984 -0.00764345]
[ 0.05014459 -0.06916161 -0.00758295]
[-0.05015745 -0.06910705 -0.00759446]
[-0.05018462  0.0691582  -0.00755359]
[ 0.03588421 -0.044327

In [14]:
for i in charges:
    print(i)

0.37500295
0.37500528
0.37502065
0.37501678
0.37501904
0.3750121
0.3750109
0.37501788
0.058882836
0.05886814
0.058877792
0.05884354
-0.27176946
-0.27178088
-0.2717858
-0.2717726
-0.2685643
-0.26857054
-0.26856714
-0.2685617
-0.26856884
-0.26855847
-0.2685419
-0.26855892
0.3750088
0.3750197
0.37499848
0.3750074
0.37501925
0.37500975
0.37500235
0.37500194
0.058919836
0.058859576
0.05887129
0.058878712
-0.27176824
-0.27176636
-0.27177477
-0.2717768
-0.26858413
-0.2685702
-0.2685686
-0.26856384
-0.26857433
-0.26856753
-0.26854494
-0.26855397
0.37500304
0.3750034
0.3750062
0.37500432
0.37501237
0.37500453
0.37499568
0.37500215
0.05884924
0.058816224
0.058875162
0.058925465
-0.27174535
-0.2717656
-0.2717908
-0.2717972
-0.2685589
-0.26855868
-0.26857218
-0.26856402
-0.26856637
-0.2685596
-0.2685392
-0.26855597
0.37500754
0.37501472
0.37501818
0.3750106
0.37500307
0.37500268
0.3749989
0.37500992
0.058856413
0.058858134
0.058864214
0.058877148
-0.27172336
-0.27177086
-0.27178475
-0.2717956
-0.2