In [2]:
import cv2 as cv
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import Dataset,DataLoader,random_split
from matplotlib import pyplot as plt
from torchvision import transforms
from torchsummary import summary
from dataset2 import TrainData
import torch.optim as optim
from tqdm.notebook import tqdm
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix,classification_report

In [4]:

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [5]:
class InterMediateBlock(nn.Module):
    
    def __init__(self,in_channels,inter_channels,identity_connection=None,stride=1):
        
        super(InterMediateBlock,self).__init__()
        self.expansion = 4
        
        self.conv1 = nn.Conv2d(
            in_channels,
            inter_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False
        )
        
        self.conv2 = nn.Conv2d(
            inter_channels,
            inter_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False
        )
        
        self.conv3 = nn.Conv2d(
            inter_channels,
            inter_channels*self.expansion,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False
        )
        
        self.bn1 = nn.BatchNorm2d(inter_channels)
        self.bn2 = nn.BatchNorm2d(inter_channels)
        self.bn3 = nn.BatchNorm2d(inter_channels*self.expansion)
        
        self.relu = nn.ReLU()
        
        self.identity_connection = identity_connection
        self.initialize_weights()
        
    
    def forward(self,x):
        
        id_connection = x.clone()
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        
        if self.identity_connection:
            
            id_connection = self.identity_connection(id_connection)
            
        x += id_connection
        
        x = self.relu(x)
        return x      

      

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight)

                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)



In [6]:
class ResNet(nn.Module):
    
    def __init__(self, block, no_layers, channels, num_classes):
        
        super(ResNet,self).__init__()
        self.in_channels = 64
        
        self.conv1 = nn.Conv2d(channels,64,kernel_size=7,stride=2,padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        
        self.layer1 = self.layer(InterMediateBlock,no_layers[0],64,1)
        self.layer2 = self.layer(InterMediateBlock,no_layers[1],128,2) 
        self.layer3 = self.layer(InterMediateBlock,no_layers[2],256,2)
        self.layer4 = self.layer(InterMediateBlock,no_layers[3],512,2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc =  nn.Linear(2048,num_classes)
        self.softmax = nn.Softmax(dim=1)
        self.initialize_weights()
        
    def layer(self,block, no_layer_blocks, out_channels, stride):
        
        identity_connection = None
        layers = []
        
        if stride!=1 or self.in_channels!= out_channels*4:
            
            identity_connection = nn.Sequential(
                    nn.Conv2d(self.in_channels,out_channels*4,kernel_size=1,stride=stride),
                    nn.BatchNorm2d(out_channels*4),
            )
        layers.append(block(self.in_channels,out_channels,identity_connection,stride=stride))
        self.in_channels = out_channels*4
        
        
        for i in range(no_layer_blocks):
            layers.append(block(self.in_channels,out_channels))
        
        return nn.Sequential(*layers)
        
        
    def forward(self,x):
    
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = x.reshape(x.shape[0],-1)
        x = self.fc(x)
        # x = self.softmax(x)
        
        return x


    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight)

                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)
            

In [7]:


def create_model(img_channel=1, layers = [1,1,1,1], num_classes=10):
    return ResNet(InterMediateBlock, layers, img_channel, num_classes)

In [8]:
transform = transforms.Compose(
    [
        transforms.ToPILImage(),
        # transforms.RandomHorizontalFlip(p=0.4),
#         transforms.RandomVerticalFlip(p=0.4),
        transforms.RandomRotation((-30,30)),
        transforms.Resize((300,400)),
        transforms.CenterCrop((280,280)),
        transforms.Resize((28,28)),
        transforms.ToTensor(),
        # transforms.Normalize(mean=(0.8840,),std=(0.3186,))
    ]
)

In [12]:
num_classes = 10
classes = [str(i) for i in range(num_classes)]

In [14]:
def save_chkpt(state,filename='model.pth.tar'):
  print('===== Saving Checkpoint =====')
  torch.save(state,filename)

In [15]:
def load_chkpoint(chkpt):
  print('===== Loading Checkpoint =====')
  model.load_state_dict(chkpt['state_dict'])
  


In [19]:


def check_accuracy_probs_preds_loss(loader,loss_fn , model,mode='train',global_step=None):


    num_correct = 0
    num_samples = 0
    class_probs = []
    class_preds = []
    losses = []

    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device)
            y = y.to(device=device)

            scores = model(x)
            loss = loss_fn(scores,y)

            losses.append(loss.item())

            _, predictions = scores.max(1)
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)


            class_probs_batch = [F.softmax(el, dim=0) for el in scores]
            _, class_preds_batch = torch.max(scores, 1)

            class_probs.append(class_probs_batch)
            class_preds.append(class_preds_batch)

    test_probs = torch.cat([torch.stack(batch) for batch in class_probs])
    test_preds = torch.cat(class_preds)

    model.train()
    acc = num_correct/num_samples



    return acc, test_probs,test_preds,sum(losses)/len(losses)


## Training on the dataset used for part 1 by selecting 0-9 classes

In [None]:

dataset = TrainData(root='./train' ,transform=transform)
train_data,val_data = random_split(dataset,[342,58])
train_loader = DataLoader(train_data,batch_size=32,shuffle=True)
val_loader = DataLoader(val_data,batch_size=32,shuffle=True)

In [None]:

writer = SummaryWriter('runs/part2_custom_mnist')

In [None]:
learning_rate = 3e-4
num_epochs = 80
model = create_model(img_channel=1,layers=[1,2,2,2],num_classes=num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate,weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=2,gamma=0.9)

In [85]:


tot_train_loss = []
tot_train_acc = []
tot_val_loss = []
tot_val_acc = []
best_acc = -10


# Train Network
for epoch in range(num_epochs):
    losses = []
    loop = tqdm(train_loader)
    num_correct = 0
    num_samples = 0


    for batch_idx, (data, targets) in enumerate(loop):
        # Get data to cuda if possible

        data = data.to(device=device)
        targets = targets.to(device=device)

        # forward
        scores = model(data)
        loss = criterion(scores, targets)

        _, predictions = scores.max(1)
        num_correct += (predictions == targets).sum()
        num_samples += predictions.size(0)

        # epoch_loss+= loss.item()

        losses.append(loss.item())
        # backward
        optimizer.zero_grad()
        loss.backward()

        # gradient descent or adam step
        optimizer.step()

        loop.set_description(f'Epoch [{epoch+1}/{num_epochs}]')
        loop.set_postfix(loss=loss.item(),train_acc= (num_correct*100/num_samples).item())



    val_acc,val_probs,val_preds,val_loss = check_accuracy_probs_preds_loss(val_loader,criterion,model,'val',epoch)
    train_acc = num_correct/num_samples
    print('Val_acc: {:0.2f} Val_loss: {:0.2f}'.format(val_acc*100,val_loss))

    writer.add_scalar('Loss/Train',torch.tensor(losses).mean(),epoch)
    writer.add_scalar('Loss/Val',val_loss,epoch)
    writer.add_scalar('Acc/Train',train_acc,epoch)
    writer.add_scalar('Acc/Val',val_acc,epoch)

    tot_train_loss.append(torch.tensor(losses).mean())
    tot_train_acc.append(train_acc)
    tot_val_loss.append(val_loss)
    tot_val_acc.append(val_acc)

    if val_acc > best_acc:
      best_acc = val_acc
      chkpt = { 
          'state_dict': model.state_dict(),
          'val_acc': best_acc,
          'train_acc': train_acc
      }
      save_chkpt(chkpt,'pretrained_custom_mnist.pth.tar')




    # for i in range(len(classes)):
    #     add_pr_curve_tensorboard(i, train_probs, train_preds,epoch,'train')

    # for i in range(len(classes)):
    #     add_pr_curve_tensorboard(i, val_probs, val_preds,epoch,'val')






HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 10.34 Val_loss: 8593.58
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 8.62 Val_loss: 891.41


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 12.07 Val_loss: 59.10
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 37.93 Val_loss: 3.87
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 34.48 Val_loss: 8.93


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 53.45 Val_loss: 2.01
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 55.17 Val_loss: 1.97
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 51.72 Val_loss: 1.34


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 43.10 Val_loss: 5.23


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 53.45 Val_loss: 1.90


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 50.00 Val_loss: 1.80


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 65.52 Val_loss: 2.37
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 65.52 Val_loss: 1.12


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 70.69 Val_loss: 0.88
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 68.97 Val_loss: 0.88


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 62.07 Val_loss: 1.43


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 68.97 Val_loss: 1.02


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 72.41 Val_loss: 0.85
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 74.14 Val_loss: 0.91
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 58.62 Val_loss: 2.18


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 41.38 Val_loss: 6.65


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 56.90 Val_loss: 1.98


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 65.52 Val_loss: 1.08


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 60.34 Val_loss: 1.24


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 75.86 Val_loss: 0.83
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 70.69 Val_loss: 1.02


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 67.24 Val_loss: 1.35


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 44.83 Val_loss: 29.67


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 36.21 Val_loss: 13.72


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 68.97 Val_loss: 1.13


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 68.97 Val_loss: 1.64


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 62.07 Val_loss: 3.33


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 77.59 Val_loss: 0.54
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 75.86 Val_loss: 0.88


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 65.52 Val_loss: 1.91


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 68.97 Val_loss: 1.16


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 74.14 Val_loss: 1.25


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 63.79 Val_loss: 6.25


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 68.97 Val_loss: 1.23


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 72.41 Val_loss: 0.96


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 65.52 Val_loss: 1.00


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 75.86 Val_loss: 0.79


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 81.03 Val_loss: 0.63
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 82.76 Val_loss: 0.61
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 81.03 Val_loss: 0.74


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 55.17 Val_loss: 7.12


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 68.97 Val_loss: 1.26


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 79.31 Val_loss: 0.77


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 86.21 Val_loss: 0.51
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Val_acc: 81.03 Val_loss: 4.91


In [86]:
load_chkpoint(torch.load('./pretrained_custom_mnist.pth.tar'))

===== Loading Checkpoint =====


In [87]:
def ret_preds_labels(model,loader):

  class_preds = []
  actual_preds = []

  model.eval()

  with torch.no_grad():
      for x, y in loader:
          x = x.to(device=device)
          y = y.to(device=device)

          scores = model(x)

          _, class_preds_batch = torch.max(scores, 1)

          class_preds.append(class_preds_batch)
          actual_preds.append(y)

  test_preds = torch.cat(class_preds)
  actual_preds = torch.cat(actual_preds)

  model.train()
  return test_preds,actual_preds


In [88]:
y_pred,y_true = ret_preds_labels(model,val_loader)

In [89]:
print(classification_report(y_true.cpu(),y_pred.cpu(),digits=3))

              precision    recall  f1-score   support

           0      0.667     0.800     0.727         5
           1      1.000     0.889     0.941         9
           2      1.000     0.833     0.909         6
           3      0.600     0.750     0.667         4
           4      0.750     0.750     0.750         4
           5      0.714     0.714     0.714         7
           6      0.750     0.500     0.600         6
           7      1.000     1.000     1.000         6
           8      0.667     1.000     0.800         6
           9      0.750     0.600     0.667         5

    accuracy                          0.793        58
   macro avg      0.790     0.784     0.778        58
weighted avg      0.810     0.793     0.792        58



In [None]:

writer.close()


# Training on MNIST Data Using and not using pretrained model

## Using Pretrained

In [None]:
writer = SummaryWriter('runs/part2_standard_mnist_pretrained')


In [9]:
dataset = torchvision.datasets.MNIST(root='.',
                                  train=True,
                                  transform=transforms.Compose([                                                   
                                    transforms.RandomRotation((-30,30)),
                                    transforms.ToTensor(),
                                  ]),download=True,
  )
print(len(dataset))
train_data,val_data = random_split(dataset,[50000,10000])
train_loader = DataLoader(train_data,batch_size=32,shuffle=True)
val_loader = DataLoader(val_data,batch_size=32,shuffle=True)

test_loader = DataLoader( torchvision.datasets.MNIST(root='.',
                                  train=False,
                                  transform=transforms.Compose([                                                   
                                    transforms.RandomRotation((-30,30)),
                                    transforms.ToTensor(),
                                  ]),download=True,
  ),batch_size=32,shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw

Processing...
Done!
60000


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [None]:
writer = SummaryWriter('runs/part2_standard_mnist_pretrained')

In [10]:
! cp '/content/drive/MyDrive/midas/task2/part2/part_2_custom_mnist_model_tb.zip' .
! unzip -q part_2_custom_mnist_model_tb.zip

In [20]:
learning_rate = 3e-4
num_epochs = 30

model = create_model(img_channel=1,layers=[1,2,2,2],num_classes=num_classes).to(device)
load_chkpoint(torch.load('./pretrained_custom_mnist.pth.tar'))

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate,weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=2,gamma=0.9)

===== Loading Checkpoint =====


