In [1]:
import jax
import jax.numpy as np
import jraph
import flax.linen as nn

from functools import partial

In [2]:
%load_ext autoreload
%autoreload 2

In [133]:
import sys
sys.path.append("../")

from models.gnn import GraphConvNet
from models.graph_utils import nearest_neighbors

In [134]:
x = np.load("data/halos_small.npy")
n_nodes = 5000

sources, targets, dist = jax.vmap(nearest_neighbors, in_axes=(0, None))(x[:4, :, :3], 50)

In [135]:
graph = jraph.GraphsTuple(
          n_node=np.array(4 * [[n_nodes]]), 
          n_edge=np.array(4 * [[50]]),
          nodes=x[:4], 
          edges=dist[..., None],
          globals=np.ones((4, 7)),
          senders=sources,
          receivers=targets)

In [136]:
graph.n_node.shape, graph.n_edge.shape, graph.nodes.shape, graph.globals.shape, graph.senders.shape, graph.receivers.shape, 

((4, 1), (4, 1), (4, 5000, 7), (4, 7), (4, 250000), (4, 250000))

In [137]:
class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(GraphConvNet(
            latent_size=6, 
            hidden_size=32,
            num_mlp_layers=4, 
            attention=True,
            message_passing_steps=3, 
            shared_weights=False,
            skip_connections=True,))
        return model(x)

In [115]:
model = GraphWrapper()
rng = jax.random.PRNGKey(42)
graph_out, params = model.init_with_output(rng, graph)

graph_out

GraphsTuple(nodes=Array([[[ 0.3201818 ,  0.27534413, -0.3180032 ],
        [ 0.33794418,  0.30218726, -0.31939706],
        [ 0.32354224,  0.28020442, -0.31872648],
        ...,
        [ 0.54821897,  0.52038336, -0.4289983 ],
        [ 0.55066013,  0.5230596 , -0.428186  ],
        [ 0.5485737 ,  0.5207604 , -0.42850506]],

       [[ 0.33066076,  0.29415044, -0.31264204],
        [ 0.33746338,  0.29990524, -0.32095745],
        [ 0.318626  ,  0.27991438, -0.30049586],
        ...,
        [ 0.5462873 ,  0.51871043, -0.42789146],
        [ 0.5475256 ,  0.52038574, -0.4265357 ],
        [ 0.54554456,  0.517431  , -0.4304901 ]],

       [[ 0.3071271 ,  0.2596523 , -0.31489176],
        [ 0.33791628,  0.30205736, -0.31948337],
        [ 0.33137852,  0.29102114, -0.32083347],
        ...,
        [ 0.5432364 ,  0.51646173, -0.42521912],
        [ 0.54638654,  0.5210671 , -0.41968778],
        [ 0.54745233,  0.52049667, -0.42569938]],

       [[ 0.32646844,  0.284438  , -0.31935483],
      

In [39]:
sum(x.size for x in jax.tree_leaves(params))

  sum(x.size for x in jax.tree_leaves(params))


23433

In [34]:
graph_out.nodes.shape

(4, 5000, 3)