In [1]:
import jax
import jax.numpy as np
import jraph
import flax.linen as nn
import e3nn_jax as e3nn

from functools import partial
from typing import Callable, List, Tuple

%load_ext autoreload
%autoreload 2

In [2]:
jnp = np

In [3]:
def _index_max(i: jnp.ndarray, x: jnp.ndarray, out_dim: int) -> jnp.ndarray:
    return jnp.zeros((out_dim,) + x.shape[1:], x.dtype).at[i].max(x)


class Transformer(nn.Module):
    irreps_node_output: e3nn.Irreps
    list_neurons: Tuple[int, ...]
    act: Callable[[jnp.ndarray], jnp.ndarray]
    num_heads: int = 1

    @nn.compact
    def __call__(
        self,
        edge_src: jnp.ndarray,  # [E] dtype=int32
        edge_dst: jnp.ndarray,  # [E] dtype=int32
        edge_weight_cutoff: jnp.ndarray,  # [E] dtype=float
        edge_attr: e3nn.IrrepsArray,  # [E, D] dtype=float
        node_feat: e3nn.IrrepsArray,  # [N, D] dtype=float
    ) -> e3nn.IrrepsArray:
        r"""Equivariant Transformer.

        Args:
            edge_src (array of int32): source index of the edges
            edge_dst (array of int32): destination index of the edges
            edge_weight_cutoff (array of float): cutoff weight for the edges (typically given by ``soft_envelope``)
            edge_attr (e3nn.IrrepsArray): attributes of the edges (typically given by ``spherical_harmonics``)
            node_f (e3nn.IrrepsArray): features of the nodes

        Returns:
            e3nn.IrrepsArray: output features of the nodes
        """

        def f(x, y, filter_ir_out=None, name=None):
            out1 = e3nn.concatenate([x, e3nn.tensor_product(x, y.filter(drop="0e"))]).regroup().filter(keep=filter_ir_out)
            out2 = e3nn.flax.MultiLayerPerceptron(
                self.list_neurons + (out1.irreps.num_irreps,), self.act, output_activation=False, name=name
            )(y.filter(keep="0e"))
            return out1 * out2

        edge_key = f(node_feat[edge_src], edge_attr, node_feat.irreps, name="mlp_key")
        edge_logit = e3nn.flax.Linear(f"{self.num_heads}x0e", name="linear_logit")(
            e3nn.tensor_product(node_feat[edge_dst], edge_key, filter_ir_out="0e")
        ).array  # [E, H]
        node_logit_max = _index_max(edge_dst, edge_logit, node_feat.shape[0])  # [N, H]
        exp = edge_weight_cutoff[:, None] * jnp.exp(edge_logit - node_logit_max[edge_dst])  # [E, H]
        z = e3nn.scatter_sum(exp, dst=edge_dst, output_size=node_feat.shape[0])  # [N, H]
        z = jnp.where(z == 0.0, 1.0, z)
        alpha = exp / z[edge_dst]  # [E, H]

        edge_v = f(node_feat[edge_src], edge_attr, self.irreps_node_output, "mlp_val")  # [E, D]
        edge_v = edge_v.mul_to_axis(self.num_heads)  # [E, H, D]
        edge_v = edge_v * jnp.sqrt(jax.nn.relu(alpha))[:, :, None]  # [E, H, D]
        edge_v = edge_v.axis_to_mul()  # [E, D]

        node_out = e3nn.scatter_sum(edge_v, dst=edge_dst, output_size=node_feat.shape[0])  # [N, D]
        return e3nn.flax.Linear(self.irreps_node_output, name="linear_out")(node_out)  # [N, D]


In [4]:
import sys
sys.path.append("../")

from models.graph_utils import nearest_neighbors

