# Imports

In [None]:
%load_ext autoreload
%autoreload 2

import os
import math

import matplotlib.pyplot as plt
import numpy as np
import awkward as ak
import pandas as pd
import dask

import vector
import particle
import hepunits

import comet_ml
import zuko
import torch
from torch import nn, optim
import lightning as L
from lightning.pytorch import loggers as pl_loggers
from torch.utils.data import DataLoader, Subset
from torch.optim.lr_scheduler import SequentialLR, LambdaLR, CosineAnnealingLR

import multiprocessing
import uuid

from memflow.dataset.data import ParquetData
from memflow.dataset.dataset import CombinedDataset
from memflow.ttH.ttH_dataclasses import ttHHardDataset, ttHRecoDataset
from memflow.ttH.models.transfer_flow_callbacks import SamplingCallback, BiasCallback

from memflow.ttH.models.TransferCFM import StandardCFM as TransferCFM
from memflow.ttH.models.Transfusion import StandardCFM as Transfusion
from memflow.ttH.models.ParallelTransfusion import StandardCFM as ParallelTransfusion
from memflow.ttH.models.TransferCFM_original import StandardCFM as OriginalCFM
from models.callbacks import ModelCheckpoint

vector.register_awkward()

num_workers = min(16, multiprocessing.cpu_count())  # Use up to 16 CPU cores
print(f'Number of CPU workers for dataloading: {num_workers}')
os.environ["CUDA_VISIBLE_DEVICES"] = "5"  # Change "<n>" to the index of the GPU you want to use on node

print (f"Running on GPU : {torch.cuda.is_available()}")
accelerator = 'cuda' if torch.cuda.is_available() else 'cpu'
print (f"Accelerator : {accelerator}")
torch.set_float32_matmul_precision('medium')
if accelerator =='cuda':
    torch.cuda.empty_cache()
    print (torch.cuda.memory_summary(device=None, abbreviated=True))

# Data

In [None]:
data_hard = ParquetData(
    files = [
        '/cephfs/dice/users/sa21722/datasets/MEM_data/ttH/TF_v6/hard/2018/ttH/ttH_HToInvisible_M125.parquet',
    ],
    lazy = True,
    # N = int(1e5),
)

print (data_hard)

# Hard dataset

In [None]:
hard_dataset = ttHHardDataset(
    data = data_hard,
    selection = [
        # 'higgs',
        # 'tops',
        'bottoms',
        # 'Ws',
        # 'Zs',
        'quarks',
        'neutrinos',
    ],
    build = False,
    fit = True,
    coordinates = 'cylindrical',
    apply_preprocessing = True,
    apply_boost = False,
    dtype = torch.float32,
)

obj = hard_dataset.objects['neutrinos']

# Check its type
print("Type of neutrinos object:", type(obj))

# If it’s a tuple, print its length
if isinstance(obj, tuple):
    print("Length of tuple:", len(obj))
    for i, element in enumerate(obj):
        print(f"Element {i} has type {type(element)} and shape/length:", end=" ")
        # If element is a torch.Tensor or numpy array:
        if hasattr(element, 'shape'):
            print(element.shape)
        # If it's a list of fields:
        else:
            print(element)

In [None]:
print ('Before preprocessing')
hard_dataset.plot(selection=True,raw=True)
print ('After preprocessing')
hard_dataset.plot(selection=True,raw=False)

In [None]:
# This is not strictly necessary, but just to make sure loading works as expected
# We will use later a combined dataset (hard+reco) below
hard_loader = DataLoader(
    hard_dataset,
    batch_size = 32,
    num_workers = num_workers, # Parallel loading
    pin_memory = True, # Faster transfer to GPU
)
batch = next(iter(hard_loader))

for obj,mask,sel in zip(batch['data'],batch['mask'],hard_loader.dataset.selection):
    print (sel,obj.shape,mask.shape)

# Reco dataset

In [None]:
data_reco = ParquetData(
    files = [
        '/cephfs/dice/users/sa21722/datasets/MEM_data/ttH/TF_v6/reco/2018/ttH/ttH_HToInvisible_M125.parquet',
    ],
    lazy = True,
    #N = data_hard.N,
)

print(data_reco)

Have a look at athe minimum values for Jet and MET pT in the raw dataset. This can give an indication as to what the cutoff in the SR is and hence what to set the `'pt':lowercutshift()` to in the pre-processing

In [None]:
# Extract MET pt from raw dataset
raw_met_pt = data_reco['InputMet_pt']
# Print min and max MET pt
print("Raw Minimum MET pt:", ak.min(raw_met_pt))
print("Raw Maximum MET pt:", ak.max(raw_met_pt))

# Extract Jet pt from raw dataset
raw_jet_pt = data_reco['cleanedJet_pt']
# Print min and max MET pt
print("Raw Minimum MET pt:", ak.min(raw_jet_pt))
print("Raw Maximum MET pt:", ak.max(raw_jet_pt))