In [21]:

tot_train_loss = []
tot_train_acc = []
tot_val_loss = []
tot_val_acc = []
best_acc = -10


# Train Network
for epoch in range(num_epochs):
    losses = []
    loop = tqdm(train_loader)
    num_correct = 0
    num_samples = 0


    for batch_idx, (data, targets) in enumerate(loop):
        # Get data to cuda if possible

        data = data.to(device=device)
        targets = targets.to(device=device)

        # forward
        scores = model(data)
        loss = criterion(scores, targets)

        _, predictions = scores.max(1)
        num_correct += (predictions == targets).sum()
        num_samples += predictions.size(0)

        # epoch_loss+= loss.item()

        losses.append(loss.item())
        # backward
        optimizer.zero_grad()
        loss.backward()

        # gradient descent or adam step
        optimizer.step()

        loop.set_description(f'Epoch [{epoch+1}/{num_epochs}]')
        loop.set_postfix(loss=loss.item(),train_acc= (num_correct*100/num_samples).item())



    val_acc,val_probs,val_preds,val_loss = check_accuracy_probs_preds_loss(val_loader,criterion,model,'val',epoch)
    train_acc = num_correct/num_samples
    print('Val_acc: {:0.2f} Val_loss: {:0.2f}'.format(val_acc*100,val_loss))

    writer.add_scalar('Loss/Train',torch.tensor(losses).mean(),epoch)
    writer.add_scalar('Loss/Val',val_loss,epoch)
    writer.add_scalar('Acc/Train',train_acc,epoch)
    writer.add_scalar('Acc/Val',val_acc,epoch)

    tot_train_loss.append(torch.tensor(losses).mean())
    tot_train_acc.append(train_acc)
    tot_val_loss.append(val_loss)
    tot_val_acc.append(val_acc)

    if val_acc > best_acc:
      best_acc = val_acc
      chkpt = { 
          'state_dict': model.state_dict(),
          'val_acc': best_acc,
          'train_acc': train_acc
      }
      save_chkpt(chkpt,'part_2_standard_mnist_pretrained.pth.tar')






HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 94.85 Val_loss: 0.16
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 96.93 Val_loss: 0.10
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 97.35 Val_loss: 0.09
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.12 Val_loss: 0.07
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 97.65 Val_loss: 0.08


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.05 Val_loss: 0.07


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.13 Val_loss: 0.06
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.13 Val_loss: 0.06


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.15 Val_loss: 0.07
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.19 Val_loss: 0.06
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.50 Val_loss: 0.05
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.65 Val_loss: 0.05
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.58 Val_loss: 0.05


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.28 Val_loss: 0.06


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.65 Val_loss: 0.05


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.35 Val_loss: 0.05


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.33 Val_loss: 0.06


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.65 Val_loss: 0.05


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.94 Val_loss: 0.04
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.85 Val_loss: 0.04


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 99.00 Val_loss: 0.04
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.78 Val_loss: 0.04


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.85 Val_loss: 0.04


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.83 Val_loss: 0.04


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.57 Val_loss: 0.05


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.89 Val_loss: 0.04


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.88 Val_loss: 0.04


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.81 Val_loss: 0.04


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.64 Val_loss: 0.04


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.79 Val_loss: 0.04


In [22]:
model = create_model(1,[1,2,2,2],10).to(device)

In [25]:
load_chkpoint(torch.load('./part_2_standard_mnist_pretrained.pth.tar',map_location=device))

===== Loading Checkpoint =====


In [26]:
def ret_preds_labels(model,loader):

  class_preds = []
  actual_preds = []

  model.eval()

  with torch.no_grad():
      for x, y in loader:
          x = x.to(device=device)
          y = y.to(device=device)

          scores = model(x)

          _, class_preds_batch = torch.max(scores, 1)

          class_preds.append(class_preds_batch)
          actual_preds.append(y)

  test_preds = torch.cat(class_preds)
  actual_preds = torch.cat(actual_preds)

  model.train()
  return test_preds,actual_preds


In [27]:
y_pred,y_true = ret_preds_labels(model,test_loader)

