In [None]:
from __future__ import print_function
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
import numpy as np

In [None]:
data_path = "data/MNIST"
batch_size = 128

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [1.0])
])

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(data_path, train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(data_path, train=False, download=True, transform=transform),
    batch_size=batch_size, shuffle=True)

In [None]:
class LeNet5(nn.Module):

    def __init__(self):
        super(LeNet5, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(1, 6, 5, 1, 2),
            nn.Tanh(),
            nn.MaxPool2d(2),
            nn.Conv2d(6, 16, 5, 1, 0),
            nn.Tanh(),
            nn.MaxPool2d(2)
        )
        
        self.fc = nn.Sequential(
            nn.Linear(16 * 5 * 5, 120),
            nn.Tanh(),
            nn.Linear(120, 84),
            nn.Tanh(),
            nn.Linear(84, 10),
            nn.LogSoftmax()
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(-1, 16 * 5 * 5)
        x = self.fc(x)
        return x

In [None]:
def initialize_weights(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        nn.init.kaiming_uniform(m.weight, a=0, mode="fan_in")

In [None]:
use_cuda = torch.cuda.is_available()

In [None]:
learning_rate = 0.01

model = LeNet5()
loss_function = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

if use_cuda:
    model.cuda()
    
model.apply(initialize_weights)

In [None]:
log_every_batches = 100

def run_batches(loader, train=True):
    epoch_loss = 0.0
    epoch_correct = 0.0
    for batch_id, (images, labels) in enumerate(loader):
        if train:
            optimizer.zero_grad()
        
        v_images = Variable(images)
        v_labels = Variable(labels)

        if use_cuda:
            v_images = v_images.cuda()
            v_labels = v_labels.cuda()

        v_predictions = model(v_images)
        v_loss = loss_function(v_predictions, v_labels)
        v_correct = torch.eq(torch.max(v_predictions, 1)[1], v_labels)

        if use_cuda:
            loss = v_loss.cpu().data.numpy()[0]
            correct = v_correct.cpu().data.numpy()[0]
        else:
            loss = v_loss.data.numpy()[0]
            correct = v_correct.data.numpy()[0]
        
        epoch_loss += loss
        epoch_correct += correct
        
        if train:
            v_loss.backward()
            optimizer.step()

            if batch_id % log_every_batches == log_every_batches - 1:
                print("Train Batch: {:5d} Loss: {:.4f}".format(batch_id + 1, loss))
            
    return epoch_loss / len(loader), epoch_correct / len(loader)

In [None]:
epochs = 10

for epoch_id in range(epochs):
    train_loss, train_accuracy = run_batches(train_loader, train=True)
    test_loss, test_accuracy = run_batches(test_loader, train=False)
    
    print("Epoch: {:5d} Train Loss: {:.4f} Test Loss: {:.4f} Train Accuracy: {:.4f} Test Accuracy: {:.4f}".format(
        epoch_id + 1, train_loss, test_loss, train_accuracy, test_accuracy))