In [1]:
import jax
import jax.numpy as np
import jraph
import flax.linen as nn
import numpy as onp
import matplotlib.pyplot as plt

from functools import partial

from jax.config import config
config.update("jax_debug_nans", True)

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
sys.path.append("../")

from models.gnn import GraphConvNet
from models.graph_utils import nearest_neighbors
from models.graph_utils import add_graphs_tuples
from models.train_utils import create_input_iter

In [4]:
from datasets import load_data

In [5]:
n_nodes = 5000

In [6]:
train_ds, norm_dict = load_data("nbody", 3, n_nodes, 2, 234)

In [7]:
tuple(map(float, norm_dict['mean']))

(500.28973388671875, 500.1238708496094, 500.25457763671875)

In [8]:
batches = create_input_iter(train_ds)
x_data = next(batches)
x = x_data[0][0]
cond = x_data[1][0]

In [9]:
k = 20

In [10]:
sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(x, k)

In [20]:
def translate_pbc(coordinates, translation_vector, box_size):
    """
    Translate a batch of coordinates and apply periodic boundary conditions.

    :param coordinates: np.array of shape (batch, points, 3) containing the coordinates
    :param translation_vector: np.array of shape (3,) containing the translation vector
    :param box_size: float representing the box size
    :return: np.array of shape (batch, points, 3) containing the translated coordinates
    """

    # Translate the coordinates
    translated_coordinates = coordinates + translation_vector

    # Apply periodic boundary conditions
    translated_coordinates = np.mod(translated_coordinates, box_size)

    return translated_coordinates

## EGNN; equivariance test

In [15]:
from models.egnn_jax import EGNN
from models.graph_utils import rotate_representation

In [16]:
coord_mean, coord_std, box_size = norm_dict['mean'], norm_dict['std'], 1000.

In [17]:
class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, graph, pos):
        return jax.vmap(EGNN(k=k), in_axes=(0, 0, 0, 0, None, None, None))(graph, pos, None, None, coord_mean, coord_std, box_size)
    
model = GraphWrapper()
rng = jax.random.PRNGKey(42)

In [18]:
d2 = np.sum(x ** 2, axis=-1, keepdims=True)

### Rotation equivariance

In [15]:
pos = x
pos_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x, 45, np.array([0., 0., 1.]))

In [16]:
graph = jraph.GraphsTuple(
          n_node=np.array(2 * [[n_nodes]]), 
          n_edge=np.array(2 * [[k]]),
          nodes=d2, 
          edges=None,
          globals=cond,
          senders=sources,
          receivers=targets)

pos_out, _ = model.init_with_output(rng, graph, pos)

In [17]:
graph = jraph.GraphsTuple(
          n_node=np.array(2 * [[n_nodes]]), 
          n_edge=np.array(2 * [[k]]),
          nodes=d2, 
          edges=None,
          globals=cond,
          senders=sources,
          receivers=targets)

pos_rot_out, _ = model.init_with_output(rng, graph, pos_rot)

In [18]:
pos_out_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(pos_out, 45, np.array([0., 0., 1.]))

In [19]:
pos_rot_out / pos_out_rot

Array([[[1.0000187 , 1.0041956 , 1.0012716 ],
        [0.9996871 , 0.99964136, 1.0001278 ],
        [0.99993014, 0.9998736 , 0.9998053 ],
        ...,
        [0.9938977 , 0.9999632 , 1.0000858 ],
        [1.0005772 , 1.000247  , 0.99893963],
        [1.0001596 , 1.0000917 , 1.0006036 ]],

       [[0.99974424, 0.99993306, 0.9999587 ],
        [1.0002201 , 1.000392  , 0.9999066 ],
        [1.0005502 , 0.99963886, 0.9998647 ],
        ...,
        [0.9997811 , 0.9992433 , 1.0001013 ],
        [0.9999908 , 0.9902577 , 1.0000782 ],
        [1.0034574 , 1.0000578 , 0.9995494 ]]], dtype=float32)

### Translation equivariance

In [20]:
tran = np.array([-0.45, 0.9, -1.2])[None, None, :]
pos_tran = pos + tran
pos_out_tran = pos_out + tran

In [21]:
graph = jraph.GraphsTuple(
          n_node=np.array(2 * [[n_nodes]]), 
          n_edge=np.array(2 * [[k]]),
          nodes=d2, 
          edges=None,
          globals=cond,
          senders=sources,
          receivers=targets)

pos_tran_out, _ = model.init_with_output(rng, graph, pos_tran)

In [22]:
pos_tran_out / pos_out_tran

