# Deformable DETR Lesion Detector

This notebook is dedicated to training and evaluating a lesion detector on DeepLesion dataset with the Deformable DETR architecture for image detection with ResNet-50 backbone.

## Assumptions:
- Use 2D slice inputs (optionally use the neighbouring ones too),
- Resize all images to 512x512,
- Use YOLO-style Dataset class, with slight modification (bbox coordinates are: [x1, y1, x2, y2]).
- Use DeepLesion for training a general lesion localizer and some other small dataset for more specialized localizer.

## ðŸ“š Thesis Value Summary
### Contribution and Value:
- Comparison of 1-stage vs 2-stage vs Transformer vs legacy detectors on DeepLesion	-> âœ… Fills a gap in literature
- Evaluation of improved DETR (Deformable) -> âœ… Modern insight
- General vs specialized lesion detection -> âœ… Strong clinical relevance


# Google Colab only

### Download required packages

In [None]:
!pip install -r https://raw.githubusercontent.com/pmalesa/lesion_detector/main/notebooks/requirements.txt

### Mount DeepLesion images and checkpoints from Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content

# remove existing link if any
!rm -rf data/deeplesion
!rm -rf deformable_detr_checkpoints

!mkdir -p data
!ln -s /content/drive/MyDrive/deeplesion/data/deeplesion data/deeplesion
!ln -s /content/drive/MyDrive/deeplesion/checkpoints/deformable_detr deformable_detr_checkpoints
!ls -l data

# Clone DETR repository (+ PyTorch 2.x fix)

In [None]:
# Clone and rename original Deformable DETR repo
!git clone https://github.com/fundamentalvision/Deformable-DETR.git
!mv Deformable-DETR deformable_detr

# Clone Torch2.xCUDA12 fix repo
!git clone https://github.com/Norman-Ou/Deformable-DETR-Torch2.x-cuda12.git

# Overwrite models/ops subdirectory
!rm -rf deformable_detr/models/ops
!mv Deformable-DETR-Torch2.x-cuda12 deformable_detr/models/ops

%cd deformable_detr

# Compile CUDA Operators

In [None]:
!python deformable_detr/models/ops/setup.py build_ext --inplace

# Import all packages

In [None]:
# General packages
import os
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import random
from pathlib import Path
from datetime import datetime
import copy
from types import SimpleNamespace

# PyTorch packages
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision
import torchvision.transforms.functional as F
from torchmetrics.detection.mean_ap import MeanAveragePrecision
import torchvision.ops as ops
from torchvision.ops import box_iou

# Deformable DETR packages
import sys
sys.path.append("deformable_detr")

# repo root
# sys.path.insert(0, "/content/deformable_detr")

# # CUDA extension location
# sys.path.insert(0, "/content/deformable_detr/models/ops")

from deformable_detr.models.matcher import HungarianMatcher
from deformable_detr.models.deformable_detr import DeformableDETR, SetCriterion
from deformable_detr.models.backbone import build_backbone
from deformable_detr.models.deformable_transformer import build_deforamble_transformer
from deformable_detr.util.misc import NestedTensor

# Set paths to DeepLesion images and metadata

## Paths to unprocessed data

In [None]:
# Google Colab
deeplesion_metadata_path = Path("data/deeplesion/deeplesion_metadata.csv")
deeplesion_image_path = Path("data/deeplesion/key_slices/")

In [None]:
# Local
deeplesion_metadata_path = Path("../data/deeplesion_metadata.csv")
deeplesion_image_path = Path("../data/deeplesion/key_slices/")

# Paths to processed data

In [None]:
deeplesion_data_dir = Path("data/deeplesion/")
deeplesion_preprocessed_image_path = deeplesion_data_dir / "deeplesion_preprocessed_uint8/key_slices"
deeplesion_preprocessed_metadata_path = deeplesion_data_dir / "deeplesion_preprocessed_uint8/deeplesion_metadata_preprocessed.csv"

# [DEBUG] DETR Criterion 

In [None]:
def build_criterion(model, num_classes, device):
    matcher = HungarianMatcher(cost_class=1.0, cost_bbox=5.0, cost_giou=2.0)

    base_weight_dict = {"loss_ce": 1.0, "loss_bbox": 5.0, "loss_giou": 2.0}

    # DETR uses auxiliary losses from intermediate decoder layers by default
    num_decoder_layers = model.transformer.decoder.num_layers
    weight_dict = dict(base_weight_dict)
    for i in range(num_decoder_layers - 1):
        weight_dict.update({k + f"_{i}": v for k, v in base_weight_dict.items()})

    criterion = SetCriterion(
        num_classes=num_classes - 1, # foreground only
        matcher=matcher,
        weight_dict=weight_dict,
        eos_coef=0.1,
        losses=["labels", "boxes"]
    )
    criterion.to(device)

    return criterion

# Evaluation Functions

## Helper functions

In [None]:
def _count_instances_per_class(dataset, num_classes):
    """
    Function that counts the ground truth instances per class by reading the label .txt files on disk.
    """

    counts = [0] * num_classes
    total = 0

    for img_name in dataset.image_names:
        label_path = os.path.join(dataset.label_dir, Path(img_name).stem + ".txt")
        if not os.path.exists(label_path):
            continue
        with open(label_path, "r") as file:
            for line in file:
                parts = line.strip().split()
                if len(parts) != 5: # Ill-written label text file
                    continue
                cls = int(float(parts[0]))
                # Label files are 1-based (1...K), map to 0...K-1 indices
                if 1 <= cls <= num_classes:
                    counts[cls - 1] += 1
                    total += 1
    return counts, total

def _to_float(x, default=float("nan")):
    if x is None:
        return default
    if isinstance(x, torch.Tensor):
        if x.numel() == 0:
            return default
        x = x.detach().cpu().item() if x.ndim == 0 else x.detach().cpu().numpy()
    if isinstance(x, np.ndarray):
        return float(x.item()) if x.size == 1 else default
    try:
        return float(x)
    except Exception:
        return default
    
def _to_int(x, default=0):
    try:
        return int(x)
    except Exception:
        return default
    
def _nan_if_undefined(x):
    return float("nan") if x is None or (isinstance(x, (float, int)) and x < 0) else float(x)

