In [1]:
%load_ext autoreload
%autoreload 2

import time
import yaml
import math
import copy
import torch
from torch import nn, optim
import numpy as np
import awkward as ak
import uproot
import pandas as pd
import dask
import vector
import particle
import hepunits

import zuko
import torch
from torch import nn, optim
import lightning as L
from lightning.pytorch import loggers as pl_loggers

from torch.utils.data import DataLoader

from memflow.dataset.data import RootData,ParquetData
from memflow.dataset.dataset import MultiDataset
from memflow.callbacks.clustering_callbacks import *
from memflow.models.custom_flows import *
from memflow.models.transformer_autoencoder import TAE

from matplotlib import pyplot as plt
plt.rcParams.update({'figure.max_open_warning': 100})

vector.register_awkward()

print (f"Running on GPU : {torch.cuda.is_available()}")
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
print (f"Accelerator : {accelerator}")
torch.set_float32_matmul_precision('medium')  
if accelerator =='gpu':
    torch.cuda.empty_cache()
    print (torch.cuda.memory_summary(device=None, abbreviated=True))

Running on GPU : True
Accelerator : gpu
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------------------|
| Active memory         |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------------------|
| Requested memory      |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------------------|
| GPU reserved memory   |      0 B   |      0 B   |      0 B   |      0 B   |
|-----------------------

# Dataset

In [21]:
from memflow.HH.HH import *
from memflow.HH.DY import *
from memflow.HH.ttbar import *
from memflow.HH.ZZ import *
from memflow.HH.ZH import *
from memflow.HH.ST import *

args = dict(
    data = None,
    selection = [
        'muons',
        'electrons',
        'met',
        'jets',
    ],
    coordinates = 'cylindrical',
    topology = 'resolved',
    apply_boost = False,
    apply_preprocessing = True,
    default_features = {
        'pt': 0.,
        'eta': 0.,
        'phi': 0.,
        'mass': 0.,
        'btag' : 0.,
        'btagged': None,
        'pdgId' : 0.,
        'charge' : 0.,
    },
    build = False,
    fit = True,
    build_dir = '/nfs/scratch/fynu/fbury/MEMFlow_data/transfer_flow_v7',
    dtype = torch.float32,
)

classes = {
    'HH' : HHbbWWDoubleLeptonRecoDataset,
    'TT' : TTDoubleLeptonRecoDataset,
    #'DY' : DYDoubleLeptonRecoDataset,
    'ZZ' : ZZDoubleLeptonRecoDataset,
    'ZH' : ZHDoubleLeptonRecoDataset,
    'STminus' : STMinusDoubleLeptonRecoDataset,
    'STplus' : STPlusDoubleLeptonRecoDataset,
}
datasets = {
    suffix : cls(**args)
    for suffix,cls in classes.items()
}
for suffix,dataset in datasets.items():
    print (suffix,len(dataset))

multi_dataset = MultiDataset(datasets)
print (len(multi_dataset))

Loading objects from /nfs/scratch/fynu/fbury/MEMFlow_data/transfer_flow_v7/hh_reco
Saving preprocessing to /nfs/scratch/fynu/fbury/MEMFlow_data/transfer_flow_v7/hh_reco
Will overwrite what is in output directory /nfs/scratch/fynu/fbury/MEMFlow_data/transfer_flow_v7/hh_reco/preprocessing
Preprocessing saved in /nfs/scratch/fynu/fbury/MEMFlow_data/transfer_flow_v7/hh_reco/preprocessing
Loading objects from /nfs/scratch/fynu/fbury/MEMFlow_data/transfer_flow_v7/tt_reco
Saving preprocessing to /nfs/scratch/fynu/fbury/MEMFlow_data/transfer_flow_v7/tt_reco
Will overwrite what is in output directory /nfs/scratch/fynu/fbury/MEMFlow_data/transfer_flow_v7/tt_reco/preprocessing
Preprocessing saved in /nfs/scratch/fynu/fbury/MEMFlow_data/transfer_flow_v7/tt_reco/preprocessing
Loading objects from /nfs/scratch/fynu/fbury/MEMFlow_data/transfer_flow_v7/zz_reco
Saving preprocessing to /nfs/scratch/fynu/fbury/MEMFlow_data/transfer_flow_v7/zz_reco
Will overwrite what is in output directory /nfs/scratch/f

In [3]:
# for dataset in datasets:
#     dataset.plot(raw=True,log=True)

In [4]:
# for dataset in datasets:
#     dataset.plot(raw=False,log=True)

In [22]:
train_frac = 0.9
indices = torch.randperm(len(multi_dataset))
sep = int(train_frac*len(multi_dataset))
train_indices = indices[:sep]
valid_indices = indices[sep:]

dataset_train = torch.utils.data.Subset(multi_dataset,train_indices)
dataset_valid = torch.utils.data.Subset(multi_dataset,valid_indices)
print (f'Dataset : training {len(dataset_train)} / validation {len(dataset_valid)}')

