# Imports

In [None]:
%load_ext autoreload
%autoreload 2

import os
import math
import copy

import matplotlib.pyplot as plt
import numpy as np
import awkward as ak
import uproot
import pandas as pd
import dask

import vector
import particle
import hepunits

import comet_ml
import zuko
import torch
from torch import nn, optim
import lightning as L
from lightning.pytorch import loggers as pl_loggers

from torch.utils.data import DataLoader

from memflow.dataset.data import RootData,ParquetData
from memflow.dataset.dataset import CombinedDataset
from memflow.ttH.ttH_dataclasses import ttHHardDataset, ttHRecoDataset
from memflow.models.transfer_flow_model import TransferFlow
from memflow.models.custom_flows import *
from memflow.callbacks.transfer_flow_callbacks import SamplingCallback, BiasCallback

vector.register_awkward()

os.environ["CUDA_VISIBLE_DEVICES"] = "3"  # Change "<n>" to the index of the GPU you want to use on node

print (f"Running on GPU : {torch.cuda.is_available()}")
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
print (f"Accelerator : {accelerator}")
torch.set_float32_matmul_precision('medium')  
if accelerator =='gpu':
    torch.cuda.empty_cache()
    print (torch.cuda.memory_summary(device=None, abbreviated=True))

# Data

In [None]:
data_hard = ParquetData(
    files = [
        '/cephfs/dice/users/sa21722/datasets/MEM_data/ttH/TF_v6/hard/2018/ttH/ttH_HToInvisible_M125.parquet',
        #'all_jets_fullRun2_ttHbb_forTraining_allyears_spanetprov_part1_validation.parquet',
        #'all_jets_fullRun2_ttHTobb_forTraining_2016_PreVFP_v3.parquet',
    ],
    lazy = True,
    N = int(1e5),
)

print (data_hard)

# Hard dataset

In [None]:
hard_dataset = ttHHardDataset(
    data = data_hard,
    selection = [
        # 'higgs',
        # 'tops',
        # 'bottoms',
        # 'Ws',
        # 'Zs',
        'quarks',
        'neutrinos',
    ],
    build = True,
    fit = True,
    coordinates = 'cylindrical',
    apply_preprocessing = True,
    apply_boost = False,
    dtype = torch.float32,
)
print(hard_dataset)

In [None]:
print ('Before preprocessing')
hard_dataset.plot(selection=True,raw=True)
print ('After preprocessing')
hard_dataset.plot(selection=True,raw=False)

In [None]:
# This is not strictly necessary, but just to make sure loading works as expected
# We will use later a combined dataset (hard+reco) below
hard_loader = DataLoader(
    hard_dataset,
    batch_size = 32,
)
batch = next(iter(hard_loader))

for obj,mask,sel in zip(batch['data'],batch['mask'],hard_loader.dataset.selection):
    print (sel,obj.shape,mask.shape)

# Reco dataset

In [None]:
data_reco = ParquetData(
    files = [
        '/cephfs/dice/users/sa21722/datasets/MEM_data/ttH/TF_v6/reco/2018/ttH/ttH_HToInvisible_M125.parquet',
    ],
    lazy = True,
    N = data_hard.N,
)

print (data_reco)

In [None]:
reco_dataset = ttHRecoDataset(
    data = data_reco,
    selection = [
        'jets',
        'met',
    ],
    build = True,
    fit = True,
    coordinates = 'cylindrical',
    apply_preprocessing = True,
    apply_boost = False,
    dtype = torch.float32,
)
print(reco_dataset)

In [None]:
print ('Before preprocessing')
reco_dataset.plot(selection=True,raw=True,log=True)
print ('After preprocessing')
reco_dataset.plot(selection=True,raw=False,log=True)

In [None]:
# Also not needed, just checking 
reco_loader = DataLoader(
    reco_dataset,
    batch_size = 32,
)
batch = next(iter(reco_loader))

for obj,mask,sel in zip(batch['data'],batch['mask'],reco_loader.dataset.selection):
    print (sel,obj.shape,mask.shape)

# Combined dataset

In [None]:
print(reco_dataset.intersection_branch)
print (hard_dataset.metadata.keys())
print (reco_dataset.metadata.keys())

combined_dataset = CombinedDataset(
    hard_dataset=hard_dataset,
    reco_dataset=reco_dataset,
)
print (combined_dataset)

In [None]:
combined_loader = DataLoader(
    combined_dataset,
    batch_size = 256,
)
batch = next(iter(combined_loader))

print ('Reco')
for obj,mask,sel in zip(batch['reco']['data'],batch['reco']['mask'],combined_loader.dataset.reco_dataset.selection):
    print (sel,obj.shape,mask.shape)
