# Kseg nnUNet Evaluation

## Calculate Quantitative Results

In [None]:
import os
import nibabel as nib
import numpy as np
import torch

from torchmetrics import Specificity, Recall

In [None]:
def calculate_recall_specificity(pred, gt, num_classes):
    # Flatten tensor
    pred = torch.flatten(pred.int())
    gt = torch.flatten(gt.int())

    if num_classes == 2:
        num_classes -= 1

    # Calculate recall and specificity
    recall_metric = Recall(
        num_classes=num_classes, average='none', multiclass=(num_classes > 1)
    )
    specificity_metric = Specificity(
        num_classes=num_classes, average='none', multiclass=(num_classes > 1)
    )
    per_class_recall = recall_metric(pred, gt)
    avg_recall = per_class_recall.mean()
    per_class_specificity = specificity_metric(pred, gt)
    avg_specificity = per_class_specificity.mean()
    return (
        avg_recall,
        per_class_recall,
        avg_specificity,
        per_class_specificity,
    )

In [None]:
def calculate_dice(y_pred, y_true, num_classes, smooth=1):
    # Convert ordinal encoded tensors to one-hot encoded tensors
    y_pred = torch.nn.functional.one_hot(y_pred.long(), num_classes=num_classes)
    y_true = torch.nn.functional.one_hot(y_true.long(), num_classes=num_classes)

    # Calculate intersection and union
    intersection = torch.sum(y_pred * y_true, dim=(0, 1, 2))
    union = torch.sum(y_pred + y_true, dim=(0, 1, 2))

    # Calculate Dice score for each class
    dice_scores = (2.0 * intersection + smooth) / (union + smooth)

    # Return average Dice score and per class Dice score
    return dice_scores.mean(), dice_scores

In [None]:
# gt_dir = "/Users/me/upenn/preprocessed/test/"
# pred_dir = "/Users/me/nnUNet/nnUNet_Prediction_Results/Task508_UPenn_GBM_SS/"
gt_dir = "/Users/me/oasis/preprocessed/test/"
pred_dir = "/Users/me/nnUNet/nnUNet_Prediction_Results/Task509_Oasis_Tissue/"

# Assuming both directories have the same number of files and ordering
gt_files = sorted(os.listdir(gt_dir))
pred_files = sorted(os.listdir(pred_dir))

avg_dice_scores = []
avg_per_class_dice_scores = []
avg_spec_scores = []
avg_per_class_spec_scores = []
avg_rec_scores = []
avg_per_class_rec_scores = []

for gt_file, pred_file in zip(gt_files, pred_files):
    if not gt_file.endswith(".nii.gz") or not pred_file.endswith(".nii.gz"):
        continue

    gt_path = os.path.join(gt_dir, gt_file)
    pred_path = os.path.join(pred_dir, pred_file)

    input_image = nib.load(gt_path)
    pred_image = nib.load(pred_path)

    y_true = torch.Tensor(input_image.get_fdata())
    y_pred = torch.Tensor(pred_image.get_fdata())

    num_classes = int(torch.max(y_true) + 1)

    # Calculate Dice
    avg_dice, dice = calculate_dice(y_pred, y_true, num_classes)

    # Calculate Recall and Specificity
    avg_rec, rec, avg_spec, spec = calculate_recall_specificity(y_pred, y_true, 
                                                                num_classes)
    # Log intermediate scores
    avg_dice_scores.append(avg_dice)
    avg_per_class_dice_scores.append(dice)
    avg_rec_scores.append(avg_rec)
    avg_per_class_rec_scores.append(rec)
    avg_spec_scores.append(avg_spec)
    avg_per_class_spec_scores.append(spec)

overall_avg_dice = np.mean(avg_dice_scores)
overall_per_class_dice = np.mean(avg_per_class_dice_scores, axis=0)
overall_avg_rec = np.mean(avg_rec_scores)
overall_per_class_rec = np.mean(avg_per_class_rec_scores, axis=0)
overall_avg_spec = np.mean(avg_spec_scores)
overall_per_class_spec = np.mean(avg_per_class_spec_scores, axis=0)