In [24]:
class Module(nn.Module):
    irreps_out: e3nn.Irreps

    @nn.compact
    def __call__(
        self,
        positions: e3nn.IrrepsArray,  # [N, 3] dtype=float
        features: e3nn.IrrepsArray,  # [N, D] dtype=float
        senders: np.array, 
        receivers: np.array,
        cutoff: float = 1.,
    ):
        r"""Equivariant Transformer.

        Args:
            positions (e3nn.IrrepsArray): positions of the nodes
            features (e3nn.IrrepsArray): features of the nodes
            senders (np.array): features of the nodes
            receivers (np.array): features of the nodes
            cutoff (float): cutoff radius

        Returns:
            e3nn.IrrepsArray: output features of the nodes
        """

        vectors = positions[senders] - positions[receivers]
        dist = jnp.linalg.norm(vectors.array, axis=1) / cutoff

        edge_attr = e3nn.concatenate([
            e3nn.bessel(dist, 8),
            e3nn.spherical_harmonics(list(range(1, 3 + 1)), vectors, True)
        ])
        edge_weight_cutoff = e3nn.soft_envelope(dist)

        features = Transformer(
            irreps_node_output=e3nn.Irreps("1o") + self.irreps_out,
            list_neurons=(64, 64),
            act=jax.nn.gelu,
            num_heads=1,
        )(senders, receivers, edge_weight_cutoff, edge_attr, features)

        displacements, features = features.slice_by_mul[:1], features.slice_by_mul[1:]
        positions = positions + displacements
        return positions, features

In [25]:
data_dir = "/n/holyscratch01/iaifi_lab/ccuesta/data_for_sid/"

# Load and normalize
x = np.load("{}/halos.npy".format(data_dir))
x_mean = x.mean(axis=(0,1))
x_std = x.std(axis=(0,1))
x = (x - x_mean + 1e-7) / (x_std + 1e-7)

In [26]:
positions = x[:, :, :3]
features = x[:, :, 3:]

In [46]:
model = Module(irreps_out=" 1o + 0e")  # Masses, velocities

In [47]:
idx = 0

pos = e3nn.IrrepsArray("1o", positions[idx])
feat = e3nn.IrrepsArray("1o + 0e", features[idx])  # Concat other parameters (time, cosmology) here as scalars

sources, targets = nearest_neighbors(positions[idx], 20)

In [48]:
rng = jax.random.PRNGKey(42)
out, params = model.init_with_output(rng, pos, feat, sources, targets, 1.)

In [49]:
sum(x.size for x in jax.tree_leaves(params))

  sum(x.size for x in jax.tree_leaves(params))


9869

## Batched

In [91]:
class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, pos, feat, sources, targets):
        model = jax.vmap(Module(irreps_out="0e + 1o"))
        return model(pos, feat, sources, targets)

In [103]:
vel, mass = features[:, :, :3], features[:, :, 3:] 

In [105]:
cond = jax.random.normal(rng, (n_batch, 3))

In [123]:
n_batch = 4

pos = e3nn.IrrepsArray("1o", positions[:n_batch])
feat = e3nn.IrrepsArray("1o + 3x0e", np.concatenate([vel[:n_batch], mass[:n_batch] + cond[:, None, :]], -1))  # Velocities and masses; concat other parameters (time, cosmology) here as scalars

sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(positions[:n_batch], 20)

In [124]:
model = GraphWrapper()
rng = jax.random.PRNGKey(42)
graph_out, _ = model.init_with_output(rng, pos, feat, sources, targets)

In [125]:
pos_update, feat_update = graph_out

In [126]:
z = np.concatenate([pos_update.array, feat_update.array], -1)

In [128]:
z

Array([[[-7.7683401e-01,  1.4123533e+00, -1.0109222e-01, ...,
         -5.1964450e+00, -4.0435457e+00,  3.0125337e+00],
        [-1.9902754e+00,  3.4663868e-01,  1.9180111e+00, ...,
          4.8906431e+00,  7.3061442e-01, -4.4355278e+00],
        [-1.8633548e+00,  7.1310925e-01,  3.1873112e+00, ...,
          2.1785817e+00,  1.9830981e+00, -2.3987551e-01],
        ...,
        [-2.2503205e-01,  1.8321974e+00, -9.1458201e-02, ...,
          1.5460939e+00, -2.2586634e+00, -2.2370868e+00],
        [-1.7145429e+00,  1.1601226e+00,  1.0122617e+00, ...,
          2.9002097e+00, -2.0736446e+00, -4.6399626e-01],
        [-1.7706704e-01, -1.0801482e+00, -9.5355856e-01, ...,
          2.9872074e+00, -1.5950775e+00,  5.1751769e-01]],

       [[ 1.2023078e+01, -4.3291950e-01,  3.6274630e-01, ...,
         -3.3836920e+00, -4.0713320e+00, -1.5207373e+00],
        [-9.3178052e-01,  1.6736686e+00,  1.5778362e+00, ...,
          3.7244375e+00,  7.8742927e-01, -2.9427314e+00],
        [ 1.2329445e+01, 

In [127]:
np.isnan(z).sum()

Array(0, dtype=int32)