In [110]:
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.nn import functional as F
import pytorch_lightning as pl

In [111]:
train_ds = MNIST(root='data-nn', train=True,
                 download=True, transform=ToTensor())
valid_ds = MNIST(root='data-nn', train=False,
                 download=True, transform=ToTensor())
bs = 64

train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=bs, shuffle=True)

In [112]:
import torchmetrics


class MNISTModel(pl.LightningModule):
    def __init__(self, lr=0.5):
        super().__init__()
        self.lin = nn.Linear(784, 10)
        self.lr = lr

        self.train_accuracy = torchmetrics.Accuracy(
            task='multiclass', num_classes=10)
        self.valid_accuracy = torchmetrics.Accuracy(
            task='multiclass', num_classes=10)

    def forward(self, xb):
        xb = xb.flatten(1, -1)
        return self.lin(xb)

    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch, train=True)
        return loss

    def validation_step(self, batch):
        self.shared_step(batch, train=False)

    def shared_step(self, batch, train):
        xb, yb = batch
        pred = self(xb)
        loss_func = F.cross_entropy(pred, yb)

        if (train):
            self.train_accuracy(pred.softmax(dim=1), yb)
            self.log('train_accuracy', self.train_accuracy,
                     on_step=True, on_epoch=False, prog_bar=True)
        else:
            self.valid_accuracy(pred.softmax(dim=1), yb)
            self.log('valid_accuracy', self.valid_accuracy,
                     on_step=True, on_epoch=True, prog_bar=True)
        return loss_func

    def configure_optimizers(self):
        return optim.SGD(self.parameters(), lr=self.lr)

In [113]:
from lightning.pytorch.loggers import TensorBoardLogger

tb_logger = TensorBoardLogger('tb_logs')

In [114]:
mnist_model = MNISTModel()

trainer = pl.Trainer(max_epochs=2, logger=tb_logger)

trainer.fit(mnist_model, train_dl, valid_dl)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name           | Type               | Params | Mode 
--------------------------------------------------------------
0 | lin            | Linear             | 7.9 K  | train
1 | train_accuracy | MulticlassAccuracy | 0      | train
2 | valid_accuracy | MulticlassAccuracy | 0      | train
--------------------------------------------------------------
7.9 K     Trainable params
0         Non-trainable params
7.9 K     Total params
0.031     Total estimated model params size (MB)
3         Modules in train mode
0         Modules in eval mode


Epoch 1: 100%|██████████| 938/938 [00:04<00:00, 188.61it/s, v_num=0, train_accuracy=1.000, valid_accuracy_step=1.000, valid_accuracy_epoch=0.919]

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 938/938 [00:04<00:00, 188.52it/s, v_num=0, train_accuracy=1.000, valid_accuracy_step=1.000, valid_accuracy_epoch=0.919]
