In [None]:
import argparse
import os, sys
import torch
import tqdm
import torch.nn as nn
import math
import random
import pickle as pk
import pandas as pd
from dataset import VAdatasetDecoder
from transformer import VertexFitting
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib import font_manager
import matplotlib
from trainer import plot_scatter, plot_scatter_len, set_seed

# set parameters for matplotlib
plt.rcdefaults()
from pathlib import Path
font_path = str(Path(matplotlib.get_data_path(), "fonts/ttf/cmr10.ttf"))
font_manager.fontManager.addfont(font_path)
prop = font_manager.FontProperties(fname=font_path)
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = prop.get_name()
plt.rcParams["axes.formatter.use_mathtext"] = True
params = {'mathtext.default': 'regular' }          
plt.rcParams.update(params)
torch.multiprocessing.set_sharing_strategy('file_system')

# manually specify the GPUs to use
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device = 'cpu'

In [None]:
# parameters

parser = argparse.ArgumentParser()
parser.add_argument("-c", "--dataset_path", required=True, type=str, help="train dataset path")
parser.add_argument("-im", "--img_size", type=int, default=5, help="image size (one dimension)")
parser.add_argument("-is", "--input_size", type=int, default=1, help="input dimension (per cube)")
parser.add_argument("-ts", "--target_size", type=int, default=7, help="target size")
parser.add_argument("-ns", "--noise_size", type=int, default=100, help="size of the noise")
parser.add_argument("-hs", "--hidden", type=int, default=256, help="hidden size of transformer model")
parser.add_argument("-dr", "--dropout", type=float, default=0.1, help="dropout of the model")
parser.add_argument("-el", "--encoder_layers", type=int, default=8, help="number of encoder layers")
parser.add_argument("-dl", "--decoder_layers", type=int, default=8, help="number of decoder layers")
parser.add_argument("-a", "--attn_heads", type=int, default=8, help="number of attention heads")
parser.add_argument("-b", "--batch_size", type=int, default=64, help="number of batch_size")
parser.add_argument("-e", "--epochs", type=int, default=10, help="number of epochs")
parser.add_argument("-w", "--num_workers", type=int, default=5, help="dataloader worker size")
parser.add_argument("-lr", "--learning_rate", type=float, default=1e-3, help="learning rate of adam")
parser.add_argument("-wd", "--weight_decay", type=float, default=0.01, help="weight_decay of adam")
parser.add_argument("-b1", "--beta1", type=float, default=0.9, help="adam first beta value")
parser.add_argument("-b2", "--beta2", type=float, default=0.999, help="adam second beta value")   
parser.add_argument("--eps", type=float, default=1e-9, help="value to prevent division by zero")
parser.add_argument("-dis", "--disable", type=bool, default=False, help="whether to show the training progress")
parser.add_argument("-lo", "--load", type=bool, default=False, help="whether to load a pretrained model")
parser.add_argument("-s", "--save", type=bool, default=False, help="whether to save the model after every epoch")
parser.add_argument("-sp", "--save_path", type=str, default=".", help="whether to save the model after every epoch")
parser.add_argument("-es", "--early_stopping", type=int, default=10, help="early stopping count")
parser.add_argument('-sr','--source_range', nargs='+', type=int, default=[-1,1], help='source range')
parser.add_argument('-tr','--target_range', nargs='+', type=int, default=[0,1], help='source range')
parser.add_argument('-tc','--total_charge', type=bool, default=False, help='Use total charge in discriminator')
parser.add_argument('-cr','--crit_repeats', type=int, default=5, help='Critic repetitions')
parser.add_argument('-mp','--max_protons', type=int, default=3, help='Maximum number of protons')
parser.add_argument('-ws','--warmup_steps', type=int, default=1000, help='Maximum number of warmup steps')

