## Extract the Dataset

In [0]:
import zipfile  # For faster extraction
dataset_path = "MICCAI_BraTS_2018_Data_Training.zip"  # Replace with your dataset path
zfile = zipfile.ZipFile(dataset_path)
zfile.extractall()

## Imports and helper functions

In [0]:
import SimpleITK as sitk  # For loading the dataset
import numpy as np  # For data manipulation
from model import build_model  # For creating the model
import glob  # For populating the list of files
from scipy.ndimage import zoom  # For resizing
import re  # For parsing the filenames (to know their modality)

Using TensorFlow backend.


In [0]:
def read_img(img_path):
    """
    Reads a .nii.gz image and returns as a numpy array.
    """
    return sitk.GetArrayFromImage(sitk.ReadImage(img_path))

In [0]:
def resize(img, shape, mode='constant', orig_shape=(155, 240, 240)):
    """
    Wrapper for scipy.ndimage.zoom suited for MRI images.
    """
    assert len(shape) == 3, "Can not have more than 3 dimensions"
    factors = (
        shape[0]/orig_shape[0],
        shape[1]/orig_shape[1], 
        shape[2]/orig_shape[2]
    )
    
    # Resize to the given shape
    return zoom(img, factors, mode=mode)


def preprocess(img, out_shape=None):
    """
    Preprocess the image.
    Just an example, you can add more preprocessing steps if you wish to.
    """
    if out_shape is not None:
        img = resize(img, out_shape, mode='constant')
    
    # Normalize the image
    mean = img.mean()
    std = img.std()
    return (img - mean) / std


def preprocess_label(img, out_shape=None, mode='nearest'):
    """
    Separates out the 3 labels from the segmentation provided, namely:
    GD-enhancing tumor (ET — label 4), the peritumoral edema (ED — label 2))
    and the necrotic and non-enhancing tumor core (NCR/NET — label 1)
    """
    print(img.shape)
    print(np.unique(img))
    ncr = img == 1  # Necrotic and Non-Enhancing Tumor (NCR/NET)
    ed = img == 2  # Peritumoral Edema (ED)
    et = img == 4  # GD-enhancing Tumor (ET)
    
    if out_shape is not None:
        ncr = resize(ncr, out_shape, mode=mode)
        ed = resize(ed, out_shape, mode=mode)
        et = resize(et, out_shape, mode=mode)

    return np.array([ncr, ed, et], dtype=np.uint8)
    

## Loading Data


In [0]:
# Get a list of files for all modalities individually
t1 = glob.glob('*GG/*/*t1.nii.gz')
t2 = glob.glob('*GG/*/*t2.nii.gz')
flair = glob.glob('*GG/*/*flair.nii.gz')
t1ce = glob.glob('*GG/*/*t1ce.nii.gz')
seg = glob.glob('*GG/*/*seg.nii.gz')  # Ground Truth

Parse all the filenames and create a dictionary for each patient with structure:

{<br />
    &nbsp;&nbsp;&nbsp;&nbsp;'t1': _<path to t1 MRI file&gt;_,<br />
    &nbsp;&nbsp;&nbsp;&nbsp;'t2': _<path to t2 MRI&gt;_,<br />
    &nbsp;&nbsp;&nbsp;&nbsp;'flair': _<path to FLAIR MRI file&gt;_,<br />
    &nbsp;&nbsp;&nbsp;&nbsp;'t1ce': _<path to t1ce MRI file&gt;_,<br />
    &nbsp;&nbsp;&nbsp;&nbsp;'seg': _<path to Ground Truth file&gt;_,<br />
}<br />

In [0]:
pat = re.compile('.*_(\w*)\.nii\.gz')

data_paths = [{
    pat.findall(item)[0]:item
    for item in items
}
for items in list(zip(t1, t2, t1ce, flair, seg))]

## Load the data in a Numpy array
Creating an empty Numpy array beforehand and then filling up the data helps you gauge beforehand if the data fits in your memory.



_Loading only the first 4 images here, to save time._

In [0]:
input_shape = (4, 80, 96, 64)
# input_shape = (4, 155, 240, 240)
output_channels = 3
data = np.empty((len(data_paths[:4]),) + input_shape, dtype=np.float32)
labels = np.empty((len(data_paths[:4]), output_channels) + input_shape[1:], dtype=np.uint8)

In [0]:
import math

# Parameters for the progress bar
total = len(data_paths[:4])
step = 25 / total

for i, imgs in enumerate(data_paths[:4]):
    try:
        data[i] = np.array([preprocess(read_img(imgs[m]), input_shape[1:]) for m in ['t1', 't2', 't1ce', 'flair']], dtype=np.float32)
        labels[i] = preprocess_label(read_img(imgs['seg']), input_shape[1:])[None, ...]
        
        # Print the progress bar
        print('\r' + 'Progress: ' + "[%s %s]"%('=' * int((i+1) * step), ' ' * (24 - int((i+1) * step))) + "(%s percentage)"%(math.ceil((i+1) * 100 / (total))),
            end='')
    except Exception as e:
        print('Something went wrong with %s, skipping...\n Exception:\n%s'%(imgs["t1"], str(e)))
        continue

