# DETR Lesion Detector

This notebook is dedicated to training and evaluating a lesion detector on DeepLesion dataset with the following supervised model architectures for image detection with ResNet-50 backbone:
- DETR (Facebook DETR),
- Deformable DETR.

## Assumptions:
- Use 2D slice inputs (optionally use the neighbouring ones too),
- Resize all images to 512x512,
- Use YOLO-style Dataset class, with slight modification (bbox coordinates are: [x1, y1, x2, y2]).
- Use DeepLesion for training a general lesion localizer and some other like LiTS (Liver Tumor Segmentation) or CHAOS (CT liver dataset) for more specialized localizer.

## ðŸ“š Thesis Value Summary
### Contribution and Value:
- Comparison of 1-stage vs 2-stage vs Transformer vs legacy detectors on DeepLesion	-> âœ… Fills a gap in literature
- Evaluation of improved DETRs (DINO/Deformable) -> âœ… Modern insight
- General vs specialized lesion detection -> âœ… Strong clinical relevance
- Analysis of training time, robustness, failure modes -> âœ… Engineering depth


# Google Colab only

### Download required packages

In [None]:
!pip install -r https://raw.githubusercontent.com/pmalesa/lesion_detector/main/notebooks/requirements.txt

### Mount DeepLesion images and checkpoints from Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content

# remove existing link if any
!rm -rf data/deeplesion
!rm -rf faster_rcnn_checkpoints

!mkdir -p data
!ln -s /content/drive/MyDrive/deeplesion/data/deeplesion data/deeplesion
!ln -s /content/drive/MyDrive/deeplesion/checkpoints/detr detr_checkpoints
!ls -l data

# Import all packages

In [None]:
# General packages
import os
import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import NDArray
from typing import Any
from PIL import Image
import random
from pathlib import Path
from datetime import datetime

# Faster R-CNN packages
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2
from torchvision import transforms
import torchvision.transforms as T
import torchvision.transforms.functional as F
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_V2_Weights
from torchmetrics.detection.mean_ap import MeanAveragePrecision
import torchvision.ops as ops
from torchvision.ops import box_iou

# DETR packages
from transformers import DetrForObjectDetection, DetrImageProcessor

# Set paths to DeepLesion images and metadata

## Paths to unprocessed data

In [None]:
# Google Colab
deeplesion_metadata_path = Path("data/deeplesion/deeplesion_metadata.csv")
deeplesion_image_path = Path("data/deeplesion/key_slices/")

In [None]:
# Local
deeplesion_metadata_path = Path("../data/deeplesion_metadata.csv")
deeplesion_image_path = Path("../data/deeplesion/key_slices/")

# Paths to processed data

In [None]:
deeplesion_data_dir = Path("data/deeplesion/")
deeplesion_preprocessed_image_path = deeplesion_data_dir / "deeplesion_preprocessed_uint8/key_slices"
deeplesion_preprocessed_metadata_path = deeplesion_data_dir / "deeplesion_preprocessed_uint8/deeplesion_metadata_preprocessed.csv"

# Evaluation Functions

## FROC Curve Computation

