In [1]:
import jax
import jax.numpy as np
import jraph
import flax.linen as nn
import numpy as onp
import matplotlib.pyplot as plt

from functools import partial

from jax.config import config
config.update("jax_debug_nans", True)

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
sys.path.append("../")

from models.gnn import GraphConvNet
from models.graph_utils import nearest_neighbors
from models.graph_utils import add_graphs_tuples
from models.train_utils import create_input_iter

In [4]:
from datasets import load_data

In [5]:
n_nodes = 100

In [6]:
train_ds, norm_dict = load_data("nbody", 3, n_nodes, 2, 234, small=True)

Metal device set to: Apple M1 Pro

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB



In [7]:
tuple(map(float, norm_dict['mean']))

(505.32177734375, 502.5718994140625, 493.03143310546875)

In [8]:
batches = create_input_iter(train_ds)
x_data = next(batches)
x = x_data[0][0]
cond = x_data[1][0]

In [9]:
k = 20

In [10]:
sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(x, k)

In [11]:
def translate_pbc(coordinates, translation_vector, box_size):
    """
    Translate a batch of coordinates and apply periodic boundary conditions.

    :param coordinates: np.array of shape (batch, points, 3) containing the coordinates
    :param translation_vector: np.array of shape (3,) containing the translation vector
    :param box_size: float representing the box size
    :return: np.array of shape (batch, points, 3) containing the translated coordinates
    """

    # Translate the coordinates
    translated_coordinates = coordinates + translation_vector

    # Apply periodic boundary conditions
    translated_coordinates = np.mod(translated_coordinates, box_size)

    return translated_coordinates

## EGNN; equivariance test

In [12]:
from models.egnn_jax import EGNN
from models.graph_utils import rotate_representation

In [13]:
coord_mean, coord_std, box_size = norm_dict['mean'], norm_dict['std'], 1000.
unit_cell = np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])

In [14]:
class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, graph, pos):
        return jax.vmap(EGNN(k=k), in_axes=(0, 0, 0, 0, None, None, None, None))(graph, pos, None, None, coord_mean, coord_std, box_size, unit_cell)
    
model = GraphWrapper()
rng = jax.random.PRNGKey(42)

In [15]:
d2 = np.sum(x ** 2, axis=-1, keepdims=True)

### Rotation equivariance

In [16]:
pos = x
pos_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x, 45, np.array([0., 0., 1.]))

In [17]:
graph = jraph.GraphsTuple(
          n_node=np.array(2 * [[n_nodes]]), 
          n_edge=np.array(2 * [[k]]),
          nodes=d2, 
          edges=None,
          globals=cond,
          senders=sources,
          receivers=targets)

pos_out, _ = model.init_with_output(rng, graph, pos)

In [18]:
graph = jraph.GraphsTuple(
          n_node=np.array(2 * [[n_nodes]]), 
          n_edge=np.array(2 * [[k]]),
          nodes=d2, 
          edges=None,
          globals=cond,
          senders=sources,
          receivers=targets)

pos_rot_out, _ = model.init_with_output(rng, graph, pos_rot)

In [19]:
pos_out_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(pos_out, 45, np.array([0., 0., 1.]))

In [20]:
pos_rot_out / pos_out_rot

