# Faster R-CNN Lesion Detector

This notebook is dedicated to training and evaluating a lesion detector on DeepLesion dataset with the Faster R-CNN architecture for image detection with ResNet-50 backbone.

## Assumptions:
- Use 2D slice inputs (optionally use the neighbouring ones too),
- Resize all images to 512x512,
- Use YOLO-style Dataset class, with slight modification (bbox coordinates are: [x1, y1, x2, y2]).
- Use DeepLesion for training a general lesion localizer and some other like LiTS (Liver Tumor Segmentation) or CHAOS (CT liver dataset) for more specialized localizer.

## ðŸ“š Thesis Value Summary
### Contribution and Value:
- Comparison of 1-stage vs 2-stage vs Transformer vs legacy detectors on DeepLesion	-> âœ… Fills a gap in literature
- Evaluation of improved DETRs (DINO/Deformable) -> âœ… Modern insight
- General vs specialized lesion detection -> âœ… Strong clinical relevance
- Analysis of training time, robustness, failure modes -> âœ… Engineering depth


# Google Colab only

### Download required packages

In [None]:
!pip install -r https://raw.githubusercontent.com/pmalesa/lesion_detector/main/notebooks/requirements.txt

### Mount DeepLesion images and checkpoints from Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content

# remove existing link if any
!rm -rf data/deeplesion
!rm -rf faster_rcnn_checkpoints

!mkdir -p data
!ln -s /content/drive/MyDrive/deeplesion/data/deeplesion data/deeplesion
!ln -s /content/drive/MyDrive/deeplesion/checkpoints/faster_rcnn faster_rcnn_checkpoints
!ls -l data

# Import all packages

In [None]:
# General packages
import os
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import random
from pathlib import Path
from datetime import datetime

# Faster R-CNN packages
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2
from torchvision.datasets import ImageFolder
from torchvision import transforms
import torchvision.transforms as T
import torchvision.transforms.functional as F
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_V2_Weights
from torchmetrics.detection.mean_ap import MeanAveragePrecision
import torchvision.ops as ops
import copy
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.models.detection.rpn import RPNHead
from torchvision.ops import box_iou

# Set paths to DeepLesion images and metadata

## Paths to unprocessed data

In [None]:
# Google Colab
deeplesion_metadata_path = Path("data/deeplesion/deeplesion_metadata.csv")
deeplesion_image_path = Path("data/deeplesion/key_slices/")

In [None]:
# Local
deeplesion_metadata_path = Path("../data/deeplesion_metadata.csv")
deeplesion_image_path = Path("../data/deeplesion/key_slices/")

## Paths to processed data

In [None]:
deeplesion_data_dir = Path("data/deeplesion/")
deeplesion_preprocessed_image_path = deeplesion_data_dir / "deeplesion_preprocessed_uint8/key_slices"
deeplesion_preprocessed_metadata_path = deeplesion_data_dir / "deeplesion_preprocessed_uint8/deeplesion_metadata_preprocessed.csv"

# Evaluation Functions

## FROC Curve

In [None]:
@torch.no_grad()
def compute_froc_curve(model, loader, device, iou_thr=0.5, score_thresholds=None):
    """
    Computes FROC curve: sensitivity (recall) vs. FP per image.
    Returns:
        fp_per_image:     np.ndarray of shape [T]
        sensitivity:      np.ndarray of shape [T]
        score_thresholds: np.ndarray of shape [T] 
    """

    model.eval()

    if score_thresholds is None:
        # 0.0 ... 1.0, 101 points
        score_thresholds = np.linspace(0.0, 1.0, 101)
    
    # Collect all predictions and GTs first to avoid calling model many times
    all_image_preds = []
    all_image_gts = []
    n_images = 0
    n_gt_total = 0

    for images, targets in loader:
        images = [img.to(device) for img in images]
        outputs = model(images)

        for output, target in zip(outputs, targets):
            # predictions
            boxes_pred = output["boxes"].detach().cpu()
            scores_pred = output["scores"].detach().cpu()

            # ground truth (class agnostic)
            boxes_gt = target["boxes"].detach().cpu()
            all_image_preds.append((boxes_pred, scores_pred))
            all_image_gts.append(boxes_gt)
            n_images += 1
            n_gt_total += boxes_gt.shape[0]
    
    score_thresholds = np.array(score_thresholds, dtype=np.float32)
    fp_per_image = np.zeros_like(score_thresholds)
    sensitivity = np.zeros_like(score_thresholds)

    # For each threshold, count TP/FP over the whole dataset
    for i, thr in enumerate(score_thresholds):
        TP = 0
        FP = 0

        for (boxes_pred, scores_pred), boxes_gt in zip(all_image_preds, all_image_gts):
            # Filter predictions by score threshold
            keep = scores_pred >= thr
            boxes_p = boxes_pred[keep]
            scores_p = scores_pred[keep]

            if boxes_gt.numel() == 0:
                # No GT lesions in this image: all predictions are FPs
                FP += boxes_p.shape[0]
                continue

            if boxes_p.numel() == 0:
                # No predictions above threshold, but GTs exist -> all missed (FN)
                continue

            # Sort preds by score (descending) for greedy matching
            order = torch.argsort(scores_p, descending=True)
            boxes_p = boxes_p[order]

            ious = box_iou(boxes_p, boxes_gt) # [N_pred, N_gt]
            matched_gt = torch.zeros(boxes_gt.shape[0], dtype=torch.bool)

            for p_idx in range(boxes_p.shape[0]):
                # Best-match GT for this prediction
                iou_vals = ious[p_idx]
                best_iou, best_gt_idx = iou_vals.max(0)

                if best_iou >= iou_thr and not matched_gt[best_gt_idx]:
                    TP += 1
                    matched_gt[best_gt_idx] = True
                else:
                    FP += 1
        
        sensitivity[i] = TP / max(1, n_gt_total)
        fp_per_image[i] = FP / max(1, n_images)

    return fp_per_image, sensitivity, score_thresholds


