In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import os


In [None]:
input_shape = 28*28  # we receive the size of images in MNIST as 28x28
# as we are having neural network the input shape
# will be 28*28 and we will flatten each image to
# match the input size of model
classes = 10
learning_rate = 0.001
batch_size = 64
num_epochs = 3
model_file = 'model_run/CNN.pth.tar'


## Model Definition


In [None]:
class CNN(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(CNN, self).__init__()
        self.cnn1 = nn.Conv2d(in_channels=in_channels, out_channels=4,
                              kernel_size=(3, 3), stride=(1, 1),
                              padding=(1, 1))
        self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        self.cnn2 = nn.Conv2d(in_channels=4, out_channels=8,
                              kernel_size=(3, 3), stride=(1, 1),
                              padding=(1, 1))
        self.fc1 = nn.Linear(8*14*14, num_classes)

    def forward(self, x):
        x = F.relu(self.cnn1(x))
        x = self.pool(x)
        x = F.relu(self.cnn2(x))
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        return x


In [None]:
model = CNN(1, 10)  # MNIST has black and white dataset so in_channels =1
# and output of the model is 10 differnet classes


In [None]:
metrics = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


In [None]:
# Checkpoint saving function
def save_checkpoint(state, filename=model_file):
    if (state['epoch'] == 0):
        print("===> Saving Checkpoint ")
        torch.save(state, filename)
        return
    else:
        checkpoint = torch.load(filename)
        if (checkpoint['accuracy'] < state['accuracy']):
            print("===> Saving Checkpoint ")
            torch.save(state, filename)
        else:
            return


In [None]:
# Accuracy check function
def check_accuracy(loader, model):
    if loader.dataset.train:
        print("testing on training data")
    else:
        print("Testing on Testing data")
    num_correct = 0
    num_sample = 0
    model.eval()  # We're telling model to shift to eval mode

    with torch.no_grad():   # we dont want the model to calculate the graidents
        # just the outputs are required as model is already trained
        for x, y in loader:
            scores = model(x)
            # scores.max = ([max_values], [indices of max value])
            # Shape = batchsize
            # here indices represent class with max prob
            _, predictions = scores.max(1)
            # we want the class with max probability
            # hence we take max from the last dimension
            # for each image in the batch
            num_correct += (predictions == y).sum()
            # prediction is of shape (batch) and when you compare both
            # you get array of size(batch) with entries of either 1 or 0
            # by summing them you get number of correct predictions

            num_sample += predictions.size(0)

        acc = round((float(num_correct) / float(num_sample)*100), ndigits=2)
        # print(f'Got {float(num_correct)/(float(num_sample))*100:.2f} accuracy')

        # if this was used to check accuracy for every epoch during training
        # add the code
        model.train()
        return acc


## Dataset

In [None]:
train_dataset = datasets.MNIST(root="dataset/", train=True,
                               transform=transforms.ToTensor(),
                               download=True)
test_dataset = datasets.MNIST(root="dataset/", train=False,
                              transform=transforms.ToTensor(),
                              download=True)
train_loader = DataLoader(train_dataset, shuffle=True,
                          batch_size=batch_size, drop_last=True)
test_loader = DataLoader(test_dataset, shuffle=True, batch_size=batch_size)

## Training

In [None]:
if (os.path.isfile(model_file)):
    prev_model = torch.load(model_file)
    model.load_state_dict(prev_model['state_dict'])
    max_acc = prev_model['accuracy']
else:
    max_acc = 0


In [None]:
for epochs in range(num_epochs):

    for batch_idx, (train, targets) in enumerate(train_loader):
        data = train
        targets = targets  # now we have loaded the data
        print(targets.shape)
        print(data.shape)
        # now fit the model
        scores = model(data)
        print(scores.shape)
        loss = metrics(scores, targets)
        # CAlculate gradient of the loss wrt the parameters
        # set optimizer's gradients to zero for every batch initially
        optimizer.zero_grad()
        loss.backward()

        # gradient descent
        optimizer.step()

    acc = check_accuracy(train_loader, model)
    print(f'For epoch : {epochs} accuracy is {acc}')

    if (acc > max_acc):
        max_acc = acc
        checkpoint = {'state_dict': model.state_dict(),
                      'optimizer': optimizer.state_dict(),
                      'accuracy': acc,
                      'epoch': epochs}
        save_checkpoint(checkpoint)


In [None]:
check_accuracy(train_loader, model)