print ('Hard')
for obj,mask,sel in zip(batch['hard']['data'],batch['hard']['mask'],combined_loader.dataset.hard_dataset.selection):
    print (sel,obj.shape,mask.shape)

In [None]:
# Split dataset into training and validation
# Not randomly for reproducilibility, but just based on number

train_frac = 0.8
indices = torch.arange(len(combined_dataset))
sep = int(train_frac*len(combined_dataset))
train_indices = indices[:sep]
valid_indices = indices[sep:]

combined_dataset_train = torch.utils.data.Subset(combined_dataset,train_indices)
combined_dataset_valid = torch.utils.data.Subset(combined_dataset,valid_indices)
print (f'Dataset : training {len(combined_dataset_train)} / validation {len(combined_dataset_valid)}')

batch_size = 1024

combined_loader_train = DataLoader(
    combined_dataset_train,
    batch_size = batch_size,
    shuffle = True,
)
combined_loader_valid = DataLoader(
    combined_dataset_valid,
    batch_size = 10000,
    shuffle = False,
)
print (f'Batching {len(combined_loader_train)} / Validation {len(combined_loader_valid)}')

In [13]:
# # Find some indices in the validation set with max number of jets
# # To use in the sampling to see a maximum number of jets
# count = ak.count(data_reco['jets'].pt,axis=1).to_numpy()
# mask_max = count == ak.max(count)
# mask_valid = np.full(len(count),fill_value=False)
# mask_valid[valid_indices] = True
# idx_max = np.where(
#     np.logical_and(mask_max,mask_valid)
# )[0]
# for i in idx_max:
#     prov = data_reco['jets'].prov[i]
#     unique, counts = np.unique(prov, return_counts=True)
#     print ('idx',i,', '.join([f'prov {u:.0f} x {c:.0f}' for u,c in zip(unique, counts)]))

# Transfer flow

In [None]:
print(combined_dataset.reco_dataset.input_features) # Determine the available features that you can select in ` flow_input_features`
# Note the length of this (2 elements) must match the length of  `flow_input_features`
# But you dont have to select all the features in each element

In [None]:
# Transfer flow #
model = TransferFlow(
    # General args #
    dropout = 0.,
    # Embedding arguments #
    embed_dims = [32,64],
    embed_act = nn.GELU,
    # Particle features, names, masks, and number for printouts and logging #
    n_hard_particles_per_type = combined_dataset.hard_dataset.number_particles_per_type,
    hard_particle_type_names = combined_dataset.hard_dataset.selection,
    hard_input_features_per_type = combined_dataset.hard_dataset.input_features,
    n_reco_particles_per_type = combined_dataset.reco_dataset.number_particles_per_type,
    reco_particle_type_names = combined_dataset.reco_dataset.selection,
    reco_input_features_per_type = combined_dataset.reco_dataset.input_features,
    flow_input_features = [ # features to be used in the flows (different from the tranformer)
        # ['pt','eta','phi'], # leptons
        ['pt','eta','phi'], # jets
        ['pt','phi'],        # met

    ],
    hard_mask_attn = None,
    reco_mask_attn = combined_dataset.reco_dataset.attention_mask,
    # Transformer arguments #
    onehot_encoding = False, # add onehot encoded position vector to particles
    transformer_args = { # to be passed to the Transformer pytorch class
        'nhead' : 8,
        'num_encoder_layers' : 8, 
        'num_decoder_layers' : 8, 
        'dim_feedforward' : 128, 
        'activation' : 'gelu', 
    },
    # Flow args #
    flow_common_args = { # common args for all flows
        'bins' : 16,
        'transforms' : 5,
        'randperm' : True,
        'passes' : None,
        'hidden_features' : [128] * 2,   
    },
    flow_classes = { # classes for each feature
        'pt'  : zuko.flows.NSF,
        'eta' : UniformNSF,
        'phi' : UniformNCSF,
    },
    flow_specific_args = { # specific args for each class above
        'eta' : {'bound' : 1.},
        'phi' : {'bound' : 1.},
    },
    flow_mode = 'global', # 'global', 'type' or 'particle'
)
model = model.cpu()

# Just check the model before training #
batch = next(iter(combined_loader_train))

log_probs, mask, weights = model(batch)
mask = mask > 0
print ('log_probs',log_probs,log_probs.shape)
print ('mask',mask,mask.shape)
print ('weights',weights,weights.shape)

log_probs_tot = model.shared_eval(batch,0,'test')
print ('tot log probs',log_probs_tot)

samples = model.sample(batch['hard']['data'],batch['hard']['mask'],batch['reco']['data'],batch['reco']['mask'],N=100)
print ('samples')
for sample in samples:
    print ('\t',sample.shape)

