In [None]:
# Importing Libraries that are necessary

import cv2 as cv
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import Dataset,DataLoader,random_split
from matplotlib import pyplot as plt
from torchvision import transforms
from torchsummary import summary
from dataset1 import TrainData
import torch.optim as optim
from tqdm.notebook import tqdm
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix,classification_report

In [None]:
# Torch SummaryWriter to store generated plots and calculated accuracies

writer = SummaryWriter('runs/task1')

In [None]:


# Initialising Torch Device so that if there exists a GPU it can be utilised

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Model Architecture
 

In [None]:
'''

ResNet model was used. ResNet50 was modified where the intermediate blocks consisted of 1,2,2,2 layers 
instead of the 3,4,4,4 layers proposed.

'''

'\n\nResNet model was used. ResNet50 was modified where the intermediate blocks consisted of 1,2,2,2 layers \ninstead of the 3,4,4,4 layers proposed.\n\n'

In [None]:

'''
  Intermediate Block which is used on every layer.
  It has identity connections in it which form the core concept of ResNets
'''

class InterMediateBlock(nn.Module):
    
    def __init__(self,in_channels,inter_channels,identity_connection=None,stride=1):
        
        super(InterMediateBlock,self).__init__()
        self.expansion = 4   # Expansion refers to how the channel size should change with respect to the input channel
        
        # 1st layer
        self.conv1 = nn.Conv2d(
            in_channels,
            inter_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False
        )
        
        # 2nd layer
        self.conv2 = nn.Conv2d(
            inter_channels,
            inter_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False
        )
        
        # 3rd layer
        self.conv3 = nn.Conv2d(
            inter_channels,
            inter_channels*self.expansion,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False
        )
        
        
        self.bn1 = nn.BatchNorm2d(inter_channels)
        self.bn2 = nn.BatchNorm2d(inter_channels)
        self.bn3 = nn.BatchNorm2d(inter_channels*self.expansion)
        
        self.relu = nn.ReLU()
        
        # Identity Connection
        self.identity_connection = identity_connection
        
        # self.initialize_weights()
        
    
    def forward(self,x):
        
        id_connection = x.clone()
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        
        if self.identity_connection:
            
            id_connection = self.identity_connection(id_connection)
            
        x += id_connection
        
        x = self.relu(x)
        return x      

      

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight)

                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)



In [None]:
"""
  The complete ResNet model
"""

class ResNet(nn.Module):
    
    def __init__(self, block, no_layers, channels, num_classes):
        
        super(ResNet,self).__init__()
        self.in_channels = 64
        
        self.conv1 = nn.Conv2d(channels,64,kernel_size=7,stride=2,padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        
        self.layer1 = self.layer(InterMediateBlock,no_layers[0],64,1)
        self.layer2 = self.layer(InterMediateBlock,no_layers[1],128,2) 
        self.layer3 = self.layer(InterMediateBlock,no_layers[2],256,2)
        self.layer4 = self.layer(InterMediateBlock,no_layers[3],512,2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc =  nn.Linear(2048,num_classes)
        self.softmax = nn.Softmax(dim=1)
        # self.initialize_weights()
        
    def layer(self,block, no_layer_blocks, out_channels, stride):
        
        identity_connection = None
        layers = []
        
        if stride!=1 or self.in_channels!= out_channels*4:
            
            identity_connection = nn.Sequential(
                    nn.Conv2d(self.in_channels,out_channels*4,kernel_size=1,stride=stride),
                    nn.BatchNorm2d(out_channels*4),
            )
        layers.append(block(self.in_channels,out_channels,identity_connection,stride=stride))
        self.in_channels = out_channels*4
        
        
        for i in range(no_layer_blocks):
            layers.append(block(self.in_channels,out_channels))
        
        return nn.Sequential(*layers)
        
        
    def forward(self,x):
    
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = x.reshape(x.shape[0],-1)
        x = self.fc(x)
        # x = self.softmax(x)
        
        return x


    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight)

                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)
            

In [None]:


def create_model(img_channel=1, layers = [1,1,1,1], num_classes=10):
    return ResNet(InterMediateBlock, layers, img_channel, num_classes)


# Dataset Preparation

