In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:98% !important; }</style>"))

In [None]:
import torch
import torch_geometric
import torch
import os
import os.path as osp
import numpy as np
from tqdm import tqdm as tqdm

In [None]:
# import tau events
import glob

raw_dir='/grid_mnt/data__data.polcms/cms/sghosh/TAUGNN/GRAPH_TENTAU/'
fnamelist = [filepath for filepath in glob.glob(raw_dir+'data*.pt')]
data_list_tau = []

### load graphs
for i in tqdm(fnamelist):
    idx = torch.load(i)
    data_list_tau.append(idx)
    

totalevpho = len(data_list_tau)
print("total events:",totalevpho)

In [None]:
ntrain = 7000
import random

data_list_comb = data_list_tau  
random.shuffle(data_list_comb)
totalev = len(data_list_comb)
print("total comb evs:",totalev)


### test train split etc

import torch_geometric
ntrainbatch = 100
ntestbatch = 100
trainloader = torch_geometric.data.DataLoader(data_list_comb[:totalev-900], batch_size=ntrainbatch)
testloader = torch_geometric.data.DataLoader(data_list_comb[totalev-900:totalev], batch_size=ntestbatch)
epoch_size = len(data_list_comb[:totalev-900])
print("epoch size,batch_size:",epoch_size,ntrainbatch)

In [None]:
### define the network

import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T

from torch.utils.checkpoint import checkpoint
from torch_cluster import knn_graph

from torch_geometric.nn import EdgeConv
from torch_geometric.utils import remove_self_loops
from torch_geometric.utils.undirected import to_undirected

transform = T.Cartesian(cat=False)



class TauNetwork(nn.Module):
    
    def __init__(self, input_dim=24, hidden_dim=50, output_dim=2, k=16, aggr='add'):#,
        super(TauNetwork, self).__init__()
        
        self.k = k
        start_width = 2 * hidden_dim
        middle_width = 3 * hidden_dim // 2

        
        ### input net takes to higher dim representation
        self.inputnet =  nn.Sequential(
            nn.Linear(input_dim, hidden_dim*2),            
            nn.ELU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim*2, hidden_dim),
            nn.ELU(),
        )
        
        ## nets used in edgeconv
        convnn1 = nn.Sequential(nn.Linear(start_width, middle_width),
                                nn.ELU(),
                                nn.Dropout(0.2),
                                nn.Linear(middle_width, middle_width),                                             
                                nn.ELU(),
                                nn.Dropout(0.2),
                                nn.Linear(middle_width, middle_width),                                             
                                nn.ELU(),
                                nn.Dropout(0.2),
                                
                                nn.Linear(middle_width, hidden_dim),                                             
                                nn.ELU()
                                )
        convnn2 = nn.Sequential(nn.Linear(start_width, middle_width),
                                nn.ELU(),
                                nn.Dropout(0.2),
                                nn.Linear(middle_width, output_dim),                                             
                                nn.ELU()
                                )
        
        
                
        self.edgeconv1 = EdgeConv(nn=convnn1, aggr=aggr)
        self.edgeconv2 = EdgeConv(nn=convnn2, aggr=aggr)
        
        

    ### data flow    
    def forward(self, data):        
        data.x = self.inputnet(data.x)

        data.edge_index = to_undirected(knn_graph(data.x, self.k, data.batch, loop=False, flow=self.edgeconv1.flow))
        data.x = self.edgeconv1(data.x, data.edge_index)
        data.edge_index = to_undirected(knn_graph(data.x, self.k, data.batch, loop=False, flow=self.edgeconv2.flow))
        
        
        
        x = self.edgeconv2(data.x, data.edge_index)
        
        

        return x 

In [None]:
## LR optimiser

from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
import math
import torch
import sys

class ReduceMaxLROnRestart:
    def __init__(self, ratio=0.75):
        self.ratio = ratio
        
        def __call__(self, eta_min, eta_max):
            return eta_min, eta_max * self.ratio
        
        
