<a href="https://colab.research.google.com/github/plaban1981/Pytorch_lightning/blob/main/Transfer_Learning_with_PyTorch_Lightning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## ⚙️ Imports and Setups

In [None]:
%%capture
# Install pytorch lighting
!pip install pytorch-lightning --quiet
# Install weights and biases
!pip install wandb --quiet
# Install patool to unrar dataset file
!pip install patool

In [None]:
# regular imports
import os
import re
import numpy as np
import patoolib

# pytorch related imports 
import torch
from torch import nn
from torch.nn import functional as F
import torchvision.models as models
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import download_url

# lightning related imports
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint

# sklearn related imports
from sklearn.metrics import precision_recall_curve
from sklearn.preprocessing import label_binarize

# import wandb and login
import wandb
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

## 🎨 Using DataModules - `Clatech101DataModule`

DataModules are a way of decoupling data-related hooks from the `LightningModule` so you can develop dataset agnostic models.

In [None]:
class Caltech101DataModule(pl.LightningDataModule):
    def __init__(self, batch_size, data_dir: str = './'):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

        # Augmentation policy
        self.augmentation = transforms.Compose([
              transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
              transforms.RandomRotation(degrees=15),
              transforms.RandomHorizontalFlip(),
              transforms.CenterCrop(size=224),
              transforms.ToTensor(),
              transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
        ])
        self.transform = transforms.Compose([
              transforms.Resize(size=256),
              transforms.CenterCrop(size=224),
              transforms.ToTensor(),
              transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
        ])
        
        self.num_classes = 102

    def prepare_data(self):
        # source: https://figshare.com/articles/dataset/Caltech101_Image_Dataset/7007090
        url = 'https://s3-eu-west-1.amazonaws.com/pfigshare-u-files/12855005/Caltech101ImageDataset.rar'
        # download
        download_url(url, self.data_dir)
        # extract 
        patoolib.extract_archive("Caltech101ImageDataset.rar", outdir=self.data_dir)

    def setup(self, stage=None):
        # build dataset
        caltect_dataset = ImageFolder('Caltech101')
        # split dataset
        self.train, self.val, self.test = random_split(caltect_dataset, [6500, 1000, 1645])
        self.train.dataset.transform = self.augmentation
        self.val.dataset.transform = self.transform
        self.test.dataset.transform = self.transform
        
    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size)

## 📲 Callbacks

#### 🚏 Earlystopping

In [None]:
early_stop_callback = EarlyStopping(
   monitor='val_loss',
   patience=3,
   verbose=False,
   mode='min'
)

#### 🛃 Custom Callback - `ImagePredictionLogger`

In [None]:
class ImagePredictionLogger(Callback):
    def __init__(self, val_samples, num_samples=32):
        super().__init__()
        self.num_samples = num_samples
        self.val_imgs, self.val_labels = val_samples
        
    def on_validation_epoch_end(self, trainer, pl_module):
        val_imgs = self.val_imgs.to(device=pl_module.device)
        val_labels = self.val_labels.to(device=pl_module.device)
       
        logits = pl_module(val_imgs)
        preds = torch.argmax(logits, -1)
        
        trainer.logger.experiment.log({
            "examples":[wandb.Image(x, caption=f"Pred:{pred}, Label:{y}") 
                           for x, pred, y in zip(val_imgs[:self.num_samples], 
                                                 preds[:self.num_samples], 
                                                 val_labels[:self.num_samples])]
            })

#### 💾 Model Checkpoint Callback

In [None]:
MODEL_CKPT_PATH = 'model/'
MODEL_CKPT = 'model/model-{epoch:02d}-{val_loss:.2f}'

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    filename=MODEL_CKPT,
    save_top_k=3,
    mode='min')

## 🎺 Define The Model

In [None]:
class LitModel(pl.LightningModule):
    def __init__(self, input_shape, num_classes, learning_rate=2e-4):
        super().__init__()
        
        # log hyperparameters
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.dim = input_shape
        self.num_classes = num_classes
        
        self.feature_extractor = models.resnet18(pretrained=True)
        self.feature_extractor.eval()

        for param in self.feature_extractor.parameters():
            param.requires_grad = False
        
        n_sizes = self._get_conv_output(input_shape)

        self.classifier = nn.Linear(n_sizes, num_classes)

    # returns the size of the output tensor going into Linear layer from the conv block.
    def _get_conv_output(self, shape):
        batch_size = 1
        input = torch.autograd.Variable(torch.rand(batch_size, *shape))

        output_feat = self._forward_features(input) 
        n_size = output_feat.data.view(batch_size, -1).size(1)
        return n_size
        
    # returns the feature tensor from the conv block
    def _forward_features(self, x):
        x = self.feature_extractor(x)
        return x
    
    # will be used during inference
    def forward(self, x):
       x = self._forward_features(x)
       x = x.view(x.size(0), -1)
       x = F.log_softmax(self.classifier(x), dim=1)
       
       return x

    # logic for a single training step
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # training metrics
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)
        
        return loss

    # logic for a single validation step
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        # validation metrics
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss

    # logic for a single testing step
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # validation metrics
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