batch_size = 1024

loader_train = DataLoader(
    dataset_train,
    batch_size = batch_size,
    shuffle = True,
)
loader_valid = DataLoader(
    dataset_valid,
    batch_size = 10000,
    shuffle = False,
)
print (f'Batching {len(loader_train)} / Validation {len(loader_valid)}')

Dataset : training 571715 / validation 63524
Batching 559 / Validation 7


# Training

In [27]:
model = TAE(
    dim_input = len(list(datasets.values())[0].input_features[0]),
    dim_embeds = [32,64,128,256],
    dim_latents = [128,64,32],
    nhead = 8,
    expansion_factor = 4,
    activation = nn.GELU,
    num_encoding_layers = 12,
    num_decoding_layers = 1,
    decoder_window = 1,
    max_seq_len = sum(list(datasets.values())[0].number_particles_per_type),
    reco_mask_attn = list(datasets.values())[0].attention_mask,
    dropout = 0.,
    process_names = multi_dataset.names,
)
batch = next(iter(loader_train))

x_init = torch.cat(batch['data'],dim=1)
mask = torch.cat(batch['mask'],dim=1)
x_reco = model(x_init,mask)
print (x_init.shape,x_reco.shape)

print (f'Number of parameters {sum(p.numel() for p in model.parameters())}')
print (model)


Decoder mask
tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False,  True, False,  True,  True,  True,  True,  True,  True,  True],
        [False,  True,  True, False,  True,  True,  True,  True,  True,  True],
        [False,  True,  True,  True, False,  True,  True,  True,  True,  True],
        [False,  True,  True,  True,  True, False,  True,  True,  True,  True],
        [False,  True,  True,  True,  True,  True, False,  True,  True,  True],
        [False,  True,  True,  True,  True,  True,  True, False,  True,  True],
        [False,  True,  True,  True,  True,  True,  True,  True, False,  True],
        [False,  True,  True,  True,  True,  True,  True,  True,  True, False]])
