In [86]:
import sys
sys.path.append("../")

import jax
import jax.numpy as jnp
import jraph
from e3nn_jax import Irreps
from e3nn_jax import IrrepsArray

from models.nequip import NequIP

from models.utils.irreps_utils import balanced_irreps

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [87]:
import numpy as np

n_batch = 2
n_nodes = 1000
k = 20

x = np.load("../data/halos_small.npy")[:n_batch, :n_nodes, :]

# Normalize
x_mean = x.mean((0, 1))
x_std = x.std((0, 1))
x[:, :, 3:] = (x[:, :, 3:] - x_mean[3:]) / x_std[3:]
x[:, :, :3] = x[:, :, :3] / 1000.  # Divide by box size

In [88]:
from models.utils.graph_utils import nearest_neighbors, rotate_representation

# Get nearest neighbors graph
sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(x, k)

In [89]:
x_irreps = IrrepsArray("1o + 1o + 1x0e", x)

graph = jraph.GraphsTuple(
          n_node=n_nodes,
          n_edge=k,
          edges=None,
          globals=None,
          nodes=x_irreps[0], 
          senders=sources[0],
          receivers=targets[0])

segnn = NequIP(num_message_passing_steps=5)

key = jax.random.PRNGKey(0)
out, params = segnn.init_with_output(key, graph)

out

