<a href="https://colab.research.google.com/github/rhiga2/DeepLearningHawaii/blob/main/workshops/pytorch_lightning_intro/PytorchLightningIntro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Pytorch Lightning Introduction


In [None]:
!pip install torch
!pip install matplotlib
!pip install pytorch_lightning
!pip install torchvision
!pip install torchinfo
!pip install torchmetrics
!pip install wandb

In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary
from torchmetrics import Accuracy
import time
import pytorch_lightning as pl

## Pytorch Performance

In [None]:
def time_function(function, *args):
  start = time.time()
  output = function(*args)
  end = time.time()
  return output, end - start

In [None]:
# initialize data
a = torch.rand((20, 20))
b = torch.rand((20, 20))

In [None]:
# my own matrix multiplication implementation 
def my_mm(a, b):
  c = torch.zeros(a.size(0), b.size(1))
  for i in range(a.size(0)):
    for j in range(a.size(1)):
      for k in range(b.size(1)):
        c[i, k] += a[i, j] * b[j, k]
  return c 

my_output, my_duration = time_function(my_mm, a, b)
print("Time for custom matrix multiplication: ", my_duration)

In [None]:
# matrix multiplication in pytorch 
def torch_mm(a, b):
  # write matrix multiplication in pytorch
  pass

torch_output, torch_duration = time_function(torch_mm, a, b)
print("Square error between torch and my output: ", torch.sum((my_output - torch_output)**2).item())
print("Time for torch's matrix multiplication: ", torch_duration)
print("How much faster is pytorch: ", my_duration / torch_duration)

What can time difference be attributed to?
* Vectorized batch processing in pytorch implementation, better usage of memory and CPU
* Better algorithms for matrix multiplication
* Low-level optimizations in pytorch

## Weights and Biases

In [None]:
from pytorch_lightning.loggers import WandbLogger
%env WANDB_NOTEBOOK_NAME='PytorchLightningIntro.ipynb'

logger = WandbLogger(project='mnist_classifier')

## Image Classification in Pytorch Lightning

What do you need to specify when training a pytorch lightning model?
* The dataset / dataloader
* The model
* The trainer 


### Load the Dataset

In [None]:
# Load the dataset
import torchvision

train_dataset = torchvision.datasets.MNIST(
    '/files/', train=True, download=True,
    transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.1307,), (0.3081,))
    ]))

test_dataset = torchvision.datasets.MNIST(
    '/files/', train=False, download=True,
    transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.1307,), (0.3081,))
    ]))

fig, ax = plt.subplots(5, 5, figsize=(10, 10))
for i in range(5):
  for j in range(5):
    ax[i, j].grid(False)
    ax[i, j].set_xticks([])
    ax[i, j].set_yticks([])
    ax[i, j].imshow(train_dataset[5*i+j][0].squeeze(0), cmap='gray')

In [None]:
# Create dataloaders
from torch.utils.data.sampler import SubsetRandomSampler

batch_size = 32
val_proportion = 0.2

# Split train set into train and validation.
pass

trainloader = torch.utils.data.DataLoader(train_dataset, 
                                          batch_size=batch_size,
                                          sampler=train_sampler)

valloader = torch.utils.data.DataLoader(train_dataset, 
                                        batch_size=batch_size,
                                        sampler=val_sampler)

testloader = torch.utils.data.DataLoader(train_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=True)


### Specify the Model
Description of model:
* Input: (batch_size, 1, 28, 28)
* 1st 3x3 Convolution w/ 32 units
* 2nd 3x3 Convolution w/ 64 units
* 2x2 Max Pool
* 1st Dropout 25%
* Flatten
* 1st Dense + ReLU w/ 128 units
* 2nd Dropout 50%
* 2nd Dense w/ 10 units

In [None]:
from IPython.lib.security import passwd
# Conv net
accuracy = Accuracy('multiclass', num_classes=10).to(
    'cuda' if torch.cuda.is_available() else 'cpu'
)

class MnistClassifier(pl.LightningModule):
  def __init__(self):
    super().__init__()
    network = []
    
    # 1st 3x3 Convolution + ReLU w/ 32 units
    pass
    
    # 2nd 3x3 Convolution + ReLU w/ 64 units
    pass
    
    # 2x2 Max Pool
    pass
    
    # 1st Dropout 25%
    pass

    # Flatten
    pass

    # 1st Dense + ReLU w/ 128 units
    pass

    # 2nd Dropout 50%
    pass

    # 2nd Dense w/ 10 units
    pass
    self.network = nn.Sequential(*network)

  def forward(self, x):
    return self.network(x)

  def training_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    loss, acc = self._get_loss_and_accuracy(logits, y)
    self.log('training loss', loss)
    self.log("training accuracy", acc, prog_bar=True, on_step=False, 
             on_epoch=True)
    return loss

  def validation_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    loss, acc = self._get_loss_and_accuracy(logits, y)
    self.log('validation loss', loss)
    self.log("validation accuracy", acc)
    return loss, acc

  def test_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    loss, acc = self._get_loss_and_accuracy(logits, y)
    self.log('test loss', loss)
    self.log('test accuracy', acc)
    return loss, acc

  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr = 1e-4)
    return optimizer

  def _get_loss_and_accuracy(self, logits, y):
    loss = F.cross_entropy(logits, y)
    acc = accuracy(logits, y)
    return loss, acc

In [None]:
model = MnistClassifier()
summary(model, (32, 1, 28, 28))

### Create the Trainer

In [None]:
# Create the trainer
trainer = pl.Trainer(accelerator="cuda", max_epochs=4, logger=logger)
trainer.fit(model, train_dataloaders=trainloader, val_dataloaders=valloader)

In [None]:
trainer.test(model, dataloaders=testloader)

In [None]:
# Visualize results
index = 3
fig, axs = plt.subplots(1, 2, figsize=(10, 5))

axs[0].axis('off')
image = test_dataset[index][0]
axs[0].imshow(image.squeeze(0), cmap='gray')

classes = torch.arange(10).detach().numpy()
axs[1].barh(classes, 
            F.softmax(model(image.unsqueeze(0)), dim=1).squeeze(0).detach().numpy())
axs[1].set_yticks(classes)
axs[1].set_xlabel('Confidence')
axs[1].set_ylabel('Class Label')