In [None]:
transform = transforms.Compose(
    [
        transforms.ToPILImage(),
        transforms.Resize((300,400)),
        transforms.CenterCrop((210,280)),
        transforms.ToTensor(),
    ]
)

In [None]:

dataset = TrainData(root='./train' ,transform=transform)
train_data,val_data = random_split(dataset,[2200,280])
train_loader = DataLoader(train_data,batch_size=32,shuffle=True)
val_loader = DataLoader(val_data,batch_size=32,shuffle=True)

# Training

In [None]:
num_classes = 62
classes = [str(i) for i in range(num_classes)]

In [None]:


def check_accuracy_probs_preds_loss(loader,loss_fn , model,mode='train',global_step=None):


    num_correct = 0
    num_samples = 0
    class_probs = []
    class_preds = []
    losses = []

    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device)
            y = y.to(device=device)

            scores = model(x)
            loss = loss_fn(scores,y)

            losses.append(loss.item())

            _, predictions = scores.max(1)
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)


            class_probs_batch = [F.softmax(el, dim=0) for el in scores]
            _, class_preds_batch = torch.max(scores, 1)

            class_probs.append(class_probs_batch)
            class_preds.append(class_preds_batch)

    test_probs = torch.cat([torch.stack(batch) for batch in class_probs])
    test_preds = torch.cat(class_preds)

    model.train()
    acc = num_correct/num_samples



    return acc, test_probs,test_preds,sum(losses)/len(losses)

In [None]:
learning_rate = 3e-5
num_epochs = 60

model = create_model(img_channel=1,layers=[1,2,2,2], num_classes=62).to(device)

# Loss and optimizer

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate,weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=2,gamma=0.9)

In [None]:
def save_chkpt(state,filename='model.pth.tar'):
  print('===== Saving Checkpoint =====')
  torch.save(state,filename)


In [None]:
def load_chkpoint(chkpt):
  print('===== Loading Checkpoint =====')
  model.load_state_dict(chkpt['state_dict'])
  


# Training 

In [None]:

tot_train_loss = []
tot_train_acc = []
tot_val_loss = []
tot_val_acc = []
best_acc = -10


# Train Network
for epoch in range(num_epochs):
    losses = []
    loop = tqdm(train_loader)
    num_correct = 0
    num_samples = 0


    for batch_idx, (data, targets) in enumerate(loop):
        # Get data to cuda if possible

        data = data.to(device=device)
        targets = targets.to(device=device)

        # forward
        scores = model(data)
        loss = criterion(scores, targets)

        _, predictions = scores.max(1)
        num_correct += (predictions == targets).sum()
        num_samples += predictions.size(0)

        # epoch_loss+= loss.item()

        losses.append(loss.item())
        # backward
        optimizer.zero_grad()
        loss.backward()

        # gradient descent or adam step
        optimizer.step()

        loop.set_description(f'Epoch [{epoch+1}/{num_epochs}]')
        loop.set_postfix(loss=loss.item(),train_acc= (num_correct*100/num_samples).item())


    scheduler.step()

    val_acc,val_probs,val_preds,val_loss = check_accuracy_probs_preds_loss(val_loader,criterion,model,'val',epoch)
    train_acc = num_correct/num_samples
    print('Val_acc: {:0.2f} Val_loss: {:0.2f}'.format(val_acc*100,val_loss))

    writer.add_scalar('Loss/Train',torch.tensor(losses).mean(),epoch)
    writer.add_scalar('Loss/Val',val_loss,epoch)
    writer.add_scalar('Acc/Train',train_acc,epoch)
    writer.add_scalar('Acc/Val',val_acc,epoch)

    tot_train_loss.append(torch.tensor(losses).mean())
    tot_train_acc.append(train_acc)
    tot_val_loss.append(val_loss)
    tot_val_acc.append(val_acc)

    if val_acc > best_acc:
        best_acc = val_acc
        chkpt = { 
            'state_dict': model.state_dict(),
            'val_acc': best_acc,
            'train_acc': train_acc
        }
        save_chkpt(chkpt,'part1_trained_model.pth.tar')




    # for i in range(len(classes)):
    #     add_pr_curve_tensorboard(i, train_probs, train_preds,epoch,'train')

    # for i in range(len(classes)):
    #     add_pr_curve_tensorboard(i, val_probs, val_preds,epoch,'val')






HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 4.29 Val_loss: 3.89
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 6.07 Val_loss: 3.63
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 14.64 Val_loss: 3.30
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 22.86 Val_loss: 2.92
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 38.93 Val_loss: 2.54
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 45.71 Val_loss: 2.16
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 60.71 Val_loss: 1.81
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 62.86 Val_loss: 1.66
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 67.14 Val_loss: 1.44
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 70.71 Val_loss: 1.26
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 72.86 Val_loss: 1.19
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 72.50 Val_loss: 1.12


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 78.21 Val_loss: 1.00
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 77.50 Val_loss: 0.89


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 80.71 Val_loss: 0.83
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 80.71 Val_loss: 0.80


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 76.79 Val_loss: 0.85


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 81.07 Val_loss: 0.79
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 82.50 Val_loss: 0.71
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 81.43 Val_loss: 0.73


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 83.93 Val_loss: 0.67
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 82.86 Val_loss: 0.65


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 82.14 Val_loss: 0.64


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 82.86 Val_loss: 0.65


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 84.64 Val_loss: 0.58
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 82.86 Val_loss: 0.61


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 86.43 Val_loss: 0.59
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 84.29 Val_loss: 0.58


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 85.36 Val_loss: 0.58


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 84.64 Val_loss: 0.59


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 84.64 Val_loss: 0.54


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 86.43 Val_loss: 0.54


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 83.93 Val_loss: 0.56


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 87.86 Val_loss: 0.50
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 86.43 Val_loss: 0.52


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 83.93 Val_loss: 0.54


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 83.57 Val_loss: 0.53


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 85.71 Val_loss: 0.54


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 85.00 Val_loss: 0.49


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 85.00 Val_loss: 0.52


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 85.71 Val_loss: 0.52


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 83.57 Val_loss: 0.52


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 85.00 Val_loss: 0.50


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 84.29 Val_loss: 0.49


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 85.00 Val_loss: 0.49


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 85.00 Val_loss: 0.48


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 86.07 Val_loss: 0.47


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 84.64 Val_loss: 0.53


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 85.36 Val_loss: 0.50


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 83.21 Val_loss: 0.50


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 87.50 Val_loss: 0.49


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 87.86 Val_loss: 0.48


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 87.14 Val_loss: 0.50


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 86.79 Val_loss: 0.51


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 85.36 Val_loss: 0.48


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 85.36 Val_loss: 0.48


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 85.36 Val_loss: 0.50


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 87.14 Val_loss: 0.47


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 86.79 Val_loss: 0.45


HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))


Val_acc: 86.07 Val_loss: 0.48


# Testing on Validation data

In [None]:
model = create_model(1,[1,2,2,2],62).to(device)

In [None]:
load_chkpoint(torch.load('./part1_trained_model.pth.tar',map_location=device))

===== Loading Checkpoint =====


In [None]:
def ret_preds_labels(model,loader):

  class_preds = []
  actual_preds = []

  model.eval()

  with torch.no_grad():
      for x, y in loader:
          x = x.to(device=device)
          y = y.to(device=device)

          scores = model(x)

          _, class_preds_batch = torch.max(scores, 1)

          class_preds.append(class_preds_batch)
          actual_preds.append(y)

  test_preds = torch.cat(class_preds)
  actual_preds = torch.cat(actual_preds)

  model.train()
  return test_preds,actual_preds


In [None]:
y_pred,y_true = ret_preds_labels(model,val_loader)

In [None]:
print(classification_report(y_true.cpu(),y_pred.cpu(),digits=3))

              precision    recall  f1-score   support

           0      1.000     0.500     0.667         2
           1      0.500     0.667     0.571         3
           2      1.000     1.000     1.000         7
           3      1.000     1.000     1.000         5
           4      1.000     1.000     1.000         6
           5      1.000     0.833     0.909         6
           6      1.000     1.000     1.000         6
           7      1.000     1.000     1.000         4
           8      0.833     1.000     0.909         5
           9      1.000     1.000     1.000         6
          10      1.000     1.000     1.000         3
          11      1.000     1.000     1.000         6
          12      0.667     1.000     0.800         2
          13      1.000     1.000     1.000         9
          14      1.000     1.000     1.000         5
          15      1.000     1.000     1.000         4
          16      1.000     1.000     1.000         4
          17      1.000    