In [892]:
import sys
sys.path.append("../")

import jax
import jax.numpy as jnp
import jraph
from e3nn_jax import Irreps
from e3nn_jax import IrrepsArray

from models.segnn import SEGNN

from models.utils.irreps_utils import balanced_irreps

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [893]:
import numpy as np

n_batch = 4
n_nodes = 1000
k = 20

x = np.load("../data/halos_small.npy")[:, :n_nodes, :]

# Normalize
x_mean = x.mean((0, 1))
x_std = x.std((0, 1))
x[:, :, 3:] = (x[:, :, 3:] - x_mean[3:]) / x_std[3:]
x[:, :, :3] = x[:, :, :3] / 1000.  # Divide by box size

x = x[:n_batch]

In [894]:
from models.utils.graph_utils import nearest_neighbors, rotate_representation

# Get nearest neighbors graph
sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(x, k)

In [895]:
l_attr = 1
hidden_feats = 64

irreps_sh = Irreps.spherical_harmonics(l_attr)
irreps_hidden = balanced_irreps(lmax=l_attr, feature_size=hidden_feats, use_sh=True)

irreps_hidden

34x0e+10x1o

In [906]:
x_irreps = IrrepsArray("1o + 1o + 1x0e", x)

graph = jraph.GraphsTuple(
          n_node=n_nodes,
          n_edge=k,
          edges=None,
          globals=None,
          nodes=x_irreps[0], 
          senders=sources[0],
          receivers=targets[0])

segnn = SEGNN(num_message_passing_steps=3, intermediate_hidden_irreps=False, task="graph")

key = jax.random.PRNGKey(0)
out, params = segnn.init_with_output(key, graph)

out

Array([-0.0112217], dtype=float32)

In [907]:
sum(x.size for x in jax.tree_util.tree_leaves(params))

66694

In [908]:
axis = np.array([0, 1 / np.sqrt(2), 1 / np.sqrt(2)])

graph = jraph.GraphsTuple(
          n_node=n_nodes,
          n_edge=k,
          edges=None,
          globals=None,
          nodes=IrrepsArray("1o + 1o + 1x0e", rotate_representation(x[0], 45., axis)), 
          senders=sources[0],
          receivers=targets[0])

out_rot, params = segnn.init_with_output(key, graph)

out_rot

Array([-0.0112217], dtype=float32)

In [909]:
# out_rot.nodes.array / rotate_representation(out.nodes.array, 45, axis)

In [910]:
import flax.linen as nn
import e3nn_jax as e3nn

class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(SEGNN(num_message_passing_steps=3, intermediate_hidden_irreps=False, task="graph"))
        return model(x)

model = GraphWrapper()
rng = jax.random.PRNGKey(0)

In [911]:
graph = jraph.GraphsTuple(
          n_node=jnp.array(n_batch * [[n_nodes]]), 
          n_edge=jnp.array(n_batch * [[k]]),
          nodes=x_irreps, 
          edges=None,
          globals=None,
          senders=sources,
          receivers=targets)

graph_out, params1 = model.init_with_output(rng, graph)
x_out = graph_out  # Output features

x_out

In [None]:
sum(x.size for x in jax.tree_util.tree_leaves(params1))

66694