# PyTorch Lightning Faster R-CNN Finetuning LB=0.358


Finetune TorchVision's [Faster R-CNN](https://pytorch.org/vision/stable/models.html#id57) with [PyTorch Lightninng](https://www.pytorchlightning.ai/).

Tracks the training with [Weights & Biases](https://wandb.ai/site).

Corresponding Inference kernel: [Starter PyTorch Lightning Faster R-CNN Inference](https://www.kaggle.com/clemchris/starter-pytorch-lightning-faster-r-cnn-inference)

## Lightninng Features Used in this Kernel:
- Easily switch between CPU and GPU training (`gpus=[0/1]` Trainer flag)
- Quickly check if complete training is running without errors with `fast_dev_run=True` Trainer flag
- Use half precision training on GPU with `precision=16` Trainer flag
- Log learning rate with `LearningRateMonitor` callback
- Log losses and metrics to Weights & Biases with `WandbLogger`

## Ideas for next steps:
- Add data augmentation
- Use also those image without annotations for training
- Try other TorchVision models
- Try TPU training (made easy with PyTorch Lightning's [TPU Support](https://pytorch-lightning.readthedocs.io/en/latest/advanced/tpu.html))
- Tune hyperparameters of model
- Visualize inputs and predictions with Weights & Biases

## Sources and Inspirations:
- [TorchVision Object Detection Finetuning Tutorial](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html)
- [Reef- Starter Torch FasterRCNN Train [LB=0.416]](https://www.kaggle.com/julian3833/reef-starter-torch-fasterrcnn-train-lb-0-416)



## Installs

Let's use the latest versions of PyTorch Lightning and [TorchMetrics](https://torchmetrics.readthedocs.io/en/latest/) to be able to use TorchMetric's [MAP](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#map) metric, which needs [pycocotools](https://github.com/cocodataset/cocoapi/tree/master/PythonAPI/pycocotools).

In [None]:
%%capture
!pip install pytorch-lightning==1.5.3 torchmetrics==0.6.0 pycocotools

## Imports

In [None]:
import ast
import math
import multiprocessing as mp
from pathlib import Path

import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torchmetrics
import torchvision
import wandb
from PIL import Image
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
from torchmetrics.metric import Metric
from torchvision.datasets import VisionDataset
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torch.utils.data.dataloader import DataLoader

## Load Weights & Biases API Key from secrets
- Copy key from [W&B settings page](https://wandb.ai/settings)
- Add key to Kaggle through [Add-ons](https://www.kaggle.com/product-feedback/114053) 

In [None]:
from kaggle_secrets import UserSecretsClient
secret_label = "wb_api"
wb_api_key = UserSecretsClient().get_secret(secret_label)
wandb.login(key=wb_api_key)

## Paths

In [None]:
INPUT_DIR = Path("../input")
DATA_DIR = INPUT_DIR / "tensorflow-great-barrier-reef"
TRAIN_CSV_PATH = DATA_DIR / "train.csv"

## Settings

In [None]:
# Data Module Args
NUM_WORKERS = mp.cpu_count()
BATCH_SIZE = 8

WANDB_PROJECT = "kaggle-great-barrier-reef"

# Trainer Args
GPUS = 1             # Set to 1 if GPU is enabled for notebook
FAST_DEV_RUN = True  # Set to False to properly train
MAX_EPOCHS = 5

## Competition Metric implemented with TorchMetrics

Mostly copied from https://www.kaggle.com/bamps53/competition-metric-implementation

In [None]:
def f_beta(tp, fp, fn, beta=2):
    return (1+beta**2)*tp / ((1+beta**2)*tp + beta**2*fn+fp)

In [None]:
class KaggleF2(Metric):
    def __init__(
        self,
        compute_on_step=True,
        dist_sync_on_step=False,
        process_group=None,
        dist_sync_fn=None,
    ) -> None:
        super().__init__(
            compute_on_step=compute_on_step,
            dist_sync_on_step=dist_sync_on_step,
            process_group=process_group,
            dist_sync_fn=dist_sync_fn,
        )

        self.add_state("detection_boxes", default=[], dist_reduce_fx=None)
        self.add_state("detection_scores", default=[], dist_reduce_fx=None)
        self.add_state("groundtruth_boxes", default=[], dist_reduce_fx=None)

    def update(self, preds, target):
        for item in preds:
            self.detection_boxes.append(
                torchvision.ops.box_convert(item["boxes"], in_fmt="xywh", out_fmt="xyxy")
                if len(item["boxes"]) > 0
                else item["boxes"]
            )
            self.detection_scores.append(item["scores"])

        for item in target:
            self.groundtruth_boxes.append(
                torchvision.ops.box_convert(item["boxes"], in_fmt="xywh", out_fmt="xyxy")
                if len(item["boxes"]) > 0
                else item["boxes"]
            )

    def compute(self):
        tps, fps, fns = 0, 0, 0
        for gt_boxes, pred_boxes, pred_scores in zip(
            self.groundtruth_boxes, self.detection_boxes, self.detection_scores
        ):
            tp, fp, fn = self._compute_stat_scores(gt_boxes, pred_boxes, pred_scores)
            tps += tp
            fps += fp
            fns += fn

        return f_beta(tps, fps, fns, beta=2)

    def _compute_stat_scores(self, gt_boxes, pred_boxes, pred_scores):
        if len(gt_boxes) == 0 and len(pred_boxes) == 0:
            tps, fps, fns = 0, 0, 0
            return tps, fps, fns

        elif len(gt_boxes) == 0:
            tps, fps, fns = 0, len(pred_boxes), 0
            return tps, fps, fns

        elif len(pred_boxes) == 0:
            tps, fps, fns = 0, 0, len(gt_boxes)
            return tps, fps, fns

        # sort by conf
        _, indices = torch.sort(pred_scores, descending=True)
        pred_boxes = pred_boxes[indices]

        tps, fps, fns = 0, 0, 0
        for iou_th in np.arange(0.3, 0.85, 0.05):
            tp, fp, fn = self._compute_stat_scores_at_iou_th(gt_boxes, pred_boxes, iou_th)
            tps += tp
            fps += fp
            fns += fn

        return tps, fps, fns

    def _compute_stat_scores_at_iou_th(self, gt_boxes, pred_boxes, iou_th):
        gt_boxes = gt_boxes.clone()
        pred_boxes = pred_boxes.clone()

        tp = 0
        fp = 0
        for k, pred_bbox in enumerate(pred_boxes):
            ious = torchvision.ops.box_iou(gt_boxes, pred_bbox[None, ...])

            max_iou = ious.max()
            if max_iou > iou_th:
                tp += 1
                
                # Delete max_iou box
                argmax_iou = ious.argmax()
                gt_boxes = torch.cat([gt_boxes[0:argmax_iou], gt_boxes[argmax_iou+1:]])
            else:
                fp += 1
            if len(gt_boxes) == 0:
                fp += len(pred_boxes) - (k + 1)
                break

        fn = len(gt_boxes)

        return tp, fp, fn


## PyTorch Dataset Class

In [None]:
class GBRDataset(VisionDataset):
    """Custom VisionDataset class that creates a dataset from the train DataFrame. 
       
       Uses only those images with annotations.
    """
    
    def __init__(self, csv_path):
        super().__init__(csv_path.parent)

        self.train_df = pd.read_csv(csv_path)
        self.image_paths_annotations = self._create_image_paths_annotations()

    def _create_image_paths_annotations(self):
        """Iterate over train DataFrame and extract image paths and annotations."""
        
        image_paths_annotations = []

        for _, row in self.train_df.iterrows():
            image_path = Path(self.root) / "train_images" / f"video_{row['video_id']}" / f"{row['video_frame']}.jpg"
            annotations = ast.literal_eval(row["annotations"])

            # Use only those images with annotations
            if annotations:
                image_paths_annotations.append((image_path, annotations))

        return image_paths_annotations

    def __getitem__(self, index):
        image_path, annotations = self.image_paths_annotations[index]

        # Image
        image = Image.open(image_path).convert("RGB")
        image = np.array(image, dtype=np.float32) / 255.0

        # Convert to pascal_voc boxes
        boxes = []
        for annotation in annotations:
            x_min, y_min, width, height = (
                annotation["x"],
                annotation["y"],
                annotation["width"],
                annotation["height"],
            )
            
            x_max = x_min + width
            y_max = y_min + height

            box = (x_min, y_min, x_max, y_max)
            boxes.append(box)

        labels = [1] * len(boxes)

        # Convert to tensors
        image = torch.from_numpy(image.transpose(2, 0, 1))

        target = {}
        target["boxes"] = torch.as_tensor(boxes, dtype=torch.float32)
        target["labels"] = torch.as_tensor(labels, dtype=torch.int64)

        return image, target

    def __len__(self) -> int:
        return len(self.image_paths_annotations)

## Lightning Module Class

In [None]:
class GBRModule(pl.LightningModule):
    """LightningModule class to finetune torchvision's Faster R-CNN model."""
    
    def __init__(self, pretrained_weights_path=None):
        super().__init__()

        self.model = self._create_model(pretrained_weights_path)

        self.val_map = torchmetrics.MAP()
        self.val_f2 = KaggleF2()

    def _create_model(self, pretrained_weights_path):
        """Creates finetunable Faster R-CNN model.
        
        In the finetuning notebook, the internet can be used and the weights are downloaded.
        In the inference notebook, there is no internet access and the weights are provided
        with the pretrained_weights_path arg.
        """
        if pretrained_weights_path is None:
            model = fasterrcnn_resnet50_fpn(pretrained=True)
        else:
            model = fasterrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False)
            model.load_state_dict(torch.load(pretrained_weights_path))

        in_features = model.roi_heads.box_predictor.cls_score.in_features
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=2)

        return model

    def forward(self, image):
        """Runs inference."""
        self.model.eval()
        output = self.model(image)

        return output

    def training_step(self, batch, batch_idx):
        image, target = batch
        loss_dict = self.model(image, target)
        losses = sum(loss for loss in loss_dict.values())

        batch_size = len(batch[0])
        self.log_dict(loss_dict, batch_size=batch_size)
        self.log("train_loss", losses, batch_size=batch_size)

        return losses

    def validation_step(self, batch, batch_idx):
        image, target = batch
        output = self.model(image)

        val_map = self.val_map(output, target)
        val_f2 = self.val_f2(output, target)

        self.log("val_map", val_map["map"])
        self.log("val_f2", val_f2)


    def configure_optimizers(self):
        params = [p for p in self.model.parameters() if p.requires_grad]
        optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

        return [optimizer], [lr_scheduler]


## Lightning DataModule Class

In [None]:
class GBRDataModule(pl.LightningDataModule):
    """LightningDataModule class to split dataset and create dataloaders."""
    def __init__(self, csv_path, batch_size, num_workers):
        super().__init__()

        self.save_hyperparameters()

        self.csv_path = csv_path
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        train_dataset = GBRDataset(self.csv_path)
        val_dataset = GBRDataset(self.csv_path)

        # Split
        len_total = len(train_dataset)
        len_train = int(0.8 * len_total)
        indices = torch.randperm(len_total).tolist()
        train_dataset = torch.utils.data.Subset(train_dataset, indices[:len_train])
        val_dataset = torch.utils.data.Subset(val_dataset, indices[len_train:])

        self.train_dataset, self.val_dataset = train_dataset, val_dataset

    def train_dataloader(self):
        return self._dataloader(self.train_dataset, shuffle=True)

    def val_dataloader(self):
        return self._dataloader(self.val_dataset)

    def _dataloader(self, dataset, shuffle=False):
        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=shuffle,
            num_workers=self.num_workers,
            collate_fn=collate_fn,
            pin_memory=True,
            drop_last=True,
        )


def collate_fn(batch):
    return tuple(zip(*batch))


## Train Function

In [None]:
def train():
    pl.seed_everything(42, workers=True)

    gbr_module = GBRModule()

    gbr_data_module = GBRDataModule(TRAIN_CSV_PATH, BATCH_SIZE, NUM_WORKERS)

    trainer = pl.Trainer(
        fast_dev_run=FAST_DEV_RUN,
        gpus=GPUS,
        logger=WandbLogger(project=WANDB_PROJECT, log_model=True),
        max_epochs=MAX_EPOCHS,
        precision=16 if GPUS else 32,
    )

    trainer.fit(gbr_module, gbr_data_module)

## Run the Training

In [None]:
train()