In [1]:
#@title Imports

import jax.numpy as np
import numpy as onp
from jax import jit
from jax import random
from jax import lax
import jax_md

from jax.config import config
# config.update('jax_enable_x64', True)

from jax_md import space
from jax_md import energy
from jax_md import simulate
from jax_md import quantity
from jax_md import partition

# from jax_md.colab_tools import renderer

In [45]:
lattice_constant = 1.37820
N_rep = 10
box_size = N_rep * lattice_constant
# Using float32 for positions / velocities, but float64 for reductions.
dtype = np.float32

# Specify the format of the neighbor list. 
# Options are Dense, Sparse, or OrderedSparse. 
format = partition.OrderedSparse

displacement, shift = space.periodic(box_size)

In [46]:
R = []
for i in range(N_rep):
  for j in range(N_rep):
    for k in range(N_rep):
      R += [[i, j, k]]
R = np.array(R, dtype=dtype) * lattice_constant

In [47]:
N = R.shape[0]
phi = N / (box_size) ** 3
print(f'Created a system of {N} LJ particles with number density {phi:.3f}')

Created a system of 1000 LJ particles with number density 0.382


In [48]:
energy_fn = energy.lennard_jones

In [49]:
# nbrs = neighbor_fn.allocate(R)
# energy_fn(R, nbrs)

In [50]:
init, apply = simulate.nvt_langevin(energy_fn, shift, 5e-3, kT=0.85, gamma=0.1)

In [51]:
key = random.PRNGKey(0)
state = init(key, R)

In [52]:
def step(state, unused):
  return apply(state), state.position
# energy_fn(state.position, nbrs)


In [53]:
new_state, states = lax.scan(step, state, np.ones((100)))


In [54]:
# (energies/N)[3000:].mean()
# energies


In [56]:
import mdtraj as md
import pandas as pd

# # traj = md.load('/Users/reubencohn-gordon/Downloads/ala2 2/sim4/trajectory.dcd', top='/Users/reubencohn-gordon/Downloads/ala2 2/ala2.pdb')
# traj, _ = md.load('./data/prod_alanine_dipeptide_amber/structure.pdb').topology.to_dataframe()
# print(traj)
# raise Exception

# new_t = traj.append([688,0])

n = 1000

new_t = pd.DataFrame({'name' : ['O']*n,
                      'element' : ['O']*n,
                      'serial' : range(n),
                      'resSeq' : [0]*n,
                      'resName' : ['0']*n,
                      'chainID' : [0]*n,

                      
                      })

# new_t.loc[:, 'name'] = 'HOH'
# new_t.loc[:, 'element'] = 'O'
# new_t.loc[:, 'resName'] = 'HOH'
# new_t.loc[:, 'resSeq'] = 1
# new_t.loc['name'] = ['HOH']*125
new_t

# states[:, :125, :].shape


top = md.Topology.from_dataframe(new_t)
# # states = np.where(np.isnan(states), 0, states)
traj = md.Trajectory(xyz=states[:, :, :], topology=top)

traj.save_pdb('./data/lj_small.pdb')


In [64]:
import jax
import math
from jax_md.simulate import Sampler

pos = state.position

T = 1.2
l = lambda x : energy_fn(np.reshape(x, pos.shape), nbrs) / T
value_grad = jax.value_and_grad(l)


class MD():


  def __init__(self, d):
    self.d = d

  def grad_nlogp(self, x):
    return value_grad(x)

  def transform(self, x):
    return x

  def prior_draw(self, key):
    return np.array(np.reshape(pos, math.prod(pos.shape)), dtype='float64')


# displacement_fn, shift_fn = space.periodic_general(box[0])

target = MD(d = math.prod(pos.shape))
sampler = Sampler(target, shift_fn=shift, 
                  # varEwanted=1e-6, 
                  frac_tune1=0.0,
                  frac_tune2=0.0,
                  frac_tune3=0.0,
                  eps=1e0
                  )

In [65]:

chain_length = 100
num_chains = 1
samples, energy, L, eps = sampler.sample(chain_length, num_chains, output= 'detailed')

In [70]:
import mdtraj as md
import pandas as pd

traj, _ = md.load('./data/prod_alanine_dipeptide_amber/structure.pdb').topology.to_dataframe()

# print(samples.shape)
# raise Exception

n = 64000

new_t = pd.DataFrame({'name' : ['O']*n,
                      'element' : ['O']*n,
                      'serial' : range(n),
                      'resSeq' : [0]*n,
                      'resName' : ['0']*n,
                      'chainID' : [0]*n,

                      
                      })


top = md.Topology.from_dataframe(new_t)
# # states = np.where(np.isnan(states), 0, states)
traj = md.Trajectory(xyz=np.reshape(samples, (100, 64000, 3)), topology=top)

traj.save_pdb('./data/lj_mclmc.pdb')


In [72]:
print(samples[0])
print(samples[1])

[ 0.13258595  0.20883596 55.01027739 ... 53.45025905 53.83830398
 53.79270948]
[ 0.13258898  0.20884516 55.01027526 ... 53.45027182 53.83830053
 53.79271843]


In [60]:
ran = np.linspace(1e-5, 1e-4, num = 50)
gr = jax_md.quantity.pair_correlation(displacement, ran, 1.0)

In [61]:
gr(state.position[:]).shape
# state.position.shape
# np.logspace(0.1,1.0)

(125, 50)

In [63]:
import seaborn as sns
sns.scatterplot(x=ran, y=jax.vmap(gr)([state.position]))

NameError: name 'jax' is not defined