Array([[[1.0000131 , 1.0000918 , 1.0000881 ],
        [1.000087  , 1.0004143 , 1.000023  ],
        [1.0000138 , 1.0000178 , 0.99999046],
        ...,
        [1.0000104 , 1.0000035 , 1.0000101 ],
        [1.0000414 , 1.0000386 , 0.999928  ],
        [1.0000116 , 0.9999892 , 0.9999824 ]],

       [[0.99999607, 1.0000114 , 1.0000179 ],
        [1.0000027 , 0.999789  , 1.000614  ],
        [1.0000229 , 0.99996877, 0.9999823 ],
        ...,
        [0.99998987, 0.9998664 , 0.9999764 ],
        [1.0000228 , 0.999852  , 1.0000552 ],
        [0.9999945 , 0.99997723, 1.0000137 ]]], dtype=float32)

In [24]:
tran = np.array([700., 100., 800.])
pos_tran = (translate_pbc(pos * coord_std + coord_mean, tran, 1000.) - coord_mean) / coord_std
pos_out_tran = (translate_pbc(pos_out * coord_std + coord_mean, tran, 1000.) - coord_mean) / coord_std

In [25]:
graph = jraph.GraphsTuple(
          n_node=np.array(2 * [[n_nodes]]), 
          n_edge=np.array(2 * [[k]]),
          nodes=d2, 
          edges=None,
          globals=cond,
          senders=sources,
          receivers=targets)

pos_tran_out, _ = model.init_with_output(rng, graph, pos_tran)

In [26]:
pos_tran_out / pos_out_tran

Array([[[1.0000098 , 1.0000012 , 1.0000873 ],
        [0.9999845 , 1.0002767 , 0.99993163],
        [1.0000025 , 1.0000275 , 1.0000205 ],
        ...,
        [0.99997884, 0.99999005, 1.0001746 ],
        [1.0000045 , 0.9999778 , 0.999953  ],
        [1.0000043 , 0.99998236, 1.00002   ]],

       [[1.0000103 , 1.0000004 , 0.99998987],
        [1.0000541 , 0.9998161 , 0.99974245],
        [0.99990183, 0.99998397, 0.99998796],
        ...,
        [1.0000011 , 0.99973375, 0.99991703],
        [0.9999598 , 1.0000031 , 1.0000739 ],
        [1.0000082 , 0.9999883 , 1.0000718 ]]], dtype=float32)

## Look at VLB loss

In [22]:
from models.diffusion import VariationalDiffusionModel
from flax.core import FrozenDict
from models.diffusion_utils import loss_vdm
from models.train_utils import create_input_iter, param_count, train_step

In [23]:
n_batch = 2

In [24]:
train_ds, norm_dict_tmp = load_data("nbody", 3, 5000, n_batch, 23)

In [25]:
batches = create_input_iter(train_ds)

In [26]:
x, conditioning, mask = next(batches)
x, conditioning, mask = x[0], conditioning[0], mask[0]

In [27]:
score_dict = FrozenDict({"k":20, "n_pos_features":3})  # GNN args
encoder_dict = decoder_dict = FrozenDict({})
norm_dict = FrozenDict({"x_mean":tuple(map(float, norm_dict_tmp['mean'])), "x_std":tuple(map(float, norm_dict_tmp['std'])), "box_size":1000.})  # GNN args
score = "egnn"

vdm = VariationalDiffusionModel(
          timesteps=0, 
          d_t_embedding=16,
          d_feature=3,
          score=score,
          score_dict=score_dict,
          n_classes=0,
          embed_context=True,
          d_context_embedding=16,
          noise_schedule="learned_linear",
          gamma_min=-8.,
          gamma_max=14.,
          use_encdec=False,
          encoder_dict=encoder_dict,
          decoder_dict=decoder_dict,
          norm_dict=norm_dict)

In [28]:
x.shape, conditioning.shape, mask.shape

((2, 5000, 3), (2, 2), (2, 5000))

In [29]:
rng = jax.random.PRNGKey(42)
out, params = vdm.init_with_output({"sample": rng, "params": rng}, x, conditioning, mask);

### Check rotational invariance of loss

In [35]:
loss = loss_vdm(params, vdm, rng, x, conditioning, mask)
loss

Array(9.643856, dtype=float32)

In [36]:
x_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x, 45., np.array([0., 0., 1.]))
loss = loss_vdm(params, vdm, rng, x_rot, conditioning, mask)
loss

Array(9.592255, dtype=float32)

In [37]:
x_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x, 93., np.array([0., 1. / np.sqrt(2), 1. / np.sqrt(2)]))
loss = loss_vdm(params, vdm, rng, x_rot, conditioning, mask)
loss

Array(9.698193, dtype=float32)

In [38]:
# loss, loss_grad = jax.value_and_grad(loss_vdm)(params, vdm, rng, x, np.zeros_like(conditioning), mask)

### Check rotational invariance of loss (individual components)

In [39]:
loss_diff, loss_klz, loss_recon = vdm.apply(params, x, conditioning, mask, rngs={"sample": rng})
loss_batch = (((loss_diff + loss_klz) * mask[:, :, None]).sum((-1, -2)) + (loss_recon * mask[:, :, None]).sum((-1, -2))) / mask.sum(-1)
loss_batch

Array([ 35.985077, -16.697365], dtype=float32)