def _cxcywh_to_xyxy(boxes):
    """
    Function to convert boxes' coordinates
    from cxcywh to xyxy.

    boxes: [B, 4] normalized cxcywh
    """

    cx, cy, w, h = boxes.unbind(-1)
    x1 = cx - 0.5 * w
    y1 = cy - 0.5 * h
    x2 = cx + 0.5 * w
    y2 = cy + 0.5 * h
    return torch.stack([x1, y1, x2, y2], dim=-1)

def _scale_boxes_to_pixels(boxes_xyxy_norm, sizes):
    """
    Function to scale normalized xyxy boxes
    to the correct pixel sizes.

    boxes_xyxy_norm: [N, 4] in [0, 1]
    sizes: (height, width)
    """

    h, w = sizes
    scale = boxes_xyxy_norm.new_tensor([w, h, w, h])
    return boxes_xyxy_norm * scale

def _orig_size_from_mask(mask):
    """
    Function to determine the actual
    image sizes from mask.

    mask: [h_max, w_max] bool, True=pad, False = valid
    """

    valid = ~mask
    h = int(valid.any(dim=1).sum().item())
    w = int(valid.any(dim=0).sum().item())
    return (h, w)

def _detr_outputs_to_predictions(outputs, masks, score_thr=0.0, top_k=100):
    """
    Function to convert DETR outputs to predictions
    acceptable by torchvision's methods.

    outputs: dict with pred_logits [B, Q, K+1], pred_boxes [B, Q, 4] (cxcywh norm)
    masks: [B, height, width] bool
    """

    logits = outputs["pred_logits"]         # [B, Q, K+1]
    boxes_cxcywh = outputs["pred_boxes"]    # [B, Q, 4]

    # Apply softmax over the class dimension, independently for each query
    probs = logits.softmax(-1)              # [B, Q, K+1]
    probs_fq = probs[..., :-1]              # drop no-object -> [B, Q, K]
    scores, labels = probs_fq.max(-1)       # [B, Q], [B, Q] labels in 0...K-1

    preds = []
    B, Q = scores.shape
    for b in range(B):
        # get original image size
        height, width = _orig_size_from_mask(masks[b])

        # convert boxes to xyxy and scale
        boxes_xyxy = _cxcywh_to_xyxy(boxes_cxcywh[b])
        boxes_xyxy = boxes_xyxy.clamp(0, 1)
        boxes_xyxy = _scale_boxes_to_pixels(boxes_xyxy, (height, width))

        # filter / top-k
        s = scores[b]
        l = labels[b]

        keep = s >= score_thr
        boxes_xyxy = boxes_xyxy[keep]
        s = s[keep]
        l = l[keep]

        if top_k is not None and boxes_xyxy.shape[0] > top_k:
            idx = torch.argsort(s, descending=True)[:top_k]
            boxes_xyxy, s, l = boxes_xyxy[idx], s[idx], l[idx]

        preds.append({
            "boxes": boxes_xyxy.cpu(),
            "scores": s.cpu(),
            "labels": l.cpu()
        })

    return preds

def _targets_to_ground_truths(targets, masks):
    """
    Function that converts ground truths in targets
    from normalized cxcywh to pixel size xyxy.
    """

    gts = []
    for idx, target in enumerate(targets):
        height, width = _orig_size_from_mask(masks[idx])
        boxes_xyxy = _cxcywh_to_xyxy(target["boxes"]).clamp(0, 1)
        boxes_xyxy = _scale_boxes_to_pixels(boxes_xyxy, (height, width))
        labels = target["labels"]
        gts.append({
            "boxes": boxes_xyxy.cpu(),
            "labels": labels.cpu()
        })

    return gts

## FROC Curve Computation

In [None]:
@torch.no_grad()
def compute_froc_curve(model, loader, device, iou_thr=0.5, score_thresholds=None):
    """
    Computes FROC curve: sensitivity (recall) vs. FP per image.
    Returns:
        fp_per_image:     np.ndarray of shape [T]
        sensitivity:      np.ndarray of shape [T]
        score_thresholds: np.ndarray of shape [T] 
    """

    model.eval()

    if score_thresholds is None:
        # 0.0 ... 1.0, 101 points
        score_thresholds = np.linspace(0.0, 1.0, 101)
    
    # Collect all predictions and GTs first to avoid calling model many times
    all_image_preds = []
    all_image_gts = []
    n_images = 0
    n_gt_total = 0

    for batch_idx, batch in enumerate(loader):
        # -- Unpack DETR batch
        (images, masks), targets = batch
        images = images.to(device)
        masks = masks.to(device)
        targets = [{key: val.to(device) for key, val in target.items()} for target in targets]

        # -- Wrap into NestedTensor (facebook DETR expects this)
        nested_tensor = NestedTensor(images, masks)

        # -- Forward
        outputs = model(nested_tensor)

        # -- Postprocess DETR outputs and move to CPU for metrics
        predictions = _detr_outputs_to_predictions(outputs, masks, score_thr=0.0, top_k=300)
        ground_truths = _targets_to_ground_truths(targets, masks)

        for output, target in zip(predictions, ground_truths):
            # predictions
            boxes_pred = output["boxes"].detach().cpu()
            scores_pred = output["scores"].detach().cpu()

            # ground truth (class agnostic)
            boxes_gt = target["boxes"].detach().cpu()

            all_image_preds.append((boxes_pred, scores_pred))
            all_image_gts.append(boxes_gt)
            
            n_images += 1
            n_gt_total += boxes_gt.shape[0]

    score_thresholds = np.array(score_thresholds, dtype=np.float32)
    fp_per_image = np.zeros_like(score_thresholds)
    sensitivity = np.zeros_like(score_thresholds)

    # For each threshold, count TP/FP over the whole dataset
    for i, thr in enumerate(score_thresholds):
        TP = 0
        FP = 0

        for (boxes_pred, scores_pred), boxes_gt in zip(all_image_preds, all_image_gts):
            # Filter predictions by score threshold
            keep = scores_pred >= thr
            boxes_p = boxes_pred[keep]
            scores_p = scores_pred[keep]

            if boxes_gt.numel() == 0:
                # No GT lesions in this image: all predictions are FPs
                FP += boxes_p.shape[0]
                continue

            if boxes_p.numel() == 0:
                # No predictions above threshold, but GTs exist -> all missed (FN)
                continue

            # Sort preds by score (descending) for greedy matching
            order = torch.argsort(scores_p, descending=True)
            boxes_p = boxes_p[order]

            ious = box_iou(boxes_p, boxes_gt) # [N_pred, N_gt]
            matched_gt = torch.zeros(boxes_gt.shape[0], dtype=torch.bool)

            for p_idx in range(boxes_p.shape[0]):
                # Best-match GT for this prediction
                iou_vals = ious[p_idx]
                best_iou, best_gt_idx = iou_vals.max(0)

                # one-to-one matching (one prediction <-> one GT)
                if best_iou >= iou_thr and not matched_gt[best_gt_idx]:
                    TP += 1
                    matched_gt[best_gt_idx] = True
                else:
                    FP += 1
        
        sensitivity[i] = TP / max(1, n_gt_total)
        fp_per_image[i] = FP / max(1, n_images)

    return fp_per_image, sensitivity, score_thresholds