def sensitivity_at_fp(fp_per_image, sensitivity, fp_targets):
    """
    Interpolate sensitivity at target FP/image values.
    fp_per_image, sensitivity: np arrays from compute_froc_curve()
    fp_targets: list or array of target FP/image values, e.g. [0.5, 1, 2, 4]
    
    Returns dict: {fp_target, sensitivity_value}
    """

    fp_per_image = np.asarray(fp_per_image)
    sensitivity = np.asarray(sensitivity)
    fp_targets = np.asarray(fp_targets, dtype=np.float32)

    # Ensure fp_per_image is sorted ascending
    order = np.argsort(fp_per_image)
    fp_sorted = fp_per_image[order]
    sens_sorted = sensitivity[order]


    # Use numpy interpolation
    sens_at = np.interp(fp_targets, fp_sorted, sens_sorted, left=0.0, right=sens_sorted[-1])

    return {float(fp): float(s) for fp, s in zip(fp_targets, sens_at)}


def plot_froc_curve(fp_per_image, sensitivity):
    """
    Plots the full FROC curve.

    fp_per_image, sensitivity: np arrays from compute_froc_curve()
    """

    plt.figure(figsize=(6, 4))
    plt.plot(fp_per_image, sensitivity, color="red", marker="o")
    plt.xlabel("False positives per image")
    plt.ylabel("Sensitivity (recall)")
    plt.grid(True, linestyle="--", alpha=0.3)
    plt.title("FROC curve (class-agnostic lesions)")
    plt.tight_layout()
    plt.show()


def print_froc_curve_info(model, loader, device):
    fp_curve, sens_curve, thr = compute_froc_curve(model, loader, device, iou_thr=0.5)
    targets = [0.5, 1.0, 2.0, 4.0, 5.0, 8.0]
    sens_dict = sensitivity_at_fp(fp_curve, sens_curve, targets)
    plot_froc_curve(fp_curve, sens_curve)
    print("FROC (class-agnostic):")
    for fp, s in sens_dict.items():
        print(f"Sensitivity at {fp} FP/image: {s*100:.2f}%")


## Mean Average Precision