print("\nOverall Average Dice Score:", overall_avg_dice)
print("Overall Average Per Class Dice:", overall_per_class_dice)
print("Overall Average Specificity:", overall_avg_spec)
print("Overall Average Per Class Specificity:", overall_per_class_spec)
print("Overall Average Recall:", overall_avg_rec)
print("Overall Average Per Class Recall:", overall_per_class_rec)

## Show Qualitative Results

In [None]:
import cv2
import numpy as np
import nibabel as nib
import torch
import matplotlib.pyplot as plt

In [None]:
identifier = '0'
# input_file = f'/Users/me/nnUNet/nnUNet_raw_data_base/nnUNet_raw_data/Task509_Oasis_Tissue/imagesTs/{identifier}_0000.nii.gz'
# gt_file = f'/Users/me/oasis/preprocessed/test/label_{identifier}.nii.gz'
# pred_file = f'/Users/me/nnUNet/nnUNet_Prediction_Results/Task509_Oasis_Tissue/{identifier}.nii.gz'
input_file = f'/Users/me/nnUNet/nnUNet_raw_data_base/nnUNet_raw_data/Task508_UPenn_GBM_SS/imagesTs/{identifier}_0000.nii.gz'
gt_file = f'/Users/me/upenn/preprocessed/test/label_{identifier}.nii.gz'
pred_file = f'/Users/me/nnUNet/nnUNet_Prediction_Results/Task508_UPenn_GBM_SS/{identifier}.nii.gz'

In [None]:
input_image = nib.load(input_file)
gt_image = nib.load(gt_file)
pred_image = nib.load(pred_file)

x = input_image.get_fdata()
y = gt_image.get_fdata()
y_hat = pred_image.get_fdata()


# Extract middle slice along the z-axis
selected_z = 32
x_slice = x[:, :, selected_z]
y_slice = y[:, :, selected_z]
y_hat_slice = y_hat[:, :, selected_z]

# Define a colormap where each class ID maps to an RGB color
color_map = {
    0: [0, 0, 0],  # Black for class 0 (background)
    1: [0, 255, 0],  # Green for class 1 (CSF / femoral cartilage)
    2: [255, 0, 0],  # Red for class 2 (cortical GM / tibial cartilage)
    3: [0, 0, 255],  # Blue for class 3 (WM / patellar cartilage)
    4: [255, 255, 0],  # Yellow for class 4 (deep GM / femur)
    5: [0, 255, 255],  # Cyan for class 5 (brain stem / tibia)
    6: [255, 0, 255],  # Magenta for class 6 (cerebellum / patella)
}

# Transform x
x_slice = cv2.normalize(
    x_slice, None, 255, 0, cv2.NORM_MINMAX, cv2.CV_8U
)
x_slice = cv2.cvtColor(x_slice, cv2.COLOR_GRAY2RGB)
x_slice = np.transpose(torch.from_numpy(x_slice), axes=[2, 0, 1])

# Transform y
output_image = np.zeros(
    (y_slice.shape[0], y_slice.shape[1], 3), dtype=np.uint8
)
for value, color in color_map.items():
    mask = y_slice == value
    output_image[mask] = color
y_slice = np.transpose(torch.from_numpy(output_image), axes=[2, 0, 1])
y_slice = torch.from_numpy(
    cv2.addWeighted(x_slice.numpy(), 0.5, y_slice.numpy(), 0.5, 0)
).permute(1,2,0)

# Transform y_hat
output_image = np.zeros(
    (y_hat_slice.shape[0], y_hat_slice.shape[1], 3), dtype=np.uint8
)
for value, color in color_map.items():
    mask = y_hat_slice == value
    output_image[mask] = color
y_hat_slice = np.transpose(
    torch.from_numpy(output_image), axes=[2, 0, 1]
)
y_hat_slice = torch.from_numpy(
    cv2.addWeighted(x_slice.numpy(), 0.5, y_hat_slice.numpy(), 0.5, 0)
).permute(1,2,0)


In [None]:
plt.imshow(y_hat_slice)
plt.axis('off')
plt.savefig(f'{identifier}_y_hat_{selected_z}.png', bbox_inches='tight', pad_inches=0)

In [None]:
plt.imshow(y_slice)
plt.axis('off')
plt.savefig(f'{identifier}_y_{selected_z}.png', bbox_inches='tight', pad_inches=0)