# MLP Reference Implementation in Pytorch.

This implementation is based on Pytorch's FashionMNIST example in its
[quickstart](https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html).

See the mlax implementation in `mlp.ipynb` notebook.

In [14]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision

### Load the MNIST dataset.

In [15]:
mnist_train = torchvision.datasets.MNIST(
    root="../data",
    train=True,
    download=True,
    transform=torchvision.transforms.ToTensor()
)
mnist_test = torchvision.datasets.MNIST(
    root="../data",
    train=False,
    download=True,
    transform=torchvision.transforms.ToTensor()
)
print(mnist_train.data.shape)
print(mnist_test.data.shape)

torch.Size([60000, 28, 28])
torch.Size([10000, 28, 28])


### Batch the MNIST data with Pytorch dataloaders.

In [16]:
batch_size = 128
train_dataloader = DataLoader(mnist_train, batch_size, shuffle=True, num_workers=6)
test_dataloader = DataLoader(mnist_test, batch_size, num_workers=6)
print(len(train_dataloader), len(test_dataloader))

469 79


In [17]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


### Define MLP using modules.
We jit-compiled the model for better performance.

In [18]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )
    
    def forward(self, batch):
        flattened = self.flatten(batch)
        logits = self.linear_stack(flattened)
        return logits

model = MLP()
print(model)

MLP(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)


### Define loss function and optimizer.

In [19]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)

### Define training and testing steps.

In [20]:
@torch.compile
def train_step(X, y):
    with torch.enable_grad():
        loss = loss_fn(model(X), y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return loss.item()

In [21]:
@torch.compile
def test_step(X, y):
    with torch.no_grad():
        preds = model(X)
        loss = loss_fn(preds, y)
    accurate = (preds.argmax(1) == y).type(torch.int).sum()
    return loss.item(), accurate.item()

### Define training and testing loops.

In [22]:
def train(dataloader):
    model.train()
    train_loss = 0.0
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        train_loss += train_step(X, y)

    print(f"Train loss: {train_loss / len(dataloader)}")

In [23]:
def test(dataloader):
    model.eval()
    test_loss, accurate = 0.0, 0
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        loss, acc = test_step(X, y)
        test_loss += loss
        accurate += acc
    
    print(f"Test loss: {test_loss / len(dataloader)}, accuracy: {accurate / len(dataloader.dataset)}")

In [24]:
def train_loop(
    train_dataloader,
    test_dataloader,
    epochs,
    test_every
):
    model.to(device)
    for i in range(epochs):
        epoch = (i + 1)
        print(f"Epoch {epoch}\n----------------")
        train(train_dataloader)
        if (epoch % test_every == 0):
            test(test_dataloader)
        print(f"----------------")

### Train MLP on MNIST dataset.

In [25]:
train_loop(train_dataloader, test_dataloader, 30, 5)

Epoch 1
----------------
Train loss: 0.7105030930563331
----------------
Epoch 2
----------------
Train loss: 0.2606135940691556
----------------
Epoch 3
----------------
Train loss: 0.19348088591528345
----------------
Epoch 4
----------------
Train loss: 0.14841762555242854
----------------
Epoch 5
----------------
Train loss: 0.11993246731251034


   function: '<graph break in test_step>' (/tmp/ipykernel_8192/1809820324.py:7)
   reasons:  ___stack0 == 0.07154276967048645
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.


Test loss: 0.11379139458383375, accuracy: 0.9651
----------------
Epoch 6
----------------
Train loss: 0.09955388356421167
----------------
Epoch 7
----------------
Train loss: 0.0844976609147815
----------------
Epoch 8
----------------
Train loss: 0.07103818764230971
----------------
Epoch 9
----------------
Train loss: 0.061924979644718325
----------------
Epoch 10
----------------
Train loss: 0.05413225592612458
Test loss: 0.07388464802923271, accuracy: 0.9766
----------------
Epoch 11
----------------
Train loss: 0.04712611088143991
----------------
Epoch 12
----------------
Train loss: 0.04152066931168217
----------------
Epoch 13
----------------
Train loss: 0.03652756774762292
----------------
Epoch 14
----------------
Train loss: 0.03163804555101308
----------------
Epoch 15
----------------
Train loss: 0.02818333867665277
Test loss: 0.06288157137076936, accuracy: 0.98
----------------
Epoch 16
----------------
Train loss: 0.024346933414194503
----------------
Epoch 17
-------