args = parser.parse_args(["-c", "/scratch2/salonso/vertex_activity/images_5M/event{}.npz",
                          "-lo", 0,
                          "-b", "512",
                          "-ts", "3",
                          "-wd", "0",
                          "-dis", 0,
                          "-s", "True",
                          "-sp", "pretrained/full_transformer_5M_full_best",
                          "-e", "5000",
                          "-ns", "64",
                          "-el", "5",
                          "-dl", "5",
                          "-w", "16",
                          "-hs", "192",
                          "-a", "12",
                          "-is", "1",
                          "-lr", "0.0001",
                          "--eps", "1e-12",
                          "-tr", "0", "1",
                          "-tc", 0,
                          "-mp", "5",
                          "-ws", "2000",
                          "-im", "7",
                         ]
                        ) 

In [None]:
# ini dataset
PAD_IDX = -2
dataset = VAdatasetDecoder(args.dataset_path, source_range=args.source_range, 
                           target_range=args.target_range, max_protons=args.max_protons, PAD_IDX=PAD_IDX)

# sets
fulllen = len(dataset)
train_len = int(fulllen*0.6)
val_len = int(fulllen*0.1)
test_len = fulllen-train_len-val_len
train_set, val_set, test_set = random_split(dataset, [train_len, val_len, test_len],
                                            generator=torch.Generator().manual_seed(7))

# loaders
train_loader = DataLoader(train_set, batch_size=args.batch_size, collate_fn=dataset.collate_fn2, shuffle=True)
valid_loader = DataLoader(val_set, batch_size=args.batch_size, collate_fn=dataset.collate_fn2, shuffle=False)
test_loader = DataLoader(test_set, batch_size=args.batch_size, collate_fn=dataset.collate_fn2, shuffle=False)

In [None]:
'''
Auxiliary mask functions needed for the transformer network.
'''

def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt):
    src_seq_len = src.shape[0]+1
    tgt_seq_len = tgt.shape[0]+1

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=device).type(torch.bool)
    
    src_padding_mask = torch.zeros(size=(src.size(1), src.size(0)+1), dtype=torch.bool).to(device)
    src_padding_mask[:, 1:] = (src[:,:,0] == PAD_IDX).transpose(0, 1)
    
    tgt_padding_mask = torch.zeros(size=(tgt.size(1), tgt.size(0)+1), dtype=torch.bool).to(device)
    tgt_padding_mask[:, 1:] = (tgt[:,:,0] == PAD_IDX).transpose(0, 1)
    
    src[src==PAD_IDX] = 0
    
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

def create_mask_src(src):
    src_seq_len = src.shape[0]+1
    
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=device).type(torch.bool)
    src_padding_mask = torch.zeros(size=(src.size(1), src.size(0)+1), dtype=torch.bool).to(device)
    src_padding_mask[:, 1:] = (src[:,:,0] == PAD_IDX).transpose(0, 1)
    
    src[src==PAD_IDX] = 0
    
    return src_mask, src_padding_mask

def create_mask_tgt(tgt):
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    tgt_padding_mask = (tgt[:,:,0] == PAD_IDX).transpose(0, 1)
    
    return tgt_mask, tgt_padding_mask

In [None]:
# define model
model = VertexFitting(num_encoder_layers = args.encoder_layers,
                           num_decoder_layers = args.decoder_layers,
                           emb_size = args.hidden,
                           nhead = args.attn_heads,
                           img_size = args.img_size,
                           src_size = args.input_size,
                           tgt_size = args.target_size,
                           dropout = args.dropout,
                           maxlen = 5,
                           device = device, 
                           )

print(model)
model = model.to(device)

model_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(model)
print("total trainable params: {} (model).".format(model_total_params))

In [None]:
print("Loading saved model...")
checkpoint = torch.load(args.save_path)
model.load_state_dict(checkpoint['model_state_dict'])
epoch = checkpoint['epoch']+1
train_losses = checkpoint['train_losses']
val_losses = checkpoint['val_losses']
min_val_loss = min(val_losses)
count = checkpoint['count']
    
