<a href="https://colab.research.google.com/github/plaban1981/Pytorch_lightning/blob/main/Cifar_10_Image_Classification_using_PyTorch_Lightning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## ⚙️ Imports and Setups

In [None]:
# install pytorch lighting
! pip install pytorch-lightning --quiet

[K     |████████████████████████████████| 542kB 4.8MB/s 
[K     |████████████████████████████████| 276kB 19.3MB/s 
[K     |████████████████████████████████| 92kB 7.8MB/s 
[K     |████████████████████████████████| 829kB 19.7MB/s 
[?25h  Building wheel for PyYAML (setup.py) ... [?25l[?25hdone
  Building wheel for future (setup.py) ... [?25l[?25hdone


In [None]:
# install weights and biases
!pip install wandb --quiet

[K     |████████████████████████████████| 1.7MB 4.8MB/s 
[K     |████████████████████████████████| 122kB 17.0MB/s 
[K     |████████████████████████████████| 102kB 11.3MB/s 
[K     |████████████████████████████████| 102kB 10.9MB/s 
[K     |████████████████████████████████| 163kB 32.4MB/s 
[K     |████████████████████████████████| 71kB 9.4MB/s 
[?25h  Building wheel for subprocess32 (setup.py) ... [?25l[?25hdone
  Building wheel for watchdog (setup.py) ... [?25l[?25hdone
  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


In [None]:
# regular imports
import os
import re
import numpy as np

# pytorch related imports
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
from torchvision import transforms

# lightning related imports
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint

# sklearn related imports
from sklearn.metrics import precision_recall_curve
from sklearn.preprocessing import label_binarize

# import wandb and login
import wandb
!wandb login

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter: 
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


## 🎨 Using DataModules

DataModules are a way of decoupling data-related hooks from the `LightningModule` so you can develop dataset agnostic models.

In [None]:
CLASS_NAMES = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]

In [None]:
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, batch_size, data_dir: str = './'):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
        self.dims = (3, 32, 32)
        self.num_classes = 10

    def prepare_data(self):
        # download 
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
            self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.cifar_train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.cifar_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.cifar_test, batch_size=self.batch_size)

## 📲 Callbacks

#### 🚏 Earlystopping

In [None]:
early_stop_callback = EarlyStopping(
   monitor='val_loss',
   patience=3,
   verbose=False,
   mode='min'
)

#### 🛃 Custom Callback - `ImagePredictionLogger`

In [None]:
class ImagePredictionLogger(Callback):
    def __init__(self, val_samples, num_samples=32):
        super().__init__()
        self.num_samples = num_samples
        self.val_imgs, self.val_labels = val_samples
        
    def on_validation_epoch_end(self, trainer, pl_module):
        # Bring the tensors to CPU
        val_imgs = self.val_imgs.to(device=pl_module.device)
        val_labels = self.val_labels.to(device=pl_module.device)
        # Get model prediction
        logits = pl_module(val_imgs)
        preds = torch.argmax(logits, -1)
        # Log the images as wandb Image
        trainer.logger.experiment.log({
            "examples":[wandb.Image(x, caption=f"Pred:{pred}, Label:{y}") 
                           for x, pred, y in zip(val_imgs[:self.num_samples], 
                                                 preds[:self.num_samples], 
                                                 val_labels[:self.num_samples])]
            })

#### 💾 Model Checkpoint Callback

In [None]:
MODEL_CKPT_PATH = 'model/'
MODEL_CKPT = 'model/model-{epoch:02d}-{val_loss:.2f}'

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    filepath=MODEL_CKPT ,
    save_top_k=3,
    mode='min')

## 🎺 Define The Model

In [None]:
class LitModel(pl.LightningModule):
    def __init__(self, input_shape, num_classes, learning_rate=2e-4):
        super().__init__()
        
        # log hyperparameters
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        
        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 32, 3, 1)
        self.conv3 = nn.Conv2d(32, 64, 3, 1)
        self.conv4 = nn.Conv2d(64, 64, 3, 1)

        self.pool1 = torch.nn.MaxPool2d(2)
        self.pool2 = torch.nn.MaxPool2d(2)
        
        n_sizes = self._get_conv_output(input_shape)

        self.fc1 = nn.Linear(n_sizes, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, num_classes)

    # returns the size of the output tensor going into Linear layer from the conv block.
    def _get_conv_output(self, shape):
        batch_size = 1
        input = torch.autograd.Variable(torch.rand(batch_size, *shape))

        output_feat = self._forward_features(input) 
        n_size = output_feat.data.view(batch_size, -1).size(1)
        return n_size
        
    # returns the feature tensor from the conv block
    def _forward_features(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.pool2(F.relu(self.conv4(x)))
        return x
    
    # will be used during inference
    def forward(self, x):
       x = self._forward_features(x)
       x = x.view(x.size(0), -1)
       x = F.relu(self.fc1(x))
       x = F.relu(self.fc2(x))
       x = F.log_softmax(self.fc3(x), dim=1)
       
       return x

    # logic for a single training step
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # training metrics
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)
        
        return loss

    # logic for a single validation step
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        # validation metrics
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss

    # logic for a single testing step
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # validation metrics
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

## ⚡ Train and Evaluate The Model

In [None]:
# Init our data pipeline
dm = CIFAR10DataModule(batch_size=32)
# To access the x_dataloader we need to call prepare_data and setup.
dm.prepare_data()
dm.setup()

# Samples required by the custom ImagePredictionLogger callback to log image predictions.
val_samples = next(iter(dm.val_dataloader()))
val_imgs, val_labels = val_samples[0], val_samples[1]
val_imgs.shape, val_labels.shape

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./cifar-10-python.tar.gz to ./

Files already downloaded and verified


(torch.Size([32, 3, 32, 32]), torch.Size([32]))

In [None]:
# Init our model
model = LitModel(dm.size(), dm.num_classes)

# Initialize wandb logger
wandb_logger = WandbLogger(project='wandb-lightning', job_type='train')

# Initialize a trainer
trainer = pl.Trainer(max_epochs=50,
                     progress_bar_refresh_rate=20, 
                     gpus=1, 
                     logger=wandb_logger,
                     callbacks=[early_stop_callback,
                                ImagePredictionLogger(val_samples)],
                     checkpoint_callback=checkpoint_callback)

# Train the model ⚡🚅⚡
trainer.fit(model, dm)

# Evaluate the model on the held out test set ⚡⚡
trainer.test()

# Close wandb run
wandb.finish()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[34m[1mwandb[0m: Currently logged in as: [33mwandb[0m (use `wandb login --relogin` to force relogin)



  | Name  | Type      | Params
------------------------------------
0 | conv1 | Conv2d    | 896   
1 | conv2 | Conv2d    | 9 K   
2 | conv3 | Conv2d    | 18 K  
3 | conv4 | Conv2d    | 36 K  
4 | pool1 | MaxPool2d | 0     
5 | pool2 | MaxPool2d | 0     
6 | fc1   | Linear    | 819 K 
7 | fc2   | Linear    | 65 K  
8 | fc3   | Linear    | 1 K   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': tensor(0.7097, device='cuda:0'),
 'test_loss': tensor(0.8496, device='cuda:0'),
 'train_acc': tensor(0.6250, device='cuda:0'),
 'train_acc_epoch': tensor(0.8456, device='cuda:0'),
 'train_acc_step': tensor(0.6250, device='cuda:0'),
 'train_loss': tensor(0.7803, device='cuda:0'),
 'train_loss_epoch': tensor(0.4344, device='cuda:0'),
 'train_loss_step': tensor(0.7803, device='cuda:0'),
 'val_acc': tensor(0.7258, device='cuda:0'),
 'val_loss': tensor(0.8611, device='cuda:0')}
--------------------------------------------------------------------------------



VBox(children=(Label(value=' 0.94MB of 0.94MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
_step,374.0
_runtime,278.0
_timestamp,1603467570.0
global_step,16884.0
train_loss_step,0.31585
train_acc_step,0.875
epoch,11.0
val_acc,0.7258
val_loss,0.86108
train_loss_epoch,0.43438


0,1
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
train_loss_step,█▆▆▅▅▅▄▆▃▄▅▄▄▄▄▅▅▄▃▄▁▃▅▃▄▃▃▂▃▃▃▂▃▃▃▂▃▁▂▂
train_acc_step,▁▁▂▃▂▂▅▁▆▄▅▅▅▄▅▅▄▄▆▅█▆▃▅▄▅▆█▇▆▆▆▆▇▅▇▆█▆▇
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▇▇▇▇▇▇▇███
val_acc,▁▃▄▅▆▆▇▇████
val_loss,█▆▄▃▃▂▂▁▁▁▁▁
train_loss_epoch,█▆▅▄▄▃▃▃▂▂▁▁


## 💾 Save your hardwork(checkpoints) as W&B Artifacts

In [None]:
run = wandb.init(project='wandb-lightning', job_type='producer')

artifact = wandb.Artifact('model', type='model')
artifact.add_dir(MODEL_CKPT_PATH)

run.log_artifact(artifact)
run.join()

[34m[1mwandb[0m: Adding directory to artifact (./model)... Done. 0.1s


VBox(children=(Label(value=' 0.00MB of 32.71MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.6850630067…

## 📡 Load The Best Model

In [None]:
model_ckpts = os.listdir(MODEL_CKPT_PATH)
losses = []
for model_ckpt in model_ckpts:
    loss = re.findall("\d+\.\d+", model_ckpt)
    losses.append(float(loss[0]))

losses = np.array(losses)
best_model_index = np.argsort(losses)[0]
best_model = model_ckpts[best_model_index]
print(best_model)

model-epoch=08-val_loss=0.84.ckpt


In [None]:
inference_model = LitModel.load_from_checkpoint(MODEL_CKPT_PATH+best_model)

## 📉 Precision Recall Curve 

In [None]:
def evaluate(model, loader):
    y_true = []
    y_pred = []
    for imgs, labels in loader:
        logits = inference_model(imgs)

        y_true.extend(labels)
        y_pred.extend(logits.detach().numpy())
    
    return np.array(y_true), np.array(y_pred)

y_true, y_pred = evaluate(inference_model, dm.test_dataloader())

In [None]:
# generate binary correctness labels across classes
binary_ground_truth = label_binarize(y_true,
                                     classes=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

# compute a PR curve with sklearn like you normally would
precision_micro, recall_micro, _ = precision_recall_curve(binary_ground_truth.ravel(),
                                                          y_pred.ravel())

In [None]:
run = wandb.init(project='wandb-lightning', job_type='evaluate')

data = [[x, y] for (x, y) in zip(recall_micro, precision_micro)]
sample_rate = int(len(data)/10000)

table = wandb.Table(columns=["recall_micro", "precision_micro"], data=data[::sample_rate])
wandb.log({"precision_recall" : wandb.plot.line(table, 
                                                "recall_micro", 
                                                "precision_micro", 
                                                stroke=None, 
                                                title="Average Precision")})

run.join()



VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…



VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
_step,0
_runtime,1
_timestamp,1603467987


0,1
_step,▁
_runtime,▁
_timestamp,▁