class ExpReduceMaxLROnIteration:
    def __init__(self, gamma=1):
        self.gamma = gamma
        
    def __call__(self, eta_min, eta_max, iterations):
        return eta_min, eta_max * self.gamma ** iterations


class CosinePolicy:
    def __call__(self, t_cur, restart_period):
        return 0.5 * (1. + math.cos(math.pi *
                                    (t_cur / restart_period)))
    
    
class ArccosinePolicy:
    def __call__(self, t_cur, restart_period):
        return (math.acos(max(-1, min(1, 2 * t_cur
                                      / restart_period - 1))) / math.pi)
    
    
class TriangularPolicy:
    def __init__(self, triangular_step=0.5):
        self.triangular_step = triangular_step
        
    def __call__(self, t_cur, restart_period):
        inflection_point = self.triangular_step * restart_period
        point_of_triangle = (t_cur / inflection_point
                             if t_cur < inflection_point
                             else 1.0 - (t_cur - inflection_point)
                             / (restart_period - inflection_point))
        return point_of_triangle
    
    
class CyclicLRWithRestarts(_LRScheduler):
    """Decays learning rate with cosine annealing, normalizes weight decay
    hyperparameter value, implements restarts.
    https://arxiv.org/abs/1711.05101
    Args:
        optimizer (Optimizer): Wrapped optimizer.
        batch_size: minibatch size
        epoch_size: training samples per epoch
        restart_period: epoch count in the first restart period
        t_mult: multiplication factor by which the next restart period will expand/shrink
        policy: ["cosine", "arccosine", "triangular", "triangular2", "exp_range"]
        min_lr: minimum allowed learning rate
        verbose: print a message on every restart
        gamma: exponent used in "exp_range" policy
        eta_on_restart_cb: callback executed on every restart, adjusts max or min lr
        eta_on_iteration_cb: callback executed on every iteration, adjusts max or min lr
        triangular_step: adjusts ratio of increasing/decreasing phases for triangular policy
    Example:
        >>> scheduler = CyclicLRWithRestarts(optimizer, 32, 1024, restart_period=5, t_mult=1.2)
        >>> for epoch in range(100):
        >>>     scheduler.step()
        >>>     train(...)
        >>>         ...
        >>>         optimizer.zero_grad()
        >>>         loss.backward()
        >>>         optimizer.step()
        >>>         scheduler.batch_step()
        >>>     validate(...)
    """
    
    def __init__(self, optimizer, batch_size, epoch_size, restart_period=100,
                 t_mult=2, last_epoch=-1, verbose=False,
                 policy="cosine", policy_fn=None, min_lr=1e-7,
                 eta_on_restart_cb=None, eta_on_iteration_cb=None,
                 gamma=1.0, triangular_step=0.5):
        
        if not isinstance(optimizer, Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        
        self.optimizer = optimizer
        
        if last_epoch == -1:
            for group in optimizer.param_groups:
                group.setdefault('initial_lr', group['lr'])
                group.setdefault('minimum_lr', min_lr)
        else:
            for i, group in enumerate(optimizer.param_groups):
                if 'initial_lr' not in group:
                    raise KeyError("param 'initial_lr' is not specified "
                                   "in param_groups[{}] when resuming an"
                                   " optimizer".format(i))
                
        self.base_lrs = [group['initial_lr'] for group
                         in optimizer.param_groups]
        
        self.min_lrs = [group['minimum_lr'] for group
                        in optimizer.param_groups]
        
        self.base_weight_decays = [group['weight_decay'] for group
                                   in optimizer.param_groups]
        
        self.policy = policy
        self.eta_on_restart_cb = eta_on_restart_cb
        self.eta_on_iteration_cb = eta_on_iteration_cb
        if policy_fn is not None:
            self.policy_fn = policy_fn
        elif self.policy == "cosine":
            self.policy_fn = CosinePolicy()
        elif self.policy == "arccosine":
            self.policy_fn = ArccosinePolicy()
        elif self.policy == "triangular":
            self.policy_fn = TriangularPolicy(triangular_step=triangular_step)
        elif self.policy == "triangular2":
            self.policy_fn = TriangularPolicy(triangular_step=triangular_step)
            self.eta_on_restart_cb = ReduceMaxLROnRestart(ratio=0.5)
        elif self.policy == "exp_range":
            self.policy_fn = TriangularPolicy(triangular_step=triangular_step)
            self.eta_on_iteration_cb = ExpReduceMaxLROnIteration(gamma=gamma)
            
        self.last_epoch = last_epoch
        self.batch_size = batch_size
        self.epoch_size = epoch_size
        
        self.iteration = 0
        self.total_iterations = 0
        
        self.t_mult = t_mult
        self.verbose = verbose
        self.restart_period = math.ceil(restart_period)
        self.restarts = 0
        self.t_epoch = -1
        self.epoch = -1
        
        self.eta_min = 0
        self.eta_max = 1
        
        self.end_of_period = False
        self.batch_increments = []
        self._set_batch_increment()
        
    def _on_restart(self):
        if self.eta_on_restart_cb is not None:
            self.eta_min, self.eta_max = self.eta_on_restart_cb(self.eta_min,
                                                                self.eta_max)
            
    def _on_iteration(self):
        if self.eta_on_iteration_cb is not None:
            self.eta_min, self.eta_max = self.eta_on_iteration_cb(self.eta_min,
                                                                  self.eta_max,
                                                                  self.total_iterations)
            
    def get_lr(self, t_cur):
        eta_t = (self.eta_min + (self.eta_max - self.eta_min)
                 * self.policy_fn(t_cur, self.restart_period))
        
        weight_decay_norm_multi = math.sqrt(self.batch_size /
                                            (self.epoch_size *
                                             self.restart_period))
        
        lrs = [min_lr + (base_lr - min_lr) * eta_t for base_lr, min_lr
               in zip(self.base_lrs, self.min_lrs)]
        weight_decays = [base_weight_decay #* eta_t * weight_decay_norm_multi
                         for base_weight_decay in self.base_weight_decays]
        
        if (self.t_epoch + 1) % self.restart_period < self.t_epoch:
            self.end_of_period = True
            
        if self.t_epoch % self.restart_period < self.t_epoch:
            if self.verbose:
                print("Restart {} at epoch {}".format(self.restarts + 1,
                                                      self.last_epoch))
            self.restart_period = math.ceil(self.restart_period * self.t_mult)
            self.restarts += 1
            self.t_epoch = 0
            self._on_restart()
            self.end_of_period = False
            
        return zip(lrs, weight_decays)
        
    def _set_batch_increment(self):
        d, r = divmod(self.epoch_size, self.batch_size)
        batches_in_epoch = d + 2 if r > 0 else d + 1
        self.iteration = 0
        self.batch_increments = torch.linspace(0, 1, batches_in_epoch).tolist()
        
    def step(self):
        self.last_epoch += 1
        self.t_epoch += 1
        self._set_batch_increment()
        self.batch_step()
        
    def batch_step(self):
        try:
            t_cur = self.t_epoch + self.batch_increments[self.iteration]
            self._on_iteration()
            self.iteration += 1
            self.total_iterations += 1
        except (IndexError):
            raise StopIteration("Epoch size and batch size used in the "
                                "training loop and while initializing "
                                "scheduler should be the same.")
        
        for param_group, (lr, weight_decay) in zip(self.optimizer.param_groups,
                                                   self.get_lr(t_cur)):
            param_group['lr'] = lr
            param_group['weight_decay'] = weight_decay

In [None]:
from torch_geometric.data import DataLoader
from tqdm import tqdm_notebook as tqdm

## model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.drn = TauNetwork(input_dim=24,
                            hidden_dim=50, #50
                            k=4,#16
                             )
        
    def forward(self, data):
        logits = self.drn(data)
        return logits

