In [1]:
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from ipywidgets import interact, interactive, IntSlider, ToggleButtons
from tensorflow.keras.models import load_model

# Custom objects required for loading the U-Net segmentation models
import sys 
sys.path.append('..')
from segmentation_losses import (
    dice_coefficient,
    dice_loss,
    log_cosh_dice_loss,
    iou
)

In [2]:
model = load_model(
    "../trained_models/task05_prostate_unet_3d_segmentation_model.h5",
    custom_objects={
        "log_cosh_dice_loss": log_cosh_dice_loss,
        "dice_coefficient": dice_coefficient,
        "iou": iou
    }
)

In [3]:
def standardize(mri):
    """
    Standardize mean and standard deviation of each channel and z_dimension slice to mean 0 and standard 
     deviation 1.

    Note: setting the type of the input mri to np.float16 beforehand causes issues, set it afterwards.

    Args:
        mri (np.array): input mri, shape (dim_x, dim_y, dim_z, num_channels)
    Returns:
        standardized_mri (np.array): standardized version of input mri
    """

    standardized_mri = np.zeros(mri.shape)

    # Iterate over channels
    for c in range(mri.shape[3]):
        # Iterate over the `z` depth dimension
        for z in range(mri.shape[2]):
            # Get a slice of the mri at channel c and z-th dimension
            mri_slice = mri[:, :, z, c]

            # Subtract the mean from mri_slice
            centered = mri_slice - np.mean(mri_slice)

            # Divide by the standard deviation (only if it is different from zero)
            if np.std(centered) != 0:
                centered_scaled = centered / np.std(centered)

                # Update the slice of standardized mri with the centered and scaled mri
                standardized_mri[:, :, z, c] = centered_scaled

    return standardized_mri

### 1. Visualizing predictions on validation set images

### 1.1. Visualizing prediction on prostate_37

In [4]:
mri = nib.load("../datasets/Task05_Prostate_320x320x32/val/images/prostate_37.nii.gz").get_fdata().astype(np.float32)

In [5]:
mask = nib.load("../datasets/Task05_Prostate_320x320x32/val/masks/prostate_37.nii.gz").get_fdata().astype(np.uint8)

In [6]:
mask.shape

(320, 320, 32)

In [7]:
mri_standardized = standardize(mri)

In [8]:
mri_standardized = np.expand_dims(mri_standardized, axis=0)  # Keras models require an additional dimension of 'batch_size'

In [9]:
prediction = model.predict(mri_standardized)

In [10]:
prediction = np.squeeze(prediction, axis=0)  # Remove 'batch_size' dimension

In [11]:
prediction.shape

(320, 320, 32, 3)

In [12]:
prediction

array([[[[0.9805638 , 0.01175441, 0.00768164],
         [0.9805638 , 0.01175441, 0.00768164],
         [0.9805638 , 0.01175441, 0.00768164],
         ...,
         [0.9805638 , 0.01175441, 0.00768164],
         [0.9805638 , 0.01175441, 0.00768164],
         [0.9805638 , 0.01175441, 0.00768164]],

        [[0.9805638 , 0.01175441, 0.00768164],
         [0.9805638 , 0.01175441, 0.00768164],
         [0.9805638 , 0.01175441, 0.00768164],
         ...,
         [0.9805638 , 0.01175441, 0.00768164],
         [0.9805638 , 0.01175441, 0.00768164],
         [0.9805638 , 0.01175441, 0.00768164]],

        [[0.9805638 , 0.01175441, 0.00768164],
         [0.9805638 , 0.01175441, 0.00768164],
         [0.9805638 , 0.01175441, 0.00768164],
         ...,
         [0.9805638 , 0.01175441, 0.00768164],
         [0.9805638 , 0.01175441, 0.00768164],
         [0.9805638 , 0.01175441, 0.00768164]],

        ...,

        [[0.9805638 , 0.01175441, 0.00768164],
         [0.9805638 , 0.01175441, 0.00768164]

In [13]:
type(prediction)

numpy.ndarray

In [14]:
prediction = np.argmax(prediction, axis=3)

In [15]:
prediction.shape

(320, 320, 32)

In [16]:
np.unique(prediction)

array([0, 1, 2])

#### Model prediction vs groundtruth number of labels

In [17]:
# Number of labels per class in model prediction
print("Background - class 0")
print(len(prediction[prediction == 0])) 
print("--------")
print("PZ - class 1")
print(len(prediction[prediction == 1]))
print("--------")
print("TZ - class 2")
print(len(prediction[prediction == 2]))

Background - class 0
3232999
--------
PZ - class 1
7856
--------
TZ - class 2
35945


In [18]:
# Number of labels per class in ground truth mask
print("Background - class 0")
print(len(mask[mask == 0]))
print("--------")
print("PZ - class 1")
print(len(mask[mask == 1]))
print("--------")
print("TZ - class 2")
print(len(mask[mask == 2]))

Background - class 0
3228634
--------
PZ - class 1
5897
--------
TZ - class 2
42269


In [19]:
classes_dict = {
    'Background': 0,
    'PZ': 1,
    'TZ': 2
}

# Create button values
select_class = ToggleButtons(
    options=['Background', 'PZ', 'TZ', 'All'],
    description='Select Class:',
    disabled=False,
    button_style='info', 
    
)
# Create layer slider
select_layer = IntSlider(min=0, max=mri.shape[2] - 1, description='Select Layer', continuous_update=False)

    
# Define a function for plotting images
def plot(seg_class, layer, channel):
    print(f"Plotting Layer: {layer} | Label: {seg_class} | Channel: {channel}")
    fig = plt.figure(figsize=(20, 10))
    
    fig.add_subplot(1, 3, 1)
    plt.title("prostate_37", fontsize=20)
    plt.imshow(mri[:, :, layer, channel], cmap='gray');
    plt.axis('off')
    
    fig.add_subplot(1, 3, 2)
    if seg_class == "All":
        mask_groundtruth = mask[:, :, layer]
        plt.title("Groundtruth mask", fontsize=20)
        plt.imshow(mask_groundtruth)
        plt.axis('off');
    else:
        img_label = classes_dict[seg_class]
        mask_groundtruth = np.where(mask[:, :, layer] == img_label, 255, 0)
        plt.title("Groundtruth mask", fontsize=20)
        plt.imshow(mask_groundtruth, cmap='gray')
        plt.axis('off');
    
    fig.add_subplot(1, 3, 3)
    if seg_class == "All":
        mask_model_prediction = prediction[:, :, layer]
        plt.title("Model prediction mask", fontsize=20)
        plt.imshow(mask_model_prediction)
        plt.axis('off');
    else:
        img_label = classes_dict[seg_class]
        mask_model_prediction = np.where(prediction[:, :, layer] == img_label, 255, 0)
        plt.title("Model prediction mask", fontsize=20)
        plt.imshow(mask_model_prediction, cmap='gray')
        plt.axis('off');

# Set channel to view:
#  Channel 0: "T2"
#  Channel 1: "ADC"   
    
# Use the interactive() tool to create the visualization
interactive(plot, seg_class=select_class, layer=select_layer, channel=(0, 1))

interactive(children=(ToggleButtons(button_style='info', description='Select Class:', options=('Background', '…