In [40]:
x_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x, 45., np.array([0., 0., 1.]))
loss_diff, loss_klz, loss_recon = vdm.apply(params, x_rot, conditioning, mask, rngs={"sample": rng})
loss_batch = (((loss_diff + loss_klz) * mask[:, :, None]).sum((-1, -2)) + (loss_recon * mask[:, :, None]).sum((-1, -2))) / mask.sum(-1)
loss_batch

Array([ 35.881878, -16.697369], dtype=float32)

In [41]:
x_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x, 93., np.array([0., 1. / np.sqrt(2), 1. / np.sqrt(2)]))
loss_diff, loss_klz, loss_recon = vdm.apply(params, x_rot, conditioning, mask, rngs={"sample": rng})
loss_batch = (((loss_diff + loss_klz) * mask[:, :, None]).sum((-1, -2)) + (loss_recon * mask[:, :, None]).sum((-1, -2))) / mask.sum(-1)
loss_batch

Array([ 36.093754, -16.697369], dtype=float32)

In [42]:
(loss_diff * mask[:, :, None]).sum((-1, -2)) / mask.sum(-1)

Array([5.2791515e+01, 1.3752948e-02], dtype=float32)

In [43]:
loss_batch[:].mean()

Array(9.698193, dtype=float32)

### Check translational invariance of loss (individual components)

In [30]:
# Translate
tran = np.array([700., 100., 800.])
x_tran = (translate_pbc(x * coord_std + coord_mean, tran, 1000.) - coord_mean) / coord_std

loss_diff, loss_klz, loss_recon = vdm.apply(params, x_tran, conditioning, mask, rngs={"sample": rng})
loss_batch = (((loss_diff + loss_klz) * mask[:, :, None]).sum((-1, -2)) + (loss_recon * mask[:, :, None]).sum((-1, -2))) / mask.sum(-1)
loss_batch

Array([ 35.224545, -16.697493], dtype=float32)

In [31]:
(loss_diff * mask[:, :, None]).sum((-1, -2)) / mask.sum(-1)

Array([5.1922321e+01, 1.3630271e-02], dtype=float32)

In [32]:
# Translate
tran = np.array([100., -1200., -840.])
x_tran = (translate_pbc(x * coord_std + coord_mean, tran, 1000.) - coord_mean) / coord_std

loss_diff, loss_klz, loss_recon = vdm.apply(params, x_tran, conditioning, mask, rngs={"sample": rng})
loss_batch = (((loss_diff + loss_klz) * mask[:, :, None]).sum((-1, -2)) + (loss_recon * mask[:, :, None]).sum((-1, -2))) / mask.sum(-1)
loss_batch

Array([ 36.925606, -16.697319], dtype=float32)

In [33]:
(loss_diff * mask[:, :, None]).sum((-1, -2)) / mask.sum(-1)

Array([5.3623352e+01, 1.3796040e-02], dtype=float32)

In [34]:
loss_batch[:].mean()

Array(10.114143, dtype=float32)

In [37]:
# Translate
tran = np.array([100., -1200., -840.])
x_tran = x + tran / coord_std

loss_diff, loss_klz, loss_recon = vdm.apply(params, x_tran, conditioning, mask, rngs={"sample": rng})
loss_batch = (((loss_diff + loss_klz) * mask[:, :, None]).sum((-1, -2)) + (loss_recon * mask[:, :, None]).sum((-1, -2))) / mask.sum(-1)
loss_batch

Array([311.13904, -16.57467], dtype=float32)

In [38]:
(loss_diff * mask[:, :, None]).sum((-1, -2)) / mask.sum(-1)

Array([3.2783246e+02, 1.3209820e-01], dtype=float32)

### Do rotates/translated pos produce same graph?

In [199]:
tran = np.array([0., 0., 0.])
x_tran = x * coord_std + coord_mean
x_tran += tran[None, None, :]

sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None, None))(x_tran, k, box_size)
sources, targets

targets.sum()

Array(501639938, dtype=int32)

In [216]:
# Make graph with translated data

tran = np.array([100., 3000., -120.])
x_tran = x * coord_std + coord_mean
x_tran += tran[None, None, :]

sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None, None))(x_tran, k, box_size)
sources, targets

targets.sum()

Array(501639938, dtype=int32)

In [217]:
# Make graph with rotated data

x_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x, 45., np.array([0., 1. / np.sqrt(2), 1. / np.sqrt(2)]))
x_rot = x_rot * coord_std + coord_mean

sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None, None))(x_rot, k, box_size)
sources, targets

targets.sum()

Array(502825442, dtype=int32)

In [196]:
tran = np.array([0., 0., 0.])
x_tran = x * coord_std + coord_mean
x_tran += tran[None, None, :]

sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(x_tran, k)
sources, targets

targets.sum()

Array(502067660, dtype=int32)

In [50]:
eps = jax.random.normal(rng, shape=x.shape)
eps_com = np.mean(eps[..., :3], axis=-2, keepdims=True)