In [1]:
import jax
import jax.numpy as np
import jraph
import flax.linen as nn

from functools import partial

%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../")

from models.gnn import GraphConvNet
from models.graph_utils import nearest_neighbors

In [3]:
n_nodes = 5000
x = np.load("../data/nbody_samples_only_pos.npz")['x_true'][:, :n_nodes, :]

In [4]:
sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(x[:4], 20)

In [5]:
graph = jraph.GraphsTuple(
          n_node=np.array(4 * [[n_nodes]]), 
          n_edge=np.array(4 * [[10]]),
          nodes=x[:4], 
          edges=None,
          globals=np.ones((4, 6)),
          senders=sources,
          receivers=targets)

In [6]:
class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(GraphConvNet(latent_size=6, num_mlp_layers=4, message_passing_steps=3, skip_connections=True))
        return model(x)

In [7]:
model = GraphWrapper()
rng = jax.random.PRNGKey(42)
graph_out, _ = model.init_with_output(rng, graph)

In [8]:
graph_out.nodes.shape

(4, 5000, 3)

In [9]:
import sys
sys.path.append("../")

from models.diffusion import ScoreNet

In [10]:
model = ScoreNet(5, 16)
rng = jax.random.PRNGKey(42)

n_batch = 4

model.init_with_output(rng, x[:n_batch], np.linspace(0., 1., n_batch), np.ones((n_batch, 6)), np.ones((n_batch, x.shape[1])))

(4, 5000, 3) (4, 5000, 3)


(Array([[[903.8656  , 368.60745 , 830.4612  ],
         [837.30084 , 872.7723  , 764.3804  ],
         [277.46283 , 385.47812 , 559.92017 ],
         ...,
         [486.67163 , 401.82965 , 156.97105 ],
         [134.36885 , 378.06964 ,  40.45679 ],
         [287.27963 , 494.56064 , 396.5344  ]],
 
        [[581.6499  , 364.99442 , 859.5758  ],
         [990.69415 , 964.5659  , 222.57562 ],
         [703.97516 , 730.88666 , 879.9244  ],
         ...,
         [602.93494 ,  43.60773 , 854.40894 ],
         [ 29.818102, 844.79785 , 636.4221  ],
         [ 66.17915 , 377.77112 , 694.73956 ]],
 
        [[444.6601  , 408.54828 , 528.1016  ],
         [185.81073 , 338.89386 , 650.596   ],
         [ 74.36075 , 399.5336  , 605.35455 ],
         ...,
         [932.522   , 321.66864 , 276.05817 ],
         [200.77759 , 607.71234 , 162.46689 ],
         [401.99054 , 283.14972 , 562.74817 ]],
 
        [[697.26526 , 908.07574 , 123.614456],
         [875.4801  , 246.53026 , 467.9639  ],
         

In [None]:
sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(x[:n_batch], 20, mask=np.ones((n_batch, x.shape[1])))

In [12]:
np.ones((n_batch, x.shape[1])).shape

(4, 5000)