In [None]:
@torch.no_grad()
def compute_froc_curve(model, loader, device, iou_thr=0.5, score_thresholds=None):
    """
    Computes FROC curve: sensitivity (recall) vs. FP per image.
    Returns:
        fp_per_image:     np.ndarray of shape [T]
        sensitivity:      np.ndarray of shape [T]
        score_thresholds: np.ndarray of shape [T] 
    """

    model.eval()

    if score_thresholds is None:
        # 0.0 ... 1.0, 101 points
        score_thresholds = np.linspace(0.0, 1.0, 101)
    
    # Collect all predictions and GTs first to avoid calling model many times
    all_image_preds = []
    all_image_gts = []
    n_images = 0
    n_gt_total = 0

    for images, targets in loader:
        images = [img.to(device) for img in images]
        outputs = model(images)

        for output, target in zip(outputs, targets):
            # predictions
            boxes_pred = output["boxes"].detach().cpu()
            scores_pred = output["scores"].detach().cpu()

            # ground truth (class agnostic)
            boxes_gt = target["boxes"].detach().cpu()
            all_image_preds.append((boxes_pred, scores_pred))
            all_image_gts.append(boxes_gt)
            n_images += 1
            n_gt_total += boxes_gt.shape[0]
    
    score_thresholds = np.array(score_thresholds, dtype=np.float32)
    fp_per_image = np.zeros_like(score_thresholds)
    sensitivity = np.zeros_like(score_thresholds)

    # For each threshold, count TP/FP over the whole dataset
    for i, thr in enumerate(score_thresholds):
        TP = 0
        FP = 0

        for (boxes_pred, scores_pred), boxes_gt in zip(all_image_preds, all_image_gts):
            # Filter predictions by score threshold
            keep = scores_pred >= thr
            boxes_p = boxes_pred[keep]
            scores_p = scores_pred[keep]

            if boxes_gt.numel() == 0:
                # No GT lesions in this image: all predictions are FPs
                FP += boxes_p.shape[0]
                continue

            if boxes_p.numel() == 0:
                # No predictions above threshold, but GTs exist -> all missed (FN)
                continue

            # Sort preds by score (descending) for greedy matching
            order = torch.argsort(scores_p, descending=True)
            boxes_p = boxes_p[order]

            ious = box_iou(boxes_p, boxes_gt) # [N_pred, N_gt]
            matched_gt = torch.zeros(boxes_gt.shape[0], dtype=torch.bool)

            for p_idx in range(boxes_p.shape[0]):
                # Best-match GT for this prediction
                iou_vals = ious[p_idx]
                best_iou, best_gt_idx = iou_vals.max(0)

                if best_iou >= iou_thr and not matched_gt[best_gt_idx]:
                    TP += 1
                    matched_gt[best_gt_idx] = True
                else:
                    FP += 1
        
        sensitivity[i] = TP / max(1, n_gt_total)
        fp_per_image[i] = FP / max(1, n_images)

    return fp_per_image, sensitivity, score_thresholds


def sensitivity_at_fp(fp_per_image, sensitivity, fp_targets):
    """
    Interpolate sensitivity at target FP/image values.
    fp_per_image, sensitivity: np arrays from compute_froc_curve()
    fp_targets: list or array of target FP/image values, e.g. [0.5, 1, 2, 4]
    
    Returns dict: {fp_target, sensitivity_value}
    """

    fp_per_image = np.asarray(fp_per_image)
    sensitivity = np.asarray(sensitivity)
    fp_targets = np.asarray(fp_targets, dtype=np.float32)

    # Ensure fp_per_image is sorted ascending
    order = np.argsort(fp_per_image)
    fp_sorted = fp_per_image[order]
    sens_sorted = sensitivity[order]


    # Use numpy interpolation
    sens_at = np.interp(fp_targets, fp_sorted, sens_sorted, left=0.0, right=sens_sorted[-1])

    return {float(fp): float(s) for fp, s in zip(fp_targets, sens_at)}


def plot_froc_curve(fp_per_image, sensitivity):
    """
    Plots the full FROC curve.

    fp_per_image, sensitivity: np arrays from compute_froc_curve()
    """

    plt.figure(figsize=(6, 4))
    plt.plot(fp_per_image, sensitivity, color="red", marker="o")
    plt.xlabel("False positives per image")
    plt.ylabel("Sensitivity (recall)")
    plt.grid(True, linestyle="--", alpha=0.3)
    plt.title("FROC curve (class-agnostic lesions)")
    plt.tight_layout()
    plt.show()


def print_froc_curve_info(model, loader, device):
    fp_curve, sens_curve, thr = compute_froc_curve(model, loader, device, iou_thr=0.5)
    targets = [0.5, 1.0, 2.0, 4.0, 5.0, 8.0]
    sens_dict = sensitivity_at_fp(fp_curve, sens_curve, targets)
    plot_froc_curve(fp_curve, sens_curve)
    print("FROC (class-agnostic):")
    for fp, s in sens_dict.items():
        print(f"Sensitivity at {fp} FP/image: {s*100:.2f}%")


