# Multi Layer Perceptron (MLP)

> Simple feedforward Multilayer perceptron model

In [None]:
#| default_exp models.mlp

In [None]:
#| hide
%load_ext autoreload
%autoreload 2
from nbdev.showdoc import *

In [None]:
#| export
import torch.nn as nn
import torch
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

from lightning import LightningModule, Trainer
from lightning.pytorch.tuner.tuning import Tuner
from lightning.pytorch.callbacks import LearningRateFinder
from lightning.pytorch.loggers import CSVLogger

from hydra.utils import instantiate
from omegaconf import OmegaConf
from matplotlib import pyplot as plt
import pandas as pd

from nimrod.utils import get_device
from nimrod.image.datasets import MNISTDataModule
from nimrod.models.core import Classifier
# torch.set_num_interop_threads(1)
# from IPython.core.debugger import set_trace

import logging
logger = logging.getLogger(__name__)


## MLP

In [None]:
#| export
class MLP(nn.Module):
    def __init__(
                self,
                n_in:int=784, # input dimension e.g. (H,W) for image
                n_h:int=64, # hidden dimension
                n_out:int=10, # output dimension (= number of classes for classification)
                dropout:float=0.2
                ) -> None:
        logger.info("MLP: init")
        super().__init__()
        l1 = nn.Linear(n_in, n_h)
        dropout = nn.Dropout(dropout)
        relu = nn.ReLU()
        l2 = nn.Linear(n_h, n_out)
        self.layers = nn.Sequential(l1, dropout, relu, l2)
        
    def forward(self, x: torch.Tensor # dim (B, H*W)
                ) -> torch.Tensor:
        return self.layers(x)

### Usage

In [None]:
show_doc(MLP)

In [None]:
image = torch.rand((5, 28*28))
mlp = MLP(n_in=28*28, n_h=64, n_out=10, dropout=0.1)
out = mlp(image)
print(out.shape)
cfg = OmegaConf.load('../config/image/model/mlp.yaml')
model = instantiate(cfg.nnet)
out = model(image)
print(out.shape)

### Training

In [None]:
# load from config file
cfg = OmegaConf.load('../config/image/data/mnist.yaml')
datamodule = instantiate(cfg.datamodule)
datamodule.prepare_data()
datamodule.setup()

x = datamodule.data_test[0][0] # (C, H, W)
print(len(datamodule.data_test))
label = datamodule.data_test[0][1] #(int)
print("original shape (C,H,W): ", x.shape)
print("reshape (C,HxW): ", x.view(x.size(0), -1).shape)
print(x[0][1])

In [None]:
# using nimrod datamodule
train_loader = datamodule.train_dataloader()
val_loader = datamodule.val_dataloader()
test_loader = datamodule.test_dataloader()

In [None]:
device = "mps" if torch.backends.mps.is_available() else "cpu"
# device = "cpu" # for CI on cpu instance
device = torch.device(device)


#### Training loop

In [None]:
#| notest

%time
# data
cfg = OmegaConf.load('../config/image/data/mnist.yaml')
cfg.batch_size = 2048
datamodule = instantiate(cfg.datamodule)
datamodule.prepare_data()
datamodule.setup()

# model
model = mlp.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


n_epochs = 2
losses = []
lrs = []
current_step = 0
steps_per_epoch = len(datamodule.data_train) // cfg.datamodule.batch_size
total_steps = steps_per_epoch * n_epochs
print(f"steps_per_epoch: {steps_per_epoch}, total_steps: {total_steps}")

for epoch in range(n_epochs):
    model.train()
    for images, labels in datamodule.train_dataloader():
        optimizer.zero_grad()
        images = images.view(-1, 28*28)
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)        
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        current_lr = optimizer.param_groups[0]['lr']
        lrs.append(current_lr)
        if not (current_step % 100):
            print(f"Loss {loss.item():.4f}, Current LR: {current_lr:.10f}, Step: {current_step}/{total_steps}")
        current_step += 1

    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in datamodule.test_dataloader():
            # model expects input (B,H*W)
            images = images.view(-1, 28*28).to(device)
            images = images.to(device)
            labels = labels.to(device)
            # Pass the input through the model
            outputs = model(images)
            # Get the predicted labels
            _, predicted = torch.max(outputs.data, 1)

            # Update the total and correct counts
            total += labels.size(0)
            correct += (predicted == labels).sum()

        # Print the accuracy
        print(f"Epoch {epoch + 1}: Accuracy = {100 * correct / total:.2f}%")


In [None]:
#| notest
# plt.figure(1)
# plt.subplot(211)
plt.ylabel('loss')
plt.xlabel('step')
plt.plot(losses)
# plt.subplot(212)
# plt.ylabel('lr')
# plt.xlabel('step')
# plt.plot(lrs)

## MLP_X

In [None]:
#| export

class MLP_X(Classifier, LightningModule):
    def __init__(
            self,
            nnet:MLP,
            num_classes:int,
            optimizer:torch.optim.Optimizer,
            scheduler:torch.optim.lr_scheduler
        ):
        
        logger.info("MLP_X init")
        super().__init__(num_classes, optimizer, scheduler)
        self.nnet = nnet
        self.save_hyperparameters(logger=False,ignore=['nnet'])
        self.lr = optimizer.keywords['lr'] # for lr finder
    
    def forward(self, x:torch.Tensor)->torch.Tensor:
        return self.nnet(x)
    
    def _step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)
        preds = y_hat.argmax(dim=1)
        return loss, preds, y
    
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x, y = batch
        x = x.view(x.size(0), -1)
        y_hat = self.forward(x)
        return y_hat.argmax(dim=1)

### Usage

In [None]:
cfg = OmegaConf.load('../config/image/model/mlp.yaml')
model = instantiate(cfg)
b = torch.rand((16,1, 28*28))
y = model(b)
print(y.shape)

### Nimrod training

In [None]:
# model
cfg = OmegaConf.load('../config/image/model/conv.yaml')
model = instantiate(cfg)

# data module config
cfg = OmegaConf.load('../config/image/data/mnist.yaml')
cfg.datamodule.batch_size = 2048
cfg.datamodule.num_workers = 0
datamodule = instantiate(cfg.datamodule)
# datamodule.prepare_data()
datamodule.setup()

In [None]:
trainer = Trainer(
    accelerator="auto",
    max_epochs=3,
    logger=CSVLogger("logs", name="mnist_mlp")
    )


In [None]:
#| notest
trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())

In [None]:
#| notest
csv_path = f"{trainer.logger.log_dir}/metrics.csv"
metrics = pd.read_csv(csv_path)
metrics.head(5)

In [None]:
#| notest
plt.figure()
plt.plot(metrics['step'], metrics['train/loss_step'], 'b.-')
plt.plot(metrics['step'], metrics['val/loss'],'r.-')
plt.show()

In [None]:
#| notest
trainer.test(model, datamodule.test_dataloader())

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()