In [1]:
## Import Modules 

import torch
import torchvision
from torch import nn, optim
import torch.nn.functional as F
from torchsummary import summary

In [2]:
## Model Configuration (Initialize hyperparameters)
batch_size = 64
learning_rate = 0.01
cross_entropy = nn.CrossEntropyLoss()

In [3]:
## DataLoader (Load the training set and validation set using Dataset and DataLoader)
transform = torchvision.transforms.ToTensor()
train_data = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
    'mnist_data', train=True, download=True, transform=transform
    ), batch_size=batch_size
)
val_data = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
    'mnist_data', train=False, download=True, transform=transform
    ), batch_size=batch_size
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist_data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to mnist_data/MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [4]:
# Validation function (To check whether the model is learning properly we can use a validation set)
def validate(model, data):
    total = 0
    correct = 0
    for i, (images, labels) in enumerate(data):
        images = images.cuda()
        labels = labels.cuda()
        y_pred = model(images)
        value, pred = torch.max(y_pred, 1)
        total += y_pred.size(0)
        correct += torch.sum(pred == labels)
    return correct * 100 / total

In [5]:
## Training Function (Training the model)
def train(model,epochs=5) :
    optimizer = optim.Adam(model.parameters(),lr=learning_rate)    
    for n in range(epochs)  :
        for i , (images , labels) in enumerate(train_data) :
            images = images.cuda()
            labels = labels.cuda()
            optimizer.zero_grad()
            prediction = model(images)
            loss = cross_entropy(prediction, labels)
            loss.backward()
            optimizer.step()
        accuracy = float(validate(model, val_data))
        print("Epoch:", n+1, "Loss: ", float(loss.data), "Accuracy:", accuracy)

In [6]:
## Model (A sample CNN is defined here for image classification)
class CNNWithPoolRelu(nn.Module) :
    def __init__(self):
        super(CNNWithPoolRelu,self).__init__()
        self.conv_1 = nn.Conv2d(in_channels=1,out_channels=16,kernel_size=3)
        self.conv_2 = nn.Conv2d(in_channels=16,out_channels=32,kernel_size=3)
        
        self.pool_1 = nn.MaxPool2d(2)
        self.pool_2 = nn.MaxPool2d(2)

        self.dense_1 = nn.Linear(in_features=800,out_features=256)
        self.dense_2 = nn.Linear(in_features=256,out_features=10)

        self.relu = nn.ReLU()
    def forward(self,x) :
        x = self.relu(self.conv_1(x))
        x = self.pool_1(x)
        x = self.relu(self.conv_2(x))
        x = self.pool_2(x)
        x = x.view(x.shape[0],-1)
        x = self.relu(self.dense_1(x))
        x = self.dense_2(x)
        # output = self.tanh(x)
        output = F.log_softmax(x, dim=1)

        return output

In [7]:
# Model (Initialize the neural network)
model = CNNWithPoolRelu().cuda()

In [8]:
# Summary
summary(model, (1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 26, 26]             160
              ReLU-2           [-1, 16, 26, 26]               0
         MaxPool2d-3           [-1, 16, 13, 13]               0
            Conv2d-4           [-1, 32, 11, 11]           4,640
              ReLU-5           [-1, 32, 11, 11]               0
         MaxPool2d-6             [-1, 32, 5, 5]               0
            Linear-7                  [-1, 256]         205,056
              ReLU-8                  [-1, 256]               0
            Linear-9                   [-1, 10]           2,570
Total params: 212,426
Trainable params: 212,426
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.25
Params size (MB): 0.81
Estimated Total Size (MB): 1.07
-------------------------------------------

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


In [9]:
# Train for 30 Epochs
train(model,epochs=30)

Epoch: 1 Loss:  0.084713414311409 Accuracy: 97.62999725341797
Epoch: 2 Loss:  0.002180100418627262 Accuracy: 97.29000091552734
Epoch: 3 Loss:  0.0005753105506300926 Accuracy: 97.79000091552734
Epoch: 4 Loss:  0.0023213077802211046 Accuracy: 98.25999450683594
Epoch: 5 Loss:  0.024052107706665993 Accuracy: 98.19999694824219
Epoch: 6 Loss:  0.014843380078673363 Accuracy: 97.68999481201172
Epoch: 7 Loss:  0.011326142586767673 Accuracy: 98.18999481201172
Epoch: 8 Loss:  2.529386620153673e-06 Accuracy: 98.36000061035156
Epoch: 9 Loss:  2.5404424377484247e-05 Accuracy: 98.22999572753906
Epoch: 10 Loss:  0.003313989145681262 Accuracy: 97.68999481201172
Epoch: 11 Loss:  5.7216253480874e-05 Accuracy: 98.23999786376953
Epoch: 12 Loss:  0.2729506194591522 Accuracy: 98.19999694824219
Epoch: 13 Loss:  1.2854721717303619e-05 Accuracy: 98.25
Epoch: 14 Loss:  2.2351736461700966e-08 Accuracy: 98.14999389648438
Epoch: 15 Loss:  0.0001866416132543236 Accuracy: 98.14999389648438
Epoch: 16 Loss:  4.99185375

**We can observe that utilising the ReLU function improves accuracy and reduces loss more efficiently than using tanh.**