# Kseg Custom Model Evaluation

## Show Qualitative Results

In [None]:
import cv2
import glob
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import torch

In [None]:
identifier = 'c94bf'
batch_id = '0'

input_file = f'/Users/me/pytorch_logs/train_worker_{identifier}_*/test_samples/batch_{batch_id}/input.nii.gz'
gt_file = f'/Users/me/pytorch_logs/train_worker_{identifier}_*/test_samples/batch_{batch_id}/gt.nii.gz'
pred_file = f'/Users/me/pytorch_logs/train_worker_{identifier}_*/test_samples/batch_{batch_id}/pred.nii.gz'

In [None]:
input_image = nib.load(glob.glob(input_file)[0])
gt_image = nib.load(glob.glob(gt_file)[0])
pred_image = nib.load(glob.glob(pred_file)[0])

x = input_image.get_fdata()
y = gt_image.get_fdata()
y_hat = pred_image.get_fdata()


# Extract middle slice along the z-axis
selected_z = 32
x_slice = x[:, :, selected_z]
y_slice = y[:, :, selected_z]
y_hat_slice = y_hat[:, :, selected_z]

# Define a colormap where each class ID maps to an RGB color
color_map = {
    0: [0, 0, 0],  # Black for class 0 (background)
    1: [0, 255, 0],  # Green for class 1 (CSF / femoral cartilage)
    2: [255, 0, 0],  # Red for class 2 (cortical GM / tibial cartilage)
    3: [0, 0, 255],  # Blue for class 3 (WM / patellar cartilage)
    4: [255, 255, 0],  # Yellow for class 4 (deep GM / femur)
    5: [0, 255, 255],  # Cyan for class 5 (brain stem / tibia)
    6: [255, 0, 255],  # Magenta for class 6 (cerebellum / patella)
}

# Transform x
x_slice = cv2.normalize(
    x_slice, None, 255, 0, cv2.NORM_MINMAX, cv2.CV_8U
)
x_slice = cv2.cvtColor(x_slice, cv2.COLOR_GRAY2RGB)
x_slice = np.transpose(torch.from_numpy(x_slice), axes=[2, 0, 1])

# Transform y
output_image = np.zeros(
    (y_slice.shape[0], y_slice.shape[1], 3), dtype=np.uint8
)
for value, color in color_map.items():
    mask = y_slice == value
    output_image[mask] = color
y_slice = np.transpose(torch.from_numpy(output_image), axes=[2, 0, 1])
y_slice = torch.from_numpy(
    cv2.addWeighted(x_slice.numpy(), 0.5, y_slice.numpy(), 0.5, 0)
).permute(1,2,0)

# Transform y_hat
output_image = np.zeros(
    (y_hat_slice.shape[0], y_hat_slice.shape[1], 3), dtype=np.uint8
)
for value, color in color_map.items():
    mask = y_hat_slice == value
    output_image[mask] = color
y_hat_slice = np.transpose(
    torch.from_numpy(output_image), axes=[2, 0, 1]
)
y_hat_slice = torch.from_numpy(
    cv2.addWeighted(x_slice.numpy(), 0.5, y_hat_slice.numpy(), 0.5, 0)
).permute(1,2,0)

In [None]:
plt.imshow(y_hat_slice)
plt.axis('off')
plt.savefig(f'{identifier}_y_hat_{selected_z}.png', bbox_inches='tight', pad_inches=0)

In [None]:
# plt.imshow(y_slice)
# plt.axis('off')
# plt.savefig(f'{identifier}_y_{selected_z}.png', bbox_inches='tight', pad_inches=0)