In [1]:
import jax
import jax.numpy as np
import jraph
import flax.linen as nn

from functools import partial

In [2]:
%load_ext autoreload
%autoreload 2

In [197]:
import sys
sys.path.append("../")

from models.gnn import GraphConvNet
from models.graph_utils import nearest_neighbors
from models.graph_utils import add_graphs_tuples

In [4]:
n_nodes = 5000
x = np.load("/n/holyscratch01/iaifi_lab/ccuesta/data_for_sid/halos.npy")[:, :n_nodes,:]

In [5]:
sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(x[:4], 20)

## EGNN

In [484]:
from models.egnn import EGNN

In [485]:
class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(EGNN())
        return model(x)
    
model = GraphWrapper()
rng = jax.random.PRNGKey(42)

## Test equivariance

In [486]:
from models.graph_utils import rotate_representation

In [487]:
graph = jraph.GraphsTuple(
          n_node=np.array(4 * [[n_nodes]]), 
          n_edge=np.array(4 * [[20]]),
          nodes=x[:4, :, :], 
          edges=None,
          globals=np.ones((4, 7)),
          senders=sources,
          receivers=targets)

graph_out, _ = model.init_with_output(rng, graph)

x_out = graph_out.nodes

angle_deg = 45.
axis = jnp.array([0, 1 / np.sqrt(2), 1 / np.sqrt(2)])

x_out_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x_out, angle_deg, axis)

graph = jraph.GraphsTuple(
          n_node=np.array(4 * [[n_nodes]]), 
          n_edge=np.array(4 * [[20]]),
          nodes=jax.vmap(rotate_representation, in_axes=(0,None,None))(x[:4, :, :], angle_deg, axis),
          edges=None,
          globals=np.ones((4, 7)),
          senders=sources,
          receivers=targets)

graph_out, params = model.init_with_output(rng, graph)
x_out = graph_out.nodes

In [488]:
# Equivariance ratio
eq_ratio = x_out / x_out_rot
print(eq_ratio.max(), eq_ratio.min(), eq_ratio)

109.473564 -115.38762 [[[1.0002339  0.9960663  0.99954814 ... 0.99648935 0.99980164 0.99895155]
  [0.99985373 1.0009943  0.9863638  ... 1.0014861  0.9888656  0.9974784 ]
  [0.9999505  0.9996911  0.9997461  ... 1.0000232  0.9990794  0.9957409 ]
  ...
  [0.9999879  0.9939521  0.99819815 ... 0.99556774 0.99789757 1.0002339 ]
  [1.0008882  0.99951917 1.0003865  ... 0.99909174 1.0002927  0.9990782 ]
  [0.9994261  1.000005   0.99979764 ... 0.99999774 0.9999151  0.99951893]]

 [[0.99888057 1.0005398  0.9565121  ... 1.0004878  0.98075706 0.997439  ]
  [0.99880654 1.000618   1.0001235  ... 0.9999942  0.99992913 0.9962117 ]
  [1.0002842  0.9994418  0.99967384 ... 1.0002614  0.9988869  1.0009137 ]
  ...
  [0.9998485  1.0026954  1.0011265  ... 1.0037766  0.9994561  0.99975455]
  [1.0006504  1.0000986  1.0005809  ... 0.99983186 1.0003997  0.999679  ]
  [1.0000618  0.9999948  0.9998149  ... 1.0003642  0.99999964 1.0001197 ]]

 [[0.9989343  0.99609506 1.0001825  ... 0.9957949  1.0006255  1.0010105 ]


In [489]:
sum(x.size for x in jax.tree_util.tree_leaves(params))

119316