In [1]:
import jax
import jax.numpy as np
import jraph
import flax.linen as nn
import numpy as onp
import matplotlib.pyplot as plt

from functools import partial

from jax.config import config
config.update("jax_debug_nans", True)

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
sys.path.append("../")

from models.gnn import GraphConvNet
from models.graph_utils import nearest_neighbors
from models.graph_utils import add_graphs_tuples
from models.train_utils import create_input_iter

In [4]:
from datasets import load_data

In [5]:
n_nodes = 5000

In [6]:
train_ds, norm_dict = load_data("nbody", 3, n_nodes, 2, 234)

In [7]:
tuple(map(float, norm_dict['mean']))

(500.28973388671875, 500.1238708496094, 500.25457763671875)

In [8]:
batches = create_input_iter(train_ds)
x_data = next(batches)
x = x_data[0][0]
cond = x_data[1][0]

In [9]:
k = 20

In [10]:
sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(x, k)

## EGNN; equivariance test

In [11]:
from models.egnn_jax import EGNN
from models.graph_utils import rotate_representation

In [12]:
coord_mean, coord_std, box_size = norm_dict['mean'], norm_dict['std'], 1000.

In [13]:
class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, graph, pos):
        return jax.vmap(EGNN(k=k), in_axes=(0, 0, 0, 0, None, None, None))(graph, pos, None, None, coord_mean, coord_std, box_size)
    
model = GraphWrapper()
rng = jax.random.PRNGKey(42)

In [14]:
d2 = np.sum(x ** 2, axis=-1, keepdims=True)

### Rotation equivariance

In [15]:
pos = x
pos_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x, 45, np.array([0., 0., 1.]))

In [16]:
graph = jraph.GraphsTuple(
          n_node=np.array(2 * [[n_nodes]]), 
          n_edge=np.array(2 * [[k]]),
          nodes=d2, 
          edges=None,
          globals=cond,
          senders=sources,
          receivers=targets)

pos_out, _ = model.init_with_output(rng, graph, pos)

In [17]:
graph = jraph.GraphsTuple(
          n_node=np.array(2 * [[n_nodes]]), 
          n_edge=np.array(2 * [[k]]),
          nodes=d2, 
          edges=None,
          globals=cond,
          senders=sources,
          receivers=targets)

pos_rot_out, _ = model.init_with_output(rng, graph, pos_rot)

In [18]:
pos_out_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(pos_out, 45, np.array([0., 0., 1.]))

In [19]:
pos_rot_out / pos_out_rot

Array([[[1.0000128 , 1.0034319 , 1.0012053 ],
        [0.9996871 , 0.9996415 , 1.0001278 ],
        [0.99993014, 0.9998736 , 0.9998053 ],
        ...,
        [0.9938977 , 0.9999632 , 1.0000858 ],
        [1.0005772 , 1.000247  , 0.99893963],
        [1.0001596 , 1.0000917 , 1.0006036 ]],

       [[0.99974424, 0.99993306, 0.9999587 ],
        [1.00022   , 1.0003921 , 0.9999065 ],
        [1.0005502 , 0.99963886, 0.9998647 ],
        ...,
        [0.9997811 , 0.9992432 , 1.0001013 ],
        [0.9999908 , 0.9902577 , 1.0000782 ],
        [1.0034574 , 1.0000578 , 0.9995494 ]]], dtype=float32)

### Translation equivariance

In [20]:
tran = np.array([-0.45, 0.9, -1.2])[None, None, :]
pos_tran = pos + tran
pos_out_tran = pos_out + tran

In [21]:
graph = jraph.GraphsTuple(
          n_node=np.array(2 * [[n_nodes]]), 
          n_edge=np.array(2 * [[k]]),
          nodes=d2, 
          edges=None,
          globals=cond,
          senders=sources,
          receivers=targets)

pos_tran_out, _ = model.init_with_output(rng, graph, pos_tran)

In [22]:
pos_tran_out / pos_out_tran

Array([[[1.0000131 , 1.0000918 , 1.0000881 ],
        [1.0001146 , 1.0003357 , 1.0000314 ],
        [1.0000138 , 1.0000178 , 0.99999046],
        ...,
        [0.9999998 , 1.0000001 , 0.99999994],
        [1.000044  , 1.0000398 , 0.9999269 ],
        [1.0000116 , 0.9999892 , 0.9999824 ]],

       [[0.99999607, 1.0000114 , 1.0000179 ],
        [1.0000027 , 0.999789  , 1.000614  ],
        [1.0000229 , 0.99996877, 0.9999823 ],
        ...,
        [0.99998987, 0.9998664 , 0.9999764 ],
        [1.0000228 , 0.999852  , 1.0000551 ],
        [0.99999464, 0.9999804 , 1.0000181 ]]], dtype=float32)

