# L-GATr-slim Quickstart
# [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/heidelberg-hepml/lgatr/blob/main/examples/demo_lgatr_slim.ipynb)

In this tutorial, we give a quick introduction for how to use L-GATr-slim. L-GATr-slim is a Lorentz-equivariant transformer for applications in high-energy physics and other domains where Lorentz symmetry is relevant.

`LGATrSlim` is build on a mix of scalar and vector representations. The design is inspired by the original `LGATr` network and closely follows it, but deviates in a few critical points: 

The main difference is that `LGATrSlim` uses only scalar and vector representations, whereas `LGATr` uses the unifying multivector representations, that contain also bivector, axialvector, and pseudoscalar representations. However, we find that these representations are not required in most practical high-energy physics tasks, and link the axialvector to vector and pseudoscalar to scalar representations to reduce the network size, effectively breaking $O(1,3)$-equivariance to $SO^+(1,3)$-equivariance.

In principle, `LGATrSlim` could be implemented to be completely equivalent to `LGATr` where the higher-order representations are set to zero. Our implementation mostly follows this approach, but deviates in a few aspects:
- `LGATrSlim`'s `Linear` simply multiplies scalars and vectors with one common weight, whereas `LGATr`'s `EquiLinear` creates a 16x16 matrix weight for multivectors for all possible interactions. Most entries in this matrix are zero and therefore significantly increase the FLOPs, but the `opt_einsum` package allows for an efficient implementation. The two implementations yield equivalent results, but the `LGATrSlim` linear layer avoids unnecessary operations, making it more efficient.
- `LGATrSlim`'s `GatedLinearUnit` replaces `LGATr`'s mix of `GeometricProduct`, `EquiLinear` and `ScalarGatedNonlinearity` in the MLP. First, this means that `LGATrSlim` can not express outer products. We find that this restriction does not affect the network performance in the tasks that we tested, but significantly speeds up the network. Second, the `GatedLinearUnit` uses gated nonlinearities also for the scalar channels, and constructs the vector gates from an inner product instead of using scalar channels directly.
- `LGATrSlim`'s `RMSNorm` computes one unified norm for scalars and vectors, where `LGATr`'s `EquiLayerNorm` normalizes multivectors and scalars seperately.

In [None]:
# install the lgatr package
%pip install lgatr

After importing the required modules, we construct a `LGATrSlim` encoder module. The user can specify the number of input, output, and hidden channels for scalars and vectors, as well as the number of heads and transformer blocks. More hyperparameters can be found in the documentation.

In [1]:
from lgatr import LGATrSlim

lgatrslim = LGATrSlim(
    in_v_channels=1,
    out_v_channels=1,
    hidden_v_channels=8,
    in_s_channels=1,
    out_s_channels=1,
    hidden_s_channels=16,
    num_blocks=2,
    num_heads=1,
)

We now test `LGATrSlim` on toy data, e.g. a bunch of LHC events. We create particles with fixed mass and gaussian noise as momentum. The resulting four-momenta have shape `p.shape = (128, 20, 1, 4)`; for batch size 128, 20 particles per jet, 1 four-momentum per particle, and 4 numbers for the four-momentum. We also generated random particle types `pid` with `pid.shape = (128, 20, 1)`. More generally, `LGATrSlim` operates on vectors of shape `(batch_size, num_particles, num_v_channels, 4)` and scalars of shape `(batch_size, num_particles, num_s_channels)`, while normal transformers operate on `(batch_size, num_particles, num_channels)`, without the extra 'vector' dimension.

In [2]:
# generate toy data
import torch

p3 = torch.randn(128, 20, 1, 3)
mass = 1
E = (mass**2 + (p3**2).sum(dim=-1, keepdim=True)) ** 0.5
p = torch.cat((E, p3), dim=-1)
pid = torch.randint(high=3, size=p3.shape[:-1]).float()
print(p.shape)  # torch.Size([128, 20, 1, 4])
print(pid.shape)  # torch.Size([128, 20, 1])

torch.Size([128, 20, 1, 4])
torch.Size([128, 20, 1])


We can now process the vector with the `LGATrSlim` architecture! It returns new vectors and scalars, from which we can extract the component that we want -- for instance the scalar component for a jet tagging or amplitude regression application, or the vector component for flow matching.

In [3]:
vectors = p
scalars = pid
output_v, output_s = lgatrslim(vectors=vectors, scalars=scalars)
print(output_v.shape)  # torch.Size([128, 20, 1, 4])
print(output_s.shape)  # torch.Size([128, 20, 1])

torch.Size([128, 20, 1, 4])
torch.Size([128, 20, 1])


The design choice of using only scalar and vector channels in `LGATrSlim` allows a particularly efficient implementation. We add a `compile` option, defaulting to `compile=False`, that triggers an internal `self.__class__ = torch.compile(self.__class__, dynamic=True)` call in the `LGATrSlim` constructor. `torch.compile` automatically selects the most efficient kernels and fuses operations. We find significant speed gains and use `compile=True` in all our production runs.

In [4]:
lgatrslim = LGATrSlim(
    in_v_channels=1,
    out_v_channels=1,
    hidden_v_channels=8,
    in_s_channels=1,
    out_s_channels=1,
    hidden_s_channels=16,
    num_blocks=2,
    num_heads=1,
    compile=True,
)

The price of `torch.compile` is the initial compilation that happens on the first forward pass, and typically takes between 10 seconds and 2 minutes. The code does not have to be re-compiled later on when the network is called with different shapes, thanks to the option `dynamic=True` in `torch.compile`.

In [5]:
import time

t0 = time.time()
output_v, output_s = lgatrslim(vectors=vectors, scalars=scalars)
dt = time.time() - t0
print(output_v.shape)  # torch.Size([128, 20, 1, 4])
print(output_s.shape)  # torch.Size([128, 20, 1])
print(f"First iteration (including compilation) takes {dt:.2}s")

t0 = time.time()
output_v, output_s = lgatrslim(vectors=vectors, scalars=scalars)
dt = time.time() - t0
print(f"Second iteration takes {dt:.2}s")

torch.Size([128, 20, 1, 4])
torch.Size([128, 20, 1])
First iteration (including compilation) takes 7.4s
Second iteration takes 0.0042s


Thats it, now you're ready to build your own `LGATrSlim` model! 