In [4]:
import jax
import jax.numpy as np
import jraph
import flax.linen as nn
import numpy as onp
import matplotlib.pyplot as plt

from functools import partial

from jax.config import config
config.update("jax_debug_nans", True)

In [5]:
%load_ext autoreload
%autoreload 2

In [6]:
import sys
sys.path.append("../")

from models.gnn import GraphConvNet
from models.graph_utils import nearest_neighbors
from models.graph_utils import add_graphs_tuples
from models.train_utils import create_input_iter

In [7]:
from datasets import load_data

In [8]:
train_ds, norm_dict = load_data("nbody", 3, 5000, 2, 234)

In [9]:
batches = create_input_iter(train_ds)
x = next(batches)
x = x[0][0]

In [10]:
norm_dict

{'mean': Array([250.20403, 250.02423, 250.07912], dtype=float32),
 'std': Array([144.30989, 144.35008, 144.3569 ], dtype=float32)}

In [9]:
# z = [[[-0.74147505 -1.6911283   1.5016145 ]
#   [-1.2356482   0.83966905 -0.6621572 ]
#   [-0.93439144  0.7761591   1.7189647 ]
#   ...
#   [-0.37795684  0.05500687 -0.8659335 ]
#   [-1.708255   -1.1711149  -0.9808073 ]
#   [-1.0818185   0.7704281  -1.5321599 ]]

#  [[-0.5730722  -0.08018497 -0.70453495]
#   [-0.3276197  -0.12879108 -1.3185476 ]
#   [-0.8818957  -0.07687593 -1.0378724 ]
#   ...
#   [-0.0777759  -0.44764853  1.6243275 ]
#   [ 0.98842955  0.85802907 -1.4616792 ]
#   [-1.565508    0.18810192  1.3286879 ]]]

In [10]:
x

Array([[[-0.74147505, -1.6911283 ,  1.5016145 ],
        [-1.2356482 ,  0.83966905, -0.6621572 ],
        [-0.93439144,  0.7761591 ,  1.7189647 ],
        ...,
        [-0.37795684,  0.05500687, -0.8659335 ],
        [-1.708255  , -1.1711149 , -0.9808073 ],
        [-1.0818185 ,  0.7704281 , -1.5321599 ]],

       [[-0.5730722 , -0.08018497, -0.70453495],
        [-0.3276197 , -0.12879108, -1.3185476 ],
        [-0.8818957 , -0.07687593, -1.0378724 ],
        ...,
        [-0.0777759 , -0.44764853,  1.6243275 ],
        [ 0.98842955,  0.85802907, -1.4616792 ],
        [-1.565508  ,  0.18810192,  1.3286879 ]]], dtype=float32)

In [11]:
k = 20

In [12]:
sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(x, k)

## EGNN

In [13]:
from models.egnn import EGNN

In [14]:
class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(EGNN(
                        message_passing_steps=2, d_hidden=64, n_layers=3, norm_layer=False, skip_connections=False,
                ))
        return model(x)
    
model = GraphWrapper()
rng = jax.random.PRNGKey(42)

In [15]:
n_nodes = 5000

graph = jraph.GraphsTuple(
          n_node=np.array(2 * [[n_nodes]]), 
          n_edge=np.array(2 * [[k]]),
          nodes=x, 
          edges=None,
          globals=np.ones((2, 7)),
          senders=sources,
          receivers=targets)

graph_out, _ = model.init_with_output(rng, graph)

x_out = graph_out.nodes

0.0, 0.4084402918815613, 0.06261736899614334, 0.047285839915275574, 0, 0
0.0, 0.4638731777667999, 0.062252726405858994, 0.04635818302631378, 0, 0
xi = [[-0.74147505 -1.6911283   1.5016145 ]
 [-0.74147505 -1.6911283   1.5016145 ]
 [-0.74147505 -1.6911283   1.5016145 ]
 [-0.74147505 -1.6911283   1.5016145 ]], xj = [[-0.74147505 -1.6911283   1.5016145 ]
 [-0.7410142  -1.6854184   1.5522099 ]
 [-0.7767966  -1.6671458   1.6424124 ]
 [-0.771016   -1.5445424   1.5368899 ]]
