In [1]:

%load_ext autoreload
%autoreload 2

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'


In [5]:
from jax import random as rnd, numpy as jnp 
from jax_md.partition import to_jraph, neighbor_list, NeighborListFormat
from jax_md import space
from jax import vmap, jit

key = rnd.PRNGKey(2)
n_nodes, n_dim = 12, 3
r_cutoff = 0.3

displacement_fn, shift_fn = space.free()
neighbor_fn = neighbor_list(displacement_fn,
                            capacity_multiplier=1.,
                            box_size=2., 
                            r_cutoff=r_cutoff,
                            dr_threshold=0.1,  # when the neighbor list updates
                            format=NeighborListFormat.Sparse,
                            mask_self=True)

x0 = rnd.normal(key, (1, n_nodes, n_dim)) * 0.01
nbrs0 = neighbor_fn.allocate(x0[0])  # extra capacity per node
# nbrsj0 = to_jraph(nbrs0)
# print(nbrsj0.senders)


def update_nbr(pos):
    nbr = nbrs0.update(pos)
    receivers = nbr.idx[0]
    return receivers
    # return nbr don't return in MWE

update_nbr = vmap(update_nbr, in_axes=(0,), out_axes=0)

x1 = rnd.normal(key, (10, n_nodes, n_dim)) * 0.1  # to ensure overflow 
x2 = rnd.normal(key, (10, n_nodes, n_dim)) * 100  # to ensure overflow 
x1 = jnp.concatenate([x1, x2], axis=0)
receivers = update_nbr(x1)
print(nbrs0.did_buffer_overflow)  # Answer is yes, but not clear for which sample, so not clear how to reinitialise

# print(receivers)
# so aim to reinit with quadratically increasing extra capacity 
# nbrs0 = neighbor_fn.allocate(x0, extra_capacity=2**n_times_reallocated)

print(x1[0, 0, :].shape)




True
(3,)


In [8]:
receivers[-1]

DeviceArray([12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
             12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
             12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
             12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
             12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
             12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
             12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
             12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
             12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12], dtype=int32)

In [37]:
from jax import random as rnd, numpy as jnp 
from jax_md.partition import to_jraph, neighbor_list, NeighborListFormat
from jax_md import space
from jax import vmap

def compute_rs(pos):
    return jnp.linalg.norm(pos[None, ...] - pos[:, None, :], axis=-1)

def where_is_less_than_rcut(pos, r_cutoff):
    rs = compute_rs(pos)
    # print(rs)
    print(rs < r_cutoff)

key = rnd.PRNGKey(1)

n_nodes, n_dim = 5, 3
r_cutoff = 0.1

displacement_fn, shift_fn = space.free()
neighbor_fn = neighbor_list(displacement_fn,
                            # capacity_multiplier=4.,
                            box_size=1., 
                            r_cutoff=r_cutoff,
                            dr_threshold=0.,  # when the neighbor list updates
                            format=NeighborListFormat.Sparse,
                            mask_self=True)

x0 = rnd.normal(key, (n_nodes, n_dim)) * 0.5
where_is_less_than_rcut(x0, r_cutoff)
nbrs0 = neighbor_fn.allocate(x0)  # extra capacity per node
nbrsj0 = to_jraph(nbrs0)
print(nbrsj0.senders)

x1 = rnd.normal(key, (n_nodes, n_dim)) * 0.1
where_is_less_than_rcut(x1, r_cutoff)
nbrs1 = nbrs0.update(x1)
nbrsj1 = to_jraph(nbrs1)
print(nbrsj1.senders)

print(nbrs1.did_buffer_overflow)


def update_nbr(pos, nbr):
    nbr = nbr.update(pos)
    return nbr

update_nbr = vmap(update_nbr, in_axes=(0, None))

x1 = rnd.normal(key, (10, n_nodes, n_dim)) * 0.01
nbrs = update_nbr(x1, nbrs0)
print(nbrs)



[[ True False False False False]
 [False  True False False False]
 [False False  True False False]
 [False False False  True False]
 [False False False False  True]]
