In [1]:
%load_ext autoreload
%autoreload 2

import math
import copy
import matplotlib.pyplot as plt
import numpy as np
import awkward as ak
import uproot
import pandas as pd
import dask
import os
import multiprocessing
import vector
import particle
import hepunits

# Deep learning imports
import zuko # Flow based models
import torch
from torch import nn, optim
import pytorch_lightning as L # PyTorch Lightning for training
from pytorch_lightning import loggers as pl_loggers

from torch.utils.data import DataLoader

from memflow.dataset.data import RootData, ParquetData
from memflow.dataset.dataset import HardDataset, RecoDataset, CombinedDataset

from memflow.ttH.ttH_dataclasses import ttHHardDataset, ttHRecoDataset

from transfer_flow.tools import *
from transfer_flow.custom_flows import *
from transfer_flow.transfer_flow_model import TransferFlow

from transfer_flow.transfer_flow_callbacks import *
from models.callbacks import ModelCheckpoint

plt.rcParams.update({'figure.max_open_warning': 100})

vector.register_awkward()

num_workers = min(16, multiprocessing.cpu_count())  # Use up to 16 CPU cores
print(f'Number of CPU workers for dataloading: {num_workers}')

os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # Change "<n>" to the index of the GPU you want to use on node

print (f"Running on GPU : {torch.cuda.is_available()}")
accelerator = 'cuda' if torch.cuda.is_available() else 'cpu'
print (f"Accelerator : {accelerator}")
torch.set_float32_matmul_precision('medium')
if accelerator =='cuda':
    torch.cuda.empty_cache()
    print (torch.cuda.memory_summary(device=None, abbreviated=True))

  from pandas.core.computation.check import NUMEXPR_INSTALLED


Number of CPU workers for dataloading: 16
Running on GPU : True
Accelerator : cuda
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------------------|
| Active memory         |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------------------|
| Requested memory      |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------------------|
| GPU reserved memory   |      0 B   |      0 B   |      0 

# Data

In [2]:
data_hard = ParquetData(
    files = [
        '/cephfs/dice/users/sa21722/datasets/MEM_data/ttH/TF_v6/hard/2018/ttH/ttH_HToInvisible_M125.parquet',
    ],
    lazy = True,
    # N = int(1e5),
)

print (data_hard)

Data object
Loaded branches:
   ... file: 1903554
   ... sample: 1903554
   ... tree: 1903554
Branch in files not loaded:
   ... Generator_scalePDF
   ... Generator_weight
   ... Generator_x1
   ... Generator_x2
   ... Generator_xpdf1
   ... Generator_xpdf2
   ... W_minus_from_antitop_eta
   ... W_minus_from_antitop_genPartIdxMother
   ... W_minus_from_antitop_idx
   ... W_minus_from_antitop_mass
   ... W_minus_from_antitop_pdgId
   ... W_minus_from_antitop_phi
   ... W_minus_from_antitop_pt
   ... W_minus_from_antitop_status
   ... W_minus_from_antitop_statusFlags
   ... W_plus_from_top_eta
   ... W_plus_from_top_genPartIdxMother
   ... W_plus_from_top_idx
   ... W_plus_from_top_mass
   ... W_plus_from_top_pdgId
   ... W_plus_from_top_phi
   ... W_plus_from_top_pt
   ... W_plus_from_top_status
   ... W_plus_from_top_statusFlags
   ... Z_from_higgs_eta
   ... Z_from_higgs_genPartIdxMother
   ... Z_from_higgs_idx
   ... Z_from_higgs_mass
   ... Z_from_higgs_pdgId
   ... Z_from_higgs_phi

In [3]:
hard_dataset = ttHHardDataset(
    data = data_hard,
    selection = [
        # 'higgs',
        # 'tops',
        'bottoms',
        # 'Ws',
        # 'Zs',
        'quarks',
        'neutrinos',
    ],
    build = False,
    fit = True,
    coordinates = 'cylindrical',
    apply_preprocessing = True,
    apply_boost = False,
    dtype = torch.float32,
)

obj = hard_dataset.objects['neutrinos']

# Check its type
print("Type of neutrinos object:", type(obj))

# If it’s a tuple, print its length
if isinstance(obj, tuple):
    print("Length of tuple:", len(obj))
    for i, element in enumerate(obj):
        print(f"Element {i} has type {type(element)} and shape/length:", end=" ")
        # If element is a torch.Tensor or numpy array:
        if hasattr(element, 'shape'):
            print(element.shape)
        # If it's a list of fields:
        else:
            print(element)