plt.figure(figsize=(10,5))
plt.title("Training and Validation Loss", fontsize=20)
plt.plot(val_losses,label="val")
plt.plot(train_losses,label="train")
plt.axvline(np.array(val_losses).argmin(), color="red", ls="--", label="best")
plt.xlabel("epochs", fontsize=15)
plt.ylabel("Loss", fontsize=15)
plt.yscale("log")
plt.ylim(3e-1,1)
#plt.ylim(0.215,0.23)
plt.legend(fontsize=15)
plt.show()

In [None]:
'''
Functions needed to evaluate the network.
''' 

def eval_event(event_n, model, test_set):
    """Evaluate the model on a single event.

    Args:
        event_n (int): the event number (from the test set).
        model: transformer neural network.
        test_set: the set used for testing.
    """
    # model to evaluate mode
    model.eval()
    
    # retrieve testing event from the dataset
    event = test_set[event_n]
    X, Y, vtx = event['images'], event['params'], event['ini_pos']
    event = dataset.collate_fn2([event])
    src, vtx, tgt, next_tgt, _ = event
    
    # set up the input
    tgt_input = tgt[:-1, :]
    src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
    src = src.to(device)
    src_mask = src_mask.to(device)
    src_padding_mask = src_padding_mask.to(device)

    '''
    Run the model (encoder part) on the input event
    
    Output:
        - memory: encoding learnt from the input that will be passed to the decoder.
        - vtx_pred: predicted vertex position.
    ''' 
    memory, vtx_pred = model.encode(src, src_mask, src_padding_mask)
    memory = memory.to(device)
        
    '''
    Run the model (decoder part) iteratively on the memory
    
    Output:
        - out: physics parameters of the decoded (subtracted) proton.
        - is_next: whether to keep decoding or not.
    ''' 
    ys = model.first_token.reshape(1,1,3)
    max_len = len(tgt)
    for i in range(max_len):
        # set up the input
        
        tgt_mask, tgt_padding_mask = create_mask_tgt(ys)
        
        # run the model
        out, is_next = model.decode(ys, memory, tgt_mask, tgt_padding_mask, src_padding_mask)
              
        # store output
        out_last = out[-1].reshape(1, out.shape[1], out.shape[2])
        is_next_last = is_next[-1,0]
        ys = torch.cat([ys, out_last], dim=0)
        
        # loop exiting condition
        if is_next_last.argmax(0) == 0:
            break
            
    # parse output
    X = X[:,:,0]
    Y = Y
    ys = ys.detach().cpu().numpy()[1:,0,:]
    vtx = vtx
    vtx_pred = vtx_pred.detach().cpu().numpy()
    Y[:, 0] = np.interp(Y[:, 0].ravel(), dataset.source_range, 
                       (dataset.min_KE, dataset.max_KE)).reshape(Y[:, 0].shape)
    Y[:, 1] = np.interp(Y[:, 1].ravel(), dataset.source_range, 
                       (dataset.min_theta, dataset.max_theta)).reshape(Y[:, 1].shape)
    Y[:, 2] = np.interp(Y[:, 2].ravel(), dataset.source_range, 
                       (dataset.min_phi, dataset.max_phi)).reshape(Y[:, 2].shape)
    ys[:, 0] = np.interp(ys[:, 0].ravel(), dataset.source_range, 
                       (dataset.min_KE, dataset.max_KE)).reshape(ys[:, 0].shape)
    ys[:, 1] = np.interp(ys[:, 1].ravel(), dataset.source_range, 
                       (dataset.min_theta, dataset.max_theta)).reshape(ys[:, 1].shape)
    ys[:, 2] = np.interp(ys[:, 2].ravel(), dataset.source_range, 
                       (dataset.min_phi, dataset.max_phi)).reshape(ys[:, 2].shape)
    vtx = np.interp(vtx.ravel(), dataset.source_range, 
                       (dataset.min_pos, dataset.max_pos)).reshape(vtx.shape)
    vtx_pred = np.interp(vtx_pred.ravel(), dataset.source_range, 
                       (dataset.min_pos, dataset.max_pos)).reshape(vtx_pred.shape)
        
    return X, Y, ys, vtx, vtx_pred

