In [83]:
import sys
sys.path.append("../")

import jax
import jax.numpy as np
from functools import partial
import jraph

from models.graph_utils import nearest_neighbors, nearest_neighbors_ann, RadiusSearch

## Brute force kNN

In [195]:
rng = jax.random.PRNGKey(42)

n_nodes = 5000
k = 10

# Generate random inputs
x = jax.random.normal(rng, (n_nodes, 3))

In [196]:
sources, targets = nearest_neighbors(x, k=k)
sources_ann, targets_ann = nearest_neighbors_ann(x, k=k)

In [197]:
np.array_equal(sources_ann, sources), np.array_equal(targets_ann, targets)

(Array(True, dtype=bool), Array(False, dtype=bool))

In [198]:
(targets_ann == targets).sum()

Array(46485, dtype=int32)

In [199]:
%%timeit
sources_ann, targets_ann = nearest_neighbors_ann(x, k=k)

4.88 ms ± 455 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [200]:
%%timeit
sources, targets = nearest_neighbors(x, k=k)

7.05 ms ± 661 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Using jax-md

Scale up without having to recalculate pairwise distance matrix. Let's see if it can scale to 50000 nodes.

In [236]:
n_batch = 4
n_nodes = 50000
n_pos = 3

x = np.load("/n/holyscratch01/iaifi_lab/ccuesta/data_for_sid/halos.npy")[:n_batch, :n_nodes,:n_pos]
# x = jax.random.uniform(rng, (n_batch, n_nodes, n_pos)) * 1000.

In [237]:
ns = RadiusSearch(1000., 30.)
nbr = ns.init_neighbor_lst(x[0])

In [238]:
%%timeit
nbr_update, _ = ns.update_neighbor_lst(x, nbr)

66.5 ms ± 793 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [239]:
nbr_update, nbr = ns.update_neighbor_lst(x, nbr)

In [240]:
nbr_update.idx

Array([[[   0,  331,    1, ..., 5000, 5000, 5000],
        [   0,    0,    1, ..., 5000, 5000, 5000]],

       [[   0, 1379,    1, ..., 5000, 5000, 5000],
        [   0,    0,    1, ..., 5000, 5000, 5000]],

       [[   0, 3839,    1, ..., 5000, 5000, 5000],
        [   0,    0,    1, ..., 5000, 5000, 5000]],

       [[   0, 4977, 4138, ..., 5000, 5000, 5000],
        [   0,    0,    0, ..., 5000, 5000, 5000]]], dtype=int32)

In [241]:
# Shape [batch, (senders, receivers), -1]
nbr_update.idx.shape

(4, 2, 29868)

In [251]:
from jax_md import space, partition
globals = jax.random.normal(rng, (n_batch, 4))

## Batches 

In [285]:
graph = jax.vmap(partition.to_jraph)(nbr_update, mask=None,  nodes=x, edges=None, globals=globals)

In [286]:
graph.n_node

Array([[5000,    1],
       [5000,    1],
       [5000,    1],
       [5000,    1]], dtype=int32)

In [287]:
graph.n_edge

Array([[19912,  9956],
       [15296, 14572],
       [17400, 12468],
       [14846, 15022]], dtype=int32)

In [288]:
graph.n_node

Array([[5000,    1],
       [5000,    1],
       [5000,    1],
       [5000,    1]], dtype=int32)

In [289]:
graph.n_node.shape, graph.n_edge.shape, graph.nodes.shape, graph.globals.shape, graph.senders.shape, graph.receivers.shape, graph.receivers.shape 

((4, 2), (4, 2), (4, 5001, 3), (4, 5), (4, 29868), (4, 29868), (4, 29868))

In [278]:
n_batch = 4
graph = jraph.GraphsTuple(
    n_node=np.array(n_batch * [[5000]]),
    n_edge=np.array(n_batch * [[k]]),
    nodes=x,
    edges=None,#np.zeros(sources.shape),
    globals=globals,
    senders=np.array(n_batch * [sources]),
    receivers=np.array(n_batch * [targets]),
)

