# Basic usage
Example calculations of potential energy and forces of LiPON electrolyte.

In [None]:
! pip install ase

In [None]:
from ase.io import read
import numpy as np
import jax
import jax.numpy as jnp
from jax_nb.jax_nb import pqeq_fori_loop, nonbond_potential, LAMBDA
from jax_md import partition, space
from jax_nb.parameters import pqeq_parameters
from functools import partial

## Read data
We should convert numpy array to jax.numpy array. See difference between original numpy array and jax.numpy array in this [site](https://jax.readthedocs.io/en/latest/jax.numpy.html).

In [None]:
atoms = read('Li2PO2N.cif')
positions = jnp.array(atoms.get_scaled_positions())
symbols = atoms.get_chemical_symbols()
cell = jnp.array(atoms.cell.array.T)

## Define the displacement function

In [None]:
displacement_fn, _ = space.periodic_general(box=cell, fractional_coordinates=True)

## Define the neighbor list function

In [None]:
neighbor_fn = partition.neighbor_list(displacement_fn,
                                      box=cell,
                                      r_cutoff=12.5,
                                      format=partition.Sparse,
                                      fractional_coordinates=True,
                                      capacity_multiplier=2.0)


## Load pqeq parameters

In [None]:
rad = jnp.array([pqeq_parameters[s]['rad'] for s in symbols])
alpha = 0.5 * LAMBDA / rad / rad
alpha = jnp.sqrt(alpha.reshape(-1,1) * alpha.reshape(1,-1) / (alpha.reshape(-1,1) + alpha.reshape(1,-1)))
chi0 = jnp.array([pqeq_parameters[s]['chi0'] for s in symbols])
eta0 = jnp.array([pqeq_parameters[s]['eta0'] for s in symbols])
z = jnp.array([pqeq_parameters[s]['Z'] for s in symbols])
Ks = jnp.array([pqeq_parameters[s]['Ks'] for s in symbols])

## Define a loop function to solve partial charges

In [None]:
charges_fn = partial(pqeq_fori_loop, displacement_fn,
                                     alpha=alpha, cutoff=12.5, iterations=2, net_charge=0.0,
                                     eta0=eta0, chi0=chi0, z=z, Ks=Ks)


## Define energy function

In [None]:
energy_fn_nb = partial(nonbond_potential, displacement_fn,
                               # key args pqeq
                      alpha=alpha, cutoff=12.5, eta0=eta0, chi0=chi0, z=z, Ks=Ks,
                               # key args d3
                      atomic_numbers=jnp.array(atoms.numbers), compute_d3=False, 
                      # PBE zero damping parameters
                      d3_params={'s6': 1.0, 'rs6': 1.217, 's18': 0.722, 'rs18': 1.0, 'alp': 14.0},
                      damping='zero', smooth_fn=None)


In [None]:
def energy_fn(positions, nbr,**displ_kwargs):
    nbr = nbr.update(positions, **displ_kwargs)
    charges, r_shell = charges_fn(jax.lax.stop_gradient(positions), nbr, **displ_kwargs)
    pe_nb = energy_fn_nb(positions, nbr, r_shell, charges, **displ_kwargs)
    return pe_nb, (charges, r_shell)


## Jit for acceleration

In [None]:
nbr = neighbor_fn.allocate(positions)
value_and_grad_fn = jax.jit(jax.value_and_grad(partial(energy_fn, nbr=nbr), argnums=0, has_aux=True))


## Get results

In [None]:
results = value_and_grad_fn(positions, box=cell)

In [None]:
pe = np.asarray(results[0][0])
forces = np.asarray(-results[1])
charges = np.asarray(results[0][1][0])