In [1]:
import jax
import jax.numpy as np
import jraph
import flax.linen as nn
import numpy as onp
import matplotlib.pyplot as plt

from functools import partial

from jax.config import config
config.update("jax_debug_nans", True)

In [2]:
%load_ext autoreload
%autoreload 2

In [183]:
import sys
sys.path.append("../")

from models.gnn import GraphConvNet
from models.graph_utils import nearest_neighbors
from models.graph_utils import add_graphs_tuples
from models.train_utils import create_input_iter

In [184]:
from datasets import load_data

In [185]:
n_nodes = 5000

In [186]:
train_ds, norm_dict = load_data("nbody", 3, n_nodes, 2, 234)

In [187]:
tuple(map(float, norm_dict['mean']))

(500.28973388671875, 500.1238708496094, 500.25457763671875)

In [188]:
batches = create_input_iter(train_ds)
x_data = next(batches)
x = x_data[0][0]
cond = x_data[1][0]

In [189]:
k = 20

In [190]:
sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(x, k)

## EGNN; equivariance test

In [191]:
from models.egnn_jax import EGNN
from models.graph_utils import rotate_representation

In [192]:
coord_mean, coord_std, box_size = norm_dict['mean'], norm_dict['std'], 1000.

In [193]:
class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, graph, pos):
        return jax.vmap(EGNN(k=k), in_axes=(0, 0, 0, 0, None, None, None))(graph, pos, None, None, coord_mean, coord_std, box_size)
    
model = GraphWrapper()
rng = jax.random.PRNGKey(42)

In [194]:
d2 = np.sum(x ** 2, axis=-1, keepdims=True)

### Rotation equivariance

In [195]:
pos = x
pos_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x, 45, np.array([0., 0., 1.]))

In [196]:
graph = jraph.GraphsTuple(
          n_node=np.array(2 * [[n_nodes]]), 
          n_edge=np.array(2 * [[k]]),
          nodes=d2, 
          edges=None,
          globals=cond,
          senders=sources,
          receivers=targets)

pos_out, _ = model.init_with_output(rng, graph, pos)

In [197]:
graph = jraph.GraphsTuple(
          n_node=np.array(2 * [[n_nodes]]), 
          n_edge=np.array(2 * [[k]]),
          nodes=d2, 
          edges=None,
          globals=cond,
          senders=sources,
          receivers=targets)

pos_rot_out, _ = model.init_with_output(rng, graph, pos_rot)

In [198]:
pos_out_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(pos_out, 45, np.array([0., 0., 1.]))

In [199]:
pos_rot_out / pos_out_rot

Array([[[1.0000186 , 1.0041968 , 1.0012723 ],
        [0.9996871 , 0.99964136, 1.0001278 ],
        [0.99993014, 0.9998736 , 0.9998053 ],
        ...,
        [0.9938977 , 0.9999632 , 1.0000858 ],
        [1.0005772 , 1.000247  , 0.99893963],
        [1.0001596 , 1.0000917 , 1.0006036 ]],

       [[0.99974424, 0.99993306, 0.9999587 ],
        [1.00022   , 1.0003921 , 0.9999065 ],
        [1.0005502 , 0.99963886, 0.9998647 ],
        ...,
        [0.9997811 , 0.9992433 , 1.0001013 ],
        [0.9999908 , 0.9902577 , 1.0000782 ],
        [1.0034574 , 1.0000578 , 0.9995494 ]]], dtype=float32)

### Translation equivariance

In [200]:
tran = np.array([-0.45, 0.9, -1.2])[None, None, :]
pos_tran = pos + tran
pos_out_tran = pos_out + tran

In [201]:
graph = jraph.GraphsTuple(
          n_node=np.array(2 * [[n_nodes]]), 
          n_edge=np.array(2 * [[k]]),
          nodes=d2, 
          edges=None,
          globals=cond,
          senders=sources,
          receivers=targets)

pos_tran_out, _ = model.init_with_output(rng, graph, pos_tran)

In [202]:
pos_tran_out / pos_out_tran

Array([[[1.0000131 , 1.0000918 , 1.0000881 ],
        [1.0001146 , 1.0003357 , 1.0000314 ],
        [1.0000138 , 1.0000178 , 0.99999046],
        ...,
        [1.0000104 , 1.0000035 , 1.0000101 ],
        [1.0000414 , 1.0000386 , 0.999928  ],
        [1.0000116 , 0.9999892 , 0.9999824 ]],

       [[0.99999607, 1.0000114 , 1.0000179 ],
        [1.0000027 , 0.999789  , 1.000614  ],
        [1.0000232 , 0.9999705 , 0.99998134],
        ...,
        [0.99998987, 0.9998664 , 0.9999764 ],
        [1.0000228 , 0.999852  , 1.0000552 ],
        [0.99999464, 0.9999804 , 1.0000181 ]]], dtype=float32)

## Look at VLB loss

In [104]:
from models.diffusion import VariationalDiffusionModel
from flax.core import FrozenDict
from models.diffusion_utils import loss_vdm
from models.train_utils import create_input_iter, param_count, train_step

In [148]:
n_batch = 8

In [149]:
train_ds, norm_dict_tmp = load_data("nbody", 3, 5000, n_batch, 23)

In [150]:
batches = create_input_iter(train_ds)

In [151]:
x, conditioning, mask = next(batches)
x, conditioning, mask = x[0], conditioning[0], mask[0]

