In [12]:
import sys
sys.path.append("../")

import numpy as np
import jax.numpy as jnp
import jax
import flax.linen as nn
import jraph

from models.egnn import EGNN
from models.graph_utils import nearest_neighbors, rotate_representation

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
n_nodes = 2000
n_feat = 7
positions_only = False
k = 10
n_batch = 2

# Load dark matter particles data
# 3 positions + 3 velocities + 1 scalar (mass)
x = np.load("../data/halos_small.npy")[:, :n_nodes, :]

# Normalize
x_mean = x.mean((0, 1))
x_std = x.std((0, 1))
x[:, :, 3:] = (x[:, :, 3:] - x_mean[3:]) / x_std[3:]
x[:, :, :3] = x[:, :, :3] / 1000.  # Divide by box size

In [14]:
# Get nearest neighbors graph
sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(x[:n_batch], k)

In [15]:
class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(EGNN(
                        message_passing_steps=3, d_hidden=32, n_layers=3, activation='gelu',
                        positions_only=positions_only
                ))
        return model(x)

model = GraphWrapper()
rng = jax.random.PRNGKey(42)

In [16]:
# Original graph

graph = jraph.GraphsTuple(
          n_node=np.array(n_batch * [[n_nodes]]), 
          n_edge=np.array(n_batch * [[k]]),
          nodes=x[:n_batch, :, :n_feat], 
          edges=None,
          globals=np.ones((n_batch, 7)),
          senders=sources,
          receivers=targets)

graph_out, _ = model.init_with_output(rng, graph)
x_out = graph_out.nodes  # Output features

angle_deg = 45.
axis = np.array([0, 1 / np.sqrt(2), 1 / np.sqrt(2)])

# Rotated output features
x_out_rot = jax.vmap(rotate_representation, in_axes=(0,None,None,None))(x_out, angle_deg, axis, positions_only)

def get_rotated(x_rot):

    graph = jraph.GraphsTuple(
              n_node=np.array(n_batch * [[n_nodes]]), 
              n_edge=np.array(n_batch * [[k]]),
              nodes=x_rot[:n_batch, :, :n_feat],
              edges=None,
              globals=np.ones((n_batch, 7)),
              senders=sources,
              receivers=targets)

    graph_out, _ = model.init_with_output(rng, graph)
    x_out = graph_out.nodes
    
    return x_out

# Output rotated graph
x_out = get_rotated(jax.vmap(rotate_representation, in_axes=(0,None,None,None))(x[:n_batch, :, :], angle_deg, axis, positions_only))
x_out.shape

(2, 2000, 7)

In [17]:
# Equivariance ratio; compare rotated output to output of rotated input
eq_ratio = x_out / x_out_rot
print(eq_ratio.max(), eq_ratio.min(), eq_ratio)

1.1372669 0.9680851 [[[0.99999976 1.         1.         ... 0.9999997  1.         1.0000005 ]
  [0.99999964 1.         1.         ... 0.99999934 0.99999946 0.9999997 ]
  [0.9999999  1.0000001  1.         ... 1.0000005  1.0000006  0.9999999 ]
  ...
  [1.0000002  1.0000001  1.         ... 0.99999857 0.9999982  1.0000001 ]
  [1.0000001  1.         1.         ... 1.0000006  0.99999964 0.9999996 ]
  [0.9999999  1.0000002  1.         ... 0.9999997  0.99999946 0.99999976]]

 [[1.0000001  1.         0.99999994 ... 0.9999997  0.99999905 1.0000012 ]
  [1.0000004  1.         1.0000001  ... 1.0000035  1.0000137  0.9999986 ]
  [1.0000012  1.0000002  0.9999999  ... 1.0000021  1.0000021  1.0000004 ]
  ...
  [0.99999976 1.0000002  1.         ... 1.0000007  1.0000006  0.9999991 ]
  [0.99999976 1.0000001  1.0000001  ... 0.999986   0.9999976  1.0000001 ]
  [0.99999994 1.0000002  0.99999994 ... 1.0000002  1.0000001  0.99999946]]]
