# Conditional L-GATr Quickstart
# [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/heidelberg-hepml/lgatr/blob/main/examples/demo_conditional_lgatr.ipynb)

This tutorial is a quick introduction into using conditional L-GATr models, building on the simpler [non-conditional L-GATr](https://github/heidelberg-hepml/lgatr/blob/main/examples/demo_lgatr.ipynb).

In [None]:
# install the lgatr package
%pip install lgatr

In addition to the normal `LGATr` encoder module `lgatr`, we now create a `ConditionalLGATr` decoder module `conditional_lgatr`. We will first process the condition with the `lgatr` encoder, and then process it together with the main data using the `conditional_lgatr`.
Note that we set `out_mv_channels=hidden_mv_channels`, `out_s_channels=hidden_s_channels` for `LGATr`, and `condition_mv_channels=hidden_mv_channels`, `condition_s_channels=hidden_s_channels` for `ConditionalLGATr`. This is because we do not want to enforce a bottleneck for the condition.

In [None]:
# construct LGATr and ConditionalLGATr modules
from lgatr import LGATr, ConditionalLGATr

attention = dict(num_heads=2)
crossattention = dict(num_heads=2)
mlp = dict()
lgatr = LGATr(
   in_mv_channels=1,
   out_mv_channels=8,
   hidden_mv_channels=8,
   in_s_channels=0,
   out_s_channels=16,
   hidden_s_channels=16,
   attention=attention,
   mlp=mlp,
   num_blocks=2,
)
conditional_lgatr = ConditionalLGATr(
   in_mv_channels=1,
   condition_mv_channels=8,
   out_mv_channels=1,
   hidden_mv_channels=8,
   in_s_channels=0,
   out_s_channels=0,
   condition_s_channels=16,
   hidden_s_channels=16,
   attention=attention,
   crossattention=crossattention,
   mlp=mlp,
   num_blocks=2,
)

Similar to the `LGATr` notebook, we test our model on toy data of shape `p.shape = (128, 20, 1, 4)`; for batch size 128, 20 particles per jet, 1 four-momentum per particle, and 4 numbers for the four-momentum. We add another set of 40 particles as the condition, `p_c.shape = (128, 40, 1, 4)`.

In [3]:
# generate toy data
import torch
p3 = torch.randn(128, 20, 1, 3)
mass = 1
E = (mass**2 + (p3**2).sum(dim=-1, keepdim=True))**0.5
p = torch.cat((E, p3), dim=-1)
p3_c = torch.randn(128, 40, 1, 3)
E_c = (mass**2 + (p3_c**2).sum(dim=-1, keepdim=True))**0.5
p_c = torch.cat((E_c, p3_c), dim=-1)
print(p.shape, p_c.shape)

torch.Size([128, 20, 1, 4]) torch.Size([128, 40, 1, 4])


We now embed the four-momentum and the condition into multivectors.

In [4]:
from lgatr.interface import embed_vector, extract_scalar
multivector = embed_vector(p)
multivector_condition = embed_vector(p_c)
print(multivector.shape, multivector_condition.shape)

torch.Size([128, 20, 1, 16]) torch.Size([128, 40, 1, 16])


We can now process the data with our conditional L-GATr model. First, we process the condition with the `lgatr` encoder. We obtain an embedding of the condition in a high-dimensional latent space. We then process this condition together with the main network input using the `conditional_lgatr` decoder. Finally, we extract the scalar part.

In [5]:
# encoder (lgatr)
condition_mv, condition_s = lgatr(multivectors=multivector_condition, scalars=None)
print(condition_mv.shape, condition_s.shape)

# decoder (conditional_lgatr)
output_mv, output_s = conditional_lgatr(
    multivectors=multivector,
    multivectors_condition=condition_mv,
    scalars_condition=condition_s,
)
out = extract_scalar(output_mv)
print(out.shape)

torch.Size([128, 40, 8, 16]) torch.Size([128, 40, 16])
torch.Size([128, 20, 1, 1])


Thats it, now you're ready to build your own `ConditionalLGATr` model! 