In [83]:
import sys
sys.path.append("../")

import jax
import jax.numpy as np
from functools import partial
import jraph

from models.graph_utils import nearest_neighbors, nearest_neighbors_ann, RadiusSearch

## Brute force kNN

In [195]:
rng = jax.random.PRNGKey(42)

n_nodes = 5000
k = 10

# Generate random inputs
x = jax.random.normal(rng, (n_nodes, 3))

In [196]:
sources, targets = nearest_neighbors(x, k=k)
sources_ann, targets_ann = nearest_neighbors_ann(x, k=k)

In [197]:
np.array_equal(sources_ann, sources), np.array_equal(targets_ann, targets)

(Array(True, dtype=bool), Array(False, dtype=bool))

In [198]:
(targets_ann == targets).sum()

Array(46485, dtype=int32)

In [199]:
%%timeit
sources_ann, targets_ann = nearest_neighbors_ann(x, k=k)

4.88 ms ± 455 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [200]:
%%timeit
sources, targets = nearest_neighbors(x, k=k)

7.05 ms ± 661 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Using jax-md

Scale up without having to recalculate pairwise distance matrix. Let's see if it can scale to 50000 nodes.

In [236]:
n_batch = 4
n_nodes = 50000
n_pos = 3

x = np.load("/n/holyscratch01/iaifi_lab/ccuesta/data_for_sid/halos.npy")[:n_batch, :n_nodes,:n_pos]
# x = jax.random.uniform(rng, (n_batch, n_nodes, n_pos)) * 1000.

In [237]:
ns = RadiusSearch(1000., 30.)
nbr = ns.init_neighbor_lst(x[0])

In [238]:
%%timeit
nbr_update, _ = ns.update_neighbor_lst(x, nbr)

66.5 ms ± 793 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [239]:
nbr_update, nbr = ns.update_neighbor_lst(x, nbr)

In [240]:
nbr_update.idx

Array([[[   0,  331,    1, ..., 5000, 5000, 5000],
        [   0,    0,    1, ..., 5000, 5000, 5000]],

       [[   0, 1379,    1, ..., 5000, 5000, 5000],
        [   0,    0,    1, ..., 5000, 5000, 5000]],

       [[   0, 3839,    1, ..., 5000, 5000, 5000],
        [   0,    0,    1, ..., 5000, 5000, 5000]],

       [[   0, 4977, 4138, ..., 5000, 5000, 5000],
        [   0,    0,    0, ..., 5000, 5000, 5000]]], dtype=int32)

In [241]:
# Shape [batch, (senders, receivers), -1]
nbr_update.idx.shape

(4, 2, 29868)

In [251]:
from jax_md import space, partition
globals = jax.random.normal(rng, (n_batch, 4))

## Batches 

In [265]:
graph = jax.vmap(partition.to_jraph)(nbr_update, mask=None,  nodes=x, edges=None, globals=globals)

In [266]:
graph.n_node

Array([[5000,    1],
       [5000,    1],
       [5000,    1],
       [5000,    1]], dtype=int32)

In [267]:
graph.n_edge

Array([[19912,  9956],
       [15296, 14572],
       [17400, 12468],
       [14846, 15022]], dtype=int32)

In [262]:
graph.n_node.shape, graph.n_edge.shape, graph.nodes.shape, graph.globals.shape, graph.senders.shape, graph.receivers.shape, 

((4, 2), (4, 2), (4, 5001, 3), (4, 5), (4, 29868), (4, 29868))

In [263]:
n_batch = 4
graph = jraph.GraphsTuple(
    n_node=np.array(n_batch * [[5000]]),
    n_edge=np.array(n_batch * [[k]]),
    nodes=x,
    edges=None,#np.zeros(sources.shape),
    globals=globals,
    senders=np.array(n_batch * [sources]),
    receivers=np.array(n_batch * [targets]),
)

graph

GraphsTuple(nodes=Array([[[112.30655 , 701.6364  , 649.7085  ],
        [137.20975 , 557.04095 , 747.1494  ],
        [ 99.483185, 248.12627 , 932.8447  ],
        ...,
        [553.4064  , 856.65936 , 302.90207 ],
        [226.60483 , 676.7364  , 757.2367  ],
        [677.3594  ,  66.43591 , 264.32083 ]],

       [[904.3494  , 337.22562 , 234.15457 ],
        [ 63.514877, 261.34128 , 157.26772 ],
        [159.5823  , 473.1259  , 573.4834  ],
        ...,
        [552.94183 , 856.9142  , 305.17673 ],
        [225.61792 , 677.6981  , 757.5429  ],
        [675.90625 ,  66.14816 , 264.68692 ]],

       [[605.41376 , 144.60951 , 655.7672  ],
        [632.3991  , 406.96896 , 924.6638  ],
        [783.7091  , 947.40924 , 166.29826 ],
        ...,
        [ 90.095436, 958.2254  , 405.9588  ],
        [676.97943 , 183.01384 , 682.7778  ],
        [244.5     , 120.537735, 187.83807 ]],

       [[524.7156  , 900.19855 , 220.11913 ],
        [657.20465 ,  12.784197, 912.2405  ],
        [427.2985

In [257]:
graph.n_node.shape, graph.n_edge.shape, graph.nodes.shape, graph.globals.shape, graph.senders.shape, graph.receivers.shape, 

((4, 1), (4, 1), (4, 5000, 3), (4, 4), (4, 50000), (4, 50000))

In [258]:
from models.gnn import GraphConvNet
import flax.linen as nn

class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(GraphConvNet(
            latent_size=6, 
            num_mlp_layers=4, 
            message_passing_steps=3))
        return model(x)

In [259]:
model = GraphWrapper()
model.init_with_output(rng, graph)

(GraphsTuple(nodes=Array([[[-7.42330998e-02,  1.06472954e-01,  4.27184016e-01],
         [-2.62082130e-01, -1.55594368e-02,  5.40986896e-01],
         [-6.02795660e-01, -2.13898361e-01,  6.93150461e-01],
         ...,
         [ 3.57081264e-01,  1.84403256e-01,  1.44310266e-01],
         [-1.85057983e-01, -5.70326578e-03,  4.92387146e-01],
         [-2.06039637e-01, -3.85865837e-01,  4.36843187e-01]],
 
        [[-6.79195276e-04, -3.83904636e-01,  2.29710907e-01],
         [ 2.80707717e-01,  2.00759456e-01,  1.46481842e-02],
         [-1.17081106e-01, -2.82531027e-02,  3.31816822e-01],
         ...,
         [ 3.21702540e-01,  4.06617448e-02,  6.64333403e-02],
         [-1.24051020e-01, -2.72298008e-02,  4.06954706e-01],
         [-2.16668859e-01, -5.17479122e-01,  3.09679031e-01]],
 
        [[-2.36724108e-01, -1.38110116e-01,  5.67132294e-01],
         [-2.07769811e-01, -5.95804192e-02,  5.94775081e-01],
         [ 5.42180240e-01,  1.89586118e-01, -4.91482466e-02],
         ...,
    