graph

GraphsTuple(nodes=Array([[[112.30655 , 701.6364  , 649.7085  ],
        [137.20975 , 557.04095 , 747.1494  ],
        [ 99.483185, 248.12627 , 932.8447  ],
        ...,
        [553.4064  , 856.65936 , 302.90207 ],
        [226.60483 , 676.7364  , 757.2367  ],
        [677.3594  ,  66.43591 , 264.32083 ]],

       [[904.3494  , 337.22562 , 234.15457 ],
        [ 63.514877, 261.34128 , 157.26772 ],
        [159.5823  , 473.1259  , 573.4834  ],
        ...,
        [552.94183 , 856.9142  , 305.17673 ],
        [225.61792 , 677.6981  , 757.5429  ],
        [675.90625 ,  66.14816 , 264.68692 ]],

       [[605.41376 , 144.60951 , 655.7672  ],
        [632.3991  , 406.96896 , 924.6638  ],
        [783.7091  , 947.40924 , 166.29826 ],
        ...,
        [ 90.095436, 958.2254  , 405.9588  ],
        [676.97943 , 183.01384 , 682.7778  ],
        [244.5     , 120.537735, 187.83807 ]],

       [[524.7156  , 900.19855 , 220.11913 ],
        [657.20465 ,  12.784197, 912.2405  ],
        [427.2985

In [279]:
graph.n_node.shape, graph.n_edge.shape, graph.nodes.shape, graph.globals.shape, graph.senders.shape, graph.receivers.shape, 

((4, 1), (4, 1), (4, 5000, 3), (4, 4), (4, 50000), (4, 50000))

In [280]:
from models.gnn import GraphConvNet
import flax.linen as nn

class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(GraphConvNet(
            latent_size=6, 
            num_mlp_layers=4, 
            message_passing_steps=3))
        return model(x)

In [290]:
# model = GraphWrapper()
# model.init_with_output(rng, graph)

## Test without batch

In [305]:
class RadiusSearch:
    """Jittable radius graph"""

    def __init__(self, box_size, cutoff, boundary_cond="free", capacity_multiplier=1.5):

        self.box_size = np.array(box_size)

        if boundary_cond == "free":
            self.displacement_fn, _ = space.free()
        elif boundary_cond == "periodic":
            self.displacement_fn, _ = space.periodic(self.box_size)
        else:
            raise NotImplementedError

        self.cutoff = cutoff
        self.neighbor_list_fn = partition.neighbor_list(self.displacement_fn, self.box_size, cutoff, format=partition.Sparse, dr_threshold=cutoff / 6.0, mask_self=False, capacity_multiplier=capacity_multiplier)

        self.neighbor_list_fn_jit = jax.jit(self.neighbor_list_fn)
        self.neighbor_dist_jit = self.displacement_fn

        # Each time number of neighbours buffer overflows, reallocate
        self.n_times_reallocated = 0

    def init_neighbor_lst(self, pos):
        """Allocate initial neighbour list."""
        pos = np.mod(pos, self.box_size)
        nbr = self.neighbor_list_fn.allocate(pos)
        return nbr

    def update_neighbor_lst(self, pos, nbr):
        """Update neighbour list. If buffer overflows, reallocate (re-jit)."""
        pos = np.mod(pos, self.box_size)
        nbr_update = self.neighbor_list_fn_jit.update(pos, nbr)

        # If buffer overflows, update capacity of neighbours.
        # NOTE: This reallocation strategy might be more efficient: https://github.com/jax-md/jax-md/issues/192#issuecomment-1114002995
        if np.any(nbr_update.did_buffer_overflow):
            nbr = self.neighbor_list_fn.allocate(pos[0], extra_capacity=2**self.n_times_reallocated)
            self.n_times_reallocated += 1

        return nbr_update, nbr


In [391]:
n_batch = 4
n_nodes = 5000
n_pos = 3

x = np.load("/n/holyscratch01/iaifi_lab/ccuesta/data_for_sid/halos.npy")[:n_batch, :n_nodes,:n_pos][0]

In [392]:
ns = RadiusSearch(1000., 30.)
nbr = ns.init_neighbor_lst(x)

In [393]:
nbr_update, nbr = ns.update_neighbor_lst(x, nbr)

In [396]:
globals

Array([[-0.18471177]], dtype=float32)

In [407]:
globals = jax.random.normal(rng, (1,4))
graph = partition.to_jraph(nbr_update, nodes=x, edges=None, globals=globals)

In [408]:
graph.n_node.shape, graph.n_edge.shape, graph.nodes.shape, graph.globals.shape, graph.senders.shape, graph.receivers.shape, 

((2,), (2,), (5001, 3), (2, 4), (29868,), (29868,))

In [406]:
globals

Array([[ 0.18693547,  1.0653336 , -1.5593132 , -1.5352962 ]], dtype=float32)

In [399]:
# def update_edge_fn(
#   edge_features,
#   sender_node_features,
#   receiver_node_features,
#   globals_):
#     """Returns the update edge features."""
#     del sender_node_features
#     del receiver_node_features
#     del globals_
#     return edge_features

# def update_node_fn(
#   node_features,
#   aggregated_sender_edge_features,
#   aggregated_receiver_edge_features,
#   globals_):
#     """Returns the update node features."""
#     del aggregated_sender_edge_features
#     del aggregated_receiver_edge_features
#     del globals_
#     return node_features

# def update_global_fn(
#   aggregated_node_features,
#   aggregated_edge_features,
#   globals_):
#     """Returns the global features."""
#     del aggregated_node_features
#     del aggregated_edge_features
#     return globals_

# net = jraph.GraphNetwork(update_edge_fn=update_edge_fn,
#                          update_node_fn=update_node_fn,
#                          update_global_fn=update_global_fn)

# updated_graph = net(graph)

In [400]:
# model = GraphConvNet(
#             latent_size=6, 
#             num_mlp_layers=4, 
#             message_passing_steps=3)

# model.init_with_output(rng, graph)

In [409]:
import jax.numpy as jnp

# Define a three node graph, each node has an integer as its feature.
node_features = jnp.array([[0.], [1.], [2.]])

# We will construct a graph fro which there is a directed edge between each node
# and its successor. We define this with `senders` (source nodes) and `receivers`
# (destination nodes).
senders = jnp.array([0, 1, 2])
receivers = jnp.array([1, 2, 0])

# You can optionally add edge attributes.
edges = jnp.array([[5.], [6.], [7.]])

# We then save the number of nodes and the number of edges.
# This information is used to make running GNNs over multiple graphs
# in a GraphsTuple possible.
n_node = jnp.array([3])
n_edge = jnp.array([3])

# Optionally you can add `global` information, such as a graph label.

global_context = jnp.array([[1]]) # Same feature dimensions as nodes and edges.
graph = jraph.GraphsTuple(nodes=node_features, senders=senders, receivers=receivers,
edges=edges, n_node=n_node, n_edge=n_edge, globals=global_context)


In [410]:
graph.n_node.shape, graph.n_edge.shape, graph.nodes.shape, graph.globals.shape, graph.senders.shape, graph.receivers.shape, 

((1,), (1,), (3, 1), (1, 1), (3,), (3,))

In [411]:
graph_batch = jraph.batch([graph, graph])
graph_batch.n_node.shape, graph_batch.n_edge.shape, graph_batch.nodes.shape, graph_batch.globals.shape, graph_batch.senders.shape, graph_batch.receivers.shape, 

((2,), (2,), (6, 1), (2, 1), (6,), (6,))

In [401]:
updated_graph = net(graph)

In [375]:
updated_graph = net(graph_batch)