In [1]:
import h5py
import awkward as ak
import energyflow as ef
import fastjet as fj
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import pandas as pd
import vector
from cycler import cycler
import uproot
import logging
from pathlib import Path
from omegaconf import OmegaConf
from collections import deque
from pathlib import Path
import time

In [2]:
import sys
import os

project_root = os.path.abspath("..")
if project_root not in sys.path:
    sys.path.append(project_root)

### Functions to process JetClass data

In [3]:
# Read jet histories and filter the ones clustered into only one inclusive jet
# which often result in lower pT jets (i.e. outliers)
def read_histories(cluster_seq, pt_lower=0, pt_upper=np.inf): 

    histories = ak.Array(cluster_seq.unique_history_order())
    inc_jets = cluster_seq.inclusive_jets()

    # Get number of inclusive jets per reclustered jet
    num_inc_jets = ak.num(inc_jets)

    # Get the pt of the first jet
    jet_pts = ak.firsts(inc_jets.pt)

    # Filtering mask
    mask = (num_inc_jets == 1) & (pt_lower < jet_pts) & (jet_pts < pt_upper)

    # Filter histories and indices
    filtered_histories = histories[mask]
    filtered_indices = np.nonzero(mask)[0]

    return filtered_histories, filtered_indices


# Return list of declustering steps (pseudo particle decays) from jets
# Utilises cluster.jets() and cluster.unique_history_order()
def return_decays(jet_arrays, jet_hists, num_consts, dummy_momentum, return_indices=True):
    decays = []
    if return_indices == True:
        for i in range(len(jet_arrays)):
            hist = jet_hists[i]
            n_consts = num_consts[i]
            last_idx = hist[-1]

            stack = []
            jet_decays = []

            for idx in hist:
                # End of the cluster history reached
                if idx == last_idx:
                    break
                if idx < n_consts:
                    # Constituent, no decay
                    stack.append(idx)

                    # Inserting at beginning of the list to get reversed order
                    # Constituent index and two placeholders, since no further decay
                    jet_decays.insert(0, [idx, None, None])
                else:
                    # Merge of the two last items in the stack
                    left = stack.pop()
                    right = stack.pop()
                    stack.append(idx)

                    # Insert pseudojet index and two children
                    jet_decays.insert(0, [idx, left, right])
 
            decays.append(jet_decays)

        return decays

    # Return 4-momentum instead of tokens
    else:
        for i in range(len(jet_arrays)):
            hist = jet_hists[i]
            jet_array = jet_arrays[i]
            n_consts = num_consts[i]

            stack = []
            jet_decays = []

            for idx in hist:
                # End of the cluster history reached
                if idx == hist[-1]:
                    break

                if idx < n_consts:
                    # Constituent, no decay
                    stack.append(idx)
                    # Append parent 4-momentum and two dummy values
                    jet_decays.append([jet_array[idx], dummy_momentum, dummy_momentum])

                else:
                    # Merge of the two last items in the stack
                    left = stack.pop()
                    right = stack.pop()
                    stack.append(idx)
                    jet_decays.append([jet_array[idx], jet_array[left], jet_array[right]])
                        
        return decays
    

# Insert the corresponding tokens for the declustering steps of each jet    
def tokenise_decays(jet_decays, token_arrays, codebook_size):
    end_token = codebook_size
    masked_decays = ak.fill_none(jet_decays, -1)

    # Get the tokens to the same shape as the decays
    masked_tokens = token_arrays[ak.local_index(token_arrays, axis=1)][ak.where(masked_decays != -1, masked_decays, 0)]

    result = ak.where(masked_decays == -1, end_token, masked_tokens)
    
    # Add start and end row to the tokenised decays
    # Start row: list of zeros
    # End row: list of end tokens (codebook size + 1)
    return ak.concatenate(
    [
        ak.zeros_like(result[:, :1]),
        result + 1,
        ak.ones_like(result[:, :1]) + end_token,
    ],
    axis=1,
    )

### Reading unique history order when jets are reclustered into multiple inclusive jets

