### Image Recognition on MNIST using PyTorch Lightning

Demonstrating the elements of machine learning:

1) **E**xperience (Datasets and Dataloaders)<br>
2) **T**ask (Classifier Model)<br>
3) **P**erformance (Accuracy)<br>

**Experience:** <br>
We use MNIST dataset for this demo. MNIST is made of 28x28 images of handwritten digits, `0` to `9`. The train split has 60,000 images and the test split has 10,000 images. Images are all gray-scale.

**Task:**<br>
Our task is to classify the images into 10 classes. We use ResNet18 model from torchvision.models. The ResNet18 first convolutional layer (`conv1`) is modified to accept a single channel input. The number of classes is set to 10.

**Performance:**<br>
We use accuracy metric to evaluate the performance of our model on the test split. `torchmetrics.functional.accuracy`  calculates the accuracy.

**[Pytorch Lightning](https://www.pytorchlightning.ai/):**<br>
Our demo uses Pytorch Lightning to simplify the process of training and testing. Pytorch Lightning `Trainer` trains and evaluates our model. The default configurations are for a GPU-enabled system with 48 CPU cores. Please change the configurations if you have a different system.

**[Weights and Biases](https://www.wandb.ai/):**<br>
`wandb` is used by PyTorch Lightining Module to log train and evaluations results. Use `--no-wandb` to disable `wandb`.


Let us install `pytorch-lightning` and `torchmetrics`.

In [1]:
# %pip install pytorch-lightning --upgrade
# %pip install torchmetrics --upgrade
# %pip install lightning --upgrade

In [None]:
import torch
import torchvision
import wandb 
from argparse import ArgumentParser
from lightning.pytorch import LightningModule, Trainer, Callback
from lightning.pytorch.loggers import WandbLogger
from torchmetrics.functional import accuracy

### Pytorch Lightning Module

PyTorch Lightning Module has a PyTorch ResNet18 Model. It is a subclass of LightningModule. The model part is subclassed to support a single channel input. We replaced the input convolutional layer to support single channel inputs. The Lightning Module is also a container for the model, the optimizer, the loss function, the metrics, and the data loaders.

`ResNet` class can be found [here](https://pytorch.org/vision/0.8/_modules/torchvision/models/resnet.html).

By using PyTorch Lightning, we simplify the training and testing processes since we do not need to write boiler plate code blocks. These include automatic transfer to chosen device (i.e. `gpu` or `cpu`), model `eval` and `train` modes, and backpropagation routines.