In [None]:
from torch.utils.data import Dataset
from PIL import Image
import torchvision
import torch
import pandas as pd
import ast
import numpy as np

import os
from torch import optim, nn, utils, Tensor
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from torchvision.ops import nms, box_iou
from skimage import io
from skimage.color import gray2rgb
from torchvision.ops import nms 
import seaborn as sns


import lightning as L
from retinanet import retinanet_resnet50_fpn

from model import RetinaDataset, RetinaNet, collate

from faster_rcnn import RCNNDataset, FasterRCNNModel
from faster_rcnn import collate as collate_rcnn

from results import evaluate_model_retina

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def display_predictions_and_ground_truth_grid(model, dataloader, num_images=30, grid_size=(5, 6), figsize=(20, 20), iou_threshold=0.15, score_threshold=0.2):
    """
    Display predictions and ground truth for a few images from the dataloader in a grid layout.
    Applies NMS to the predicted bounding boxes and filters out low-confidence predictions.
    """
    model.eval()
    model.to(device)  # Move model to device
    image_count = 0
    total_images = grid_size[0] * grid_size[1]

    f, ax = plt.subplots(grid_size[0], grid_size[1], figsize=figsize)
    ax = ax.ravel()

    # To iterate over the dataloader in reverse, we can first convert it to a list and then reverse it
    reversed_loader = list(dataloader)[::-1]

    with torch.no_grad():
        for images, targets in reversed_loader:
            if image_count >= num_images:
                break
            
            # Move images and targets to the same device as the model
            images = images.to(device)
            targets = {k: [v.to(device) for v in t] for k, t in targets.items()}
            
            # Perform inference
            outputs = model(images)

            # Since outputs are returned as a tuple, we should access the list of detections
            outputs = outputs[1]  # Select the list of outputs for each image

            for j in range(len(images)):
                if image_count >= num_images or image_count >= total_images:
                    break
                
                img = images[j].cpu()

                gt_boxes = targets['bbox'][j].cpu().numpy()  
                gt_labels = targets['labels'][j].cpu().numpy()          

                pred_boxes = outputs[j]['bbox'].cpu()  
                pred_labels = outputs[j]['labels'].cpu()  
                pred_scores = outputs[j]['scores'].cpu()

                # Filter out low-confidence boxes
                high_confidence_idx = pred_scores > score_threshold
                pred_boxes = pred_boxes[high_confidence_idx]
                pred_labels = pred_labels[high_confidence_idx]
                pred_scores = pred_scores[high_confidence_idx]

                # Apply NMS
                keep = nms(pred_boxes, pred_scores, iou_threshold)
                pred_boxes = pred_boxes[keep].numpy()
                pred_labels = pred_labels[keep].numpy()
                pred_scores = pred_scores[keep].numpy()

                # Plot image
                img = img.numpy()  # Convert tensor to numpy array
                img = np.transpose(img, [1, 2, 0])  # Convert from (C, H, W) to (H, W, C)
                ax[image_count].imshow(img, cmap='gray')
                
                # Plot ground truth boxes in green
                for box, label in zip(gt_boxes, gt_labels):
                    x_min, y_min, x_max, y_max = box
                    width = x_max - x_min
                    height = y_max - y_min
                    rect = patches.Rectangle((x_min, y_min), width, height, linewidth=2, edgecolor='g', facecolor='none')
                    ax[image_count].add_patch(rect)
                    ax[image_count].text(x_min, y_min, str(label), color='green', fontsize=10, backgroundcolor='white')
                
                # Plot predicted boxes in red
                for box, label, score in zip(pred_boxes, pred_labels, pred_scores):
                    x_min, y_min, x_max, y_max = box
                    width = x_max - x_min
                    height = y_max - y_min
                    score = score*2 if score*2 < 1 else score*(1/0.6)
                    rect = patches.Rectangle((x_min, y_min), width, height, linewidth=2, edgecolor='r', facecolor='none')
                    ax[image_count].add_patch(rect)
                    ax[image_count].text(x_min, y_min, f'{label}:{score:.2f}', color='red', fontsize=10, backgroundcolor='white')
                
                ax[image_count].axis('off')
                image_count += 1

    plt.tight_layout()
    plt.savefig('sample_preds.png')
    plt.show()

