In [1]:
import jax
import jax.numpy as np
import jraph
import flax.linen as nn
import numpy as onp
import matplotlib.pyplot as plt

from functools import partial

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
sys.path.append("../")

from models.gnn import GraphConvNet
from models.graph_utils import nearest_neighbors
from models.graph_utils import add_graphs_tuples
from models.train_utils import create_input_iter

In [4]:
from datasets import load_data

In [5]:
train_ds, norm_dict = load_data("nbody", 3, 5000, 2, 234)

In [6]:
batches = create_input_iter(train_ds)
x = next(batches)
x = x[0][0]

In [7]:
norm_dict

{'mean': Array([250.20403, 250.02423, 250.07912], dtype=float32),
 'std': Array([144.30989, 144.35008, 144.3569 ], dtype=float32)}

In [None]:
z = [[[-0.74147505 -1.6911283   1.5016145 ]
  [-1.2356482   0.83966905 -0.6621572 ]
  [-0.93439144  0.7761591   1.7189647 ]
  ...
  [-0.37795684  0.05500687 -0.8659335 ]
  [-1.708255   -1.1711149  -0.9808073 ]
  [-1.0818185   0.7704281  -1.5321599 ]]

 [[-0.5730722  -0.08018497 -0.70453495]
  [-0.3276197  -0.12879108 -1.3185476 ]
  [-0.8818957  -0.07687593 -1.0378724 ]
  ...
  [-0.0777759  -0.44764853  1.6243275 ]
  [ 0.98842955  0.85802907 -1.4616792 ]
  [-1.565508    0.18810192  1.3286879 ]]]

In [8]:
x

Array([[[-0.74147505, -1.6911283 ,  1.5016145 ],
        [-1.2356482 ,  0.83966905, -0.6621572 ],
        [-0.93439144,  0.7761591 ,  1.7189647 ],
        ...,
        [-0.37795684,  0.05500687, -0.8659335 ],
        [-1.708255  , -1.1711149 , -0.9808073 ],
        [-1.0818185 ,  0.7704281 , -1.5321599 ]],

       [[-0.5730722 , -0.08018497, -0.70453495],
        [-0.3276197 , -0.12879108, -1.3185476 ],
        [-0.8818957 , -0.07687593, -1.0378724 ],
        ...,
        [-0.0777759 , -0.44764853,  1.6243275 ],
        [ 0.98842955,  0.85802907, -1.4616792 ],
        [-1.565508  ,  0.18810192,  1.3286879 ]]], dtype=float32)

In [9]:
k = 20

In [10]:
sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(x, k)

## EGNN

In [11]:
from models.egnn import EGNN

In [12]:
class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(EGNN(
                        message_passing_steps=2, d_hidden=64, n_layers=3, norm_layer=False, skip_connections=False,
                ))
        return model(x)
    
model = GraphWrapper()
rng = jax.random.PRNGKey(42)

## Test equivariance

In [13]:
from models.graph_utils import rotate_representation

In [14]:
x.shape, x.mean(), x.std()

((2, 5000, 3),
 Array(0.00506021, dtype=float32),
 Array(1.0109903, dtype=float32))

In [15]:
n_nodes = 5000

graph = jraph.GraphsTuple(
          n_node=np.array(2 * [[n_nodes]]), 
          n_edge=np.array(2 * [[k]]),
          nodes=x, 
          edges=None,
          globals=np.ones((2, 7)),
          senders=sources,
          receivers=targets)

graph_out, _ = model.init_with_output(rng, graph)

x_out = graph_out.nodes

# angle_deg = 45.
# axis = np.array([0, 1 / np.sqrt(2), 1 / np.sqrt(2)])

# x_out_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x_out, angle_deg, axis)

# graph = jraph.GraphsTuple(
#           n_node=np.array(4 * [[n_nodes]]), 
#           n_edge=np.array(4 * [[20]]),
#           nodes=jax.vmap(rotate_representation, in_axes=(0,None,None))(x[0:4, :, :], angle_deg, axis),
#           edges=None,
#           globals=np.ones((4, 7)),
#           senders=sources,
#           receivers=targets)

# graph_out, params = model.init_with_output(rng, graph)
# x_out = graph_out.nodes

(100000, 3) (100000, 3)
9.999999747378752e-06, 0.40845030546188354, 0.06262736767530441, 0.047285839915275574, 0, 0
9.999999747378752e-06, 0.4638831913471222, 0.06226273253560066, 0.04635818302631378, 0, 0
xi = [[-0.74147505 -1.6911283   1.5016145 ]
 [-0.74147505 -1.6911283   1.5016145 ]
 [-0.74147505 -1.6911283   1.5016145 ]
 [-0.74147505 -1.6911283   1.5016145 ]], xj = [[-0.74147505 -1.6911283   1.5016145 ]
 [-0.7410142  -1.6854184   1.5522099 ]
 [-0.7767966  -1.6671458   1.6424124 ]
 [-0.771016   -1.5445424   1.5368899 ]]
