In [1]:
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from ipywidgets import interact, interactive, IntSlider, ToggleButtons
from tensorflow.keras.models import load_model

# Custom objects required for loading the U-Net segmentation models
import sys 
sys.path.append('../..')
from segmentation_losses import (
    dice_coefficient,
    dice_loss,
    log_cosh_dice_loss,
    iou
)

In [2]:
model = load_model(
    "../trained_models/task01_braintumour_unet_3d_segmentation_model.h5",
    custom_objects={
        "log_cosh_dice_loss": log_cosh_dice_loss,
        "dice_coefficient": dice_coefficient,
        "iou": iou
    }
)

In [3]:
def standardize(mri):
    """
    Standardize mean and standard deviation of each channel and z_dimension slice to mean 0 and standard 
     deviation 1.

    Note: setting the type of the input mri to np.float16 beforehand causes issues, set it afterwards.

    Args:
        mri (np.array): input mri, shape (dim_x, dim_y, dim_z, num_channels)
    Returns:
        standardized_mri (np.array): standardized version of input mri
    """

    standardized_mri = np.zeros(mri.shape)

    # Iterate over channels
    for c in range(mri.shape[3]):
        # Iterate over the `z` depth dimension
        for z in range(mri.shape[2]):
            # Get a slice of the mri at channel c and z-th dimension
            mri_slice = mri[:, :, z, c]

            # Subtract the mean from mri_slice
            centered = mri_slice - np.mean(mri_slice)

            # Divide by the standard deviation (only if it is different from zero)
            if np.std(centered) != 0:
                centered_scaled = centered / np.std(centered)

                # Update the slice of standardized mri with the centered and scaled mri
                standardized_mri[:, :, z, c] = centered_scaled

    return standardized_mri

### 1. Visualizing predictions on validation set images

### 1.1. Visualizing prediction on BRATS_389 file

In [4]:
mri = nib.load("../datasets/Task01_BrainTumour_240x240x160x4/val/images/BRATS_389.nii.gz").get_fdata().astype(np.float32)

In [5]:
mask = nib.load("../datasets/Task01_BrainTumour_240x240x160x4/val/masks/BRATS_389.nii.gz").get_fdata().astype(np.uint8)

In [6]:
mask.shape

(240, 240, 160)

In [7]:
mri_standardized = standardize(mri)

In [8]:
mri_standardized = np.expand_dims(mri_standardized, axis=0)  # Keras models require an additional dimension of 'batch_size'

In [9]:
prediction = model.predict(mri_standardized)

In [10]:
prediction = np.squeeze(prediction, axis=0)  # Remove 'batch_size' dimension

In [11]:
prediction.shape

(240, 240, 160, 4)

In [None]:
prediction

In [13]:
type(prediction)

numpy.ndarray

In [14]:
prediction = np.argmax(prediction, axis=3)

In [15]:
prediction.shape

(240, 240, 160)

In [16]:
np.unique(prediction)

array([0, 1, 2, 3])

#### Model prediction vs groundtruth number of labels

In [17]:
# Number of labels per class in model prediction
print("Background - class 0")
print(len(prediction[prediction == 0])) 
print("--------")
print("Edema - class 1")
print(len(prediction[prediction == 1]))
print("--------")
print("Non-Enhancing-Tumor - class 2")
print(len(prediction[prediction == 2]))
print("--------")
print("Enhancing-Tumor - class 3")
print(len(prediction[prediction == 3])) 

Background - class 0
9179123
--------
Edema - class 1
23876
--------
Non-Enhancing-Tumor - class 2
3064
--------
Enhancing-Tumor - class 3
9937


In [18]:
# Number of labels per class in ground truth mask
print("Background - class 0")
print(len(mask[mask == 0]))
print("--------")
print("Edema - class 1")
print(len(mask[mask == 1]))
print("--------")
print("Non-Enhancing-Tumor - class 2")
print(len(mask[mask == 2]))
print("--------")
print("Enhancing-Tumor - class 3")
print(len(mask[mask == 3]))

Background - class 0
9178172
--------
Edema - class 1
25665
--------
Non-Enhancing-Tumor - class 2
3169
--------
Enhancing-Tumor - class 3
8994


In [19]:
classes_dict = {
    'Background': 0,
    'Edema': 1,
    'Non-enhancing tumor': 2,
    'Enhancing tumor': 3 
}