In [None]:
@torch.no_grad()
def evaluate_detector(model, loader, device, num_classes=8, class_names=None, pr_conf_thr=0.25, pr_iou_thr=0.5):
    """
        pr_conf_thr is used to filter predictions by their score.
        pr_iou_thr is used to define what is considered a TP in fixed-threshold precision-recall computation.
        pr_conf_thr and pr_iou_thr parameters affect only precision and recall calculation, not mAPs.
        mAPs are calculated by torchmetrics, which uses its own IoU thresholds.

        Returns a dictionary:
        {
            "mAP10": float,
            "mAP30": float,
            "mAP50": float,
            "mAP50_95: float,
            "per_class": ["name": ..., "AP10": ..., "AP30": ..., "AP50": ..., "AP": ...],
            "precision_overall": float,
            "recall_overall": float,
            "precision_per_class": [...],
            "recall_per_class": [...],
        }
    """

    model.eval()
    metric_all = MeanAveragePrecision(iou_type="bbox", box_format="xyxy", class_metrics=True)
    metric_10  = MeanAveragePrecision(iou_type="bbox", box_format="xyxy", class_metrics=True, iou_thresholds=[0.1])
    metric_30  = MeanAveragePrecision(iou_type="bbox", box_format="xyxy", class_metrics=True, iou_thresholds=[0.3])
    metric_50  = MeanAveragePrecision(iou_type="bbox", box_format="xyxy", class_metrics=True, iou_thresholds=[0.5])

    # For precision and recall at fixed thresholds (IoU=0.5, conf=pr_conf_thr)
    TP = torch.zeros(num_classes, dtype=torch.long)
    FP = torch.zeros(num_classes, dtype=torch.long)
    FN = torch.zeros(num_classes, dtype=torch.long)
    n_processed_images = 0

    for batch_idx, (images, targets) in enumerate(loader, start=1):
        images = [image.to(device) for image in images]
        outputs = model(images)
        n_processed_images += len(images)

        # Move to CPU for metrics
        predictions, ground_truths = [], []
        for output, target in zip(outputs, targets):
            predictions.append({"boxes": output["boxes"].cpu(),
                                "scores": output["scores"].cpu(),
                                "labels": output["labels"].cpu()})
            ground_truths.append({"boxes": target["boxes"].cpu(),
                                  "labels": target["labels"].cpu()})
            
        # mAP update
        metric_all.update(predictions, ground_truths)
        metric_10.update(predictions, ground_truths)
        metric_30.update(predictions, ground_truths)
        metric_50.update(predictions, ground_truths)

        # Precision and recall accumulation at fixed thresholds
        for output, target in zip(predictions, ground_truths):
            # Filter predictions by confidence
            keep = output["scores"] >= pr_conf_thr
            scores = output["scores"][keep]
            order = torch.argsort(scores, descending=True)
            pred_boxes = output["boxes"][keep][order] # Reorder to the same order as scores
            pred_labels = output["labels"][keep][order] # Reorder to the same order as scores
            gt_boxes = target["boxes"]
            gt_labels = target["labels"]

            matched = torch.zeros(len(gt_boxes), dtype=torch.bool)
            if len(pred_boxes) and len(gt_boxes):
                ious = ops.box_iou(pred_boxes, gt_boxes)
                for pred_idx in range(len(pred_boxes)):
                    cls = int(pred_labels[pred_idx].item()) # classes are 1...K
                    # candidates: same class
                    same = (gt_labels == cls)
                    if same.any():
                        ious_c = ious[pred_idx, same]
                        if len(ious_c):
                            gt_idxs = torch.where(same)[0]
                            best_iou, best_loc = ious_c.max(0)
                            gt_idx = gt_idxs[best_loc]
                            if best_iou >= pr_iou_thr and not matched[gt_idx]:
                                TP[cls - 1] += 1
                                matched[gt_idx] = True
                            else:
                                FP[cls - 1] += 1
                        else:
                            FP[cls - 1] += 1
                    else:
                        FP[cls - 1] += 1
            
            # Any unmatched ground truths are FN
            for gt_idx, gt_label in enumerate(gt_labels):
                if not matched[gt_idx]:
                    FN[int(gt_label.item()) - 1] += 1

        if batch_idx % 10 == 0 or batch_idx == len(loader):
            print(f"\r[{n_processed_images}/{len(loader.dataset)}] images validated.", end="", flush=True)

    print() # print a new line

    # mAP metrics
    res_all = metric_all.compute()
    res_10  = metric_10.compute()
    res_30  = metric_30.compute()
    res_50  = metric_50.compute()

    out = {
        "mAP10": float(res_10["map"]),
        "mAP30": float(res_30["map"]),
        "mAP50": float(res_50["map"]),
        "mAP50_95": float(res_all["map"]),
        "mAP_S": float(res_all["map_small"]),
        "mAP_M": float(res_all["map_medium"]),
        "mAP_L": float(res_all["map_large"])
    }

    # Per-class AP (if available)
    per_class = []
    map_per_class = res_all.get("map_per_class", None)
    map10_per_class = res_10.get("map_per_class", None)
    map30_per_class = res_30.get("map_per_class", None)
    map50_per_class = res_50.get("map_per_class", None)
    if map_per_class is not None:
        ap   = map_per_class.tolist()
        ap10 = map10_per_class.tolist()
        ap30 = map30_per_class.tolist()
        ap50 = map50_per_class.tolist()
        for i in range(len(ap)):
            name = class_names[i] if class_names and i < len(class_names) else f"class_{i + 1}"
            per_class.append({
                "idx": i + 1,
                "name": name,
                "AP": _nan_if_undefined(ap[i]), 
                "AP10": _nan_if_undefined(ap10[i]),
                "AP30": _nan_if_undefined(ap30[i]),
                "AP50": _nan_if_undefined(ap50[i])
            })

    out["per_class"] = per_class

    # Precision/Recall at fixed thresholds
    precision_per_class = (TP.float() / (TP + FP).clamp(min=1)).tolist()
    recall_per_class    = (TP.float() / (TP + FN).clamp(min=1)).tolist()
    overall_precision   = float(TP.sum() / (TP.sum() + FP.sum()).clamp(min=1))
    overall_recall      = float(TP.sum() / (TP.sum() + FN.sum()).clamp(min=1))

    out["precision_overall"]   = overall_precision
    out["recall_overall"]      = overall_recall
    out["precision_per_class"] = precision_per_class
    out["recall_per_class"]    = recall_per_class
    out["pr_conf_thr"]         = pr_conf_thr
    out["pr_iou_thr"]          = pr_iou_thr

    return out

# =================================================================================================================================================
# =================================================================================================================================================

def _count_instances_per_class(dataset, num_classes):
    """
    Function that counts the ground truth instances per class by reading the label .txt files on disk.
    """

    counts = [0] * num_classes
    total = 0

    for img_name in dataset.image_names:
        label_path = os.path.join(dataset.label_dir, Path(img_name).stem + ".txt")
        if not os.path.exists(label_path):
            continue
        with open(label_path, "r") as file:
            for line in file:
                parts = line.strip().split()
                if len(parts) != 5: # Ill-written label text file
                    continue
                cls = int(float(parts[0]))
                # Faster R-CNN labels are 1...K (background is implicit), we map to 0...K-1 index
                if 1 <= cls <= num_classes:
                    counts[cls - 1] += 1
                    total += 1
    return counts, total

