In [None]:
import torch
from torch import nn
import torchvision
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

## MNIST

In [None]:
dl_train = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('./data/mnist', train=True, download=True))

dl_test  = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('./data/mnist', train=False, download=True))

In [None]:
train_data   = dl_train.dataset.data.to(dtype=torch.float32)
train_labels = dl_train.dataset.targets

In [None]:
fig_mnist, ax = plt.subplots(1,8, figsize=(8*4,4))
for i in range(8):
    ax[i].imshow(train_data[i].numpy(), cmap='Greys');

In [None]:
train_labels[0:8]

## Standarisation/Normalisation

In [None]:
train_dataset = torch.utils.data.TensorDataset( 
    (train_data/128.0-1.0).view(-1,28*28), 
    train_labels)

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, 
                                           batch_size=100, 
                                           shuffle=True)

In [None]:
test_data   = dl_test.dataset.data.to(dtype=torch.float32)
test_labels = dl_test.dataset.targets
test_dataset = torch.utils.data.TensorDataset(
    (test_data/128.8-1.0).view(-1,28*28), test_labels)

## Batch normalisation

$$\hat{x}_{ij} = \frac{x_{ij}-\mu_j}{\sigma_j},
\quad \mu_j=\frac{1}{N_{batch}}\sum_{i\in batch} x_{ij},
\quad \sigma_j = \sqrt{\frac{1}{N_{batch}}\sum_{i\in batch}(x_{ij}-\mu_j)^2}$$

$$y_{ik}=\gamma_j \hat{x}_{ij}+\beta_j $$

In [None]:
model = torch.nn.Sequential(
    nn.Linear(28*28,1200), nn.ReLU(),
    nn.BatchNorm1d(1200),
    nn.Linear(1200,1200), nn.ReLU(),
    nn.BatchNorm1d(1200),
    nn.Linear(1200,1200), nn.ReLU(),
    nn.BatchNorm1d(1200),
    nn.Linear(1200,10),
    nn.BatchNorm1d(10)
)

In [None]:
optim = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.6)

In [None]:
loss_f = nn.CrossEntropyLoss()

In [None]:
errors = []
batches = 0
epochs = 0

In [None]:
%%time
model.train()
for e in range(10):
    for d in train_loader:        
        optim.zero_grad()
        features, labels = d
        pred = model(features)
        loss = loss_f(pred, labels)
        errors.append(loss.item())
        loss.backward()
        optim.step()
        batches += 1
    epochs += 1   
print(loss)        

In [None]:
plt.plot(np.linspace(0,epochs, batches),errors)

In [None]:
model.eval()
with torch.no_grad():
    pred = torch.softmax(model(train_dataset[:][0]),1)
    ac = torch.sum(torch.argmax(pred,1)==train_labels).to(dtype=torch.float32)/len(train_dataset)
ac  

In [None]:
model.eval()
with torch.no_grad():
    pred = torch.softmax(model(test_dataset[:][0]),1)
    ac = torch.sum(torch.argmax(pred,1)==test_labels).to(dtype=torch.float32)/len(test_dataset)
ac    