GraphsTuple(nodes=1x1o+1x1o+1x0e
[[-6.16390839e+01 -3.87518692e+01  1.45572405e+01 ...  1.98271313e+01
  -7.44811344e+00 -3.97852441e+03]
 [-3.48966255e+01 -8.00071144e+00  1.72874584e+01 ...  4.09350967e+00
  -8.84501076e+00 -5.22493164e+03]
 [ 4.96933044e+02  7.60859528e+01 -1.29368866e+02 ... -3.89288559e+01
   6.61907043e+01 -4.39523584e+03]
 ...
 [-3.80794792e+01 -4.61656427e+00  3.28418770e+01 ...  2.36203384e+00
  -1.68033237e+01 -6.13700195e+03]
 [ 2.63453430e+02  3.03911648e+01 -4.44013275e+02 ... -1.55494328e+01
   2.27176376e+02 -4.38056934e+03]
 [-5.24121208e+01 -1.02180119e+01  7.14080477e+00 ...  5.22797585e+00
  -3.65354395e+00 -7.23831348e+03]], edges=6x0e+9x1o+4x1e+7x2e+4x2o+4x3o+1x3e+1x4e
[[ 1.6039459e+03 -3.2757256e+00 -0.0000000e+00 ...  0.0000000e+00
   0.0000000e+00  0.0000000e+00]
 [ 1.6039459e+03 -3.2757256e+00  7.9584933e-22 ...  2.5352295e+01
   6.8216076e+00 -3.1906335e+00]
 [ 1.6039459e+03 -3.2757256e+00 -0.0000000e+00 ...  0.0000000e+00
   0.0000000e+00  0.

In [90]:
sum(x.size for x in jax.tree_leaves(params))


  sum(x.size for x in jax.tree_leaves(params))


54025

In [91]:
sum(x.size for x in jax.tree_util.tree_leaves(params))

54025

In [92]:
axis = np.array([0, 1 / np.sqrt(2), 1 / np.sqrt(2)])
x_irreps = IrrepsArray("1o + 1o + 1x0e", rotate_representation(x[0], 45., axis))

graph = jraph.GraphsTuple(
          n_node=n_nodes,
          n_edge=k,
          edges=None,
          globals=None,
          nodes=x_irreps, 
          senders=sources[0],
          receivers=targets[0])

out_rot, params = segnn.init_with_output(key, graph)

out_rot

GraphsTuple(nodes=1x1o+1x1o+1x0e
[[-7.1265549e+01 -2.9520565e-01 -2.3866104e+01 ...  1.5103996e-01
   1.2210931e+01 -3.9784980e+03]
 [-3.5641479e+01  1.2272608e+01 -3.9331954e+00 ... -6.2791958e+00
   2.0123928e+00 -5.2250049e+03]
 [ 4.5408710e+02 -2.0246201e+02  1.4918707e+02 ...  1.0358830e+02
  -7.6330551e+01 -4.3952393e+03]
 ...
 [-4.5870655e+01  1.9482538e+01  8.3496656e+00 ... -9.9681082e+00
  -4.2720494e+00 -6.1370068e+03]
 [ 4.2349399e+02 -1.7081511e+02 -2.4279933e+02 ...  8.7396385e+01
   1.2422664e+02 -4.3805679e+03]
 [-4.5740585e+01  1.8530035e+01 -2.1607677e+01 ... -9.4807663e+00
   1.1055422e+01 -7.2383208e+03]], edges=6x0e+9x1o+4x1e+7x2e+4x2o+4x3o+1x3e+1x4e
[[ 1.6039447e+03 -3.2757256e+00 -0.0000000e+00 ...  0.0000000e+00
   0.0000000e+00  0.0000000e+00]
 [ 1.6039440e+03 -3.2757311e+00  7.9601710e-22 ...  2.0544931e-01
   2.3520678e+01  1.3702793e+01]
 [ 1.6039447e+03 -3.2757256e+00 -0.0000000e+00 ...  0.0000000e+00
   0.0000000e+00  0.0000000e+00]
 ...
 [ 1.7252953e+03 -

In [93]:
out_rot.nodes.array / rotate_representation(out.nodes.array, 45, axis)

Array([[1.0146011 , 2.3544729 , 0.99156016, ..., 2.3544097 , 0.99156034,
        0.9999934 ],
       [0.9550306 , 0.93320924, 1.0178486 , ..., 0.93320924, 1.017849  ,
        1.0000141 ],
       [0.99994487, 0.9999666 , 1.000008  , ..., 0.99996644, 1.0000081 ,
        1.0000008 ],
       ...,
       [1.004713  , 0.9785873 , 1.0039915 , ..., 0.97858727, 1.0039914 ,
        1.0000008 ],
       [1.0000049 , 1.000027  , 0.9999495 , ..., 1.000027  , 0.9999496 ,
        0.99999964],
       [1.0000046 , 0.99999136, 1.0000128 , ..., 0.99999124, 1.0000128 ,
        1.000001  ]], dtype=float32)

In [94]:
import flax.linen as nn
import e3nn_jax as e3nn

class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(NequIP(num_message_passing_steps=3))
        return model(x)

model = GraphWrapper()
rng = jax.random.PRNGKey(42)

In [95]:
x_irreps = IrrepsArray("1o + 1o + 1x0e", x)

graph = jraph.GraphsTuple(
          n_node=jnp.array(n_batch * [[n_nodes]]), 
          n_edge=jnp.array(n_batch * [[k]]),
          nodes=x_irreps, 
          edges=None,
          globals=jnp.ones((n_batch, 7)),
          senders=sources,
          receivers=targets)

graph_out, _ = model.init_with_output(rng, graph)
x_out = graph_out  # Output features

x_out

GraphsTuple(nodes=1x1o+1x1o+1x0e
[[[ 8.0869305e-01  2.4658722e-01 -4.5842662e-01 ... -1.0893855e+00
    2.8542280e+00  1.0231916e-01]
  [ 1.6788752e-01  5.0235432e-01 -1.9495252e-01 ...  1.3089621e+00
   -5.1104575e-01 -5.4781165e-02]
  [ 7.0445490e+00  3.1431112e+00 -5.3350925e+00 ... -8.3813534e+00
    1.4177038e+01  6.8024471e-02]
  ...
  [-3.5712576e+00 -2.4034457e+00  2.2403276e+00 ...  7.1788449e+00
   -6.6258383e+00 -4.1972633e-02]
  [-4.0234259e-01  7.9447114e-01  8.4972030e-01 ...  6.1443293e-01
    6.5706849e-01  1.3994895e+00]
  [-7.5217980e-01 -1.6053778e+00  3.8537998e+00 ...  5.2818809e+00
   -1.2871581e+01  1.0857159e-01]]

 [[ 1.5418994e-01 -7.0568867e-02 -7.4971455e-01 ... -2.2951296e-01
   -2.0088096e+00 -5.5209219e-02]
  [ 2.1186693e+01 -2.9733748e+00  1.3333091e+01 ...  5.6325879e+00
   -2.5271307e+01 -7.2338758e-03]
  [-7.0508838e-01 -1.1527719e+00 -6.6511196e-01 ... -1.0938417e+00
   -6.4568228e-01 -5.5738314e-05]
  ...
  [ 3.8190970e-01  8.7662393e-01  8.7452358e