def _to_float(x, default=float("nan")):
    if x is None:
        return default
    if isinstance(x, torch.Tensor):
        if x.numel() == 0:
            return default
        x = x.detach().cpu().item() if x.ndim == 0 else x.detach().cpu().numpy()
    if isinstance(x, np.ndarray):
        return float(x.item()) if x.size == 1 else default
    try:
        return float(x)
    except Exception:
        return default
    
def _to_int(x, default=0):
    try:
        return int(x)
    except Exception:
        return default
    
def _nan_if_undefined(x):
    return float("nan") if x is None or (isinstance(x, (float, int)) and x < 0) else float(x)

def print_result_report(metrics, loader, class_names):
    """
    Function that prints pretty report with evaluation metrics.
    Uses dataset files to compute number of images and instances.
    """

    num_classes = len(class_names)
    images      = _to_int(len(loader.dataset))
    per_class   = metrics.get("per_class", [])
    p_overall   = _to_float(metrics["precision_overall"])
    r_overall   = _to_float(metrics["recall_overall"])
    map10       = _to_float(metrics["mAP10"])
    map30       = _to_float(metrics["mAP30"])
    map50       = _to_float(metrics["mAP50"])
    map50_95    = _to_float(metrics["mAP50_95"])
    map_S       = _to_float(metrics["mAP_S"])
    map_M       = _to_float(metrics["mAP_M"])
    map_L       = _to_float(metrics["mAP_L"])

    # Count instances per class from labels
    counts, total_instances = _count_instances_per_class(loader.dataset, num_classes)

    # Build quick dicts for per-class AP10/AP30/AP50/AP
    ap10_by_name = {d['name']: d['AP10'] for d in per_class}
    ap30_by_name = {d['name']: d['AP30'] for d in per_class}
    ap50_by_name = {d['name']: d['AP50'] for d in per_class}
    ap_by_name = {d['name']: d['AP'] for d in per_class}

    # Header
    print(f"{'Class':>18} {'Images':>8} {'Instances':>10} {'P':>10} {'R':>10} {'mAP10':>10} {'mAP30':>10} {'mAP50':>10} {'mAP50-95':>10} {'mAPS':>10} {'mAPM':>10} {'mAPL':>10}")

    # Overall row ("all")
    print(f"{'all':>18} {images:8d} {_to_int(total_instances):10d} {p_overall:10.3f} {r_overall:10.3f} {map10:10.3f} {map30:10.3f} {map50:10.3f} {map50_95:10.3f} {map_S:10.3f} {map_M:10.3f} {map_L:10.3f}")

    # Per-class rows
    p_pc = metrics.get("precision_per_class", [])
    r_pc = metrics.get("recall_per_class", [])

    for i, name in enumerate(class_names):
        P_i = _to_float(p_pc[i] if i < len(p_pc) else float("nan"))
        R_i = _to_float(r_pc[i] if i < len(r_pc) else float("nan"))
        AP10_i = _to_float(ap10_by_name.get(name, float("nan")))
        AP30_i = _to_float(ap30_by_name.get(name, float("nan")))
        AP50_i = _to_float(ap50_by_name.get(name, float("nan")))
        AP_i = _to_float(ap_by_name.get(name, float("nan")))
        inst_i = _to_int(counts[i])
        print(f"{name:>18} {images:8d} {inst_i:10d} {P_i:10.3f} {R_i:10.3f} {AP10_i:10.3f} {AP30_i:10.3f} {AP50_i:10.3f} {AP_i:10.3f}")


# Faster R-CNN

## Load Pre-trained Faster R-CNN Model

In [None]:
# Number of classes (no. dataset classes + 1 for background)
num_classes = 8 + 1
class_names = ["bone", "abdomen", "mediastinum", "liver", "lung", "kidney", "soft_tissue", "pelvis"]

# Set up the available device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Set up the default anchor generator
def _deeplesion_anchorgen():
    # Default: ((32,), (64,), (128,), (256,), (512,)) these subtuples are per level
    # anchor_sizes = ((8, 12), (16, 24), (32, 48), (64, 96), (128, 192))
    # DeepLesions sizes: ((16,), (24,), (32,), (48,), (96,))
    anchor_sizes = ((16,), (24,), (32,), (48,), (96,))
    aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
    return AnchorGenerator(anchor_sizes, aspect_ratios)

