<a href="https://colab.research.google.com/github/peeush-agarwal/week-based-learning/blob/master/Deep-Learning/Batch-Normalization/MNIST_with_Batch_Normalization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Neural Network for MNIST Dataset

Outline:
+ Load MNIST dataset (Train, test)
+ Build a custom Fully connected Neural network model
+ Build a custom Convolutional Neural Network model
+ Train and eval different models
+ Apply **Batch Normalization** on both the above models
+ Depict the difference on both the models (before and after Batch Norm)

## Import libraries

In [0]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

plt.style.use('dark_background')

## Load MNIST dataset

In [0]:
trainset = datasets.MNIST('./Data', train=True, transform=transforms.ToTensor(), download=True)
testset = datasets.MNIST('./Data', train=False, transform=transforms.ToTensor(), download=True)

In [0]:
batch_size = 4

In [0]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size, shuffle=False)

In [0]:
classes = trainset.classes
print (classes)

## Visualize sample images

In [0]:
def imshow(images, labels, n_samples = 4):
  images = images[:n_samples]
  labels = labels[:n_samples]

  images = torchvision.utils.make_grid(images)
  title = f'Labels:{[label.item() for label in labels]}'

  plt.figure(figsize=(n_samples * 4, 4))
  plt.imshow(np.transpose(images, (1,2,0)))
  plt.title(title)
  plt.axis('off')
  plt.show()

In [0]:
images, labels = next(iter(trainloader))

print(images.shape)
print(labels.shape)

imshow(images, labels, n_samples=batch_size if batch_size <= 8 else 8)

## Get computing device

In [0]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

## Build custom Fully connected NeuralNetwork for this dataset

In [0]:
class MNIST_NN(nn.Module):
  def __init__(self):
    super(MNIST_NN, self).__init__()

    self.classifier = nn.Sequential(
        nn.Linear(784, 48),
        nn.ReLU(),
        nn.Linear(48, 24),
        nn.ReLU(),
        nn.Linear(24, 10)
    )

  def forward(self, x):
    x = x.view(-1, 784)
    x = self.classifier(x)

    return x

In [0]:
class MNIST_NN_BatchNorm(nn.Module):
  def __init__(self):
    super(MNIST_NN_BatchNorm, self).__init__()

    self.classifier = nn.Sequential(
        nn.Linear(784, 48),
        nn.BatchNorm1d(48), # 1d because we have 1 channel only, for 3 channels we use 2d. 48 argument = no. of output features from previous layer
        nn.ReLU(),
        nn.Linear(48, 24),
        nn.BatchNorm1d(24),
        nn.ReLU(),
        nn.Linear(24, 10)
    )
  
  def forward(self, x):
    x = x.view(x.size(0), -1)

    x = self.classifier(x)

    return x

In [0]:
loss_fn = nn.CrossEntropyLoss()

In [0]:
# Model without Batch Normalization
fc_model = MNIST_NN().to(device)
fc_optimizer = optim.SGD(fc_model.parameters(), lr=0.01)

In [0]:
fc_batchnorm_model = MNIST_NN_BatchNorm().to(device)
fc_batchnorm_optimizer = optim.SGD(fc_batchnorm_model.parameters(), lr=0.01)

In [0]:
batch_size = 512

In [0]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size, shuffle=True)

In [0]:
# Train both models simulataneously on same data

epochs = 2

loss_arr = []
loss_bn_arr = []
for epoch in range(epochs):
  for i, data in enumerate(trainloader, 0):
    images, labels = data
    images = images.to(device)
    labels = labels.to(device)

    # Train FC without BatchNorm model
    fc_optimizer.zero_grad()
    outputs = fc_model(images)
    loss = loss_fn(outputs, labels)
    loss.backward()
    fc_optimizer.step()

    # Train FC with BatchNorm model
    fc_batchnorm_optimizer.zero_grad()
    outputs1 = fc_batchnorm_model(images)
    loss1 = loss_fn(outputs1, labels)
    loss1.backward()
    fc_batchnorm_optimizer.step()

    loss_arr.append(loss.item())
    loss_bn_arr.append(loss1.item())

    if i%10 == 0:
      # Display distribution plots of learning after BatchNorm layer
      inputs = images.view(images.size(0), -1)

      fc_model.eval()
      fc_batchnorm_model.eval()

      a = fc_model.classifier[0](inputs)    # Linear(784, 48)
      a = fc_model.classifier[1](a)         # Relu()
      a = fc_model.classifier[2](a)         # Linear(48, 24)
      a = a.detach().numpy().ravel()        # Flatten 
      sns.distplot(a, kde=True, color='r', label='W/o Batch Norm')

      b = fc_batchnorm_model.classifier[0](inputs)  # Linear(784, 48)
      b = fc_batchnorm_model.classifier[1](b)       # BatchNorm1d(48)
      b = fc_batchnorm_model.classifier[2](b)       # Relu()
      b = fc_batchnorm_model.classifier[3](b)       # Linear(48, 24)
      b = fc_batchnorm_model.classifier[4](b)       # BatchNorm1d(24)
      b = b.detach().numpy().ravel()                # Flatten 
      sns.distplot(b, kde=True, color='g', label='W/ Batch Norm')

      plt.title(f'iteration:{i}, loss:{loss.item():.2f}, loss_bn:{loss1.item():.2f}')
      plt.legend()
      plt.show()

      fc_model.train()
      fc_batchnorm_model.train()

