In [136]:
import numpy as np
import jax.numpy as jnp
import jax
import flax.linen as nn
import jraph

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [137]:
import sys
sys.path.append("../")

from models.egnn import EGNN
from models.graph_utils import nearest_neighbors, rotate_representation

In [138]:
import numpy as np

n_nodes = 2000
n_feat = 3
positions_only = True

x = np.load("/Users/smsharma/Downloads/halos_small.npy")[:, :n_nodes, :]

x_mean = x.mean((0, 1))
x_std = x.std((0, 1))
x[:, :, 3:] = (x[:, :, 3:] - x_mean[3:]) / x_std[3:]
x[:, :, :3] = x[:, :, :3] / 1000.

In [139]:
k = 10
n_batch = 1

sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(x[:n_batch], k)

In [140]:
class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(EGNN(
                        message_passing_steps=3, d_hidden=32, n_layers=3, skip_connections=False, activation='gelu',
                        positions_only=positions_only
                ))
        return model(x)

model = GraphWrapper()
rng = jax.random.PRNGKey(42)

In [141]:
graph = jraph.GraphsTuple(
          n_node=np.array(n_batch * [[n_nodes]]), 
          n_edge=np.array(n_batch * [[k]]),
          nodes=x[:n_batch, :, :n_feat], 
          edges=None,
          globals=np.ones((n_batch, 7)),
          senders=sources,
          receivers=targets)

graph_out, _ = model.init_with_output(rng, graph)

x_out = graph_out.nodes

angle_deg = 45.
axis = np.array([0, 1 / np.sqrt(2), 1 / np.sqrt(2)])

x_out_rot = jax.vmap(rotate_representation, in_axes=(0,None,None,None))(x_out, angle_deg, axis, positions_only)

def get_rotated(x_rot):

    graph = jraph.GraphsTuple(
              n_node=np.array(n_batch * [[n_nodes]]), 
              n_edge=np.array(n_batch * [[k]]),
              nodes=x_rot[:n_batch, :, :n_feat],
              edges=None,
              globals=np.ones((n_batch, 7)),
              senders=sources,
              receivers=targets)

    graph_out, _ = model.init_with_output(rng, graph)
    x_out = graph_out.nodes
    
    return x_out

x_out = get_rotated(jax.vmap(rotate_representation, in_axes=(0,None,None,None))(x[:n_batch, :, :], angle_deg, axis, positions_only))

(20000, 3, 5)
(20000, 3, 5)
(20000, 3, 5)
(20000, 3, 5)
(20000, 3, 5)
(20000, 3, 5)


In [142]:
x_out.shape

(1, 2000, 3)

In [143]:
# Equivariance ratio
eq_ratio = x_out / x_out_rot
print(eq_ratio.max(), eq_ratio.min(), eq_ratio)

1.0001143 0.99994147 [[[0.9999999 1.        1.       ]
  [1.0000114 1.        1.       ]
  [1.0000001 1.        1.0000001]
  ...
  [1.0000001 1.        0.9999999]
  [0.9999999 1.0000002 1.0000001]
  [1.        1.0000001 1.       ]]]