def sensitivity_at_fp(fp_per_image, sensitivity, fp_targets):
    """
    Interpolate sensitivity at target FP/image values.
    fp_per_image, sensitivity: np arrays from compute_froc_curve()
    fp_targets: list or array of target FP/image values, e.g. [0.5, 1, 2, 4]
    
    Returns dict: {fp_target, sensitivity_value}
    """

    fp_per_image = np.asarray(fp_per_image)
    sensitivity = np.asarray(sensitivity)
    fp_targets = np.asarray(fp_targets, dtype=np.float32)

    # Ensure fp_per_image is sorted ascending
    order = np.argsort(fp_per_image)
    fp_sorted = fp_per_image[order]
    sens_sorted = sensitivity[order]

    # Use numpy interpolation
    sens_at = np.interp(fp_targets, fp_sorted, sens_sorted, left=0.0, right=sens_sorted[-1])

    return {float(fp): float(s) for fp, s in zip(fp_targets, sens_at)}

def plot_froc_curve(fp_per_image, sensitivity):
    """
    Plots the full FROC curve.

    fp_per_image, sensitivity: np arrays from compute_froc_curve()
    """

    plt.figure(figsize=(6, 4))
    plt.plot(fp_per_image, sensitivity, color="red", marker="o")
    plt.xlabel("False positives per image")
    plt.ylabel("Sensitivity (recall)")
    plt.grid(True, linestyle="--", alpha=0.3)
    plt.title("FROC curve (class-agnostic lesions)")
    plt.tight_layout()
    plt.show()

def print_froc_curve_info(model, loader, device):
    fp_curve, sens_curve, thr = compute_froc_curve(model, loader, device, iou_thr=0.5)
    targets = [0.5, 1.0, 2.0, 4.0, 5.0, 8.0]
    sens_dict = sensitivity_at_fp(fp_curve, sens_curve, targets)
    plot_froc_curve(fp_curve, sens_curve)
    print("FROC (class-agnostic):")
    for fp, s in sens_dict.items():
        print(f"Sensitivity at {fp} FP/image: {s*100:.2f}%")


## Mean Average Precision

