In [53]:
import sys
sys.path.append("../")

import jax
import jax.numpy as jnp
import jraph
from e3nn_jax import Irreps
from e3nn_jax import IrrepsArray

from models.nequip import NequIP

from models.utils.irreps_utils import balanced_irreps

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [54]:
import numpy as np

n_batch = 2
n_nodes = 1000
k = 20

x = np.load("../data/halos_small.npy")[:n_batch, :n_nodes, :]

# Normalize
x_mean = x.mean((0, 1))
x_std = x.std((0, 1))
x[:, :, 3:] = (x[:, :, 3:] - x_mean[3:]) / x_std[3:]
x[:, :, :3] = x[:, :, :3] / 1000.  # Divide by box size

In [55]:
from models.utils.graph_utils import nearest_neighbors, rotate_representation

# Get nearest neighbors graph
sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(x, k)

In [61]:
x_irreps = IrrepsArray("1o + 1o + 1x0e", x)

graph = jraph.GraphsTuple(
          n_node=n_nodes,
          n_edge=k,
          edges=None,
          globals=None,
          nodes=x_irreps[0], 
          senders=sources[0],
          receivers=targets[0])

segnn = NequIP(num_message_passing_steps=11, task="graph")

key = jax.random.PRNGKey(0)
out, params = segnn.init_with_output(key, graph)

out

Array([13845.224], dtype=float32)

In [62]:
sum(x.size for x in jax.tree_util.tree_leaves(params))


168712

In [63]:
sum(x.size for x in jax.tree_util.tree_leaves(params))

168712

In [64]:
axis = np.array([0, 1 / np.sqrt(2), 1 / np.sqrt(2)])
x_irreps = IrrepsArray("1o + 1o + 1x0e", rotate_representation(x[0], 45., axis))

graph = jraph.GraphsTuple(
          n_node=n_nodes,
          n_edge=k,
          edges=None,
          globals=None,
          nodes=x_irreps, 
          senders=sources[0],
          receivers=targets[0])

out_rot, params = segnn.init_with_output(key, graph)

out_rot

Array([13845.237], dtype=float32)

In [52]:
# out_rot.nodes.array / rotate_representation(out.nodes.array, 45, axis)

In [45]:
import flax.linen as nn
import e3nn_jax as e3nn

class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(NequIP(num_message_passing_steps=3))
        return model(x)

model = GraphWrapper()
rng = jax.random.PRNGKey(42)

In [46]:
x_irreps = IrrepsArray("1o + 1o + 1x0e", x)

graph = jraph.GraphsTuple(
          n_node=jnp.array(n_batch * [[n_nodes]]), 
          n_edge=jnp.array(n_batch * [[k]]),
          nodes=x_irreps, 
          edges=None,
          globals=jnp.ones((n_batch, 7)),
          senders=sources,
          receivers=targets)

graph_out, _ = model.init_with_output(rng, graph)
x_out = graph_out  # Output features

x_out

GraphsTuple(nodes=1x1o+1x1o+1x0e
[[[-2.5133986e+00  3.2979611e-01 -8.6644256e-01 ...  2.5471106e-01
   -6.6917866e-01  1.7413012e+01]
  [-8.3343214e-01 -8.5485518e-01  4.6340966e-01 ... -6.6022938e-01
    3.5790467e-01  7.8666644e+00]
  [-4.5500264e+00 -5.9013815e+00  1.1662164e+00 ... -4.5578074e+00
    9.0070260e-01 -0.0000000e+00]
  ...
  [ 2.6453974e+00  1.3448736e-01 -4.6251394e-02 ...  1.0386847e-01
   -3.5721287e-02  1.7531221e+01]
  [ 8.6989157e-02  7.8718656e-01  1.1298593e+00 ...  6.0796690e-01
    8.7262303e-01  9.1318512e+00]
  [ 1.0827856e+00  1.0906880e+00 -1.0774341e+00 ...  8.4236991e-01
   -8.3213353e-01  3.1309504e+01]]

 [[ 4.6635823e+00 -3.7457795e+00 -6.2306490e+00 ... -2.8929739e+00
   -4.8121099e+00  1.9082710e+01]
  [-1.4896696e+01 -8.0221760e-01 -9.5045252e+00 ... -6.1957580e-01
   -7.3406196e+00 -0.0000000e+00]
  [-1.8548755e-01  3.7974778e-01  9.6922672e-01 ...  2.9329017e-01
    7.4856180e-01 -0.0000000e+00]
  ...
  [-4.8790962e-01 -7.0162886e-01 -1.3959970e