In [426]:
from typing import Optional, Tuple, Union
import warnings

import jax
import jax.numpy as jnp
import flax.linen as nn

import e3nn_jax as e3nn
from e3nn_jax import Irreps
from e3nn_jax import IrrepsArray
from e3nn_jax import FunctionalLinear

from e3nn_jax import tensor_product
from e3nn_jax.flax import Linear


In [427]:
class TensorProductLinearGate(nn.Module):

    output_irreps: Irreps = None
    bias: bool = True
    gradient_normalization: Optional[Union[str, float]] = "element"
    path_normalization: Optional[Union[str, float]] = None

    @nn.compact
    def __call__(self, x: IrrepsArray, y: IrrepsArray) -> IrrepsArray:

        output_irreps = self.output_irreps
        if not isinstance(output_irreps, Irreps):
            output_irreps = Irreps(output_irreps)

        # Predict extra scalars for gating \ell > 0 irreps
        gate_irreps = Irreps(f"{output_irreps.num_irreps - output_irreps.count('0e')}x0e")
        output_irreps = (gate_irreps + output_irreps).regroup()
        linear = Linear(output_irreps, biases=self.bias, gradient_normalization=self.gradient_normalization, path_normalization=self.path_normalization)
        out = linear(tensor_product(x, y))
        out = e3nn.gate(out)  # Default activations
        return out

In [428]:
import numpy as np

n_batch = 2
n_nodes = 400
k = 20


x = np.load("../data/halos_small.npy")[:n_batch, :n_nodes, :]

# Normalize
x_mean = x.mean((0, 1))
x_std = x.std((0, 1))
x[:, :, 3:] = (x[:, :, 3:] - x_mean[3:]) / x_std[3:]
x[:, :, :3] = x[:, :, :3] / 1000.  # Divide by box size

In [429]:
pos = IrrepsArray("1o", x[0, :2, :3])  # Positions
feat = IrrepsArray("1o + 1x0e", x[0, :2, 3:])  # Features

irreps_out = Irreps("3x0o + 1x1e + 2x1o")

In [430]:
tp = TensorProductLinearGate(irreps_out)
key = jax.random.PRNGKey(0)

tp.init_with_output(key, pos, feat)

(3x0e+2x1o+1x1e
 [[-0.04109664  0.1015624  -0.07117075 -0.11383748 -0.7112009  -0.6585651
   -0.03152191 -0.1969335  -0.18235849 -0.32363778  0.24062441 -0.20391332]
  [-0.05772255  0.15109392 -0.09818695 -0.13843763 -0.56202585 -0.7538356
   -0.04004519 -0.16257456 -0.21805845  0.20437935 -0.28139314  0.17226094]],
 {'params': {'Linear_0': {'w[0,0] 1x0e,6x0e': Array([[ 0.21604401, -0.4652762 ,  0.3886473 ,  0.08119641, -0.7921135 ,
             0.45960167]], dtype=float32),
    'w[1,1] 1x1o,2x1o': Array([[1.073293  , 0.26672718]], dtype=float32),
    'w[2,2] 1x1e,1x1e': Array([[-0.35275576]], dtype=float32),
    'b[0] 6x0e': Array([ 0.,  0., -0., -0., -0., -0.], dtype=float32)}}})

In [431]:
def balanced_irreps(lmax: int, feature_size: int, use_sh: bool = True) -> Irreps:
    """Allocates irreps uniformely up until level lmax with budget feature_size."""
    irreps = ["0e"]
    n_irreps = 1 + (lmax if use_sh else lmax * 2)
    total_dim = 0
    for level in range(1, lmax + 1):
        dim = 2 * level + 1
        multi = int(feature_size / dim / n_irreps)
        if multi == 0:
            break
        if use_sh:
            irreps.append(f"{multi}x{level}{'e' if (level % 2) == 0 else 'o'}")
            total_dim = multi * dim
        else:
            irreps.append(f"{multi}x{level}e+{multi}x{level}o")
            total_dim = multi * dim * 2

    # add scalars to fill missing dimensions
    irreps[0] = f"{feature_size - total_dim}x{irreps[0]}"

    return Irreps("+".join(irreps))


In [433]:
l_attr = 1
hidden_feats = 64

irreps_sh = Irreps.spherical_harmonics(l_attr)
irreps_hidden = balanced_irreps(lmax=l_attr, feature_size=hidden_feats, use_sh=True)

irreps_hidden

34x0e+10x1o

In [435]:
import jraph
from jraph._src import utils

def get_edge_mlp_updates(irreps_out: Irreps = None, n_layers: int = 2, irreps_sh: Irreps = None):
    def update_fn(
        edges: jnp.array,
        senders: jnp.array,
        receivers: jnp.array,
        globals: jnp.array,
    ) -> jnp.array:
        
        x_i = senders.array[..., :3]
        x_j = receivers.array[..., :3]

        m_ij = e3nn.concatenate([senders, receivers], axis=-1)
        a_ij = e3nn.spherical_harmonics(irreps_out=irreps_sh, input=x_i - x_j, normalize=False)

        for _ in range(n_layers):
            m_ij = TensorProductLinearGate(irreps_out)(m_ij, a_ij)

        return m_ij, a_ij

    return update_fn