In [None]:
def read_histories(cluster_seq, num_jets): 
    corrected_histories = []

    histories = ak.Array(cluster_seq.unique_history_order())
    num_consts = cluster_seq.n_particles()
    jet_arrays = cluster_seq.jets()
    inc_jets = cluster_seq.inclusive_jets()

    for i, hist in enumerate(histories[:num_jets]):
        jet_history = hist.tolist()
        jet_array = jet_arrays[i]
        length_array = len(jet_array)

        # Find the index at which the inclusive jet in the cluster history ends
        num_inc_jets = len(inc_jets[i])
        inc_jet_pt = inc_jets[i].pt
        jet_array_pt = jet_array.pt
        cluster_ends = []

        # Account for .jets() being reindexed after clustering to fix mismatch with .unique_history_order()
        if num_inc_jets > 1:
            # Get indices for the inclusive jet pt from .jets_out()
            inc_indices = [i for i, x in enumerate(jet_array_pt) if x in inc_jet_pt]
            # Loop backwards through history
            for cluster_step in reversed(range(len(jet_history)-1)):
                # If index of inclusive jet is found, the one after it is a placeholder index
                if jet_history[cluster_step] in inc_indices:
                    # Ensure that the next element is not a constituent, but a pseudojet
                    if jet_history[cluster_step+1] > num_consts[i]:
                        idx_to_adjust = jet_history[cluster_step+1]
                        # Get the index for the placeholder
                        cluster_ends.append(cluster_step+1)
                        # And adjust the history by subtracting 1 from every index larger than the placeholder
                        jet_history = [x if x <= idx_to_adjust else x - 1 for x in jet_history]
            
            # Go through the placeholder indices and set them to the length of the whole array
            # This means each inclusive jet gets the same ending placeholder
            for cluster_end in cluster_ends:
                jet_history[cluster_end] = length_array
            
            jet_histories = []
            current_hist = []
            for x in jet_history:
                current_hist.append(x)
                if x == length_array:
                    jet_histories.append(current_hist)
                    current_hist = []
            corrected_histories.append(jet_histories)
        else:
            corrected_histories.append([jet_history])

    return ak.Array(corrected_histories)

### Load the VQ-VAE model for tokenisation

In [None]:
from gabbro.models.vqvae import VQVAELightning

# this checkpoint is the checkpoint from a tokenization training
ckpt_path = "../checkpoints/vqvae_12288_tokens/last.ckpt"
cfg = OmegaConf.load(Path(ckpt_path).parent / "config.yaml")
pp_dict = OmegaConf.to_container(cfg.data.dataset_kwargs_common.feature_dict)

pp_dict_cuts = {
    feat_name: {
        criterion: pp_dict[feat_name].get(criterion)
        for criterion in ["larger_than", "smaller_than"]
    }
    for feat_name in pp_dict
}

pp_dict_transform = {
    feat_name: {
        key: value 
        for key, value in feat_settings.items() 
        if key not in ["larger_than", "smaller_than"]
        }
    for feat_name, feat_settings in pp_dict.items()
}

# hacky way to setup logging in jupyter
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
logger.info("Setup complete")
vqvae_model = VQVAELightning.load_from_checkpoint(ckpt_path)

print("\nModel:")
print(vqvae_model)
vqvae_model.eval()

### Process dataset (extract declustering sequences + tokenise jets and insert tokens)

In [None]:
from gabbro.utils.arrays import ak_select

# Define the dataset directory
dataset_dir = "jetclass/"
subdirs = ["test_20M/", "train_100M/", "val_5M/"]

# How many files per folder & jets per file to load and process
num_files = 3
n_load = 10000

# How many jets to process at a time, eg. 20% 
# In order to avoid running out of memory
batch_size = n_load // 5

# Dummy 4-momentum object for children of constituents
zero_momentum = ak.zip(
    {"px": 0, "py": 0, "pz": 0, "E": 0},
    with_name="Momentum4D",
    behavior=vector.backends.awkward.behavior,
)

# Codebook size from the VQ-VAE
codebook_size = np.float32(vqvae_model.model.vqlayer.num_codes)

###---------------------------------------------------------###
### Open one file at a time and extract particle 4-momentum ###
###---------------------------------------------------------###

print(f"Number of files to process: {num_files}")
print(f"Using folders {subdirs}\n")