# Create button values
select_class = ToggleButtons(
    options=['Background','Edema', 'Non-enhancing tumor', 'Enhancing tumor', 'All'],
    description='Select Class:',
    disabled=False,
    button_style='info', 
    
)
# Create layer slider
select_layer = IntSlider(min=0, max=mri.shape[2] - 1, description='Select Layer', continuous_update=False)

    
# Define a function for plotting images
def plot(seg_class, layer, channel):
    print(f"Plotting Layer: {layer} | Label: {seg_class} | Channel: {channel}")
    fig = plt.figure(figsize=(20, 10))
    
    fig.add_subplot(1, 3, 1)
    plt.title("MRI_BRATS_389", fontsize=20)
    plt.imshow(mri[:, :, layer, channel], cmap='gray');
    plt.axis('off')
    
    fig.add_subplot(1, 3, 2)
    if seg_class == "All":
        mask_groundtruth = mask[:, :, layer]
        plt.title("Groundtruth mask", fontsize=20)
        plt.imshow(mask_groundtruth)
        plt.axis('off');
    else:
        img_label = classes_dict[seg_class]
        mask_groundtruth = np.where(mask[:, :, layer] == img_label, 255, 0)
        plt.title("Groundtruth mask", fontsize=20)
        plt.imshow(mask_groundtruth, cmap='gray')
        plt.axis('off');
    
    fig.add_subplot(1, 3, 3)
    if seg_class == "All":
        mask_model_prediction = prediction[:, :, layer]
        plt.title("Model prediction mask", fontsize=20)
        plt.imshow(mask_model_prediction)
        plt.axis('off');
    else:
        img_label = classes_dict[seg_class]
        mask_model_prediction = np.where(prediction[:, :, layer] == img_label, 255, 0)
        plt.title("Model prediction mask", fontsize=20)
        plt.imshow(mask_model_prediction, cmap='gray')
        plt.axis('off');

# Set channel to view:
#  Channel 0: "FLAIR" Fluid-attenuated inversion recovery
#  Channel 1: "T1w" T1-weighted
#  Channel 2: "t1gd" T1-weighted with gadolinium contrast enhancement
#  Channel 3: "T2w" T2-weighted    
    
# Use the interactive() tool to create the visualization
interactive(plot, seg_class=select_class, layer=select_layer, channel=(0, 3))

