In [2]:
from ipywidgets import IntSlider, Dropdown, interactive
from IPython.display import display
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import os

In [3]:
def _display_slices(slices: list):
    """
    Function to plot a slice of an MRI image

    Args:
        slices: List of 2D slices to plot

    Returns:
        None
    """
    # Normalize slices to [0, 1] for proper blending
    slices = [
        np.clip(slice_data / np.max(slice_data), 0, 1) if np.max(slice_data) > 0 else slice_data
        for slice_data in slices
    ]

    # Separate the images
    original = slices[0]
    ground_truth = slices[1]
    prediction = slices[2]

    # Create a new RGB image for overlay
    overlay = np.zeros((*original.shape, 3))  
    overlay[..., 0] = prediction  
    overlay[..., 2] = ground_truth 

    overlay_truth = np.zeros((*original.shape, 3)) 
    overlay_truth[..., 2] = ground_truth 

    overlay_pred = np.zeros((*original.shape, 3))  
    overlay_pred[..., 0] = prediction  

    # Plot the two images side by side
    _, axes = plt.subplots(1, 4, figsize=(20, 20))

    # Left axis: Original grayscale image
    axes[0].imshow(original, cmap='gray')
    axes[0].set_title("Original MRI Slice")
    axes[0].axis('off')

    # Right axis: Superimposed image
    axes[1].imshow(original, cmap='gray')  
    axes[1].imshow(overlay_truth, alpha=0.5)  
    axes[1].set_title("Ground Truth")
    axes[1].axis('off')

    # Right axis: Superimposed image
    axes[2].imshow(original, cmap='gray')  
    axes[2].imshow(overlay_pred, alpha=0.5)  
    axes[2].set_title("Model's Prediction")
    axes[2].axis('off')

    # Right axis: Superimposed image
    axes[3].imshow(original, cmap='gray')  
    axes[3].imshow(overlay, alpha=0.5)  
    axes[3].set_title("Overlap")
    axes[3].axis('off')

    plt.tight_layout()
    plt.show()

In [4]:
def display_slices(slice_type: str, slice_index: int, nifti_files: list):
    """
    Function to display a slice of an MRI image

    Args:
        slice_type: Type of slice to display (Axial, Coronal, Sagittal)
        slice_index: Index of the slice to display
    
    Returns:
        None
    """

    os.rename(f"{nifti_files[0]}.gz", nifti_files[0])
    os.rename(f"{nifti_files[1]}.gz", nifti_files[1])

    # Load NIFTI files and extract the slice
    slices = []
    for file in nifti_files:
        nii_img = nib.load(file)  
        data = nii_img.get_fdata()  
        if slice_type == "Axial":
            slice_data = data[:, :, slice_index]
        elif slice_type == "Coronal":
            slice_data = data[:, slice_index, :]
        elif slice_type == "Sagittal":
            slice_data = data[slice_index, :, :]

        slices.append(slice_data)

    os.rename(nifti_files[0], f"{nifti_files[0]}.gz")
    os.rename(nifti_files[1], f"{nifti_files[1]}.gz")

    # Plot the slices
    _display_slices(slices)

In [5]:
def display_slices_interactive(nifti_files: list):
    # Create a dropdown for configuration
    config_dropdown = Dropdown(
        options=["Axial", "Coronal", "Sagittal"],
        value="Axial",
        description='Orientation:'
    )

    # Create a slider for the slice index
    slice_slider = IntSlider(value=110, min=0, max=181, step=1, description="Slice Index")

    # Function to update the slider's max value based on the dropdown
    def update_slider(change):
        if change['new'] == "Coronal":
            slice_slider.max = 217
        else:
            slice_slider.max = 181

    # Link the dropdown to update the slider
    config_dropdown.observe(update_slider, names='value')

    # Wrapper function to fix 'data'
    def wrapped_display_slices(slice_type, slice_index):
        display_slices(slice_type=slice_type, slice_index=slice_index, nifti_files=nifti_files)

    # Create an interactive widget
    interactive_widget = interactive(wrapped_display_slices, slice_type=config_dropdown, slice_index=slice_slider)

    # Display the interactive widget
    display(interactive_widget)

In [6]:
cwd = os.getcwd()

nnUNet_raw = f"{cwd}/nnUNet_raw/Dataset024_MSLesSeg"
flair = f"{nnUNet_raw}/imagesTs/BRATS_88_0000.nii"
mask = f"{nnUNet_raw}/labelsTs/BRATS_88.nii"
nnUNet_prediction = f"{nnUNet_raw}/nnUNet_tests_0/BRATS_88.nii.gz"

nifti_files = [flair, mask, nnUNet_prediction]

display_slices_interactive(nifti_files)

interactive(children=(Dropdown(description='Orientation:', options=('Axial', 'Coronal', 'Sagittal'), value='Ax…