torch.Size([1024, 9, 7]) torch.Size([1024, 9, 7])
Number of parameters 10181671
TAE(
  (_loss_function): MSELoss()
  (embedding_encoder): MLP(
    (layers): Sequential(
      (0): Linear

In [28]:
plot_callback_reco = ReconstructionCallback(
    dataset = dataset_valid,
    preprocessing = list(datasets.values())[0]._preprocessing,
    names = list(datasets.values())[0].selection,
    features = list(datasets.values())[0].input_features,
    number_particles_per_type = list(datasets.values())[0].number_particles_per_type,
    generative = False,
    frequency = 10,
    bins = 51,
    log_scale = True,
    batch_size = 5000,
)
plot_callback_gen = ReconstructionCallback(
    dataset = dataset_valid,
    preprocessing = list(datasets.values())[0]._preprocessing,
    names = list(datasets.values())[0].selection,
    features = list(datasets.values())[0].input_features,
    number_particles_per_type = list(datasets.values())[0].number_particles_per_type,
    generative = True,
    frequency = 10,
    bins = 51,
    log_scale = True,
    batch_size = 5000,
)
# figs = plot_callback.make_plots(model.cuda(),show=True)

In [None]:
##### Parameters #####
epochs = 500
steps_per_epoch_train = math.ceil(len(dataset_train)/loader_train.batch_size)

print (f'Training   : Batch size = {loader_train.batch_size} => {steps_per_epoch_train} steps per epoch')
##### Optimizer #####
optimizer = optim.RAdam(model.parameters(), lr=1e-4)
model.set_optimizer(optimizer)

##### Scheduler #####
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer = optimizer,
    mode='min', 
    factor=0.1, 
    patience=10, 
    threshold=0.001, 
    threshold_mode='rel', 
    cooldown=0, 
    min_lr=1e-7
)
model.set_scheduler_config(
    {
        'scheduler' : scheduler,
        'interval' : 'step' if isinstance(scheduler,optim.lr_scheduler.OneCycleLR) else 'epoch',
        'frequency' : 1,
        'monitor' : 'val/loss_tot',
        'strict' : True,
        'name' : 'scheduler',
    }
)

##### Callbacks #####
callbacks = [
    L.pytorch.callbacks.LearningRateMonitor(logging_interval = 'epoch'),
    L.pytorch.callbacks.ModelSummary(max_depth=2),
    plot_callback_reco,
    plot_callback_gen,
]

##### Logger #####
logger = pl_loggers.CometLogger(
    save_dir = '../comet_logs',
    project_name = 'TAE',
    experiment_name = 'AE',
    offline = False,
) 
logger.log_graph(model)
# logger.log_hyperparams()
# logger.experiment.log_code(folder='../src/')
logger.experiment.log_notebook(filename=globals()['__session__'],overwrite=True)

##### Trainer #####
trainer = L.Trainer(    
    min_epochs = 5,
    max_epochs = epochs,
    callbacks = callbacks,
    devices = 'auto',
    accelerator = accelerator,
    logger = logger,
    log_every_n_steps = steps_per_epoch_train/100,
)
##### Fit #####
trainer.fit(
    model = model, 
    train_dataloaders = loader_train,
    val_dataloaders = loader_valid,
)

CometLogger will be initialized in online mode


Training   : Batch size = 1024 => 559 steps per epoch


[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/florianbury/tae/94eadd7578c247c289a15f0ad21cfea8

/home/ucl/cp3/fbury/scratch/anaconda3/envs/mem-flow/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/ucl/cp3/fbury/scratch/anaconda3/envs/mem-flow/ ...
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                      | Type               | Params | Mode 
--------------------------------------------------------------------------
0  | _loss_function

Sanity Checking: |                                                                            | 0/? [00:00<?, …

Training: |                                                                                   | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

Validation: |                                                                                 | 0/? [00:00<?, …

In [None]:
figs = plot_callback_reco.make_plots(model.cuda(),show=True)

In [None]:
figs = plot_callback_gen.make_plots(model.cuda(),show=True)

In [None]:
torch.save(model,'TAE_32.pt')

# Encode latent space

In [None]:
model = torch.load('TAE_32.pt',map_location="cpu")

In [None]:
print (model)

In [None]:
def encode(model,loader,device='cpu'):
    x_init = []
    mask   = []
    latent = []
    
    model = model.to(device)
    process = []
    for batch in tqdm(loader,position=0,leave=True):
        xi = torch.cat(batch['data'],dim=1).to(model.device)
        m = torch.cat(batch['mask'],dim=1).to(model.device)
        process.append(batch['process'])
        with torch.no_grad():
            latent.append(model.encode(xi,m).cpu())
        x_init.append(xi)
        mask.append(m)
        
    latent = torch.cat(latent,dim=0)
    x_init = torch.cat(x_init,dim=0)
    mask   = torch.cat(mask,dim=0)
    process= torch.cat(process,dim=0)

    return latent,x_init,mask,process

latent_train,x_train,mask_train,process_train = encode(model,loader_train,'cpu')
latent_valid,x_valid,mask_valid,process_valid = encode(model,loader_valid,'cpu')

print (f'Training {latent_train.size(0)} / Validation {latent_valid.size(0)}')
print ('Shape :',latent_train.shape)

In [None]:
outdir = '/nfs/scratch/fynu/fbury/MEMFlow_data/clustering'
torch.save(latent_train,os.path.join(outdir,'latent_train'))
torch.save(x_train,os.path.join(outdir,'x_train'))
torch.save(mask_train,os.path.join(outdir,'mask_train'))
torch.save(process_train,os.path.join(outdir,'process_train'))
torch.save(latent_valid,os.path.join(outdir,'latent_valid'))
torch.save(x_valid,os.path.join(outdir,'x_valid'))
torch.save(mask_valid,os.path.join(outdir,'mask_valid'))
torch.save(process_valid,os.path.join(outdir,'process_valid'))

# Encoding-decoding

In [None]:
model = torch.load('TAE.pt')

In [None]:
outdir = '/nfs/scratch/fynu/fbury/MEMFlow_data/clustering'
latent_train = torch.load(os.path.join(outdir,'latent_train'),map_location=torch.device('cpu'))
x_train      = torch.load(os.path.join(outdir,'x_train'),map_location=torch.device('cpu'))
mask_train   = torch.load(os.path.join(outdir,'mask_train'),map_location=torch.device('cpu'))
latent_valid = torch.load(os.path.join(outdir,'latent_valid'),map_location=torch.device('cpu'))
x_valid      = torch.load(os.path.join(outdir,'x_valid'),map_location=torch.device('cpu'))
mask_valid   = torch.load(os.path.join(outdir,'mask_valid'),map_location=torch.device('cpu'))

print (f'Training {latent_train.shape}, Validation {latent_valid.shape}')

In [None]:
model = model.to('cuda')
N = 1
model.eval()
x_tmp = x_valid[:N].cpu()
m_tmp = mask_valid[:N].cpu()
with torch.no_grad():
    print ('Forward\n')
    reco_tmp   = model(x_tmp.to(model.device),m_tmp.to(model.device)).cpu()
    print ('Encode\n')
    latent_tmp = model.encode(x_tmp.to(model.device),m_tmp.to(model.device)).cpu()
    print ('Decode\n')
    y_tmp      = model.decode(latent_tmp.to(model.device),m_tmp.to(model.device)).cpu()

print (y_tmp.shape,reco_tmp.shape,x_tmp.shape)

In [None]:
idx = 0
print (y_tmp[idx,:,:4].round(decimals=2))
print (reco_tmp[idx,:,:4].round(decimals=2))
print (x_tmp[idx,:,:4].round(decimals=2))

In [None]:
(y_tmp-x_tmp).abs().sum(dim=-1)[0]

In [None]:
(reco_tmp-x_tmp).abs().sum(dim=-1)[0]