interactive(children=(ToggleButtons(button_style='info', description='Select Class:', options=('Background', '…

### 1.2 Visualizing prediction on BRATS_390 file

In [20]:
mri = nib.load("../datasets/Task01_BrainTumour_240x240x160x4/val/images/BRATS_390.nii.gz").get_fdata().astype(np.float32)

In [21]:
mask = nib.load("../datasets/Task01_BrainTumour_240x240x160x4/val/masks/BRATS_390.nii.gz").get_fdata().astype(np.uint8)

In [22]:
mask.shape

(240, 240, 160)

In [23]:
mri_standardized = standardize(mri)

In [24]:
mri_standardized = np.expand_dims(mri_standardized, axis=0)  # Keras models require an additional dimension of 'batch_size'

In [25]:
prediction = model.predict(mri_standardized)

In [26]:
prediction = np.squeeze(prediction, axis=0)  # Remove 'batch_size' dimension

In [27]:
prediction.shape

(240, 240, 160, 4)

In [29]:
type(prediction)

numpy.ndarray

In [30]:
prediction = np.argmax(prediction, axis=3)

In [31]:
prediction.shape

(240, 240, 160)

In [32]:
np.unique(prediction)

array([0, 1, 2, 3])

#### Model prediction vs groundtruth number of labels

In [33]:
# Number of labels per class in model prediction
print("Background - class 0")
print(len(prediction[prediction == 0])) 
print("--------")
print("Edema - class 1")
print(len(prediction[prediction == 1]))
print("--------")
print("Non-Enhancing-Tumor - class 2")
print(len(prediction[prediction == 2]))
print("--------")
print("Enhancing-Tumor - class 3")
print(len(prediction[prediction == 3])) 

Background - class 0
9173289
--------
Edema - class 1
32362
--------
Non-Enhancing-Tumor - class 2
2059
--------
Enhancing-Tumor - class 3
8290


In [34]:
# Number of labels per class in ground truth mask
print("Background - class 0")
print(len(mask[mask == 0]))
print("--------")
print("Edema - class 1")
print(len(mask[mask == 1]))
print("--------")
print("Non-Enhancing-Tumor - class 2")
print(len(mask[mask == 2]))
print("--------")
print("Enhancing-Tumor - class 3")
print(len(mask[mask == 3]))

Background - class 0
9185510
--------
Edema - class 1
19594
--------
Non-Enhancing-Tumor - class 2
2287
--------
Enhancing-Tumor - class 3
8609


In [36]:
classes_dict = {
    'Background': 0,
    'Edema': 1,
    'Non-enhancing tumor': 2,
    'Enhancing tumor': 3 
}

# Create button values
select_class = ToggleButtons(
    options=['Background','Edema', 'Non-enhancing tumor', 'Enhancing tumor', 'All'],
    description='Select Class:',
    disabled=False,
    button_style='info', 
    
)
# Create layer slider
select_layer = IntSlider(min=0, max=mri.shape[2] - 1, description='Select Layer', continuous_update=False)

    
# Define a function for plotting images
def plot(seg_class, layer, channel):
    print(f"Plotting Layer: {layer} | Label: {seg_class} | Channel: {channel}")
    fig = plt.figure(figsize=(20, 10))
    
    fig.add_subplot(1, 3, 1)
    plt.title("MRI_BRATS_390", fontsize=20)
    plt.imshow(mri[:, :, layer, channel], cmap='gray');
    plt.axis('off')
    
    fig.add_subplot(1, 3, 2)
    if seg_class == "All":
        mask_groundtruth = mask[:, :, layer]
        plt.title("Groundtruth mask", fontsize=20)
        plt.imshow(mask_groundtruth)
        plt.axis('off');
    else:
        img_label = classes_dict[seg_class]
        mask_groundtruth = np.where(mask[:, :, layer] == img_label, 255, 0)
        plt.title("Groundtruth mask", fontsize=20)
        plt.imshow(mask_groundtruth, cmap='gray')
        plt.axis('off');
    
    fig.add_subplot(1, 3, 3)
    if seg_class == "All":
        mask_model_prediction = prediction[:, :, layer]
        plt.title("Model prediction mask", fontsize=20)
        plt.imshow(mask_model_prediction)
        plt.axis('off');
    else:
        img_label = classes_dict[seg_class]
        mask_model_prediction = np.where(prediction[:, :, layer] == img_label, 255, 0)
        plt.title("Model prediction mask", fontsize=20)
        plt.imshow(mask_model_prediction, cmap='gray')
        plt.axis('off');

# Set channel to view:
#  Channel 0: "FLAIR" Fluid-attenuated inversion recovery
#  Channel 1: "T1w" T1-weighted
#  Channel 2: "t1gd" T1-weighted with gadolinium contrast enhancement
#  Channel 3: "T2w" T2-weighted    
    
# Use the interactive() tool to create the visualization
interactive(plot, seg_class=select_class, layer=select_layer, channel=(0, 3))

interactive(children=(ToggleButtons(button_style='info', description='Select Class:', options=('Background', '…

### 1.3 Visualizing prediction on BRATS_391 file

In [37]:
mri = nib.load("../datasets/Task01_BrainTumour_240x240x160x4/val/images/BRATS_391.nii.gz").get_fdata().astype(np.float32)

In [38]:
mask = nib.load("../datasets/Task01_BrainTumour_240x240x160x4/val/masks/BRATS_391.nii.gz").get_fdata().astype(np.uint8)

In [39]:
mask.shape

(240, 240, 160)

In [40]:
mri_standardized = standardize(mri)

In [41]:
mri_standardized = np.expand_dims(mri_standardized, axis=0)  # Keras models require an additional dimension of 'batch_size'

In [42]:
prediction = model.predict(mri_standardized)

In [43]:
prediction = np.squeeze(prediction, axis=0)  # Remove 'batch_size' dimension

In [44]:
prediction.shape

(240, 240, 160, 4)

In [46]:
type(prediction)

numpy.ndarray

In [47]:
prediction = np.argmax(prediction, axis=3)

In [48]:
prediction.shape

(240, 240, 160)

In [49]:
np.unique(prediction)

array([0, 1, 2, 3])

#### Model prediction vs groundtruth number of labels

In [50]:
# Number of labels per class in model prediction
print("Background - class 0")
print(len(prediction[prediction == 0])) 
print("--------")
print("Edema - class 1")
print(len(prediction[prediction == 1]))
print("--------")
print("Non-Enhancing-Tumor - class 2")
print(len(prediction[prediction == 2]))
print("--------")
print("Enhancing-Tumor - class 3")
print(len(prediction[prediction == 3])) 

Background - class 0
9123505
--------
Edema - class 1
49013
--------
Non-Enhancing-Tumor - class 2
14481
--------
Enhancing-Tumor - class 3
29001


In [51]:
# Number of labels per class in ground truth mask
print("Background - class 0")
print(len(mask[mask == 0]))
print("--------")
print("Edema - class 1")
print(len(mask[mask == 1]))
print("--------")
print("Non-Enhancing-Tumor - class 2")
print(len(mask[mask == 2]))
print("--------")
print("Enhancing-Tumor - class 3")
print(len(mask[mask == 3]))

Background - class 0
9134153
--------
Edema - class 1
44701
--------
Non-Enhancing-Tumor - class 2
3614
--------
Enhancing-Tumor - class 3
33532


In [52]:
classes_dict = {
    'Background': 0,
    'Edema': 1,
    'Non-enhancing tumor': 2,
    'Enhancing tumor': 3 
}

# Create button values
select_class = ToggleButtons(
    options=['Background','Edema', 'Non-enhancing tumor', 'Enhancing tumor', 'All'],
    description='Select Class:',
    disabled=False,
    button_style='info', 
    
)
# Create layer slider
select_layer = IntSlider(min=0, max=mri.shape[2] - 1, description='Select Layer', continuous_update=False)

    
# Define a function for plotting images
def plot(seg_class, layer, channel):
    print(f"Plotting Layer: {layer} | Label: {seg_class} | Channel: {channel}")
    fig = plt.figure(figsize=(20, 10))
    
    fig.add_subplot(1, 3, 1)
    plt.title("MRI_BRATS_391", fontsize=20)
    plt.imshow(mri[:, :, layer, channel], cmap='gray');
    plt.axis('off')
    
    fig.add_subplot(1, 3, 2)
    if seg_class == "All":
        mask_groundtruth = mask[:, :, layer]
        plt.title("Groundtruth mask", fontsize=20)
        plt.imshow(mask_groundtruth)
        plt.axis('off');
    else:
        img_label = classes_dict[seg_class]
        mask_groundtruth = np.where(mask[:, :, layer] == img_label, 255, 0)
        plt.title("Groundtruth mask", fontsize=20)
        plt.imshow(mask_groundtruth, cmap='gray')
        plt.axis('off');
    
    fig.add_subplot(1, 3, 3)
    if seg_class == "All":
        mask_model_prediction = prediction[:, :, layer]
        plt.title("Model prediction mask", fontsize=20)
        plt.imshow(mask_model_prediction)
        plt.axis('off');
    else:
        img_label = classes_dict[seg_class]
        mask_model_prediction = np.where(prediction[:, :, layer] == img_label, 255, 0)
        plt.title("Model prediction mask", fontsize=20)
        plt.imshow(mask_model_prediction, cmap='gray')
        plt.axis('off');

# Set channel to view:
#  Channel 0: "FLAIR" Fluid-attenuated inversion recovery
#  Channel 1: "T1w" T1-weighted
#  Channel 2: "t1gd" T1-weighted with gadolinium contrast enhancement
#  Channel 3: "T2w" T2-weighted    
    
# Use the interactive() tool to create the visualization
interactive(plot, seg_class=select_class, layer=select_layer, channel=(0, 3))

interactive(children=(ToggleButtons(button_style='info', description='Select Class:', options=('Background', '…

### 1.4 Visualizing prediction on BRATS_392 file

In [53]:
mri = nib.load("../datasets/Task01_BrainTumour_240x240x160x4/val/images/BRATS_392.nii.gz").get_fdata().astype(np.float32)

In [54]:
mask = nib.load("../datasets/Task01_BrainTumour_240x240x160x4/val/masks/BRATS_392.nii.gz").get_fdata().astype(np.uint8)

In [55]:
mask.shape

(240, 240, 160)

In [56]:
mri_standardized = standardize(mri)

In [57]:
mri_standardized = np.expand_dims(mri_standardized, axis=0)  # Keras models require an additional dimension of 'batch_size'

In [58]:
prediction = model.predict(mri_standardized)

In [59]:
prediction = np.squeeze(prediction, axis=0)  # Remove 'batch_size' dimension

In [60]:
prediction.shape

(240, 240, 160, 4)

In [61]:
type(prediction)

numpy.ndarray

In [62]:
prediction = np.argmax(prediction, axis=3)

In [63]:
prediction.shape

(240, 240, 160)

In [64]:
np.unique(prediction)

array([0, 1, 2, 3])

#### Model prediction vs groundtruth number of labels

In [65]:
# Number of labels per class in model prediction
print("Background - class 0")
print(len(prediction[prediction == 0])) 
print("--------")
print("Edema - class 1")
print(len(prediction[prediction == 1]))
print("--------")
print("Non-Enhancing-Tumor - class 2")
print(len(prediction[prediction == 2]))
print("--------")
print("Enhancing-Tumor - class 3")
print(len(prediction[prediction == 3])) 

Background - class 0
9095749
--------
Edema - class 1
22441
--------
Non-Enhancing-Tumor - class 2
20798
--------
Enhancing-Tumor - class 3
77012


In [66]:
# Number of labels per class in ground truth mask
print("Background - class 0")
print(len(mask[mask == 0]))
print("--------")
print("Edema - class 1")
print(len(mask[mask == 1]))
print("--------")
print("Non-Enhancing-Tumor - class 2")
print(len(mask[mask == 2]))
print("--------")
print("Enhancing-Tumor - class 3")
print(len(mask[mask == 3]))

Background - class 0
9104688
--------
Edema - class 1
28795
--------
Non-Enhancing-Tumor - class 2
21404
--------
Enhancing-Tumor - class 3
61113


In [67]:
classes_dict = {
    'Background': 0,
    'Edema': 1,
    'Non-enhancing tumor': 2,
    'Enhancing tumor': 3 
}

# Create button values
select_class = ToggleButtons(
    options=['Background','Edema', 'Non-enhancing tumor', 'Enhancing tumor', 'All'],
    description='Select Class:',
    disabled=False,
    button_style='info', 
    
)
# Create layer slider
select_layer = IntSlider(min=0, max=mri.shape[2] - 1, description='Select Layer', continuous_update=False)

    
# Define a function for plotting images
def plot(seg_class, layer, channel):
    print(f"Plotting Layer: {layer} | Label: {seg_class} | Channel: {channel}")
    fig = plt.figure(figsize=(20, 10))
    
    fig.add_subplot(1, 3, 1)
    plt.title("MRI_BRATS_392", fontsize=20)
    plt.imshow(mri[:, :, layer, channel], cmap='gray');
    plt.axis('off')
    
    fig.add_subplot(1, 3, 2)
    if seg_class == "All":
        mask_groundtruth = mask[:, :, layer]
        plt.title("Groundtruth mask", fontsize=20)
        plt.imshow(mask_groundtruth)
        plt.axis('off');
    else:
        img_label = classes_dict[seg_class]
        mask_groundtruth = np.where(mask[:, :, layer] == img_label, 255, 0)
        plt.title("Groundtruth mask", fontsize=20)
        plt.imshow(mask_groundtruth, cmap='gray')
        plt.axis('off');
    
    fig.add_subplot(1, 3, 3)
    if seg_class == "All":
        mask_model_prediction = prediction[:, :, layer]
        plt.title("Model prediction mask", fontsize=20)
        plt.imshow(mask_model_prediction)
        plt.axis('off');
    else:
        img_label = classes_dict[seg_class]
        mask_model_prediction = np.where(prediction[:, :, layer] == img_label, 255, 0)
        plt.title("Model prediction mask", fontsize=20)
        plt.imshow(mask_model_prediction, cmap='gray')
        plt.axis('off');

# Set channel to view:
#  Channel 0: "FLAIR" Fluid-attenuated inversion recovery
#  Channel 1: "T1w" T1-weighted
#  Channel 2: "t1gd" T1-weighted with gadolinium contrast enhancement
#  Channel 3: "T2w" T2-weighted    
    
# Use the interactive() tool to create the visualization
interactive(plot, seg_class=select_class, layer=select_layer, channel=(0, 3))

interactive(children=(ToggleButtons(button_style='info', description='Select Class:', options=('Background', '…

### 2. Testing on the medical decathlon test images for Task 1 
Having to modify the original 155 layer files to 160 layer files then taking the first 155 layers as the model prediction for submission to the challenge

### 2.1 Visualizing prediction on BRATS_485 file

In [68]:
mri = nib.load("../datasets/Task01_BrainTumour/test_images_for_model_prediction_submission/BRATS_485.nii.gz").get_fdata().astype(np.float32)

In [70]:
mri.shape

(240, 240, 155, 4)

In [71]:
mri_temp = np.zeros((240, 240, 160, 4))
mri_temp[:, :, :155, :] = mri[:, :, :]

mri = mri_temp

In [72]:
mri.shape

(240, 240, 160, 4)

In [73]:
mri_standardized = standardize(mri)

In [74]:
mri_standardized = np.expand_dims(mri_standardized, axis=0)  # Keras models require an additional dimension of 'batch_size'

In [75]:
prediction = model.predict(mri_standardized)

In [76]:
prediction = np.squeeze(prediction, axis=0)  # Remove 'batch_size' dimension

In [77]:
prediction.shape

(240, 240, 160, 4)

In [78]:
type(prediction)

numpy.ndarray

In [79]:
prediction = np.argmax(prediction, axis=3)

In [80]:
prediction.shape

(240, 240, 160)

In [88]:
mri = mri[:, :, :155, :]
prediction = prediction[:, :, :155]

In [89]:
prediction.shape

(240, 240, 155)

In [90]:
np.unique(prediction)

array([0, 1, 2, 3])

#### Model prediction number of labels (no available masks)

In [91]:
# Number of labels per class in model prediction
print("Background - class 0")
print(len(prediction[prediction == 0])) 
print("--------")
print("Edema - class 1")
print(len(prediction[prediction == 1]))
print("--------")
print("Non-Enhancing-Tumor - class 2")
print(len(prediction[prediction == 2]))
print("--------")
print("Enhancing-Tumor - class 3")
print(len(prediction[prediction == 3])) 

Background - class 0
8866196
--------
Edema - class 1
46081
--------
Non-Enhancing-Tumor - class 2
11244
--------
Enhancing-Tumor - class 3
4479


In [92]:
classes_dict = {
    'Background': 0,
    'Edema': 1,
    'Non-enhancing tumor': 2,
    'Enhancing tumor': 3 
}

# Create button values
select_class = ToggleButtons(
    options=['Background','Edema', 'Non-enhancing tumor', 'Enhancing tumor', 'All'],
    description='Select Class:',
    disabled=False,
    button_style='info', 
    
)
# Create layer slider
select_layer = IntSlider(min=0, max=mri.shape[2] - 1, description='Select Layer', continuous_update=False)

    
# Define a function for plotting images
def plot(seg_class, layer, channel):
    print(f"Plotting Layer: {layer} | Label: {seg_class} | Channel: {channel}")
    fig = plt.figure(figsize=(20, 10))
    
    fig.add_subplot(1, 2, 1)
    plt.title("MRI_BRATS_485", fontsize=20)
    plt.imshow(mri[:, :, layer, channel], cmap='gray');
    plt.axis('off')
    
    fig.add_subplot(1, 2, 2)
    if seg_class == "All":
        mask_model_prediction = prediction[:, :, layer]
        plt.title("Model prediction mask", fontsize=20)
        plt.imshow(mask_model_prediction)
        plt.axis('off');
    else:
        img_label = classes_dict[seg_class]
        mask_model_prediction = np.where(prediction[:, :, layer] == img_label, 255, 0)
        plt.title("Model prediction mask", fontsize=20)
        plt.imshow(mask_model_prediction, cmap='gray')
        plt.axis('off');

# Set channel to view:
#  Channel 0: "FLAIR" Fluid-attenuated inversion recovery
#  Channel 1: "T1w" T1-weighted
#  Channel 2: "t1gd" T1-weighted with gadolinium contrast enhancement
#  Channel 3: "T2w" T2-weighted    
    
# Use the interactive() tool to create the visualization
interactive(plot, seg_class=select_class, layer=select_layer, channel=(0, 3))

interactive(children=(ToggleButtons(button_style='info', description='Select Class:', options=('Background', '…