xi = [[-0.5730722  -0.08018497 -0.70453495]
 [-0.5730722  -0.08018497 -0.70453495]
 [-0.5730722  -0.08018497 -0.70453495]
 [-0.5730722  -0.08018497 -0.70453495]], xj = [[-0.5730722  -0.08018497 -0.70453495]
 [-0.679718    0.01507041 -0.6691244 ]
 [-0.4275025  -0.15570742 -0.6210643 ]
 [-0.424658   -0.1758395  -0.6367179 ]]
ms = [[9.9999997e-06]
 [2.6027097e-03]
 [2.1656834e-02]
 ...
 [7.8889132e-02]
 [8.0581762e-02]
 [8.6887777e-02]]
ms = [[9.9999997e-06]
 [2.17108

In [67]:
x_out

Array([[[-0.7414756 , -1.6911348 ,  1.5015562 ],
        [-1.2356482 ,  0.83966905, -0.6621572 ],
        [-0.93439144,  0.7761591 ,  1.7189647 ],
        ...,
        [-0.37767634,  0.05574466, -0.8661723 ],
        [-1.708255  , -1.1711149 , -0.9808073 ],
        [-1.0818185 ,  0.7704281 , -1.5321599 ]],

       [[-0.5730722 , -0.08018497, -0.70453495],
        [-0.3276868 , -0.12839817, -1.3189393 ],
        [-0.8818957 , -0.07687593, -1.0378724 ],
        ...,
        [-0.07786658, -0.44757283,  1.6241324 ],
        [ 0.98844427,  0.85803986, -1.4616636 ],
        [-1.5655274 ,  0.1881368 ,  1.3286687 ]]], dtype=float32)

In [92]:
# plt.hist(onp.array(x_out)[0,:,2])

In [93]:
# Equivariance ratio
eq_ratio = x_out / x_out_rot
print(eq_ratio.max(), eq_ratio.min(), eq_ratio)

20.295788 -7.051834 [[[ 1.0008166   1.0008724   0.999304  ]
  [ 1.0008503   1.0005685   1.000457  ]
  [ 0.99155265  0.9978411   0.99908423]
  ...
  [ 1.0001603   0.9993249   0.99925137]
  [ 0.9889371   0.99885225  0.99559397]
  [ 0.9864263   1.0003259   1.0027362 ]]

 [[ 0.9991339   0.9994029   0.9991387 ]
  [ 0.99895686  0.99927497  1.0007412 ]
  [ 1.0004698   1.0016893   1.0144402 ]
  ...
  [ 1.0006037   0.98780954  1.0003967 ]
  [ 0.9956347   0.99795693  1.0007982 ]
  [-0.8481864   1.0002426   1.0030551 ]]

 [[ 1.0004767   1.0000732   1.0003489 ]
  [ 1.0294195   1.0052049   0.57948416]
  [ 1.0051254   0.99525446  1.0029076 ]
  ...
  [ 1.0045681   0.99976003  0.98984253]
  [ 0.9997207   1.0006992   0.99946576]
  [ 0.9994839   0.99905413  0.9848319 ]]

 [[ 1.0067011   0.9998752   0.9987213 ]
  [ 0.9998024   1.0002029   1.000398  ]
  [ 0.99177223  0.9998248   1.0002543 ]
  ...
  [ 0.99944764  1.0020063   1.0002644 ]
  [ 0.9997483   0.9973999   1.000468  ]
  [ 1.0003804   0.9888876   0.

In [346]:
sum(x.size for x in jax.tree_util.tree_leaves(params))

34308

In [347]:
from clu import parameter_overview
print(parameter_overview.get_parameter_overview(params))

+------------------------------------+----------+-------+-----------+-------+
| Name                               | Shape    | Size  | Mean      | Std   |
+------------------------------------+----------+-------+-----------+-------+
| params/EGNN_0/MLP_0/Dense_0/bias   | (64,)    | 64    | 0.0       | 0.0   |
| params/EGNN_0/MLP_0/Dense_0/kernel | (2, 64)  | 128   | 0.0302    | 0.675 |
| params/EGNN_0/MLP_0/Dense_1/bias   | (64,)    | 64    | 0.0       | 0.0   |
| params/EGNN_0/MLP_0/Dense_1/kernel | (64, 64) | 4,096 | 0.00124   | 0.126 |
| params/EGNN_0/MLP_0/Dense_2/bias   | (64,)    | 64    | 0.0       | 0.0   |
| params/EGNN_0/MLP_0/Dense_2/kernel | (64, 64) | 4,096 | -0.000354 | 0.124 |
| params/EGNN_0/MLP_0/Dense_3/bias   | (1,)     | 1     | 0.0       | 0.0   |
| params/EGNN_0/MLP_0/Dense_3/kernel | (64, 1)  | 64    | -0.0182   | 0.124 |
| params/EGNN_0/MLP_1/Dense_0/bias   | (64,)    | 64    | 0.0       | 0.0   |
| params/EGNN_0/MLP_1/Dense_0/kernel | (2, 64)  | 128   | -0.003