# MNIST Digit Classifier using Neural Networks

In [None]:
__author__ = 'rsh'

In [None]:
import torch
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
from torch import nn, optim

In [None]:
from torchvision import datasets, transforms

In [None]:
print (torch.__version__)

In [None]:
transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize((0.5,), (0.5,),)])

In [None]:
traindata = datasets.MNIST('./MNIST', train=True, download=True, transform=transform)

In [None]:
testdata = datasets.MNIST('./MNIST', train=False, download=True, transform=transform)

In [None]:
len(traindata), len(testdata)

In [None]:
trainloader = torch.utils.data.DataLoader(traindata, shuffle=True, batch_size=64)
testloader = torch.utils.data.DataLoader(testdata, shuffle=True, batch_size=64)

In [None]:
images, labels = next(iter(trainloader))

In [None]:
images.shape

In [None]:
plt.imshow(images[0].squeeze(), cmap='Greys_r')

print ('The label is {}'.format(labels[0].item()))

In [None]:
class DigitClassifier(nn.Module):
    def __init__(self,):
        super(DigitClassifier, self).__init__()
        self.linear1 = nn.Linear(784, 256)
        self.linear2 = nn.Linear(256, 64)
        self.linear3 = nn.Linear(64, 10)
        
        self.relu = nn.ReLU()
        
        self.dropout = nn.Dropout(p=0.20)
        self.logsoftmax = nn.LogSoftmax(dim=1)
    
    def forward(self, input):
        out = self.relu(self.linear1(input))
        out = self.relu(self.linear2(self.dropout(out)))
        out = self.logsoftmax(self.linear3(self.dropout(out)))
        
        return out

In [None]:
model = DigitClassifier()

In [None]:
criterion = nn.NLLLoss()

In [None]:
optimizer = optim.SGD(model.parameters(), lr=0.001)

In [None]:
def test_validation(model, test_data):
    total = 0.
    correct = 0.
    test_loss = 0.
    with torch.no_grad():
        model.eval()
        for images, labels in test_data:
            images_size = images.shape[0]
            images = images.view(images_size, -1)
            pred = model(images)
            test_loss += criterion(pred, labels).item()
            value, pred_class = torch.max(pred, dim=1)
            total += images_size
            correct += torch.sum(pred_class == labels)
    return float(correct * 100./total), test_loss/len(test_data)

In [None]:
loss_array = []
test_loss_array = []
accuracy_array = []
loss_val = 0
epochs = 15

# Training 

for i in range(epochs):
    loss_val = 0
    model.train()
    for images, labels in trainloader:
        images = images.view(images.shape[0], -1)
        optimizer.zero_grad()
        
        output = model(images)
        loss = criterion(output, labels)
        
        loss.backward()
        
        loss_val += loss.item()
        
        optimizer.step()
    else:
        accuracy, t_loss = test_validation(model, testloader)
        accuracy_array.append(accuracy)
        test_loss_array.append(t_loss)
        print ('Loss at epoch {} is {}, accuracy: {}'.format(i+1, loss_val/len(trainloader), accuracy))
        print ('Test Loss {}'.format(t_loss))
        loss_array.append(loss_val/len(trainloader))

In [None]:
x_epochs = [i for i in range(epochs)]
plt.plot(x_epochs, loss_array, label='Training Loss')
plt.plot(x_epochs, test_loss_array, label='Testing/Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
plt.plot([i for i in range(epochs)], accuracy_array)
plt.xlabel('Epochs')
plt.ylabel('Test Accuracy')
plt.show()

In [None]:
# Testing on an image (index = 5; you can vary the index)
test_images, test_label = next(iter(testloader))

test_image = test_images.view(test_images.shape[0], -1)[5]

# print (test_image.shape)

# print (test_image.view(1, -1).shape)

pred = model(test_image.view(1, -1))

_, pred_class = torch.max(pred, dim=1)

plt.imshow(test_images[5].squeeze(), cmap='Greys_r')

print ('Predicted: {}'.format(pred_class.item()))

print ('Actual: {}'.format(test_label[5]))


In [None]:
# weights and biases
for param in model.parameters():
    print (param.data)