<a href="https://colab.research.google.com/github/robert-s-lee/ymxb/blob/master/mnist-pl-hello-world.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction to Pytorch Lightning ⚡

In this notebook, we'll go over the basics of lightning by preparing models to train on the [MNIST Handwritten Digits dataset](https://en.wikipedia.org/wiki/MNIST_database).

---
  - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)
  - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)
  - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ)

### Setup  
Lightning is easy to install. Simply ```pip install pytorch-lightning```

In [6]:
! pip install pytorch-lightning



In [2]:
import os
from argparse import ArgumentParser
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy

In [3]:
parser = ArgumentParser()
parser.add_argument('--epochs', default=3, type=int)
parser.add_argument('--gpu', default=0, type=int)
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--hidden_size', default=64, type=int)
parser.add_argument('--lr', default=2e-4, type=float)
parser.add_argument('--data', default=os.getcwd(), type=str)

if __name__ == "__main__.py":    
    args = parser.parse_args()      
else:
    args = parser.parse_args("")    # take defaults in Jupyter 


## A more complete MNIST Lightning Module Example

That wasn't so hard was it?

Now that we've got our feet wet, let's dive in a bit deeper and write a more complete `LightningModule` for MNIST...

This time, we'll bake in all the dataset specific pieces directly in the `LightningModule`. This way, we can avoid writing extra code at the beginning of our script every time we want to run it.

---

### Note what the following built-in functions are doing:

1. [prepare_data()](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html#pytorch_lightning.core.lightning.LightningModule.prepare_data) 💾
    - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.
    - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)

2. [setup(stage)](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning-module.html#setup) ⚙️
    - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test). 
    - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.
    - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage` (or ignore it altogether and exclude any conditionals).
    - **Note this runs across all GPUs and it *is* safe to make state assignments here**

3. [x_dataloader()](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning-module.html#data-hooks) ♻️
    - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`

In [30]:
class LitMNIST(pl.LightningModule):
    
    def __init__(self, data_dir=args.data, hidden_size=args.hidden_size, learning_rate=args.lr):

        super().__init__()

        # Set our init args as class attributes
        self.data_dir = data_dir
        self.hidden_size = hidden_size
        self.learning_rate = learning_rate

        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.dims = (1, 28, 28)
        channels, width, height = self.dims
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

        # Define PyTorch model
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, self.num_classes)
        )

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    ####################
    # DATA RELATED HOOKS
    ####################

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=args.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=args.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=args.batch_size)

In [32]:
model = LitMNIST()
trainer = pl.Trainer(gpus=args.gpu, max_epochs=args.epochs, progress_bar_refresh_rate=20)
trainer.fit(model)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 55.1 K
-------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params
0.220     Total estimated model params size (MB)
Epoch 0:  92%|█████████▏| 1720/1876 [00:10<00:00, 164.02it/s, loss=0.343, v_num=1, val_loss=2.330, val_acc=0.0781]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/157 [00:00<?, ?it/s][A
Epoch 0:  93%|█████████▎| 1740/1876 [00:10<00:00, 164.23it/s, loss=0.343, v_num=1, val_loss=2.330, val_acc=0.0781]
Epoch 0:  95%|█████████▍| 1780/1876 [00:10<00:00, 165.15it/s, loss=0.343, v_num=1, val_loss=2.330, val_acc=0.0781]
Epoch 0:  97%|█████████▋| 1820/1876 [00:10<00:00, 166.02it/s, loss=0.343, v_num=1, val_loss=2.330, val_acc=0.0781]
Epoch 0: 100%|██████████| 1876/1876 [00:11<00:00, 166.91it/s, loss=0.256, v_num=1, val_loss=0.261, val_ac

### Testing

To test a model, call `trainer.test(model)`.

Or, if you've just trained a model, you can just call `trainer.test()` and Lightning will automatically test using the best saved checkpoint (conditioned on val_loss).

In [33]:
trainer.test()

Testing: 100%|██████████| 313/313 [00:01<00:00, 211.11it/s]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'val_acc': 0.9553999900817871, 'val_loss': 0.14908868074417114}
--------------------------------------------------------------------------------


[{'val_loss': 0.14908868074417114, 'val_acc': 0.9553999900817871}]

### Bonus Tip

You can keep calling `trainer.fit(model)` as many times as you'd like to continue training

In [37]:
trainer.fit(model)


  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 55.1 K
-------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params
0.220     Total estimated model params size (MB)
Epoch 2:  92%|█████████▏| 1720/1876 [00:10<13:15,  5.10s/it, loss=0.162, v_num=1, val_loss=0.0508, val_acc=0.984]   
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/157 [00:00<?, ?it/s][A
Epoch 2:  93%|█████████▎| 1740/1876 [00:10<01:03,  2.13it/s, loss=0.162, v_num=1, val_loss=0.0508, val_acc=0.984]
Epoch 2:  95%|█████████▍| 1780/1876 [00:10<00:16,  5.90it/s, loss=0.162, v_num=1, val_loss=0.0508, val_acc=0.984]
Epoch 2:  97%|█████████▋| 1820/1876 [00:10<00:05,  9.55it/s, loss=0.162, v_num=1, val_loss=0.0508, val_acc=0.984]
Epoch 2: 100%|██████████| 1876/1876 [00:10<00:00, 14.43it/s, loss=0.128, v_num=1, val_loss=0.117, val_acc=0.964] 
                                                              [A


# Start tensorboard.

In Colab, you can use the TensorBoard magic function to view the logs that Lightning has created for you!
```
%load_ext tensorboard
%tensorboard --logdir lightning_logs/
```

on VS Code 
```
Python: Launch TensorBoard
```