In [1]:
import jax
import jax.numpy as np
import jraph
import flax.linen as nn

from functools import partial

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
sys.path.append("../")

from models.gnn import GraphConvNet
from models.graph_utils import nearest_neighbors



In [4]:
x = np.load("../data/halos_small.npy")
n_nodes = 5000
k = 20
sources, targets, dist = jax.vmap(nearest_neighbors, in_axes=(0, None))(x[:4, :, :3], k)

In [5]:
graph = jraph.GraphsTuple(
          n_node=np.array(4 * [[n_nodes]]), 
          n_edge=np.array(4 * [[k]]),
          nodes=x[:4, :, :6], 
          edges=dist,
          globals=np.ones((4, 7)),
          senders=sources,
          receivers=targets)

In [6]:
graph.n_edge[0]

Array([20], dtype=int32)

In [7]:
graph.n_node.shape, graph.n_edge.shape, graph.nodes.shape, graph.globals.shape, graph.senders.shape, graph.receivers.shape, 

((4, 1), (4, 1), (4, 5000, 6), (4, 7), (4, 100000), (4, 100000))

In [8]:
graph.edges.shape

(4, 100000, 3)

In [9]:
class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(GraphConvNet(
            latent_size=12, 
            hidden_size=32,
            num_mlp_layers=4, 
            attention=False,
            message_passing_steps=3, 
            shared_weights=False,
            skip_connections=True,))
        return model(x)

In [10]:
model = GraphWrapper()
rng = jax.random.PRNGKey(42)
graph_out, params = model.init_with_output(rng, graph)

graph_out

GraphsTuple(nodes=Array([[[-0.1125825 , -0.09283566,  0.11377222],
        [-0.05893194, -0.10181687,  0.07655472],
        [-0.1651399 , -0.19137335,  0.06769908],
        ...,
        [-0.07573469, -0.13577965,  0.09486147],
        [-0.14018403, -0.16820411,  0.08733329],
        [-0.24084674, -0.20037228, -0.04847446]],

       [[-0.11877929, -0.19698136, -0.03018252],
        [-0.17410675, -0.20649745,  0.09587927],
        [-0.07466327, -0.10741072,  0.10991385],
        ...,
        [-0.08268358, -0.10388442,  0.08842175],
        [-0.1477319 , -0.1747698 ,  0.08993898],
        [-0.25986856, -0.2063719 , -0.04190782]],

       [[-0.02476275, -0.0844294 , -0.01236729],
        [-0.12253042, -0.16551678,  0.08586553],
        [-0.13652821, -0.15572608, -0.02954918],
        ...,
        [-0.08189055, -0.11284494,  0.13404271],
        [-0.19673815, -0.22579373,  0.02905469],
        [-0.14372848, -0.16787201, -0.00698224]],

       [[-0.10140395, -0.1546606 ,  0.06181228],
      

In [11]:
sum(x.size for x in jax.tree_util.tree_leaves(params))

38763

In [12]:
graph_out.nodes.shape

(4, 5000, 3)

## DynamicEdgeConv test

In [13]:
from models.dynamic_edge_conv import DynamicEdgeConvNet

In [14]:
class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(DynamicEdgeConvNet(
            latent_size=12, 
            hidden_size=32,
            num_mlp_layers=4, 
            message_passing_steps=4, 
            k=k,
            skip_connections=True,))
        return model(x)

In [15]:
# model = GraphWrapper()
# rng = jax.random.PRNGKey(42)
# graph_out, params = model.init_with_output(rng, graph)

# graph_out

## ChebConv

In [111]:
graph = jraph.GraphsTuple(
          n_node=np.array(4 * [[n_nodes]]), 
          n_edge=np.array(4 * [[k]]),
          nodes=(x[:4, :, :6] - 500) / 500, 
          edges=dist.sum(-1) / 1000.,
          globals=np.ones((4, 7)),
          senders=sources,
          receivers=targets)

In [112]:
from models.chebconv import ChebConvNet

class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(ChebConvNet(bias=False))
        return model(x)

In [113]:
model = GraphWrapper()
rng = jax.random.PRNGKey(42)
graph_out, params = model.init_with_output(rng, graph)

graph_out.nodes[0]

Array([[ 2.3170793 ,  0.16200593, -1.2803118 ,  1.8137808 ,  0.5908083 ,
         0.3293581 ],
       [ 0.962221  , -1.7021335 , -0.21352464,  0.9629654 ,  0.31692111,
         1.0524303 ],
       [ 2.8236384 , -1.3319454 , -0.5397958 ,  1.13712   ,  1.1572564 ,
         1.2790109 ],
       ...,
       [ 2.341406  ,  1.8625927 , -1.4193633 ,  1.790666  , -0.31950915,
         0.7152538 ],
       [ 1.7572807 ,  0.7810406 ,  0.5669117 ,  1.2669858 ,  0.6030566 ,
         1.217212  ],
       [ 0.05830427,  1.167861  ,  0.33585003,  0.8599856 , -1.3885045 ,
         1.2700629 ]], dtype=float32)

In [114]:
sum(x.size for x in jax.tree_util.tree_leaves(params))

923910