# Exploring MRI data and labels
3D MRI brain scans from the public [Medical Segmentation Decathlon](https://decathlon-10.grand-challenge.org/) challenge project.


<img src="images/mri-slice.png" alt="U-net Image" width="300"/>

### Importing packages

In [None]:
import numpy as np
import nibabel as nib
import itk
import itkwidgets
from ipywidgets import interact, interactive, IntSlider, ToggleButtons
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set_style('darkgrid')

# 1 Exploring the data

### 1.1  Loading images of the Brain
Grabbing a single 3D MRI brain scan

In [None]:
image_path = "data\BraTS-Data\imagesTr\BRATS_001.nii.gz"
image_obj = nib.load(image_path)
print(f"Type of the image {type(image_obj)}")

### 1.2 Extracting the data as a numpy array

In [None]:
# Extracting the data using the get_fdata() method of the image object
image_data = image_obj.get_fdata()
type(image_data)

In [None]:
# get the image shape 
height, width, depth, channels = image_data.shape
print(f"the image object has these dimensions: \n height:{height} \n widht:{width} \n depth:{depth} \n channels:{channels}")

### 1.3 Visualise the data
The "depth" indicated that there are 155 layers (slices through the brain) in every image object.


In [None]:
# select a random layer number
maxval = 154
i = np.random.randint(0,maxval)

# defining a channel to look at
channel = 0
print(f"Displaying layer {i} Channel {channel} of Image")
plt.imshow(image_data[:,:,i,channel], cmap='gray')
plt.axis('off')

#### 1.3.1 Interactive Exploration

In [None]:
# define function to visualise the data
def explore_3dimage(layer):
    plt.figure(figsize=(10,5))
    channel=3
    plt.imshow(image_data[:,:,layer,channel],cmap='gray')
    plt.title("explore layers of brain mri", fontsize=20)
    plt.axis('off')
    return layer

# running the ipywidget's interact() function to explore the data
interact(explore_3dimage, layer=(0,image_data.shape[2]-1))

## 2. Explore the data lables
We'll load a new dataset containing the lables for the MRI scan we loaded above

In [None]:
# define the data path and load the data
label_path = "data/BraTS-Data/labelsTr/BRATS_001.nii.gz"
label_obj = nib.load(label_path)
type(label_obj)

### Extract the data labels as a numpy array

In [None]:
# using the get_fdata() method of the image object
label_array = label_obj.get_fdata()
type(label_array)

In [None]:
# extract andprint out the shape of the labels data
height, width, depth = label_array.shape
print(f"dinemsionsof the labels data array: height: {height}, widht: {width}, depth:{depth} ")
print(f"with the unique values: {np.unique(label_array)}")
print(""" Corresponding to the following label categories:
0: for normal 
1: for edema
2: for non-enhancing tumor 
3: for enhancing tumor""")

### 2.2 Visualize the lables for a specificlayer
Visualising a single layer of the labeled data

In [None]:
# define a single layer to look at
layer = 50
# define a dictionary of class labels
classes_dict = {
    'Normal': 0.,
    'Edema': 1.,
    'Non-enhancing tumor': 2.,
    'Enhancing tumor': 3.
}
# set up for plotting
fix,ax = plt.subplots(nrows=1, ncols=4, figsize=(50,30))
for i in range(4):
    # converted to list and extracted ith
    img_label_str = list(classes_dict.keys())[i]
    img = label_array[:,:,layer]
    #where that img is equal to the img_label string
    mask = np.where(img==classes_dict[img_label_str],255,0)
    ax[i].imshow(mask)
    ax[i].set_title(f"Layer{layer} for {img_label_str}", forntsize=45)
    ax[i].axis('off')
plt.tight_layout()

#### 2.2.1 Interactive Visualisation across layers
here we can choose the class we want to look at by clicking a button to choose a particular label and scrolling across layers using the slider

In [None]:
# create button values
select_classes=ToggleButtons(
    options=['Normal','Edema', 'Non-enhancing tumor', 'Enhancing tumor'],
    description= 'Select Class:',
    disabled=False,
    button_style='info',
)

# create layer slider
select_layer = IntSlider(min=0. max=154, description ="Select Layer", continuous_update=False)

# define a function for plotting images
def plot_image(seg_class,layer):
    print(f"displaying {layer} layer label:{seg_class}")
    img_label = classes_dict[seg_class]
    mask = np.where(label_array[:,:,layer] == img_label, 255, 0)
    plt.figure(figsize=(10,5))
    plt.imshow(mask, cmap='gray')
    plt.axis("off")

# using interactive() to create the visualisation 
interactive(plot_image, seg_class=select_classes, layer =select_layer)