In [1]:
import torch
from torch.utils.data import DataLoader

from codebase import MultiTaskModel, IDRiDSegmentation, IDRiDClassification, get_seg_transforms

  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [2]:
DEVICE = "cuda:5"
cfg = {
    "batch_size": 8,
    "size": [712, 1072],
    "norm_mean": [0.485, 0.456, 0.406],
    "norm_std": [0.229, 0.224, 0.225],

    "backbone": "resnet50",
    "task_specs": {
        "classification_1": 5,
        "classification_2": 3,
        "segmentation_MA": 1,
        "segmentation_HE": 1,
        "segmentation_EX": 1,
        "segmentation_SE": 1,
        "segmentation_OD": 1,
    },
    "num_experts": 3,

    "train_path": "./data/B. Disease Grading/2. Groundtruths/a. IDRiD_Disease Grading_Training Labels.csv",
    "test_path": "./data/B. Disease Grading/2. Groundtruths/b. IDRiD_Disease Grading_Testing Labels.csv",
}

In [3]:
def compute_batch_metrics(outputs, targets, num_classes, device):
    intersection = torch.zeros(num_classes, device=device, dtype=torch.float)
    union = torch.zeros(num_classes, device=device, dtype=torch.float)
    pred_cardinality = torch.zeros(num_classes, device=device, dtype=torch.float)
    target_cardinality = torch.zeros(num_classes, device=device, dtype=torch.float)

    task_names = ["segmentation_MA", "segmentation_HE", "segmentation_EX", "segmentation_SE", "segmentation_OD"]
    
    if "classification_1" in targets:
        pred_class_1 = torch.argmax(outputs["classification_1_logits"], dim=1)
        classification_1_correct = (pred_class_1 == targets["classification_1"]).sum().item()
    else:
        classification_1_correct = 0

    if "classification_2" in targets:
        pred_class_2 = torch.argmax(outputs["classification_2_logits"], dim=1)
        classification_2_correct = (pred_class_2 == targets["classification_2"]).sum().item()
    else:
        classification_2_correct = 0

    for i in range(num_classes):
        if task_names[i] not in targets:
            continue

        gt_mask = targets[task_names[i]]
        pred_mask_logits = outputs[task_names[i]+"_logits"]
        pred_mask = torch.zeros_like(gt_mask)
        pred_mask[pred_mask_logits > 0.5] = 1

        inter = (pred_mask & gt_mask).sum() # the intersection is same as true positive or class correct
        uni = (pred_mask | gt_mask).sum()

        intersection[i] = inter
        union[i] = uni
        pred_cardinality[i] = pred_mask.sum()
        target_cardinality[i] = gt_mask.sum() # the target cardinality is same as class total

    batch_metrics = {
        "classification_1_correct": classification_1_correct,
        "classification_2_correct": classification_2_correct,
        "intersection": intersection,
        "union": union,
        "pred_cardinality": pred_cardinality,
        "target_cardinality": target_cardinality,
    }

    return batch_metrics


In [4]:
@torch.no_grad()
def evaluate_model(model, classi_loader, seg_loader, device, num_classes=5, force_all_experts=False, p=0.75):
    total_samples_classification = 0
    total_samples_segmentation = 0
    total_correct_1 = 0
    total_correct_2 = 0
    total_intersection = torch.zeros(num_classes, device=device)
    total_union = torch.zeros(num_classes, device=device)
    total_pred_cardinality = torch.zeros(num_classes, device=device)
    total_target_cardinality = torch.zeros(num_classes, device=device)
    model.eval()
    with torch.no_grad():
        for images, targets in classi_loader:
            images = images.to(device)
            targets = {k: v.to(device) for k, v in targets.items()}

            with torch.autocast(device_type='cuda', dtype=torch.float16):
                outputs, _ = model(images, force_all_experts=force_all_experts, p=p)

            batch_metrics = compute_batch_metrics(outputs, targets, num_classes, device)

            total_samples_classification += images.shape[0]
            total_correct_1 += batch_metrics["classification_1_correct"]
            total_correct_2 += batch_metrics["classification_2_correct"]

        for images, targets in seg_loader:
            images = images.to(device)
            targets = {k: v.to(device) for k, v in targets.items()}

            with torch.autocast(device_type='cuda', dtype=torch.float16):
                outputs, _ = model(images, force_all_experts=force_all_experts, p=p)

            batch_metrics = compute_batch_metrics(outputs, targets, num_classes, device)

            total_samples_segmentation += images.shape[0]
            total_intersection += batch_metrics["intersection"]
            total_union += batch_metrics["union"]
            total_pred_cardinality += batch_metrics["pred_cardinality"]
            total_target_cardinality += batch_metrics["target_cardinality"]

    iou_per_class = total_intersection / (total_union + 1e-6)
    dice_per_class = 2 * total_intersection / (total_pred_cardinality + total_target_cardinality + 1e-6)
    pixel_accuracy_per_class = total_intersection / (total_target_cardinality + 1e-6)

    metrics = {
        "accuracy_1": 100 * total_correct_1 / total_samples_classification,
        "accuracy_2": 100 * total_correct_2 / total_samples_classification,
        "iou_per_class": iou_per_class.tolist(),
        "dice_per_class": dice_per_class.tolist(),
        "pixel_accuracy_per_class": pixel_accuracy_per_class.tolist(),
    }

    return metrics