## ⚡ Train and Evaluate The Model

In [None]:
# Init our data pipeline
dm = Caltech101DataModule(batch_size=64)
# To access the x_dataloader we need to call prepare_data and setup.
dm.prepare_data()
dm.setup()

# Samples required by the custom ImagePredictionLogger callback to log image predictions.
val_samples = next(iter(dm.val_dataloader()))
val_imgs, val_labels = val_samples[0], val_samples[1]
val_imgs.shape, val_labels.shape

Downloading https://s3-eu-west-1.amazonaws.com/pfigshare-u-files/12855005/Caltech101ImageDataset.rar to ./Caltech101ImageDataset.rar


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

patool: Extracting Caltech101ImageDataset.rar ...
patool: running /usr/bin/unrar x -- /content/Caltech101ImageDataset.rar
patool:     with cwd='./'
patool: ... Caltech101ImageDataset.rar extracted to `./'.


(torch.Size([64, 3, 224, 224]), torch.Size([64]))

In [None]:
# Init our model
model = LitModel((3,224,224), 102)

# Initialize wandb logger
wandb_logger = WandbLogger(project='wandb-lightning', job_type='train')

# Initialize a trainer
trainer = pl.Trainer(max_epochs=50,
                     progress_bar_refresh_rate=20, 
                     gpus=1, 
                     logger=wandb_logger,
                     callbacks=[early_stop_callback,
                                ImagePredictionLogger(val_samples, 64)],
                     checkpoint_callback=checkpoint_callback)

# Train the model ⚡🚅⚡
trainer.fit(model, dm)

# Evaluate the model on the held out test set ⚡⚡
trainer.test()

# Close wandb run
wandb.finish()

## 💾 Save your hardwork(checkpoints) as W&B Artifacts

In [None]:
run = wandb.init(project='wandb-lightning', job_type='producer')

artifact = wandb.Artifact('model', type='model')
artifact.add_dir(MODEL_CKPT_PATH)

run.log_artifact(artifact)
run.join()

[34m[1mwandb[0m: Adding directory to artifact (./model)... Done. 0.5s


VBox(children=(Label(value=' 0.00MB of 405.19MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.360407402…

## 📡 Load The Best Model

In [None]:
model_ckpts = os.listdir(MODEL_CKPT_PATH)
losses = []
for model_ckpt in model_ckpts:
    loss = re.findall("\d+\.\d+", model_ckpt)
    losses.append(float(loss[0]))

losses = np.array(losses)
best_model_index = np.argsort(losses)[0]
best_model = model_ckpts[best_model_index]
print(best_model)

model-epoch=07-val_loss=0.18.ckpt


In [None]:
inference_model = LitModel.load_from_checkpoint(MODEL_CKPT_PATH+best_model)

## 📉 Precision Recall Curve 

In [None]:
def evaluate(model, loader):
    y_true = []
    y_pred = []
    for imgs, labels in loader:
        logits = inference_model(imgs)

        y_true.extend(labels)
        y_pred.extend(logits.detach().numpy())

    return np.array(y_true), np.array(y_pred)

y_true, y_pred = evaluate(inference_model, dm.test_dataloader())

In [None]:
# generate binary correctness labels across classes
binary_ground_truth = label_binarize(y_true,
                                     classes=np.arange(0, 102).tolist())

# compute a PR curve with sklearn like you normally would
precision_micro, recall_micro, _ = precision_recall_curve(binary_ground_truth.ravel(),
                                                          y_pred.ravel())

In [None]:
run = wandb.init(project='wandb-lightning', job_type='evaluate')

data = [[x, y] for (x, y) in zip(recall_micro, precision_micro)]
sample_rate = int(len(data)/10000)

table = wandb.Table(columns=["recall_micro", "precision_micro"], data=data[::sample_rate])
wandb.log({"precision_recall" : wandb.plot.line(table, 
                                                "recall_micro", 
                                                "precision_micro", 
                                                stroke=None, 
                                                title="Average Precision")})

run.join()



VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
_step,0
_runtime,1
_timestamp,1605020741


0,1
_step,▁
_runtime,▁
_timestamp,▁