## device and optimiser
device = torch.device('cuda:0')#('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-3)
scheduler = CyclicLRWithRestarts(optimizer, ntrainbatch, epoch_size, restart_period=100, t_mult=1.2, policy="cosine")


### loff func definitions      
mseloss = torch.nn.MSELoss()
def categorical_loss_mse(outputa,trutha):
    total_loss = 1.0*mseloss(torch.tensor(len(outputa[trutha == 1][outputa[trutha == 1] == 1]),dtype=torch.float32),torch.tensor(len(outputa[outputa == 1]),dtype=torch.float32))
    return total_loss

def categorical_loss_mse2(outputa,trutha):
    total_loss =  .01*losscat(outputa[trutha ==1],trutha[trutha ==1].long()) + 0.01*losscat(outputa[trutha ==0],trutha[trutha ==0].long()) 
    return total_loss



### train func
def train(epoch):
    model.train()
    torch.cuda.empty_cache()
    scheduler.step()
    loss = []
    for data in tqdm(trainloader):
            data = data.to(device)        
            optimizer.zero_grad()
            result = model(data)
            lossc = categorical_loss_mse2(result, data.y)
            loss.append(lossc.item()) 
            lossc.backward()
            optimizer.step()
            scheduler.batch_step()
    print( 'batches for train:',len(loss)) 
    print('train loss:',np.mean(np.array(loss)))
    return np.mean(np.array(loss))