def eval_batch(batch, model):
    """Evaluate the model on a single event.

    Args:
        batch: batch of events to test.
        model: transformer neural network.
    """ 
    # model to evaluate mode
    model.eval()
   
    # retrieve testing batch
    src, vtx, tgt, next_tgt, _ = batch
    tgt_input = tgt[:-1, :]
    src = src.clone()

    # set up the input
    src_mask, src_padding_mask = create_mask_src(src)
    src = src.to(device)
    src_mask = src_mask.to(device)
    src_padding_mask = src_padding_mask.to(device)
 
    '''
    Run the model (encoder part) on the input event
    
    Output:
        - memory: encoding learnt from the input that will be passed to the decoder.
        - vtx_pred: predicted vertex position.
    ''' 
    memory, vtx_pred = model.encode(src, src_mask, src_padding_mask)
    memory = memory.to(device)
    
    # set up an structure to store the results from the decoder
    max_len = len(tgt)
    ys = torch.zeros(max_len+1, tgt_input.shape[1], 3).fill_(PAD_IDX).type(torch.float).to(device)
    ys_first = model.first_token.repeat(1, tgt_input.shape[1], 1)
    ys[0, :, :] = ys_first
    
    # keep track of predictions that finished (none before starting)
    prev_info = torch.ones(size=(src.shape[1],)).bool().to(device)
    
    del src, src_mask
    
    '''
    Run the model (decoder part) iteratively on the memory
    
    Output:
        - out: physics parameters of the decoded (subtracted) proton.
        - is_next: whether to keep decoding or not.
    ''' 
    for i in range(max_len):
        # create masks and run model
        tgt_mask, tgt_padding_mask = create_mask_tgt(ys[:i+1])
        out, is_next = model.decode(ys[:i+1], memory, tgt_mask, tgt_padding_mask, src_padding_mask)
        
        # reshape output
        out_last = out[-1].reshape(1, out.shape[1], out.shape[2])
        is_next_last = is_next[-1].argmax(1).bool()
    
        # only update results of predictions that haven't finished
        ys[i+1, prev_info, :] = out_last.detach()[:,prev_info,:]
        
        # update the information with predictions that just finished
        prev_info = torch.logical_and(prev_info, is_next_last) 
    
    return ys, vtx_pred

In [None]:
'''
Example: run the network on one event.

Output:
  - X: images of the protons involved in the VA. Shape: (#protons, 9, 9).
  - Y: physics parameters (KE, theta, phi) of the protons involved in the VA. Shape: (#protons, 3).
  - Y_hat: predicted physics parameters (KE, theta, phi) of the protons involved in the VA. Same shape as Y.
  - vtx: vertex of the VA. Shape: (1, 3).
  - vtx_hat: predicted vertex of the VA. Same shape as vtx_hat.
'''

event_n = 0
np.set_printoptions(suppress = True)
set_seed(7, random, np, torch) # for reproducibility
X, Y, Y_hat, vtx, vtx_hat = eval_event(event_n, model, test_set)

print("Vertex true: {}".format(vtx[0]))
print("Vertex pred: {}\n".format(vtx_hat[0]))
    
for i in range(max(len(Y), len(Y_hat))):
    if i<len(Y) and i<len(Y_hat):
        print("Proton {0}:\n (true) KE={1:.2f}, theta={2:.2f}, phi={3:.2f};"
              "(transformer) KE={4:.2f}, theta={5:.2f}, phi={6:.2f}".format(i+1, Y[i,0], Y[i,1], Y[i,2],
                                                                            Y_hat[i,0], Y_hat[i,1], Y_hat[i,2]))
    elif i>=len(Y) and i<len(Y_hat):
        print("Proton {0}:\n (transformer) KE={1:.2f},"
              "theta={2:.2f}, phi={3:.2f}".format(i+1, Y_hat[i,0], Y_hat[i,1], Y_hat[i,2]))
    
    elif i<len(Y) and i>=len(Y_hat):
        print("Proton {0}:\n (true) KE={1:.2f}, theta={2:.2f}, phi={3:.2f}".format(i+1, Y[i,0], Y[i,1], Y[i,2]))

