In [1]:
import numpy as np
import torch
import os
import sys

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(os.path.join(project_root, "src"))

from model.ensemble import EnsembleModel
from model.common import CommonPLModuleWrapper
from sklearn.metrics import confusion_matrix
from monai.losses import DiceLoss
from monai.networks.nets import UNETR, SegResNet
from monai.networks.utils import one_hot
import pandas as pd


from dataloader import BrainTumourDataModule

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
IMAGE_PATH = "../data/BrainTumourData/imagesTr/"
LABEL_PATH = "../data/BrainTumourData/labelsTr/"
IMG_SIZE = 128
BATCH_SIZE = 1
IN_CHANNELS = 4
OUT_CHANNELS = 4

In [3]:
data_module = BrainTumourDataModule(
    data_path=IMAGE_PATH,
    seg_path=LABEL_PATH,
    img_dim=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
)
data_module.prepare_data()
data_module.setup()

In [4]:
segresnet = CommonPLModuleWrapper(
    model=SegResNet(in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS),
    loss_fn=DiceLoss(softmax=True),
)
segres_weights = torch.load(
    f"../model/{segresnet.model.__class__.__name__}.ckpt", weights_only=True
)
segresnet.load_state_dict(segres_weights["state_dict"])

unet = CommonPLModuleWrapper(
    model=UNETR(
        in_channels=IN_CHANNELS,
        out_channels=OUT_CHANNELS,
        img_size=(IMG_SIZE, IMG_SIZE, IMG_SIZE),
    ),
    loss_fn=DiceLoss(softmax=True),
)
unet_weights = torch.load(
    f"../model/{unet.model.__class__.__name__}.ckpt", weights_only=True
)
unet.load_state_dict(unet_weights["state_dict"])

model = EnsembleModel([segresnet, unet], num_classes=4)

In [None]:
model.eval()

total_batches = len(data_module.test_dataloader())
num_classes = 4
height, width = 128, 128
slices_per_sample = 128

all_predictions = torch.empty(
    (total_batches, num_classes, slices_per_sample, height, width), dtype=torch.float32
)
all_targets = torch.empty(
    (total_batches, num_classes, slices_per_sample, height, width), dtype=torch.float32
)

for current_index, (images, targets) in enumerate(data_module.test_dataloader()):
    with torch.no_grad():
        predictions = model(images)
        all_predictions[current_index] = torch.argmax(predictions[0], dim=0)
        all_targets[current_index] = targets[0]

all_predictions = one_hot(all_predictions.unsqueeze(1), num_classes)

RuntimeError: The expanded size of the tensor (128) must match the existing size (4) at non-singleton dimension 1.  Target sizes: [4, 128, 128, 128].  Tensor sizes: [4, 128, 128]

In [None]:
from monai.metrics import (
    DiceMetric,
    compute_confusion_matrix_metric,
    get_confusion_matrix,
    compute_dice,
)

dice_metric = compute_dice(
    y_pred=all_predictions, y=all_targets, ignore_empty=False, num_classes=4
)
mean_dice = dice_metric.mean(dim=0)
print(f"Dice Coefficient: {mean_dice}")
print(f"Mean Dice Coefficient: {mean_dice.mean()}")

In [None]:
index_to_name = {
    0: "Background",
    1: "Edema",
    2: "Non-Enhancing Tumor",
    3: "Enhancing Tumor",
}

predicted_classes = torch.argmax(all_predictions, dim=1).flatten().cpu().numpy()
target_classes = torch.argmax(all_targets, dim=1).flatten().cpu().numpy()
conf_matrix = confusion_matrix(target_classes, predicted_classes)

df = pd.DataFrame(
    conf_matrix, columns=index_to_name.values(), index=index_to_name.values()
)
print(df)

In [None]:
confusion_matrix = get_confusion_matrix(
    y_pred=all_predictions, y=all_targets, include_background=False
)
cm = confusion_matrix.sum(dim=0)
print(f"Confusion Matrix: \n{cm}")

recall = compute_confusion_matrix_metric(
    confusion_matrix=cm,
    metric_name="sensitivity",
)
print(f"Recall: {recall}")

alpha = torch.tensor([0.2, 0.3, 0.5])
weighted_recall = (recall * alpha).sum()
print(f"Weighted Recall: {weighted_recall}")

precision = compute_confusion_matrix_metric(
    confusion_matrix=cm,
    metric_name="precision",
)
print(f"Precision: {precision}")