In [None]:
reco_dataset = ttHRecoDataset(
    data = data_reco,
    selection = [
        'jets',
        'met',
    ],
    build = False,
    fit = True,
    coordinates = 'cylindrical',
    apply_preprocessing = True,
    apply_boost = False,
    dtype = torch.float32,
)
print(reco_dataset)

In [None]:
print ('Before preprocessing')
reco_dataset.plot(selection=True,raw=True,log=True)
print ('After preprocessing')
reco_dataset.plot(selection=True,raw=False,log=True)

In [None]:
# Also not needed, just checking 
reco_loader = DataLoader(
    reco_dataset,
    batch_size = 32,
    num_workers = num_workers,
    pin_memory = True,
)
batch = next(iter(reco_loader))

for obj,mask,sel in zip(batch['data'],batch['mask'],reco_loader.dataset.selection):
    print (sel,obj.shape,mask.shape)

# Combined dataset

In [None]:
print(f'Intersection Branch: {reco_dataset.intersection_branch}')
print (f'Hard Datset keys: {hard_dataset.metadata.keys()}')
print (f'Reco Datset keys: {reco_dataset.metadata.keys()}')

combined_dataset = CombinedDataset(
    hard_dataset=hard_dataset,
    reco_dataset=reco_dataset,
)
print(combined_dataset)

In [None]:
combined_loader = DataLoader(
    combined_dataset,
    batch_size = 256,
)
batch = next(iter(combined_loader))

print ('Reco')
for obj,mask,sel in zip(batch['reco']['data'],batch['reco']['mask'],combined_loader.dataset.reco_dataset.selection):
    print (sel,obj.shape,mask.shape)

print ('Hard')
for obj,mask,sel in zip(batch['hard']['data'],batch['hard']['mask'],combined_loader.dataset.hard_dataset.selection):
    print (sel,obj.shape,mask.shape)

In [None]:
# Split dataset into training and validation
# Not randomly for reproducilibility, but just based on number

train_frac = 0.8
indices = torch.arange(len(combined_dataset))
sep = int(train_frac*len(combined_dataset))
train_indices = indices[:sep]
valid_indices = indices[sep:]

combined_dataset_train = torch.utils.data.Subset(combined_dataset,train_indices)
combined_dataset_valid = torch.utils.data.Subset(combined_dataset,valid_indices)
print (f'Dataset : training {len(combined_dataset_train)} / validation {len(combined_dataset_valid)}')

batch_size = 1024

combined_loader_train = DataLoader(
    combined_dataset_train,
    batch_size = batch_size,
    shuffle = True,
    num_workers = num_workers,
    pin_memory = True,
)
combined_loader_valid = DataLoader(
    combined_dataset_valid,
    batch_size = 10000,
    shuffle = False,
    num_workers = num_workers,
    pin_memory = True,
)
print (f'Batching {len(combined_loader_train)} / Validation {len(combined_loader_valid)}')

# TransferCFM

In [None]:
for i, (obj, mask) in enumerate(zip(batch['reco']['data'], batch['reco']['mask'])):
    print(f"Reco Object {i}: Shape = {obj.shape}, Mask Shape = {mask.shape}")
for i, (obj, mask) in enumerate(zip(batch['hard']['data'], batch['hard']['mask'])):
    print(f"Hard Object {i}: Shape = {obj.shape}, Mask Shape = {mask.shape}")

In [None]:
print(combined_dataset.hard_dataset.input_features)
print(combined_dataset.reco_dataset.input_features)

In [None]:
model = ParallelTransfusion(
    embed_dims=[32, 64], # was [32, 64]
    embed_act=nn.GELU,
    dropout=0.0,

    n_hard_particles_per_type=combined_dataset.hard_dataset.number_particles_per_type,
    hard_particle_type_names=combined_dataset.hard_dataset.selection,
    hard_input_features_per_type=combined_dataset.hard_dataset.input_features, # These are all the features available, speciefied in dataclass

    n_reco_particles_per_type=combined_dataset.reco_dataset.number_particles_per_type,
    reco_particle_type_names= combined_dataset.reco_dataset.selection,
    reco_input_features_per_type=combined_dataset.reco_dataset.input_features,

    # Only pick a subset in bridging:
    flow_input_features = [
        ["pt", "eta", "phi", "mass"],  # Features for reco type 0 (e.g., jets)
        ["pt", "phi"],         # Features for reco type 1 (e.g., MET)
        # Add more reco types as needed
    ],

    hard_mask_attn=None,
    reco_mask_attn=reco_dataset.attention_mask,
    transformer_args={
        "nhead": 8, # was 8
        "num_encoder_layers": 6,
        "num_decoder_layers": 8, # was 8
        "dim_feedforward": 256, # was 256
        "activation": "relu",
    },
    cfm_args={
        "dim_hidden": 512, # was 512
        "num_layers": 8, # was 8
        "activation": nn.SiLU,
    },
    sigma=0.1,
    # # ot_reg=0.1, # For OT CFMs
    # ot_method='exact', # For OT & SchrodingBridge CFMs
    # # normalize_cost=True, # For sinkhorn ot_method
)


