In [2]:
import sys
sys.path.append("../")

import jax
import jax.numpy as jnp
import jraph
from e3nn_jax import Irreps
from e3nn_jax import IrrepsArray

from models.segnn import SEGNN

from models.utils.irreps_utils import balanced_irreps

%load_ext autoreload
%autoreload 2

In [3]:
import numpy as np

n_batch = 2
n_nodes = 1000
k = 20

x = np.load("../data/halos_small.npy")[:n_batch, :n_nodes, :]

# Normalize
x_mean = x.mean((0, 1))
x_std = x.std((0, 1))
x[:, :, 3:] = (x[:, :, 3:] - x_mean[3:]) / x_std[3:]
x[:, :, :3] = x[:, :, :3] / 1000.  # Divide by box size

In [5]:
from models.utils.graph_utils import nearest_neighbors, rotate_representation

# Get nearest neighbors graph
sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(x, k)

In [6]:
l_attr = 2
hidden_feats = 64

irreps_sh = Irreps.spherical_harmonics(l_attr)
irreps_hidden = balanced_irreps(lmax=l_attr, feature_size=hidden_feats, use_sh=True)

irreps_hidden

44x0e+7x1o+4x2e

In [8]:
x_irreps = IrrepsArray("1o + 1o + 1x0e", x)

graph = jraph.GraphsTuple(
          n_node=n_nodes,
          n_edge=k,
          edges=None,
          globals=None,
          nodes=x_irreps[0], 
          senders=sources[0],
          receivers=targets[0])

segnn = SEGNN(num_message_passing_steps=3)

key = jax.random.PRNGKey(0)
out, params = segnn.init_with_output(key, graph)

out

GraphsTuple(nodes=1x1o+1x1o+1x0e
[[ 0.1127352   0.7019572   0.64938754 ... -1.2253296   0.9189348
  -1.0862536 ]
 [ 0.13718818  0.5570412   0.74716806 ...  0.30911353 -1.330875
  -1.0863101 ]
 [ 0.09951164  0.24814542  0.93283325 ...  0.77555734  0.21970801
  -1.0862983 ]
 ...
 [ 0.47534055  0.7629518   0.62139857 ...  0.18325417  0.20174643
  -0.8891841 ]
 [ 0.47973168  0.9436957   0.3180124  ... -1.0878631  -0.59072036
  -0.88918823]
 [ 0.58530563  0.48572883  0.13688986 ... -0.47685483  0.06073122
  -0.8891901 ]], edges=(34x0e+10x1o
[[-7.0330918e-01 -3.2343227e-01  3.6037989e-02 ... -2.4040036e-01
  -1.8852744e-01  1.3454834e-01]
 [-6.2264746e-01  4.5916218e-01  2.7856091e-01 ...  5.4117595e-04
  -2.5518769e-01  2.7464986e-01]
 [-1.4202735e+00  6.7755282e-01  4.8944324e-01 ... -4.2370981e-01
  -2.1518993e-01  1.4726941e-01]
 ...
 [-9.1243070e-01 -4.2818181e-02 -3.5486576e-01 ... -5.3200221e-01
   3.2040301e-01  2.5485125e-01]
 [-1.0641135e+00 -3.3805856e-01 -5.4417920e-01 ... -6.506

In [9]:
sum(x.size for x in jax.tree_util.tree_leaves(params))

16965

In [10]:
axis = np.array([0, 1 / np.sqrt(2), 1 / np.sqrt(2)])
x_irreps = IrrepsArray("1o + 1o + 1x0e", rotate_representation(x[0], 45., axis))

graph = jraph.GraphsTuple(
          n_node=n_nodes,
          n_edge=k,
          edges=None,
          globals=None,
          nodes=x_irreps, 
          senders=sources[0],
          receivers=targets[0])

out_rot, params = segnn.init_with_output(key, graph)

out_rot

GraphsTuple(nodes=1x1o+1x1o+1x0e
[[ 0.10600067  0.63789093  0.71345377 ... -0.11498036 -0.19141458
  -1.0862536 ]
 [ 0.00194322  0.5162905   0.78791875 ... -0.6867165  -0.335045
  -1.0863101 ]
 [-0.27197862  0.29865986  0.88231885 ...  0.36774498  0.6275203
  -1.0862983 ]
 ...
 [ 0.40689313  0.5045514   0.8797988  ...  0.8751244  -0.4901238
  -0.8891841 ]
 [ 0.6520632   0.6122006   0.64950746 ... -1.0223409  -0.65624255
  -0.88918823]
 [ 0.5882931   0.14198966  0.480629   ... -0.12055017 -0.2955734
  -0.8891901 ]], edges=(34x0e+10x1o
[[-0.7033092  -0.32343227  0.03603799 ... -0.33152667 -0.02101387
  -0.03296524]
 [-0.6226473   0.4591622   0.27856097 ... -0.26453626 -0.17786507
   0.1973272 ]
 [-1.4202739   0.6775529   0.48944354 ... -0.48083752  0.0497459
  -0.11766642]
 ...
 [-0.9124307  -0.04281837 -0.35486618 ... -0.34340668  0.57680464
  -0.00155008]
 [-1.0641137  -0.3380588  -0.5441793  ... -0.7323551   0.6138189
   0.34823993]
 [-1.1288533  -0.5980963  -1.0674239  ... -1.2300696

In [11]:
out_rot.nodes.array / rotate_representation(out.nodes.array, 45, axis)

Array([[1.0000001 , 1.        , 1.        , ..., 1.0000008 , 1.0000004 ,
        1.        ],
       [0.9999862 , 0.9999999 , 1.0000001 , ..., 0.99999994, 1.0000004 ,
        1.        ],
       [1.        , 1.0000001 , 1.0000001 , ..., 0.99999994, 1.        ,
        1.        ],
       ...,
       [1.0000001 , 0.9999999 , 1.        , ..., 1.0000001 , 1.        ,
        1.        ],
       [1.0000001 , 1.0000001 , 1.        , ..., 1.0000001 , 1.        ,
        1.        ],
       [1.0000001 , 1.0000001 , 1.0000001 , ..., 0.99999994, 1.        ,
        1.        ]], dtype=float32)

In [14]:
import flax.linen as nn
import e3nn_jax as e3nn

class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(SEGNN(num_message_passing_steps=3))
        return model(x)

model = GraphWrapper()
rng = jax.random.PRNGKey(42)

In [15]:
x_irreps = IrrepsArray("1o + 1o + 1x0e", x)

graph = jraph.GraphsTuple(
          n_node=jnp.array(n_batch * [[n_nodes]]), 
          n_edge=jnp.array(n_batch * [[k]]),
          nodes=x_irreps, 
          edges=None,
          globals=jnp.ones((n_batch, 7)),
          senders=sources,
          receivers=targets)

graph_out, _ = model.init_with_output(rng, graph)
x_out = graph_out  # Output features