In [2]:
import jax
import jax.numpy as jnp
import numpy as np
from jax import grad, vmap
from functools import partial
from tqdm import tqdm

key = jax.random.PRNGKey(0)

# -------------------------------
# Simulation parameters
# -------------------------------

N = 32
box_size = 10.0
dt = 1e-2
n_steps = 200
epsilon = 5.0
r0 = 1.2
sigma_r = 0.2
sigma_angle = 0.2

# -------------------------------
# Initialize state
# -------------------------------

R = jax.random.uniform(key, (N,3), minval=0.0, maxval=box_size)
v = jnp.zeros((N,3))
omega = jnp.zeros((N,3))

# Random unit quaternions
def random_quaternion(key):
    u1,u2,u3 = jax.random.uniform(key, (3,))
    q = jnp.array([
        jnp.sqrt(1-u1) * jnp.sin(2*jnp.pi*u2),
        jnp.sqrt(1-u1) * jnp.cos(2*jnp.pi*u2),
        jnp.sqrt(u1)   * jnp.sin(2*jnp.pi*u3),
        jnp.sqrt(u1)   * jnp.cos(2*jnp.pi*u3)
    ])
    return q / jnp.linalg.norm(q)

keys = jax.random.split(key, N)
q = vmap(random_quaternion)(keys)

# -------------------------------
# Patch directions (body frame)
# Two opposite patches
# -------------------------------

patch_body = jnp.array([
    [1.,0.,0.],
    [-1.,0.,0.]
])

# -------------------------------
# Quaternion utilities
# -------------------------------

def quat_mul(q1, q2):
    w1,x1,y1,z1 = q1
    w2,x2,y2,z2 = q2
    return jnp.array([
        w1*w2 - x1*x2 - y1*y2 - z1*z2,
        w1*x2 + x1*w2 + y1*z2 - z1*y2,
        w1*y2 - x1*z2 + y1*w2 + z1*x2,
        w1*z2 + x1*y2 - y1*x2 + z1*w2
    ])

def quat_conj(q):
    w,x,y,z = q
    return jnp.array([w,-x,-y,-z])

def rotate(q, v):
    qv = jnp.concatenate([jnp.array([0.0]), v])
    return quat_mul(quat_mul(q, qv), quat_conj(q))[1:]

# -------------------------------
# Energy function
# -------------------------------

def pair_energy(Ri, Rj, qi, qj):
    dR = Rj - Ri
    r = jnp.linalg.norm(dR)
    r_hat = dR / (r + 1e-8)

    # radial Gaussian well
    U_r = jnp.exp(-(r - r0)**2 / sigma_r**2)

    # rotate patches
    n_i = vmap(lambda pb: rotate(qi, pb))(patch_body)
    n_j = vmap(lambda pb: rotate(qj, pb))(patch_body)

    def patch_pair(ni):
        def inner(nj):
            Ai = jnp.dot(ni, r_hat)
            Aj = jnp.dot(nj, -r_hat)
            ang = jnp.exp(-(1-Ai)**2/sigma_angle**2) \
                * jnp.exp(-(1-Aj)**2/sigma_angle**2)
            return -epsilon * U_r * ang
        return jnp.sum(vmap(inner)(n_j))
    return jnp.sum(vmap(patch_pair)(n_i))

def total_energy(R, q):
    def particle_i(i):
        def particle_j(j):
            return jnp.where(j>i,
                pair_energy(R[i], R[j], q[i], q[j]),
                0.0)
        return jnp.sum(vmap(particle_j)(jnp.arange(N)))
    return jnp.sum(vmap(particle_i)(jnp.arange(N)))

# gradients
force_fn = grad(lambda R,q: total_energy(R,q), argnums=0)
quat_grad_fn = grad(lambda R,q: total_energy(R,q), argnums=1)

# -------------------------------
# Overdamped integration
# -------------------------------

trajectory_R = []
trajectory_q = []

for step in tqdm(range(n_steps)):

    F = -force_fn(R, q)
    G = -quat_grad_fn(R, q)

    R = R + dt * F

    # update quaternion with gradient descent step
    q = q + dt * G
    q = q / jnp.linalg.norm(q, axis=1, keepdims=True)

    R = R % box_size

    trajectory_R.append(np.array(R))
    trajectory_q.append(np.array(q))

trajectory_R = np.array(trajectory_R)
trajectory_q = np.array(trajectory_q)


100%|██████████| 200/200 [00:37<00:00,  5.39it/s]


In [3]:
n_frames = trajectory_R.shape[0]

with open("patchy_3d.lammpstrj", "w") as f:
    for t in range(n_frames):
        f.write("ITEM: TIMESTEP\n")
        f.write(f"{t}\n")
        f.write("ITEM: NUMBER OF ATOMS\n")
        f.write(f"{N}\n")
        f.write("ITEM: BOX BOUNDS pp pp pp\n")
        f.write(f"0 {box_size}\n0 {box_size}\n0 {box_size}\n")
        f.write("ITEM: ATOMS id type x y z qw qx qy qz\n")

        for i in range(N):
            x,y,z = trajectory_R[t,i]
            qw,qx,qy,qz = trajectory_q[t,i]
            f.write(f"{i+1} 1 {x} {y} {z} {qw} {qx} {qy} {qz}\n")
