In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
from pathlib import Path

import awkward as ak
import numpy as np
import vector
from omegaconf import OmegaConf

sys.path.append("/data/dust/user/rosehenn/gabbro")

## Shower generation with a trained OmniJet model

This notebook provides a short example on how to load a trained OmniJet model with the next-token-prediction head and generate jets with it.

In [None]:
from gabbro.models.backbone import BackboneNextTokenPredictionLightning

# this checkpoint is the checkpoint from a backbone training with the nex-token-prediction head
# make sure you have downloaded the checkpoint in advance
# if not, run the script `checkpoints/download_checkpoints.sh`
ckpt_path = "/data/dust/user/rosehenn/gabbro_output/full_resolution/runs/2024-11-21_13-49-55_max-wng060_TerminativeCirculation/checkpoints/epoch_032_loss_4.10881.ckpt"
gen_model = BackboneNextTokenPredictionLightning.load_from_checkpoint(ckpt_path)
gen_model.eval()

## Generating Showers

In [None]:
generated_showers = gen_model.generate_n_showers_batched(
    n_showers=2,
    batch_size=2,
    # saveas=save_path,  # use this option if you want to save the awkward array as a parquet file
)

In [None]:
generated_showers

In [None]:
# --- Load the tokenizer model from checkpoint, and also get the feature_dict from the config ---
from gabbro.models.vqvae import VQVAELightning

ckpt_path = "/data/dust/user/rosehenn/gabbro_output/TokTrain/runs/2024-09-21_16-54-39_max-wng062_CerousLocknut/checkpoints/epoch_231_loss_0.17179.ckpt"

vqvae_model = VQVAELightning.load_from_checkpoint(ckpt_path)
vqvae_model.eval()

In [None]:
cfg = OmegaConf.load(Path(ckpt_path).parent.parent / "config.yaml")
pp_dict = OmegaConf.to_container(cfg.data.dataset_kwargs_common.feature_dict)
print("\npp_dict:")
for item in pp_dict:
    print(item, pp_dict[item])

# get the cuts from the pp_dict (since this leads to particles being removed during
# preprocessing/tokenization), thus we also have to remove them from the original jets
# when we compare the tokenized+reconstructed particles to the original ones)
pp_dict_cuts = {
    feat_name: {
        criterion: pp_dict[feat_name].get(criterion)
        for criterion in ["larger_than", "smaller_than"]
    }
    for feat_name in pp_dict
}

print("\npp_dict_cuts:")
for item in pp_dict_cuts:
    print(item, pp_dict_cuts[item])

print("\nModel:")
print(vqvae_model)

In [None]:
# reconstruct the generated tokens to physical features

# note that if you want to reconstruct tokens from the generative model, you'll have
# to remove the start token from the tokenized array, and subtract 1 from the tokens
# (since we chose the convention to use 0 as the start token, so the tokens from the
# generative model are shifted by 1 compared to the ones from the VQ-VAE)
showers_reconstructed = vqvae_model.reconstruct_ak_tokens(
    tokens_ak=generated_showers[:, 1:] - 1,
    pp_dict=pp_dict,
    batch_size=512,
    pad_length=128,
)

In [None]:
showers_reconstructed