Bounding Box Area Plot using multi-class dataset

In [None]:
# MAKE COOL PLOT OF DATA MASS SHAPE BOUNDING BOX SIZE DISTRIBUTION
df1 = pd.read_csv('/vol/biomedic3/bglocker/mscproj24/mrm123/retinanet/csv_files/multi_class/train.csv')
df2 = pd.read_csv('/vol/biomedic3/bglocker/mscproj24/mrm123/retinanet/csv_files/multi_class/val.csv')
df3 = pd.read_csv('/vol/biomedic3/bglocker/mscproj24/mrm123/retinanet/csv_files/multi_class/test.csv')

df = pd.concat([df1, df2, df3])

 # VINDR GRAPH
areas = {0: [], 1: [], 2: [], 3: []}
for _, row in df.iterrows():
    boxes = ast.literal_eval(row['bbox'])
    labels = ast.literal_eval(row['label'])

    for box, label in zip(boxes, labels):
        area = (box[2] - box[0]) * (box[3] - box[1])  # Calculate area
        areas[label].append(area)
        
data = [areas[0], areas[1], areas[2], areas[3]]
labels = ['Masses', 'Calcifications', 'Asymmetries', 'Architectural Distortions']

plt.figure(figsize=(10, 6))

sns.set_palette("Set2")
sns.boxplot(data=data, palette="Set2", showfliers=False)
sns.swarmplot(data=data, color=".25", size=1.8, log_scale=True)

plt.title('Bounding Box Areas by Class', fontsize=16)
plt.xlabel('Class', fontsize=14)
plt.ylabel('Bounding Box Area px²', fontsize=14)
plt.xticks(ticks=range(len(labels)), labels=labels, fontsize=12)
plt.yticks([1e2, 1e3, 1e4, 1e5], fontsize=12)

plt.grid(True, linestyle='--', alpha=0.7)
plt.savefig('EMBEDBoundingBoxAreas.png')
plt.show()


FROC Curve Evaluation

