<a href="https://colab.research.google.com/github/swarr438/hmumu-rmm/blob/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Check GPU configuration
!nvidia-smi

#Enviroment Building#

In [None]:
!pip install --upgrade pip --quiet
# !pip install --upgrade numpy==1.24 --quiet
!pip install --upgrade matplotlib --quiet
!pip install --upgrade torchaudio torch==2.1.0  torchvision --quiet
!pip install tensorboard --quiet

!pip install torchmetrics --quiet

%matplotlib inline

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import models

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Load datasets from Google Drive
# or you can upload them manually
from google.colab import drive
from google.colab import files
import os

drive.mount('/content/drive')

In [None]:
#@title Extract datasets

Path_of_folder = '/content/drive/MyDrive/dataset/rmm/' # @param {type:"string"}
Filename = 'dataset_f.zip' # @param {type:"string"}
pathdata = Path_of_folder+Filename

import zipfile
with zipfile.ZipFile(pathdata, 'r') as zip_ref:
  zip_ref.extractall()

#Model Config#

In [None]:
#@title Dataset Class

import pandas as pd
from torch.utils.data import Dataset
from torchvision.io import read_image

'''Datasets & dataloaders'''
class ImageDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None, target_transform=None):
        """
        Arguments:
            csv_file (string): Path to the csv file with annotations.
            img_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.img_labels = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [None]:
from sqlalchemy.sql.selectable import SelectLabelStyle
#@title Dataset Loader

import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Subset

transform = transforms.Compose(
    [transforms.ToPILImage(),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

training_data = ImageDataset(csv_file='dataset0.csv',img_dir='./dataset0/',transform=transform)
validation_data = ImageDataset(csv_file='dataset1.csv',img_dir='./dataset1/',transform=transform)
test_data = ImageDataset(csv_file='dataset2.csv',img_dir='./dataset2/',transform=transform)
# val_data3 = ImageDataset(csv_file='dataset3.csv',img_dir='./dataset3/',transform=transform)
# val_data4 = ImageDataset(csv_file='dataset4.csv',img_dir='./dataset4/',transform=transform)
'''
training_data = ImageDataset(csv_file='dataset5.csv',img_dir='./dataset5/',transform=transform)
validation_data = ImageDataset(csv_file='dataset7.csv',img_dir='./dataset7/',transform=transform)
test_data = ImageDataset(csv_file='dataset6.csv',img_dir='./dataset6/',transform=transform)
'''
batch_size=4096
num_workers=2

train_loader = DataLoader(training_data, batch_size=batch_size,
                          shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(validation_data, batch_size=batch_size,
                        shuffle=True, num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(test_data, batch_size=batch_size,
                         shuffle=True, num_workers=num_workers, pin_memory=True)
# val_loader3 = DataLoader(val_data3, batch_size=batch_size,
#                         shuffle=True, num_workers=num_workers, pin_memory=True)
# val_loader4 = DataLoader(val_data4, batch_size=batch_size,
#                         shuffle=True, num_workers=num_workers, pin_memory=True)

global classes
classes = ('bkg','sig')#('bkg','ttH','VH','VBF','ggF')##('bkg','VH','ggF')

def subload(dataloader):
  dataset = dataloader.dataset
  select = list(range(0,len(dataset),10))
  subs = Subset(dataset, select)
  sub_loader = DataLoader(subs, batch_size=4096, shuffle=True,
                          num_workers=0,pin_memory=True)
  return sub_loader

In [None]:
#@title Initialize the ResNet Model

num_classes = 2 #@param{type:'integer'}

model = models.resnet18(weights=None)#'ResNet18_Weights.DEFAULT')
model.fc = nn.Linear(512, out_features=num_classes)
model = model.to(device)

FOUND_LR = 1e-3

params = [
          {'params': model.conv1.parameters(), 'lr': FOUND_LR / 10},
          {'params': model.bn1.parameters(), 'lr': FOUND_LR / 10},
          {'params': model.layer1.parameters(), 'lr': FOUND_LR / 8},
          {'params': model.layer2.parameters(), 'lr': FOUND_LR / 6},
          {'params': model.layer3.parameters(), 'lr': FOUND_LR / 4},
          {'params': model.layer4.parameters(), 'lr': FOUND_LR / 2},
          {'params': model.fc.parameters()}
         ]

#optimizer = torch.optim.Adam(params, lr = FOUND_LR)
optimizer = torch.optim.SGD(model.parameters(), lr=FOUND_LR, momentum=0.8)

MAX_LRS = [p['lr'] for p in optimizer.param_groups]

criterion = nn.CrossEntropyLoss()

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

#Training#

In [None]:
#@title Train Functions

from torch.utils.tensorboard import SummaryWriter
import time
!mkdir models

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

def testpf(loader):
   # corresponding test performance
   model.eval()
   with torch.no_grad():
        total = 0
        correct = 0
        for images, labels in loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            total += labels.size(0)
            losst = criterion(outputs, labels)
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()

        return losst.item(), total, 100*correct/total

def train(train_loader, test_loader, criterion, optimizer, scheduler,
          n_epochs=50, start_epoch=0):
  writer = SummaryWriter()
  nstep = 0
  for epoch in range(n_epochs):

    start_time = time.monotonic()

    for i, (images, labels) in enumerate(train_loader):
        # move images and labels to the configured device
        images = images.to(device)
        labels = labels.to(device)

        # forward pass
        model.train()
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        nstep+=1
        #scheduler.step()

    lossa, ta, crta = testpf(train_loader)
    losst, tt, crtt = testpf(test_loader)
    total_epoch = start_epoch+epoch
    writer.add_scalars('Loss', {'train':lossa,'test':losst}, total_epoch)
    writer.add_scalars('Accuracy', {'train':crta,'test':crtt}, total_epoch)
    writer.flush()

    scheduler.step(losst)

    end_time = time.monotonic()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    print('Epoch [{}/{}], {}m {}s, Loss: {:.4f}, '
                   .format(total_epoch+1, start_epoch+n_epochs,
                           epoch_mins, epoch_secs,
                           loss.item()))
    print('Accuracy of the network on the {} test images: {} %'
                   .format(tt, crtt))

    # autosave
    PATH = f"/content/models/{total_epoch+1}.pt"
    if (epoch+1)%10==0:
      torch.save({
              'epoch': total_epoch+1,
              'model': model.state_dict(),
              'optimizer': optimizer.state_dict(),
              'lr':  [ group['lr'] for group in optimizer.param_groups ],
              'loss': loss
              }, PATH)

  writer.close()
  return total_epoch, loss

In [None]:
from IPython.core.inputtransformer2 import show_linewise_tokens
#@title Check Datasets
import matplotlib.pyplot as plt
import numpy as np
import torchvision

# functions to show an image


def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(train_loader)
images, labels = next(dataiter)

# show images
shown = 16
imshow(torchvision.utils.make_grid(images[:shown]))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(shown)))

In [None]:
#@title Process

%reload_ext tensorboard
%tensorboard --logdir=runs

from torch.optim.lr_scheduler import ReduceLROnPlateau, OneCycleLR

n_epochs = 30 # @param{type:'integer'}
#@markdown For resuming, load the checkpoint first
resume = False #@param{type:"boolean"}
start_epoch = 0
if resume:
  try:
    start_epoch = epoch
  except NameError:
    start_epoch = 0

EPOCHS = n_epochs
STEPS_PER_EPOCH = len(train_loader)
TOTAL_STEPS = EPOCHS * STEPS_PER_EPOCH

# scheduler = OneCycleLR(optimizer, max_lr = MAX_LRS,
#                        total_steps = TOTAL_STEPS)

scheduler = ReduceLROnPlateau(optimizer, 'min', patience=3, cooldown=3,
                              min_lr=1e-10)

total_epoch, loss = train(test_loader, subload(train_loader),
                          criterion, optimizer, scheduler,
                          n_epochs, start_epoch=start_epoch)

In [None]:
#@title Save or Load the Checkpoint from Google Drive


ckpt_name  = 'rmm_f2.pt' # @param{type:"string"}
PATH = Path_of_folder+ckpt_name
#@markdown Intialize the resnet model before loading
load = False # @param {type:"boolean"}
if not load:
  epoch = start_epoch + n_epochs
  torch.save({
              'epoch': epoch,#total_epoch+1,
              'model': model.state_dict(),
              'optimizer': optimizer.state_dict(),
              'lr': [ group['lr'] for group in optimizer.param_groups ],
              'loss': loss
              }, PATH)
else:
  checkpoint = torch.load(PATH)
  model.load_state_dict(checkpoint['model'])
  optimizer.load_state_dict(checkpoint['optimizer'])
  epoch = checkpoint['epoch']
  loss = checkpoint['loss']

#Evaluation#

In [None]:
import torchmetrics

def score(model,loader,num_classes=2):
     with torch.no_grad():
        model.eval()
        real_res = torch.zeros(len(loader.dataset), dtype=torch.long)
        pred_res = torch.zeros(len(loader.dataset), num_classes)
        i = 0
        for images, labels in loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            interval = len(labels)

            probs = torch.nn.functional.softmax(outputs, dim=1)
            real_res[i:i+interval] = labels
            pred_res[i:i+interval] = probs

            i+=interval

        # accuracy
        thre_range = np.linspace(0,1,100)
        acc = torch.zeros(len(thre_range), num_classes)
        for k in range(len(thre_range)):
          if num_classes!=2:
            accuracy = torchmetrics.Accuracy(task="multiclass",
                                            num_classes=num_classes, average=None)
            acc[k] = 100*accuracy(pred_res,real_res)
          else:
            accuracy = torchmetrics.Accuracy(task="binary", threshold=thre_range[k])
            acc[k][0] = 100*accuracy(pred_res[:,0],1-real_res)
            acc[k][1] = 100*accuracy(pred_res[:,1],real_res)

        accuracy = torchmetrics.Accuracy(task="multiclass",
                    num_classes=num_classes, average='macro')
        ave_acc = 100*accuracy(pred_res,real_res)

        print(f'Total Events: {i}')

        # stat_scores = torchmetrics.StatScores(task="binary", average=None)
        # roc & f1 score
        roc = torchmetrics.ROC(task="multiclass",thresholds=100,num_classes=num_classes)
        f1s = torchmetrics.F1Score(task="multiclass",num_classes=num_classes,average=None)
        f1sa = torchmetrics.F1Score(task="multiclass",num_classes=num_classes,average='micro')

        return roc(pred_res,real_res), torch.transpose(acc,0,1), ave_acc.item()

In [None]:
roc, acc, ave_acc = score(model,test_loader,num_classes=num_classes)

fpr = roc[0].cpu().detach().numpy()
tpr = roc[1].cpu().detach().numpy()
thre = roc[2].cpu().detach().numpy()
acc = acc.cpu().detach().numpy()

In [None]:
import matplotlib.pyplot as plt

classindex = 1

fig = plt.figure(figsize=plt.figaspect(0.4))

if num_classes==2:
  classes = ('bkg','sig')
elif num_classes==5:
  classes = ('bkg','ttH','VH','VBF','ggF')

# 3d
ax = fig.add_subplot(1, 2, 2, projection='3d')
# for classindex in range(num_classes):
ax.plot3D(fpr[classindex], tpr[classindex], thre, color='orange')
ax.set_xlabel('FPR')
ax.set_ylabel('TPR')
ax.set_zlabel('Threshold')
ax.set_box_aspect(aspect=None, zoom=0.8)

# 2d
ax = fig.add_subplot(1, 2, 1)
#for classindex in range(num_classes):
auc = np.trapz(tpr[classindex],x=fpr[classindex])
ax.plot(fpr[classindex],tpr[classindex], color='orange',
        label=f'{classes[classindex]}, AUC = {auc:.2f}, Precision = {acc[classindex][50]:.1f}%')
anx = fpr[classindex][50]
any = tpr[classindex][50]
ax.plot(anx, any,'o',color='red')

ax.plot(anx, any,'o',color='red',label='Points at threshold = 0.5')

ax.set_xlabel('FPR')
ax.set_ylabel('TPR')
plt.title('ROC Curve')
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
ax.text(0.05, 0.95, '$\mathrm{Accuracy}='+'{:.2f}\%$'.format(ave_acc), transform=ax.transAxes, fontsize=14,
        verticalalignment='top', bbox=props)

# dashed line
ax.plot(np.linspace(0,1),np.linspace(0,1),'--',color='k')
lp = np.array((.5, .5))
th2 = ax.text(*lp, 'Random classifier', horizontalalignment='center',
              verticalalignment='center', color='k',
              fontsize=12, rotation=45, rotation_mode='anchor',
              transform_rotates_text=True)
# ax.annotate('$\mathrm{Threshold}=0.5$\n$\mathrm{Accuracy}='+'{:.2f}\%$'.format(acc[classindex][49]),
#             color='red', xy=(anx, any), xytext=(anx-0.45, any-0.05))
plt.legend(loc='lower right')
plt.show()

In [None]:
from matplotlib.lines import lineStyles
fig, ax = plt.subplots()
classname = ['bkg','sig']
for classindex in {1}:
  accx=acc[classindex]
  maxp = np.argmax(accx)
  thre_range = np.linspace(0,1,len(accx))

  ax.plot(thre_range, accx, color='orange', label=f'{classname[classindex]}')
  ax.plot(thre_range[maxp],accx[maxp],'o',color='red')
  ax.vlines(0.5,30,75,'k',linestyles='dashed')
  ax.annotate('Max: ${:.2f}$'.format(accx[maxp]),xy=(thre_range[maxp],accx[maxp]),
              xytext=(thre_range[maxp],accx[maxp]-3),color='red')
  ax.set_xlabel('Threshold')
  ax.set_ylabel('Accuracy(%)')

#plt.legend()
plt.show()

In [None]:
from collections import Counter

def decompose(loader, classindex):
  with torch.no_grad():
      model.eval()
      k = 0
      for images, labels in loader:
          images = images.to(device)
          labels = labels.to(device)
          outputs = model(images)

          _,preds = torch.max(outputs.data, 1)

          for i, label in enumerate(labels):
            if label==classindex:
              pred = torch.unsqueeze(preds[i],0)
              if k == 0:
                x = pred
              else:
                x = torch.cat((x,pred),0)

              k+=1

      print(f"{'label '+classes[classindex]:*^20}")
      x = x.cpu().detach().numpy()
      for (nam,numb) in Counter(x).items():
        namin = int(nam)
        numb /= k
        print(f'{classes[namin]}:{100*numb:.2f}%')

for q in range(num_classes):
  decompose(val_loader,q)
  print('-'*30)
