In [1]:
#@title Imports

import jax.numpy as np
import numpy as onp
from jax import jit
from jax import random
from jax import lax
import jax_md

from jax.config import config
config.update('jax_enable_x64', True)

from jax_md import space
from jax_md import energy
from jax_md import simulate
from jax_md import quantity
from jax_md import partition

# from jax_md.colab_tools import renderer

In [14]:
lattice_constant = 1.37820
N_rep = 5 # int(10000 ** (1/3))
box_size = N_rep * lattice_constant * 2
# Using float32 for positions / velocities, but float64 for reductions.
dtype = np.float32

# Specify the format of the neighbor list. 
# Options are Dense, Sparse, or OrderedSparse. 
format = partition.OrderedSparse

displacement, shift = space.periodic(box_size)

In [15]:
R = []
for i in range(N_rep):
  for j in range(N_rep):
    for k in range(N_rep):
      R += [[i, j, k]]
R = np.array(R, dtype=dtype) * lattice_constant

In [16]:
N = R.shape[0]
phi = N / (box_size) ** 3
print(f'Created a system of {N} LJ particles with number density {phi:.3f}')

Created a system of 125 LJ particles with number density 0.048


In [23]:
neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(displacement,
                                                            box_size, 
                                                            r_cutoff=3.0,
                                                            epsilon=1.0,
                                                            sigma=1.0,
                                                            dr_threshold=1.,
                                                            format=format)

In [24]:
nbrs = neighbor_fn.allocate(R)
energy_fn(R, nbrs)

Array(-188.36255, dtype=float32)

In [25]:
init, apply = simulate.nvt_nose_hoover(energy_fn, shift, 5e-3, kT=0.75)

FOO


In [26]:
key = random.PRNGKey(0)

# We pick an "extra capacity" to ensure ahead of time that the neighbor
# list will have enough capacity. Since sparse neighbor lists are more 
# robust to changes in the number of particles, in this case we only 
# need to actually add more capacity for dense neighbor lists.
if format is partition.Dense: 
  nbrs = neighbor_fn.allocate(R, extra_capacity=55)
else:
  nbrs = neighbor_fn.allocate(R)

state = init(key, R, neighbor=nbrs)

In [27]:
def step(state_and_nbrs, unused):
  state, nbrs = state_and_nbrs
  nbrs = nbrs.update(state.position)
  return (apply(state, neighbor=nbrs), nbrs), energy_fn(state.position, nbrs)


In [28]:
(new_state, new_nbrs), energies = lax.scan(step, (state, nbrs), np.ones((5000)))


In [29]:
(energies/N)[3000:].mean()
# energies

Array(0., dtype=float32)

In [30]:
energy_fn(R, nbrs)
# R.shape
# Array(0.16502948, dtype=float32)


Array(-188.36255, dtype=float32)

In [31]:
new_nbrs

NeighborList(idx=Array([[  0,   1,   0, ..., 125, 125, 125],
       [  1,   2,   2, ..., 125, 125, 125]], dtype=int32), reference_position=Array([[ 0.816599  ,  0.78013045,  1.3134943 ],
       [ 0.865217  , 13.30726   ,  1.0350077 ],
       [ 1.319262  , 13.276201  ,  2.1949813 ],
       [ 1.6175023 ,  1.426612  ,  4.6933904 ],
       [ 0.64612055,  0.41510585,  5.4012995 ],
       [10.890388  ,  0.35618663,  2.9244862 ],
       [ 0.29788017,  0.0802899 ,  2.2079988 ],
       [ 1.4438272 ,  1.7487953 ,  1.6585805 ],
       [13.350106  ,  0.82188624,  2.7625377 ],
       [12.133874  ,  1.2305152 ,  5.1647973 ],
       [13.538495  ,  3.7277586 ,  1.3926023 ],
       [13.745797  ,  1.4298304 ,  2.0376968 ],
       [ 0.7478255 ,  1.5557102 ,  2.7435997 ],
       [12.56515   ,  2.6466334 ,  2.3977683 ],
       [13.672601  ,  2.4747167 ,  2.2498424 ],
       [ 0.2260567 ,  3.5513508 ,  0.30416477],
       [ 1.6282372 ,  4.060674  ,  2.6097388 ],
       [ 0.15377094,  2.7241244 ,  3.496886  

In [60]:
ran = np.linspace(1e-5, 1e-4, num = 50)
gr = jax_md.quantity.pair_correlation(displacement, ran, 1.0)

In [61]:
gr(state.position[:]).shape
# state.position.shape
# np.logspace(0.1,1.0)

(125, 50)

In [63]:
import seaborn as sns
sns.scatterplot(x=ran, y=jax.vmap(gr)([state.position]))

NameError: name 'jax' is not defined