xi = [[-0.5730722  -0.08018497 -0.70453495]
 [-0.5730722  -0.08018497 -0.70453495]
 [-0.5730722  -0.08018497 -0.70453495]
 [-0.5730722  -0.08018497 -0.70453495]], xj = [[-0.5730722  -0.08018497 -0.70453495]
 [-0.679718    0.01507041 -0.6691244 ]
 [-0.4275025  -0.15570742 -0.6210643 ]
 [-0.424658   -0.1758395  -0.6367179 ]]
ms = [[0.        ]
 [0.00259271]
 [0.02164683]
 ...
 [0.07887913]
 [0.08057176]
 [0.08687778]]
ms = [[0.        ]
 [0.02170082]
 [0.03386152]
 ...
 [0.04503524]
 [0.0550633 ]
 [0.06724646]]
0.0, 0.4172574

In [16]:
x_out

Array([[[-0.74218833, -1.6939582 ,  1.5025882 ],
        [-1.2361424 ,  0.8383076 , -0.66255885],
        [-0.9358201 ,  0.7730425 ,  1.7211223 ],
        ...,
        [-0.3784609 ,  0.05406107, -0.86660147],
        [-1.7090771 , -1.167872  , -0.9857224 ],
        [-1.0820674 ,  0.7688353 , -1.5321354 ]],

       [[-0.5736024 , -0.07981203, -0.7029008 ],
        [-0.33021122, -0.13043177, -1.3140228 ],
        [-0.8851719 , -0.07996556, -1.0405757 ],
        ...,
        [-0.07942551, -0.45003724,  1.6275957 ],
        [ 0.98923224,  0.8587559 , -1.4602206 ],
        [-1.5674908 ,  0.18713947,  1.3279947 ]]], dtype=float32)

## Actually run VDM

In [11]:
from models.diffusion import VariationalDiffusionModel
from flax.core import FrozenDict
from models.diffusion_utils import loss_vdm
from models.train_utils import create_input_iter, param_count, train_step

In [12]:
train_ds, norm_dict = load_data("nbody", 3, 5000, 4, 234)

In [13]:
batches = create_input_iter(train_ds)

In [14]:
x, conditioning, mask = next(batches)

In [15]:
score_dict = FrozenDict({"k":20, "message_passing_steps":2, "skip_connections":False, "norm_layer":False, "n_layers":4, "d_hidden":16, "n_pos_features":3})  # GNN args
encoder_dict = decoder_dict = FrozenDict({})
score = "egnn"

vdm = VariationalDiffusionModel(
          timesteps=0, 
          d_t_embedding=16,
          d_feature=3,
          score=score,
          score_dict=score_dict,
          n_classes=2,
          embed_context=True,
          d_context_embedding=16,
          noise_schedule="learned_linear",
          gamma_min=-8.,
          gamma_max=14.,
          use_encdec=False,
          encoder_dict=encoder_dict,
          decoder_dict=decoder_dict)

In [16]:
x.shape, conditioning.shape, mask.shape

((1, 4, 5000, 3), (1, 4, 2), (1, 4, 5000))

In [17]:
# Pass a test batch through to initialize model
n_smoke = 4

x = np.array(x[0, :n_smoke])
conditioning = np.array(conditioning[0, :n_smoke])
mask = np.array(mask[0, :n_smoke])

rng = jax.random.PRNGKey(42)
out, params = vdm.init_with_output({"sample": rng, "params": rng}, x, conditioning, mask);

In [18]:
loss = loss_vdm(params, vdm, rng, x, np.zeros_like(conditioning), mask)

In [19]:
loss, loss_grad = jax.value_and_grad(loss_vdm)(params, vdm, rng, x, np.zeros_like(conditioning), mask)

FloatingPointError: invalid value (nan) encountered in jit(mul)

In [25]:
loss

Array(13.893028, dtype=float32)

In [26]:
loss_grad

FrozenDict({
    params: {
        embedding_class: {
            embedding: Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],      dtype=float32),
        },
        embedding_context: {
            bias: Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float32),
            kernel: Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],      dtype=float32),
        },
        score_model: {
            EGNN_0: {
                MLP_0: {
                    Dense_0: {
                        bias: Array([ 0.00364586,  0.00067228, -0.00084506,  0.00293911,  0.00205781,
                               -0.00048638,  0.00236795, -0.00028667, -0.00018439, -0.00475061,
                                0.00075227,  0.00250874,  0.00036629, -0.00423419, -0.00340149,
                                0.00109621], dtype=float32),
          