## Mean Average Precision

In [None]:
@torch.no_grad()
def evaluate_detector(model, loader, device, num_classes=8, class_names=None, pr_conf_thr=0.25, pr_iou_thr=0.5):
    """
        Returns a dictionary:
        {
            "mAP50": float,
            "mAP50_95: float,
            "per_class": ["name": ..., "AP50": ..., "AP": ...],
            "precision_overall": float,
            "recall_overall": float,
            "precision_per_class": [...],
            "recall_per_class": [...],
        }
    """

    model.eval()
    metric_all = MeanAveragePrecision(iou_type="bbox", box_format="xyxy", class_metrics=True)
    metric_50  = MeanAveragePrecision(iou_type="bbox", box_format="xyxy", class_metrics=True, iou_thresholds=[0.5])

    # For precision and recall at fixed thresholds (IoU=0.5, conf=pr_conf_thr)
    TP = torch.zeros(num_classes, dtype=torch.long)
    FP = torch.zeros(num_classes, dtype=torch.long)
    FN = torch.zeros(num_classes, dtype=torch.long)
    n_processed_images = 0

    for batch_idx, (images, targets) in enumerate(loader, start=1):
        images = [image.to(device) for image in images]
        outputs = model(images)
        n_processed_images += len(images)

        # Move to CPU for metrics
        predictions, ground_truths = [], []
        for output, target in zip(outputs, targets):
            predictions.append({"boxes": output["boxes"].cpu(),
                                "scores": output["scores"].cpu(),
                                "labels": output["labels"].cpu()})
            ground_truths.append({"boxes": target["boxes"].cpu(),
                                  "labels": target["labels"].cpu()})
            
        # mAP update
        metric_all.update(predictions, ground_truths)
        metric_50.update(predictions, ground_truths)

        # Precision and recall accumulation at fixed thresholds
        for output, target in zip(predictions, ground_truths):
            # Filter predictions by confidence
            keep = output["scores"] >= pr_conf_thr
            scores = output["scores"][keep]
            order = torch.argsort(scores, descending=True)
            pred_boxes = output["boxes"][keep][order] # Reorder to the same order as scores
            pred_labels = output["labels"][keep][order] # Reorder to the same order as scores
            gt_boxes = target["boxes"]
            gt_labels = target["labels"]

            matched = torch.zeros(len(gt_boxes), dtype=torch.bool)
            if len(pred_boxes) and len(gt_boxes):
                ious = ops.box_iou(pred_boxes, gt_boxes)
                for pred_idx in range(len(pred_boxes)):
                    cls = int(pred_labels[pred_idx].item()) # classes are 1...K
                    # candidates: same class
                    same = (gt_labels == cls)
                    if same.any():
                        ious_c = ious[pred_idx, same]
                        if len(ious_c):
                            gt_idxs = torch.where(same)[0]
                            best_iou, best_loc = ious_c.max(0)
                            gt_idx = gt_idxs[best_loc]
                            if best_iou >= pr_iou_thr and not matched[gt_idx]:
                                TP[cls - 1] += 1
                                matched[gt_idx] = True
                            else:
                                FP[cls - 1] += 1
                        else:
                            FP[cls - 1] += 1
                    else:
                        FP[cls - 1] += 1
            
            # Any unmatched ground truths are FN
            for gt_idx, gt_label in enumerate(gt_labels):
                if not matched[gt_idx]:
                    FN[int(gt_label.item()) - 1] += 1

        if batch_idx % 10 == 0 or batch_idx == len(loader):
            print(f"\r[{n_processed_images}/{len(loader.dataset)}] images validated.", end="", flush=True)

    print() # print a new line

    # mAP metrics
    res_all = metric_all.compute()
    res_50  = metric_50.compute()

    out = {
        "mAP50": float(res_50["map"]),
        "mAP50_95": float(res_all["map"]),
    }

    # Per-class AP (if available)
    per_class = []
    map_per_class = res_all.get("map_per_class", None)
    map50_per_class = res_50.get("map_per_class", None)
    if map_per_class is not None:
        ap   = map_per_class.tolist()
        ap50 = map50_per_class.tolist()
        for i in range(len(ap)):
            name = class_names[i] if class_names and i < len(class_names) else f"class_{i + 1}"
            per_class.append({"idx": i + 1, "name": name, "AP": _nan_if_undefined(ap[i]), "AP50": _nan_if_undefined(ap50[i])})

    out["per_class"] = per_class

    # Precision/Recall at fixed thresholds
    precision_per_class = (TP.float() / (TP + FP).clamp(min=1)).tolist()
    recall_per_class    = (TP.float() / (TP + FN).clamp(min=1)).tolist()
    overall_precision   = float(TP.sum() / (TP.sum() + FP.sum()).clamp(min=1))
    overall_recall      = float(TP.sum() / (TP.sum() + FN.sum()).clamp(min=1))

    out["precision_overall"]   = overall_precision
    out["recall_overall"]      = overall_recall
    out["precision_per_class"] = precision_per_class
    out["recall_per_class"]    = recall_per_class
    out["pr_conf_thr"]         = pr_conf_thr
    out["pr_iou_thr"]          = pr_iou_thr

    return out