Array([[[  1.0023512 ,   1.0016347 ,   1.0003382 ],
        [  1.0272945 ,   1.0032095 ,   1.0000627 ],
        [  0.98657024,   1.0031565 ,   1.0011903 ],
        [  0.99881697,   0.99889696,   0.99880624],
        [  0.99897164,   0.9994192 ,   0.9997275 ],
        [  1.0009176 ,   1.0009068 ,   0.9999247 ],
        [  0.99880594,   1.0002304 ,   1.0000639 ],
        [  1.0004345 ,   1.0001076 ,   1.0001807 ],
        [  0.9954863 ,   1.0000083 ,   1.0001408 ],
        [  0.9999696 ,   1.0000054 ,   1.0000558 ],
        [  1.000141  ,   1.000029  ,   0.9998255 ],
        [  0.9998997 ,   1.00006   ,   1.0000829 ],
        [  0.99994177,   0.9999949 ,   1.0001148 ],
        [  1.0004723 ,   0.99999374,   0.99993414],
        [  1.0005957 ,   0.99999714,   1.0001227 ],
        [  1.0004659 ,   1.0001912 ,   1.000227  ],
        [  1.0001805 ,   1.0004332 ,   1.000162  ],
        [  1.0000856 ,   0.999969  ,   0.99978703],
        [  1.0000488 ,   1.0000314 ,   1.0002534 ],
        [  0

### Translation equivariance

In [21]:
tran = np.array([-0.45, 0.9, -1.2])[None, None, :]
pos_tran = pos + tran
pos_out_tran = pos_out + tran

In [22]:
graph = jraph.GraphsTuple(
          n_node=np.array(2 * [[n_nodes]]), 
          n_edge=np.array(2 * [[k]]),
          nodes=d2, 
          edges=None,
          globals=cond,
          senders=sources,
          receivers=targets)

pos_tran_out, _ = model.init_with_output(rng, graph, pos_tran)

In [23]:
pos_tran_out / pos_out_tran

Array([[[0.9999998 , 1.0000005 , 0.99999994],
        [1.000002  , 0.99999976, 0.9999994 ],
        [0.99999976, 1.        , 0.9999999 ],
        [0.99999416, 1.        , 0.9999997 ],
        [1.        , 1.0000001 , 0.99999857],
        [1.0000002 , 1.        , 0.9999997 ],
        [0.9999988 , 1.0000002 , 1.0000001 ],
        [0.99999994, 1.0000002 , 1.0000001 ],
        [1.        , 1.        , 1.0000002 ],
        [1.0000001 , 1.        , 0.99999964],
        [0.99999994, 1.0000001 , 0.99999994],
        [1.        , 1.0000001 , 1.0000002 ],
        [1.0000001 , 1.        , 0.99999917],
        [1.0000002 , 0.99999994, 1.0000004 ],
        [1.0000001 , 1.0000002 , 0.9999995 ],
        [1.0000002 , 1.0000012 , 1.        ],
        [1.0000001 , 1.0000008 , 1.        ],
        [1.        , 1.0000004 , 1.0000001 ],
        [0.99999994, 1.        , 1.0000001 ],
        [1.0000001 , 1.        , 0.9999999 ],
        [1.0000002 , 1.0000001 , 1.0000023 ],
        [1.        , 0.9999997 , 0

In [24]:
tran = np.array([700., 100., 800.])
pos_tran = (translate_pbc(pos * coord_std + coord_mean, tran, 1000.) - coord_mean) / coord_std
pos_out_tran = (translate_pbc(pos_out * coord_std + coord_mean, tran, 1000.) - coord_mean) / coord_std

In [25]:
graph = jraph.GraphsTuple(
          n_node=np.array(2 * [[n_nodes]]), 
          n_edge=np.array(2 * [[k]]),
          nodes=d2, 
          edges=None,
          globals=cond,
          senders=sources,
          receivers=targets)

pos_tran_out, _ = model.init_with_output(rng, graph, pos_tran)

In [26]:
pos_tran_out / pos_out_tran

Array([[[ 0.9999999 ,  1.0000002 ,  0.99999875],
        [ 1.0000001 ,  1.0000017 ,  0.9999998 ],
        [ 1.0000067 ,  1.0000005 ,  1.0000001 ],
        [ 0.9999997 ,  0.99999964,  1.0000002 ],
        [ 0.9999998 ,  0.9999998 ,  0.9999993 ],
        [ 1.        ,  1.0000002 ,  1.0000007 ],
        [ 0.9999996 ,  1.0000002 ,  1.0000001 ],
        [ 1.0000001 ,  1.0000001 ,  0.9999992 ],
        [ 1.        ,  1.0000001 ,  0.9999995 ],
        [ 0.99999994, -1.0738907 ,  1.0000002 ],
        [ 0.9999997 ,  1.0000001 ,  0.99999994],
        [-1.0092188 ,  0.99999976,  0.99999976],
        [-1.0131207 ,  1.0000001 ,  1.        ],
        [-0.9964309 ,  0.9999999 ,  1.0000001 ],
        [-1.0729132 ,  1.0000031 ,  0.99999994],
        [ 1.0000001 ,  0.99999976,  0.99999994],
        [ 1.        ,  1.        ,  1.        ],
        [-1.0455102 ,  1.0000001 ,  1.0000001 ],
        [ 1.0000001 ,  1.        ,  1.0000001 ],
        [ 1.0000001 ,  1.0000001 ,  0.9999997 ],
        [ 0.9999995 

## Look at VLB loss

In [27]:
from models.diffusion import VariationalDiffusionModel
from flax.core import FrozenDict
from models.diffusion_utils import loss_vdm
from models.train_utils import create_input_iter, param_count, train_step

In [28]:
n_batch = 2

In [29]:
train_ds, norm_dict_tmp = load_data("nbody", 3, 100, n_batch, 23, small=True)

In [30]:
batches = create_input_iter(train_ds)

In [31]:
x, conditioning, mask = next(batches)
x, conditioning, mask = x[0], conditioning[0], mask[0]

In [32]:
score_dict = FrozenDict({"k":20, "n_pos_features":3})  # GNN args
encoder_dict = decoder_dict = FrozenDict({})
norm_dict = FrozenDict({"x_mean":tuple(map(float, norm_dict_tmp['mean'])), "x_std":tuple(map(float, norm_dict_tmp['std'])), "box_size":1000.})  # GNN args
score = "egnn"

vdm = VariationalDiffusionModel(
          timesteps=0, 
          d_t_embedding=16,
          d_feature=3,
          score=score,
          score_dict=score_dict,
          n_classes=0,
          embed_context=True,
          d_context_embedding=16,
          noise_schedule="learned_linear",
          gamma_min=-8.,
          gamma_max=14.,
          use_encdec=False,
          encoder_dict=encoder_dict,
          decoder_dict=decoder_dict,
          norm_dict=norm_dict)

In [33]:
x.shape, conditioning.shape, mask.shape

((2, 100, 3), (2, 2), (2, 100))

In [34]:
rng = jax.random.PRNGKey(42)
out, params = vdm.init_with_output({"sample": rng, "params": rng}, x, conditioning, mask);

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (None, 3).

### Check rotational invariance of loss

In [35]:
loss = loss_vdm(params, vdm, rng, x, conditioning, mask)
loss

Array(9.643856, dtype=float32)

In [36]:
x_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x, 45., np.array([0., 0., 1.]))
loss = loss_vdm(params, vdm, rng, x_rot, conditioning, mask)
loss

Array(9.592255, dtype=float32)

In [37]:
x_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x, 93., np.array([0., 1. / np.sqrt(2), 1. / np.sqrt(2)]))
loss = loss_vdm(params, vdm, rng, x_rot, conditioning, mask)
loss

Array(9.698193, dtype=float32)

In [38]:
# loss, loss_grad = jax.value_and_grad(loss_vdm)(params, vdm, rng, x, np.zeros_like(conditioning), mask)

### Check rotational invariance of loss (individual components)

In [39]:
loss_diff, loss_klz, loss_recon = vdm.apply(params, x, conditioning, mask, rngs={"sample": rng})
loss_batch = (((loss_diff + loss_klz) * mask[:, :, None]).sum((-1, -2)) + (loss_recon * mask[:, :, None]).sum((-1, -2))) / mask.sum(-1)
loss_batch

Array([ 35.985077, -16.697365], dtype=float32)

In [40]:
x_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x, 45., np.array([0., 0., 1.]))
loss_diff, loss_klz, loss_recon = vdm.apply(params, x_rot, conditioning, mask, rngs={"sample": rng})
loss_batch = (((loss_diff + loss_klz) * mask[:, :, None]).sum((-1, -2)) + (loss_recon * mask[:, :, None]).sum((-1, -2))) / mask.sum(-1)
loss_batch

Array([ 35.881878, -16.697369], dtype=float32)

In [41]:
x_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x, 93., np.array([0., 1. / np.sqrt(2), 1. / np.sqrt(2)]))
loss_diff, loss_klz, loss_recon = vdm.apply(params, x_rot, conditioning, mask, rngs={"sample": rng})
loss_batch = (((loss_diff + loss_klz) * mask[:, :, None]).sum((-1, -2)) + (loss_recon * mask[:, :, None]).sum((-1, -2))) / mask.sum(-1)
loss_batch

Array([ 36.093754, -16.697369], dtype=float32)

In [42]:
(loss_diff * mask[:, :, None]).sum((-1, -2)) / mask.sum(-1)

Array([5.2791515e+01, 1.3752948e-02], dtype=float32)

In [43]:
loss_batch[:].mean()

Array(9.698193, dtype=float32)

### Check translational invariance of loss (individual components)

In [30]:
# Translate
tran = np.array([700., 100., 800.])
x_tran = (translate_pbc(x * coord_std + coord_mean, tran, 1000.) - coord_mean) / coord_std

loss_diff, loss_klz, loss_recon = vdm.apply(params, x_tran, conditioning, mask, rngs={"sample": rng})
loss_batch = (((loss_diff + loss_klz) * mask[:, :, None]).sum((-1, -2)) + (loss_recon * mask[:, :, None]).sum((-1, -2))) / mask.sum(-1)
loss_batch

Array([ 35.224545, -16.697493], dtype=float32)

In [31]:
(loss_diff * mask[:, :, None]).sum((-1, -2)) / mask.sum(-1)

Array([5.1922321e+01, 1.3630271e-02], dtype=float32)

In [32]:
# Translate
tran = np.array([100., -1200., -840.])
x_tran = (translate_pbc(x * coord_std + coord_mean, tran, 1000.) - coord_mean) / coord_std

loss_diff, loss_klz, loss_recon = vdm.apply(params, x_tran, conditioning, mask, rngs={"sample": rng})
loss_batch = (((loss_diff + loss_klz) * mask[:, :, None]).sum((-1, -2)) + (loss_recon * mask[:, :, None]).sum((-1, -2))) / mask.sum(-1)
loss_batch

Array([ 36.925606, -16.697319], dtype=float32)

In [33]:
(loss_diff * mask[:, :, None]).sum((-1, -2)) / mask.sum(-1)

Array([5.3623352e+01, 1.3796040e-02], dtype=float32)

In [34]:
loss_batch[:].mean()

Array(10.114143, dtype=float32)

In [37]:
# Translate
tran = np.array([100., -1200., -840.])
x_tran = x + tran / coord_std

loss_diff, loss_klz, loss_recon = vdm.apply(params, x_tran, conditioning, mask, rngs={"sample": rng})
loss_batch = (((loss_diff + loss_klz) * mask[:, :, None]).sum((-1, -2)) + (loss_recon * mask[:, :, None]).sum((-1, -2))) / mask.sum(-1)
loss_batch

Array([311.13904, -16.57467], dtype=float32)

In [38]:
(loss_diff * mask[:, :, None]).sum((-1, -2)) / mask.sum(-1)

Array([3.2783246e+02, 1.3209820e-01], dtype=float32)

### Do rotates/translated pos produce same graph?

In [199]:
tran = np.array([0., 0., 0.])
x_tran = x * coord_std + coord_mean
x_tran += tran[None, None, :]

sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None, None))(x_tran, k, box_size)
sources, targets

targets.sum()

Array(501639938, dtype=int32)

In [216]:
# Make graph with translated data

tran = np.array([100., 3000., -120.])
x_tran = x * coord_std + coord_mean
x_tran += tran[None, None, :]

sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None, None))(x_tran, k, box_size)
sources, targets

targets.sum()

Array(501639938, dtype=int32)

In [217]:
# Make graph with rotated data

x_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x, 45., np.array([0., 1. / np.sqrt(2), 1. / np.sqrt(2)]))
x_rot = x_rot * coord_std + coord_mean

sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None, None))(x_rot, k, box_size)
sources, targets

targets.sum()

Array(502825442, dtype=int32)

In [196]:
tran = np.array([0., 0., 0.])
x_tran = x * coord_std + coord_mean
x_tran += tran[None, None, :]

sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(x_tran, k)
sources, targets

targets.sum()

Array(502067660, dtype=int32)

In [50]:
eps = jax.random.normal(rng, shape=x.shape)
eps_com = np.mean(eps[..., :3], axis=-2, keepdims=True)