In [None]:
'''
Test network on the entire test set.
'''

def test(loader, disable=False):
    
    batch_size = loader.batch_size
    n_batches = int(math.ceil(len(loader.dataset)/batch_size)) #if max_iters_train is None else max_iters_train
    t = tqdm.tqdm(enumerate(loader),total=n_batches, disable=disable)
    
    sum_loss = 0.
    Xs = []
    Ys = []
    Ys_hat = []
    Vtx = []
    Vtx_hat = []
    Lens = []
    
    for i, data in t:        
        src, vtx, tgt, next_tgt, lens = data
        
        y_hat, vtx_hat = eval_batch(data, model)
        
        src = src.cpu().numpy()
        tgt = tgt.cpu().numpy()
        y_hat = y_hat.detach().cpu().numpy()
        vtx = vtx.cpu().numpy()
        vtx_hat = vtx_hat.detach().cpu().numpy()
        lens = lens.cpu().numpy()
        
        for event_n in range(src.shape[1]):
            X = src[:,event_n,:]
            Y = tgt[:,event_n,:]
            Y_hat = y_hat[1:,event_n,:]
            Len = lens[:, event_n]
            
            # remove padding
            X = X[X!=PAD_IDX].reshape(-1,2)
            Y = Y[Y!=PAD_IDX].reshape(-1,3)
            Y_hat = Y_hat[Y_hat!=PAD_IDX].reshape(-1,3)
            Len = Len[Len!=PAD_IDX]
            
            Xs.append(X)
            Ys.append(Y)
            Ys_hat.append(Y_hat)
            Vtx.append(vtx[event_n])
            Vtx_hat.append(vtx_hat[event_n])
            Lens.append(Len)
    
    return Xs, Ys, Ys_hat, Vtx, Vtx_hat, Lens

set_seed(7, random, np, torch) # for reproducibility
Xs, Ys, Ys_hat, Vtx, Vtx_hat, Lens = test(test_loader)

In [None]:
'''
Generate a Pandas dataframe witht he results.
'''

dic = {'KE_true':[], 'KE_reco':[], 'theta_true':[], 'theta_reco':[], 
       'phi_true':[], 'phi_reco':[], 'vertex_true':[], 'vertex_reco':[],
       'lens':[], 'nparticles_true':[], 'nparticles_reco':[], 'eventid':[]}

t = tqdm.tqdm(range(len(Xs)), total=len(Xs), disable=False)
    