Loading objects from /cephfs/dice/users/sa21722/projects/MEM/memflow/ttH/ttH_hard
Saving preprocessing to /cephfs/dice/users/sa21722/projects/MEM/memflow/ttH/ttH_hard
Will overwrite what is in output directory /cephfs/dice/users/sa21722/projects/MEM/memflow/ttH/ttH_hard/preprocessing
Preprocessing saved in /cephfs/dice/users/sa21722/projects/MEM/memflow/ttH/ttH_hard/preprocessing
Type of neutrinos object: <class 'tuple'>
Length of tuple: 3
Element 0 has type <class 'torch.Tensor'> and shape/length: torch.Size([756642, 4, 5])
Element 1 has type <class 'torch.Tensor'> and shape/length: torch.Size([756642, 4])
Element 2 has type <class 'torch.Tensor'> and shape/length: torch.Size([756642, 4])


In [4]:
# This is not strictly necessary, but just to make sure loading works as expected
# We will use later a combined dataset (hard+reco) below
hard_loader = DataLoader(
    hard_dataset,
    batch_size = 32,
    num_workers = num_workers, # Parallel loading
    pin_memory = True, # Faster transfer to GPU
)
batch = next(iter(hard_loader))

for obj,mask,sel in zip(batch['data'],batch['mask'],hard_loader.dataset.selection):
    print (sel,obj.shape,mask.shape)

bottoms torch.Size([32, 2, 5]) torch.Size([32, 2])
quarks torch.Size([32, 4, 5]) torch.Size([32, 4])
neutrinos torch.Size([32, 4, 5]) torch.Size([32, 4])


In [5]:
data_reco = ParquetData(
    files = [
        '/cephfs/dice/users/sa21722/datasets/MEM_data/ttH/TF_v6/reco/2018/ttH/ttH_HToInvisible_M125.parquet',
    ],
    lazy = True,
    #N = data_hard.N,
)

print(data_reco)

Data object
Loaded branches:
   ... file: 231528
   ... sample: 231528
   ... tree: 231528
Branch in files not loaded:
   ... Generator_scalePDF
   ... Generator_weight
   ... Generator_x1
   ... Generator_x2
   ... Generator_xpdf1
   ... Generator_xpdf2
   ... InputMet_phi
   ... InputMet_pt
   ... cleanedJet_btagDeepFlavB
   ... cleanedJet_eta
   ... cleanedJet_mass
   ... cleanedJet_phi
   ... cleanedJet_pt
   ... event
   ... ncleanedBJet
   ... ncleanedJet
   ... region
   ... weight_nominal
   ... xs_weight


In [6]:
reco_dataset = ttHRecoDataset(
    data = data_reco,
    selection = [
        'jets',
        'met',
    ],
    build = False,
    fit = True,
    coordinates = 'cylindrical',
    apply_preprocessing = True,
    apply_boost = False,
    dtype = torch.float32,
)
print(reco_dataset)

Loading objects from /cephfs/dice/users/sa21722/projects/MEM/memflow/ttH/ttH_reco
Saving preprocessing to /cephfs/dice/users/sa21722/projects/MEM/memflow/ttH/ttH_reco
Will overwrite what is in output directory /cephfs/dice/users/sa21722/projects/MEM/memflow/ttH/ttH_reco/preprocessing
Preprocessing saved in /cephfs/dice/users/sa21722/projects/MEM/memflow/ttH/ttH_reco/preprocessing
Reco dataset with 114647 events
Containing the following tensors
jets  : data ([114647, 6, 5]), mask ([114647, 6])
        Mask exist    : [100.00%, 100.00%, 100.00%, 100.00%, 100.00%, 62.85%]
        Mask attn     : [True, True, True, True, True, True]
        Weights       : 114647.00, 114647.00, 114647.00, 114647.00, 114647.00, 114647.00
        Features      : ['pt', 'eta', 'phi', 'mass', 'btag']
        Selected for batches : True