# =================================================================================================================================================
# =================================================================================================================================================

def _count_instances_per_class(dataset, num_classes):
    """
    Function that counts the ground truth instances per class by reading the label .txt files on disk.
    """

    counts = [0] * num_classes
    total = 0

    for img_name in dataset.image_names:
        label_path = os.path.join(dataset.label_dir, Path(img_name).stem + ".txt")
        if not os.path.exists(label_path):
            continue
        with open(label_path, "r") as file:
            for line in file:
                parts = line.strip().split()
                if len(parts) != 5: # Ill-written label text file
                    continue
                cls = int(float(parts[0]))
                # Faster R-CNN labels are 1...K (background is implicit), we map to 0...K-1 index
                if 1 <= cls <= num_classes:
                    counts[cls - 1] += 1
                    total += 1
    return counts, total

def _to_float(x, default=float("nan")):
    if x is None:
        return default
    if isinstance(x, torch.Tensor):
        if x.numel() == 0:
            return default
        x = x.detach().cpu().item() if x.ndim == 0 else x.detach().cpu().numpy()
    if isinstance(x, np.ndarray):
        return float(x.item()) if x.size == 1 else default
    try:
        return float(x)
    except Exception:
        return default
    
def _to_int(x, default=0):
    try:
        return int(x)
    except Exception:
        return default
    
def _nan_if_undefined(x):
    return float("nan") if x is None or (isinstance(x, (float, int)) and x < 0) else float(x)

def print_result_report(metrics, loader, class_names):
    """
    Function that prints pretty report with evaluation metrics.
    Uses dataset files to compute number of images and instances.
    """

    num_classes = len(class_names)
    images      = _to_int(len(loader.dataset))
    per_class   = metrics.get("per_class", [])
    p_overall   = _to_float(metrics["precision_overall"])
    r_overall   = _to_float(metrics["recall_overall"])
    map50       = _to_float(metrics["mAP50"])
    map50_95    = _to_float(metrics["mAP50_95"])

    # Count instances per class from labels
    counts, total_instances = _count_instances_per_class(loader.dataset, num_classes)

    # Build quick dicts for per-class AP50/AP
    ap50_by_name = {d['name']: d['AP50'] for d in per_class}
    ap_by_name = {d['name']: d['AP'] for d in per_class}

    # Header
    print(f"{'Class':>18} {'Images':>8} {'Instances':>10} {'P':>10} {'R':>10} {'mAP50':>10} {'mAP50-95':>10}")

    # Overall row ("all")
    print(f"{'all':>18} {images:8d} {_to_int(total_instances):10d} {p_overall:10.3f} {r_overall:10.3f} {map50:10.3f} {map50_95:10.3f}")

    # Per-class rows
    p_pc = metrics.get("precision_per_class", [])
    r_pc = metrics.get("recall_per_class", [])

    for i, name in enumerate(class_names):
        P_i = _to_float(p_pc[i] if i < len(p_pc) else float("nan"))
        R_i = _to_float(r_pc[i] if i < len(r_pc) else float("nan"))
        AP50_i = _to_float(ap50_by_name.get(name, float("nan")))
        AP_i = _to_float(ap_by_name.get(name, float("nan")))
        inst_i = _to_int(counts[i])
        print(f"{name:>18} {images:8d} {inst_i:10d} {P_i:10.3f} {R_i:10.3f} {AP50_i:10.3f} {AP_i:10.3f}")


