<a href="https://colab.research.google.com/github/safal-singh/MNIST-Sum-CNN/blob/main/EVA6_Assignment3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision # provide access to datasets, models, transforms, utils, etc
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim

import numpy as np
import random

In [None]:
class CustomDataset(Dataset):
  def __init__(self):
    self.data = torchvision.datasets.MNIST(
        root='./data'
        ,train=True
        ,download=True
        ,transform=transforms.Compose([
            transforms.ToTensor()
        ])
    )

  def __getitem__(self, index):
    r = self.data[index]
    image, label = r
    random_int = random.randint(0, 9)
    one_hot = torch.zeros((10, 1)) #  CREATING ONE-HOT ENCODING FOR THE RANDOM INT GENERATED
    one_hot[random_int, ] = 1
    return image, one_hot, label, label+random_int

  def __len__(self):
    return len(self.data)

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # Input Block
        self.convblock1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=(3, 3), padding=0, bias=False) # output_size = 26
  
        # CONVOLUTION BLOCK 1
        self.convblock2 = nn.Conv2d(in_channels=8, out_channels=10, kernel_size=(3, 3), padding=0, bias=False) # output_size = 24
        
        self.convblock3 = nn.Conv2d(in_channels=10, out_channels=10, kernel_size=(3, 3), padding=0, bias=False) # output_size = 22        

        # TRANSITION BLOCK 1
        self.pool1 = nn.MaxPool2d(2, 2) # output_size = 11
        self.convblock4 = nn.Conv2d(in_channels=10, out_channels=16, kernel_size=(1, 1), padding=0, bias=False) # output_size = 11

        # CONVOLUTION BLOCK 2
        self.convblock5 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), padding=0, bias=False) # output_size = 9
        
        self.convblock6 = nn.Conv2d(in_channels=16, out_channels=10, kernel_size=(3, 3), padding=0, bias=False) # output_size = 7
      
        self.convblock7 = nn.Conv2d(in_channels=10, out_channels=10, kernel_size=(7, 7), padding=0, bias=False) # output_size = 1

        self.convblock8 = nn.Conv2d(in_channels=2, out_channels=10, kernel_size=(3, 1)) # input-10x1x2, output - 8x1x10

        self.convblock9 = nn.Conv2d(in_channels=10, out_channels=25, kernel_size=(3, 1)) # input-10x1x2, output - 6x1x20

        self.convblock10 = nn.Conv2d(in_channels=25, out_channels=19, kernel_size=(6, 1)) # input-10x1x2, output - 8x1x20

        # 3 convolutions to combine 10 outputs of conv7, 10 one-hot codes of random int (stacking gives 10x1x2) to classify from 0 to 18 (19 values)

    def forward(self, x, random_int):
        x = self.convblock1(x)
        x = F.relu(x)
        x = self.convblock2(x)
        x = F.relu(x)
        x = self.convblock3(x)
        x = F.relu(x)
        x = self.pool1(x)
        x = self.convblock4(x)
        x = F.relu(x)
        x = self.convblock5(x)
        x = F.relu(x)
        x = self.convblock6(x)
        x = F.relu(x)
        x = self.convblock7(x)
        x_mnist = x.view(-1, 10)
        mnist_out = F.log_softmax(x_mnist, dim=-1)

        # print(f'MNIST CONV SHAPE: {x.shape}')
        # print(f'RANDOM INT ONE HOT SHAPE: {random_int.shape}')

        # MNIST CONV SHAPE: torch.Size([32, 10, 1, 1])
        # RANDOM INT ONE HOT SHAPE: torch.Size([32, 10, 1])

        x = torch.cat((x.squeeze().unsqueeze(dim=1).unsqueeze(dim=3), 
                       random_int.unsqueeze(dim=1)), dim=1)
        x = F.relu(x)
        x = self.convblock8(x)
        x = F.relu(x)
        x = self.convblock9(x)
        x = F.relu(x)
        x = self.convblock10(x)
        x = x.view(-1, 19)
        sum_out = F.log_softmax(x, dim=-1)

        return mnist_out, sum_out

