In [23]:
import jax
import jax.numpy as np
import jraph
import flax.linen as nn
import e3nn_jax as e3nn

from functools import partial
from typing import Callable, List, Tuple

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [24]:
jnp = np

In [25]:
def _index_max(i: jnp.ndarray, x: jnp.ndarray, out_dim: int) -> jnp.ndarray:
    return jnp.zeros((out_dim,) + x.shape[1:], x.dtype).at[i].max(x)


class Transformer(nn.Module):
    irreps_node_output: e3nn.Irreps
    list_neurons: Tuple[int, ...]
    act: Callable[[jnp.ndarray], jnp.ndarray]
    num_heads: int = 1

    @nn.compact
    def __call__(
        self,
        edge_src: jnp.ndarray,  # [E] dtype=int32
        edge_dst: jnp.ndarray,  # [E] dtype=int32
        edge_weight_cutoff: jnp.ndarray,  # [E] dtype=float
        edge_attr: e3nn.IrrepsArray,  # [E, D] dtype=float
        node_feat: e3nn.IrrepsArray,  # [N, D] dtype=float
    ) -> e3nn.IrrepsArray:
        r"""Equivariant Transformer.

        Args:
            edge_src (array of int32): source index of the edges
            edge_dst (array of int32): destination index of the edges
            edge_weight_cutoff (array of float): cutoff weight for the edges (typically given by ``soft_envelope``)
            edge_attr (e3nn.IrrepsArray): attributes of the edges (typically given by ``spherical_harmonics``)
            node_f (e3nn.IrrepsArray): features of the nodes

        Returns:
            e3nn.IrrepsArray: output features of the nodes
        """

        def f(x, y, filter_ir_out=None, name=None):
            out1 = e3nn.concatenate([x, e3nn.tensor_product(x, y.filter(drop="0e"))]).regroup().filter(keep=filter_ir_out)
            out2 = e3nn.flax.MultiLayerPerceptron(
                self.list_neurons + (out1.irreps.num_irreps,), self.act, output_activation=False, name=name
            )(y.filter(keep="0e"))
            return out1 * out2

        edge_key = f(node_feat[edge_src], edge_attr, node_feat.irreps, name="mlp_key")
        edge_logit = e3nn.flax.Linear(f"{self.num_heads}x0e", name="linear_logit")(
            e3nn.tensor_product(node_feat[edge_dst], edge_key, filter_ir_out="0e")
        ).array  # [E, H]
        node_logit_max = _index_max(edge_dst, edge_logit, node_feat.shape[0])  # [N, H]
        exp = edge_weight_cutoff[:, None] * jnp.exp(edge_logit - node_logit_max[edge_dst])  # [E, H]
        z = e3nn.scatter_sum(exp, dst=edge_dst, output_size=node_feat.shape[0])  # [N, H]
        z = jnp.where(z == 0.0, 1.0, z)
        alpha = exp / z[edge_dst]  # [E, H]

        edge_v = f(node_feat[edge_src], edge_attr, self.irreps_node_output, "mlp_val")  # [E, D]
        edge_v = edge_v.mul_to_axis(self.num_heads)  # [E, H, D]
        edge_v = edge_v * jnp.sqrt(jax.nn.relu(alpha))[:, :, None]  # [E, H, D]
        edge_v = edge_v.axis_to_mul()  # [E, D]

        node_out = e3nn.scatter_sum(edge_v, dst=edge_dst, output_size=node_feat.shape[0])  # [N, D]
        return e3nn.flax.Linear(self.irreps_node_output, name="linear_out")(node_out)  # [N, D]


In [26]:
@partial(jax.jit, static_argnums=(1,))
def nearest_neighbors(x, k, mask=None):
    """The shittiest implementation of nearest neighbours with masking in the world"""

    if mask is None:
        mask = np.ones((x.shape[0],), dtype=np.int32)

    n_nodes = x.shape[0]

    distance_matrix = np.sum((x[:, None, :] - x[None, :, :]) ** 2, axis=-1)

    distance_matrix = np.where(mask[:, None], distance_matrix, np.inf)
    distance_matrix = np.where(mask[None, :], distance_matrix, np.inf)

    indices = np.argsort(distance_matrix, axis=-1)[:, :k]

    sources = indices[:, 0].repeat(k)
    targets = indices.reshape(n_nodes * (k))

    return (sources, targets)

