<a href="https://colab.research.google.com/github/robert-s-lee/ymxb/blob/master/mnist-pl-hello-world.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction to Pytorch Lightning ⚡

In this notebook, we'll go over the basics of lightning by preparing models to train on the [MNIST Handwritten Digits dataset](https://en.wikipedia.org/wiki/MNIST_database).

---
  - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)
  - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)
  - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ)

### Setup  
Lightning is easy to install. Simply ```pip install pytorch-lightning```

In [None]:
! pip install pytorch-lightning --quiet

In [42]:
import os

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy

In [None]:
parser = ArgumentParser()
parser.add_argument('--epochs', default=3, type=int)
parser.add_argument('--gpu', default=0, type=int)
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--hidden_size', default=64, type=int)
parser.add_argument('--lr', default=2e-4, type=float)
parser.add_argument('--data', default=os.getcwd(), type=str)

if __name__ == "__main__.py":    
    args = parser.parse_args()      
else:
    args = parser.parse_args("")    # take defaults in Jupyter 

## Simplest example

Here's the simplest most minimal example with just a training loop (no validation, no testing).

**Keep in Mind** - A `LightningModule` *is* a PyTorch `nn.Module` - it just has a few more helpful features.

In [43]:
class MNISTModel(pl.LightningModule):

    def __init__(self):
        super(MNISTModel, self).__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=args.lr)

By using the `Trainer` you automatically get:
1. Tensorboard logging
2. Model checkpointing
3. Training and validation loop
4. early-stopping

In [44]:
# Init our model
mnist_model = MNISTModel()

# Init DataLoader from MNIST Dataset
train_ds = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_ds, batch_size=args.batch_size)

# Initialize a trainer
trainer = pl.Trainer(gpus=args.gpu, max_epochs=args.epochs, progress_bar_refresh_rate=20)

# Train the model ⚡
trainer.fit(mnist_model, train_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name | Type   | Params
--------------------------------
0 | l1   | Linear | 7.9 K 
--------------------------------
7.9 K     Trainable params
0         Non-trainable params
7.9 K     Total params
0.031     Total estimated model params size (MB)
Epoch 2: 100%|██████████| 1875/1875 [00:06<00:00, 289.10it/s, loss=0.486, v_num=3]


### Testing

To test a model, call `trainer.test(model)`.

Or, if you've just trained a model, you can just call `trainer.test()` and Lightning will automatically test using the best saved checkpoint (conditioned on val_loss).

In [39]:
trainer.test()

Testing: 100%|██████████| 313/313 [00:01<00:00, 207.14it/s]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'val_acc': 0.965499997138977, 'val_loss': 0.11321506649255753}
--------------------------------------------------------------------------------


[{'val_loss': 0.11321506649255753, 'val_acc': 0.965499997138977}]

### Bonus Tip

You can keep calling `trainer.fit(model)` as many times as you'd like to continue training

In [None]:
trainer.fit(model)


# Start tensorboard.

In Colab, you can use the TensorBoard magic function to view the logs that Lightning has created for you!
```
%load_ext tensorboard
%tensorboard --logdir lightning_logs/
```

on VS Code 
```
Python: Launch TensorBoard
```