In [None]:
net = Net()
for name, param in net.named_parameters():
  print(name, param.shape)

convblock1.weight torch.Size([8, 1, 3, 3])
convblock2.weight torch.Size([10, 8, 3, 3])
convblock3.weight torch.Size([10, 10, 3, 3])
convblock4.weight torch.Size([16, 10, 1, 1])
convblock5.weight torch.Size([16, 16, 3, 3])
convblock6.weight torch.Size([10, 16, 3, 3])
convblock7.weight torch.Size([10, 10, 7, 7])
convblock8.weight torch.Size([10, 2, 3, 1])
convblock8.bias torch.Size([10])
convblock9.weight torch.Size([25, 10, 3, 1])
convblock9.bias torch.Size([25])
convblock10.weight torch.Size([19, 25, 6, 1])
convblock10.bias torch.Size([19])


In [None]:
network = Net()
dt = CustomDataset()

train_loader = torch.utils.data.DataLoader(dt
    ,batch_size=32
    ,shuffle=True
)

optimizer = optim.Adam(network.parameters(), lr=0.01)

In [8]:
def get_num_correct(preds, labels, pred_sums, sums):
  return (preds.argmax(dim=1).eq(labels) & pred_sums.argmax(dim=1).eq(sums)).sum().item()

In [None]:
# TRAINED WITH ADAM OPTIMIZER, CROSS ENTROPY LOSS
for epoch in range(10):

    total_loss = 0
    total_correct_mnist = 0
    total_correct_sum = 0

    for batch in train_loader: # Get Batch
        images, ints, labels, sums = batch 

        preds, pred_sums = network(images, ints) # Pass Batch
        loss = F.cross_entropy(preds, labels) + F.cross_entropy(pred_sums, sums) # Calculate Loss

        optimizer.zero_grad()
        loss.backward() # Calculate Gradients
        optimizer.step() # Update Weights

        total_loss += loss.item()
        total_correct_mnist += get_num_correct(preds, labels, pred_sums, sums)
        # total_correct_sum += get_num_correct(pred_sums, sums)

    print(
        "epoch", epoch, 
        "total_correct:", total_correct_mnist,
        "train_acc (%):", total_correct_mnist/600, 
        "loss:", total_loss
    )

epoch 0 total_correct: 52755 train_acc: 87.925 loss: 1116.6370263546705
epoch 1 total_correct: 56993 train_acc: 94.98833333333333 loss: 564.7087776418775
epoch 2 total_correct: 57405 train_acc: 95.675 loss: 484.94715392496437
epoch 3 total_correct: 57555 train_acc: 95.925 loss: 467.5541512882337
epoch 4 total_correct: 57814 train_acc: 96.35666666666667 loss: 418.7254181718454
epoch 5 total_correct: 57777 train_acc: 96.295 loss: 419.2600337318145
epoch 6 total_correct: 57857 train_acc: 96.42833333333333 loss: 413.58039489621297
epoch 7 total_correct: 57814 train_acc: 96.35666666666667 loss: 426.6177197410725
epoch 8 total_correct: 58011 train_acc: 96.685 loss: 381.0123276892118
epoch 9 total_correct: 57669 train_acc: 96.115 loss: 456.71365990000777


In [6]:
network = Net()
dt = CustomDataset()

train_loader = torch.utils.data.DataLoader(dt
    ,batch_size=32
    ,shuffle=True
)