In [28]:
print(classification_report(y_true.cpu(),y_pred.cpu(),digits=3))

              precision    recall  f1-score   support

           0      0.997     0.994     0.995       980
           1      0.984     0.994     0.989      1135
           2      0.990     0.992     0.991      1032
           3      0.986     0.994     0.990      1010
           4      0.986     0.995     0.990       982
           5      0.993     0.987     0.990       892
           6      0.991     0.991     0.991       958
           7      0.984     0.980     0.982      1028
           8      0.996     0.984     0.990       974
           9      0.986     0.982     0.984      1009

    accuracy                          0.989     10000
   macro avg      0.989     0.989     0.989     10000
weighted avg      0.989     0.989     0.989     10000



In [None]:
writer.close()

# From Scratch

In [31]:
dataset = torchvision.datasets.MNIST(root='.',
                                  train=True,
                                  transform=transforms.Compose([                                                   
                                    transforms.RandomRotation((-30,30)),
                                    transforms.ToTensor(),
                                  ]),download=True,
  )
print(len(dataset))
train_data,val_data = random_split(dataset,[50000,10000])
train_loader = DataLoader(train_data,batch_size=32,shuffle=True)
val_loader = DataLoader(val_data,batch_size=32,shuffle=True)

test_loader = DataLoader( torchvision.datasets.MNIST(root='.',
                                  train=False,
                                  transform=transforms.Compose([                                                   
                                    transforms.RandomRotation((-30,30)),
                                    transforms.ToTensor(),
                                  ]),download=True,
  ),batch_size=32,shuffle=True)

In [32]:
writer = SummaryWriter('runs/part2_standard_mnist_scratch')

In [33]:
num_classes = 10
classes = [str(i) for i in range(num_classes)]

In [39]:
learning_rate = 3e-4
num_epochs = 30
model = create_model(img_channel=1,layers=[1,2,2,2],num_classes=num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate,weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=2,gamma=0.9)

In [35]:


def check_accuracy_probs_preds_loss(loader,loss_fn , model,mode='train',global_step=None):


    num_correct = 0
    num_samples = 0
    class_probs = []
    class_preds = []
    losses = []

    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device)
            y = y.to(device=device)

            scores = model(x)
            loss = loss_fn(scores,y)

            losses.append(loss.item())

            _, predictions = scores.max(1)
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)


            class_probs_batch = [F.softmax(el, dim=0) for el in scores]
            _, class_preds_batch = torch.max(scores, 1)

            class_probs.append(class_probs_batch)
            class_preds.append(class_preds_batch)

    test_probs = torch.cat([torch.stack(batch) for batch in class_probs])
    test_preds = torch.cat(class_preds)

    model.train()
    acc = num_correct/num_samples



    return acc, test_probs,test_preds,sum(losses)/len(losses)

In [36]:
def save_chkpt(state,filename='model.pth.tar'):
  print('===== Saving Checkpoint =====')
  torch.save(state,filename)

In [37]:
def load_chkpoint(chkpt):
  print('===== Loading Checkpoint =====')
  model.load_state_dict(chkpt['state_dict'])
  


In [None]:

dataset = torchvision.datasets.MNIST(root='.',
                                  train=True,
                                  transform=transforms.Compose([                                                   
                                    transforms.RandomRotation((-30,30)),
                                    transforms.ToTensor(),
                                  ]),download=True,
  )
print(len(dataset))
train_data,val_data = random_split(dataset,[50000,10000])
train_loader = DataLoader(train_data,batch_size=32,shuffle=True)
val_loader = DataLoader(val_data,batch_size=32,shuffle=True)

