# 👀 Multilayer perceptron (MLP)

In this notebook, we'll walk through the steps required to train your own multilayer perceptron on the CIFAR dataset

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision.transforms import v2

import lightning as L
from lightning.pytorch.callbacks import RichModelSummary

## 0. Parameters

In [None]:
NUM_CLASSES = 10
BATCH_SIZE = 32
EPOCHS = 10

## 1. Prepare the Data

In [None]:
transforms = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
train_set = datasets.CIFAR10(root='data', train=True, download=True, transform=transforms)
test_set = datasets.CIFAR10(root='data', train=False, download=True, transform=transforms)

In [None]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)

## 2. Build the model

In [None]:
class Model(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(3072, 200),
            nn.ReLU(),
            nn.Linear(200, 150),
            nn.ReLU(),
            nn.Linear(150, NUM_CLASSES),
            nn.Softmax(dim=1)
        )

    def forward(self, batch):
        X, _ = batch
        return self.model(X)

    def training_step(self, batch, batch_idx):
        X, y_true = batch
        y_pred = self.model(X)
        loss = F.cross_entropy(y_pred, y_true)
        self.log('train_loss', loss)

        acc = (y_pred.argmax(dim=1) == y_true).float().mean()
        self.log('train_acc', acc)
        return loss

    def test_step(self, batch, batch_idx):
        X, y_true = batch
        y_pred = self.model(X)
        loss = F.cross_entropy(y_pred, y_true)
        self.log('test_loss', loss)

        acc = (y_pred.argmax(dim=1) == y_true).float().mean()
        self.log('test_acc', acc)

    def predict_step(self, batch, batch_idx):
        return self(batch)

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(), lr=0.0005)
        return opt

## 3. Train the model

In [None]:
model = Model()
trainer = L.Trainer(max_epochs=EPOCHS, accelerator='cuda', callbacks=[RichModelSummary(max_depth=2)])
trainer.fit(model, train_loader)

## 4. Evaluation

In [None]:
trainer.test(model, test_loader)

In [None]:
CLASSES = np.array([
    'airplane',
    'automobile',
    'bird',
    'cat',
    'deer',
    'dog',
    'frog',
    'horse',
    'ship',
    'truck',
])

preds = trainer.predict(model, test_loader)
preds = torch.cat(preds)
preds_single = CLASSES[np.argmax(preds, axis=-1)]
actual_single = CLASSES[test_set.targets]

In [None]:
n_to_show = 10
indices = np.random.choice(range(len(test_set)), n_to_show)

fig = plt.figure(figsize=(15, 3))
fig.subplots_adjust(hspace=0.4, wspace=0.4)

for i, idx in enumerate(indices):
    img = test_set.data[idx]
    ax = fig.add_subplot(1, n_to_show, i + 1)
    ax.axis("off")
    ax.text(
        0.5,
        -0.35,
        "pred = " + str(preds_single[idx]),
        fontsize=10,
        ha="center",
        transform=ax.transAxes,
    )
    ax.text(
        0.5,
        -0.7,
        "act = " + str(actual_single[idx]),
        fontsize=10,
        ha="center",
        transform=ax.transAxes,
    )
    ax.imshow(img)