[]
[[ True False False False False]
 [False  True False False False]
 [False False  True  True False]
 [False False  True  True False]
 [False False False False  True]]
[]
True
NeighborList(idx=DeviceArray([], dtype=int32), reference_position=DeviceArray([[[-4.81269648e-03,  1.06660801e-03,  3.29988683e-03],
              [ 6.02418417e-03,  1.43634696e-02,  7.83192087e-03],
              [-9.35661234e-03,  5.32855606e-03,  9.40408185e-03],
              [-7.38178659e-03,  3.81043553e-03,  1.28416019e-02],
              [-1.45765574e-04,  5.24004828e-03, -5.47317555e-03]],

             [[ 5.14350773e-04, -1.48846358e-02, -9.39764082e-03],
              [ 1.37704425e-03,  4.77804244e-03,  4.95645870e-03],
              [-8.09216499e-03,  1.55307690e-03, -7.01821191e-05],
              [-1.03738727e-02,  3.95740243e-03, -8.8

In [2]:
import jax
from jax import numpy as jnp, random as rnd, jit, vmap
from pd_code.config import Config
from jax_md.partition import neighbor_list, NeighborListFormat, to_jraph
from jax_md import space

import numpy as onp
from jraph._src import utils
from typing import Sequence
import jraph
from flax import linen as nn

from jax_md import energy

from jax_md import quantity

import time 

from functools import partial
import numpy as onp

import jax.numpy as np

from jax import jit
from jax import grad
from jax import vmap
from jax import value_and_grad

from jax import random
from jax import lax

from jax.experimental import stax
from jax.experimental import optimizers

from jax.config import config

# NOTE(schsam): We need this in OSS I think.
from IPython.display import HTML, display
import time



In [26]:
cfg = Config()
key = rnd.PRNGKey(cfg.seed)

positions = rnd.uniform(key, (5, cfg.n_dim))

displacement_fn, shift_fn = space.periodic(cfg.side, wrapped=True)
vmap_displacement_fn = vmap(lambda r0, r1: r0 - r1, in_axes=(0, 0), out_axes=0)
displacements = vmap_displacement_fn(positions, positions)

# Dense: (N, n_max_neighbors_per_atom), Sparse (2, n_max_neighbors): Ordered Sparse: Sparse but half (no mirrors)
neighbor_fn = neighbor_list(displacement_fn, box_size=1., r_cutoff=0.3, dr_threshold=0.01, format=NeighborListFormat.Sparse)
nbrs = neighbor_fn.allocate(positions)

# jraph_tuple = to_jraph(nbrs)
# print(nbrs)
# print(jraph_tuple)
from jax_md import partition
dr_threshold = 0.02
r_cutoff = 0.5
d = displacement_fn
d = space.map_bond(d)
print(nbrs.idx[0])

dR = d(positions[nbrs.idx[0]], positions[nbrs.idx[1]])
print(len(dR), len(positions))
if dr_threshold > 0.0:
    dr_2 = space.square_distance(dR)
    mask = dr_2 < r_cutoff ** 2 + 1e-5
    graph = partition.to_jraph(nbrs, mask)
    print(graph.receivers)
    # TODO(schsam): It seems wasteful to recompute dR after we remask the
    # edges. If I can think of a clean way to get rid of this, I should.
    dR = d(positions[graph.receivers], positions[graph.senders])
    print(len(dR))
else:
    graph = partition.to_jraph(nbrs)

# graph = graph._replace(
# nodes=jnp.concatenate((_nodes,
#                         jnp.zeros((1,) + _nodes.shape[1:], R.dtype)),
#                         axis=0),
# edges=dR,
# globals=jnp.broadcast_to(_globals[:, None], (2, 1))
# )

print(dR[:10])
print(dr_2[:5])
print(nbrs.idx[0])
print(nbrs.idx[1])

[4 4 1 2 5]
5 5
[4 4 1 2 5]
5
[[ 0.23274148  0.00872886]
 [-0.06568599  0.2511469 ]
 [-0.23274148 -0.00872886]
 [ 0.06568599 -0.2511469 ]
 [ 0.          0.        ]]
[0.05424479 0.06738942 0.05424479 0.06738942 0.        ]
[4 4 1 2 5]
[1 2 4 4 5]


In [66]:
# Attempt to vmap the nbr process

n_node, n_dim = 3, 2

node = rnd.uniform(key, (5, 10, n_node, n_dim))
node = rnd.uniform(key, (1, 1, n_node, n_dim)).repeat(5, 0).repeat(10, 0)
nbrs = neighbor_fn.allocate(node[0, 0], extra_capacity=0)
 
def compute_edges(positions, receivers, senders):
    return jnp.linalg.norm(positions[receivers] - positions[senders], axis=-1, keepdims=True)

def get_graphs(node):
    nbr = nbrs.update(node)
    receivers = nbr.idx[cfg.receivers_idx]
    senders = nbr.idx[cfg.senders_idx]
    positions = node[:, :3]
    edges = compute_edges(positions, receivers, senders)
    return edges, senders, receivers

@jit
def batch_graphs(nodes, edges, senders, receivers):
    graphs = []
    for n, e, s, r in zip(nodes, edges, senders, receivers):
        graph = jraph.GraphsTuple(nodes=n,
                              edges=e,
                              n_node=np.array([n.shape[0]]),
                              n_edge=np.array([e.shape[0]]),
                              senders=s,
                              receivers=r,
                              globals={})
        graphs.append(graph)
    return jraph.batch(graphs) 

node = node.reshape(-1, n_node, n_dim)
get_graphs = jit(vmap(get_graphs, in_axes=(0,), out_axes=(0, 0, 0)))


In [70]:
edge_info = get_graphs(node)
graph = batch_graphs(node, *edge_info)
print(graph.senders)

[  1   2   4   5   7   8  10  11  13  14  16  17  19  20  22  23  25  26
  28  29  31  32  34  35  37  38  40  41  43  44  46  47  49  50  52  53
  55  56  58  59  61  62  64  65  67  68  70  71  73  74  76  77  79  80
  82  83  85  86  88  89  91  92  94  95  97  98 100 101 103 104 106 107
 109 110 112 113 115 116 118 119 121 122 124 125 127 128 130 131 133 134
 136 137 139 140 142 143 145 146 148 149]


In [72]:
def untransform(data, old_min, old_max, new_min=-1, new_max=1, mean=None):
    data = ((data - old_min) / (old_max - old_min)) * (new_max - new_min) + new_min
    if mean is not None:
        data = data + mean
    return data

def transform(data, old_min, old_max, new_min=-1, new_max=1, mean=None):
    if mean is not None:
        data = data - mean
    data = ((data - old_min) / (old_max - old_min)) * (new_max - new_min) + new_min
    return data

def cut_remainder(data, n_batch):
    n_batch_time, remainder = divmod(data.shape[0], n_batch)
    data = data[:-remainder] if remainder > 0 else data
    return data

def split_into_timesteps(data, n_timesteps):
    data = cut_remainder(data, n_timesteps)
    n_trajectories = data.shape[0]//n_timesteps
    data = data.reshape(n_trajectories, n_timesteps, *data.shape[1:])
    return data

def get_stats(self, data, axis=0):
    if len(data.shape) == 3:
        data = data.reshape(-1, 3)
    return jnp.min(data, axis=axis), jnp.max(data, axis=axis), jnp.mean(data, axis=axis)

def compute_edges(positions, receivers, senders):
        return jnp.linalg.norm(positions[receivers] - positions[senders], axis=-1, keepdims=True)

def batch_graphs(nodes, edges, senders, receivers):
        graphs = []
        for n, e, s, r in zip(nodes, edges, senders, receivers):
            graph = jraph.GraphsTuple(nodes=n,
                                edges=e,
                                n_node=np.array([n.shape[0]]),
                                n_edge=np.array([e.shape[0]]),
                                senders=s,
                                receivers=r,
                                globals={})
            graphs.append(graph)
        return jraph.batch(graphs)

def prep_neval_eq_ntr(cfg, split, data):
    data = split_into_timesteps(data, cfg.n_timesteps)
    n_trajectories = len(data)

    n_train, n_val, n_test = (
        int(n_trajectories * split[0]),
        int(n_trajectories * split[1]),
        int(n_trajectories * split[2])
    )

    key = rnd.PRNGKey(cfg.seed)
    idxs = rnd.permutation(key, jnp.arange(0, n_trajectories))
    tr_idxs, val_idxs, test_idxs = idxs[:n_train], idxs[n_train:(n_val+n_train)], idxs[-n_test:]

    val_idxs = jnp.delete(val_idxs, jnp.where(val_idxs==0)[0])  # remove if first one! 
    test_idxs = jnp.delete(test_idxs, jnp.where(test_idxs==0)[0])  # remove if first one! 
    initial_states_val_data = data[(val_idxs - 1), -cfg.n_eval_warmup:, ...]
    initial_states_test_data = data[(test_idxs - 1), -cfg.n_eval_warmup:, ...]

    tr_data = cut_remainder(data[tr_idxs], cfg.batch_size)
    val_data = cut_remainder(jnp.concatenate([initial_states_val_data, data[val_idxs]], axis=1), cfg.batch_size)
    test_data = cut_remainder(jnp.concatenate([initial_states_test_data, data[test_idxs]], axis=1), cfg.batch_size)

    print(f'Datasets length: Train {len(tr_data)} Val {len(val_data)} Test {len(test_data)}')
    print(f'Some idxs:  \n Train {tr_idxs[:5]} \n Val {val_idxs[:5]}')

    return tr_data, val_data, test_data

def compute_edges(positions, receivers, senders):
    return jnp.norm(positions[receivers] - positions[senders], axis=-1, keepdims=True)

In [8]:




class DataLoader():
    def __init__(self, 
                 cfg: Config,
                 nodes: jnp.array,
                 target: jnp.array,
                 shuffle: bool=True, 
                 eval: bool=False):
        self.seed = cfg.seed

        self.nodes = nodes
        self.edges = edges
        self.n_data, self.n_timesteps, self.n_nodes, self.n_node_features = nodes.shape
        self.n_edge_features = edges.shape[-1]
        self.batch_size = cfg.batch_size
        
        self.n_eval_timesteps = cfg.n_eval_timesteps
        self.eval = eval
        
        self.target = nodes[..., :3]

        if cfg.periodic:
            displacement_fn, shift_fn = space.periodic(cfg.side, wrapped=True)
            neighbor_fn = neighbor_list(displacement_fn, 
                                        box_size=cfg.box_size, 
                                        r_cutoff=cfg.r_cutoff,
                                        dr_threshold=cfg.dr_threshold,  # when the neighbor list updates
                                        format=NeighborListFormat.Sparse)
        else:
            displacement_fn, shift_fn = space.free()
            neighbor_fn = neighbor_list(displacement_fn, 
                                        box_size=cfg.box_size, 
                                        r_cutoff=999.,
                                        dr_threshold=999.,  # when the neighbor list updates
                                        format=NeighborListFormat.Sparse)
        
        nbrs = neighbor_fn.allocate(nodes[0, 0], extra_capacity=0)  # extra capacity may reduce jit overhead

        def get_edge_info(node):
            nbr = nbrs.update(node)
            receivers = nbr.idx[cfg.receivers_idx]
            senders = nbr.idx[cfg.senders_idx]
            positions = node[:, :3]
            edges = compute_edges(positions, receivers, senders)
            return edges, senders, receivers

        _get_edge_info = vmap(get_edge_info, in_axes=(0,), out_axes=(0, 0, 0))

        @jit
        def create_graphs(nodes):
            nodes = nodes.reshape(-1, *nodes.shape[2:])
            edge_info = _get_edge_info(nodes)
            graphs = batch_graphs(nodes, *edge_info)
            return graphs

        self._create_graphs = create_graphs

        self._shuffle = shuffle
        self._order = jnp.arange(0, self.n_data)
        self.key = rnd.PRNGKey(self.seed)
        self.key, subkey = rnd.split(self.key)
        self._order = rnd.permutation(subkey, self._order)

    def shuffle(self, key=None, returns_new_key=False, reset=True):
        
        self.key, subkey = rnd.split(self.key)
        self._order = rnd.permutation(subkey, self._order)

        if reset:
            print('Resetting dataloader... ')
            self._exhausted_batches = 0

        if returns_new_key:
            return self.key

    def __next__(self):
        i = self._exhausted_batches
        start = i * self.batch_size
        stop  = (i+1) * self.batch_size
        idxs = (self._order[start:stop],)

        if i < len(self):
            self._exhausted_batches += 1
            if not self.eval:
                nodes = node[idxs]
                target = nodes[..., :3]
                graphs = self._create_graphs(nodes)
                return graphs, target
            else:
                nodes_warmup, nodes_eval = jnp.split(node[idxs], [nodes.shape[1]-self.n_eval_timesteps,], axis=1)
                target_warmup, target_eval = nodes_warmup[..., :3], nodes_eval[..., :3]
                
                graphs_warmup = self._create_graphs(nodes_warmup)
                graphs_eval = self._create_graphs(nodes_eval)
                return ((graphs_warmup, target_warmup), (graphs_eval, target_eval))
            
        else:
            raise StopIteration


def create_dataloaders(cfg):
    raw_data = np.load(cfg.data_path)
    keys = raw_data.keys()

    for feature in cfg.node_features:
        assert feature in keys(), f'{feature} not in dataset'
    
    # Get the data statisitics
    for name in cfg.node_features:
        tmp_min, tmp_max, tmp_mean = get_stats(raw_data[name])
        setattr(cfg, f'{name}_min', tmp_min)
        setattr(cfg, f'{name}_max', tmp_max)
        setattr(cfg, f'{name}_mean', tmp_mean)
        setattr(cfg, f'{name}', raw_data[name])
    
    setattr(cfg, 'R_lims', tuple((cfg.R_min[i], cfg.R_max[i]) for i in range(3)))
    setattr(cfg, 'n_nodes', raw_data['R'].shape[1])

    print(f' Pos-Lims: {tuple((float(cfg.R_min[i]), float(cfg.R_max[i])) for i in range(3))} \
            \n F-Lims: {tuple((float(cfg.F_min[i]), float(cfg.F_max[i])) for i in range(3))} \
            \n A-Lims: {int(cfg.z_min)} {int(cfg.z_max)}')

    n_data, n_nodes, _ = raw_data['R'].shape

    # Transform the data
    box_size = raw_data.get('box_size', False)
    if box_size:
        positions = transform(raw_data['R'], 0, box_size, mean=box_size/2.)
    else:
        positions = transform(raw_data['R'], cfg.R_min, cfg.R_max, mean=cfg.R_mean)
    
    node_id = raw_data.get('z', False)
    if node_id:
        raw_data['z'] = jax.nn.one_hot((node_id-1), int(max(node_id)), dtype=jnp.float32)
        raw_data['z'] = raw_data['z'][None, :, None].repeat(n_data, axis=0)
    
    if 'F' in keys:
        raw_data['F'] = transform(raw_data['F'], -1, 1, mean=cfg.F_mean)

    # Set the node features
    nodes = jnp.concatenate([positions, *[raw_data[feature] for feature in cfg.node_features if not feature == 'R']], axis=-1)

    setattr(cfg, 'n_node_features', nodes.shape[-1] * n_nodes)
    setattr(cfg, 'n_target_features', 3 * n_nodes)
    

    # Get the val/train/test split and put into timeslices
    # Tr and val take normal time
    split = (0.7, 0.15, 0.15)
    tr, val, test = prep_neval_eq_ntr(cfg, split, nodes)

    train_loader = DataLoader(cfg, *tr)
    val_loader = DataLoader(cfg, *val, eval=True)
    test_loader = DataLoader(cfg, *test, eval=True)

    return train_loader, val_loader, test_loader




IndentationError: expected an indented block after function definition on line 164 (1815998741.py, line 165)

In [None]:
cfg = Config()

class ExplicitMLP(nn.Module):
  """A flax MLP."""
  features: Sequence[int]

  @nn.compact
  def __call__(self, inputs):
    x = inputs
    for i, lyr in enumerate([nn.Dense(feat) for feat in self.features]):
      x = lyr(x)
      if i != len(self.features) - 1:
        x = nn.relu(x)
    return x

def make_embed_fn(latent_size):
  def embed(inputs):
    return nn.Dense(latent_size)(inputs)
  return embed


def make_mlp(features):
  @jraph.concatenated_args
  def update_fn(inputs):
    return ExplicitMLP(features)(inputs)
  return update_fn

class GraphNetwork(nn.Module):
  """A flax GraphNetwork."""
  mlp_features: Sequence[int]
  latent_size: int

  @nn.compact
  def __call__(self, graph):
    # Add a global parameter for graph classification.
    graph = graph._replace(globals=jnp.zeros([graph.n_node.shape[0], 1]))

    embedder = jraph.GraphMapFeatures(
        embed_node_fn=make_embed_fn(self.latent_size),
        embed_edge_fn=make_embed_fn(self.latent_size),)
        # embed_global_fn=make_embed_fn(self.latent_size))
    
    net = jraph.GraphNetwork(
        update_node_fn=make_mlp(self.mlp_features),
        update_edge_fn=make_mlp(self.mlp_features),)
        # The global update outputs size 2 for binary classification.
        # update_global_fn=make_mlp(self.mlp_features + (2,)))  # pytype: disable=unsupported-operands
    return net(embedder(graph))

net = GraphNetwork(mlp_features=(cfg.graph_mlp_features, cfg.graph_mlp_features), 
                   latent_size=cfg.graph_mlp_features, 
                   aggregate_edges_for_nodes_fn=utils.segment_mean)

In [46]:


def ProgressIter(iter_fun, iter_len=0):
  if not iter_len:
    iter_len = len(iter_fun)
  out = display(progress(0, iter_len), display_id=True)
  for i, it in enumerate(iter_fun):
    yield it
    out.update(progress(i + 1, iter_len))

def progress(value, max):
    return HTML("""
        <progress
            value='{value}'
            max='{max}',
            style='width: 45%'
        >
            {value}
        </progress>
    """.format(value=value, max=max))

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
  
sns.set_style(style='white')
sns.set(font_scale=1.6)

def format_plot(x, y):  
  plt.xlabel(x, fontsize=20)
  plt.ylabel(y, fontsize=20)
  
def finalize_plot(shape=(1, 1)):
  plt.gcf().set_size_inches(
    shape[0] * 1.5 * plt.gcf().get_size_inches()[1], 
    shape[1] * 1.5 * plt.gcf().get_size_inches()[1])
  plt.tight_layout()

f32 = np.float32
f64 = np.float64

def draw_system(R, box_size, marker_size, color=None):
  if color == None:
    color = [64 / 256] * 3
  ms = marker_size / box_size

  R = onp.array(R)

  marker_style = dict(
      linestyle='none', 
      markeredgewidth=3,
      marker='o', 
      markersize=ms, 
      color=color, 
      fillstyle='none')

  plt.plot(R[:, 0], R[:, 1], **marker_style)
  plt.plot(R[:, 0] + box_size, R[:, 1], **marker_style)
  plt.plot(R[:, 0], R[:, 1] + box_size, **marker_style)
  plt.plot(R[:, 0] + box_size, R[:, 1] + box_size, **marker_style)
  plt.plot(R[:, 0] - box_size, R[:, 1], **marker_style)
  plt.plot(R[:, 0], R[:, 1] - box_size, **marker_style)
  plt.plot(R[:, 0] - box_size, R[:, 1] - box_size, **marker_style)

  plt.xlim([0, box_size])
  plt.ylim([0, box_size])
  plt.axis('off')

def square_lattice(N, box_size):
  Nx = int(np.sqrt(N))
  Ny, ragged = divmod(N, Nx)
  if Ny != Nx or ragged:
    assert ValueError('Particle count should be a square. Found {}.'.format(N))
  length_scale = box_size / Nx
  R = []
  for i in range(Nx):
    for j in range(Ny):
      R.append([i * length_scale, j * length_scale])
  return np.array(R)

N = 256
box_size = quantity.box_size_at_number_density(particle_count=N, 
                                               number_density=1, 
                                               spatial_dimension=2)

r = square_lattice(N, box_size)
r = r.at[0, 0].set(12.1)
# draw_system(r, box_size, 270.0)
# finalize_plot((0.75, 0.75))

print(r.shape, box_size, jnp.max(r), jnp.min(r))

displacement_fn, shift_fn = space.periodic(box_size)

energy_fn = energy.soft_sphere_pair(displacement_fn, per_particle=True)
energy = energy_fn(r)
print(energy)
# print('Energy of the system, U = {:f}'.format(energy))

(256, 2) 16.0 15.0 0.0
[0.205  0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0. 

In [27]:
def batch_graph(node_features, edge_features):
    # can't be vmapped because the different length lists?

    # ?? do we need this ?? Yes! The node features are stacked (they reused the weights), 
    # but the receivers and senders are cumulative (only referring to each other)

    # Might need optimisation, look here https://github.com/google/flax/blob/main/examples/ogbg_molpcba/input_pipeline.py
    # and use the convert to graphs tuple function and map. Jit won't work because it suuuucks

    graphs = []
    for nf, ef in zip(node_features, edge_features):
        graph = jraph.GraphsTuple(nodes=nf, 
                                  receivers=nbrs[0, :], 
                                  senders=nbrs[1, :],
                                  edges=ef, 
                                  n_node=jnp.array([len(nf)]), 
                                  n_edge=jnp.array([len(ef)]), 
                                  globals=jnp.array([0.0]))  # could be for temperature or something global
        
        graphs.append(graph)
    batch = jraph.batch(graphs)
    return batch

# Given a batch of positions

# Compute the neighbor list for each batch
neighbor = neighbor_fn.allocate(positions[0], extra_capacity=6)  # the neighbor list only refers to the first ones?? 

# vmap the train function!



In [35]:
# Testing to see how the batching system works, does it make sense? 

node_features = jnp.repeat(jnp.array([[0.], [1.], [2.]]), 2, axis=1)

senders = jnp.repeat(jnp.array([0, 1, 2])[:, None], 2, axis=1)
receivers = jnp.repeat(jnp.array([1, 2, 0])[:, None], 2, axis=1)

# You can optionally add edge attributes.
edges = jnp.repeat(jnp.array([[5.], [6.], [7.]]), 2, axis=1)

n_node = jnp.array([[3], [3]])
n_edge = jnp.array([[3], [3]])

global_context = jnp.array([[1], [1]]) # Same feature dimensions as nodes and edges.
graph1 = jraph.GraphsTuple(nodes=node_features, senders=senders, receivers=receivers,
edges=edges, n_node=n_node, n_edge=n_edge, globals=global_context)

# Define a three node graph, each node has an integer as its feature.
node_features = jnp.array([[0.], [1.], [2.]])
senders = jnp.array([0, 1, 2])
receivers = jnp.array([1, 2, 0])
edges = jnp.array([[5.], [6.], [7.]])
n_node = jnp.array([3])
n_edge = jnp.array([3])
global_context = jnp.array([[1]]) # Same feature dimensions as nodes and edges.
graph = jraph.GraphsTuple(nodes=node_features, senders=senders, receivers=receivers,
edges=edges, n_node=n_node, n_edge=n_edge, globals=global_context)

graph = jraph.batch([graph, graph])
print(graph1.nodes, graph.nodes)
print(graph1.senders, graph.senders)

[[0. 0.]
 [1. 1.]
 [2. 2.]] [[0.]
 [1.]
 [2.]
 [0.]
 [1.]
 [2.]]
[[0 0]
 [1 1]
 [2 2]] [0 1 2 3 4 5]


In [29]:
# testing the displacement function
# custom displacement functions don't work, because they can't accept the t=0 parameter

vmap_displacement_fn = vmap(lambda r0, r1: r0 - r1, in_axes=(0, 0), out_axes=0)
displacement_fn = lambda r0, r1: r0 - r1
R = ShapedArray((1,), f32)
eval_shape(displacement_fn, R, R, t=0)


TypeError: <lambda>() got an unexpected keyword argument 't'