In [1]:

%load_ext autoreload
%autoreload 2

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'


In [2]:
import jax
from jax import numpy as jnp, random as rnd, jit, vmap
from pd_code.config import Config
from jax_md.partition import neighbor_list, NeighborListFormat, to_jraph
from jax_md import space

import numpy as onp
from jraph._src import utils
from typing import Sequence
import jraph
from flax import linen as nn

from jax_md import energy

from jax_md import quantity

import time 

from functools import partial
import numpy as onp

import jax.numpy as np

from jax import jit
from jax import grad
from jax import vmap
from jax import value_and_grad

from jax import random
from jax import lax

from jax.experimental import stax
from jax.experimental import optimizers

from jax.config import config

# NOTE(schsam): We need this in OSS I think.
from IPython.display import HTML, display
import time



In [3]:
cfg = Config()
key = rnd.PRNGKey(cfg.seed)

positions = rnd.uniform(key, (cfg.n_particles, cfg.n_dim))

displacement_fn, shift_fn = space.periodic(cfg.side, wrapped=True)
vmap_displacement_fn = vmap(lambda r0, r1: r0 - r1, in_axes=(0, 0), out_axes=0)
displacements = vmap_displacement_fn(positions, positions)

# Dense: (N, n_max_neighbors_per_atom), Sparse (2, n_max_neighbors): Ordered Sparse: Sparse but half (no mirrors)
neighbor_fn = neighbor_list(displacement_fn, box_size=1., r_cutoff=0.3, dr_threshold=0.01, format=NeighborListFormat.Sparse)
nbrs = neighbor_fn.allocate(positions)

# jraph_tuple = to_jraph(nbrs)
# print(nbrs)
# print(jraph_tuple)

In [7]:
raw_data_test = np.load('/home/amawi/projects/mol-td/data/uracil_dft.npz')
keys = raw_data_test.keys()
print(list(keys))
print('R' in keys)
raw_data_test.get('R', None)

['E', 'name', 'F', 'theory', 'R', 'z', 'type', 'md5']
True


