In [1]:
import pytorch_lightning as pl
import torch
from torch import nn

from torchmetrics import Accuracy

In [2]:
class MultilayerPerceptron(pl.LightningModule):
    def __init__(self, image_shape=(1, 28, 28), hidden_units=(32, 16)):
        super().__init__()
        self.train_acc = Accuracy(task="multiclass", num_classes=10)
        self.valid_acc = Accuracy(task="multiclass", num_classes=10)
        self.test_acc = Accuracy(task="multiclass", num_classes=10)

        input_size = image_shape[0] * image_shape[1] * image_shape[2]
        all_layers = [nn.Flatten()]
        for hidden_unit in hidden_units:
            layer = nn.Linear(input_size, hidden_unit)
            all_layers.append(layer)
            all_layers.append(nn.ReLU())
            input_size = hidden_unit

        all_layers.append(nn.Linear(hidden_units[-1], 10))
        self.model = nn.Sequential(*all_layers)

    def forward(self, x):
        x = self.model(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.train_acc.update(preds, y)
        self.log("train_loss", loss, prog_bar=True)
        return loss
    
    def on_training_epoch_end(self):
        self.log("train_acc", self.train_acc.compute())
        self.train_acc.reset()

    def on_validation_epoch_end(self):
        self.log("valid_acc", self.valid_acc.compute())
        self.valid_acc.reset()

    def on_test_epoch_end(self):
        self.log("test_acc", self.test_acc.compute())
        self.test_acc.reset()

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.valid_acc.update(preds, y)
        self.log("valid_loss", loss, prog_bar=True)
        self.log("valid_acc", self.valid_acc.compute(), prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.test_acc.update(preds, y)
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", self.test_acc.compute(), prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

In [3]:
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.datasets import MNIST
from torchvision import transforms

In [None]:
class MnistDataModule(pl.LightningDataModule):
    def __init__(self, data_path='../data/'):
        super().__init__()
        self.data_path = data_path
        self.transform = transforms.Compose([transforms.ToTensor()])

    def prepare_data(self):
        # MNIST(root=self.data_path, download=True)
        MNIST(root=self.data_path, download=False)

    def setup(self, stage=None):
        mnist_all = MNIST(root=self.data_path, train=True, transform=self.transform, download=False)

        self.train, self.val = random_split(mnist_all, [55000, 5000], generator=torch.Generator().manual_seed(1))

        self.test = MNIST(root=self.data_path, train=False, transform=self.transform, download=False)

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=64, num_workers=4)
    
    def val_dataloader(self):
        return DataLoader(self.val, batch_size=64, num_workers=4)
    
    def test_dataloader(self):
        return DataLoader(self.test, batch_size=64, num_workers=4)

torch.manual_seed(1)
mnist_dm = MnistDataModule()

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

mnistclassifier = MultilayerPerceptron()
logger = TensorBoardLogger("lightning_logs", name="version_0")
callbacks = [ModelCheckpoint(dirpath="../models/", save_top_k=1, mode='max', monitor="valid_acc")]

if torch.cuda.is_available():
    trainer = pl.Trainer(max_epochs=10, logger=logger, callbacks=callbacks, accelerator="gpu", devices=1)
else:
    trainer = pl.Trainer(max_epochs=10, logger=logger, callbacks=callbacks)

trainer.fit(model=mnistclassifier, datamodule=mnist_dm)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/quan-pham/jupyter-env/lib/python3.12/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /home/quan-pham/Documents/Study/MLPytorch/models exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | train_acc | MulticlassAccuracy | 0      | train
1 | valid_acc | MulticlassAccuracy | 0      | train
2 | test_acc  | MulticlassAccuracy | 0      | train
3 | model     | Sequential         | 25.8 K | train
---------------------------------------------------------
25.8 K    Trainable params
0         Non-trainable params
25.8 K    Total params
0.103     Total estimated model params size (MB)
10        Modules in train mode
0         Modules in eval mode


Epoch 9: 100%|██████████| 860/860 [00:14<00:00, 57.73it/s, v_num=0, train_loss=0.260, valid_loss=0.166, valid_acc=0.949]  

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 860/860 [00:14<00:00, 57.72it/s, v_num=0, train_loss=0.260, valid_loss=0.166, valid_acc=0.949]


In [6]:
trainer.test(model=mnistclassifier, datamodule=mnist_dm, ckpt_path='best')

Restoring states from the checkpoint path at /home/quan-pham/Documents/Study/MLPytorch/models/epoch=8-step=7740.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/quan-pham/Documents/Study/MLPytorch/models/epoch=8-step=7740.ckpt


Testing DataLoader 0: 100%|██████████| 157/157 [00:01<00:00, 132.82it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.9553999900817871
        test_loss           0.14912296831607819
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.14912296831607819, 'test_acc': 0.9553999900817871}]