In [None]:
def evaluate_froc(models, title, iou_threshold=0.5, batch_size=32, device='cuda'):
    """
    Function to get FROC Curve plot.
    """
    def compute_tp_fp_fn(pred_boxes, gt_boxes, iou_threshold=0.5, device='cuda'):
        pred_boxes = pred_boxes.to(device)
        gt_boxes = gt_boxes.to(device)

        if len(pred_boxes) == 0 and len(gt_boxes) == 0:
            return 0, 0, 0  # No predictions and no ground truths, ignore this case

        if len(pred_boxes) == 0:
            return 0, 0, len(gt_boxes)  # No predictions, all ground truths are false negatives

        if len(gt_boxes) == 0:
            return 0, len(pred_boxes), 0  # No ground truths, all predictions are false positives

        iou_matrix = box_iou(pred_boxes, gt_boxes)

        tp = 0
        fp = 0
        fn = len(gt_boxes)

        for i in range(len(pred_boxes)):
            max_iou = torch.max(iou_matrix[i])

            if max_iou >= iou_threshold:
                tp += 1
                fn -= 1
            else:
                fp += 1

        return tp, fp, fn

    def compute_froc(tp_list, fp_list, num_images):
        sensitivity = np.cumsum(tp_list) / np.sum(tp_list)
        fppi = np.cumsum(fp_list) / num_images
        return sensitivity, fppi

    device = torch.device(device)
    
    # Prepare to store results for each model
    froc_results = {}

    for model_name, model_data in models.items():
        model, test_dataset = model_data

        model.eval()
        model.to(device)

        test_loader = utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate)

        all_outputs = []
        all_targets = []

        with torch.no_grad():
            for images, targets in test_loader:
                images = images.to(device)
                outputs = model(images)

                all_outputs.extend(outputs[1])
                formatted_targets = [{'boxes': b, 'labels': l.int()} for b, l in zip(targets['bbox'], targets['labels'])]
                all_targets.extend(formatted_targets)

        all_tp = []
        all_fp = []
        all_fn = []

        for output, target in zip(all_outputs, all_targets):
            # Skip if there are no predictions
            print(output)
            if output is None or 'bbox' not in output or output['bbox'].numel() == 0:
                continue  
            
            # Skip if there are no ground truth objects
            if target['boxes'].numel() == 0:
                continue  

            pred_boxes = output['bbox'].to(device)
            gt_boxes = target['boxes'].to(device)

            tp, fp, fn = compute_tp_fp_fn(pred_boxes, gt_boxes, iou_threshold=iou_threshold, device=device)
            all_tp.append(tp)
            all_fp.append(fp)
            all_fn.append(fn)

        num_images = len(test_loader.dataset)
        sensitivity, fppi = compute_froc(all_tp, all_fp, num_images)
        
        # Store results
        froc_results[model_name] = (sensitivity, fppi)

    # Plot the results
    plt.figure()
    for model_name, (sensitivity, fppi) in froc_results.items():
        plt.plot(fppi, sensitivity, linewidth=1, linestyle='-', label=model_name)
    
    plt.xlabel('False Positives Per Image (FPPI)')
    plt.ylabel('Sensitivity')
    plt.title(title)
    plt.grid(True)
    plt.legend()
    plt.savefig(title + '.png')
    plt.show()

Example use of FROC plot function

In [None]:
path_to_checkpoint_1 = '/vol/biomedic3/bglocker/mscproj24/mrm123/slurm_scripts/FasterRCNN Models MV0/nlvf848c/checkpoints/best_massfil_vindr_0ns_2.ckpt'
path_to_checkpoint_2 = '/vol/biomedic3/bglocker/mscproj24/mrm123/slurm_scripts/FasterRCNN Models MV1-1/aos75xcs/checkpoints/best_massfil_vindr_1:1_2.ckpt'
path_to_checkpoint_3 = '/vol/biomedic3/bglocker/mscproj24/mrm123/slurm_scripts/Retinanet Models M2-1/uj817edx/checkpoints/best_massfil_vindr_2:1_1.ckpt'

models_vindr = {
    'VinDr 0NS': (
        FasterRCNNModel.load_from_checkpoint(path_to_checkpoint_1), #, ratios=[1.0, 1.1927551241662488, 0.8383950567380818], scales=[0.6701667306842978, 0.430679744216531, 1.092957151337979]),
        RCNNDataset(csv_file='csv_files/massfil_vindr_1:1/test.csv', augmentation=False)
    ),
    'VinDr 1:1': (
        FasterRCNNModel.load_from_checkpoint(path_to_checkpoint_2), #, ratios=[1.0, 1.1927551241662488, 0.8383950567380818], scales=[0.6701667306842978, 0.430679744216531, 1.092957151337979]),
        RCNNDataset(csv_file='csv_files/massfil_vindr_1:1/test.csv', augmentation=False)
    ),
    # 'VinDr 2:1': (
    #     RetinaNet.load_from_checkpoint(path_to_checkpoint_3, ratios=[1.0, 1.1927551241662488, 0.8383950567380818], scales=[0.6701667306842978, 0.430679744216531, 1.092957151337979]),
    #     RetinaDataset(csv_file='csv_files/massfil_vindr_1:1/test.csv', augmentation=False)
    # )
}


evaluate_froc(models_vindr, title='VinDr FROC Curves', iou_threshold=0.5, batch_size=8, device='cuda')