# Writing a script to visualize the segmentations

In [None]:
# import libraries
from glob import glob
import ipywidgets as widgets
from IPython.display import display
import os
from os import path as osp
%pylab
%matplotlib inline

import nibabel as nib

In [None]:
# define directories
SEG_DIR = "/data/rohitrango/Implicit3DCNNTasks/brats2021_unimodal_val/resnetv1.5-fold0/"
IMG_DIR = "/data/rohitrango/BRATS2021/val/"

In [None]:
# select file
files = sorted(glob(SEG_DIR + "/*"))
items = [x.split('/')[-1] for x in files]
selected_file = None

# Create the dropdown
dropdown = widgets.Dropdown(
    options=items,
    value=items[0], # default value
    description='Choose an item:',
    disabled=False,
)

# Function to handle changes in the dropdown
def on_dropdown_change(change):
    global selected_file
    if change['type'] == 'change' and change['name'] == 'value':
        print(f"You selected {change['new']}")
        selected_file = change['new']

# Watch the dropdown for changes
dropdown.observe(on_dropdown_change, names='value')

# Display the dropdown
display(dropdown)


In [None]:
# given the selected file, load this and the corresponding brain file
seg = nib.load(osp.join(SEG_DIR, selected_file)).get_fdata()

In [None]:
# load the corresponding image inputs
imgfiles = sorted(glob(osp.join(IMG_DIR, selected_file.split('.')[0], '*.nii.gz')))
imgs = [nib.load(x).get_fdata() for x in imgfiles]
ranges = [(img.min(), img.max()) for img in imgs]

In [None]:
def plot_slices(slice_index):
    fig, axs = plt.subplots(2, 3, figsize=(15, 10))  # Adjust the size as needed
    axs = [ax for axi in axs for ax in axi]
    for i, img in enumerate(imgs):
        axs[i].imshow(img[:, :, slice_index].T, cmap='gray', vmin=ranges[i][0], vmax=ranges[i][1])
        axs[i].set_title(f'Image {i+1} - Slice {slice_index}')
        axs[i].axis('off')
    # last one is segmentation
    axs[-2].imshow(seg[:, :, slice_index].transpose((1, 0, 2)))
    axs[-2].set_title('Segmentation')
    axs[-2].axis('off')
    axs[-1].imshow(seg[:, :, slice_index, 0].T, cmap='gray')
    axs[-1].set_title('ET')
    axs[-1].axis('off')
    plt.show()

# Create a slider
max_slice_index = imgs[0].shape[2] - 1  # Assuming all images have the same shape
slider = widgets.IntSlider(
    value=0,
    min=0,
    max=max_slice_index,
    step=1,
    description='Slice Index:',
    continuous_update=False
)

widgets.interactive(plot_slices, slice_index=slider)