In [0]:
labels.shape

## Model

build the model

In [0]:
model = build_model(input_shape=input_shape, output_channels=3)

Train the model

In [0]:
# model.fit(data, labels, batch_size=1, epochs=5)
# preds = model.predict(np.array([data[0]]))
import matplotlib.pyplot as plt
pred = preds[0]
print(pred.shape)
print(pred[0,:,:,:].shape)
# pred[:, 50,:,:] #50th slice 
# img = pred.sum(axis=0)
# img = img.sum(axis=0)
# img = (img>1).astype(np.uint8)
# print(img.shape)
# print(np.unique(img))
# print(img.sum())
# plt.imshow(pred[:, 50,:,:], cmap='Greys_r')

That's it!

## Closing Regards

If you are resizing the segmentation mask, the resized segmentation mask retains the overall shape, but loses a lot of pixels and becomes somewhat 'grainy'. See the illustration below.

1. Original segmentation mask:

In [0]:
import matplotlib.pyplot as plt
img = read_img(seg[0])
print(img.shape)
img = img.sum(axis=0)
img = (img>1).astype(np.uint8)
print(img.shape)
print(np.unique(img))
print(img.sum())
plt.imshow(img, cmap='Greys_r')

In [0]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
slice_no = 60
s = read_img('/home/suraj/Documents/smita/MICCAI_BraTS_2018_Data_Training/LGG/Brats18_TCIA12_101_1/Brats18_TCIA12_101_1_flair.nii.gz')
print(s.shape)
# s = s.sum(axis=0)
# plt.imshow(s[slice_no,:,:], cmap='gray')
fig = plt.figure() # make figure

# make axesimage object
# the vmin and vmax here are very important to get the color map correct
im = plt.imshow(s[0,:,:], cmap='gray')

# function to update figure
def updatefig(j):
    # set the data in the axesimage object
    print(j)
    im.set_array(s[j])
    # return the artists set
    return [im]
# kick off the animation
ani = animation.FuncAnimation(fig, updatefig, frames=range(1,155), 
                              interval=50, blit=True)
plt.show()

In [0]:
s = read_img('/home/suraj/Documents/smita/MICCAI_BraTS_2018_Data_Training/LGG/Brats18_TCIA12_101_1/Brats18_TCIA12_101_1_t1.nii.gz')
print(s.shape)
# s = s.sum(axis=0)
plt.imshow(s[slice_no], cmap='gray')

In [0]:
s = read_img('/home/suraj/Documents/smita/MICCAI_BraTS_2018_Data_Training/LGG/Brats18_TCIA12_101_1/Brats18_TCIA12_101_1_t2.nii.gz')
print(s.shape)
# s = s.sum(axis=0)
plt.imshow(s[slice_no], cmap='gray')

In [0]:
s = read_img('/home/suraj/Documents/smita/MICCAI_BraTS_2018_Data_Training/LGG/Brats18_TCIA12_101_1/Brats18_TCIA12_101_1_t1ce.nii.gz')
print(s.shape)
# s = s.sum(axis=0)
plt.imshow(s[slice_no], cmap='gray')

In [0]:
s = read_img('/home/suraj/Documents/smita/MICCAI_BraTS_2018_Data_Training/LGG/Brats18_TCIA12_101_1/Brats18_TCIA12_101_1_seg.nii.gz')
print(s.shape)
# s = s.sum(axis=0)
# s = (s>1).astype(np.uint8)
plt.imshow(s[90], cmap='Greys_r')

After resizing to (80, 96, 64)

In [0]:
img = preprocess_label(read_img(seg[0]), out_shape=(80, 96, 64), mode='nearest')
print(img.shape)
img = preprocess_label(read_img(seg[0]), out_shape=(155, 240, 240), mode='nearest')
img = img.sum(axis=0)
img = img.sum(axis=0)
img = (img>1).astype(np.uint8)
print(img.shape)
print(np.unique(img))
print(img.sum())
plt.imshow(img, cmap='Greys_r')

One can clearly notice that there are now a lot of black pixels in the region where there should have been only white pixels. This can potentially hurt our model. So, it is best to not resize the image too much. But, due to computational constraints and the model requirements, it is unavoidable. 

However, given below are a few things one could try to reduce the downsampling noise as much as possible.

In [0]:
import cv2

- Original Image > preprocess_label > Morphological Closing

In [0]:
kernel = np.ones((3, 3))
img_closed = cv2.morphologyEx(img, cv2.MORPH_CLOSE, kernel, iterations=3)
print(np.unique(img_closed))
print(img_closed.sum())
plt.imshow(img_closed, cmap='Greys_r')

- Original Image > preprocess_label > Morphological Dilation

In [0]:
kernel = np.ones((3, 3))
img_dilated = cv2.dilate(img, kernel, iterations=1)
print(np.unique(img_dilated))
print(img_dilated.sum())
plt.imshow(img_dilated, cmap='Greys_r')

You could try these things to get even better results.

## Feedback

If you have any feedback, queries, bug reports to send, please feel free to [raise an issue](https://github.com/IAmSuyogJadhav/3d-mri-brain-tumor-segmentation-using-autoencoder-regularization/issues/new) on github. It would be really helpful!