In [130]:
'''
Dataset
Build a model
Define loss_func & optimizer 
Define trainer (model produces prediciton -> compute the loss (label - pred) -> backprop)
Define test (on validation dataset for now)
Run trainer and test
'''

'\nDataset\nBuild a model\nDefine loss_func & optimizer \nDefine trainer (model produces prediciton -> compute the loss (label - pred) -> backprop)\nDefine test (on validation dataset for now)\nRun trainer and test\n'

In [131]:
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
import torchmetrics

In [132]:
# Get the dataset and create dataloaders
train_ds = MNIST(root='data', train=True, download=True, transform=ToTensor())
valid_ds = MNIST(root='data', train=False, download=True, transform=ToTensor())

bs = 64
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=bs, shuffle=True)

In [134]:
# Build the model
# With Lightning, you need just 3 functions: forward, training_step, configure_optimizers
class MNISTModel(pl.LightningModule): # pl.LightningModule is nn.Module - just with a few extra features
    def __init__(self, lr=0.5):
        super().__init__()
        self.lr = lr
        self.lin = nn.Linear(28 * 28, 10)
        
        # Metrics
        self.train_accuracy = torchmetrics.Accuracy()
        self.valid_accuracy = torchmetrics.Accuracy()
    
    def forward(self, xb): # it's exactly same as PyTorch!
        xb = xb.flatten(1, -1) # (bs, 1, 28, 28) -> (bs, 784)
        return self.lin(xb) # how do we know the shape of xb? See docs -> example

    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch, train=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.shared_step(batch, train=False)

    # optionally add test_step (will only be called on trainer.test()):
    # def test_step(self, batch, batch_idx):
    #     loss = self.shared_step(batch)
    #     return loss

    def shared_step(self, batch, train):
        xb, yb = batch
        pred = self(xb)
        loss = F.cross_entropy(pred, yb)

        # Logging
        if (train):
            # must add softmax because pred doesn't add up to 1 (since we're doing that in cross_entropy)
            self.train_accuracy(pred.softmax(dim=-1), yb)
            self.log('train_accuracy', self.train_accuracy, on_step=True, on_epoch=False, prog_bar=True)
        else:
            self.valid_accuracy(pred.softmax(dim=-1), yb)
            self.log('valid_accuracy', self.valid_accuracy, on_step=True, on_epoch=True, prog_bar=True)

        return loss # just return the loss and the Lightning module will take care of backprop and update

    def configure_optimizers(self):
        return optim.SGD(self.parameters(), lr=self.lr)

In [133]:
# Logger
tb_logger = TensorBoardLogger('tb_logs')

In [135]:
# Init model
mnist_model = MNISTModel()

# Init trainer
trainer = pl.Trainer(max_epochs=2, logger=tb_logger) # there's a bunch of options including logging!

# Train the model!
trainer.fit(mnist_model, train_dl, valid_dl)

# Optionally: run test
# trainer.test()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name           | Type     | Params
--------------------------------------------
0 | lin            | Linear   | 7.9 K 
1 | train_accuracy | Accuracy | 0     
2 | valid_accuracy | Accuracy | 0     
--------------------------------------------
7.9 K     Trainable params
0         Non-trainable params
7.9 K     Total params
0.031     Total estimated model params size (MB)
Epoch 0:  86%|████████▌ | 938/1095 [00:10<00:01, 85.60it/s, loss=0.286, v_num=0, valid_accuracy_epoch=0.0938, train_accuracy=0.906]
Validating: 0it [00:00, ?it/s][A
Epoch 0:  86%|████████▋ | 946/1095 [00:11<00:01, 85.74it/s, loss=0.286, v_num=0, valid_accuracy_epoch=0.0938, train_accuracy=0.906]
Epoch 0:  87%|████████▋ | 958/1095 [00:11<00:01, 85.96it/s, loss=0.286, v_num=0, valid_accuracy_epoch=0.0938, train_accuracy=0.906]
Epoch 0:  89%|████████▉ | 974/1095 [00:11<00:01, 86.59it/s, loss=0.286, v_num=0, valid_accuracy_epoch=0.0938, train_ac