plt.plot(loss_arr, 'r', label='W/o BatchNorm')
plt.plot(loss_bn_arr, 'g', label='W/ BatchNorm')
plt.legend()
plt.show()

It depicts that Batch normalization technique helps it to reduce loss compared to traditional technique. 

## Build Convolution Neural Network (CNN) for MNIST dataset

In [0]:
class MNIST_CNN(nn.Module):
  def __init__(self):
    super(MNIST_CNN, self).__init__()

    self.features1 = nn.Sequential(
        nn.Conv2d(1, 3, 5),               # (N, 1, 28, 28) => (N, 3, 24, 24)
        nn.ReLU(),
        nn.AvgPool2d(2, stride=2),        # (N, 3, 24, 24) => (N, 3, 12, 12)
        nn.Conv2d(3, 6, 3)                # (N, 3, 12, 12) => (N, 6, 10, 10)
    )
    self.features2 = nn.Sequential(
        nn.ReLU(),
        nn.AvgPool2d(2, stride=2)         # (N, 6, 10, 10) => (N, 6, 5, 5)
    )
    self.classifier = nn.Sequential(
        nn.Linear(150, 25),               # 150 => 6*5*5
        nn.ReLU(),
        nn.Linear(25, 10)
    )

  def forward(self, x):
    x = self.features1(x)
    x = self.features2(x)
    x = x.view(x.size(0), -1)
    x = self.classifier(x)

    return x

In [0]:
class MNIST_CNN_BatchNorm(nn.Module):
  def __init__(self):
    super(MNIST_CNN_BatchNorm, self).__init__()

    self.features1 = nn.Sequential(
        nn.Conv2d(1, 3, 5),
        nn.ReLU(),
        nn.AvgPool2d(2, stride=2),
        nn.Conv2d(3, 6, 3),
        nn.BatchNorm2d(6)
    )
    self.features2 = nn.Sequential(
        nn.ReLU(),
        nn.AvgPool2d(2, stride=2)
    )
    self.classifier = nn.Sequential(
        nn.Linear(150, 25),
        nn.ReLU(),
        nn.Linear(25, 10)
    )

  def forward(self, x):
    x = self.features1(x)
    x = self.features2(x)
    x = x.view(x.size(0), -1)
    x = self.classifier(x)

    return x

In [0]:
batch_size = 512

In [0]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size, shuffle=True)

In [0]:
loss_fn = nn.CrossEntropyLoss()

In [0]:
cnn_model = MNIST_CNN().to(device)

cnn_optimizer = optim.SGD(cnn_model.parameters(), lr=0.01)

In [0]:
cnn_bn_model = MNIST_CNN_BatchNorm().to(device)

cnn_bn_optimizer = optim.SGD(cnn_bn_model.parameters(), lr=0.01)

In [0]:
epochs = 10

loss_arr = []
loss_bn_arr = []

for epoch in range(epochs):
  for i, data in enumerate(trainloader, 0):
    images, labels = data
    images = images.to(device)
    labels = labels.to(device)

    cnn_optimizer.zero_grad()

    outputs = cnn_model(images)
    loss = loss_fn(outputs, labels)
    loss.backward()
    cnn_optimizer.step()

    loss_arr.append(loss.item())

    cnn_bn_optimizer.zero_grad()

    outputs_bn = cnn_bn_model(images)
    loss_bn = loss_fn(outputs_bn, labels)
    loss_bn.backward()
    cnn_bn_optimizer.step()

    loss_bn_arr.append(loss_bn.item())

    if i%50 == 0:
      cnn_model.eval()
      cnn_bn_model.eval()

      a = cnn_model.features1(images)
      a = a.detach().numpy().ravel()
      sns.distplot(a, kde=True, color='r', label='W/o BatchNorm')

      b = cnn_bn_model.features1(images)
      b = b.detach().numpy().ravel()
      sns.distplot(b, kde=True, color='g', label='W/ BatchNorm')
      plt.title(f'E:{epoch} I:{i} Loss:{loss.item():.2f} loss_bn:{loss_bn.item():.2f}')
      plt.legend()
      plt.show()

      cnn_model.train()
      cnn_bn_model.train()

plt.plot(loss_arr, color='r', label='W/o BatchNorm')
plt.plot(loss_bn_arr, color='g', label='W/ BatchNorm')
plt.legend()
plt.show()

Above graph clearly depicts, BatchNormalization has a great impact on reducing loss while training.