test_loader = DataLoader( torchvision.datasets.MNIST(root='.',
                                  train=False,
                                  transform=transforms.Compose([                                                   
                                    transforms.RandomRotation((-30,30)),
                                    transforms.ToTensor(),
                                  ]),download=True,
  ),batch_size=32,shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw

Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [40]:


tot_train_loss = []
tot_train_acc = []
tot_val_loss = []
tot_val_acc = []
best_acc = -10


# Train Network
for epoch in range(num_epochs):
    losses = []
    loop = tqdm(train_loader)
    num_correct = 0
    num_samples = 0


    for batch_idx, (data, targets) in enumerate(loop):
        # Get data to cuda if possible

        data = data.to(device=device)
        targets = targets.to(device=device)

        # forward
        scores = model(data)
        loss = criterion(scores, targets)

        _, predictions = scores.max(1)
        num_correct += (predictions == targets).sum()
        num_samples += predictions.size(0)

        # epoch_loss+= loss.item()

        losses.append(loss.item())
        # backward
        optimizer.zero_grad()
        loss.backward()

        # gradient descent or adam step
        optimizer.step()

        loop.set_description(f'Epoch [{epoch+1}/{num_epochs}]')
        loop.set_postfix(loss=loss.item(),train_acc= (num_correct*100/num_samples).item())



    val_acc,val_probs,val_preds,val_loss = check_accuracy_probs_preds_loss(val_loader,criterion,model,'val',epoch)
    train_acc = num_correct/num_samples
    print('Val_acc: {:0.2f} Val_loss: {:0.2f}'.format(val_acc*100,val_loss))

    writer.add_scalar('Loss/Train',torch.tensor(losses).mean(),epoch)
    writer.add_scalar('Loss/Val',val_loss,epoch)
    writer.add_scalar('Acc/Train',train_acc,epoch)
    writer.add_scalar('Acc/Val',val_acc,epoch)

    tot_train_loss.append(torch.tensor(losses).mean())
    tot_train_acc.append(train_acc)
    tot_val_loss.append(val_loss)
    tot_val_acc.append(val_acc)

    if val_acc > best_acc:
      best_acc = val_acc
      chkpt = { 
          'state_dict': model.state_dict(),
          'val_acc': best_acc,
          'train_acc': train_acc
      }
      save_chkpt(chkpt,'part_2_standard_mnist_scratch.pth.tar')




    # for i in range(len(classes)):
    #     add_pr_curve_tensorboard(i, train_probs, train_preds,epoch,'train')

    # for i in range(len(classes)):
    #     add_pr_curve_tensorboard(i, val_probs, val_preds,epoch,'val')






HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 93.40 Val_loss: 0.22
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 95.35 Val_loss: 0.14
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 96.63 Val_loss: 0.11
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 96.51 Val_loss: 0.11


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 97.61 Val_loss: 0.08
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 97.43 Val_loss: 0.09


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 97.67 Val_loss: 0.08
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 97.78 Val_loss: 0.08
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.25 Val_loss: 0.06
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 97.65 Val_loss: 0.07


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.41 Val_loss: 0.06
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.38 Val_loss: 0.06


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 97.39 Val_loss: 0.09


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.29 Val_loss: 0.06


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.42 Val_loss: 0.06
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.57 Val_loss: 0.05
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.63 Val_loss: 0.05
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.32 Val_loss: 0.06


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.67 Val_loss: 0.05
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.48 Val_loss: 0.05


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.30 Val_loss: 0.06


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.45 Val_loss: 0.06


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.73 Val_loss: 0.04
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.62 Val_loss: 0.05


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.65 Val_loss: 0.05


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.47 Val_loss: 0.05


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.62 Val_loss: 0.05


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.68 Val_loss: 0.05


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.87 Val_loss: 0.04
===== Saving Checkpoint =====


HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))


Val_acc: 98.63 Val_loss: 0.05


In [None]:
model = create_model(1,[1,2,2,2],10).to(device)

In [41]:
load_chkpoint(torch.load('./part_2_standard_mnist_scratch.pth.tar',map_location=device))

===== Loading Checkpoint =====


In [42]:
def ret_preds_labels(model,loader):

  class_preds = []
  actual_preds = []

  model.eval()

  with torch.no_grad():
      for x, y in loader:
          x = x.to(device=device)
          y = y.to(device=device)

          scores = model(x)

          _, class_preds_batch = torch.max(scores, 1)

          class_preds.append(class_preds_batch)
          actual_preds.append(y)

  test_preds = torch.cat(class_preds)
  actual_preds = torch.cat(actual_preds)

  model.train()
  return test_preds,actual_preds


In [45]:
y_pred,y_true = ret_preds_labels(model,test_loader)

In [46]:
print(classification_report(y_true.cpu(),y_pred.cpu(),digits=3))

              precision    recall  f1-score   support

           0      0.984     0.997     0.990       980
           1      0.991     0.992     0.992      1135
           2      0.983     0.993     0.988      1032
           3      0.993     0.995     0.994      1010
           4      0.994     0.987     0.990       982
           5      0.987     0.987     0.987       892
           6      0.997     0.993     0.995       958
           7      0.984     0.980     0.982      1028
           8      0.993     0.984     0.988       974
           9      0.985     0.983     0.984      1009

    accuracy                          0.989     10000
   macro avg      0.989     0.989     0.989     10000
weighted avg      0.989     0.989     0.989     10000



In [None]:
writer.close()