# Quick test on one batch
batch = next(iter(combined_loader_train))
for i, (obj, mask) in enumerate(zip(batch["hard"]["data"], batch["hard"]["mask"])):
    print(f"hard_data[{i}] shape = {obj.shape}, mask shape = {mask.shape}")
for i, (obj, mask) in enumerate(zip(batch["reco"]["data"], batch["reco"]["mask"])):
    print(f"reco_data[{i}] shape = {obj.shape}, mask shape = {mask.shape}")

loss = model.shared_eval(batch, 1, 'train')
print("Initial CFM loss:", loss.item())

print(model)

In [None]:
print(type(batch['hard']['data']), len(batch['hard']['data'])) # len = 2 is for 2 particles
print(type(batch['reco']['data']), len(batch['reco']['data']))

In [None]:
# Callbacks to make plots within comet
bias = BiasCallback(
    dataset = combined_dataset_valid,               # dataset on which to evaluate bias
    preprocessing = combined_dataset.reco_dataset.preprocessing, # preprocessing pipeline to draw raw variables
    N_sample = 100,                                 # number of samples to draw
    steps = 20,                                     # Number of bridging steps
    store_trajectories = False,                     # To save trajectories plots
    frequency = 501,                                 # plotting frequency (epochs)
    raw = True,
    bins = 101,                                      # 1D/2D plot number of bins
    points = 20,                                    # Number of points for the quantile
    log_scale = True,                               # log
    batch_size = 1000,                              # Batch size to evaluate the dataset (internally makes a loaded)
    N_batch = 1,                                   # Stop after N batches (makes it faster)
    suffix = 'ttH',                                 # name for plots
    label_names = {                                 # makes nicer labels
        'pt' : 'p_T',
        'eta' : '\eta',
        'phi' : '\phi',
    },
)

sampling = SamplingCallback(
    dataset = combined_dataset_valid,           # dataset to check sampling
    preprocessing = combined_dataset.reco_dataset.preprocessing, # preprocessing pipeline
    idx_to_monitor = [1,2,3,4,5,6],               # idx of events in dataset to make plots with
    N_sample = 1000,                         # number of samples to draw
    steps = 20,                                     # Number of bridging steps
    store_trajectories = False,                     # To save trajectories plots
    frequency = 100,                             # plotting frequency (epochs)
    bins = 51,                                  # 1D/2D plot number of bins
    log_scale = True,                           # log
    label_names = {                             # makes nicer labels
        'pt' : 'p_T',
        'eta' : '\eta',
        'phi' : '\phi',
    },
)

In [None]:
epochs = 501
steps_per_epoch_train = math.ceil(len(combined_dataset_train)/combined_loader_train.batch_size)

# Optimizer + scheduler
optimizer = optim.Adam(model.parameters(), lr=1e-4) # was lr=1e-5
model.set_optimizer(optimizer)

# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
#     optimizer=optimizer,
#     mode='min',
#     factor=0.1,
#     patience=10,
#     threshold=0.001,
#     threshold_mode='rel',
#     cooldown=0,
#     min_lr=1e-7
# )

scheduler = CosineAnnealingLR(
    optimizer=optimizer,
    T_max=epochs,  # Num epochs before lr is reduced to min val
    # eta_min=1e-6   # Min value lr can decay to
    )

model.set_scheduler_config({
    'scheduler': scheduler,
    'interval': 'epoch',
    'frequency': 1,
    'monitor': 'val_loss',
    'strict': True,
    'name': 'scheduler',
})

# Logger + Trainer
logger = pl_loggers.CometLogger(
    save_dir='../comet_logs',
    project_name='mem-flow-ttH',
    experiment_name='test',
    offline=False,
    # experiment_key="4f3d2b1e843d489ea4ebd17ce92d9035", # Append to existing experiment on Comet
    # resume = True
)

##### Callbacks #####
callbacks = [
    L.pytorch.callbacks.LearningRateMonitor(logging_interval = 'epoch'),
    L.pytorch.callbacks.ModelSummary(max_depth=2),
    sampling,
    bias,
    ModelCheckpoint(save_every_n_epochs=10, save_dir="trained_model_checkpoints/test"),
]

trainer = L.Trainer(
    min_epochs=5,
    max_epochs=epochs,
    callbacks=callbacks,
    devices='auto',
    accelerator='auto',
    logger=logger,
    log_every_n_steps=steps_per_epoch_train // 100,
)

# 6) Fit
trainer.fit(
    model=model,
    train_dataloaders=combined_loader_train,
    val_dataloaders=combined_loader_valid,
    # ckpt_path="TransferCFM_checkpoints/model_epoch_10.ckpt" # Use to resume training from a checkpoint
)

logger.experiment.end()