<a href="https://colab.research.google.com/github/pinballsurgeon/deluxo_adjacency/blob/main/N2_O2_CO2_NVT_simulator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

###Installs

In [1]:
# jax molecular dynamics
!pip install jax-md

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting jax-md
  Downloading jax_md-0.1.29-py2.py3-none-any.whl (95 kB)
[K     |████████████████████████████████| 95 kB 1.3 MB/s 
Collecting jraph
  Downloading jraph-0.0.5.dev0-py3-none-any.whl (90 kB)
[K     |████████████████████████████████| 90 kB 5.1 MB/s 
Collecting dataclasses
  Downloading dataclasses-0.6-py3-none-any.whl (14 kB)
Collecting dm-haiku>=0.0.2
  Downloading dm_haiku-0.0.7-py3-none-any.whl (342 kB)
[K     |████████████████████████████████| 342 kB 11.6 MB/s 
Collecting jmp>=0.0.2
  Downloading jmp-0.0.2-py3-none-any.whl (16 kB)
Installing collected packages: jmp, jraph, dm-haiku, dataclasses, jax-md
Successfully installed dataclasses-0.6 dm-haiku-0.0.7 jax-md-0.1.29 jmp-0.0.2 jraph-0.0.5.dev0


###Imports

In [2]:
#                          __         __              /\  .___
#            ____   _____/  |_      |__|____  ___  __)/__| _/
#           / ___\_/ __ \   __\     |  \__  \ \  \/  // __ | 
#          / /_/  >  ___/|  |       |  |/ __ \_>    </ /_/ | 
#          \___  / \___  >__|   /\__|  (____  /__/\_ \____ | 
#          /_____/      \/       \______|    \/      \/    \/ 


from jax.config import config ; config.update('jax_enable_x64', True)
import jax.numpy as np
from jax import random, jit, lax, ops
from jax_md import space, smap, energy, minimize, quantity, simulate

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
  
sns.set_style(style='white')

# details on new things you dont yet know
from inspect import signature
import inspect, re


  


Experiment configuration

In [3]:
N = 2600
dimension = 3
box_size = quantity.box_size_at_number_density( N, 0.8, dimension)
dt = 5e-3
displacement, shift = space.periodic(box_size) 

steps = 10000
write_every = 100


kT = lambda t: np.where(t < 5000.0 * dt, 0.01, 0.012)



Helper functions

In [4]:
# relative concentration
conc_n2 = 0.78084
conc_o2 = 0.20946
conc_ar = 0.009340
conc_co2 = 0.000407
conc_h2o = 0
conc_ch4 = 0.0000018


# kinetic diameters
kd_n2 = 3.64
kd_o2 = 3.46
kd_ar = 0
kd_co2 = 3.3
kd_h2o = 2.65
kd_ch4 = 3.8

# molecular weight
mw_n2 = 28
mw_o2 = 32
mw_ar = 0
mw_co2 = 44
mw_h2o = 18
mw_ch4 = 16

# molecular diameter
md_n2 = 0
md_o2 = 0
md_ar = 0
md_co2 = 0
md_h2o = 0
md_ch4 = 0


Helper functions

In [5]:
def step_fn(i, state_and_log):
  state, log = state_and_log

  t = i * dt

  # Log information about the simulation.
  T = quantity.temperature(state.velocity)
  log['kT'] = log['kT'].at[i].set(T)
  H = simulate.nvt_nose_hoover_invariant(energy_fn, state, kT(t))
  log['H'] = log['H'].at[i].set(H)
  # Record positions every `write_every` steps.
  log['position'] = lax.cond(i % write_every == 0,
                             lambda p: \
                             p.at[i // write_every].set(state.position),
                             lambda p: p,
                             log['position'])

  # Take a simulation step.
  state = apply(state, kT=kT(t))



  return state, log

In [6]:
key = random.PRNGKey(0)

In [7]:
key, split = random.split(key)
R = box_size * random.uniform(split, (N, dimension), dtype=np.float64)

# The system ought to be a 50:50 mixture of two types of particles, one
# large and one small.
# sigma = np.array([[1.0, 1.2], [1.2, 1.4]])
# sigma = np.array([[0.3, 0.3], [1.2, 1.2], [0.8, 0.8]])
sigma = np.array([[0.3, 0.3], [1.2, 1.2], [0.8, 0.8], [0.5, 0.5]])


N_2 = int(N * conc_o2)
N_3 = int(N * conc_ar)
N_4 = int(N * conc_co2)


buf = 0
buf_lst = []
for i in range(0, N):

  if buf < N_4:
    buf_lst.append(0)
  elif buf < N_3:
    buf_lst.append(1)
  elif buf < N_2:
    buf_lst.append(2)    
  else:
    buf_lst.append(3)

  buf += 1

In [8]:
species = np.array(buf_lst)


In [9]:
species.shape
species

DeviceArray([0, 1, 1, ..., 3, 3, 3], dtype=int64)

In [10]:
energy_fn = energy.soft_sphere_pair(displacement, species=species, sigma=sigma)

init, apply = simulate.nvt_nose_hoover(energy_fn, shift, dt, kT(0.))

print(type(init(key, R)))
state = init(key, R)


<class 'jax_md.simulate.NVTNoseHooverState'>


In [11]:
# establish log
log = {
    'kT': np.zeros((steps,)),
    'H': np.zeros((steps,)),
    'position': np.zeros((steps // write_every,) + R.shape) 
}

In [12]:
## object details

print("Let review what we've done"); print()

objs = [energy_fn, init, apply, key, R, log]

for obj in objs:

  # function deets
  try:

    var_nm = [key for key, value in locals().items() if value == obj]
    print(var_nm[0])

    print('   ', obj.__name__ ,' type -', type(obj))
    print('   ', str(signature(obj)))
    print('')
  except:
    pass

  # dict deets
  try:
    var_nm = [key for key, value in locals().items() if value == obj]
    print(' type -', type(obj))
    print('   ', str(signature(obj)))
    print('   ', obj.shape)
    print('')
  except:
    pass

  # array deets
  try:
    print(' type -', type(obj))
    print('   ', str(signature(obj)))
    print('   ', obj.shape)
    print('')
  except:
    pass





Let review what we've done

energy_fn
    fn_mapped  type - <class 'function'>
    (R, **dynamic_kwargs)

 type - <class 'function'>
    (R, **dynamic_kwargs)
 type - <class 'function'>
    (R, **dynamic_kwargs)
init
    init_fn  type - <class 'function'>
    (key, R, mass=DeviceArray(1., dtype=float32), **kwargs)

 type - <class 'function'>
    (key, R, mass=DeviceArray(1., dtype=float32), **kwargs)
 type - <class 'function'>
    (key, R, mass=DeviceArray(1., dtype=float32), **kwargs)
apply
    apply_fn  type - <class 'function'>
    (state, **kwargs)

 type - <class 'function'>
    (state, **kwargs)
 type - <class 'function'>
    (state, **kwargs)
 type - <class 'jaxlib.xla_extension.DeviceArray'>
 type - <class 'jaxlib.xla_extension.DeviceArray'>
log
 type - <class 'dict'>
 type - <class 'dict'>


In [None]:

state, log = lax.fori_loop(0, steps, step_fn, (state, log))

R = state.position

In [None]:
buf_lst = []
for i in species:
   
  # carbon dioxide
  if i == 0: 
    buf_lst.append(np.array([1.5, 3.2, 0.01]))

  # argon
  elif i == 1: 
    buf_lst.append(np.array([1.0, 0.2, 0.5]))   

  # oxygen
  elif i == 2: 
    buf_lst.append(np.array([3.0, 1.2, 2.5]))   

  # nitrogen
  elif i == 3: 
    buf_lst.append(np.array([0.3, .8, 0.85 ]))        


In [None]:
from jax_md.colab_tools import renderer

diameters = sigma[species, species]

colors = np.array(buf_lst)

renderer.render(box_size,
                {
                    'particles': renderer.Sphere(log['position'], 
                                               diameters,
                                               colors)   
                                      
                },
                resolution=(600, 600))

(for example, at 63 degrees F, CO2 molecules crash together about 7 billion times per second)

A CO2 molecule is 0.33nm diameter

The diameter of an O2 molecule is 292 picometers, and that of N2 is 300 picometers

In [None]:
log['position'].shape

In [None]:
initial_positions = log['position'][0]

In [None]:
log['position'][1]

In [None]:
# '''
# import imageio
# import jax.numpy as jnp

# def make_from_image(filename, size_in_pixels):
#   position = []
#   angle = []
#   color = []

#   img = imageio.imread(filename)

#   scale = 2**(1/6)
#   ratio = jnp.sqrt(1 - 0.25)
#   for i, y in enumerate(range(0, img.shape[0], size_in_pixels)):
#     for x in range(0, img.shape[1], size_in_pixels):
#       r, g, b, a = img[y, x]
#       if a == 255:
#         hshift = size_in_pixels * (i % 2) / 2.0
#         position += [[scale * (x + hshift) / size_in_pixels, scale * (img.shape[0] - y) / size_in_pixels * ratio]]
#         color += [[r / 255, g / 255, b / 255]]
#   img_size = jnp.array(img.shape[:2]).T / size_in_pixels * scale
#   box_size = jnp.max(img_size) * 1.5
#   position = jnp.array(position, jnp.float64) + box_size / 2.0 - img_size / 2
#   color = jnp.array(color, jnp.float64)

#   return box_size, position, color
#   '''

In [None]:
# '''
# box, positions, colors = make_from_image('mfi_three.png', 24)
# '''

In [None]:
# '''
# from jax_md.colab_tools import renderer

# renderer.render(box,
#                 renderer.Disk(positions, color=colors))

#                 '''

In [None]:

# 
# from jax_md import space

# displacement_fn, shift_fn = space.periodic(box)
# 

In [None]:


# positions[0]

In [None]:
# displacement_fn(positions[0], positions[-1])

In [None]:
# shift_fn(positions[0], jnp.array([10.0, 0.0]))

## Energy

"Energy" in Physics plays a similar role to "Loss" in machine learning. 

Write down an energy function between two grains of sand, $\epsilon(r)$. 

The total energy will be the sum of all pairs of energies.

$$E = \sum_{i,j} \epsilon(r_{ij})$$

where $r_{ij}$ is the distance between grain $i$ and grain $j$.


We want to model wet sand:

*   Grains are hard (no interpenetration).
*   Grains stick together a little bit.
*   Grains far away from one another don't notice each other.

In [None]:
# from jax_md import energy

# rs = jnp.linspace(0.5, 2.5)
# plt.plot(rs, energy.lennard_jones(rs))

# plt.ylim([-1, 1])
# plt.xlim([0, 2.5])
# plt.xlabel('$r_{ij}$')
# plt.ylabel('$\\epsilon$')

In [None]:
# sand_energy = energy.lennard_jones_pair(displacement_fn)

# sand_energy(positions)

## Simulate

In [None]:
# from jax import random

# simulation_steps = 10000
# write_every = 50
# key = random.PRNGKey(1)

In [None]:
# from jax_md import simulate
# from jax import jit

# init_fn, step_fn = simulate.nvt_langevin(sand_energy, shift_fn, dt=5e-3, kT=0.0, gamma=1e-2)

# sand = init_fn(key, positions)
# step_fn = jit(step_fn)

In [None]:

# trajectory = []

# for i in range(simulation_steps):
#  if i % write_every == 0:
#    trajectory += [sand.position]
    
#  sand = step_fn(sand)

# trajectory = jnp.stack(trajectory)

In [None]:
# renderer.render(box, renderer.Disk(trajectory, color=colors))