print (model)

In [16]:
# # Preprocess the dataset to include only the first 2 jets
# print(combined_dataset_valid.keys())
# print(type(combined_dataset_valid['reco']))
# print(combined_dataset_valid['reco'])
# # Limit to the first 2 jets in the 'data' key
# for i, tensor in enumerate(combined_dataset_valid['reco']['data']):
#     combined_dataset_valid['reco']['data'][i] = tensor[:2]  # Keep only the first 2 jets
# print(combined_dataset_valid['reco']['data'])
# MODIFY EVENTS IN CLASS WHEN SELECTING EVENTS (EG EVENTS THAT HAVE 1 B JET OR 5 JETS, ORDER IN BTAG OR PT)

# Callbacks to make plots within comet
bias = BiasCallback(
    dataset = combined_dataset_valid,               # dataset on which to evaluate bias
    preprocessing = combined_dataset.reco_dataset.preprocessing, # preprocessing pipeline to draw raw variables
    N_sample = 100,                                 # number of samples to draw
    frequency = 20,                                 # plotting frequency (epochs)
    raw = True,
    bins = 101,                                      # 1D/2D plot number of bins
    points = 20,                                    # Number of points for the quantile
    log_scale = True,                               # log
    batch_size = 50000,                              # Batch size to evaluate the dataset (internally makes a loaded)
    N_batch = 1,                                   # Stop after N batches (makes it faster)
    suffix = 'ttH',                                 # name for plots
    label_names = {                                 # makes nicer labels
        'pt' : 'p_T',
        'eta' : '\eta',
        'phi' : '\phi',
    },
)
figs = bias.make_bias_plots(model.cuda(),show=True) # show is to plot standalone 
# of course now they are bad, need to train first

In [17]:
sampling = SamplingCallback(
    dataset = combined_dataset_valid,           # dataset to check sampling
    preprocessing = combined_dataset.reco_dataset.preprocessing, # preprocessing pipeline
    idx_to_monitor = [0,2,],               # idx of events in dataset to make plots with
    N_sample = 100000,                         # number of samples to draw
    frequency = 10,                             # plotting frequency (epochs)
    bins = 51,                                  # 1D/2D plot number of bins
    log_scale = True,                           # log
    label_names = {                             # makes nicer labels
        'pt' : 'p_T',
        'eta' : '\eta',
        'phi' : '\phi',
    },
)
figs = sampling.make_sampling_plots(model.cuda(),show=True) # show is to plot standalone 

In [None]:
##### Parameters #####
epochs = 40
steps_per_epoch_train = math.ceil(len(combined_dataset_train)/combined_loader_train.batch_size)

print (f'Training   : Batch size = {combined_loader_train.batch_size} => {steps_per_epoch_train} steps per epoch')
##### Optimizer #####
optimizer = optim.Adam(model.parameters(), lr=1e-3)
model.set_optimizer(optimizer)

##### Scheduler #####
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer = optimizer,
    mode='min', 
    factor=0.1, 
    patience=10, 
    threshold=0.001, 
    threshold_mode='rel', 
    cooldown=0, 
    min_lr=1e-7
)
model.set_scheduler_config(
    {
        'scheduler' : scheduler,
        'interval' : 'step' if isinstance(scheduler,optim.lr_scheduler.OneCycleLR) else 'epoch',
        'frequency' : 1,
        'monitor' : 'val/loss_tot',
        'strict' : True,
        'name' : 'scheduler',
    }
)

##### Callbacks #####
callbacks = [
    L.pytorch.callbacks.LearningRateMonitor(logging_interval = 'epoch'),
    L.pytorch.callbacks.ModelSummary(max_depth=2),
    sampling,
    bias,
]

##### Logger #####
logger = pl_loggers.CometLogger(
    save_dir = '../comet_logs',
    project_name = 'mem-flow-ttH',
    experiment_name = 'combined',
    offline = False,
) 
logger.log_graph(model)
logger.experiment.log_notebook(filename='transfer_flow.ipynb',overwrite=True)

##### Trainer #####
trainer = L.Trainer(
    min_epochs = 5,
    max_epochs = epochs,
    callbacks = callbacks,
    devices = 'auto',
    accelerator = accelerator,
    logger = logger,
    log_every_n_steps = steps_per_epoch_train/100,
)
##### Fit #####
trainer.fit(
    model = model,
    train_dataloaders = combined_loader_train,
    val_dataloaders = combined_loader_valid,
)

logger.experiment.end()

In [None]:
figs = sampling.make_sampling_plots(model.cuda(),show=True) # show is to plot standalone 

In [None]:
figs = bias.make_bias_plots(model.cuda(),show=True) # show is to plot standalone 