# Go through each file at a time
for subdir in subdirs:
    print(f"Now processing files in: {dataset_dir+subdir}")
    files = os.listdir(dataset_dir + subdir)
    sorted_files = sorted(files)
    output_dir = Path(dataset_dir.replace("jetclass", "jetclass_tokenised") + subdir)

    # Create output directories
    output_dir.mkdir(parents=True, exist_ok=True)

    for filename in sorted_files[:num_files]:
        # Optional: Record time taken per file
        start = time.perf_counter()

        # Generate the name of the parquet file
        filename_parquet = Path(filename).name.replace(".root", "_tokenised.parquet")
        if filename_parquet in os.listdir(output_dir):
            print(f"File {filename_parquet} already present, skipping file.")
        else:
            print(f"Using {n_load} jets from file {filename}")
            filepath = dataset_dir + subdir + filename

            # Open the file and load jets
            file = uproot.open(filepath)
            jets = file["tree"].arrays()[:n_load]

            # Close the root file
            file.close()

            vector.register_awkward()

            # Create 4-momentum vector
            p4 = ak.zip(
                {
                    "px": jets["part_px"],
                    "py": jets["part_py"],
                    "pz": jets["part_pz"],
                    "E": jets["part_energy"],
                },
                with_name="Momentum4D",  
                behavior=vector.backends.awkward.behavior,  
            )

            ###-----------------------------------------------###
            ### Truncate jets to 128 particles and apply cuts ###
            ###-----------------------------------------------###

            p4 = p4[:, :128]

            # For computing relative eta and phi
            p4_jet = ak.sum(p4, axis=1)

            unmasked_particles = ak.zip({"part_pt": p4.pt, "part_etarel": p4.deltaeta(p4_jet), "part_phirel": p4.deltaphi(p4_jet), "mass": p4.mass}, with_name="Momentum4D")

            # Apply preprocessing cuts (without applying transforms)
            mask = ak_select(unmasked_particles, pp_dict_cuts)
            masked_particles = p4[mask]

            ###------------------------------------------------------------###
            ### Cluster and filter out jets with more than 1 inclusive jet ###
            ###------------------------------------------------------------###

            #  Define the clustering algorithm (kt algorithm with R=0.8)
            jet_def = fj.JetDefinition(fj.kt_algorithm, 0.8, fj.WTA_pt_scheme)
            print(f"Particles from {n_load} jets are being clustered with the following algorithm:\n{jet_def}")

            cluster = fj.ClusterSequence(masked_particles[:n_load], jet_def)

            # Get jets and constituents
            jets_out = cluster.inclusive_jets()
            consts_out = cluster.constituents()
            num_particles = cluster.n_particles()
            jet_structure_array = cluster.jets()

            # Get histories and indices of jets between 500 and 1000 pt
            pt_cuts = {"pt_lower": 500, "pt_upper": 1000}
            jet_hists, jet_indices  = read_histories(cluster, pt_cuts["pt_lower"], pt_cuts["pt_upper"])

            # Filter the jet structure array to only keep single inclusive jets
            jets_filtered = jet_structure_array[jet_indices]
            num_particles_filtered = num_particles[jet_indices]
            inc_jets_filtered = jets_out[jet_indices]

            ###------------------###
            ### Jet tokenisation ###
            ###------------------###

            # Using batchwise processing
            results = []
            for i in range(0, len(jets_out), batch_size):
                print(f"Tokenising current batch: {i} - {i+batch_size}")
                # Get the respective inclusive jet for calculating relative eta/phi
                p4_inc_jets = ak.firsts(inc_jets_filtered[i:i+batch_size])
                jets_batch = jets_filtered[i:i+batch_size]
                
                jets_ak = ak.zip(
                    {
                        "part_pt": jets_batch.pt, 
                        "part_etarel": jets_batch.deltaeta(p4_inc_jets), 
                        "part_phirel": jets_batch.deltaphi(p4_inc_jets)
                    }, 
                    with_name="Momentum4D")
                
                # Tokenise jets
                jets_tokenized = vqvae_model.tokenize_ak_array(
                    ak_arr=jets_ak,
                    pp_dict=pp_dict_transform,
                    batch_size=512,
                    pad_length=256,
                )

                results.append(jets_tokenized)

            jets_tokenised = ak.concatenate(results)

            ###----------------------------------------------###
            ### Extract declustering steps and insert tokens ###
            ###----------------------------------------------###

            results = []
            for i in range(0, len(jets_filtered), batch_size):
                print(f"Extracting decays and inserting tokens for current batch: {i} - {i+batch_size}")
                batch_jets = jets_filtered[i:i+batch_size]
                batch_hists = jet_hists[i:i+batch_size]
                batch_num_parts = num_particles_filtered[i:i+batch_size]
                batch_tokens = jets_tokenised[i:i+batch_size]
                
                # Returns awkward arrays with the decluster sequence per jet
                # Declustering steps are represented by triplets: [parent, left, right]
                ak_decays = return_decays(batch_jets, batch_hists, batch_num_parts, zero_momentum)
                
                # Convert jets from triplet structure to 1-dimensional arrays
                ak_decays = ak.flatten(ak_decays, axis=2)

                # Replace indices with tokens
                ak_tokens = tokenise_decays(ak_decays, batch_tokens, codebook_size)
                results.append(ak_tokens)

            ak_tokens = ak.concatenate(results)
            print(f"Successfully tokenised {len(ak_tokens)} jets.")

            # Release memory
            del jets, p4, unmasked_particles, mask, masked_particles, cluster, jets_out
            del consts_out, num_particles, jet_structure_array, jet_hists, jet_indices, jets_filtered, jets_tokenised
            del results, batch_jets, batch_hists, batch_num_parts, batch_tokens, ak_decays, num_particles_filtered

            ###-------------------------###
            ### Store as .parquet files ###
            ###-------------------------###

            # Generate the path for saving the file
            filename_out = output_dir / filename_parquet

            # Record the time it took to process one file
            end = time.perf_counter()
            print(f"File: {filename} finished processing after {end - start:.4f} seconds")

            # Save the processed data to file
            print(f"Saving tokenised file to {filename_out}")
            ak.to_parquet(ak_tokens, filename_out)
            print("Saving completed.")

print("Done.")