### Image Recognition on MNIST using PyTorch Lightning

Demonstrating the elements of machine learning:

1) **E**xperience (Datasets and Dataloaders)<br>
2) **T**ask (Classifier Model)<br>
3) **P**erformance (Accuracy)<br>

**Experience:** <br>
We use MNIST dataset for this demo. MNIST is made of 28x28 images of handwritten digits, `0` to `9`. The train split has 60,000 images and the test split has 10,000 images. Images are all gray-scale.

**Task:**<br>
Our task is to classify the images into 10 classes. We use ResNet18 model from torchvision.models. The ResNet18 first convolutional layer (`conv1`) is modified to accept a single channel input. The number of classes is set to 10.

**Performance:**<br>
We use accuracy metric to evaluate the performance of our model on the test split. `torchmetrics.functional.accuracy`  calculates the accuracy.

**[Pytorch Lightning](https://www.pytorchlightning.ai/):**<br>
Our demo uses Pytorch Lightning to simplify the process of training and testing. Pytorch Lightning `Trainer` trains and evaluates our model. The default configurations are for a GPU-enabled system with 48 CPU cores. Please change the configurations if you have a different system.

**[Weights and Biases](https://www.wandb.ai/):**<br>
`wandb` is used by PyTorch Lightining Module to log train and evaluations results. Use `--no-wandb` to disable `wandb`.


Let us install `pytorch-lightning` and `torchmetrics`.

In [1]:
%pip install pytorch-lightning --upgrade
%pip install torchmetrics --upgrade

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting pytorch-lightning
  Downloading pytorch_lightning-1.7.7-py3-none-any.whl (708 kB)
[K     |████████████████████████████████| 708 kB 38.8 MB/s eta 0:00:01
Collecting tensorboard>=2.9.1
  Downloading tensorboard-2.10.1-py3-none-any.whl (5.9 MB)
[K     |████████████████████████████████| 5.9 MB 93.5 MB/s eta 0:00:01
Collecting tensorboard-data-server<0.7.0,>=0.6.0
  Downloading tensorboard_data_server-0.6.1-py3-none-manylinux2010_x86_64.whl (4.9 MB)
[K     |████████████████████████████████| 4.9 MB 89.9 MB/s eta 0:00:01
Installing collected packages: tensorboard-data-server, tensorboard, pytorch-lightning
  Attempting uninstall: tensorboard
    Found existing installation: tensorboard 2.4.1
    Uninstalling tensorboard-2.4.1:
      Successfully uninstalled tensorboard-2.4.1
  Attempting uninstall: pytorch-lightning
    Found existing installation: pytorch-lightning 1.6.3
    Uninstalling pytorch-lightning-

In [1]:
import torch
import torchvision
import wandb
from argparse import ArgumentParser
from pytorch_lightning import LightningModule, Trainer, Callback
from pytorch_lightning.loggers import WandbLogger
from torchmetrics.functional import accuracy

### Pytorch Lightning Module

PyTorch Lightning Module has a PyTorch ResNet18 Model. It is a subclass of LightningModule. The model part is subclassed to support a single channel input. We replaced the input convolutional layer to support single channel inputs. The Lightning Module is also a container for the model, the optimizer, the loss function, the metrics, and the data loaders.

`ResNet` class can be found [here](https://pytorch.org/vision/0.8/_modules/torchvision/models/resnet.html).

By using PyTorch Lightning, we simplify the training and testing processes since we do not need to write boiler plate code blocks. These include automatic transfer to chosen device (i.e. `gpu` or `cpu`), model `eval` and `train` modes, and backpropagation routines.

In [2]:
class LitMNISTModel(LightningModule):
    def __init__(self, num_classes=10, lr=0.001, batch_size=32):
        super().__init__()
        self.save_hyperparameters()
        self.model = torchvision.models.resnet18(num_classes=num_classes)
        self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7,
                                           stride=2, padding=3, bias=False)
        self.loss = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    # this is called during fit()
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)
        return {"loss": loss}

    # calls to self.log() are recorded in wandb
    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.log("train_loss", avg_loss, on_epoch=True)

    # this is called at the end of an epoch
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)
        acc = accuracy(y_hat, y) * 100.
        # we use y_hat to display predictions during callback
        return {"y_hat": y_hat, "test_loss": loss, "test_acc": acc}

    # this is called at the end of all epochs
    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["test_acc"] for x in outputs]).mean()
        self.log("test_loss", avg_loss, on_epoch=True, prog_bar=True)
        self.log("test_acc", avg_acc, on_epoch=True, prog_bar=True)

    # validation is the same as test
    def validation_step(self, batch, batch_idx):
       return self.test_step(batch, batch_idx)

    def validation_epoch_end(self, outputs):
        return self.test_epoch_end(outputs)

    # we use Adam optimizer
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
    
    # this is called after model instatiation to initiliaze the datasets and dataloaders
    def setup(self, stage=None):
        self.train_dataloader()
        self.test_dataloader()

    # build train and test dataloaders using MNIST dataset
    # we use simple ToTensor transform
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            torchvision.datasets.MNIST(
                "./data", train=True, download=True, 
                transform=torchvision.transforms.ToTensor()
            ),
            batch_size=self.hparams.batch_size,
            shuffle=True,
            num_workers=48,
            pin_memory=True,
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            torchvision.datasets.MNIST(
                "./data", train=False, download=True, 
                transform=torchvision.transforms.ToTensor()
            ),
            batch_size=self.hparams.batch_size,
            shuffle=False,
            num_workers=48,
            pin_memory=True,
        )

    def val_dataloader(self):
        return self.test_dataloader()

