In [146]:
import sys
sys.path.append("../")

import jax
import jax.numpy as jnp
import jraph
from e3nn_jax import Irreps
from e3nn_jax import IrrepsArray

from models.segnn import SEGNN

from utils.irreps_utils import balanced_irreps

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [147]:
import numpy as np

n_batch = 2
n_nodes = 1000
k = 20

x = np.load("../data/halos_small.npy")[:n_batch, :n_nodes, :]

# Normalize
x_mean = x.mean((0, 1))
x_std = x.std((0, 1))
x[:, :, 3:] = (x[:, :, 3:] - x_mean[3:]) / x_std[3:]
x[:, :, :3] = x[:, :, :3] / 1000.  # Divide by box size

In [148]:
from utils.graph_utils import nearest_neighbors, rotate_representation

# Get nearest neighbors graph
sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(x, k)

In [149]:
l_attr = 2
hidden_feats = 64

irreps_sh = Irreps.spherical_harmonics(l_attr)
irreps_hidden = balanced_irreps(lmax=l_attr, feature_size=hidden_feats, use_sh=True)

irreps_hidden

44x0e+7x1o+4x2e

In [150]:
x_irreps = IrrepsArray("1o + 1o + 1x0e", x)

graph = jraph.GraphsTuple(
          n_node=n_nodes,
          n_edge=k,
          edges=None,
          globals=None,
          nodes=x_irreps[0], 
          senders=sources[0],
          receivers=targets[0])

segnn = SEGNN(irreps_hidden=irreps_hidden, irreps_sh=irreps_sh, num_message_passing_steps=3)

key = jax.random.PRNGKey(0)
out, params = segnn.init_with_output(key, graph)

out

GraphsTuple(nodes=1x1o+1x1o+1x0e
[[ 1.11022113e-04  2.57469917e-04 -2.40940208e-04 ... -5.41947957e-04
   5.02493407e-04 -5.45198418e-05]
 [-5.12439350e-04  2.13048814e-04  6.80897618e-04 ... -4.60215349e-04
  -1.53686467e-03 -1.35711789e-05]
 [-7.06139908e-05 -3.56277778e-05  2.05050295e-04 ...  1.49596555e-04
  -4.54601948e-04 -1.53323901e-06]
 ...
 [ 2.01664254e-04  4.63115139e-05 -7.53444419e-06 ... -1.14582304e-04
  -4.63076367e-06  2.92741470e-06]
 [ 3.99559431e-05  3.58298304e-04  1.02381084e-04 ... -7.74350367e-04
  -2.32465769e-04 -5.20189042e-06]
 [ 1.60769938e-04  1.22273152e-04 -5.33271414e-05 ... -2.70873774e-04
   1.42000819e-04  2.39907899e-06]], edges=(44x0e+7x1o+4x2e
[[ 1.38167024e-06  3.06219335e-05  5.28177698e-05 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 [-1.19771124e-04 -4.54308814e-04  1.16038439e-03 ... -7.70996383e-04
  -4.27940860e-04  8.70001444e-04]
 [-3.59936082e-03 -3.70895374e-03  2.44372664e-03 ... -2.72686221e-03
  -7.01097830e-04  7.542887

In [151]:
sum(x.size for x in jax.tree_util.tree_leaves(params))

27699

In [152]:
axis = np.array([0, 1 / np.sqrt(2), 1 / np.sqrt(2)])
x_irreps = IrrepsArray("1o + 1o + 1x0e", rotate_representation(x[0], 45., axis))

graph = jraph.GraphsTuple(
          n_node=n_nodes,
          n_edge=k,
          edges=None,
          globals=None,
          nodes=x_irreps, 
          senders=sources[0],
          receivers=targets[0])

out_rot, params = segnn.init_with_output(key, graph)

out_rot

GraphsTuple(nodes=1x1o+1x1o+1x0e
[[ 3.2771073e-04  1.2896882e-04 -1.1243893e-04 ... -2.5075427e-04
   2.1129930e-04 -5.4519907e-05]
 [-5.9627375e-04  5.3778343e-04  3.5616304e-04 ... -1.1657906e-03
  -8.3128968e-04 -1.3571217e-05]
 [-1.7027091e-04  3.4925837e-05  1.3449664e-04 ... -3.5836962e-05
  -2.6916852e-04 -1.5332276e-06]
 ...
 [ 1.6952104e-04 -6.2406143e-05  1.0118313e-04 ...  1.2702261e-04
  -2.4623549e-04  2.9274129e-06]
 [ 1.5621171e-04  3.0084187e-04  1.5983715e-04 ... -6.4916286e-04
  -3.5765234e-04 -5.2018672e-06]
 [ 2.0148158e-04  1.6172067e-05  5.2773801e-05 ... -3.7201142e-05
  -9.1671558e-05  2.3990774e-06]], edges=(44x0e+7x1o+4x2e
[[ 1.3816696e-06  3.0622148e-05  5.2818115e-05 ...  0.0000000e+00
   0.0000000e+00  0.0000000e+00]
 [-1.1977117e-04 -4.5430945e-04  1.1603845e-03 ... -5.1295408e-04
  -2.9983526e-04  4.6480077e-04]
 [-3.5993634e-03 -3.7089558e-03  2.4437287e-03 ... -1.3093831e-03
  -1.2986651e-03  3.8429967e-04]
 ...
 [-3.2761764e-05  6.5344619e-05  6.146641

In [153]:
out_rot.nodes.array / rotate_representation(out.nodes.array, 45, axis)

Array([[1.0000036 , 1.0000037 , 1.0000026 , ..., 1.0000037 , 1.0000024 ,
        1.0000012 ],
       [1.        , 1.0000001 , 1.0000001 , ..., 1.        , 1.0000005 ,
        1.0000029 ],
       [1.0000013 , 1.0000033 , 0.9999989 , ..., 1.0000086 , 0.9999992 ,
        0.9999926 ],
       ...,
       [0.9999994 , 0.9999993 , 0.9999987 , ..., 0.9999993 , 0.9999989 ,
        0.9999994 ],
       [0.9999998 , 0.99999934, 0.9999993 , ..., 0.9999991 , 0.99999917,
        0.9999955 ],
       [0.9999996 , 0.99999845, 0.99999785, ..., 0.9999994 , 0.9999976 ,
        0.99999934]], dtype=float32)

In [154]:
import flax.linen as nn
import e3nn_jax as e3nn

class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(SEGNN(irreps_hidden=irreps_hidden, irreps_sh=irreps_sh, num_message_passing_steps=3))
        return model(x)

model = GraphWrapper()
rng = jax.random.PRNGKey(42)

In [156]:
x_irreps = IrrepsArray("1o + 1o + 1x0e", x)

graph = jraph.GraphsTuple(
          n_node=jnp.array(n_batch * [[n_nodes]]), 
          n_edge=jnp.array(n_batch * [[k]]),
          nodes=x_irreps, 
          edges=None,
          globals=jnp.ones((n_batch, 7)),
          senders=sources,
          receivers=targets)

graph_out, _ = model.init_with_output(rng, graph)
x_out = graph_out  # Output features