Mathematically, symmetries are usually described by $groups$.
We can characterize the relationship between a function (such as a neural network layer) and a symmetry group by considering its \textit{equivariance} properties.
A map $f: X \rightarrow Y$ is said to be equivariant w.r.t. the actions $\rho:G\times X\to X$ and $\rho':G\times Y\to Y$ of a group $G$ on $X$ and $Y$ if
$$
    f\Big(\rho_g(x) \Big) = \rho_g' \Big(f(x)\Big)\,
$$

Reference : arXiv 2203.06153

<center width="600%"><img src="invariance_vs_equivariance.png" alt="Alternative text"  width="600px"></center>


The equivariant transformation on the graph is defined by 
$$\phi \Big( T_g (x) \Big) = S_g  \Big( \phi (x)\Big) $$
<center width="500%"><img src="egnn.png" alt="Alternative text"  width="500px"></center>

The equivariant graph convolutional layer EGCL (ref : 2102.09844) is defined by as following 

$$h^{l+1}, x^{l+1} = EGCL(h^l, x^l, \mathcal{E})$$

It happens over following steps

$$m_{ij} = \phi_e \Bigg( h_i^l, h_j^l, ||x_i^l - x_j^l||^2, a_{ij}\Bigg)$$
$$x_i^{l+1} = x_i^l + C \sum_{j \in \mathcal{N}_i} (x_i^l - x_j^l) ~\phi_x (m_{ij})$$
$$m_i = \sum_{j \in \mathcal{N}_i} m_{ij}$$
$$h_i^{l+1} = \phi_h(h_i^l, m_i)$$





In [3]:
# https://github.com/RobDHess/Steerable-E3-GNN
import torch
from torch.nn import Linear, ReLU, SiLU, Sequential
from torch_geometric.nn import MessagePassing, global_add_pool, global_mean_pool
from torch_scatter import scatter