met   : data ([114647, 1, 4]), mask ([114647, 1])
        Mask exist    : [100.00%]
        Mask attn     : [True]
        Weights       : 114647.00
        Features      : ['pt

In [7]:
# Also not needed, just checking 
reco_loader = DataLoader(
    reco_dataset,
    batch_size = 32,
    num_workers = num_workers,
    pin_memory = True,
)
batch = next(iter(reco_loader))

for obj,mask,sel in zip(batch['data'],batch['mask'],reco_loader.dataset.selection):
    print (sel,obj.shape,mask.shape)

jets torch.Size([32, 6, 5]) torch.Size([32, 6])
met torch.Size([32, 1, 4]) torch.Size([32, 1])


In [8]:
print(f'Intersection Branch: {reco_dataset.intersection_branch}')
print (f'Hard Datset keys: {hard_dataset.metadata.keys()}')
print (f'Reco Datset keys: {reco_dataset.metadata.keys()}')

combined_dataset = CombinedDataset(
    hard_dataset=hard_dataset,
    reco_dataset=reco_dataset,
)
print(combined_dataset)

Intersection Branch: event
Hard Datset keys: dict_keys(['file', 'tree', 'sample', 'intersection'])
Reco Datset keys: dict_keys(['file', 'tree', 'sample', 'intersection'])
Intersection branches : `event` for hard dataset and `event` for reco dataset


Looking into file metadata
Will pair these files together :
   - /cephfs/dice/users/sa21722/datasets/MEM_data/ttH/TF_v6/hard/2018/ttH/ttH_HToInvisible_M125.parquet <-> /cephfs/dice/users/sa21722/datasets/MEM_data/ttH/TF_v6/reco/2018/ttH/ttH_HToInvisible_M125.parquet
For entry 0 : from 756642 events, 91819 selected
For entry 1 : from 114647 events, 91819 selected
Combined dataset (extracting 91819 events of the following) :
Parton dataset with 756642 events
 Initial states pdgids : [21, 21]
 Final states pdgids   : [25, 6, -6]
 Final states masses   : [125.2, 172.57, 172.57]
Containing the following tensors
bottoms    : data ([756642, 2, 5]), mask ([756642, 2])
             Mask exist    : [100.00%, 100.00%]
             Mask attn     : [True, True]
             Weights       : 756642.00, 756642.00
             Features      : ['pt', 'eta', 'phi', 'mass', 'pdgId']
             Selected for batches : True
Zs         : data ([756642, 2, 5]), mask ([756642, 2])
             Mask exist    :

In [None]:
# Split train and validation #
train_frac = 0.8
indices = torch.arange(len(combined_dataset))
sep = int(train_frac*len(combined_dataset))
train_indices = indices[:sep]
valid_indices = indices[sep:]

dataset_train = torch.utils.data.Subset(combined_dataset,train_indices)
dataset_valid = torch.utils.data.Subset(combined_dataset,valid_indices)
print (f'Dataset : training {len(dataset_train)} / validation {len(ataset_valid)}')

# make data loader #
batch_size = 1024

loader_train = DataLoader(
    dataset_train ,
    batch_size = batch_size,
    shuffle = True,
)
loader_valid = DataLoader(
    dataset_valid,
    batch_size = 5000,
    shuffle = False,
)
print (f'Batching {len(loader_train)} / Validation {len(loader_valid)}')

Dataset : training 73455 / validation 18364
Batching 72 / Validation 4


# Training

In [10]:
model = TransferFlow(
    encoder_embeddings = MultiEmbeddings(
        features_per_type = combined_dataset.hard_dataset.input_features,
        embed_dims = [32,64],
        hidden_activation = nn.GELU,
    ),
    decoder_embeddings = MultiEmbeddings(
        features_per_type = combined_dataset.reco_dataset.input_features,
        embed_dims = [32,64],
        hidden_activation = nn.GELU,
    ),
    transformer = Transformer(
        d_model = 64,
        encoder_layers = 6,
        decoder_layers = 8,
        nhead = 8,
        dim_feedforward = 256,
        activation = nn.GELU,
        encoder_mask_attn = None,
        decoder_mask_attn = combined_dataset.reco_dataset.attention_mask,
        use_null_token = True,
        dropout = 0.,
    ),
    flow = KinematicFlow(
        d_model = 64,
        flow_mode = 'global',
        flow_features = [
            ['pt','eta','phi','mass'], # jets
            ['pt','phi'],              # met
        ],
        flow_classes = { # classes for each feature
            'pt'  : zuko.flows.NSF,
            'eta' : UniformNSF,
            'phi' : UniformNCSF,
            'mass': zuko.flows.NSF,
        },
        flow_common_args = { # common args for all flows
        'bins' : 16,
        'transforms' : 5,
        'randperm' : True,
        'passes' : None,
        'hidden_features' : [256] * 3,   
        },
        flow_specific_args = { # specific args for each class above
            'eta' : {'bound' : 1.},
            'phi' : {'bound' : math.pi},
        },
    ),
    hard_names = combined_dataset.hard_dataset.selection,
    reco_names = combined_dataset.reco_dataset.selection,
)

model = model.cpu()

batch = next(iter(loader_train))
print ([ data.shape for data in batch['hard']['data']])
log_probs, masks, weights = model(batch)
print ('log_probs',[log_prob.shape for log_prob in log_probs])
print ('masks    ',[mask.shape for mask in masks])
print ('weights  ',[weight.shape for weight in weights])

log_probs_tot = model.shared_eval(batch,0,'test')
print ('tot log probs',log_probs_tot)

samples = model.sample(
    batch['hard']['data'],
    batch['hard']['mask'],
    batch['reco']['data'],
    batch['reco']['mask'],
    N = 3,
)
print ('samples')
for sample in samples:
    print ('\t',sample.shape)

print (model)


/software/sa21722/miniconda3/envs/mem-flow/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'encoder_embeddings' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['encoder_embeddings'])`.
/software/sa21722/miniconda3/envs/mem-flow/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'decoder_embeddings' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['decoder_embeddings'])`.
/software/sa21722/miniconda3/envs/mem-flow/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'transformer' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['transformer'])`.
/software/sa21722/miniconda3/envs/mem-flow/lib/python3.10

[torch.Size([1024, 2, 5]), torch.Size([1024, 4, 5]), torch.Size([1024, 4, 5])]


  return torch.searchsorted(seq, value).squeeze(dim=-1)


log_probs [torch.Size([1024, 6]), torch.Size([1024, 1])]
masks     [torch.Size([1024, 6]), torch.Size([1024, 1])]
weights   [torch.Size([1024, 6]), torch.Size([1024, 1])]


/software/sa21722/miniconda3/envs/mem-flow/lib/python3.10/site-packages/lightning/pytorch/core/module.py:441: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


tot log probs tensor(4.8277, grad_fn=<MeanBackward0>)
samples
	 torch.Size([3, 1024, 6, 5])
	 torch.Size([3, 1024, 1, 4])
TransferFlow(
  (encoder_embeddings): MultiEmbeddings(
    (embeddings): ModuleList(
      (0-2): 3 x MLP(
        (layers): Sequential(
          (0): Linear(in_features=5, out_features=32, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=32, out_features=64, bias=True)
        )
      )
    )
  )
  (decoder_embeddings): MultiEmbeddings(
    (embeddings): ModuleList(
      (0): MLP(
        (layers): Sequential(
          (0): Linear(in_features=5, out_features=32, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=32, out_features=64, bias=True)
        )
      )
      (1): MLP(
        (layers): Sequential(
          (0): Linear(in_features=4, out_features=32, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=32, out_features=64, bias=True)
        )
      )
 

# Training

In [11]:
sampling = SamplingCallback(
    dataset = dataset_valid,
    preprocessing = combined_dataset.reco_dataset.preprocessing,
    idx_to_monitor = [0,1,2],
    N_sample = int(1e3), 
    frequency = 500,
    bins = 50,
    hexbin = True,
    kde = False,
    log_scale = True,
    label_names = {
        'pt' : '$p_T$',
        'eta' : '$\eta$',
        'phi' : '$\phi$',
        'mass' : '$M$',
    },
    feature_rng = {
        'pt' : (0.,500.),
        'eta' : (-5.,5.),
        'phi' : (-math.pi,math.pi),
        'mass' : (0.,100.),
    }
)
bias = BiasCallback(
    dataset = dataset_valid,
    preprocessing = combined_dataset.reco_dataset.preprocessing,
    N_sample = 30,
    frequency = 500,
    bins = 51,
    points = 30,
    log_scale = True,
    batch_size = 1024,
    N_batch = math.inf,
    label_names = {
        'pt' : '$p_T$',
        'eta' : '$\eta$',
        'phi' : '$\phi$',
        'mass' : '$M$',
    },
)

NameError: name 'dataset_valid' is not defined

In [None]:
##### Parameters #####
epochs = 501
steps_per_epoch_train = math.ceil(len(dataset_train)/loader_train.batch_size)

print (f'Training   : Batch size = {loader_train.batch_size} => {steps_per_epoch_train} steps per epoch')
##### Optimizer #####
optimizer = optim.Adam(model.parameters(), lr=1e-4)
model.set_optimizer(optimizer)

##### Scheduler #####
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer = optimizer,
    mode='min', 
    factor=0.1, 
    patience=10, 
    threshold=0.001, 
    threshold_mode='rel', 
    cooldown=0, 
    min_lr=1e-7
)
model.set_scheduler_config(
    {
        'scheduler' : scheduler,
        'interval' : 'step' if isinstance(scheduler,optim.lr_scheduler.OneCycleLR) else 'epoch',
        'frequency' : 1,
        'monitor' : 'val/loss_tot',
        'strict' : True,
        'name' : 'scheduler',
    }
)

##### Callbacks #####
callbacks = [
    L.pytorch.callbacks.LearningRateMonitor(logging_interval = 'epoch'),
    L.pytorch.callbacks.ModelSummary(max_depth=2),
    sampling,
    bias,
    ModelCheckpoint(save_every_n_epochs=10, save_dir="TransferFlow_checkpoints"),
]

##### Logger #####
logger = pl_loggers.CometLogger(
    save_dir = '../comet_logs',
    project_name = 'mem-flow-Hinv',
    experiment_name = 'transfer-flow',
    offline = False,
) 
logger.log_graph(model)
# logger.log_hyperparams()
# logger.experiment.log_code(folder='../src/')
# logger.experiment.log_notebook(filename=globals()['__session__'],overwrite=True)

##### Trainer #####
trainer = L.Trainer(    
    min_epochs = 5,
    max_epochs = epochs,
    callbacks = callbacks,
    devices = 'auto',
    accelerator = accelerator,
    logger = logger,
)
##### Fit #####
trainer.fit(
    model = model, 
    train_dataloaders = loader_train,
    val_dataloaders = loader_valid,
    #ckpt_path="TransferFlow_checkpoints/model_epoch_230.ckpt" # Use to resume training from a checkpoint
)

Training   : Batch size = 1024 => 81 steps per epoch


CometLogger will be initialized in online mode
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/themrluke/mem-flow-hinv/6d4cd14bd93c471aa763d0bbfbd638e6

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]

  | Name                          | Type            | Params | Mode 
--------------------------------------------------------------------------
0 | encoder_embeddings            | MultiEmbeddings | 6.9 K  | train
1 | encoder_embeddings.embeddings | ModuleList      | 6.9 K  | train
2 | decoder_embeddings            | MultiEmbeddings | 4.6 K  | train
3 | decoder_embeddings.embeddings | ModuleList      | 4.6 K  | train
4 | transformer                   | Transformer     | 834

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/software/sa21722/miniconda3/envs/mem-flow/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.
  output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
  return torch.searchsorted(seq, value).squeeze(dim=-1)
/software/sa21722/miniconda3/envs/mem-flow/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Checkpoint saved at epoch 10: TransferFlow_checkpoints/model_epoch_10.ckpt


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [None]:
figs = sampling.make_plots(model.cuda(),show=True)

In [None]:
figs = bias.make_plots(model.cuda(),show=True)

In [None]:
def dphi_j1j2(jets,met):
    return jets[:,0].deltaphi(jets[:,1])
        
def dR_j1j2(jets,met):
    return jets[:,0].deltaR(jets[:,1])

def HT(jets,met):
    return ak.sum(jets.pt,axis=1)

def dR_met_j1j2(jets,met):
    j1j2 = jets[:,0] + jets[:,1]
    return met[:,0].deltaR(j1j2)

def min_mass_jj(jets,met):
    # Make all possible jet pairs for each event #
    dijets = ak.combinations(jets,n=2,replacement=False,axis=1)
    # Split into pairs of jets #
    j1, j2 = ak.unzip(dijets)
    # Calculate minimum invariant mass for all pairs #
    return ak.min((j1+j2).mass,axis=1)

high_level = HighLevelVariableCallback(
    dataset = dataset_valid,
    preprocessing = combined_dataset.reco_dataset.preprocessing,
    N_sample = 30,
    frequency = 100,
    bins = 51,
    log_scale = True,
    batch_size = 1000,
    N_batch = math.inf,
    var_functions = {
        'dphi_j1j2'   : dphi_j1j2,
        'dR_j1j2'     : dR_j1j2,
        'HT'          : HT,
        'dR_met_j1j2' : dR_met_j1j2,
        'min_mass_jj' : min_mass_jj,
    },
    label_names = {
        'dphi_j1j2'      : r'$\Delta \phi(j_1,j_2)$',
        'dR_j1j2'        : r'$\Delta R(j_1,j_2)$',
        'dR_met_j1j2'    : r'$\Delta R(MET,jj)$',
        'HT'             : r'$H_T$',
        'min_mass_jj'    : r'$min_{j1,j2 \in \text{jets}} (m_{j_1j_2})$',
    },
)
samples = high_level.make_plots(model.cuda(),show=True)