## MNIST CNN with 20K or less parameters

In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import torchvision

!pip install torchsummary
from torchsummary import summary 



# Check for Cuda device

In [2]:
# check if cuda is available
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(device)

cuda


# Load and Prepare Dataset

MNIST contains 70,000 images of handwritten digits: 60,000 for training and 10,000 for testing. The images are grayscale, 28x28 pixels

We load the PIL images using torchvision.datasets.MNIST, while loading the image we transform he data to tensor and normalize the images with mean and std deviation of MNIST images.

In [3]:
torch.manual_seed(1)
batch_size = 128

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                    transform=transforms.Compose([transforms.RandomRotation(10),
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
    batch_size=batch_size, shuffle=True, **kwargs)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


# MNIST Neural Net

In [10]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Sequential(                     #input-28x28x1    #RF -3x3  #Output 26x26x16
            nn.Conv2d(1, 16, 3, bias=False),
            nn.ReLU(),          
            nn.BatchNorm2d(16),  
            nn.Dropout2d(0.1),
            nn.Conv2d(16, 16, 3, bias=False),           #input-26x26x16   #RF -5x5  #Output 24x24x16      
            nn.ReLU(), 
            nn.BatchNorm2d(16),          
            nn.Conv2d(16, 32, 3, bias=False),           #input-24x24x16   #RF -7x7  #Output 22x22x32
            nn.ReLU(),         
            nn.BatchNorm2d(32),  
            nn.Dropout2d(0.1),
        )

        #Transition block
        self.transition_layer= nn.Sequential(                 #input-22x22x32   #RF -14x14  #Output 11x11x16
            nn.Conv2d(32,16,1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )      

        self.conv2 = nn.Sequential(                     #input-11x11x16   #RF -16x16  #Output 9x9x16
            nn.Conv2d(16, 16, 3, bias=False),          
            nn.ReLU(),          
            nn.BatchNorm2d(16),   
            nn.Dropout2d(0.1), 
            nn.Conv2d(16, 16, 3, padding=1, bias=False),  #input-9x9x16   #RF -18x18  #Output 9x9x16   
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Dropout2d(0.1),
            nn.Conv2d(16, 16, 3,bias=False),             #input-9x9x16   #RF -20x20  #Output 7x7x16
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Dropout2d(0.1),
            nn.Conv2d(16,32, 3, bias=False),             #input-7x7x16   #RF -22x22  #Output 5x5x32
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Dropout2d(0.1),        
        )

        self.conv_final=nn.Conv2d(32, 10, 1, bias=False) #input-5x5x32 Output: 1x1x32 
        self.gap = nn.AvgPool2d(5)            
            

    def forward(self, x):

        x= self.conv1(x)
        x=self.transition_layer(x)
        x= self.conv2(x)      
        x=self.conv_final(x)
        x=self.gap(x)
        x=x.view(-1,10)        
        
        return F.log_softmax(x)

# Model Summary and Parameters

In [11]:
# model the model to the device
model = Network().to(device)
# print the model summary when given input_size
summary(model, input_size=(1, 28, 28))

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 26, 26]             144
              ReLU-2           [-1, 16, 26, 26]               0
       BatchNorm2d-3           [-1, 16, 26, 26]              32
         Dropout2d-4           [-1, 16, 26, 26]               0
            Conv2d-5           [-1, 16, 24, 24]           2,304
              ReLU-6           [-1, 16, 24, 24]               0
       BatchNorm2d-7           [-1, 16, 24, 24]              32
            Conv2d-8           [-1, 32, 22, 22]           4,608
              ReLU-9           [-1, 32, 22, 22]               0
      BatchNorm2d-10           [-1, 32, 22, 22]              64
        Dropout2d-11           [-1, 32, 22, 22]               0
           Conv2d-12           [-1, 16, 22, 22]             528
             ReLU-13           [-1, 16, 22, 22]               0
        MaxPool2d-14           [-1, 16,



# Training and Testing

In [6]:
from tqdm import tqdm
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    pbar = tqdm(train_loader)
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        pbar.set_description(desc= f'loss={loss.item()} batch_id={batch_idx}')



def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return 100. * correct / len(test_loader.dataset)

In [12]:
model = Network().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)


test_accuracy = []
for epoch in range(1, 20):
    print('\nEpoch {} : '.format(epoch))
    train(model, device, train_loader, optimizer, epoch)
    accuracy = test(model, device, test_loader)
    test_accuracy.append(accuracy)

print(test_accuracy)


Epoch 1 : 


loss=0.08680995553731918 batch_id=468: 100%|██████████| 469/469 [00:29<00:00, 15.72it/s]



Test set: Average loss: 0.0604, Accuracy: 9814/10000 (98.14%)


Epoch 2 : 


loss=0.10255077481269836 batch_id=468: 100%|██████████| 469/469 [00:28<00:00, 16.19it/s]



Test set: Average loss: 0.0391, Accuracy: 9877/10000 (98.77%)


Epoch 3 : 


loss=0.03218429163098335 batch_id=468: 100%|██████████| 469/469 [00:28<00:00, 16.19it/s]



Test set: Average loss: 0.0320, Accuracy: 9898/10000 (98.98%)


Epoch 4 : 


loss=0.13792012631893158 batch_id=468: 100%|██████████| 469/469 [00:28<00:00, 16.27it/s]



Test set: Average loss: 0.0274, Accuracy: 9910/10000 (99.10%)


Epoch 5 : 


loss=0.06550478935241699 batch_id=468: 100%|██████████| 469/469 [00:28<00:00, 16.34it/s]



Test set: Average loss: 0.0237, Accuracy: 9926/10000 (99.26%)


Epoch 6 : 


loss=0.018968267366290092 batch_id=468: 100%|██████████| 469/469 [00:28<00:00, 16.44it/s]



Test set: Average loss: 0.0218, Accuracy: 9933/10000 (99.33%)


Epoch 7 : 


loss=0.0606180839240551 batch_id=468: 100%|██████████| 469/469 [00:28<00:00, 16.42it/s]



Test set: Average loss: 0.0209, Accuracy: 9937/10000 (99.37%)


Epoch 8 : 


loss=0.030050842091441154 batch_id=468: 100%|██████████| 469/469 [00:28<00:00, 16.45it/s]



Test set: Average loss: 0.0188, Accuracy: 9944/10000 (99.44%)


Epoch 9 : 


loss=0.03721245378255844 batch_id=468: 100%|██████████| 469/469 [00:29<00:00, 16.10it/s]



Test set: Average loss: 0.0203, Accuracy: 9937/10000 (99.37%)


Epoch 10 : 


loss=0.11268952488899231 batch_id=468: 100%|██████████| 469/469 [00:29<00:00, 16.11it/s]



Test set: Average loss: 0.0220, Accuracy: 9927/10000 (99.27%)


Epoch 11 : 


loss=0.08877578377723694 batch_id=468: 100%|██████████| 469/469 [00:29<00:00, 15.79it/s]



Test set: Average loss: 0.0188, Accuracy: 9933/10000 (99.33%)


Epoch 12 : 


loss=0.02515028417110443 batch_id=468: 100%|██████████| 469/469 [00:29<00:00, 15.68it/s]



Test set: Average loss: 0.0179, Accuracy: 9942/10000 (99.42%)


Epoch 13 : 


loss=0.08377642184495926 batch_id=468: 100%|██████████| 469/469 [00:29<00:00, 15.95it/s]



Test set: Average loss: 0.0184, Accuracy: 9935/10000 (99.35%)


Epoch 14 : 


loss=0.01566915400326252 batch_id=468: 100%|██████████| 469/469 [00:29<00:00, 15.90it/s]



Test set: Average loss: 0.0169, Accuracy: 9950/10000 (99.50%)


Epoch 15 : 


loss=0.13080759346485138 batch_id=468: 100%|██████████| 469/469 [00:29<00:00, 16.14it/s]



Test set: Average loss: 0.0190, Accuracy: 9946/10000 (99.46%)


Epoch 16 : 


loss=0.11341261118650436 batch_id=468: 100%|██████████| 469/469 [00:29<00:00, 15.96it/s]



Test set: Average loss: 0.0180, Accuracy: 9949/10000 (99.49%)


Epoch 17 : 


loss=0.1660809963941574 batch_id=468: 100%|██████████| 469/469 [00:29<00:00, 16.07it/s]



Test set: Average loss: 0.0176, Accuracy: 9939/10000 (99.39%)


Epoch 18 : 


loss=0.08196517825126648 batch_id=468: 100%|██████████| 469/469 [00:29<00:00, 16.11it/s]



Test set: Average loss: 0.0172, Accuracy: 9940/10000 (99.40%)


Epoch 19 : 


loss=0.11692946404218674 batch_id=468: 100%|██████████| 469/469 [00:29<00:00, 16.13it/s]



Test set: Average loss: 0.0177, Accuracy: 9940/10000 (99.40%)

[98.14, 98.77, 98.98, 99.1, 99.26, 99.33, 99.37, 99.44, 99.37, 99.27, 99.33, 99.42, 99.35, 99.5, 99.46, 99.49, 99.39, 99.4, 99.4]