def construct_fasterrcnn_model():
    # 1) Load COCO-pretrained Faster R-CNN
    weights = FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1

    # Unfreeze 2 last layers
    model = fasterrcnn_resnet50_fpn_v2(
        weights=weights,
        min_size=512, 
        max_size=512,
        trainable_backbone_layers=2
    )

    # 2) Swap custom anchor generator
    # rpn_anchor_generator = _deeplesion_anchorgen()
    # model.rpn.anchor_generator = rpn_anchor_generator

    # (OPTIONAL) Tweak proposal counts for small lesions
    # model.rpn.pre_nms_top_n_train  = 4000
    # model.rpn.post_nms_top_n_train = 2000
    # model.rpn.pre_nms_top_n_test   = 2000
    # model.rpn.post_nms_top_n_test  = 1000

    # (WARNING) Run these lines below if different than default numbers of anchors per level are used
    # num_anchors = rpn_anchor_generator.num_anchors_per_location()[0]
    # in_channels = model.backbone.out_channels
    # model.rpn.head = RPNHead(in_channels, num_anchors)

    # Sanity check (prints trainable layers)
    # print("[LAYERS INFO]")
    # for n, p in model.backbone.body.named_parameters():
    #     print(f"{n} trainable = {p.requires_grad}")

    # 3) Replace the detection head to match the DeepLesion's number of classes
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # 4) Patch the first conv layer to accept 1-channel input
    #    (model.backbone.body is the ResNet-50 backbone)
    old_conv = model.backbone.body.conv1 # shape: [out_c, 3, k, k] ( == [out_channels, in_channels, kernel_height, kernel_width])

    new_conv = nn.Conv2d(
        in_channels=1,
        out_channels=old_conv.out_channels,
        kernel_size=old_conv.kernel_size,
        stride=old_conv.stride,
        padding=old_conv.padding,
        bias=False,
    )

    # Initialize 1-channel conv layer using pretrained RGB weights
    with torch.no_grad():
        # Option A: simple average over RGB
        new_conv.weight[:] = old_conv.weight.mean(dim=1, keepdim=True)

        # Option B: luminance-weighted sum to mimic grayscale
        # r = old_conv.weight[:, 0:1, :, :]
        # g = old_conv.weight[:, 1:2, :, :]
        # b = old_conv.weight[:, 2:3, :, :]
        # new_conv.weight[:] = 0.2989 * r + 0.5870 * g + 0.1140 * b

    model.backbone.body.conv1 = new_conv

    # 5) Adjust the model's internal normalization to 1 channel
    # If the loader returns tensors in [0, 1], this centers to roughly ImageNet-like scale.
    # [!] If we already pre-normalize to [0, 1] and don't want extra normalization, use mean = [0.0] and std = [1.0].
    model.transform.image_mean = [0.5]
    model.transform.image_std = [0.5]
    # These lists above usually contain 3 values, each for normalization of every RGB channel.
    # Since I have only one channel, then I need only one such value in both of these lists.

    # Move model to GPU if available
    model.to(device)

    # Sanity check
    # print(f"First conv layer shape: {model.backbone.body.conv1.weight.shape}") # Should be [64, 1, 7, 7]

    return model

## Data Augmentation - defininition of data transformations classes

In [None]:
class ComposeTransform:
    """Compose for (image, target) pairs."""
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target
    
class ToTensorTransform:
    """Convert PIL image to tensor, leave target unchanged"""
    def __call__(self, image, target):
        image = F.to_tensor(image) # [C, H, W], float32 in [0, 1]
        return image, target
    
class RandomHorizontalFlipTransform:
    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, image, target):
        if random.random() < self.p:
            # image: [1, H, W] or [C, H, W]
            _, h, w = image.shape
            image = torch.flip(image, dims=[2]) # flip width dimension

            boxes = target["boxes"]
            if boxes.numel() > 0:
            # boxes: [N, 4] in [x_min, y_min, x_max, y_max]
                x_min = boxes[:, 0]
                y_min = boxes[:, 1]
                x_max = boxes[:, 2]
                y_max = boxes[:, 3]

                # flip x-coordinates: x' = w - x
                new_x_min = w - x_max
                new_x_max = w - x_min

                boxes = torch.stack([new_x_min, y_min, new_x_max, y_max], dim=1)
                target["boxes"] = boxes

        return image, target
    
class RandomBrightnessContrastTransform:
    def __init__(self, brightness=0.1, contrast=0.1, p=0.5):
        """
        Relative change of brightness and contrast.
        brightness=0.1 means factor in [0.9, 1.1], etc. 
        """
        self.brightness = brightness
        self.contrast = contrast
        self.p = p

    def __call__(self, image, target):
        if random.random() < self.p:
            # image in tensor [C, H, W]
            # Random brightness
            if self.brightness > 0:
                factor = 1.0 + random.uniform(-self.brightness, self.brightness)
                image = F.adjust_brightness(image, factor)
            # Random contrast
            if self.contrast > 0:
                factor = 1.0 + random.uniform(-self.contrast, self.contrast)
                image = F.adjust_contrast(image, factor)
            image = image.clamp(0.0, 1.0)
        return image, target
    
class RandomGaussianNoiseTransform:
    def __init__(self, sigma=0.01, p=0.5):
        self.sigma = sigma
        self.p = p

    def __call__(self, image, target):
        if random.random() < self.p:
            noise = torch.randn_like(image) * self.sigma
            image = image + noise
            image = image.clamp(0.0, 1.0)
        return image, target


## Prepare custom Dataset class for Faster R-CNN model