class EGNNLayer(MessagePassing):
    """E(n) Equivariant GNN Layer

    Paper: E(n) Equivariant Graph Neural Networks, Satorras et al.
    """
    def __init__(self, emb_dim, edge_dim, activation="relu", norm="layer", aggr="add"):
        """
        Args:
            emb_dim: (int) - hidden dimension `d`
            activation: (str) - non-linearity within MLPs (swish/relu)
            norm: (str) - normalisation layer (layer/batch)
            aggr: (str) - aggregation function `\oplus` (sum/mean/max)
        """
        # Set the aggregation function
        super().__init__(aggr=aggr)

        self.emb_dim = emb_dim
        self.activation = {"swish": SiLU(), "relu": ReLU()}[activation]
        self.norm = {"layer": torch.nn.LayerNorm, "batch": torch.nn.BatchNorm1d}[norm]

        # MLP `\psi_h` for computing messages `m_ij`
        self.mlp_msg = Sequential(
            Linear(2 * emb_dim + 1 + edge_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
            Linear(emb_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
        )
        # MLP `\psi_x` for computing messages `\overrightarrow{m}_ij`
        self.mlp_pos = Sequential(
            Linear(emb_dim, emb_dim), self.norm(emb_dim), self.activation, Linear(emb_dim, 1)
        )
        # MLP `\phi` for computing updated node features `h_i^{l+1}`
        self.mlp_upd = Sequential(
            Linear(2 * emb_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
            Linear(emb_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
        )

    def forward(self, h, pos, edge_index, edge_attribute):
        """
        Args:
            h: (n, d) - initial node features
            pos: (n, 3) - initial node coordinates
            edge_index: (e, 2) - pairs of edges (i, j)
        Returns:
            out: [(n, d),(n,3)] - updated node features
        """
        out = self.propagate(edge_index, h=h, pos=pos, edge_attribute=edge_attribute)
        return out

    def message(self, h_i, h_j, pos_i, pos_j, edge_attribute):
        # Compute messages
        pos_diff = pos_i - pos_j
        dists = torch.norm(pos_diff, dim=-1).unsqueeze(1)
        msg = torch.cat([h_i, h_j, dists, edge_attribute], dim=-1)
        msg = self.mlp_msg(msg)
        # Scale magnitude of displacement vector
        pos_diff = pos_diff * self.mlp_pos(msg)
        # NOTE: some papers divide pos_diff by (dists + 1) to stabilise model.
        # NOTE: lucidrains clamps pos_diff between some [-n, +n], also for stability.
        return msg, pos_diff

    def aggregate(self, inputs, index):
        msgs, pos_diffs = inputs
        # Aggregate messages
        msg_aggr = scatter(msgs, index, dim=self.node_dim, reduce=self.aggr)
        # Aggregate displacement vectors
        pos_aggr = scatter(pos_diffs, index, dim=self.node_dim, reduce="sum")
        return msg_aggr, pos_aggr

    def update(self, aggr_out, h, pos):
        msg_aggr, pos_aggr = aggr_out
        upd_out = self.mlp_upd(torch.cat([h, msg_aggr], dim=-1))
        upd_pos = pos + pos_aggr
        return upd_out, upd_pos

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})"

# Lorentz equivariant GNN
Reference : 2201.08187
Proposition : A continuous function 

<center width="700%"><img src="architecture.jpg" alt="Alternative text"  width="700px"></center>

$$h^{l+1}, x^{l+1} = LGCL(h^l, x^l, \mathcal{E})$$

It happens over following steps

$$m_{ij} = \phi_e \Bigg( h_i^l, h_j^l, \psi(||x_i^l - x_j^l||^2), \psi( \langle x_i^l, x_j^l\rangle)\Bigg)$$
$$w_{ij} = \phi_{m}(m_{ij})$$
$$x_i^{l+1} = x_i^l + C \sum_{j \in \mathcal{N}_i}  \phi_x (m_{ij})~x_j^l $$
$$h_i^{l+1} = h_i^l + \phi_h(h_i^l, \sum_{j \in \mathcal{N}_i} w_{ij} m_{ij})$$

In [4]:
import torch
import torch_geometric

device = torch.device("cuda:0" if torch.cuda.is_available() else "mps")
from tqdm.notebook import tqdm
import numpy as np

local = False

In [5]:
from JetDataset import Jet_Dataset
from mlp import build_mlp

dataset_path = '/Users/sanmay/Documents/ICTS_SCHOOL/Main_School/JetDataset/'
file_name = dataset_path + 'JetClass_example_100k.root' # -- from -- "https://hqu.web.cern.ch/datasets/JetClass/example/" #
jet_dataset = Jet_Dataset(dataset_path=file_name)

In [40]:
minkowski = torch.from_numpy(
            np.array(
                [
                    [1.0, 0.0, 0.0, 0.0],
                    [0.0, -1.0, 0.0, 0.0],
                    [0.0, 0.0, -1.0, 0.0],
                    [0.0, 0.0, 0.0, -1.0],
                ],
                #dtype=np.float32,
            ))

def innerprod(x1, x2):
        return torch.sum(
            torch.matmul(x2.T, torch.matmul(minkowski, x1)), dim=1, keepdim=True
        )
        

In [64]:
class Lorentz_GNNLayer(MessagePassing):
    """E(n) Equivariant GNN Layer

    Paper: E(n) Equivariant Graph Neural Networks, Satorras et al.
    """
    def __init__(self, emb_dim, coord_dim,  activation="relu", norm="layer", aggr="add"):
        """
        Args:
            emb_dim: (int) - hidden dimension `d`
            activation: (str) - non-linearity within MLPs (swish/relu)
            norm: (str) - normalisation layer (layer/batch)
            aggr: (str) - aggregation function `\oplus` (sum/mean/max)
        """
        # Set the aggregation function
        super(Lorentz_GNNLayer, self).__init__(aggr=aggr)

        self.emb_dim = emb_dim
        self.coord_dim = coord_dim
        self.activation = {"swish": SiLU(), "relu": ReLU()}[activation]
        self.norm = {"layer": torch.nn.LayerNorm, "batch": torch.nn.BatchNorm1d}[norm]

        # MLP `\psi_h` for computing messages `m_ij`
        self.mlp_phi_e = build_mlp(2*emb_dim + 2*coord_dim, 1, features=[3, 4, 2])
        self.mlp_phi_x = build_mlp(1, 1, features=[3, 4, 2])
        self.mlp_phi_h = build_mlp(emb_dim+1, 1, features=[3, 4, 2])
        self.mlp_phi_m = build_mlp(1, 1, features=[3, 4, 2])
        
        
        self.minkowski = torch.from_numpy(
            np.array(
                [
                    [1.0, 0.0, 0.0, 0.0],
                    [0.0, -1.0, 0.0, 0.0],
                    [0.0, 0.0, -1.0, 0.0],
                    [0.0, 0.0, 0.0, -1.0],
                ],
                dtype=np.float32,
            )
        )
        
    def psi(self, x):
        return torch.sign(x) * torch.log(torch.abs(x) + 1)

    def innerprod(self, x1, x2):
        return torch.sum(
            torch.matmul(x2.T, torch.matmul(self.minkowski, x1)), dim=1, keepdim=True
        )

    def forward(self, h, pos, edge_index, edge_attribute):
        """
        Args:
            h: (n, d) - initial node features
            pos: (n, 3) - initial node coordinates
            edge_index: (e, 2) - pairs of edges (i, j)
        Returns:
            out: [(n, d),(n,3)] - updated node features
        """
        out = self.propagate(edge_index, h=h, pos=pos, edge_attribute=edge_attribute)
        return out

    def message(self, h_i, h_j, x_i, x_j):
        # Compute messages
        msg = torch.cat(
            [h_i, h_j,
                self.psi(self.innerprod(x_i - x_j, x_i - x_j)),
                self.psi(self.innerprod(x_i, x_j))
            ],
            dim=1,
        )
        
        phi_x = self.mlp_phi_x(msg) * x_j
        
        w_ij = self.mlp_phi_m(msg) * msg
        return phi_x, w_ij

    def aggregate(self, inputs, index):
        phi_x, w_ij = inputs
        # Aggregate messages
        x_aggr = scatter(phi_x, index, dim=self.node_dim, reduce=self.aggr)
        # Aggregate displacement vectors
        w_aggr = scatter(w_ij, index, dim=self.node_dim, reduce="sum")
        return x_aggr, w_aggr

    def update(self, aggr_out, h, x):
        x_aggr, w_aggr = aggr_out
        upd_x = x + x_aggr
        upd_h = h + w_aggr
        return upd_x, upd_h

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})"

In [41]:
x1, x2 = torch.randn(4, 1), torch.randn(4, 1)

In [42]:
innerprod(x1, x2)

tensor([[2.3278]])

In [28]:
a

tensor([[ 0.3714],
        [-0.2663],
        [-0.7747],
        [ 0.0447]])

In [31]:
torch.matmul(x2.T, a)

tensor([[0.7197]])

In [58]:

import torch_geometric
from torch_geometric.data import Data
from torch_geometric.data import Batch
import torch_geometric.transforms as T
from torch_geometric.utils import remove_self_loops, to_dense_adj, dense_to_sparse, to_undirected
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing, global_mean_pool, knn_graph
from torch_geometric.datasets import QM9
from torch_scatter import scatter
from torch_cluster import knn

In [59]:
data_loader = DataLoader(dataset=jet_dataset, batch_size=5, shuffle = True)

In [60]:
gr_b = next(iter(data_loader))

In [61]:
x, edge_index, batch = gr_b.x, gr_b.edge_index, gr_b.batch

In [62]:
x.shape

torch.Size([172, 16])

In [63]:
model(x, edge_index, batch)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (860x16 and 4x4)