def get_node_mlp_updates(irreps_out: Irreps = None, n_layers: int = 2):
    def update_fn(
        nodes: jnp.array,
        senders: jnp.array,
        receivers: jnp.array,
        globals: jnp.array,
    ) -> jnp.array:
        
        m_i, a_i = receivers

        m_i = e3nn.concatenate([m_i, a_i], axis=-1)

        for _ in range(n_layers):
            nodes = TensorProductLinearGate(irreps_out)(nodes, m_i)

        return nodes

    return update_fn

class SEGNN(nn.Module):
    num_message_passing_steps: int = 3
    message_passing_agg: str = "mean"

    @nn.compact
    def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple:

        aggregate_edges_for_nodes_fn = getattr(utils, f"segment_{self.message_passing_agg}")

        irreps_in = graphs.nodes.irreps
        for _ in range(self.num_message_passing_steps):

            update_edge_fn = get_edge_mlp_updates(irreps_out=irreps_hidden, n_layers=2, irreps_sh=irreps_sh)
            update_node_fn = get_node_mlp_updates(irreps_out=irreps_in, n_layers=2)

            graph_net = jraph.GraphNetwork(update_node_fn=update_node_fn, update_edge_fn=update_edge_fn, aggregate_edges_for_nodes_fn=aggregate_edges_for_nodes_fn)
            graphs = graph_net(graphs)

            graphs = graphs._replace(nodes=e3nn.flax.Linear(irreps_in)(graphs.nodes))

        return graphs

In [436]:
import sys
sys.path.append("../")
from utils.graph_utils import nearest_neighbors, rotate_representation

# Get nearest neighbors graph
sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(x, k)

In [449]:
x_irreps = IrrepsArray("1o + 1o + 1x0e", x)

graph = jraph.GraphsTuple(
          n_node=n_nodes,
          n_edge=k,
          edges=None,
          globals=None,
          nodes=x_irreps[0], 
          senders=sources[0],
          receivers=targets[0])

segnn = SEGNN()

out, params = segnn.init_with_output(key, graph)

out.nodes

1x1o+1x1o+1x0e
[[-7.6288194e-07  1.7018822e-07  1.0956992e-06 ... -1.5330832e-07
  -5.2964089e-07  9.3473750e-07]
 [ 3.0357103e-07  4.7801899e-07  4.4065919e-07 ... -2.5360916e-07
  -2.8834089e-07  8.8322116e-07]
 [ 8.9096574e-08  3.8601397e-07  1.2021733e-06 ... -1.8044206e-07
  -6.0714831e-07  1.5226134e-07]
 ...
 [ 8.7897450e-07  6.3043950e-07 -6.8528550e-07 ... -2.9397842e-07
   3.0832001e-07  1.4800338e-07]
 [ 7.9245484e-07  1.3323471e-06  1.7097606e-06 ... -6.2310596e-07
  -8.1323913e-07  1.9142585e-07]
 [ 7.8191596e-07  2.4820852e-07  1.1782128e-06 ... -1.1576098e-07
  -5.9680838e-07  1.8308998e-07]]

In [450]:
sum(x.size for x in jax.tree_util.tree_leaves(params))

10395

In [451]:
axis = np.array([0, 1 / np.sqrt(2), 1 / np.sqrt(2)])
x_irreps = IrrepsArray("1o + 1o + 1x0e", rotate_representation(x[0], 45., axis))

graph = jraph.GraphsTuple(
          n_node=n_nodes,
          n_edge=k,
          edges=None,
          globals=None,
          nodes=x_irreps, 
          senders=sources[0],
          receivers=targets[0])

segnn = SEGNN()

out_rot, params = segnn.init_with_output(key, graph)

out_rot.nodes

1x1o+1x1o+1x0e
[[-1.0021924e-06  6.8716861e-07  5.7871983e-07 ... -3.7197483e-07
  -3.1097488e-07  9.3473750e-07]
 [ 2.3333621e-07  3.2076218e-07  5.9791682e-07 ... -2.0016006e-07
  -3.4179030e-07  8.8322082e-07]
 [-3.4507957e-07  4.6098992e-07  1.1271973e-06 ... -2.2670181e-07
  -5.6088874e-07  1.5226135e-07]
 ...
 [ 1.2793918e-06 -1.7314945e-09 -5.3114281e-08 ...  9.2481800e-09
   5.0933995e-09  1.4800337e-07]
 [ 3.7164335e-07  9.9138992e-07  2.0507177e-06 ... -4.6889585e-07
  -9.6744884e-07  1.9142576e-07]
 [ 8.7895764e-08 -6.5537367e-09  1.4329753e-06 ...  1.7068778e-08
  -7.2963837e-07  1.8308971e-07]]

In [456]:
out_rot.nodes.array / rotate_representation(out.nodes.array, 45., axis)

Array([[0.999998  , 1.000002  , 0.9999995 , ..., 1.000002  , 0.9999993 ,
        1.        ],
       [0.99999654, 0.99999994, 1.0000015 , ..., 0.9999997 , 1.0000012 ,
        0.99999964],
       [1.0000018 , 1.0000011 , 0.9999996 , ..., 1.0000014 , 0.9999998 ,
        1.0000001 ],
       ...,
       [1.0000004 , 1.0000811 , 0.99999297, ..., 1.0000149 , 0.9999702 ,
        0.9999999 ],
       [1.        , 0.9999994 , 1.0000004 , ..., 0.9999991 , 1.0000001 ,
        0.99999946],
       [0.9999986 , 1.0000318 , 1.0000004 , ..., 1.0000045 , 1.0000005 ,
        0.9999985 ]], dtype=float32)