In [152]:
score_dict = FrozenDict({"k":20, "n_pos_features":3})  # GNN args
encoder_dict = decoder_dict = FrozenDict({})
norm_dict = FrozenDict({"x_mean":tuple(map(float, norm_dict_tmp['mean'])), "x_std":tuple(map(float, norm_dict_tmp['std'])), "box_size":1000.})  # GNN args
score = "egnn"

vdm = VariationalDiffusionModel(
          timesteps=0, 
          d_t_embedding=16,
          d_feature=3,
          score=score,
          score_dict=score_dict,
          n_classes=0,
          embed_context=True,
          d_context_embedding=16,
          noise_schedule="learned_linear",
          gamma_min=-8.,
          gamma_max=14.,
          use_encdec=False,
          encoder_dict=encoder_dict,
          decoder_dict=decoder_dict,
          norm_dict=norm_dict)

In [153]:
x.shape, conditioning.shape, mask.shape

((8, 5000, 3), (8, 2), (8, 5000))

In [154]:
rng = jax.random.PRNGKey(42)
out, params = vdm.init_with_output({"sample": rng, "params": rng}, x, conditioning, mask);

## Check rotational invariance of loss

In [155]:
loss = loss_vdm(params, vdm, rng, x, conditioning, mask)
loss

Array(12.139326, dtype=float32)

In [156]:
x_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x, 45., np.array([0., 0., 1.]))
loss = loss_vdm(params, vdm, rng, x_rot, conditioning, mask)
loss

Array(12.248699, dtype=float32)

In [157]:
x_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x, 93., np.array([0., 1. / np.sqrt(2), 1. / np.sqrt(2)]))
loss = loss_vdm(params, vdm, rng, x_rot, conditioning, mask)
loss

Array(12.33543, dtype=float32)

In [27]:
# loss, loss_grad = jax.value_and_grad(loss_vdm)(params, vdm, rng, x, np.zeros_like(conditioning), mask)

## Check rotational invariance of loss (individual components)

In [174]:
loss_diff, loss_klz, loss_recon = vdm.apply(params, x, conditioning, mask, rngs={"sample": rng})
loss_batch = (((loss_diff + loss_klz) * mask[:, :, None]).sum((-1, -2)) + (loss_recon * mask[:, :, None]).sum((-1, -2))) / mask.sum(-1)
loss_batch

Array([ 27.963305,   6.401713, -13.336952, -16.46254 , -16.718292,
        37.642227,  37.166668,  34.458485], dtype=float32)

In [175]:
x_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x, 45., np.array([0., 0., 1.]))
loss_diff, loss_klz, loss_recon = vdm.apply(params, x_rot, conditioning, mask, rngs={"sample": rng})
loss_batch = (((loss_diff + loss_klz) * mask[:, :, None]).sum((-1, -2)) + (loss_recon * mask[:, :, None]).sum((-1, -2))) / mask.sum(-1)
loss_batch

Array([ 28.21428  ,   6.3376017, -13.366019 , -16.469038 , -16.718327 ,
        37.53338  ,  37.779152 ,  34.67857  ], dtype=float32)

In [176]:
x_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x, 93., np.array([0., 1. / np.sqrt(2), 1. / np.sqrt(2)]))
loss_diff, loss_klz, loss_recon = vdm.apply(params, x_rot, conditioning, mask, rngs={"sample": rng})
loss_batch = (((loss_diff + loss_klz) * mask[:, :, None]).sum((-1, -2)) + (loss_recon * mask[:, :, None]).sum((-1, -2))) / mask.sum(-1)
loss_batch

Array([ 27.736063 ,   6.690169 , -13.4112835, -16.473114 , -16.718084 ,
        37.80085  ,  37.80386  ,  35.25499  ], dtype=float32)

In [178]:
loss_batch[:].mean()

Array(12.335431, dtype=float32)

## Check translational invariance of loss

In [179]:
# Translate
x_tran = x + np.array([0.2, 0.3, 0.5])[None, None, :]

loss_diff, loss_klz, loss_recon = vdm.apply(params, x_tran, conditioning, mask, rngs={"sample": rng})
loss_batch = (((loss_diff + loss_klz) * mask[:, :, None]).sum((-1, -2)) + (loss_recon * mask[:, :, None]).sum((-1, -2))) / mask.sum(-1)
loss_batch

Array([ 32.587143,   9.189462, -12.947554, -16.427973, -16.716597,
        42.018078,  41.487644,  37.862812], dtype=float32)

In [180]:
# Translate
x_tran = x + np.array([-0.45, 0.9, -1.2])[None, None, :]

loss_diff, loss_klz, loss_recon = vdm.apply(params, x_tran, conditioning, mask, rngs={"sample": rng})
loss_batch = (((loss_diff + loss_klz) * mask[:, :, None]).sum((-1, -2)) + (loss_recon * mask[:, :, None]).sum((-1, -2))) / mask.sum(-1)
loss_batch

Array([ 53.790115,  22.687338, -10.815434, -16.25179 , -16.706673,
        63.127132,  63.840603,  63.683865], dtype=float32)

In [211]:
eps = jax.random.normal(rng, shape=x.shape)
eps_com = np.mean(eps[..., :3], axis=-1, keepdims=True)
eps -= eps_com

In [213]:
eps.mean()

Array(4.1127204e-10, dtype=float32)