idx_res = 0
for i in t:        
    y = Ys[i]
    y_hat = Ys_hat[i]
    vtx = Vtx[i]
    vtx_hat = Vtx_hat[i]
    lens = Lens[i]
    
    nparticles_true = y.shape[0]
    nparticles_reco = y_hat.shape[0]
    min_nparticles = min(nparticles_true, nparticles_reco)
    
    y = y[:min_nparticles]
    y_hat = y_hat[:min_nparticles]
    lens = lens[:min_nparticles]
        
    KE_true = np.interp(y[:, 0].ravel(), dataset.source_range, 
                       (dataset.min_KE, dataset.max_KE)).reshape(y[:, 0].shape)
    KE_reco = np.interp(y_hat[:, 0].ravel(), dataset.source_range, 
                       (dataset.min_KE, dataset.max_KE)).reshape(y_hat[:, 0].shape)
    theta_true = np.interp(y[:, 1].ravel(), dataset.source_range, 
                       (dataset.min_theta, dataset.max_theta)).reshape(y[:, 1].shape)
    theta_reco = np.interp(y_hat[:, 1].ravel(), dataset.source_range, 
                       (dataset.min_theta, dataset.max_theta)).reshape(y_hat[:, 1].shape)
    phi_true = np.interp(y[:, 2].ravel(), dataset.source_range, 
                       (dataset.min_phi, dataset.max_phi)).reshape(y[:, 2].shape)
    phi_reco = np.interp(y_hat[:, 2].ravel(), dataset.source_range, 
                       (dataset.min_phi, dataset.max_phi)).reshape(y_hat[:, 2].shape)
    vertex_true = np.interp(vtx.ravel(), dataset.source_range, 
                       (dataset.min_pos, dataset.max_pos)).reshape(1, 3)
    vertex_reco = np.interp(vtx_hat.ravel(), dataset.source_range, 
                       (dataset.min_pos, dataset.max_pos)).reshape(1, 3)
        
    dic['KE_true'].append(KE_true)
    dic['KE_reco'].append(KE_reco)
    dic['theta_true'].append(theta_true)
    dic['theta_reco'].append(theta_reco)
    dic['phi_true'].append(phi_true)
    dic['phi_reco'].append(phi_reco)
    dic['vertex_true'].append(vertex_true)
    dic['vertex_reco'].append(vertex_reco)
    dic['nparticles_true'].append(nparticles_true)
    dic['nparticles_reco'].append(nparticles_reco)
    dic['lens'].append(lens)
    dic['eventid'].append(i)

    idx_res+=1
    
df = pd.DataFrame(dic, columns = ['eventid', 'KE_true', 'KE_reco', 'theta_true', 'theta_reco', 
                                  'phi_true', 'phi_reco', 'vertex_true', 'vertex_reco',
                                  'lens', 'nparticles_true', 'nparticles_reco'])

In [None]:
'''
Plot results (scatterplots)
'''

KE_true = np.concatenate(df.KE_true.values)
KE_reco = np.concatenate(df.KE_reco.values)
theta_true = np.concatenate(df.theta_true.values)
theta_reco = np.concatenate(df.theta_reco.values)
phi_true = np.concatenate(df.phi_true.values)
phi_reco = np.concatenate(df.phi_reco.values)
vertex_true = np.concatenate(df.vertex_true.values)
vertex_reco = np.concatenate(df.vertex_reco.values)
lens = np.concatenate(df.lens.values)
nparticles_true = df.nparticles_true.values
nparticles_reco = df.nparticles_reco.values

s = 0.005
plot_scatter(KE_true, KE_reco, label=("KE","[MeV]"), s=s)
plot_scatter(theta_true, theta_reco, label=(r"$\theta$","[rad]"), s=s)
plot_scatter(phi_true, phi_reco, label=(r"$\phi$","[rad]"), s=s)
plot_scatter(vertex_true[:,0], vertex_reco[:,0], label=("vertex X","[mm]"), s=s*3)
plot_scatter(vertex_true[:,1], vertex_reco[:,1], label=("vertex Y","[mm]"), s=s*3)
plot_scatter(vertex_true[:,2], vertex_reco[:,2], label=("vertex Z","[mm]"), s=s*3)
plot_scatter_len(lens, KE_reco-KE_true, label=("KE","[MeV]"), s=s*3)
plot_scatter_len(lens, theta_reco-theta_true, label=(r"$\theta$","[rad]"), s=s*3)
plot_scatter_len(lens, phi_reco-phi_true, label=(r"$\phi$","[rad]"), s=s*3)

In [None]:
'''
Plot results (number of protons predicted)
'''

from sklearn.metrics import classification_report, confusion_matrix
print(classification_report(nparticles_true, nparticles_reco, digits=3, 
                            labels=[1,2,3,4,5], target_names=[str(i+1)+" protons"for i in range(5)]))
conf = confusion_matrix(nparticles_reco, nparticles_true, labels=[1,2,3,4,5])
print(conf)