In [1]:
import jax
import jax.numpy as np
import jraph
import flax.linen as nn

from functools import partial

%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../")

from models.gnn import GraphConvNet
from models.graph_utils import nearest_neighbors

In [3]:
n_nodes = 5000
x = np.load("../data/nbody_samples_only_pos.npz")['x_true'][:, :n_nodes, :]

In [4]:
sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(x[:4], 20)

In [5]:
graph = jraph.GraphsTuple(
          n_node=np.array(4 * [[n_nodes]]), 
          n_edge=np.array(4 * [[10]]),
          nodes=x[:4], 
          edges=None,
          globals=np.ones((4, 6)),
          senders=sources,
          receivers=targets)

In [6]:
class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(GraphConvNet(latent_size=6, num_mlp_layers=4, message_passing_steps=3, skip_connections=True))
        return model(x)

In [7]:
model = GraphWrapper()
rng = jax.random.PRNGKey(42)
graph_out, _ = model.init_with_output(rng, graph)

In [8]:
graph_out.nodes.shape

(4, 5000, 3)

In [9]:
import sys
sys.path.append("../")

from models.diffusion import ScoreNet

In [10]:
model = ScoreNet(5, 16)
rng = jax.random.PRNGKey(42)

n_batch = 4

_, _ = model.init_with_output(rng, x[:n_batch], np.linspace(0., 1., n_batch), np.ones((n_batch, 6)), np.ones((n_batch, x.shape[1])))

In [11]:
sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(x[:n_batch], 20, mask=np.ones((n_batch, x.shape[1])))