from scipy.stats import norm
import matplotlib.mlab as mlab
import scipy.stats as scs
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
%matplotlib inline

def gaussian(x,  mean,a, sigma):
    return a * np.exp(-((x - mean)**2 / (2 * sigma**2)))


### eval func
def evaluate(epoch):
        """"Evaluate the model"""
        model.eval()
        pred = []
        true = []
        loss= []
        
        correct = 0
        predc = []
        truec = []
        for data in tqdm(testloader):
            data = data.to(device)        
            result = model(data)
            lossc = categorical_loss_mse2(result, data.y)
            loss.append(lossc.item())

            for i in result:
                predc.append(i.detach().cpu().argmax())
            for i in data.y.detach():
                truec.append(i.detach().cpu())
            
        predc = np.array(predc)
        truec = np.array(truec)
        print("total accuracy  :",np.equal(predc,truec).sum()/len(truec))
        print("predc:",np.unique(predc,return_counts=True))
        print("truec:",np.unique(truec,return_counts=True))
        
        
        print("accuracy class 0 :",np.equal(predc[truec == 0],truec[truec == 0]).sum()/len(truec[truec == 0]))
        print("class 0 predicted as class 1:",len(predc[truec == 0][predc[truec == 0] == 1])/len(truec[truec == 0]))
        print("accuracy class 1 :",np.equal(predc[truec == 1],truec[truec == 1]).sum()/len(truec[truec == 1]))
        print("class 1 predicted as class 0:",len(predc[truec == 1][predc[truec == 1] == 0])/len(truec[truec == 1]))
        print("true# class1 / total# class1:",len(predc[truec == 1][predc[truec == 1] == 1])/len(predc[predc == 1]) )
        print("true# class0 / total# class1:",len(predc[truec == 0][predc[truec == 0] == 1])/len(predc[predc == 1]) )
        
        print('batches for test:', len(loss)) 
        print('test loss:',np.mean(np.array(loss)))

        return np.mean(np.array(loss)),np.equal(predc,truec).sum()/len(truec),np.equal(predc[truec == 0],truec[truec == 0]).sum()/len(truec[truec == 0]),np.equal(predc[truec == 1],truec[truec == 1]).sum()/len(truec[truec == 1]),len(predc[truec == 1][predc[truec == 1] == 1])/len(predc[predc == 1])