In [5]:
transform_train, transform_val_test = get_seg_transforms(cfg["size"], cfg["norm_mean"], cfg["norm_std"])
test_seg_set = IDRiDSegmentation(root="./data/A. Segmentation/", mode="eval", transform=transform_val_test)
test_classi_set = IDRiDClassification(path=cfg["test_path"], size=cfg["size"], normalize_mean=cfg["norm_mean"], normalize_std=cfg["norm_std"], mode="eval")

test_seg_loader = DataLoader(test_seg_set, batch_size=cfg["batch_size"], shuffle=False, num_workers=2, prefetch_factor=5)
test_classi_loader = DataLoader(test_classi_set, batch_size=cfg["batch_size"], shuffle=False, num_workers=2, prefetch_factor=10)

In [6]:
model = MultiTaskModel(
    backbone=cfg["backbone"],
    task_specs=cfg["task_specs"],
    num_experts=cfg["num_experts"]
)

num_parameters = sum(p.numel() for p in model.parameters())
print("Number of parameters =", num_parameters)

checkpoint = torch.load("./saved/mlt_100.pth", map_location="cpu", weights_only=True)
model.load_state_dict(checkpoint)
model.eval()
model.to(DEVICE)

Number of parameters = 42372668


MultiTaskModel(
  (resnet): ResNet(
    (layer_0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    )
    (layer_1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (do

In [7]:
def print_metrics(metrics, mode):
    tasks = ["Microaneurysms", "Haemorrhages", "Hard Exudates", "Soft Exudates", "Optic Disc"]
    print(f"{mode} Metrics:")
    print(f"\tRG Classification Accuracy: {metrics['accuracy_1']:.04f},\tMER Classification Accuracy: {metrics['accuracy_2']:.04f}")
    for i in range(5):
        print(f"\t{tasks[i]} IoU: {metrics['iou_per_class'][i]:.04f},\t{tasks[i]} Dice Coefficient: {metrics['dice_per_class'][i]:.04f},\t{tasks[i]} Pixel Accuracy: {metrics['pixel_accuracy_per_class'][i]:.04f}")


In [8]:
test_metrics = evaluate_model(model, test_classi_loader, test_seg_loader, DEVICE, num_classes=5, force_all_experts=False, p=0.75)
print_metrics(test_metrics, "Test")

Test Metrics:
	RG Classification Accuracy: 56.3107,	MER Classification Accuracy: 78.6408
	Microaneurysms IoU: 0.1946,	Microaneurysms Dice Coefficient: 0.3258,	Microaneurysms Pixel Accuracy: 0.3016
	Haemorrhages IoU: 0.3197,	Haemorrhages Dice Coefficient: 0.4845,	Haemorrhages Pixel Accuracy: 0.3949
	Hard Exudates IoU: 0.1547,	Hard Exudates Dice Coefficient: 0.2679,	Hard Exudates Pixel Accuracy: 0.1649
	Soft Exudates IoU: 0.4280,	Soft Exudates Dice Coefficient: 0.5995,	Soft Exudates Pixel Accuracy: 0.5350
	Optic Disc IoU: 0.9091,	Optic Disc Dice Coefficient: 0.9524,	Optic Disc Pixel Accuracy: 0.9316
