In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np

import matplotlib.pyplot as plt
plt.tight_layout()
import time

from IPython import display
import seaborn as sns
sns.set_style("darkgrid")


In [12]:
INPUT_LAYER = 28*28
HIDDEN_LAYER = 512
OUTPUT_LAYER = 10

class FCN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(INPUT_LAYER, HIDDEN_LAYER)
        self.fc2 = nn.Linear(HIDDEN_LAYER, OUTPUT_LAYER)
        self.softmax = nn.Softmax(dim=0)
        self.ReLU = nn.ReLU()
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.softmax(x)
        return x

In [13]:
# Load MNIST data (Handwritten Digits)
batch_size = 1
from torchvision import datasets, transforms
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)

In [14]:
#Initialise the Neural Network on GPU
model = FCN().cuda()

In [15]:
loss_fn = torch.nn.MSELoss(size_average=False)

In [16]:
# Helper Classes
class ComputeAccuracy:
    def __init__(self):
        self.data = []
        self.buffer = []
        self.buffer_length = 100
        
    def update(self, predicted_label, true_label):
        self.buffer.append(true_label == predicted_label)
        if len(self.buffer) > self.buffer_length:
            self.data.append((sum(self.buffer)*1.0/len(self.buffer))*100)
            self.buffer.pop(0)



In [17]:
learning_rate = 0.001
accuracy_history = ComputeAccuracy()        
loss_history = []

model.train()


for batch_idx, (data, true_label) in enumerate(train_loader):
    data_flattened = data.view(data.numel())    
    data_flattened = data_flattened.cuda()
    
    predictions = model(Variable(data_flattened))
    values, index = predictions.max(0)
    predicted_label = index
    
    predicted_label = predicted_label.data.cpu().numpy()[0]
    true_label = true_label.cpu().numpy()[0]

    accuracy_history.update(predicted_label, true_label)
    
    target = np.zeros((1, 10))
    target[0, true_label] = 1
    target = torch.Tensor(target)
    target = target.cuda()
    target = Variable(target)

    loss = loss_fn(predictions, target)
    
    if batch_idx % 1 == 0:
        _d = data[0].cpu().numpy()[0]

        f, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, sharex=False, sharey=False, figsize=(30,5))

        """
        Plot actual Digit
        """
        ax1.set_title("Predicted : {} ; Actual : {}".format(predicted_label, true_label))
        sns.heatmap(_d, xticklabels=False, yticklabels=False, cbar=False, ax=ax1)
        
        
        """
        Plot prediction probabilities
        """
        ax2.set_title("Confidence")
        ax2.set_ylim(bottom=0, top=1)
        ax2.bar(np.arange(10), predictions.data.cpu().numpy())
        
        """
        Plot history
        """
        loss_history.append(loss.data[0])
        ax3.set_title("Loss")
        ax3.set_xlabel("Iterations")
        ax3.set_ylabel("Loss")
        ax3.plot(loss_history)
        
        """
        Plot Accuracy
        """
        ax4.set_title("Accuracy ")
        ax4.set_xlabel("Iterations")
        ax4.set_ylabel("Accuracy % ")
        ax4.set_ylim(bottom=0, top=100)
        ax4.plot(accuracy_history.data)
        
        display.clear_output(wait=True)        
        plt.show()

        
    model.zero_grad()
    loss.backward()
    for param in model.parameters():
        param.data -= learning_rate * param.grad.data
    
    if batch_idx > 10000:
        break

KeyboardInterrupt: 