In [None]:
@torch.no_grad()
def evaluate_detector(model, loader, device, num_classes=8, class_names=None, pr_conf_thr=0.25, pr_iou_thr=0.5):
    """
        Returns a dictionary:
        {
            "mAP50": float,
            "mAP50_95: float,
            "per_class": ["name": ..., "AP50": ..., "AP": ...],
            "precision_overall": float,
            "recall_overall": float,
            "precision_per_class": [...],
            "recall_per_class": [...],
        }
    """

    model.eval()
    metric_all = MeanAveragePrecision(iou_type="bbox", box_format="xyxy", class_metrics=True)
    metric_10  = MeanAveragePrecision(iou_type="bbox", box_format="xyxy", class_metrics=True, iou_thresholds=[0.1])
    metric_30  = MeanAveragePrecision(iou_type="bbox", box_format="xyxy", class_metrics=True, iou_thresholds=[0.3])
    metric_50  = MeanAveragePrecision(iou_type="bbox", box_format="xyxy", class_metrics=True, iou_thresholds=[0.5])

    # For precision and recall at fixed thresholds (IoU=0.5, conf=pr_conf_thr)
    TP = torch.zeros(num_classes, dtype=torch.long)
    FP = torch.zeros(num_classes, dtype=torch.long)
    FN = torch.zeros(num_classes, dtype=torch.long)
    n_processed_images = 0

    for batch_idx, batch in enumerate(loader, start=1):
        # -- Unpack DETR batch
        (images, masks), targets = batch
        images = images.to(device)
        masks = masks.to(device)
        targets = [{key: val.to(device) for key, val in target.items()} for target in targets]

        # -- Wrap into NestedTensor (facebook DETR expects this)
        nested_tensor = NestedTensor(images, masks)

        # -- Forward
        outputs = model(nested_tensor)
        n_processed_images += images.size(0)

        # DEBUG#############################################################
        # DEBUG#############################################################
        # criterion = build_criterion(model, num_classes, device)
        # matcher = criterion.matcher
        # with torch.no_grad():
        #     indices = matcher(outputs, targets)
        #     ious = []

        #     for b, (pred_idx, target_idx) in enumerate(indices):
        #         if len(pred_idx) == 0:
        #             continue

        #         p = outputs["pred_boxes"][b][pred_idx]
        #         t = targets[b]["boxes"][target_idx]

        #         p_xyxy = _cxcywh_to_xyxy(p)
        #         t_xyxy = _cxcywh_to_xyxy(t)

        #         iou = box_iou(p_xyxy, t_xyxy).diag()
        #         ious.append(iou)

        #     if ious:
        #         print("\nMatched IoU mean: ", torch.cat(ious).mean().item())
        # DEBUG#############################################################
        # DEBUG#############################################################

        # -- Postprocess DETR outputs and move to CPU for metrics
        predictions = _detr_outputs_to_predictions(outputs, masks)
        ground_truths = _targets_to_ground_truths(targets, masks)

        # -- Sanity check
        # assert all(0 <= l < 8 for l in torch.cat([g["labels"] for g in ground_truths]))
        # assert all(0 <= l < 8 for l in torch.cat([p["labels"] for p in predictions]))

        # -- mAP update
        metric_all.update(predictions, ground_truths)
        metric_10.update(predictions, ground_truths)
        metric_30.update(predictions, ground_truths)
        metric_50.update(predictions, ground_truths)

        # -- Precision and recall accumulation at fixed thresholds
        for output, target in zip(predictions, ground_truths):
            # ---- Filter predictions by confidence
            keep = output["scores"] >= pr_conf_thr
            scores = output["scores"][keep]
            order = torch.argsort(scores, descending=True)
            pred_boxes = output["boxes"][keep][order] # Reorder to the same order as scores
            pred_labels = output["labels"][keep][order] # Reorder to the same order as scores
            gt_boxes = target["boxes"]
            gt_labels = target["labels"]

            matched = torch.zeros(len(gt_boxes), dtype=torch.bool)
            if len(pred_boxes) and len(gt_boxes):
                ious = ops.box_iou(pred_boxes, gt_boxes)
                for pred_idx in range(len(pred_boxes)):
                    cls = int(pred_labels[pred_idx].item()) # classes are 0...K-1 (foreground only)
                    # candidates: same class
                    same = (gt_labels == cls)
                    if same.any():
                        ious_c = ious[pred_idx, same]
                        if len(ious_c):
                            gt_idxs = torch.where(same)[0]
                            best_iou, best_loc = ious_c.max(0)
                            gt_idx = gt_idxs[best_loc]
                            if best_iou >= pr_iou_thr and not matched[gt_idx]:
                                TP[cls] += 1
                                matched[gt_idx] = True
                            else:
                                FP[cls] += 1
                        else:
                            FP[cls] += 1
                    else:
                        FP[cls] += 1
            
            # ---- Any unmatched ground truths are FN
            for gt_idx, gt_label in enumerate(gt_labels):
                if not matched[gt_idx]:
                    FN[int(gt_label.item())] += 1

        if batch_idx % 10 == 0 or batch_idx == len(loader):
            print(f"\r[{n_processed_images}/{len(loader.dataset)}] images validated.", end="", flush=True)

    print() # print a new line

    # mAP metrics
    res_all = metric_all.compute()
    res_10  = metric_10.compute()
    res_30  = metric_30.compute()
    res_50  = metric_50.compute()

    out = {
        "mAP10": float(res_10["map"]),
        "mAP30": float(res_30["map"]),
        "mAP50": float(res_50["map"]),
        "mAP50_95": float(res_all["map"]),
    }

    # Per-class AP (if available)
    per_class = []
    map_per_class = res_all.get("map_per_class", None)
    map10_per_class = res_10.get("map_per_class", None)
    map30_per_class = res_30.get("map_per_class", None)
    map50_per_class = res_50.get("map_per_class", None)
    if map_per_class is not None:
        ap   = map_per_class.tolist()
        ap10 = map10_per_class.tolist()
        ap30 = map30_per_class.tolist()
        ap50 = map50_per_class.tolist()
        for i in range(len(ap)):
            name = class_names[i] if class_names and i < len(class_names) else f"class_{i + 1}"
            per_class.append({
                "idx": i + 1,
                "name": name,
                "AP": _nan_if_undefined(ap[i]), 
                "AP10": _nan_if_undefined(ap10[i]),
                "AP30": _nan_if_undefined(ap30[i]),
                "AP50": _nan_if_undefined(ap50[i])
            })

    out["per_class"] = per_class

    # Precision/Recall at fixed thresholds
    precision_per_class = (TP.float() / (TP + FP).clamp(min=1)).tolist()
    recall_per_class    = (TP.float() / (TP + FN).clamp(min=1)).tolist()
    overall_precision   = float(TP.sum() / (TP.sum() + FP.sum()).clamp(min=1))
    overall_recall      = float(TP.sum() / (TP.sum() + FN.sum()).clamp(min=1))

    out["precision_overall"]   = overall_precision
    out["recall_overall"]      = overall_recall
    out["precision_per_class"] = precision_per_class
    out["recall_per_class"]    = recall_per_class
    out["pr_conf_thr"]         = pr_conf_thr
    out["pr_iou_thr"]          = pr_iou_thr

    return out

def print_result_report(metrics, loader, class_names):
    """
    Function that prints pretty report with evaluation metrics.
    Uses dataset files to compute number of images and instances.
    """

    num_classes = len(class_names)
    images      = _to_int(len(loader.dataset))
    per_class   = metrics.get("per_class", [])
    p_overall   = _to_float(metrics["precision_overall"])
    r_overall   = _to_float(metrics["recall_overall"])
    map10       = _to_float(metrics["mAP10"])
    map30       = _to_float(metrics["mAP30"])
    map50       = _to_float(metrics["mAP50"])
    map50_95    = _to_float(metrics["mAP50_95"])

    # Count instances per class from labels
    counts, total_instances = _count_instances_per_class(loader.dataset, num_classes)

    # Build quick dicts for per-class AP10/AP30/AP50/AP
    ap10_by_name = {d['name']: d['AP10'] for d in per_class}
    ap30_by_name = {d['name']: d['AP30'] for d in per_class}
    ap50_by_name = {d['name']: d['AP50'] for d in per_class}
    ap_by_name = {d['name']: d['AP'] for d in per_class}

    # Header
    print(f"{'Class':>18} {'Images':>8} {'Instances':>10} {'P':>10} {'R':>10} {'mAP10':>10} {'mAP30':>10} {'mAP50':>10} {'mAP50-95':>10}")

    # Overall row ("all")
    print(f"{'all':>18} {images:8d} {_to_int(total_instances):10d} {p_overall:10.3f} {r_overall:10.3f} {map10:10.3f} {map30:10.3f} {map50:10.3f} {map50_95:10.3f}")

    # Per-class rows
    p_pc = metrics.get("precision_per_class", [])
    r_pc = metrics.get("recall_per_class", [])

    for i, name in enumerate(class_names):
        P_i = _to_float(p_pc[i] if i < len(p_pc) else float("nan"))
        R_i = _to_float(r_pc[i] if i < len(r_pc) else float("nan"))
        AP10_i = _to_float(ap10_by_name.get(name, float("nan")))
        AP30_i = _to_float(ap30_by_name.get(name, float("nan")))
        AP50_i = _to_float(ap50_by_name.get(name, float("nan")))
        AP_i = _to_float(ap_by_name.get(name, float("nan")))
        inst_i = _to_int(counts[i])
        print(f"{name:>18} {images:8d} {inst_i:10d} {P_i:10.3f} {R_i:10.3f} {AP10_i:10.3f} {AP30_i:10.3f} {AP50_i:10.3f} {AP_i:10.3f}")