In [31]:
pos * coord_std + coord_mean

Array([[[ 49.585907,  20.940552, 418.52457 ],
        [430.10794 , 164.60818 ,  86.859344],
        [370.21463 , 136.73978 , 492.3676  ],
        ...,
        [942.817   ,  10.316071, 729.01215 ],
        [368.96332 , 979.7074  , 468.75595 ],
        [268.24493 , 945.361   , 526.4537  ]],

       [[346.52512 , 583.6051  , 266.82755 ],
        [ 68.74039 , 411.4977  , 782.6518  ],
        [ 28.042023, 825.5925  ,  41.775208],
        ...,
        [285.42636 , 284.67188 , 903.8516  ],
        [253.88661 , 253.52065 , 944.4217  ],
        [181.01636 , 876.4603  , 551.5096  ]]], dtype=float32)

In [36]:
def translate_pbc(coordinates, translation_vector, box_size):
    """
    Translate a batch of coordinates and apply periodic boundary conditions.

    :param coordinates: np.array of shape (batch, points, 3) containing the coordinates
    :param translation_vector: np.array of shape (3,) containing the translation vector
    :param box_size: float representing the box size
    :return: np.array of shape (batch, points, 3) containing the translated coordinates
    """

    # Translate the coordinates
    translated_coordinates = coordinates + translation_vector

    # Apply periodic boundary conditions
    translated_coordinates = np.mod(translated_coordinates, box_size)

    return translated_coordinates

In [37]:
tran = np.array([700., 100., 800.])
pos_tran = (translate_pbc(pos * coord_std + coord_mean, tran, 1000.) - coord_mean) / coord_std
pos_out_tran = (translate_pbc(pos_out * coord_std + coord_mean, tran, 1000.) - coord_mean) / coord_std

In [38]:
graph = jraph.GraphsTuple(
          n_node=np.array(2 * [[n_nodes]]), 
          n_edge=np.array(2 * [[k]]),
          nodes=d2, 
          edges=None,
          globals=cond,
          senders=sources,
          receivers=targets)

pos_tran_out, _ = model.init_with_output(rng, graph, pos_tran)

In [42]:
pos_tran_out / pos_out_tran

Array([[[1.0000098 , 1.0000012 , 1.0000873 ],
        [0.9999937 , 1.0002681 , 0.9999287 ],
        [1.0000025 , 1.0000275 , 1.0000205 ],
        ...,
        [0.9999789 , 0.99999005, 1.0001746 ],
        [1.0000045 , 0.9999778 , 0.999953  ],
        [1.0000043 , 0.99998236, 1.00002   ]],

       [[1.0000103 , 1.0000004 , 0.99998987],
        [1.0000541 , 0.9998161 , 0.99974245],
        [0.99990183, 0.99998397, 0.99998796],
        ...,
        [1.0000011 , 0.99973375, 0.99991703],
        [0.9999598 , 1.0000031 , 1.0000739 ],
        [1.0000082 , 0.9999883 , 1.0000718 ]]], dtype=float32)

## Look at VLB loss

In [43]:
from models.diffusion import VariationalDiffusionModel
from flax.core import FrozenDict
from models.diffusion_utils import loss_vdm
from models.train_utils import create_input_iter, param_count, train_step

In [44]:
n_batch = 2

In [45]:
train_ds, norm_dict_tmp = load_data("nbody", 3, 5000, n_batch, 23)

In [46]:
batches = create_input_iter(train_ds)

In [47]:
x, conditioning, mask = next(batches)
x, conditioning, mask = x[0], conditioning[0], mask[0]

In [48]:
score_dict = FrozenDict({"k":20, "n_pos_features":3})  # GNN args
encoder_dict = decoder_dict = FrozenDict({})
norm_dict = FrozenDict({"x_mean":tuple(map(float, norm_dict_tmp['mean'])), "x_std":tuple(map(float, norm_dict_tmp['std'])), "box_size":1000.})  # GNN args
score = "egnn"

vdm = VariationalDiffusionModel(
          timesteps=0, 
          d_t_embedding=16,
          d_feature=3,
          score=score,
          score_dict=score_dict,
          n_classes=0,
          embed_context=True,
          d_context_embedding=16,
          noise_schedule="learned_linear",
          gamma_min=-8.,
          gamma_max=14.,
          use_encdec=False,
          encoder_dict=encoder_dict,
          decoder_dict=decoder_dict,
          norm_dict=norm_dict)

In [49]:
x.shape, conditioning.shape, mask.shape

