Simple crystallization of solid spheres under LJ potential in 2D and 3D under NVT dynamics.

In [1]:
from jax import random
from jax_md import quantity

key = random.PRNGKey(0)
num_particles = 256
dim = 2

box_size = quantity.box_size_at_number_density(
    particle_count=num_particles,
    number_density=0.1,
    spatial_dimension=dim
)

R = random.uniform(key, (num_particles, dim), maxval=box_size)

In [2]:
# use soft sphere potential to separate particles first
import jax.numpy as jnp
from jax import vmap, grad
from jax_md import space, energy, simulate, minimize
from tqdm import tqdm

displacement_fn, shift_fn = space.periodic(box_size)

energy_fn = energy.soft_sphere_pair(displacement_fn)
init_fn, apply_fn = minimize.fire_descent(energy_fn, shift_fn)

state = init_fn(R)
for _ in tqdm(range(5000)):
    state = apply_fn(state)
R = state.position

  return _reduction(a, "sum", lax.add, 0, preproc=_cast_to_numeric,
  return lax.convert_element_type(result, dtype or result_dtype)
100%|██████████| 5000/5000 [00:05<00:00, 993.32it/s] 


In [3]:
import plotly.express as px
import plotly.graph_objects as go
import numpy as np

In [4]:
def plot_disks(R, radius=0.5, box_size=1.0, width=600, height=600):
    fig = go.Figure()

    # Add each circle as a shape
    for (x, y) in R:
        fig.add_shape(
            type="circle",
            x0=x - radius, x1=x + radius,
            y0=y - radius, y1=y + radius,
            line=dict(width=0),
            fillcolor="lightblue",
            opacity=0.8
        )

    # Set view box
    view_box_delta = 0.05 * box_size
    fig.update_xaxes(range=[0 - view_box_delta, box_size + view_box_delta], scaleanchor="y", scaleratio=1, title="x")
    fig.update_yaxes(range=[0 - view_box_delta, box_size + view_box_delta], title="y")

    fig.update_layout(width=width, height=height, template="plotly_dark", title={
        "text": "Particle Positions",
        "x": 0.5,
        "xanchor": "center"
    })
    fig.show()

In [None]:
plot_disks(R, radius=0.5, box_size=box_size)
# you can see pretty easily that the particles are separated and not overlapping

In [6]:
energy_fn = energy.lennard_jones_pair(displacement_fn, epsilon=2.0, sigma=0.9)
force_fn = quantity.force(energy_fn)

# # routine for energy minimization via FIRE
# init_fn, apply_fn = minimize.fire_descent(force_fn, shift_fn)
# state = init_fn(R)

# for _ in tqdm(range(5000)):
#     state = apply_fn(state)
# R_final = state.position

# routine for simulation via NVT dynamics
init_fn, apply_fn = simulate.nvt_nose_hoover(energy_fn, shift_fn, kT=0.1, dt=1e-3)
state = init_fn(key, R)

trajectory = []

for _ in tqdm(range(500)):
    for __ in range(200):
        state = apply_fn(state)
    trajectory.append(state.position)
    
R_final = state.position


Explicitly requested dtype float64 requested in sum is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.


Explicitly requested dtype float64 requested in convert_element_type is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.

100%|██████████| 500/500 [01:19<00:00,  6.26it/s]


In [None]:
plot_disks(R_final, radius=0.5, box_size=box_size)
# ideally, these particles have now crystallized. we can observe the formation of small crystals throughout the sample. very interesting.