## Data Augmentation - defininition of data transformations classes

In [None]:
class ComposeTransform:
    """Compose for (image, target) pairs."""
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target
    
class ToTensorTransform:
    """Convert PIL image to tensor, leave target unchanged"""
    def __call__(self, image, target):
        image = F.to_tensor(image) # [C, H, W], float32 in [0, 1]
        return image, target
    
class RandomHorizontalFlipTransform:
    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, image, target):
        if random.random() < self.p:
            # image: [1, H, W] or [C, H, W]
            _, h, w = image.shape
            image = torch.flip(image, dims=[2]) # flip width dimension

            boxes = target["boxes"]
            if boxes.numel() > 0:
            # boxes: [N, 4] in [x_min, y_min, x_max, y_max]
                x_min = boxes[:, 0]
                y_min = boxes[:, 1]
                x_max = boxes[:, 2]
                y_max = boxes[:, 3]

                # flip x-coordinates: x' = w - x
                new_x_min = w - x_max
                new_x_max = w - x_min

                boxes = torch.stack([new_x_min, y_min, new_x_max, y_max], dim=1)
                target["boxes"] = boxes

        return image, target
    
class RandomBrightnessContrastTransform:
    def __init__(self, brightness=0.1, contrast=0.1, p=0.5):
        """
        Relative change of brightness and contrast.
        brightness=0.1 means factor in [0.9, 1.1], etc. 
        """
        self.brightness = brightness
        self.contrast = contrast
        self.p = p

    def __call__(self, image, target):
        if random.random() < self.p:
            # image in tensor [C, H, W]
            # Random brightness
            if self.brightness > 0:
                factor = 1.0 + random.uniform(-self.brightness, self.brightness)
                image = F.adjust_brightness(image, factor)
            # Random contrast
            if self.contrast > 0:
                factor = 1.0 + random.uniform(-self.contrast, self.contrast)
                image = F.adjust_contrast(image, factor)
            image = image.clamp(0.0, 1.0)
        return image, target
    
class RandomGaussianNoiseTransform:
    def __init__(self, sigma=0.01, p=0.5):
        self.sigma = sigma
        self.p = p

    def __call__(self, image, target):
        if random.random() < self.p:
            noise = torch.randn_like(image) * self.sigma
            image = image + noise
            image = image.clamp(0.0, 1.0)
        return image, target


## Prepare DeepLesion dataset for DETR model

In [None]:
# Custom Dataset class for DeepLesion dataset
class DeepLesionDataset(Dataset):
    def __init__(self, root, split):
        # Initialize dataset path, split and transformations
        self.root = root
        self.split = split

        # Apply data augmentations only for the train split
        if split == "train":
            self.transforms = ComposeTransform([
                ToTensorTransform(),
                RandomHorizontalFlipTransform(p=0.5),
                RandomBrightnessContrastTransform(brightness=0.1, contrast=0.1, p=0.5),
                RandomGaussianNoiseTransform(sigma=0.01, p=0.5),
            ])
        else:
            self.transforms = ComposeTransform([
                ToTensorTransform(), # Converts [0, 255] uint8 values to float [0.0, 1.0], and preserves 1 channel
            ])

        # Dataset logic (image paths, annotations, etc.)
        self.image_dir = os.path.join(root, "images", split)
        self.label_dir = os.path.join(root, "labels", split)
        self.image_names = sorted([img for img in os.listdir(self.image_dir) if img.endswith(".png") or img.endswith(".jpg")])

    def __getitem__(self, idx):
        image_name = self.image_names[idx]
        image_path = os.path.join(self.image_dir, image_name)
        label_path = os.path.join(self.label_dir, os.path.splitext(image_name)[0] + ".txt")

        # 1) Load image and boxes in pixel coordinates
        # Load grayscale PIL image
        image = Image.open(image_path).convert("L")

        # Load corresponding bounding boxes and labels
        boxes, labels = [], []
        if os.path.exists(label_path):
            for line in open(label_path):
                cls, x_min, y_min, x_max, y_max = map(float, line.split())
                boxes.append([x_min, y_min, x_max, y_max])
                labels.append(int(cls) - 1) # Convert to 0...K-1 for DETR

        # Create a target dictionary
        if len(boxes) == 0:
            boxes_t = torch.zeros((0, 4), dtype=torch.float32)
            labels_t = torch.zeros((0,), dtype=torch.int64)
        else:
            boxes_t = torch.as_tensor(boxes, dtype=torch.float32)
            labels_t = torch.as_tensor(labels, dtype=torch.int64)

        target = {
            "boxes": boxes_t,   # xyxy in pixels
            "labels": labels_t
        }

        # 2) Apply transforms
        if self.transforms:
            image, target = self.transforms(image, target)
        
        # 3) Convert to normalized center-based format required by DETR
        _, h_image, w_image = image.shape # image is now a Tensor, not PIL's Image object
        boxes = target["boxes"]
        if len(boxes) > 0:
            x_min, y_min, x_max, y_max = boxes.unbind(1)

            # For DETR boxes must be center-based ([cx, cy, w, h])
            bw = x_max - x_min
            bh = y_max - y_min
            cx = x_min + (bw / 2.0)
            cy = y_min + (bh / 2.0)

            # Normalize to [0, 1]
            boxes = torch.stack([
                cx / w_image,
                cy / h_image,
                bw / w_image,
                bh / h_image
            ], dim=1)

        target["boxes"] = boxes

        return image, target
    
    def __len__(self):
        return len(self.image_names)
    

## Prepare DataLoader objects

### Set up the dataset's split path

In [None]:
# [!] Same data splits as for Faster R-CNN
deeplesion_detr_path = "deeplesion_fasterrcnn_split_1" # There are three splits: *_1, *_2 and *_3

### Set up the batch sizes

In [None]:
# Google Colab
train_batch_size = 4      # Set to 4 (keep 2-4, because that is a sweet spot for two-stage detectors. With higher values may hurt training dynamics)
test_val_batch_size = 32   # Set to 32 (high value won't affect metric calculations, but increases memory usage)