In [None]:
# Custom Dataset class for DeepLesion dataset
class DeepLesionDataset(Dataset):
    def __init__(self, root, split):
        # Initialize dataset path, split and transformations
        self.root = root
        self.split = split

        # Apply data augmentations only for the train split
        if split == "train":
            self.transforms = ComposeTransform([
                ToTensorTransform(),
                RandomHorizontalFlipTransform(p=0.5),
                RandomBrightnessContrastTransform(brightness=0.1, contrast=0.1, p=0.5),
                RandomGaussianNoiseTransform(sigma=0.01, p=0.5),
            ])
        else:
            self.transforms = ComposeTransform([
                ToTensorTransform(), # Converts [0, 255] uint8 values to float [0.0, 1.0], and preservers 1 channel
            ])

        # Dataset logic (image paths, annotations, etc.)
        self.image_dir = os.path.join(root, "images", split)
        self.label_dir = os.path.join(root, "labels", split)
        self.image_names = sorted([img for img in os.listdir(self.image_dir) if img.endswith(".png") or img.endswith(".jpg")])

    def __getitem__(self, idx):
        image_name = self.image_names[idx]
        image_path = os.path.join(self.image_dir, image_name)
        label_path = os.path.join(self.label_dir, os.path.splitext(image_name)[0] + ".txt")

        # Load grayscale PIL image
        image = Image.open(image_path).convert("L")

        # Load corresponding bounding boxes and labels
        boxes, labels = [], []
        if os.path.exists(label_path):
            for line in open(label_path):
                cls, x_min, y_min, x_max, y_max = map(float, line.split())
                boxes.append([x_min, y_min, x_max, y_max])
                labels.append(int(cls))

        # Create a target dictionary
        target = {
            "boxes": torch.as_tensor(boxes, dtype=torch.float32),
            "labels": torch.as_tensor(labels, dtype=torch.int64)
        }
        
        # Apply transforms
        if self.transforms:
            image, target = self.transforms(image, target)
        
        return image, target
    
    def __len__(self):
        return len(self.image_names)
    

## Prepare DataLoader objects

### Set up the dataset's split path

In [None]:
dataset_fasterrcnn_path = "deeplesion_fasterrcnn_split_1" # There are three splits: *_1, *_2 and *_3

### Set up the batch sizes

In [None]:
# Google Colab
train_batch_size = 4      # Set to 4 (keep 2-4, because that is a sweet spot for two-stage detectors. With higher values may hurt training dynamics)
test_val_batch_size = 32   # Set to 32 (high value won't affect metric calculations, but increases memory usage)

In [None]:
# Local
train_batch_size = 1
test_val_batch_size = 1

### Create DataLoader objects

In [None]:
"""
- Shuffling is enabled for training DataLoader, because SGD benefits from seeing data in a new random order every epoch.
  During validation and testing phases we do not need that, the order does not affect the metrics.

- num_workers is the number of background processes that load & transorm batches in parallel. Good rule of thumb is num_workers being 2-4.

- pin_memory, or pinned (page-locked) host memory, speeds up host to GPU copies and lets us use asynchronous transfers
  It should be set to True if we train on GPU. It usually gives a small lbut real throughput bump. It consumes a bit more system RAM
  and is useless on CPU-only runs.

- Detection models expect lists of images and lists of target dicts, because each image can have different size and has a different
  number of boxes. The default PyTorch collate tries to stack everything into tensors of the same shape, which breaks for 
  variable-length targets. Custom collate_fn function here unzips the list oof pairs into pair of lists so Faster R-CNN can consume them:
    images: List[Tensor[C,H,W]]
    targets: List[Dict{'boxes': Tensor[N,4], 'labels': Tensor[N]}]
  That is exactly what torchvision's detection references use.

"""

train_ds = DeepLesionDataset(deeplesion_data_dir / dataset_fasterrcnn_path, "train")
val_ds = DeepLesionDataset(deeplesion_data_dir / dataset_fasterrcnn_path, "val")
test_ds = DeepLesionDataset(deeplesion_data_dir / dataset_fasterrcnn_path, "test")

def collate_fn(batch):
    # batch: [(img1, target1), (img2, target2), ...]
    # returns: ([img1, img2, ...], [target1, target2, ...])
    return tuple(zip(*batch)) # -> 

train_loader = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=test_val_batch_size, shuffle=False, collate_fn=collate_fn, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=test_val_batch_size, shuffle=False, collate_fn=collate_fn, num_workers=2, pin_memory=True)

### Training Loop 