In [None]:
from matplotlib.pyplot import figure
figure(figsize=(20, 10), dpi=80)
font = {'family' : 'normal',
        'weight' : 'bold',
        'size'   : 13}

plt.rc('axes', labelsize=20)
plt.rc('font', **font)



from tqdm import tqdm_notebook as tqdm

checkpoint_dir = '/home/llr/cms/sghosh/HGCAL_TICL_STUFF/RECOTESTS_ORGANISED/MODTICLLXPLUS/CMSSW_11_2_0_pre10/src/RecoNtuples/HGCalAnalysis/test/GENGRAPHS/TAUTRAIN/ouput_train_test/'

os.makedirs(checkpoint_dir, exist_ok=True)
best_loss = 99999999
losst = []
lossv = []
epochs = []
acc = []
acc0=[]
acc1=[]
pur = []


### train!!!  and save models 

for epoch in range(1, 501):
    print ('epoch:',epoch)
    losst.append(train(epoch))
    loss_epoch,accuracy,accuracy0,accuracy1,purity1 = evaluate(epoch)
    lossv.append(loss_epoch)
    acc.append(accuracy)
    acc0.append(accuracy0)
    acc1.append(accuracy1)
    pur.append(purity1)
    epochs.append(epoch)
    checkpoint = {
    'epoch': epoch + 1,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict()
    }
    
    
    
    checkpoint_file = 'model_epoch_%i.pth.tar' % ( epoch )
    torch.save(checkpoint,
                   os.path.join(checkpoint_dir,checkpoint_file ))
    if loss_epoch < best_loss:
        best_loss = loss_epoch
        print('new best test loss:',best_loss)
        torch.save(checkpoint,
                   os.path.join(checkpoint_dir,'model_checkpoint_best.pth.tar' ))
    if (epoch%10 == 0):
        figure(figsize=(20, 10), dpi=80)
        plt.plot(np.array(epochs),np.array(losst),c='b',label='training')
        plt.plot(np.array(epochs),np.array(lossv),c='r',label='testing')
        plt.legend()
        plt.grid()
        plt.show()
        figure(figsize=(20, 10), dpi=80)
        plt.plot(np.array(epochs),np.array(acc),label='total acc')
        plt.plot(np.array(epochs),np.array(acc0),label='acc class0(non tau)')
        plt.plot(np.array(epochs),np.array(acc1),label='acc class1(tau)')
        plt.legend()
        plt.grid()
        plt.show()
        
        figure(figsize=(20, 10), dpi=80)
        plt.plot(np.array(epochs),np.array(pur),label='true class1(tau)/total label class1')
        plt.legend()
        plt.grid()
        plt.show()
        
    #print(acc0,acc1)

In [None]:
### load model

model_fname =  '/home/llr/cms/sghosh/HGCAL_TICL_STUFF/RECOTESTS_ORGANISED/MODTICLLXPLUS/CMSSW_11_2_0_pre10/src/RecoNtuples/HGCalAnalysis/test/GENGRAPHS/TAUTRAIN/ouput_train/model_checkpoint_best.pth.tar'
mdl=Net().to(device)
mdl.load_state_dict(torch.load(model_fname)['state_dict'])
mdl.eval()

In [None]:
### load model and do tests


valloader = torch_geometric.data.DataLoader(data_list_comb[totalev-900:totalev], batch_size=1)

