In [1]:
import jax
import jax.numpy as np
import jraph
import flax.linen as nn
import numpy as onp
import matplotlib.pyplot as plt

from functools import partial

from jax.config import config
config.update("jax_debug_nans", True)

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
sys.path.append("../")

from models.gnn import GraphConvNet
from models.graph_utils import nearest_neighbors
from models.graph_utils import add_graphs_tuples
from models.train_utils import create_input_iter

In [4]:
from datasets import load_data

In [5]:
n_nodes = 5000

In [6]:
train_ds, norm_dict = load_data("nbody", 3, n_nodes, 2, 234)

In [9]:
tuple(map(float, norm_dict['mean']))

(250.2040252685547, 250.02423095703125, 250.07911682128906)

In [7]:
batches = create_input_iter(train_ds)
x = next(batches)
x = x[0][0]

In [8]:
k = 20

In [9]:
sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(x, k)

## EGNN

In [10]:
from models.egnn import EGNN

In [11]:
class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(EGNN(
                        message_passing_steps=2, d_hidden=64, n_layers=3, norm_layer=False, skip_connections=False,
                ))
        return model(x)
    
model = GraphWrapper()
rng = jax.random.PRNGKey(42)

In [12]:
globals = np.array([np.arange(50)] * 2)

In [13]:
from einops import repeat
edges = repeat(globals, "b g -> b e g", e = n_nodes * k)

In [14]:
# graph = jraph.GraphsTuple(
#           n_node=np.array(2 * [[n_nodes]]), 
#           n_edge=np.array(2 * [[k]]),
#           nodes=x, 
#           edges=None,
#           globals=np.array([np.arange(50)] * 2),
#           senders=sources,
#           receivers=targets)

# graph_out, _ = model.init_with_output(rng, graph)

# x_out = graph_out.nodes

In [15]:
# x_out.shape

## Actually run VDM

In [16]:
from models.diffusion import VariationalDiffusionModel
from flax.core import FrozenDict
from models.diffusion_utils import loss_vdm
from models.train_utils import create_input_iter, param_count, train_step

In [17]:
from models.egnn_jax import EGNN

In [18]:
train_ds, norm_dict = load_data("nbody", 3, 5000, 4, 234)

In [19]:
batches = create_input_iter(train_ds)

In [20]:
x, conditioning, mask = next(batches)

In [21]:
score_dict = FrozenDict({"k":20, "message_passing_steps":2, "skip_connections":False, "norm_layer":False, "n_layers":4, "d_hidden":16, "n_pos_features":3})  # GNN args
encoder_dict = decoder_dict = FrozenDict({})
score = "egnn"

vdm = VariationalDiffusionModel(
          timesteps=0, 
          d_t_embedding=16,
          d_feature=3,
          score=score,
          score_dict=score_dict,
          n_classes=2,
          embed_context=True,
          d_context_embedding=16,
          noise_schedule="learned_linear",
          gamma_min=-8.,
          gamma_max=14.,
          use_encdec=False,
          encoder_dict=encoder_dict,
          decoder_dict=decoder_dict,
          norm_dict=norm_dict)

In [22]:
# Pass a test batch through to initialize model
n_smoke = 4

x = np.array(x[0, :n_smoke])
conditioning = np.array(conditioning[0, :n_smoke])
mask = np.array(mask[0, :n_smoke])

In [23]:
coord_mean = norm_dict["mean"][:3]
coord_std = norm_dict["std"][:3]

In [24]:
x.shape, conditioning.shape, mask.shape

((4, 5000, 3), (4, 2), (4, 5000))

In [25]:
rng = jax.random.PRNGKey(42)
out, params = vdm.init_with_output({"sample": rng, "params": rng}, x, conditioning, mask);

In [26]:
loss = loss_vdm(params, vdm, rng, x, np.zeros_like(conditioning), mask)

In [27]:
loss, loss_grad = jax.value_and_grad(loss_vdm)(params, vdm, rng, x, np.zeros_like(conditioning), mask)

In [28]:
loss

Array(215.55902, dtype=float32)

## EGNN other implemen

In [285]:
from models.egnn_jax import EGNN as EGNNJax

In [243]:
class GraphWrapper(nn.Module):
    
    @nn.compact
    def __call__(self, graph, x):
        model = jax.vmap(EGNNJax(hidden_size=64, num_layers=4, act_fn=jax.nn.gelu, residual=True, attention=True, normalize=True, tanh=True))
        return model(graph, x)
    
model = GraphWrapper()
rng = jax.random.PRNGKey(42)

In [244]:
graph = jraph.GraphsTuple(
          n_node=np.array(2 * [[n_nodes]]), 
          n_edge=np.array(2 * [[k]]),
          nodes=x, 
          edges=None,
          globals=np.array([np.arange(50)] * 2),
          senders=sources,
          receivers=targets)

pos, params = model.init_with_output(rng, graph, x)

In [245]:
pos

Array([[[-30.578857  , -34.251675  ,  -3.9459803 ],
        [  4.5112658 ,   6.9919643 ,   0.39431453],
        [ 10.9624605 , -11.938925  ,  44.85105   ],
        ...,
        [ -6.0228853 , -19.419931  ,  -7.7576513 ],
        [-14.449694  ,  11.418314  ,   8.207899  ],
        [ -2.053481  ,   5.8104205 ,   7.0276966 ]],

       [[ -2.6015284 ,  10.589291  ,  15.668666  ],
        [ 12.237366  ,  25.771553  ,  14.827638  ],
        [ -5.10552   ,  -5.1483183 ,  -3.9938614 ],
        ...,
        [ 12.942366  , -12.413314  ,  23.25596   ],
        [  4.384595  , -14.923397  , -18.405027  ],
        [ -2.102576  ,  14.664121  , -10.526007  ]]], dtype=float32)