In [1]:
import torch
from torch.utils.data import DataLoader

from segmentation import DeepLabV3Plus, IDRiDSegmentation, get_transforms

In [2]:
DEVICE = "cuda:5"
cfg = {
    "batch_size": 2,
    "size": [1424, 2144],
    "crop_size": [1024, 1024],
    "norm_mean": [0.485, 0.456, 0.406],
    "norm_std": [0.229, 0.224, 0.225],
    "backbone": "resnet50",
    "num_classes": 5,
    "output_stride": 8,
}

In [3]:
def compute_batch_metrics(preds, targets, num_classes):
    """
    Compute metrics for a batch of predictions and targets to later compute metrics for the whole dataset.
    Args:
        - preds (torch.Tensor): The predictions of shape (N, num_classes, H, W)
        - targets (torch.Tensor): The targets of shape (N, num_classes, H, W)
        - num_classes (int): The number of classes
    Returns:
        - intersection (torch.Tensor): The intersection of the predictions and targets of shape (num_classes,)
        - union (torch.Tensor): The union of the predictions and targets of shape (num_classes,)
        - pred_cardinality (torch.Tensor): The cardinality of the predictions of shape (num_classes,)
        - target_cardinality (torch.Tensor): The cardinality of the targets of shape (num_classes,)
    """
    intersection = torch.zeros(num_classes, device=preds.device, dtype=torch.float)
    union = torch.zeros(num_classes, device=preds.device, dtype=torch.float)
    pred_cardinality = torch.zeros(num_classes, device=preds.device, dtype=torch.float)
    target_cardinality = torch.zeros(num_classes, device=preds.device, dtype=torch.float)

    for cls in range(num_classes):
        pred_mask = preds[:, cls, :, :]
        target_mask = targets[:, cls, :, :]

        inter = (pred_mask & target_mask).sum() # the intersection is same as true positive or class correct
        uni = (pred_mask | target_mask).sum()

        intersection[cls] = inter
        union[cls] = uni
        pred_cardinality[cls] = pred_mask.sum()
        target_cardinality[cls] = target_mask.sum() # the target cardinality is same as class total

    return intersection, union, pred_cardinality, target_cardinality


In [4]:
@torch.no_grad()
def evaluate(model, data_loader, num_classes, device):
    total_samples = 0
    total_intersection = torch.zeros(num_classes, device=device)
    total_union = torch.zeros(num_classes, device=device)
    total_pred_cardinality = torch.zeros(num_classes, device=device)
    total_target_cardinality = torch.zeros(num_classes, device=device)
    model.eval()
    with torch.no_grad():
        for image, mask in data_loader:
            image = image.to(device)
            mask = mask.to(device)

            with torch.autocast(device_type='cuda', dtype=torch.float16):
                pred_logits = model(image)
            
            total_samples += mask.shape[0]

            pred_prob = torch.sigmoid(pred_logits)
            pred_mask = torch.zeros_like(pred_prob).long()
            pred_mask[pred_prob > 0.50] = 1
            intersection, union, pred_cardinality, target_cardinality = compute_batch_metrics(pred_mask, mask, num_classes)

            total_intersection += intersection
            total_union += union
            total_pred_cardinality += pred_cardinality
            total_target_cardinality += target_cardinality

    iou_per_class = total_intersection / (total_union + 1e-6)
    dice_per_class = 2 * total_intersection / (total_pred_cardinality + total_target_cardinality + 1e-6)
    pixel_accuracy_per_class = total_intersection / (total_target_cardinality + 1e-6)

    metrics = {
        "iou_per_class": iou_per_class.tolist(),
        "dice_per_class": dice_per_class.tolist(),
        "pixel_accuracy_per_class": pixel_accuracy_per_class.tolist(),
    }

    return metrics

In [5]:
transform_train, transform_val_test = get_transforms(cfg["size"], cfg["crop_size"], cfg["norm_mean"], cfg["norm_std"])
test_set = IDRiDSegmentation(root="./data/A. Segmentation/", mode="eval", transform=transform_val_test)
test_loader = DataLoader(test_set, batch_size=cfg["batch_size"], shuffle=False, num_workers=2, prefetch_factor=10)

In [6]:
model = DeepLabV3Plus(
    backbone=cfg["backbone"],
    num_classes=cfg["num_classes"],
    output_stride=cfg["output_stride"],
)
checkpoint = torch.load("./saved/dlv3p_os_8_100.pth", map_location="cpu", weights_only=True)
model.load_state_dict(checkpoint)
model.to(DEVICE)
model.eval()

DeepLabV3Plus(
  (resnet): ResNet(
    (layer_0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    )
    (layer_1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (dow

In [9]:
def print_metrics(metrics, mode):
    tasks = ["Microaneurysms", "Haemorrhages", "Hard Exudates", "Soft Exudates", "Optic Disc"]
    print(f"{mode} Metrics:")
    for i in range(5):
        print(f"\t{tasks[i]} IoU: {metrics['iou_per_class'][i]:.04f},\t{tasks[i]} Dice Coefficient: {metrics['dice_per_class'][i]:.04f},\t{tasks[i]} Pixel Accuracy: {metrics['pixel_accuracy_per_class'][i]:.04f}")


In [10]:
test_metrics = evaluate(model, test_loader, 5, DEVICE)
print_metrics(test_metrics, "Test")

Test Metrics:
	Microaneurysms IoU: 0.2671,	Microaneurysms Dice Coefficient: 0.4216,	Microaneurysms Pixel Accuracy: 0.3860
	Haemorrhages IoU: 0.3261,	Haemorrhages Dice Coefficient: 0.4918,	Haemorrhages Pixel Accuracy: 0.4115
	Hard Exudates IoU: 0.1706,	Hard Exudates Dice Coefficient: 0.2915,	Hard Exudates Pixel Accuracy: 0.1785
	Soft Exudates IoU: 0.4459,	Soft Exudates Dice Coefficient: 0.6168,	Soft Exudates Pixel Accuracy: 0.6365
	Optic Disc IoU: 0.9132,	Optic Disc Dice Coefficient: 0.9546,	Optic Disc Pixel Accuracy: 0.9767