((2, 5000, 3), (2, 2), (2, 5000))

In [50]:
rng = jax.random.PRNGKey(42)
out, params = vdm.init_with_output({"sample": rng, "params": rng}, x, conditioning, mask);

## Check rotational invariance of loss

In [51]:
loss = loss_vdm(params, vdm, rng, x, conditioning, mask)
loss

Array(5.4991817, dtype=float32)

In [52]:
x_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x, 45., np.array([0., 0., 1.]))
loss = loss_vdm(params, vdm, rng, x_rot, conditioning, mask)
loss

Array(5.5293245, dtype=float32)

In [53]:
x_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x, 93., np.array([0., 1. / np.sqrt(2), 1. / np.sqrt(2)]))
loss = loss_vdm(params, vdm, rng, x_rot, conditioning, mask)
loss

Array(5.728607, dtype=float32)

In [54]:
# loss, loss_grad = jax.value_and_grad(loss_vdm)(params, vdm, rng, x, np.zeros_like(conditioning), mask)

## Check rotational invariance of loss (individual components)

In [73]:
loss_diff, loss_klz, loss_recon = vdm.apply(params, x, conditioning, mask, rngs={"sample": rng})
loss_batch = (((loss_diff + loss_klz) * mask[:, :, None]).sum((-1, -2)) + (loss_recon * mask[:, :, None]).sum((-1, -2))) / mask.sum(-1)
loss_batch

Array([ 36.10642 , -16.683311], dtype=float32)

In [74]:
x_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x, 45., np.array([0., 0., 1.]))
loss_diff, loss_klz, loss_recon = vdm.apply(params, x_rot, conditioning, mask, rngs={"sample": rng})
loss_batch = (((loss_diff + loss_klz) * mask[:, :, None]).sum((-1, -2)) + (loss_recon * mask[:, :, None]).sum((-1, -2))) / mask.sum(-1)
loss_batch

Array([ 35.976364, -16.683855], dtype=float32)

In [75]:
x_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x, 93., np.array([0., 1. / np.sqrt(2), 1. / np.sqrt(2)]))
loss_diff, loss_klz, loss_recon = vdm.apply(params, x_rot, conditioning, mask, rngs={"sample": rng})
loss_batch = (((loss_diff + loss_klz) * mask[:, :, None]).sum((-1, -2)) + (loss_recon * mask[:, :, None]).sum((-1, -2))) / mask.sum(-1)
loss_batch

Array([ 36.22463 , -16.684011], dtype=float32)

In [76]:
(loss_diff * mask[:, :, None]).sum((-1, -2)) / mask.sum(-1)

Array([5.292239e+01, 2.710890e-02], dtype=float32)

In [77]:
loss_batch[:].mean()

Array(9.7703085, dtype=float32)

## Check translational invariance of loss (individual components)

In [78]:
# Translate
tran = np.array([700., 100., 800.])
x_tran = (translate_pbc(x * coord_std + coord_mean, tran, 1000.) - coord_mean) / coord_std

loss_diff, loss_klz, loss_recon = vdm.apply(params, x_tran, conditioning, mask, rngs={"sample": rng})
loss_batch = (((loss_diff + loss_klz) * mask[:, :, None]).sum((-1, -2)) + (loss_recon * mask[:, :, None]).sum((-1, -2))) / mask.sum(-1)
loss_batch

Array([ 35.37609 , -16.683496], dtype=float32)

In [79]:
(loss_diff * mask[:, :, None]).sum((-1, -2)) / mask.sum(-1)

Array([5.2073860e+01, 2.7626948e-02], dtype=float32)

In [80]:
# Translate
tran = np.array([100., -1200., -840.])
x_tran = (translate_pbc(x * coord_std + coord_mean, tran, 1000.) - coord_mean) / coord_std

loss_diff, loss_klz, loss_recon = vdm.apply(params, x_tran, conditioning, mask, rngs={"sample": rng})
loss_batch = (((loss_diff + loss_klz) * mask[:, :, None]).sum((-1, -2)) + (loss_recon * mask[:, :, None]).sum((-1, -2))) / mask.sum(-1)
loss_batch

Array([ 37.02712, -16.68377], dtype=float32)

In [81]:
(loss_diff * mask[:, :, None]).sum((-1, -2)) / mask.sum(-1)

Array([5.3724865e+01, 2.7346682e-02], dtype=float32)

In [83]:
loss_batch[:].mean()

Array(10.171675, dtype=float32)

In [82]:
eps = jax.random.normal(rng, shape=x.shape)
eps_com = np.mean(eps[..., :3], axis=-2, keepdims=True)