In [12]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from torchmetrics.functional import accuracy

import numpy as np
import matplotlib.pyplot as plt
from torchmetrics import  Accuracy

In [3]:
# MNIST dataset
class MnistDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=64):
        
        super(MnistDataModule, self).__init__()

        self.batch_size = batch_size
        self.transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
    
    def prepare_data(self):
        MNIST(root='data', train=True, download=True)
        MNIST(root='data', train=False, download=True)

    def setup(self, stage: str):
        # Assign train/val datasets for use in dataloaders
        if stage == 'fit':
            mnist_full = MNIST(root='data', train=True, transform=self.transforms)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
        
        # Assign test dataset for use in dataloader(s)
        if stage == 'test':
            self.mnist_test = MNIST(root='data', train=False, transform=self.transforms)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)
    
    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)
    
    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

In [33]:
# # Model
class MnistModel(pl.LightningModule):
    def __init__(self):
        super(MnistModel, self).__init__()
        self.layer1 = nn.Linear(28*28, 128)
        self.layer2 = nn.Linear(128, 256)
        self.layer3 = nn.Linear(256, 10)
        self.train_accuracy = Accuracy(task='multiclass', num_classes=10)
        self.val_accuracy = Accuracy(task='multiclass', num_classes=10)
        self.test_accuracy = Accuracy(task='multiclass', num_classes=10)

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = nn.Dropout(0.2)(x)
        x = self.layer3(x)
        return x

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=1e-3)

    def training_step(self, batch, batch_idx):
        x,y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = self.train_accuracy(y_hat, y)
        self.log('Train_acc', acc, prog_bar=True, on_step=True, on_epoch=True, logger=True)
        self.log('Train_loss', loss, prog_bar=True, on_step=True, on_epoch=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x,y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = self.val_accuracy(y_hat, y)
        self.log('Val_acc', acc, prog_bar=True, on_step=True, on_epoch=True, logger=True)
        self.log('Val_loss', loss, prog_bar=True, on_step=True, on_epoch=True, logger=True)
        return loss

    def test_step(self, batch, batch_idx):
        x,y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = self.test_accuracy(y_hat, y)
        self.log('Test_acc', acc, prog_bar=True, on_step=True, on_epoch=True, logger=True)
        self.log('Test_loss', loss, prog_bar=True, on_step=True, on_epoch=True, logger=True)
        return loss

In [34]:
# Training
mnist_model = MnistModel()
mnist_data = MnistDataModule()

checkpoint_callback = ModelCheckpoint(
    monitor='Val_acc',
    save_weights_only=True,
    mode='max',
)

trainer = pl.Trainer(max_epochs=10, 
                     accelerator='gpu', 
                     devices=-1, 
                     callbacks=[checkpoint_callback])
trainer.fit(mnist_model, datamodule=mnist_data)

# Testing
trainer.test(mnist_model, datamodule=mnist_data)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type               | Params
------------------------------------------------------
0 | layer1         | Linear             | 100 K 
1 | layer2         | Linear             | 33.0 K
2 | layer3         | Linear             | 2.6 K 
3 | train_accuracy | MulticlassAccuracy | 0     
4 | val_accuracy   | MulticlassAccuracy | 0     
5 | test_accuracy  | MulticlassAccuracy | 0     
------------------------------------------------------
136 K     Trainable params
0         Non-trainable params
136 K     Total params
0.544     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Test_acc_epoch         0.9765999913215637
     Test_loss_epoch        0.10789921879768372
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'Test_acc_epoch': 0.9765999913215637,
  'Test_loss_epoch': 0.10789921879768372}]

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/