In [None]:
# Local
train_batch_size = 1
test_val_batch_size = 1

In [None]:
"""
- Shuffling is enabled for training DataLoader, because SGD benefits from seeing data in a new random order every epoch.
  During validation and testing phases we do not need that, the order does not affect the metrics.

- num_workers is the number of background processes that load & transorm batches in parallel. Good rule of thumb is num_workers being 2-4.

- pin_memory, or pinned (page-locked) host memory, speeds up host to GPU copies and lets us use asynchronous transfers
  It should be set to True if we train on GPU. It usually gives a small lbut real throughput bump. It consumes a bit more system RAM
  and is useless on CPU-only runs.

- Detection models expect lists of images and lists of target dicts, because each image can have different size and has a different
  number of boxes. The default PyTorch collate tries to stack everything into tensors of the same shape, which breaks for 
  variable-length targets. Custom collate_fn function here unzips the list oof pairs into pair of lists so Faster R-CNN can consume them:
    images: List[Tensor[C,H,W]]
    targets: List[Dict{'boxes': Tensor[N,4], 'labels': Tensor[N]}]
  That is exactly what torchvision's detection references use.

"""

train_ds = DeepLesionDataset(deeplesion_data_dir / deeplesion_detr_path, "train")
val_ds = DeepLesionDataset(deeplesion_data_dir / deeplesion_detr_path, "val")
test_ds = DeepLesionDataset(deeplesion_data_dir / deeplesion_detr_path, "test")

def collate_fn(batch):
    # This function turns a list of variable-sized images into
    # a single padded tensor + mask. DETR does not handle padding
    # internally.
    # In short:
    # - it pads images
    # - it builds masks
    # - it outputs ((images, masks), targets) 
    #
    # batch: [(img1, target1), (img2, target2), ...]
    # images: Tensor[B, C, H, W]
    # masks: Tensor[B, H, W]
    # target: dict with keys 'boxes' and 'labels'
    # returns: ((images, masks), targets) 
    
    images, targets = zip(*batch)

    # Determine max size in batch
    max_h = max(img.shape[1] for img in images)
    max_w = max(img.shape[2] for img in images)
    batch_size = len(images)
    c = images[0].shape[0]

    # Create padded batch tensor
    batch_tensor = images[0].new_zeros((batch_size, c, max_h, max_w))

    # Create mask (True = padding)
    masks = torch.ones(
        (batch_size, max_h, max_w), 
        dtype=torch.bool,
        device=batch_tensor.device  # Create mask on the same device
    )

    for i, img in enumerate(images):
        _, h, w = img.shape
        batch_tensor[i, :, :h, :w] = img
        masks[i, :h, :w] = False

    return (batch_tensor, masks), list(targets) 

train_loader = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=test_val_batch_size, shuffle=False, collate_fn=collate_fn, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=test_val_batch_size, shuffle=False, collate_fn=collate_fn, num_workers=2, pin_memory=True)
    

# DETR

## Load and adjust the pre-trained DETR model

In [None]:
#################################################################################################
'''
    - DETR uses set prediction with Hungarian matching
    - The model explicitly predicts "no object"
    - DETR uses a standard ResNet-50 backbone, whose first layer is: model.backbone.body.conv1
'''
#################################################################################################

# Number of classes
num_classes = 8
class_names = ["bone", "abdomen", "mediastinum", "liver", "lung", "kidney", "soft_tissue", "pelvis"]

# Set up the available device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

def _make_detr_args(num_queries, num_classes, dilation):
    """
    Method that reconstructs the args object used for creation of the DETR model.
    """

    return SimpleNamespace(
        # Training
        lr=1e-4,
        lr_backbone=1e-5,
        batch_size=2,
        weight_decay=1e-4,
        epochs=100,
        lr_drop=80,
        clip_max_norm=0.1,

        # Model
        backbone="resnet50",
        dilation=dilation,
        position_embedding="sine",
        enc_layers=6,
        dec_layers=6,
        dim_feedforward=2048,
        hidden_dim=256,
        dropout=0.1,
        nheads=8,
        num_queries=num_queries,
        num_classes=num_classes,
        pre_norm=False,

        # Losses
        aux_loss=True,
        set_cost_class=1.0,
        set_cost_bbox=5.0,
        set_cost_giou=2.0,
        class_loss_coef=1.0,
        bbox_loss_coef=5.0,
        giou_loss_coef=2.0,
        eos_coef=0.1,

        # Misc
        masks=False,
        device=device,
        dataset_file="coco",
        coco_path=None,
        remove_difficult=False,
    )

def _build_detr(args):
    backbone = build_backbone(args)
    transformer = build_transformer(args)

    model = DETR(
        backbone=backbone,
        transformer=transformer,
        num_classes=args.num_classes,
        num_queries=args.num_queries,
        aux_loss=True
    )

    matcher = HungarianMatcher(cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou)

    weight_dict = {
        "loss_ce": args.class_loss_coef,
        "loss_bbox": args.bbox_loss_coef,
        "loss_giou": args.giou_loss_coef
    }

    # aux loss handling
    for i in range(transformer.decoder.num_layers - 1):
        weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})

    criterion = SetCriterion(
        num_classes=num_classes,
        matcher=matcher,
        weight_dict=weight_dict,
        eos_coef=args.eos_coef,
        losses=["labels", "boxes"]
    )

    model.to(device)
    criterion.to(device)

    return model, criterion

