# MLP Reference Implementation in Pytorch.

This implementation is almost identical to Pytorch's FashionMNIST example in
its [quickstart](https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html)

See the mlax implementation in the `mlp.ipynb` notebook.

In [1]:
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
import torchvision
from torch.utils import data
import numpy as np

### Load the MNIST dataset

In [2]:
mnist_train = torchvision.datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=torchvision.transforms.ToTensor()
)
mnist_test = torchvision.datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=torchvision.transforms.ToTensor()
)
print(mnist_train.data.shape)
print(mnist_test.data.shape)

torch.Size([60000, 28, 28])
torch.Size([10000, 28, 28])


### Batch the MNIST data with Pytorch dataloaders
Data is not shuffled to keep in line with the mlax implementation.

In [3]:
# Note data is not shuffled
train_dataloader = DataLoader(mnist_train, batch_size=64)
test_dataloader = DataLoader(mnist_test, batch_size=64)
print(len(train_dataloader), len(test_dataloader))

938 157


Unlike in mlax, in Pytorch, the default device is the CPU.

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

### Define MLP using modules.
Unlike in mlax's model, the Pytorch MLP does not end with softmax layer. This is
because `torch.nn.CrossEntropyLoss` take in logits, not probabilities.

In [5]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )
    
    def forward(self, batch):
        flattened = self.flatten(batch)
        logits = self.linear_stack(flattened)
        return logits

model = MLP()
print(model)

MLP(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)


### Define loss function and optimizer
Pytorch's [SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html)
implementation differs from the conventional mathematical definition, which
mlax follows.
The mlax implementation of SGD is more similar to [Tensorflow's](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/SGD).

In [6]:
cross_entropy = nn.CrossEntropyLoss()
sgd = optim.SGD(model.parameters(), lr=5e-3, momentum=0.6)

### Define training and testing loops

In [7]:
def train(dataloader, model, loss_fn, optimizer, device):
    model.to(device)
    model.train()

    train_loss = 0
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)

        loss = loss_fn(model(X), y)
        train_loss += loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Train loss: {train_loss / len(dataloader)}")

In [8]:
def test(dataloader, model, loss_fn, device):
    model.to(device)
    model.eval()

    test_loss, accurate = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            accurate += (pred.argmax(1) == y).type(torch.float).sum().item()
    
    print(f"Test loss: {test_loss / len(dataloader)}, accuracy: {accurate / len(dataloader.dataset)}")

In [9]:
def train_loop(
    train_dataloader,
    test_dataloader,
    model, loss_fn, optimizer,
    device,
    epochs, test_every):
    for i in range(epochs):
        epoch = (i + 1)
        print(f"Epoch {epoch}\n----------------")
        train(train_dataloader, model, loss_fn, optimizer, device)
        if (epoch % test_every == 0):
            test(test_dataloader, model, loss_fn, device)
        print(f"----------------")

### Train MLP on MNIST dataset
Achieves an accuracy of ~98% in ~4 minutes.

In [10]:
train_loop(train_dataloader, test_dataloader, model, cross_entropy, sgd, device, 50, 5)

Epoch 1
----------------
Train loss: 1.3777899742126465
----------------
Epoch 2
----------------
Train loss: 0.43444499373435974
----------------
Epoch 3
----------------
Train loss: 0.3439459204673767
----------------
Epoch 4
----------------
Train loss: 0.3045724630355835
----------------
Epoch 5
----------------
Train loss: 0.2771012485027313
Test loss: 0.2571821173498775, accuracy: 0.9252
----------------
Epoch 6
----------------
Train loss: 0.25439509749412537
----------------
Epoch 7
----------------
Train loss: 0.2343655526638031
----------------
Epoch 8
----------------
Train loss: 0.21647438406944275
----------------
Epoch 9
----------------
Train loss: 0.20058171451091766
----------------
Epoch 10
----------------
Train loss: 0.18654023110866547
Test loss: 0.178603809412901, accuracy: 0.9465
----------------
Epoch 11
----------------
Train loss: 0.1740572154521942
----------------
Epoch 12
----------------
Train loss: 0.1629309356212616
----------------
Epoch 13
------------