In [None]:
"""

- For Faster R-CNN it is common to use SGD or Adam as the optimizer.
- Hyperparameters:
    - momentum: 
        adds an exponential moving average of past gradients to the current step, which causes smoother updates,
        less zig-zagging and faster convergance. Typically set to 0.9, and rarely needs tuning.
    - weight_decay (L2 regularization):
        Penalizes large weights to reduce overfitting (shrinks params each step).
        Typical for detection with SGD: 5e-4 or 1e-4
    - step_size (in StepLR):
        Every step_size epochs, the LR scheduler triggers a decay.
    - gamma (in StepLR):
        Multiplicative LR factor at each step: new_lr = old_lr * gamma. Commonly set to 0.1.

- Cross-validate only on the following hyperparameters:
    - LR: [0.01, 0.005, 0.002]
    - weight_decay: [5e-4, 1e-4]
    - Epochs -> don't cross-validate over it -> set a generous cap (e.g. 100) and early stop

- use_amp (AMP - Automatic Mixed Precision) - runs many ops in float16 instead of float32, which takes much less GPU memory and is often faster

- autocast() is a context manager that automatically picks a safe dtype per op (keeps numerically sensitive ops in float32, others in float16).
  It saves memory/computation.
  
- GradScaler multiplies the loss by a large scale before backprop to avoid float16 underflow, then unscales safely before the optimizer step.
  It makes the gradients stable in half precision.

- optimizer.zero_grad(set_to_none=True) - set_to_none parameter set to True means that for each parameter param.grad is set to None
  (no tensor is kept). On the next backward() PyTorch allocates a fresh grad tensor and writes into it. It causes faster & less memory traffic
  by avoiding writing zeros over large grad buffers every step. Lowers memory footprint by letting unused grads be garbage-collected and reallocated
  only when needed.

"""

def train_one_config(
    train_loader, val_loader, device,
    learning_rate, weight_decay, momentum=0.9,
    max_epochs=100, patience=15, metric_key="mAP50_95",
    gamma=0.1, step_size=3
):
    # Construct the model
    model = construct_fasterrcnn_model()

    # =======================================================
    # Set up optimizer (different LRs for head and backbone)
    # =======================================================
    # head_lr = learning_rate
    # backbone_lr = learning_rate * 0.1
    # head_params = []
    # backbone_params = []
    # for name, param in model.named_parameters():
    #     if not param.requires_grad:
    #         continue
    #     if name.startswith("backbone.body"):
    #         backbone_params.append(param)
    #     else:
    #         head_params.append(param)
    
    # optimizer = torch.optim.SGD(
    #     [
    #         {"params": backbone_params, "lr": backbone_lr, "weight_decay": weight_decay},
    #         {"params": head_params, "lr": head_lr, "weight_decay": weight_decay},
    #     ],
    #     momentum=momentum,
    # )
    # =======================================================

    # Set up optimizer (same LR for both head and backbone)
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=learning_rate, momentum=momentum, weight_decay=weight_decay)

    # =======================================================
    # Learning rate scheduler (StepLR vs MultiStepLR vs CosineAnnealingLR)
    # =======================================================
    # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 8, 12], gamma=gamma)
    # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs, eta_min=0.0)
    # =======================================================

    best_metric = -float("inf")
    best_epoch = -1
    best_state = None
    epochs_no_improve = 0
    history = []
    global_step = 0
    base_lr = learning_rate

    # Set up warmup
    n_warmup_epochs = 1
    warmup_iters = n_warmup_epochs * len(train_loader)

    # Train model
    for epoch in range(1, max_epochs + 1):
        print(f"\n*** Epoch [{epoch}/{max_epochs}] started ***")
        model.train()
        running_loss = 0.0
        n_processed_images = 0

        # Training loop
        for batch_idx, (images, targets) in enumerate(train_loader, start=1):
            global_step += 1
            
            # Warmup phase
            if global_step <= warmup_iters:
                warmup_factor = global_step / float(warmup_iters)
                current_lr = base_lr * warmup_factor
                for param_group in optimizer.param_groups:
                    param_group["lr"] = current_lr
                warmup_percent = 100.0 * warmup_factor
            else: # let scheduler manage LR
                warmup_percent = 100.0

            images = list(image.to(device) for image in images)
            targets = [{key: val.to(device) for key, val in target.items()} for target in targets]

            optimizer.zero_grad(set_to_none=True)
            loss_dict = model(images, targets)
            loss = sum(loss_dict.values())
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            n_processed_images += len(images)

            if batch_idx % 10 == 0 or batch_idx == len(train_loader):
                avg_loss_so_far = running_loss / batch_idx
                status = (
                    f"\rEpoch: [{epoch}/{max_epochs}], "
                    f"Images: [{n_processed_images}/{len(train_loader.dataset)}], "
                    f"Warmup: [{float(warmup_percent):.2f}%], "
                    f"Loss: {avg_loss_so_far:.4f}"
                )
                print(status, end="", flush=True)

        print()
        lr_scheduler.step()
        train_loss = running_loss / max(1, len(train_loader))
        print(f"*** Epoch [{epoch}/{max_epochs}] finished -> Loss: {train_loss:.4f} ***")

        # Validation
        print("*** Validation started ***")
        val_metrics = evaluate_detector(model, val_loader, device, num_classes, class_names)
        val_score = float(val_metrics[metric_key])

        history.append({"epoch": epoch, "train_loss": train_loss, **val_metrics})
        print(f"Loss={train_loss:.4f}, mAP50={val_metrics['mAP50']:.4f}, mAP50_95={val_metrics['mAP50_95']:.4f}")
        print(f"*** Validation finished ***")

        # Early stopping check
        if val_score > best_metric + 1e-6:
            best_metric = val_score
            best_epoch = epoch
            best_state = copy.deepcopy(model.state_dict())
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"*** Early stopping after {epoch}/{max_epochs} epochs (best at {best_epoch} with {metric_key}={best_metric:.4f}). ***")
                break

    return {
        "best_metric": best_metric,
        "best_epoch": best_epoch,
        "best_state": best_state,
        "history": history,
        "lr": learning_rate,
        "weight_decay": weight_decay
    }

