In [97]:
import sys
sys.path.append("../")

import jax
import jax.numpy as np
from functools import partial
import jraph
from jax_md import space, partition

from models.graph_utils import nearest_neighbors, nearest_neighbors_ann, RadiusSearch

In [98]:
rng = jax.random.PRNGKey(42)

n_nodes = 5000
k = 10
n_batch = 4
n_pos = 3

# x = np.load("/n/holyscratch01/iaifi_lab/ccuesta/data_for_sid/halos.npy")[:n_batch, :n_nodes,:n_pos]
x = jax.random.uniform(rng, (n_batch, n_nodes, n_pos)) * 1000.

## Brute force kNN

In [99]:
sources, targets = nearest_neighbors(x[0], k=k)
sources_ann, targets_ann = nearest_neighbors_ann(x[0], k=k)

In [100]:
np.array_equal(sources_ann, sources), np.array_equal(targets_ann, targets)

(Array(True, dtype=bool), Array(False, dtype=bool))

In [101]:
(targets_ann == targets).sum()

Array(34500, dtype=int32)

In [102]:
%%timeit
sources_ann, targets_ann = nearest_neighbors_ann(x[0], k=k)

4.9 ms ± 2.16 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [103]:
%%timeit
sources, targets = nearest_neighbors(x[0], k=k)

7.04 ms ± 1.03 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Using jax-md

Scale up without having to recalculate pairwise distance matrix. Let's see if it can scale to 50000 nodes.

In [104]:
ns = RadiusSearch(1000., 30.)
nbr = ns.init_neighbor_lst(x[0])

In [105]:
# %%timeit
# nbr_update, _ = ns.update_neighbor_lst(x, nbr)

In [106]:
nbr_update, nbr = ns.update_neighbor_lst(x, nbr)

In [107]:
nbr_update.idx

Array([[[   0,    1,  954, ..., 5000, 5000, 5000],
        [   0,    1,    1, ..., 5000, 5000, 5000]],

       [[   0,    1, 1924, ..., 5000, 5000, 5000],
        [   0,    1,    1, ..., 5000, 5000, 5000]],

       [[2935,    0, 3396, ..., 5000, 5000, 5000],
        [   0,    0,    0, ..., 5000, 5000, 5000]],

       [[   0, 2944,    1, ..., 5000, 5000, 5000],
        [   0,    0,    1, ..., 5000, 5000, 5000]]], dtype=int32)

In [108]:
# Shape [batch, (senders, receivers), -1]
nbr_update.idx.shape

(4, 2, 13965)

In [109]:
globals = jax.random.normal(rng, (n_batch, 1, 4))

## Simple graph

In [110]:
def update_edge_fn(
  edge_features,
  sender_node_features,
  receiver_node_features,
  globals_):
    """Returns the update edge features."""
    del sender_node_features
    del receiver_node_features
    del globals_
    return edge_features

def update_node_fn(
  node_features,
  aggregated_sender_edge_features,
  aggregated_receiver_edge_features,
  globals_):
    """Returns the update node features."""
    del aggregated_sender_edge_features
    del aggregated_receiver_edge_features
    del globals_
    return node_features

def update_global_fn(
  aggregated_node_features,
  aggregated_edge_features,
  globals_):
    """Returns the global features."""
    del aggregated_node_features
    del aggregated_edge_features
    return globals_

net = jraph.GraphNetwork(update_edge_fn=update_edge_fn,
                         update_node_fn=update_node_fn,
                         update_global_fn=update_global_fn)


## Use in graph 

In [111]:
graph = jax.vmap(partition.to_jraph)(nbr_update, mask=None,  nodes=x, edges=None, globals=globals)

In [112]:
graph.n_node.shape, graph.n_edge.shape, graph.nodes.shape, graph.globals.shape, graph.senders.shape, graph.receivers.shape

((4, 2), (4, 2), (4, 5001, 3), (4, 2, 4), (4, 13965), (4, 13965))

In [114]:
updated_graph = jax.vmap(net)(graph)

In [115]:
from models.gnn import GraphConvNet
import flax.linen as nn

class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(GraphConvNet(
            latent_size=6, 
            num_mlp_layers=4, 
            message_passing_steps=3))
        return model(x)

In [117]:
model = GraphWrapper()
graph_update, params  = model.init_with_output(rng, graph)

In [31]:
globals = jax.random.normal(rng, (n_batch, 5))

In [32]:
n_batch = 4
k = 10
graph = jraph.GraphsTuple(
    n_node=np.array(n_batch * [[5000]]),
    n_edge=np.array(n_batch * [[k]]),
    nodes=x,
    edges=None,
    globals=globals,
    senders=np.array(n_batch * [sources]),
    receivers=np.array(n_batch * [targets]),
)

