# LeNet

![LeNet](images/lenet.png)

Paper: [LeCun, Yann, et al. "Gradient-based learning applied to document recognition." Proceedings of the IEEE 86.11 (1998): 2278-2324.](http://yann.lecun.com/exdb/publis/psgz/lecun-98.ps.gz)

Webpage: [LeNet-5, convolutional neural networks](http://yann.lecun.com/exdb/lenet/)

In [None]:
from __future__ import print_function

import torch
import torch.nn as nn
import torch.optim as optim

from torch.autograd import Variable

from torchvision import transforms, datasets

import numpy as np

from tqdm import tqdm

import sys

## Load the dataset

Torchvision has helpers to load the MNIST dataset:

In [None]:
data_path = "data/mnist/raw"
batch_size = 128

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [1.0])
])

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(data_path, train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(data_path, train=False, download=True, transform=transform),
    batch_size=batch_size, shuffle=True)

## Model

In [None]:
class LeNet5(nn.Module):

    def __init__(self):
        super(LeNet5, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(1, 6, 5, 1, 2),
            nn.Tanh(),
            nn.MaxPool2d(2),
            nn.Conv2d(6, 16, 5, 1, 0),
            nn.Tanh(),
            nn.MaxPool2d(2)
        )
        
        self.fc = nn.Sequential(
            nn.Linear(16 * 5 * 5, 120),
            nn.Tanh(),
            nn.Linear(120, 84),
            nn.Tanh(),
            nn.Linear(84, 10),
            nn.LogSoftmax(dim=1)
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(-1, 16 * 5 * 5)
        x = self.fc(x)
        return x

According to the paper, the weights are initialized with random values drawn from a uniform distribution between $-2.4 / F_i$ and $2.4 / F_i$ where $F_i$ is the number of input dimensions (fan-in).

In [None]:
def initialize_weights(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        nn.init.kaiming_uniform_(m.weight, a=0, mode="fan_in")

## Train the network

In [None]:
use_cuda = torch.cuda.is_available()

In [None]:
learning_rate = 0.01

model = LeNet5()
loss_function = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

if use_cuda:
    model.cuda()
    
model.apply(initialize_weights)

In [None]:
def run_batches(loader, title, train=True):
    model.train(mode=train)
    
    epoch_loss = 0.0
    epoch_correct = 0.0
    epoch_total = 0.0
    with tqdm(total=len(loader)) as progress_bar:
        progress_bar.set_description(title)
        for batch_id, (images, labels) in enumerate(loader):
            if train:
                optimizer.zero_grad()

            v_images = Variable(images)
            v_labels = Variable(labels)

            if use_cuda:
                v_images = v_images.cuda()
                v_labels = v_labels.cuda()

            v_predictions = model(v_images)
            v_loss = loss_function(v_predictions, v_labels)
            v_correct = torch.eq(torch.max(v_predictions, 1)[1], v_labels).long().sum()

            if use_cuda:
                loss = v_loss.cpu().data.numpy()
                correct = v_correct.cpu().data.numpy()
            else:
                loss = v_loss.data.numpy()
                correct = v_correct.data.numpy()

            epoch_loss += loss
            epoch_correct += correct
            epoch_total += float(len(labels))

            if train:
                v_loss.backward()
                optimizer.step()
                    
            progress_bar.set_postfix(mean_loss="{:.03f}".format(epoch_loss / epoch_total),
                                     accuracy="{:.03f}".format(epoch_correct / epoch_total))
            progress_bar.update()
            
    return epoch_loss / epoch_total, epoch_correct / epoch_total

In [None]:
epochs = 10

for epoch_id in range(epochs):
    run_batches(train_loader, "Train {}/{}".format(epoch_id + 1, epochs), train=True)
    
    with torch.no_grad():
        run_batches(test_loader, "Test {}/{}".format(epoch_id + 1, epochs), train=False)