## Data Augmentation - defininition of data transformations classes

In [None]:
class ComposeTransform:
    """Compose for (image, target) pairs."""
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target
    
class ToTensorTransform:
    """Convert PIL image to tensor, leave target unchanged"""
    def __call__(self, image, target):
        image = F.to_tensor(image) # [C, H, W], float32 in [0, 1]
        return image, target
    
class RandomHorizontalFlipTransform:
    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, image, target):
        if random.random() < self.p:
            # image: [1, H, W] or [C, H, W]
            _, h, w = image.shape
            image = torch.flip(image, dims=[2]) # flip width dimension

            boxes = target["boxes"]
            if boxes.numel() > 0:
            # boxes: [N, 4] in [x_min, y_min, x_max, y_max]
                x_min = boxes[:, 0]
                y_min = boxes[:, 1]
                x_max = boxes[:, 2]
                y_max = boxes[:, 3]

                # flip x-coordinates: x' = w - x
                new_x_min = w - x_max
                new_x_max = w - x_min

                boxes = torch.stack([new_x_min, y_min, new_x_max, y_max], dim=1)
                target["boxes"] = boxes

        return image, target
    
class RandomBrightnessContrastTransform:
    def __init__(self, brightness=0.1, contrast=0.1, p=0.5):
        """
        Relative change of brightness and contrast.
        brightness=0.1 means factor in [0.9, 1.1], etc. 
        """
        self.brightness = brightness
        self.contrast = contrast
        self.p = p

    def __call__(self, image, target):
        if random.random() < self.p:
            # image in tensor [C, H, W]
            # Random brightness
            if self.brightness > 0:
                factor = 1.0 + random.uniform(-self.brightness, self.brightness)
                image = F.adjust_brightness(image, factor)
            # Random contrast
            if self.contrast > 0:
                factor = 1.0 + random.uniform(-self.contrast, self.contrast)
                image = F.adjust_contrast(image, factor)
            image = image.clamp(0.0, 1.0)
        return image, target
    
class RandomGaussianNoiseTransform:
    def __init__(self, sigma=0.01, p=0.5):
        self.sigma = sigma
        self.p = p

    def __call__(self, image, target):
        if random.random() < self.p:
            noise = torch.randn_like(image) * self.sigma
            image = image + noise
            image = image.clamp(0.0, 1.0)
        return image, target


## Prepare DeepLesion dataset for DETR model