array([[[ 1.63200000e+00,  2.95100000e-01, -6.37000000e-02],
        [ 1.44620000e+00, -1.03750000e+00,  4.14000000e-02],
        [ 1.57900000e-01, -1.58810000e+00,  1.15800000e-01],
        ...,
        [ 4.33000000e-02, -2.57330000e+00,  1.92700000e-01],
        [-1.62410000e+00,  1.22990000e+00, -5.14000000e-02],
        [ 2.63630000e+00,  7.28600000e-01, -1.21900000e-01]],

       [[ 1.63438356e+00,  2.95888310e-01, -6.02989200e-02],
        [ 1.44408771e+00, -1.03792413e+00,  4.11265200e-02],
        [ 1.56076510e-01, -1.58651275e+00,  1.14934410e-01],
        ...,
        [ 4.29233400e-02, -2.59002107e+00,  1.65317530e-01],
        [-1.62925491e+00,  1.21958722e+00, -4.03368300e-02],
        [ 2.64457191e+00,  7.14081440e-01, -9.98188800e-02]],

       [[ 1.63675376e+00,  2.96713960e-01, -5.69050200e-02],
        [ 1.44186630e+00, -1.03847227e+00,  4.08707500e-02],
        [ 1.54428440e-01, -1.58459168e+00,  1.14035550e-01],
        ...,
        [ 4.15212900e-02, -2.60694132e+00,

In [None]:
def untransform(data, old_min, old_max, new_min=-1, new_max=1, mean=None):
    data = ((data - old_min) / (old_max - old_min)) * (new_max - new_min) + new_min
    if mean is not None:
        data = data + mean
    return data

def transform(data, old_min, old_max, new_min=-1, new_max=1, mean=None):
    if mean is not None:
        data = data - mean
    data = ((data - old_min) / (old_max - old_min)) * (new_max - new_min) + new_min
    return data

class DataReader:
    def __init__(self, cfg):
        raw_data = np.load(cfg.data_path)
        keys = raw_data.keys()

        for feature in cfg.node_features:
            assert feature in keys(), f'{feature} not in dataset'
        
        for name in cfg.node_features:
            tmp_min, tmp_max, tmp_mean = self.get_state(raw_data[name])
            setattr(cfg, f'{name}_min', tmp_min)
            setattr(cfg, f'{name}_max', tmp_max)
            setattr(cfg, f'{name}_mean', tmp_mean)
            setattr(cfg, f'{name}', raw_data[name])
        
        setattr(cfg, 'R_lims', tuple((cfg.R_min[i], cfg.R_max[i]) for i in range(3)))
        setattr(cfg, 'n_nodes', raw_data['R'].shape[1])

        print(f' Pos-Lims: {tuple((float(cfg.R_min[i]), float(cfg.R_max[i])) for i in range(3))} \
                \n F-Lims: {tuple((float(cfg.F_min[i]), float(cfg.F_max[i])) for i in range(3))} \
                \n A-Lims: {int(cfg.z_min)} {int(cfg.z_max)}')

        n_data, n_nodes, _ = raw_data['R'].shape

        # Transform
        box_size = raw_data.get('box_size', False)
        if box_size:
            positions = transform(raw_data['R'], 0, box_size, mean=box_size/2.)
        else:
            positions = transform(raw_data['R'], cfg.R_min, cfg.R_max, mean=cfg.R_mean)
            
        node_id = raw_data.get('z', False)
        if node_id:
            raw_data['z'] = jax.nn.one_hot((node_id-1), int(max(node_id)), dtype=jnp.float32)
            raw_data['z'] = raw_data['z'][None, :].repeat(n_data, axis=0)
        
        if 'F' in keys:
            forces = transform(raw_data['F'], -1, 1, mean=cfg.F_mean)

        node_features = jnp.concatenate([raw_data[feature] for feature in cfg.node_features], axis=-1)

        setattr(cfg, 'n_node_features', node_features.shape[-1] * n_nodes)
        setattr(cfg, 'n_target_features', 3 * n_nodes)
        
        
        data = jnp.concatenate([positions, forces, atoms], axis=-1)
        target = positions


        # Get neighbor lists for all data points 
        if raw_data.get('periodic', False):
            displacement_fn, shift_fn = space.periodic(cfg.side, wrapped=True)
        else:
            displacement_fn, shift_fn = space.free()
        
        neighbor_fn = neighbor_list(displacement_fn, 
                                    box_size=cfg.box_size, 
                                    r_cutoff=cfg.r_cutoff,
                                    dr_threshold=cfg.dr_threshold,  # when the neighbor list updates
                                    format=NeighborListFormat.Sparse)

        nbrs = neighbor_fn.allocate(raw_data['R'][0], extra_capacity=6)  # extra capacity to reduce jit overhead 
        
        node_features = jnp.concatenate([x[key] for key in cfg.node_features], axis=-1)

        

        return positions, forces, atoms


In [None]:
cfg = Config()

class ExplicitMLP(nn.Module):
  """A flax MLP."""
  features: Sequence[int]

  @nn.compact
  def __call__(self, inputs):
    x = inputs
    for i, lyr in enumerate([nn.Dense(feat) for feat in self.features]):
      x = lyr(x)
      if i != len(self.features) - 1:
        x = nn.relu(x)
    return x

def make_embed_fn(latent_size):
  def embed(inputs):
    return nn.Dense(latent_size)(inputs)
  return embed


def make_mlp(features):
  @jraph.concatenated_args
  def update_fn(inputs):
    return ExplicitMLP(features)(inputs)
  return update_fn

class GraphNetwork(nn.Module):
  """A flax GraphNetwork."""
  mlp_features: Sequence[int]
  latent_size: int

  @nn.compact
  def __call__(self, graph):
    # Add a global parameter for graph classification.
    graph = graph._replace(globals=jnp.zeros([graph.n_node.shape[0], 1]))

    embedder = jraph.GraphMapFeatures(
        embed_node_fn=make_embed_fn(self.latent_size),
        embed_edge_fn=make_embed_fn(self.latent_size),)
        # embed_global_fn=make_embed_fn(self.latent_size))
    
    net = jraph.GraphNetwork(
        update_node_fn=make_mlp(self.mlp_features),
        update_edge_fn=make_mlp(self.mlp_features),)
        # The global update outputs size 2 for binary classification.
        # update_global_fn=make_mlp(self.mlp_features + (2,)))  # pytype: disable=unsupported-operands
    return net(embedder(graph))

net = GraphNetwork(mlp_features=(cfg.graph_mlp_features, cfg.graph_mlp_features), 
                   latent_size=cfg.graph_mlp_features, 
                   aggregate_edges_for_nodes_fn=utils.segment_mean)

In [46]:


def ProgressIter(iter_fun, iter_len=0):
  if not iter_len:
    iter_len = len(iter_fun)
  out = display(progress(0, iter_len), display_id=True)
  for i, it in enumerate(iter_fun):
    yield it
    out.update(progress(i + 1, iter_len))

def progress(value, max):
    return HTML("""
        <progress
            value='{value}'
            max='{max}',
            style='width: 45%'
        >
            {value}
        </progress>
    """.format(value=value, max=max))

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
  
sns.set_style(style='white')
sns.set(font_scale=1.6)

def format_plot(x, y):  
  plt.xlabel(x, fontsize=20)
  plt.ylabel(y, fontsize=20)
  
def finalize_plot(shape=(1, 1)):
  plt.gcf().set_size_inches(
    shape[0] * 1.5 * plt.gcf().get_size_inches()[1], 
    shape[1] * 1.5 * plt.gcf().get_size_inches()[1])
  plt.tight_layout()

f32 = np.float32
f64 = np.float64

def draw_system(R, box_size, marker_size, color=None):
  if color == None:
    color = [64 / 256] * 3
  ms = marker_size / box_size

  R = onp.array(R)

  marker_style = dict(
      linestyle='none', 
      markeredgewidth=3,
      marker='o', 
      markersize=ms, 
      color=color, 
      fillstyle='none')

  plt.plot(R[:, 0], R[:, 1], **marker_style)
  plt.plot(R[:, 0] + box_size, R[:, 1], **marker_style)
  plt.plot(R[:, 0], R[:, 1] + box_size, **marker_style)
  plt.plot(R[:, 0] + box_size, R[:, 1] + box_size, **marker_style)
  plt.plot(R[:, 0] - box_size, R[:, 1], **marker_style)
  plt.plot(R[:, 0], R[:, 1] - box_size, **marker_style)
  plt.plot(R[:, 0] - box_size, R[:, 1] - box_size, **marker_style)

  plt.xlim([0, box_size])
  plt.ylim([0, box_size])
  plt.axis('off')

def square_lattice(N, box_size):
  Nx = int(np.sqrt(N))
  Ny, ragged = divmod(N, Nx)
  if Ny != Nx or ragged:
    assert ValueError('Particle count should be a square. Found {}.'.format(N))
  length_scale = box_size / Nx
  R = []
  for i in range(Nx):
    for j in range(Ny):
      R.append([i * length_scale, j * length_scale])
  return np.array(R)

N = 256
box_size = quantity.box_size_at_number_density(particle_count=N, 
                                               number_density=1, 
                                               spatial_dimension=2)

r = square_lattice(N, box_size)
r = r.at[0, 0].set(12.1)
# draw_system(r, box_size, 270.0)
# finalize_plot((0.75, 0.75))

print(r.shape, box_size, jnp.max(r), jnp.min(r))

displacement_fn, shift_fn = space.periodic(box_size)

energy_fn = energy.soft_sphere_pair(displacement_fn, per_particle=True)
energy = energy_fn(r)
print(energy)
# print('Energy of the system, U = {:f}'.format(energy))

(256, 2) 16.0 15.0 0.0
[0.205  0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0. 

In [27]:
def batch_graph(node_features, edge_features):
    # can't be vmapped because the different length lists?

    # ?? do we need this ?? Yes! The node features are stacked (they reused the weights), 
    # but the receivers and senders are cumulative (only referring to each other)

    # Might need optimisation, look here https://github.com/google/flax/blob/main/examples/ogbg_molpcba/input_pipeline.py
    # and use the convert to graphs tuple function and map. Jit won't work because it suuuucks

    graphs = []
    for nf, ef in zip(node_features, edge_features):
        graph = jraph.GraphsTuple(nodes=nf, 
                                  receivers=nbrs[0, :], 
                                  senders=nbrs[1, :],
                                  edges=ef, 
                                  n_node=jnp.array([len(nf)]), 
                                  n_edge=jnp.array([len(ef)]), 
                                  globals=jnp.array([0.0]))  # could be for temperature or something global
        
        graphs.append(graph)
    batch = jraph.batch(graphs)
    return batch

# Given a batch of positions

# Compute the neighbor list for each batch
neighbor = neighbor_fn.allocate(positions[0], extra_capacity=6)  # the neighbor list only refers to the first ones?? 

# vmap the train function!



In [35]:
# Testing to see how the batching system works, does it make sense? 

node_features = jnp.repeat(jnp.array([[0.], [1.], [2.]]), 2, axis=1)

senders = jnp.repeat(jnp.array([0, 1, 2])[:, None], 2, axis=1)
receivers = jnp.repeat(jnp.array([1, 2, 0])[:, None], 2, axis=1)

# You can optionally add edge attributes.
edges = jnp.repeat(jnp.array([[5.], [6.], [7.]]), 2, axis=1)

n_node = jnp.array([[3], [3]])
n_edge = jnp.array([[3], [3]])

global_context = jnp.array([[1], [1]]) # Same feature dimensions as nodes and edges.
graph1 = jraph.GraphsTuple(nodes=node_features, senders=senders, receivers=receivers,
edges=edges, n_node=n_node, n_edge=n_edge, globals=global_context)

# Define a three node graph, each node has an integer as its feature.
node_features = jnp.array([[0.], [1.], [2.]])
senders = jnp.array([0, 1, 2])
receivers = jnp.array([1, 2, 0])
edges = jnp.array([[5.], [6.], [7.]])
n_node = jnp.array([3])
n_edge = jnp.array([3])
global_context = jnp.array([[1]]) # Same feature dimensions as nodes and edges.
graph = jraph.GraphsTuple(nodes=node_features, senders=senders, receivers=receivers,
edges=edges, n_node=n_node, n_edge=n_edge, globals=global_context)

graph = jraph.batch([graph, graph])
print(graph1.nodes, graph.nodes)
print(graph1.senders, graph.senders)

[[0. 0.]
 [1. 1.]
 [2. 2.]] [[0.]
 [1.]
 [2.]
 [0.]
 [1.]
 [2.]]
[[0 0]
 [1 1]
 [2 2]] [0 1 2 3 4 5]


In [29]:
# testing the displacement function
# custom displacement functions don't work, because they can't accept the t=0 parameter

vmap_displacement_fn = vmap(lambda r0, r1: r0 - r1, in_axes=(0, 0), out_axes=0)
displacement_fn = lambda r0, r1: r0 - r1
R = ShapedArray((1,), f32)
eval_shape(displacement_fn, R, R, t=0)


TypeError: <lambda>() got an unexpected keyword argument 't'