<a href="https://colab.research.google.com/github/rmartimarly/teaching_misc/blob/main/FROC_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

def compute_iou(box1, box2):
    """Compute IoU between two bounding boxes."""
    x1, y1, x2, y2 = box1
    x1_p, y1_p, x2_p, y2_p = box2

    xi1, yi1 = max(x1, x1_p), max(y1, y1_p)
    xi2, yi2 = min(x2, x2_p), min(y2, y2_p)
    inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)

    box1_area = (x2 - x1) * (y2 - y1)
    box2_area = (x2_p - x1_p) * (y2_p - y1_p)
    union_area = box1_area + box2_area - inter_area

    return inter_area / union_area if union_area > 0 else 0

def match_predictions_to_gt(gt_boxes, pred_boxes, iou_threshold=0.5):
    """Match predictions to ground truth boxes based on IoU."""
    matched_gt = set()
    tp = 0
    for pred in pred_boxes:
        for i, gt in enumerate(gt_boxes):
            if i in matched_gt:
                continue
            if compute_iou(pred, gt) >= iou_threshold:
                matched_gt.add(i)
                tp += 1
                break

    return tp, len(pred_boxes) - tp

def compute_froc_curve(gt_list, pred_list, iou_threshold=0.5):
    """Compute FROC curve given lists of ground truth and predicted bounding boxes."""
    fp_rates = []
    sensitivities = []

    total_lesions = sum(len(gt) for gt in gt_list)

    for num_fp in range(0, 10):  # Iterate over increasing FP thresholds
        tp_total = 0
        fp_total = 0

        for gt_boxes, pred_boxes in zip(gt_list, pred_list):
            tp, fp = match_predictions_to_gt(gt_boxes, pred_boxes[:num_fp + 1], iou_threshold)
            tp_total += tp
            fp_total += fp

        sensitivity = tp_total / total_lesions if total_lesions > 0 else 0
        avg_fp = fp_total / len(gt_list)

        fp_rates.append(avg_fp)
        sensitivities.append(sensitivity)

    return fp_rates, sensitivities

def plot_froc(fp_rates, sensitivities):
    """Plot the FROC curve."""
    plt.figure(figsize=(8, 6))
    plt.plot(fp_rates, sensitivities, marker='o', linestyle='-')
    plt.xlabel('False Positives per Image')
    plt.ylabel('Sensitivity')
    plt.title('FROC Curve')
    plt.grid()
    plt.show()

# Example Usage
gt_boxes = [[(30, 30, 60, 60)], [(40, 40, 70, 70)]]  # List of ground truth boxes for multiple images
pred_boxes = [[(32, 32, 62, 62), (80, 80, 110, 110)], [(42, 42, 72, 72)]]  # Predictions

fp_rates, sensitivities = compute_froc_curve(gt_boxes, pred_boxes)
plot_froc(fp_rates, sensitivities)