def plot_training_history(result, title=None):
    """
    Plots the history of one training configuration. 
    result is the dictionary returned by train_one_config
    """

    history = result["history"]
    epochs = [ h["epoch"] for h in history ]
    train_losses = [ h["train_loss"] for h in history ]
    map50 = [ h["mAP50"] for h in history ]
    map50_95 = [ h["mAP50_95"] for h in history ]

    if title is None:
        title = f"lr={result['lr']}, wd={result['weight_decay']}"

    fig, ax1 = plt.subplots(figsize=(7, 4))

    # Loss on left y-axis
    ax1.plot(epochs, train_losses, color="blue", marker="o", label="Train loss")
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss")
    ax1.grid(True, which="both", axis="both", linestyle="--", alpha=0.3)

    # mAP on right y-axis
    ax2 = ax1.twinx()
    ax2.plot(epochs, map50, color="green", marker="x", linestyle="-", label="mAP50")
    ax2.plot(epochs, map50_95, color="red", marker="s", linestyle=":", label="mAP50-95")
    ax2.set_ylabel("mAP")

    # Combine legends
    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + labels2, loc="best")

    plt.title(title)
    plt.tight_layout()
    plt.show()

# =================================================================================================================================================
# =================================================================================================================================================

# Set up hyperparameters
learning_rates = [0.01] # [0.002, 0.005, 0.01, 0.02]
weight_decays = [1e-4] # [1e-4, 5e-4]
max_epochs = 30
patience = 5
best_result = None

print("*** Training started ***")
for learning_rate in learning_rates:
    for weight_decay in weight_decays:
        print(f"[HYPERPARAMETERS]: learning_rate = {learning_rate}, weight_decay = {weight_decay}")
        result = train_one_config(
            train_loader=train_loader, val_loader=val_loader, device=device,
            learning_rate=learning_rate, weight_decay=weight_decay, max_epochs=max_epochs, patience=patience,
            metric_key="mAP50_95"
        )

        plot_training_history(result)

        if best_result is None or result["best_metric"] > best_result["best_metric"]:
            best_result = result

plot_training_history(best_result, title=f"[BEST] lr={result['lr']}, wd={result['weight_decay']}")
print("*** Training complete ***")
print(f"Best config: lr={best_result['lr']} wd={best_result['weight_decay']} "
      f"epoch={best_result['best_epoch']} mAP50_95={best_result['best_metric']:.4f}")

# Save the best model checkpoint
best_checkpoint = {
    "state_dict": best_result["best_state"],
    "epoch": best_result["best_epoch"],
    "metric_key": "mAP50_95",
    "metric_value": best_result["best_metric"],
    "hp": {
        "lr": best_result["lr"],
        "weight_decay": best_result["weight_decay"],
        "momentum": 0.9,
        "step_size": 3,
        "gamma": 0.1,
        "max_epochs": max_epochs,
        "patience": patience,
    },
    # Describe how to reconstruct the model
    "model_spec": {
        "arch": "fasterrcnn_resnet50_fpn_v2",
        "min_size": 512,
        "max_size": 512,
        "in_channels": 1,
        "num_classes": num_classes,
        "image_mean": [0.5],
        "image_std": [0.5],
    },
    "class_names": class_names,
    "versions": {"torch": torch.__version__, "torchvision": torchvision.__version__},
}

# Save checkpoint locally
os.makedirs("checkpoints", exist_ok=True)
ts = datetime.now().strftime("%Y%m%d-%H%M%S")
save_path = f"checkpoints/fasterrcnn_best_{ts}.pt"
torch.save(best_checkpoint, save_path)
print(f"[{ts}] Saved best checkpoint to {save_path}")

# [GOOGLE COLLAB ONLY - uncomment] Save checkpoint on Google drive
# gdrive_save_dir = Path("faster_rcnn_checkpoints")
# gdrive_save_dir.mkdir(parents=True, exist_ok=True)
# gdrive_save_path = os.path.join(str(gdrive_save_dir), f"fasterrcnn_best_{ts}.pt")
# torch.save(best_checkpoint, gdrive_save_path)

### Final Evaluation

In [None]:
# Load best weights and evaluate on the test set
best_checkpoint = torch.load(save_path, map_location="cpu", weights_only=False)
best_model = construct_fasterrcnn_model()
best_model.load_state_dict(best_checkpoint["state_dict"])

# Evaluate the best model
print(f"*** Evaluation started ***")
test_metrics = evaluate_detector(best_model, test_loader, device, num_classes, class_names)

print(f"*** Evaluation finished ***")
print(test_metrics)
print_result_report(test_metrics, test_loader, class_names)
print_froc_curve_info(best_model, test_loader, device)

# Disconnect Google Colab runtime

In [None]:
from google.colab import runtime
runtime.unassign()