# LaB-GATr Basic Usage

Before using this notebook, install the correct dependencies and LaB-GATr as follows:

Optional new Anaconda environment
```
conda create --name lab-gatr python=3.10
conda activate lab-gatr
``` 
Next, install PyTorch and xFormers and other libraries
```
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cu121
pip install xformers==0.0.22.post7 --index-url https://download.pytorch.org/whl/cu121
pip install torch_geometric==2.4.0
pip install torch_scatter torch_cluster -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
```
Install LaB-GATr itself, which also install GATr
```
pip install .
```
Additionally, if you have made a new Anaconda environment, install Jupyter
```
pip install jupyter jupyterlab
```


In [None]:
import torch
from lab_gatr import PointCloudPoolingScales, LaBGATr
import torch_geometric as pyg
from gatr.interface import embed_oriented_plane, extract_translation

Let us first create a dummy mesh: n positions and orientations (e.g. surface normal) and an arbitrary scalar feature (e.g. geodesic distance).

In [None]:
n = 1000

pos, orientation = torch.rand((n, 3)), torch.rand((n, 3))
scalar_feature = torch.rand(n)

Next, a point cloud pooling transform for the tokenisation (patching). Also apply this as a transform to create the dummy data.

In [None]:
transform = PointCloudPoolingScales(rel_sampling_ratios=(0.2,), interp_simplex='triangle')
dummy_data = transform(pyg.data.Data(pos=pos, orientation=orientation, scalar_feature=scalar_feature))

A geometric algebra interface to embed your data in $\mathbf{G}(3, 0, 1)$.

In [None]:
class GeometricAlgebraInterface:
    num_input_channels = num_output_channels = 1
    num_input_scalars = num_output_scalars = 1

    @staticmethod
    @torch.no_grad()
    def embed(data):

        multivectors = embed_oriented_plane(normal=data.orientation, position=data.pos).view(-1, 1, 16)
        scalars = data.scalar_feature.view(-1, 1)

        return multivectors, scalars

    @staticmethod
    def dislodge(multivectors, scalars):
        return extract_translation(multivectors).squeeze()


Create a model instance with dimensionality 8, 10 GATr blocks, and 4 attention heads in each block. It uses the geometric algebra interface we defined above.

In [None]:
model = LaBGATr(GeometricAlgebraInterface, d_model=8, num_blocks=10, num_attn_heads=4, use_class_token=False)

Generate some output with the dummy data to verify that the model functions. Training or inference from here on is the same as any PyTorch model.

In [None]:
output = model(dummy_data)
print(output.shape)