### 3D NVT Nose-Hoover Dynamics for Crystallization of Particles under Lennard-Jones Potential

In [3]:
from jax import random
from jax_md import quantity

key = random.PRNGKey(0)
num_particles = 24
dim = 3

box_size = quantity.box_size_at_number_density(
    particle_count=num_particles,
    number_density=0.1,
    spatial_dimension=dim
)

R = random.uniform(key, (num_particles, dim), maxval=box_size)

In [4]:
# use soft sphere potential to separate particles first
import jax.numpy as jnp
from jax import vmap, grad
from jax_md import space, energy, simulate, minimize
from tqdm import tqdm

displacement_fn, shift_fn = space.periodic(box_size)

energy_fn = energy.soft_sphere_pair(displacement_fn)
init_fn, apply_fn = minimize.fire_descent(energy_fn, shift_fn)

state = init_fn(R)
for _ in tqdm(range(5000)):
    state = apply_fn(state)
R = state.position

  return _reduction(a, "sum", lax.add, 0, preproc=_cast_to_numeric,
  return lax.convert_element_type(result, dtype or result_dtype)
100%|██████████| 5000/5000 [00:03<00:00, 1590.38it/s]


In [7]:
energy_fn = energy.lennard_jones_pair(displacement_fn, epsilon=2.0, sigma=0.9)
force_fn = quantity.force(energy_fn)

# # routine for energy minimization via FIRE
# init_fn, apply_fn = minimize.fire_descent(force_fn, shift_fn)
# state = init_fn(R)

# for _ in tqdm(range(5000)):
#     state = apply_fn(state)
# R_final = state.position

# routine for simulation via NVT dynamics
init_fn, apply_fn = simulate.nvt_nose_hoover(energy_fn, shift_fn, kT=0.1, dt=1e-3)
state = init_fn(key, R)

trajectory = []

for _ in tqdm(range(500)):
    for __ in range(200):
        state = apply_fn(state)
    trajectory.append(state.position)
    
R_final = state.position

100%|██████████| 500/500 [04:38<00:00,  1.80it/s]


In [8]:
# Render out the data into an XYZ file

import numpy as np

trajectory_np = np.array(trajectory)  # shape: (n_frames, 256, 3)
n_frames, N, dim = trajectory_np.shape

with open("trajectory.xyz", "w") as f:
    for t in tqdm(range(n_frames)):
        f.write(f"{N}\n")
        f.write(f"Frame {t}\n")
        
        positions = trajectory_np[t]
        for i in range(N):
            x, y, z = positions[i]
            f.write(f"Ar {x} {y} {z}\n")

100%|██████████| 500/500 [00:00<00:00, 827.14it/s]