def construct_detr_model(num_queries, dilation=True):
    """
    Method that constructs a DETR model using official build_model(),
    with reduced num_queries and input adjusted to 1-channel.

    If dilation=True, we replace stride with dilation in 
    the last convolutional block (DC5).
    """

    # Build args
    args = _make_detr_args(num_queries=num_queries, num_classes=num_classes, dilation=dilation)

    # Build model
    model, criterion = _build_detr(args)

    # Load pretrained COCO weights
    checkpoint = torch.hub.load_state_dict_from_url(
        "https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth",
        map_location="cpu",
        check_hash=True
    )

    state_dict = checkpoint["model"]

    # Remove incompatible COCO-specific weights parameters
    del state_dict["query_embed.weight"]
    del state_dict["class_embed.weight"]
    del state_dict["class_embed.bias"]

    model.load_state_dict(state_dict, strict=False)

    # Adapt DETR to 1-channel CT images
    # - model is a Joiner(nn.Sequential)
    # - model.backbone is a list of models where: 
    #   - backbone[0] -> CNN backbone (ResNet50)
    #   - backbone[1] -> positional encoding model
    old_conv = model.backbone[0].body.conv1 # shape: [64, 3, 7, 7] ( == [out_channels, in_channels, kernel_height, kernel_width])

    new_conv = nn.Conv2d(
        in_channels=1,
        out_channels=old_conv.out_channels,
        kernel_size=old_conv.kernel_size,
        stride=old_conv.stride,
        padding=old_conv.padding,
        bias=False
    )

    # Initialize 1-channel conv layer with averaged RGB weights (CT-appropriate)
    with torch.no_grad():
        new_conv.weight[:] = old_conv.weight.mean(dim=1, keepdim=True)

    model.backbone[0].body.conv1 = new_conv

    # [OPTIONAL] Freeze 2 first backbone layers
    # for name, param in model.backbone[0].named_parameters():
    #     if "layer1" in name or "layer2" in name:
    #         param.requires_grad = False

    # Move model to GPU if available
    model.to(device)
    criterion.to(device)

    # Sanity check
    # print(f"First conv layer shape: {model.backbone[0].body.conv1.weight.shape}") # Should be [64, 1, 7, 7]

    return model, criterion

## Training Loop 

In [None]:
"""
- For DETR it is common to use Adam as the optimizer.
- Hyperparameters:
    - momentum: 
        adds an exponential moving average of past gradients to the current step, which causes smoother updates,
        less zig-zagging and faster convergance. Typically set to 0.9, and rarely needs tuning.
    - weight_decay (L2 regularization):
        Penalizes large weights to reduce overfitting (shrinks params each step).
    - step_size (in StepLR):
        Every step_size epochs, the LR scheduler triggers a decay.
    - gamma (in StepLR):
        Multiplicative LR factor at each step: new_lr = old_lr * gamma. Commonly set to 0.1.

- Cross-validate empirically, due to computation power constraints.

- optimizer.zero_grad(set_to_none=True) - set_to_none parameter set to True means that for each parameter param.grad is set to None
  (no tensor is kept). On the next backward() PyTorch allocates a fresh grad tensor and writes into it. It causes faster & less memory traffic
  by avoiding writing zeros over large grad buffers every step. Lowers memory footprint by letting unused grads be garbage-collected and reallocated
  only when needed.

"""

def train_one_config(
    train_loader, val_loader, device,
    learning_rate=1e-4, weight_decay=1e-4,
    max_epochs=100, patience=30, metric_key="mAP50_95",
    gamma=0.1, warmup_epochs=1, num_queries=50, dilation=True, step_size=40
):
    # 1) Construct the model
    model, criterion = construct_detr_model(num_queries=num_queries, dilation=dilation)

    # 2) Set up optimizer with param groups (different LRs for head, transformer and backbone)
    backbone_lr = learning_rate * 0.1
    param_dicts = [
        {
            "params": [p for n, p in model.named_parameters() if n.startswith("backbone.") and p.requires_grad],
            "lr": backbone_lr,
            "initial_lr": backbone_lr,
        },
        {
            "params": [p for n, p in model.named_parameters() if (not n.startswith("backbone.")) and p.requires_grad],
            "lr": learning_rate,
            "initial_lr": learning_rate,
        },
    ]

    optimizer = torch.optim.AdamW(param_dicts, weight_decay=weight_decay)

    # 3) Learning rate scheduler (CosineAnnealingLR is fine, but StepLR is also common for DETR)
    # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[70], gamma=gamma)
    # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs, eta_min=0.0)

    best_metric = -float("inf")
    best_epoch = -1
    best_model = None
    best_optimizer = None
    best_scheduler = None
    epochs_no_improve = 0
    history = []
    global_step = 0

    # Set up warmup
    warmup_iters = warmup_epochs * len(train_loader)

    # 4) Train model
    for epoch in range(1, max_epochs + 1):
        print(f"\n*** Epoch [{epoch}/{max_epochs}] started ***")
        model.train()
        criterion.train()
        running_loss = 0.0
        n_processed_images = 0

        # -- Training loop
        for batch_idx, batch in enumerate(train_loader, start=1):
            global_step += 1
            
            # ---- Warmup phase
            if global_step <= warmup_iters:
                warmup_factor = global_step / float(warmup_iters)
                for param_group in optimizer.param_groups:
                    param_group["lr"] = param_group["initial_lr"] * warmup_factor
                warmup_percent = 100.0 * warmup_factor
            else: # let scheduler manage LR
                warmup_percent = 100.0

            # ---- Unpack DETR batch
            (images, masks), targets = batch
            images = images.to(device)
            masks = masks.to(device)
            targets = [{key: val.to(device) for key, val in target.items()} for target in targets]

            # ---- Wrap into NestedTensor (facebook DETR expects this)
            nested_tensor = NestedTensor(images, masks)
            
            # ---- Forward + loss
            outputs = model(nested_tensor)
            loss_dict = criterion(outputs, targets)
            weight_dict = criterion.weight_dict
            loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict if k in weight_dict)

            optimizer.zero_grad(set_to_none=True) # Allocates grads fresh during backward pass
            loss.backward()
            optimizer.step()

            running_loss += float(loss.item())
            n_processed_images += images.size(0)

            if batch_idx % 10 == 0 or batch_idx == len(train_loader):
                avg_loss_so_far = running_loss / batch_idx
                status = (
                    f"\rEpoch: [{epoch}/{max_epochs}], "
                    f"Images: [{n_processed_images}/{len(train_loader.dataset)}], "
                    f"Warmup: [{float(warmup_percent):.2f}%], "
                    f"Loss: {avg_loss_so_far:.4f}"
                )
                print(status, end="", flush=True)

        print()
        lr_scheduler.step()
        train_loss = running_loss / max(1, len(train_loader))
        print(f"*** Epoch [{epoch}/{max_epochs}] finished -> Loss: {train_loss:.4f} ***")

        # Validation
        print("*** Validation started ***")
        val_metrics = evaluate_detector(model, val_loader, device, num_classes, class_names)
        val_score = float(val_metrics[metric_key])

        history.append({"epoch": epoch, "train_loss": train_loss, **val_metrics})
        print(f"Loss={train_loss:.4f}, mAP50={val_metrics['mAP50']:.4f}, mAP50_95={val_metrics['mAP50_95']:.4f}")
        print(f"*** Validation finished ***")

        # Early stopping check
        if val_score > best_metric + 1e-6:
            best_metric = val_score
            best_epoch = epoch
            best_model = copy.deepcopy(model.state_dict())
            best_optimizer = copy.deepcopy(optimizer.state_dict())
            best_scheduler = copy.deepcopy(lr_scheduler.state_dict())
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"*** Early stopping after {epoch}/{max_epochs} epochs (best at {best_epoch} with {metric_key}={best_metric:.4f}). ***")
                break

    return {
        "best_metric": best_metric,
        "best_epoch": best_epoch,
        "best_model": best_model,
        "best_optimizer": best_optimizer,
        "best_scheduler": best_scheduler,
        "history": history,
        "lr": learning_rate,
        "weight_decay": weight_decay
    }