optimizer = optim.Adam(network.parameters(), lr=0.01)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Using downloaded and verified file: ./data/MNIST/raw/train-images-idx3-ubyte.gz
Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Using downloaded and verified file: ./data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [9]:
# TRAINED WITH ADAM OPTIMIZER, NLL LOSS
for epoch in range(10):

    total_loss = 0
    total_correct_mnist = 0
    total_correct_sum = 0

    for batch in train_loader: # Get Batch
        images, ints, labels, sums = batch 

        preds, pred_sums = network(images, ints) # Pass Batch
        loss = F.nll_loss(preds, labels) + F.nll_loss(pred_sums, sums) # Calculate Loss

        optimizer.zero_grad()
        loss.backward() # Calculate Gradients
        optimizer.step() # Update Weights

        total_loss += loss.item()
        total_correct_mnist += get_num_correct(preds, labels, pred_sums, sums)
        # total_correct_sum += get_num_correct(pred_sums, sums)

    print(
        "epoch", epoch, 
        "total_correct:", total_correct_mnist,
        "train_acc (%):", total_correct_mnist/600, 
        "loss:", total_loss
    )

epoch 0 total_correct: 52216 train_acc (%): 87.02666666666667 loss: 1196.1823298074305
epoch 1 total_correct: 56844 train_acc (%): 94.74 loss: 584.6205953862518
epoch 2 total_correct: 57206 train_acc (%): 95.34333333333333 loss: 521.1775610516779
epoch 3 total_correct: 57325 train_acc (%): 95.54166666666667 loss: 485.4478666591458
epoch 4 total_correct: 57441 train_acc (%): 95.735 loss: 475.08763302955776
epoch 5 total_correct: 57481 train_acc (%): 95.80166666666666 loss: 452.77352287550457
epoch 6 total_correct: 57628 train_acc (%): 96.04666666666667 loss: 439.8011473393999
epoch 7 total_correct: 57870 train_acc (%): 96.45 loss: 409.36258122045547
epoch 8 total_correct: 57866 train_acc (%): 96.44333333333333 loss: 410.1813166986685
epoch 9 total_correct: 57694 train_acc (%): 96.15666666666667 loss: 438.60501672839746


In [10]:
network = Net()
dt = CustomDataset()

train_loader = torch.utils.data.DataLoader(dt
    ,batch_size=32
    ,shuffle=True
)

optimizer = optim.SGD(network.parameters(), lr=0.01, momentum=0.9)

In [11]:
# TRAINED WITH SGD OPTIMIZER, NLL LOSS
for epoch in range(10):

    total_loss = 0
    total_correct_mnist = 0
    total_correct_sum = 0

    for batch in train_loader: # Get Batch
        images, ints, labels, sums = batch 

        preds, pred_sums = network(images, ints) # Pass Batch
        loss = F.nll_loss(preds, labels) + F.nll_loss(pred_sums, sums) # Calculate Loss

        optimizer.zero_grad()
        loss.backward() # Calculate Gradients
        optimizer.step() # Update Weights

        total_loss += loss.item()
        total_correct_mnist += get_num_correct(preds, labels, pred_sums, sums)
        # total_correct_sum += get_num_correct(pred_sums, sums)

    print(
        "epoch", epoch, 
        "total_correct:", total_correct_mnist,
        "train_acc (%):", total_correct_mnist/600, 
        "loss:", total_loss
    )

epoch 0 total_correct: 37821 train_acc (%): 63.035 loss: 2812.0291588939726
epoch 1 total_correct: 57053 train_acc (%): 95.08833333333334 loss: 543.0080773606896
epoch 2 total_correct: 57933 train_acc (%): 96.555 loss: 405.43888782709837
epoch 3 total_correct: 58227 train_acc (%): 97.045 loss: 338.6837383089587
epoch 4 total_correct: 58465 train_acc (%): 97.44166666666666 loss: 296.42404373688623
epoch 5 total_correct: 58420 train_acc (%): 97.36666666666666 loss: 297.2620654888451
epoch 6 total_correct: 58602 train_acc (%): 97.67 loss: 265.6962027081754
epoch 7 total_correct: 58656 train_acc (%): 97.76 loss: 246.78087908169255
epoch 8 total_correct: 58699 train_acc (%): 97.83166666666666 loss: 240.36291756271385
epoch 9 total_correct: 58741 train_acc (%): 97.90166666666667 loss: 228.12640660093166


HIGHEST TRAIN ACCURACY OF 97.9% USING SGD OPTIMIZER, NLL LOSS