In [27]:
class Module(nn.Module):
    irreps_out: e3nn.Irreps

    @nn.compact
    def __call__(
        self,
        positions: e3nn.IrrepsArray,  # [N, 3] dtype=float
        features: e3nn.IrrepsArray,  # [N, D] dtype=float
        senders: np.array, 
        receivers: np.array,
        cutoff: float = 1.,
    ):
        r"""Equivariant Transformer.

        Args:
            positions (e3nn.IrrepsArray): positions of the nodes
            features (e3nn.IrrepsArray): features of the nodes
            senders (np.array): features of the nodes
            receivers (np.array): features of the nodes
            cutoff (float): cutoff radius

        Returns:
            e3nn.IrrepsArray: output features of the nodes
        """

        vectors = positions[senders] - positions[receivers]
        dist = jnp.linalg.norm(vectors.array, axis=1) / cutoff

        edge_attr = e3nn.concatenate([
            e3nn.bessel(dist, 8),
            e3nn.spherical_harmonics(list(range(1, 3 + 1)), vectors, True)
        ])
        edge_weight_cutoff = e3nn.soft_envelope(dist)

        features = Transformer(
            irreps_node_output=e3nn.Irreps("1o") + self.irreps_out,
            list_neurons=(64, 64),
            act=jax.nn.gelu,
            num_heads=1,
        )(senders, receivers, edge_weight_cutoff, edge_attr, features)

        displacements, features = features.slice_by_mul[:1], features.slice_by_mul[1:]
        positions = positions + displacements
        return positions, features

## Batched

In [34]:
class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, pos, feat, sources, targets):
        model = jax.vmap(Module(irreps_out="0e + 1o"), in_axes=(0,0,0,0,None))
        return model(pos, feat, sources, targets, 0.1)

In [35]:
# Make some dummy data

n_batch = 4

rng = jax.random.PRNGKey(345)

positions = jax.random.uniform(rng, (n_batch, 5000, 3))
vel = jax.random.normal(rng, (n_batch, 5000, 3))
mass = jax.random.normal(rng, (n_batch, 5000, 1))
cond = jax.random.normal(rng, (n_batch, 3))

In [38]:
# Make irreps arrays, get edges

pos = e3nn.IrrepsArray("1o", positions[:n_batch])
feat = e3nn.IrrepsArray("1o + 3x0e", np.concatenate([vel, mass + cond[:, None, :]], -1))  # Velocities and masses; concat other parameters (time, cosmology) here as scalars

sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(positions[:n_batch], 50)  # 50 is number of edge connections

In [39]:
# Update positions and features
model = GraphWrapper()
(pos_update, feat_update), params = model.init_with_output(rng, pos, feat, sources, targets)

In [40]:
# Original positions
np.max(pos.array), np.min(pos.array), np.max(feat.array), np.min(feat.array), 

(Array(0.9999982, dtype=float32),
 Array(2.2649765e-05, dtype=float32),
 Array(4.875551, dtype=float32),
 Array(-5.4267583, dtype=float32))

In [41]:
# Updated positions are big
np.max(pos_update.array), np.min(pos_update.array), np.max(feat_update.array), np.min(feat_update.array), 

(Array(88.94738, dtype=float32),
 Array(-72.18628, dtype=float32),
 Array(206.50436, dtype=float32),
 Array(-197.39395, dtype=float32))

In [42]:
# Updated full array
z = np.concatenate([pos_update.array, feat_update.array], -1)

In [43]:
sum(x.size for x in jax.tree_leaves(params))

  sum(x.size for x in jax.tree_leaves(params))


10399