# =================================================================================================================================================
# =================================================================================================================================================

def plot_training_history(result, title=None):
    """
    Plots the history of one training configuration. 
    result is the dictionary returned by train_one_config
    """

    history = result["history"]
    epochs = [ h["epoch"] for h in history ]
    train_losses = [ h["train_loss"] for h in history ]
    map50 = [ h["mAP50"] for h in history ]
    map50_95 = [ h["mAP50_95"] for h in history ]

    if title is None:
        title = f"lr={result['lr']}, wd={result['weight_decay']}"

    fig, ax1 = plt.subplots(figsize=(7, 4))

    # Loss on left y-axis
    ax1.plot(epochs, train_losses, color="blue", marker="o", label="Train loss")
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss")
    ax1.grid(True, which="both", axis="both", linestyle="--", alpha=0.3)

    # mAP on right y-axis
    ax2 = ax1.twinx()
    ax2.plot(epochs, map50, color="green", marker="x", linestyle="-", label="mAP50")
    ax2.plot(epochs, map50_95, color="red", marker="s", linestyle=":", label="mAP50-95")
    ax2.set_ylabel("mAP")

    # Combine legends
    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + labels2, loc="best")

    plt.title(title)
    plt.tight_layout()
    plt.show()

# =================================================================================================================================================
# =================================================================================================================================================

# Set up hyperparameters
learning_rates = [1e-4] # 1e-4 for DETR
weight_decays = [1e-4] # [1e-4, 5e-4]
max_epochs = 100
num_queries = 50
dilation = True
warmup_epochs = 1
patience = max_epochs + 1
best_result = None

print("*** Training started ***")
for learning_rate in learning_rates:
    for weight_decay in weight_decays:
        print(f"[HYPERPARAMETERS]:\n"
              f"    learning_rate: {learning_rate}\n"
              f"    weight_decay: {weight_decay}\n"
              f"    num_queries: {num_queries}\n"
              f"    dilation: {dilation}\n")
        result = train_one_config(
            train_loader=train_loader, val_loader=val_loader, device=device,
            learning_rate=learning_rate, weight_decay=weight_decay, max_epochs=max_epochs, patience=patience,
            metric_key="mAP50_95", warmup_epochs=warmup_epochs, num_queries=num_queries, dilation=dilation
        )

        plot_training_history(result)

        if best_result is None or result["best_metric"] > best_result["best_metric"]:
            best_result = result

plot_training_history(best_result, title=f"[BEST] lr={result['lr']}, wd={result['weight_decay']}")
print("*** Training complete ***")
print(f"Best config: lr={best_result['lr']} wd={best_result['weight_decay']} "
      f"epoch={best_result['best_epoch']} mAP50_95={best_result['best_metric']:.4f}")

# Save the best model checkpoint
best_checkpoint = {
    "state_dict": best_result["best_model"],
    "best_optimizer": best_result["best_optimizer"],
    "best_scheduler": best_result["best_scheduler"],
    "epoch": best_result["best_epoch"],
    "metric_key": "mAP50_95",
    "metric_value": best_result["best_metric"],
    "hp": {
        "lr": best_result["lr"],
        "weight_decay": best_result["weight_decay"],
        "momentum": 0.9,
        "step_size": 3,
        "gamma": 0.1,
        "max_epochs": max_epochs,
        "patience": patience,
    },
    # Describe how to reconstruct the model
    "model_spec": {
        "arch": "detr_resnet50",
        "min_size": 512,
        "max_size": 512,
        "in_channels": 1,
        "num_classes": num_classes,
        "num_queries": num_queries,
        "dilation": dilation,
        "image_mean": [0.5],
        "image_std": [0.5],
    },
    "class_names": class_names,
    "versions": {"torch": torch.__version__, "torchvision": torchvision.__version__},
}

# Save checkpoint locally
os.makedirs("checkpoints", exist_ok=True)
ts = datetime.now().strftime("%Y%m%d-%H%M%S")
save_path = f"checkpoints/detr_best_{ts}.pt"
torch.save(best_checkpoint, save_path)
print(f"[{ts}] Saved best checkpoint to {save_path}")

# [GOOGLE COLLAB ONLY - uncomment] Save checkpoint on Google drive
# gdrive_save_dir = Path("detr_checkpoints")
# gdrive_save_dir.mkdir(parents=True, exist_ok=True)
# gdrive_save_path = os.path.join(str(gdrive_save_dir), f"detr_best_{ts}.pt")
# torch.save(best_checkpoint, gdrive_save_path)

## Final Evaluation

In [None]:
# Load best weights and evaluate on the test set
best_checkpoint = torch.load(save_path, map_location="cpu", weights_only=False)
best_model, _ = construct_detr_model(
    num_queries=best_checkpoint["model_spec"]["num_queries"],
    dilation=best_checkpoint["model_spec"]["dilation"]
)
best_model.load_state_dict(best_checkpoint["state_dict"])

# Evaluate the best model
print(f"*** Evaluation started ***")
test_metrics = evaluate_detector(best_model, test_loader, device, num_classes, class_names)

print(f"*** Evaluation finished ***")
print(test_metrics)
print_result_report(test_metrics, test_loader, class_names)
print_froc_curve_info(best_model, test_loader, device)

# Disconnect Google Colab runtime

In [None]:
from google.colab import runtime
runtime.unassign()