## PyTorch Lightning for MNIST

Lighting to easily train and test a model on MNIST dataset.

In [None]:
! pip install lightning torchmetrics --upgrade
! pip install torch torchvision --upgrade

### DataModule for MNIST



In [None]:
import torch
import torchvision
import lightning as L
import torch.nn as nn
from torchvision import transforms

# create an MNIST datamodule

class MNISTDataModule(L.LightningDataModule):
    def __init__(self, batch_size: int = 32):
        super().__init__()
        self.batch_size = batch_size

    def prepare_data(self):
        torchvision.datasets.MNIST(root="./data", download=True)

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.mnist_train = torchvision.datasets.MNIST(
                root="./data", train=True, download=True, transform=transforms.ToTensor()
            )
        if stage == "test" or stage is None:
            self.mnist_test = torchvision.datasets.MNIST(
                root="./data", train=False,download=True, transform=transforms.ToTensor()
            )

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.mnist_train, batch_size=self.batch_size, shuffle=True
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.mnist_test, batch_size=self.batch_size, shuffle=False
        )


### Model for MNIST

In [3]:
# 3-layer CNN model
class CNN(nn.Module):
    def __init__(self, num_classes=10):
        super(CNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(16, 32, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(800, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        y = self.features(x)
        y = torch.flatten(y, 1)
        y = self.classifier(y)
        
        return y

### LightningModule for MNIST

Notice this module uses the sample `CNN()` model defined above.

In [4]:
# build a pytorch lightning model
from torchmetrics import Accuracy

class LitModel(L.LightningModule):
    def __init__(self, model=CNN()):
        super().__init__()
        self.model = model
        self.loss_fn = nn.CrossEntropyLoss()
        self.accuracy = Accuracy(task="multiclass", num_classes=10)
        self.avg_acc = []
        self.avg_loss = []

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def on_train_epoch_end(self):
        lr = self.trainer.optimizers[0].param_groups[0]["lr"]
        self.log("learning_rate", lr, on_step=False, on_epoch=True, prog_bar=True, logger=True)    

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        acc = self.accuracy(y_hat, y)
        self.avg_acc.append(acc)
        self.avg_loss.append(loss)
        return loss, acc

    def on_test_epoch_end(self):
        loss = torch.stack(self.avg_loss).mean()
        acc = torch.stack(self.avg_acc).mean()
        self.log("test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log("test_acc", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.avg_acc = []
        self.avg_loss = []
        return loss, acc

### Training and Testing the Model

Note how we use the `Trainer` class to train and test the model. Also, note the selection of accelerator and precision.

In [None]:
# train for 10 epochs

dm = MNISTDataModule()
model = LitModel()
trainer = L.Trainer(max_epochs=10, accelerator="gpu", devices=[0], precision=16)
trainer.fit(model, dm)
trainer.test(model, dm)