### PyTorch Lightning Callback

We can instantiate a callback object to perform certain tasks during training. In this case, we log sample images, ground truth labels, and predicted labels from the test dataset.

We can also `ModelCheckpoint` callback to save the model after each epoch.

In [3]:
class WandbCallback(Callback):

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        # process first 10 images of the first batch
        if batch_idx == 0:
            n = 10
            x, y = batch
            outputs = outputs["y_hat"]
            outputs = torch.argmax(outputs, dim=1)
            # log image, ground truth and prediction on wandb table
            columns = ['image', 'ground truth', 'prediction']
            data = [[wandb.Image(x_i), y_i, y_pred] for x_i, y_i, y_pred in list(
                zip(x[:n], y[:n], outputs[:n]))]
            wandb_logger.log_table(
                key='ResNet18 on MNIST Predictions',
                columns=columns,
                data=data)


### Program Arguments

When running on command line, we can pass arguments to the program. For the jupyter notebook, we can pass arguments using the `%run` magic command.

```

In [4]:
def get_args():
    parser = ArgumentParser(description="PyTorch Lightning MNIST Example")
    parser.add_argument("--max-epochs", type=int, default=5, help="num epochs")
    parser.add_argument("--batch-size", type=int, default=32, help="batch size")
    parser.add_argument("--lr", type=float, default=0.001, help="learning rate")

    parser.add_argument("--num-classes", type=int, default=10, help="num classes")

    parser.add_argument("--devices", default=1)
    parser.add_argument("--accelerator", default='gpu')
    parser.add_argument("--num-workers", type=int, default=48, help="num workers")
    
    parser.add_argument("--no-wandb", default=False, action='store_true')
    args = parser.parse_args("")
    return args

### Training and Evaluation using `Trainer`

Get command line arguments. Instatiate a Pytorch Lightning Model. Train the model. Evaluate the model.

In [5]:
if __name__ == "__main__":
    args = get_args()
    model = LitMNISTModel(num_classes=args.num_classes,
                          lr=args.lr, batch_size=args.batch_size)
    model.setup()

    # printing the model is useful for debugging
    print(model)

    # wandb is a great way to debug and visualize this model
    wandb_logger = WandbLogger(project="pl-mnist")
    
    trainer = Trainer(accelerator=args.accelerator,
                      devices=args.devices,
                      max_epochs=args.max_epochs,
                      logger=wandb_logger if not args.no_wandb else None,
                      callbacks=[WandbCallback() if not args.no_wandb else None])
    trainer.fit(model)
    trainer.test(model)

    wandb.finish()


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


LitMNISTModel(
  (model): ResNet(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track

[34m[1mwandb[0m: Currently logged in as: [33mrowel[0m. Use [1m`wandb login --relogin`[0m to force relogin
Exception in thread StatsThr:
Traceback (most recent call last):
  File "/home/rowel/anaconda3/envs/voice/lib/python3.9/threading.py", line 973, in _bootstrap_inner
    self.run()
  File "/home/rowel/anaconda3/envs/voice/lib/python3.9/threading.py", line 910, in run
    self._target(*self._args, **self._kwargs)
  File "/home/rowel/anaconda3/envs/voice/lib/python3.9/site-packages/wandb/sdk/internal/stats.py", line 137, in _thread_body
    stats = self.stats()
  File "/home/rowel/anaconda3/envs/voice/lib/python3.9/site-packages/wandb/sdk/internal/stats.py", line 183, in stats
    handle = pynvml.nvmlDeviceGetHandleByIndex(i)
  File "/home/rowel/anaconda3/envs/voice/lib/python3.9/site-packages/wandb/vendor/pynvml/pynvml.py", line 2232, in nvmlDeviceGetHandleByIndex
    _nvmlCheckReturn(ret)
  File "/home/rowel/anaconda3/envs/voice/lib/python3.9/site-packages/wandb/vendor/pynvml

MisconfigurationException: GPUAccelerator can not run on your system since the accelerator is not available. The following accelerator(s) is available and can be passed into `accelerator` argument of `Trainer`: ['cpu'].