In [None]:
# Custom Dataset class for DeepLesion dataset
class DeepLesionDataset(Dataset):
    def __init__(self, root, split):
        # Initialize dataset path, split and transformations
        self.root = root
        self.split = split

        # Apply data augmentations only for the train split
        if split == "train":
            self.transforms = ComposeTransform([
                ToTensorTransform(),
                RandomHorizontalFlipTransform(p=0.5),
                RandomBrightnessContrastTransform(brightness=0.1, contrast=0.1, p=0.5),
                RandomGaussianNoiseTransform(sigma=0.01, p=0.5),
            ])
        else:
            self.transforms = ComposeTransform([
                ToTensorTransform(), # Converts [0, 255] uint8 values to float [0.0, 1.0], and preservers 1 channel
            ])

        # Dataset logic (image paths, annotations, etc.)
        self.image_dir = os.path.join(root, "images", split)
        self.label_dir = os.path.join(root, "labels", split)
        self.image_names = sorted([img for img in os.listdir(self.image_dir) if img.endswith(".png") or img.endswith(".jpg")])

    def __getitem__(self, idx):
        image_name = self.image_names[idx]
        image_path = os.path.join(self.image_dir, image_name)
        label_path = os.path.join(self.label_dir, os.path.splitext(image_name)[0] + ".txt")

        # Load grayscale PIL image
        image = Image.open(image_path).convert("L")

        # Load corresponding bounding boxes and labels
        boxes, labels = [], []
        if os.path.exists(label_path):
            for line in open(label_path):
                cls, x_min, y_min, x_max, y_max = map(float, line.split())
                boxes.append([x_min, y_min, x_max, y_max])
                labels.append(int(cls))

        # Create a target dictionary
        target = {
            "boxes": torch.as_tensor(boxes, dtype=torch.float32),
            "labels": torch.as_tensor(labels, dtype=torch.int64)
        }
        
        # Apply transforms
        if self.transforms:
            image, target = self.transforms(image, target)
        
        return image, target
    
    def __len__(self):
        return len(self.image_names)
    

## Prepare DataLoader objects

### Set up the dataset's split path

In [None]:
# [!] Same data splits as for Faster R-CNN
deeplesion_detr_path = "deeplesion_fasterrcnn_split_1" # There are three splits: *_1, *_2 and *_3

### Set up the batch sizes

In [None]:
# Google Colab
train_batch_size = 4      # Set to 4 (keep 2-4, because that is a sweet spot for two-stage detectors. With higher values may hurt training dynamics)
test_val_batch_size = 32   # Set to 32 (high value won't affect metric calculations, but increases memory usage)

In [None]:
# Local
train_batch_size = 1
test_val_batch_size = 1

In [None]:
"""
- Shuffling is enabled for training DataLoader, because SGD benefits from seeing data in a new random order every epoch.
  During validation and testing phases we do not need that, the order does not affect the metrics.

- num_workers is the number of background processes that load & transorm batches in parallel. Good rule of thumb is num_workers being 2-4.

- pin_memory, or pinned (page-locked) host memory, speeds up host to GPU copies and lets us use asynchronous transfers
  It should be set to True if we train on GPU. It usually gives a smal lbut real throughput bump. It consumes a bit more system RAM
  and is useless on CPU-only runs.

- Detection models expect lists of images and lists of target dicts, because each image can have different size and has a different
  number of boxes. The default PyTorch collate tries to stack everything into tensors of the same shape, which breaks for 
  variable-length targets. Custom collate_fn function here unzips the list oof pairs into pair of lists so Faster R-CNN can consume them:
    images: List[Tensor[C,H,W]]
    targets: List[Dict{'boxes': Tensor[N,4], 'labels': Tensor[N]}]
  That is exactly what torchvision's detection references use.

"""

train_ds = DeepLesionDataset(deeplesion_data_dir / deeplesion_detr_path, "train")
val_ds = DeepLesionDataset(deeplesion_data_dir / deeplesion_detr_path, "val")
test_ds = DeepLesionDataset(deeplesion_data_dir / deeplesion_detr_path, "test")

def collate_fn(batch):
    # batch: [(img1, target1), (img2, target2), ...]
    # returns: ([img1, img2, ...], [target1, target2, ...])
    return tuple(zip(*batch)) # -> 

train_loader = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=test_val_batch_size, shuffle=False, collate_fn=collate_fn, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=test_val_batch_size, shuffle=False, collate_fn=collate_fn, num_workers=2, pin_memory=True)

# DETR

### Training Loop 

In [None]:
# Instantiate DETR with pre-trained weights for entire model
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")

### Final Evaluation