In [33]:
graph.n_node.shape, graph.n_edge.shape, graph.nodes.shape, graph.globals.shape, graph.senders.shape, graph.receivers.shape

((4, 1), (4, 1), (4, 5000, 3), (4, 5), (4, 50000), (4, 50000))

In [34]:
model = GraphWrapper()
graph_update, params  = model.init_with_output(rng, graph)

In [59]:
net(graph)

GraphsTuple(nodes=Array([[112.30655 , 701.6364  , 649.7085  ],
       [137.20975 , 557.04095 , 747.1494  ],
       [ 99.483185, 248.12627 , 932.8447  ],
       ...,
       [226.60483 , 676.7364  , 757.2367  ],
       [677.3594  ,  66.43591 , 264.32083 ],
       [  0.      ,   0.      ,   0.      ]], dtype=float32), edges=None, receivers=Array([   0,  331,    1, ..., 5000, 5000, 5000], dtype=int32), senders=Array([   0,    0,    1, ..., 5000, 5000, 5000], dtype=int32), globals=Array([[ 0.18693547,  1.0653336 , -1.5593132 , -1.5352962 ],
       [ 0.        ,  0.        ,  0.        ,  0.        ]],      dtype=float32), n_node=Array([5000,    1], dtype=int32), n_edge=Array([19912,  9956], dtype=int32))

## Test without batch

In [39]:
class RadiusSearch:
    """Jittable radius graph"""

    def __init__(self, box_size, cutoff, boundary_cond="free", capacity_multiplier=1.5):

        self.box_size = np.array(box_size)

        if boundary_cond == "free":
            self.displacement_fn, _ = space.free()
        elif boundary_cond == "periodic":
            self.displacement_fn, _ = space.periodic(self.box_size)
        else:
            raise NotImplementedError

        self.cutoff = cutoff
        self.neighbor_list_fn = partition.neighbor_list(self.displacement_fn, self.box_size, cutoff, format=partition.Sparse, dr_threshold=cutoff / 6.0, mask_self=False, capacity_multiplier=capacity_multiplier)

        self.neighbor_list_fn_jit = jax.jit(self.neighbor_list_fn)
        self.neighbor_dist_jit = self.displacement_fn

        # Each time number of neighbours buffer overflows, reallocate
        self.n_times_reallocated = 0

    def init_neighbor_lst(self, pos):
        """Allocate initial neighbour list."""
        pos = np.mod(pos, self.box_size)
        nbr = self.neighbor_list_fn.allocate(pos)
        return nbr

    def update_neighbor_lst(self, pos, nbr):
        """Update neighbour list. If buffer overflows, reallocate (re-jit)."""
        pos = np.mod(pos, self.box_size)
        nbr_update = self.neighbor_list_fn_jit.update(pos, nbr)

        # If buffer overflows, update capacity of neighbours.
        # NOTE: This reallocation strategy might be more efficient: https://github.com/jax-md/jax-md/issues/192#issuecomment-1114002995
        if np.any(nbr_update.did_buffer_overflow):
            nbr = self.neighbor_list_fn.allocate(pos[0], extra_capacity=2**self.n_times_reallocated)
            self.n_times_reallocated += 1

        return nbr_update, nbr


In [40]:
n_batch = 4
n_nodes = 5000
n_pos = 3

x = np.load("/n/holyscratch01/iaifi_lab/ccuesta/data_for_sid/halos.npy")[:n_batch, :n_nodes,:n_pos][0]

In [41]:
ns = RadiusSearch(1000., 30.)
nbr = ns.init_neighbor_lst(x)

In [42]:
nbr_update, nbr = ns.update_neighbor_lst(x, nbr)

In [43]:
rng = jax.random.PRNGKey(42)
globals = jax.random.normal(rng, (1,4))
graph = partition.to_jraph(nbr_update, nodes=x, edges=None, globals=globals)

In [60]:
graph.n_node.shape, graph.n_edge.shape, graph.nodes.shape, graph.globals.shape, graph.senders.shape, graph.receivers.shape, 

((2,), (2,), (5001, 3), (2, 4), (29868,), (29868,))

In [90]:
graph.n_node.shape, graph.n_edge.shape, graph.nodes.shape, graph.globals.shape, graph.senders.shape, graph.receivers.shape

((4, 2), (4, 2), (4, 5001, 3), (4, 2, 4), (4, 13965), (4, 13965))

In [94]:
updated_graph = net(graph)