In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
from pathlib import Path

import awkward as ak
import numpy as np
import vector
from omegaconf import OmegaConf

sys.path.append("/data/dust/user/rosehenn/gabbro")

## Tokenization with the VQ-VAE

In [None]:
# --- Load the tokenizer model from checkpoint, and also get the feature_dict from the config ---
from gabbro.models.vqvae import VQVAELightning

ckpt_path = "/data/dust/user/rosehenn/gabbro_output/TokTrain/runs/2024-09-21_16-54-39_max-wng062_CerousLocknut/checkpoints/epoch_231_loss_0.17179.ckpt"

vqvae_model = VQVAELightning.load_from_checkpoint(ckpt_path)
vqvae_model.eval()

In [None]:
cfg = OmegaConf.load(Path(ckpt_path).parent.parent / "config.yaml")
pp_dict = OmegaConf.to_container(cfg.data.dataset_kwargs_common.feature_dict)
print("\npp_dict:")
for item in pp_dict:
    print(item, pp_dict[item])

# get the cuts from the pp_dict (since this leads to particles being removed during
# preprocessing/tokenization), thus we also have to remove them from the original jets
# when we compare the tokenized+reconstructed particles to the original ones)
pp_dict_cuts = {
    feat_name: {
        criterion: pp_dict[feat_name].get(criterion)
        for criterion in ["larger_than", "smaller_than"]
    }
    for feat_name in pp_dict
}

print("\npp_dict_cuts:")
for item in pp_dict_cuts:
    print(item, pp_dict_cuts[item])

print("\nModel:")
print(vqvae_model)

### Load shower file

In [None]:
from gabbro.data.loading import read_shower_file

filename_in = "/data/dust/user/rosehenn/gabbro/notebooks/array_real.parquet"
showers = ak.from_parquet(filename_in)
showers = showers[:5000]
# part_features_ak = ak_select_and_preprocess(data_showers, pp_dict_cuts)[:, :128]

## Tokenize and reconstruct showers

In [None]:
# tokenization and reconstruction

part_features_ak_tokenized = vqvae_model.tokenize_ak_array(
    ak_arr=showers,
    pp_dict=pp_dict,
    batch_size=4,
    pad_length=1700,
)
# note that if you want to reconstruct tokens from the generative model, you'll have
# to remove the start token from the tokenized array, and subtract 1 from the tokens
# (since we chose the convention to use 0 as the start token, so the tokens from the
# generative model are shifted by 1 compared to the ones from the VQ-VAE)
part_features_ak_reco = vqvae_model.reconstruct_ak_tokens(
    tokens_ak=part_features_ak_tokenized,
    pp_dict=pp_dict,
    batch_size=4,
    pad_length=1700,
)

In [None]:
# inspect the tokenized and reconstructed Showers
print("First 5 tokenized Showers:")
for i in range(5):
    print(part_features_ak_tokenized[i])

print("\nFirst 5 reconstructed Showers:")
for i in range(5):
    print(part_features_ak_reco[i])

## Plot the reconstructed showers

In [None]:
from gabbro.plotting.feature_plotting import plot_paper_plots

fig = plot_paper_plots(
    feature_sets=[showers[: len(part_features_ak_reco)], part_features_ak_reco],
    labels=["Geant4", "Tokenized"],  # "OmniJet-$\\alpha_C$" "BIB-AE", "L2L Flows"
    colors=["lightgrey", "#1a80bb", "#ea801c", "#4CAF50", "#1a80bb"],
)
fig.show()