count = 0
for data in tqdm(valloader):
    
    
    font = {'family' : 'normal',
        'weight' : 'bold',
        'size'   : 13}

    plt.rc('axes', labelsize=20)
    plt.rc('font', **font)
    
    
    
    count += 1
    
    
    data_uncorr = data.clone()
    data = data.to(device)
    result = mdl(data)
    
    truth = data.y.detach().cpu()
    trutharr = np.array(truth)
    #print(len(truth))
    #for i,j in zip(result,data.z.detach()):
    #    pred.append(i.detach().cpu())
    #data.x = data.x.detach().cpu()
    #print(data.x)
    #print(np.unique(data.x[:,0],return_counts=True))
    resultarr = np.argmax(result.detach().cpu(),axis=1)
    #print((np.array(resultarr)))
    #fig, (ax0,ax1,ax2) = plt.subplots(1, 3, figsize=(20,12))
    fig, (ax0,ax1) = plt.subplots(1, 2, figsize=(40,10))
    ax0.scatter(data_uncorr.x[:,0],data_uncorr.x[:,1],s=data_uncorr.x[:,3]*10,c='b',label="all 3dclus")
    ax0.scatter(data_uncorr.x[:,0][resultarr == 1],data_uncorr.x[:,1][resultarr == 1],s=data_uncorr.x[:,3][resultarr == 1]*10,c = 'r',label="pred taus")
    ax0.scatter(data_uncorr.x[:,0][trutharr == 1],data_uncorr.x[:,1][trutharr == 1],s=data_uncorr.x[:,3][trutharr == 1]*10,c = 'g',marker = "*",label="true taus")
    ax0.set_xlabel("eta")
    ax0.set_ylabel("phi")
    ax0.legend(prop={'size': 20})
    #print(data.x[:,0][resultarr == 1])
    #print(data.x[:,1][resultarr[resultarr == 1]])
    ax1.scatter(data_uncorr.x[:,0][resultarr == 1],data_uncorr.x[:,1][resultarr == 1],s=data_uncorr.x[:,3][resultarr == 1]*10,c="r",label="pred taus")
    ax1.scatter(data_uncorr.x[:,0][trutharr == 1],data_uncorr.x[:,1][trutharr == 1],s=data_uncorr.x[:,3][trutharr == 1]*10,c = 'g',marker = "*",label="true taus")
    ax1.legend(prop={'size': 20})
    ax1.set_xlabel("eta")
    ax1.set_ylabel("phi")
    #ax0.scatter(data.x[:,0],data.x[:,1],s=data.x[:,3]*10)
    plt.show()
    plt.close()
    
    if count ==100:
        break
    

In [None]:
valloader = torch_geometric.data.DataLoader(data_list_comb[totalev-900:totalev], batch_size=1)

count = 0
trl = []
prl = []

for data in tqdm(valloader):
    
    data = data.to(device)
    result = mdl(data)
    
    truth = data.y.detach().cpu()
    trutharr = np.array(truth)
    
    resultarr = result.detach().cpu()
    pred = [F.softmax(i).numpy() for i in resultarr]
    for i,j in zip(pred,trutharr):
        trl.append(int(j))
        prl.append(i[1])


In [None]:
### plot scores
prl = np.array(prl)
trl = np.array(trl)

plt.hist(prl[trl == 1],range=[0,1],bins=100,label="tau",alpha=0.5)
plt.hist(prl[trl == 0],range=[0,1],bins=100,label="PU",alpha=0.5)

plt.yscale("log")
plt.legend()
plt.show()

In [None]:
### and what would life be without ROCs

from sklearn import metrics

trll=trl
prll=prl


fpr, tpr, threshold = metrics.roc_curve(trll, prll)
roc_auc = metrics.auc(fpr, tpr)

tprv = 0
tnrv = 0
thrshv = 0
for i in range(len(fpr)):
    if (tpr[i] > 0.99 ):
        print("tpr|1-fpr|threshold",tpr[i],1-fpr[i],threshold[i])
        tprv = tpr[i]
        tnrv = 1-fpr[i]
        thrshv = threshold[i]
        break
        
# method I: plt
import matplotlib.pyplot as plt
plt.title('Receiver Operating Characteristic')
plt.plot(tpr, fpr, 'b', label = 'AUC = %0.2f' % roc_auc)
plt.legend(loc = 'upper left')
plt.plot([0, 1], [0, 1],'r--')

plt.xlabel('sig eff')
plt.ylabel('bkg eff')
plt.show()
