# LaB-GATr Detailed Model Reference And Usage Tutorial

## Installation

In this notebook we explain the inner blocks of LaB-GATr.

Before using this notebook, install the correct dependencies and LaB-GATr as follows:

Clone the repo by using
```
git clone git@github.com:sukjulian/lab-gatr.git
```

Change the directory into the repository
```
cd lab-gatr
```

Optional new Anaconda environment
```
conda create --name lab-gatr python=3.10
conda activate lab-gatr
``` 
Next, install PyTorch and xFormers and other libraries
```
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cu121
pip install xformers==0.0.22.post7 --index-url https://download.pytorch.org/whl/cu121
pip install torch_geometric==2.4.0
pip install torch_scatter torch_cluster -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
```
Install LaB-GATr itself, which also installs GATr
```
pip install .
```
Additionally, if you have made a new Anaconda environment, install Jupyter
```
pip install jupyter jupyterlab
```

# Introduction

In this tutorial we gradually introduce the building blocks of [LaB-GATr](https://arxiv.org/abs/2403.07536) which is built on top of [GATr](https://www.semanticscholar.org/paper/Geometric-Algebra-Transformer-Brehmer-Haan/4689f6603587e64e87ae36a385e9aab34af2966a). We will introduce the tokenization and the interpolation steps also discussed in our paper and show how to process dummy input with our model from scratch.

# Method

## 1. Infrastructure

In [None]:
from gatr.interface import embed_oriented_plane, extract_translation
import torch
import gatr
from lab_gatr.nn.class_token import class_token_forward_wrapper
from xformers.ops.fmha import BlockDiagonalMask

from lab_gatr.nn.mlp.geometric_algebra import MLP
from lab_gatr.nn.gnn import PointCloudPooling, pool
from torch_scatter import scatter
from gatr.interface import embed_translation

import torch_geometric as pyg
from lab_gatr.data import Data
from torch_cluster import fps, knn

Let us first create a dummy mesh: n positions and orientations (e.g. surface normal) and an arbitrary scalar feature (e.g. geodesic distance).

In [None]:
n = 1000

pos, orientation = torch.rand((n, 3)), torch.rand((n, 3))
scalar_feature = torch.rand(n)

Next, a point cloud pooling transform for the tokenisation (patching).

Nested hierarchy of sub-sampled point clouds. Each coarse-scale point is mapped to a cluster of fine-scale points. Interpolation from the coarse to the fine scales. For correct batching, ```torch_geometric.data.Data.__inc__()``` has to be overridden.

The point cloud pooling transform used by LaB-GATr pre-computes information required for coarsening and refinement of the input mesh.

For coarsening, we construct a [Voronoi](https://en.wikipedia.org/wiki/Voronoi_diagram)-type diagram. We first sample a fraction of the mesh vertices $ \mathcal{P} $, approximately equidistantly, via farthest point sampling. Then we create the Voronoi clusters by querying, for each original mesh vertex, the nearest point in the sub-sampled set $ \mathcal{P} $ via k-nearest neighbours ( $ k = 1 $ ).

For interpolation, we simply query for each original mesh vertex the closest $ k $ points in the sub-sampled set $ \mathcal{P} $. For surface meshes, we choose $ k = 3 $ which corresponds to a triangular reference simplex, while for volume meshes we choose a tetrahedron ( $ k = 4 $ ).


In [None]:
class PointCloudPoolingScales():
    """
    Args:
        rel_sampling_ratios (tuple): relative ratios for successive farthest point sampling
        interp_simplex (str): reference simplex for barycentric interpolation ('triangle' or 'tetrahedron')
    """

    def __init__(self, rel_sampling_ratios: tuple, interp_simplex: str):
        self.rel_sampling_ratios = rel_sampling_ratios
        self.interp_simplex = interp_simplex

        self.dim_interp_simplex = {'triangle': 2, 'tetrahedron': 3}[interp_simplex]

    def __call__(self, data: pyg.data.Data) -> Data:

        pos = data.pos
        batch = data.surface_id.long() if hasattr(data, 'surface_id') else torch.zeros(pos.size(0), dtype=torch.long)

        for i, sampling_ratio in enumerate(self.rel_sampling_ratios):

            sampling_idcs = fps(pos, batch, ratio=sampling_ratio)  # takes some time but is worth it
            # sampling_idcs = torch.arange(0, pos.size(0), 1. / sampling_ratio, dtype=torch.int)

            pool_source, pool_target = knn(pos[sampling_idcs], pos, 1, batch[sampling_idcs], batch)
            interp_target, interp_source = knn(pos[sampling_idcs], pos, self.dim_interp_simplex + 1, batch[sampling_idcs], batch)

            data[f'scale{i}_pool_target'], data[f'scale{i}_pool_source'] = pool_target.int(), pool_source.int()
            data[f'scale{i}_interp_target'], data[f'scale{i}_interp_source'] = interp_target.int(), interp_source.int()
            data[f'scale{i}_sampling_index'] = sampling_idcs.int()

            pos = pos[sampling_idcs]
            batch = batch[sampling_idcs]

        return Data(**data)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(rel_sampling_ratios={self.rel_sampling_ratios}, interp_simplex={self.interp_simplex})"

In [None]:
transform = PointCloudPoolingScales(rel_sampling_ratios=(0.2,), interp_simplex='triangle')
dummy_data = transform(pyg.data.Data(pos=pos, orientation=orientation, scalar_feature=scalar_feature))

## 2. Building the Model

A geometric algebra interface to embed your data in $\mathbf{G}(3, 0, 1)$.



Simply giving input and output channels is not enough when working with geometric algebra. Instead, we must define how our data is going to be represented by multivectors (and auxiliary scalars). For convenience, we package this in a `GeometricAlgebraInterface`. Here, we embed the mesh surface normal as planes at each vertex and pass an addtional scalar feature.

At the model output, we "translate" the multivectors back into Euclidean vectors by interpreting them as translation (pun intended). There are multiple modelling choices for how to perform this extraction (see [paper](https://arxiv.org/abs/2403.07536)).


In [None]:
class GeometricAlgebraInterface:
    num_input_channels = num_output_channels = 1
    num_input_scalars = num_output_scalars = 1

    @staticmethod
    @torch.no_grad()
    def embed(data):

        multivectors = embed_oriented_plane(normal=data.orientation, position=data.pos).view(-1, 1, 16)
        scalars = data.scalar_feature.view(-1, 1)

        return multivectors, scalars

    @staticmethod
    def dislodge(multivectors, scalars):
        return extract_translation(multivectors).squeeze()

### 2.1 Tokenizer

The interpolation layer performs a convex combination of multivectors which is compatible with the projective geometric algebra.

In [None]:
def interp(
    mlp: torch.nn.Module,
    multivectors: torch.Tensor,
    multivectors_skip: torch.Tensor,
    scalars: torch.Tensor,
    scalars_skip: torch.Tensor,
    pos_source: torch.Tensor,
    pos_target: torch.Tensor,
    data: Data,
    scale_id: int,
    reference_multivector: torch.Tensor
) -> torch.Tensor:

    pos_diff = pos_source[data[f'scale{scale_id}_interp_source']] - pos_target[data[f'scale{scale_id}_interp_target']]
    squared_pos_dist = torch.clamp(torch.sum(pos_diff ** 2, dim=-1), min=1e-16).view(-1, 1, 1)

    denominator = scatter(1. / squared_pos_dist, data[f'scale{scale_id}_interp_target'].long(), dim=0, reduce='sum')

    multivectors = scatter(
        multivectors[data[f'scale{scale_id}_interp_source']] / squared_pos_dist,
        data[f'scale{scale_id}_interp_target'].long(),
        dim=0,
        reduce='sum'
    ) / denominator

    scalars = scatter(
        scalars[data[f'scale{scale_id}_interp_source']] / squared_pos_dist.view(-1, 1),
        data[f'scale{scale_id}_interp_target'].long(),
        dim=0,
        reduce='sum'
    ) / denominator.view(-1, 1)

    multivectors = torch.cat((multivectors, multivectors_skip), dim=-2)
    scalars = torch.cat((scalars, scalars_skip), dim=-1)

    return mlp(multivectors, scalars, reference_mv=reference_multivector)

The pooling layer is implemented via PyG message passing. We just have to embed the relative position between points as multivector.

In [None]:
class PointCloudPooling(PointCloudPooling):

    def message(
        self,
        x_j: torch.Tensor,
        pos_i: torch.Tensor,
        pos_j: torch.Tensor,
        scalars_j: torch.Tensor,
        reference_multivector_j: torch.Tensor
    ) -> torch.Tensor:

        multivectors, scalars = self.mlp(
            torch.cat((x_j, embed_translation(pos_j - pos_i).unsqueeze(-2)), dim=-2),
            scalars=scalars_j,
            reference_mv=reference_multivector_j
        )

        return multivectors, scalars

    def aggregate(self, inputs: tuple, index: torch.Tensor, ptr=None, dim_size=None) -> torch.Tensor:
        multivectors, scalars = (self.aggr_module(tensor, index, ptr=ptr, dim_size=dim_size, dim=self.node_dim) for tensor in inputs)

        return multivectors, scalars

The complete tokenization module now just wraps pooling and interpolation. For reflection equivariance, we have to keep track of a reference multivector. Furthermore, we support extracting a global class token for classification tasks. This is a lot more efficient than taking the global average of vertex-wise outputs.

In [None]:
class Tokeniser(torch.nn.Module):
    def __init__(self, geometric_algebra_interface: object, d_model: int, num_latent_channels=None, dropout_probability=None):
        super().__init__()
        self.geometric_algebra_interface = geometric_algebra_interface()

        num_input_channels = self.geometric_algebra_interface.num_input_channels
        num_output_channels = self.geometric_algebra_interface.num_output_channels

        num_input_scalars = self.geometric_algebra_interface.num_input_scalars
        num_output_scalars = self.geometric_algebra_interface.num_output_scalars

        num_latent_channels = num_latent_channels or d_model

        self.point_cloud_pooling = PointCloudPooling(MLP(
            (num_input_channels + 1, num_latent_channels, d_model),
            num_input_scalars,
            num_output_scalars=num_input_scalars,
            plain_last=False,
            use_norm_in_first=False,
            dropout_probability=dropout_probability
        ), node_dim=0)

        self.mlp = MLP(
            (d_model + num_input_channels, *[num_latent_channels] * 2, num_output_channels),
            num_input_scalars=2 * num_input_scalars,
            num_output_scalars=num_output_scalars,
            use_norm_in_first=False,
            dropout_probability=dropout_probability
        )

        self.cache = None

    def forward(self, data: Data) -> torch.Tensor:
        multivectors, scalars = self.geometric_algebra_interface.embed(data)

        self.cache = {
            'multivectors': multivectors,
            'scalars': scalars,
            'data': data,
            'reference_multivector': self.construct_reference_multivector(multivectors, data.batch)
        }

        (multivectors, scalars), self.cache['pos'] = pool(
            self.point_cloud_pooling,
            multivectors,
            data.pos,
            data,
            scale_id=0,
            scalars=scalars,
            reference_multivector=self.cache['reference_multivector']
        )

        return multivectors, scalars, self.cache['reference_multivector'][data.scale0_sampling_index]

    @staticmethod
    def construct_reference_multivector(x: torch.Tensor, batch=None):

        if batch is None:
            reference_multivector = x.mean(dim=(0,1)).expand(x.size(0), 1, -1)

        else:
            reference_multivector = scatter(x, batch, dim=0, reduce='mean').mean(dim=1, keepdim=True)[batch]

        return reference_multivector

    def lift(self, multivectors: torch.Tensor, scalars: torch.Tensor) -> torch.Tensor:

        if multivectors.size(0) == self.cache['data'].scale0_sampling_index.numel():
            multivectors, scalars = interp(
                self.mlp,
                multivectors,
                self.cache['multivectors'],
                scalars,
                self.cache['scalars'],
                self.cache['pos'],
                self.cache['data'].pos,
                self.cache['data'],
                scale_id=0,
                reference_multivector=self.cache['reference_multivector']
            )

        else:
            multivectors, scalars = self.extract_class(
                self.mlp,
                multivectors,
                self.cache['multivectors'],
                scalars,
                self.cache['scalars'],
                self.cache['data']
            )

        return self.geometric_algebra_interface.dislodge(multivectors, scalars)

    def extract_class(
        self,
        mlp: MLP,
        multivectors: torch.Tensor,
        multivectors_skip: torch.Tensor,
        scalars: torch.Tensor,
        scalars_skip: torch.Tensor,
        data: Data
    ) -> torch.Tensor:

        if data.batch is None:
            multivectors_skip = multivectors_skip.mean(dim=0, keepdim=True)
            scalars_skip = scalars_skip.mean(dim=0, keepdim=True)

            reference_multivector = self.cache['reference_multivector'][0:1]

        else:
            multivectors_skip = scatter(multivectors_skip, data.batch, dim=0, reduce='mean')
            scalars_skip = scatter(scalars_skip, data.batch, dim=0, reduce='mean')

            reference_multivector = self.cache['reference_multivector'][data.ptr[:-1]]

        multivectors = torch.cat((multivectors, multivectors_skip), dim=-2)
        scalars = torch.cat((scalars, scalars_skip), dim=-1)

        return mlp(multivectors, scalars, reference_mv=reference_multivector)



### 2.2 LaB-GATr

Now we can assemble the complete LaB-GATr model. It is comprised of tokenization and a GATr backend. For correct handling of PyG-style batching, we use block-diagonal attention masks. Another highlight is the class token support which we implement via a ```class_token_forward_wrapper```.

In [None]:
class LaBGATr(torch.nn.Module):

    def __init__(
        self,
        geometric_algebra_interface: object,
        d_model: int,
        num_blocks: int,
        num_attn_heads: int,
        num_latent_channels=None,
        use_class_token: bool = False,
        dropout_probability=None
    ):
        super().__init__()

        num_latent_channels = num_latent_channels or d_model

        self.tokeniser = Tokeniser(
            geometric_algebra_interface,
            d_model,
            num_latent_channels=4 * num_latent_channels,
            dropout_probability=dropout_probability
        )

        self.gatr = gatr.GATr(
            in_mv_channels=d_model,
            out_mv_channels=d_model,
            hidden_mv_channels=num_latent_channels,
            in_s_channels=geometric_algebra_interface.num_input_scalars,
            out_s_channels=geometric_algebra_interface.num_input_scalars,
            hidden_s_channels=4 * num_latent_channels,
            attention=gatr.SelfAttentionConfig(num_heads=num_attn_heads),
            mlp=gatr.MLPConfig(),
            num_blocks=num_blocks,
            dropout_prob=dropout_probability
        )

        if use_class_token:
            self.gatr.forward = class_token_forward_wrapper(self.gatr.forward)

        self.num_parameters = sum(parameter.numel() for parameter in self.parameters() if parameter.requires_grad)
        print(f"LaB-GATr ({self.num_parameters} parameters)")

    def forward(self, data: Data) -> torch.Tensor:
        multivectors, scalars, reference_multivector = self.tokeniser(data)

        multivectors, scalars = self.gatr(
            multivectors,
            scalars=scalars,
            attention_mask=self.get_attn_mask(data),
            join_reference=reference_multivector
        )

        return self.tokeniser.lift(multivectors, scalars)

    @staticmethod
    def get_attn_mask(data: Data):

        if data.batch is None:
            attn_mask = None

        else:
            batch = data.batch[data.scale0_sampling_index]
            attn_mask = BlockDiagonalMask.from_seqlens(torch.bincount(batch).tolist())

        return attn_mask

## 3. Usage

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Running test on {device}!")

model = LaBGATr(GeometricAlgebraInterface, d_model=8, num_blocks=10, num_attn_heads=4, use_class_token=False).to(device)

Generate some output with the dummy data to verify that the model functions. Training or inference from here on is the same as any PyTorch model.


In [None]:
output = model(dummy_data.to(device))

print(output.shape)