# Overview

In this tutorial we will explore how to create a contract-expanding fully convolutional neural network (CNN) for segmentation of brain tumors from MRI.

## Workshop Links

Use the following link to access materials from this workshop: https://github.com/peterchang77/dl_tutor/tree/master/workshops

*Tutorials*

* Introduction to Tensorflow 2.0 and Keras: https://bit.ly/2VSYaop
* CNN for pneumonia classification: https://bit.ly/2D9ZBrX
* CNN for pneumonia segmentation: https://bit.ly/2VQMWk9
* CNN for brain tumor classification: https://bit.ly/44eekdt
* CNN for brain tumor segmentation: https://bit.ly/449i5Ro (**current tutorial**)

# Environment

The following lines of code will configure your Google Colab environment for this tutorial.

### Enable GPU runtime

Use the following instructions to switch the default Colab instance into a GPU-enabled runtime:

```
Runtime > Change runtime type > Hardware accelerator > GPU
```

### Jarvis library

In this notebook we will Jarvis, a custom Python package to facilitate data science and deep learning for healthcare. Among other things, this library will be used for low-level data management, stratification and visualization of high-dimensional medical data.

In [None]:
# --- Install Jarvis library
%pip install jarvis-md

### Imports

Use the following lines to import any needed libraries:

In [None]:
import numpy as np, pandas as pd
from tensorflow import losses, optimizers
from tensorflow.keras import Input, Model, models, layers, metrics
from jarvis.train import datasets, custom
from jarvis.utils.display import imshow

# Data

The data used in this tutorial will consist of brain tumor MRI exams derived from the MICCAI Brain Tumor Segmentation Challenge (BRaTS). More information about he BRaTS Challenge can be found here: http://braintumorsegmentation.org/. Each single 2D slice will consist of one of four different sequences (T2, FLAIR, T1 pre-contrast and T1 post-contrast). In this exercise, we will use this dataset to derive a model for slice-by-slice tumor segmentation. The custom `datasets.download(...)` method can be used to download a local copy of the dataset. By default the dataset will be archived at `/data/raw/mr_brats_2020`; as needed an alternate location may be specified using `datasets.download(name=..., path=...)`. 

In [None]:
# --- Download dataset
datasets.download(name='mr/brats-2020-mip')

Once downloaded, the `datasets.prepare(...)` method can be used to generate the required python Generators to iterate through the dataset, as well as a `client` object for any needed advanced functionality.

To specificy the correct Generator template file, pass a designated `keyword` string. In this tutorial, we will be using brain MRI volumes that have been preprocessed using a *mean intensity projection* (MIP) algorithm to subsample the original 155-slice inputs to 40-50 slices, facilitating ease of algorithm training within the Google Colab platform. In addition we will be performing voxel-level tumor prediction (e.g., a prediction for every single voxel in the 3D volume). To select the correct Client template for this task, use the keyword string `mip*vox`. 

Finally, for sake of simplicity, this tutorial will binarize the ground-truth labels (instead of the original four separate tumor classes). To do so, pass the following `configs` dictionary into the `datasets.prepare(...)` method. As needed, modify the custom `configs` dictionary with additional configurations as needed (e.g. batch size, normalization parameters, etc). 

In [None]:
# --- Prepare generators
configs = {'specs': {'ys': {'tumor': {'norms': {'clip': {'max': 1}}}}}}
gen_train, gen_valid, client = datasets.prepare(name='mr/brats-2020-mip', keyword='mip*vox', configs=configs)

As before, each iteration yields two variables, `xs` and `ys`, each representing a dictionary of model input(s) and output(s). In the current example, there is just a single input and output. Let us examine the generator data:

In [None]:
# --- Yield one example
xs, ys = next(gen_train)

# --- Print dict keys
print('xs keys: {}'.format(xs.keys()))
print('ys keys: {}'.format(ys.keys()))

In [None]:
# --- Print data shape
print('xs shape: {}'.format(xs['dat'].shape))
print('ys shape: {}'.format(ys['tumor'].shape))

### BRATS Data

Let us get to know this data a little bit more. The input images in the variable `dat` are matrices of shape `1 x 240 x 240 x 4`. Note that even though the images here are 2D in shape, the full matrix is a 3D tensor `(z, y, x)` where `z = 1` in this implementation. Note that although the 3rd z-axis dimension is redundant here (for a single slice input), many of our more complex models and architectures will commonly require a full 3D tensor. Because of this, we will directly use 3D convolutions throughout the tutorial materials for consistency.

In [None]:
print(xs['dat'][0].shape)

The four channels represent four different input MRI sequences. Each sequence is used to evaluate for a different tissue quality. T2 and FLAIR are used to evaluate edema (fluid) that results from brain injury. T1 images are used to evaluate anatomy and breakdown of the blood-brain-barrier through contrast leakge.  

```
dat[..., 0] = FLAIR
dat[..., 1] = T1 precontrast
dat[..., 2] = T1 postcontrast
dat[..., 3] = T2
```

To visualize these different modalities run the following cell:

In [None]:
imshow(xs['dat'][..., 0].astype('float32'), title='FLAIR')
imshow(xs['dat'][..., 1].astype('float32'), title='T1 precontrast')
imshow(xs['dat'][..., 2].astype('float32'), title='T1 postcontrast')
imshow(xs['dat'][..., 3].astype('float32'), title='T2')

### Tumor masks

The ground-truth labels are binary masks of the same matrix shape as the model input:

In [None]:
print(ys['tumor'][0].shape)

Use the `imshow(...)` method to visualize the ground-truth tumor mask labels:

In [None]:
# --- Show tumor masks overlaid on original data
imshow(xs['dat'], ys['tumor'])

# --- Show tumor masks isolated
imshow(ys['tumor'])

### Model inputs

For every input in `xs`, a corresponding `Input(...)` variable should be created for model development:

In [None]:
# --- Create model inputs
inputs = {}
inputs['dat'] = Input(shape=(1, 240, 240, 4))

# U-Net Architecture

The **U-Net** architecture is a common fully-convolutional neural network used to perform instance segmentation. The network topology comprises of symmetric contracting and expanding arms to map an original input image to an output segmentation mask that appoximates the size of the original image:

![U-Net Architecture](https://raw.githubusercontent.com/peterchang77/dl_tutor/master/cs190/spring_2020/notebooks/organ_segmentation/pngs/u-net-architecture.png)

# Contracting Layers

The contracting layers of a U-Net architecture are essentially identical to a standard feed-forward CNN. Compared to the original architecture above, several key modifications will be made for ease of implementation and to optimize for medical imaging tasks including:

* same padding (vs. valid padding)
* strided convoltions (vs. max-pooling)
* smaller filters (channel depths)

Let us start by defining the contracting layer architecture below:

In [None]:
# --- Define kwargs dictionary
kwargs = {
    'kernel_size': (1, 3, 3),
    'padding': 'same'}

# --- Define lambda functions
conv = lambda x, filters, strides : layers.Conv3D(filters=filters, strides=strides, **kwargs)(x)
norm = lambda x : layers.BatchNormalization()(x)
relu = lambda x : layers.ReLU()(x)

# --- Define stride-1, stride-2 blocks
conv1 = lambda filters, x : relu(norm(conv(x, filters, strides=1)))
conv2 = lambda filters, x : relu(norm(conv(x, filters, strides=(1, 2, 2))))

Using these lambda functions, let us define a simple 9-layer contracting network topology with a total a four subsample (stride-2 convolution) operations:

In [None]:
# --- Define contracting layers
l1 = conv1(16, inputs['dat'])
l2 = conv1(32, conv2(32, l1))
l3 = conv1(48, conv2(48, l2))
l4 = conv1(64, conv2(64, l3))
l5 = conv1(80, conv2(80, l4))

**Checkpoint**: What is the shape of the `l5` feature map?

# Expanding Layers

The expanding layers are simply implemented by reversing the operations found in the contract layers above. Specifically, each subsample operation is now replaced by a **convolutional transpose**. Due to the use of **same** padding, defining a transpose operation with the exact same parameters as a strided convolution will ensure that layers in the expanding pathway will exactly match the shape of the corresponding contracting layer.

### Convolutional transpose

Let us start by defining an additional lambda function for the convolutional transpose:

In [None]:
# --- Define single transpose
tran = lambda x, filters, strides : layers.Conv3DTranspose(filters=filters, strides=strides, **kwargs)(x)

# --- Define transpose block
tran2 = lambda filters, x : relu(norm(tran(x, filters, strides=(1, 2, 2))))

Carefully compare these functions to the single `conv` operations as well as the `conv1` and `conv2` blocks above. Notice that they share the exact same configurations.

Let us now apply the first convolutional transpose block to the `l5` feature map:

In [None]:
# --- Define expanding layers
l6 = tran2(64, l5)

**Checkpoint**: What is the shape of the `l6` feature map?

### Concatenation

The first connection in this specific U-Net derived architecture is a link between the `l4` and the `l6` layers:

```
l1 -------------------> l9
  \                    /
   l2 -------------> l8
     \              /   
      l3 -------> l7
        \        /
         l4 -> l6
           \  /
            l5
```

To mediate the first connection between contracting and expanding layers, we must ensure that `l4` and `l6` match in feature map size (the number of filters / channel depth *do not* necessarily). Using the `same` padding as above should ensure that this is the case and thus simplifies the connection operation:

In [None]:
# --- Ensure shapes match
print(l4.shape)
print(l6.shape)

# --- Concatenate
concat = lambda a, b : layers.Concatenate()([a, b])
concat(l4, l6)

Note that since `l4` and `l6` are **exactly the same shape** (including matching channel depth), what additional operation could be used here instead of a concatenation?

### Full expansion

Alternate the use of `conv1` and `tran2` blocks to build the remainder of the expanding pathway:

In [None]:
# --- Define expanding layers
l7 = tran2(48, conv1(64, concat(l4, l6)))
l8 = tran2(32, conv1(48, concat(l3, l7)))
l9 = tran2(16,  conv1(32, concat(l2, l8)))
l10 = conv1(16, l9)

# Logits

The last convolution projects the `l10` feature map into a total of just `n` feature maps, one for each possible class prediction. In this 2-class prediction task, a total of `2` feature maps will be needed. Recall that these feature maps essentially act as a set of **logit scores** for each voxel location throughout the image. As with a standard CNN architecture, **do not** use an activation here in the final convolution:

In [None]:
# --- Create logits
logits = {}
logits['tumor'] = layers.Conv3D(filters=2, name='tumor', **kwargs)(l10)

# Model

Let us first create our model:

In [None]:
# --- Create model
model = Model(inputs=inputs, outputs=logits)

### Custom Dice score metric

The metric of choice for tracking performance of a medical image segmentation algorithm is the **Dice score**. The Dice score is not a default metric built in the Tensorflow library, however a custom metric is available for your convenience as part of the `jarvis-md` package. It is invoked using the `custom.dsc(cls=...)` call, where the argument `cls` refers to the number of *non-zero* classes to track (e.g. the background Dice score is typically not tracked). In this exercise, it will be important to track the performance of segmentation for **pneumonia** (class = 1) only, thus set the `cls` argument to `1`.

In [None]:
# --- Compile model
model.compile(
    optimizer=optimizers.Adam(learning_rate=2e-4),
    loss={'tumor': losses.SparseCategoricalCrossentropy(from_logits=True)},
    metrics={'tumor': custom.dsc(cls=1)},
    experimental_run_tf_function=False)

# Model Training

### In-Memory Data

The following line of code will load all training data into RAM memory. This strategy can be effective for increasing speed of training for small to medium-sized datasets.

In [None]:
# --- Load data into memory
client.load_data_in_memory()

### Training

Once the model has been compiled and the data prepared (via a generator), training can be invoked using the `model.fit(...)` method. Ensure that both the training and validation data generators are used. In this particular example, we are defining arbitrary epochs of 100 steps each. Training will proceed for 8 epochs in total. Validation statistics will be assess every fourth epoch. As needed, tune these arugments as need.

In [None]:
model.fit(
    x=gen_train, 
    steps_per_epoch=100, 
    epochs=8,
    validation_data=gen_valid,
    validation_steps=100,
    validation_freq=4)

# Evaluation

To test the trained model, the following steps are required:

* load data
* use `model.predict(...)` to obtain logit scores
* use `np.argmax(...)` to obtain prediction
* compare prediction with ground-truth

Recall that the generator used to train the model simply iterates through the dataset randomly. For model evaluation, the cohort must instead be loaded manually in an orderly way. For this tutorial, we will create new **test mode** data generators, which will simply load each example individually once for testing. 

In [None]:
# --- Create validation generator
test_train, test_valid = client.create_generators(test=True)

### Dice score

While the Dice score metric for Tensorflow has been provided already, an implementation must still be used to manually calculate the performance during validation. Use the following code cell block to implement:

In [None]:
def dice(y_true, y_pred, c=1, epsilon=1):
    """
    Method to calculate the Dice score coefficient for given class
    
    :params
    
      (np.ndarray) y_true : ground-truth label
      (np.ndarray) y_pred : predicted logits scores
      (int)             c : class to calculate DSC on
    
    """
    assert y_true.ndim == y_pred.ndim
    
    true = y_true[..., 0] == c
    pred = np.argmax(y_pred, axis=-1) == c 

    A = np.count_nonzero(true & pred) * 2
    B = np.count_nonzero(true) + np.count_nonzero(pred) + epsilon
    
    return A / B

Use the following lines of code to loop through the test set generator and run model prediction on each example:

In [None]:
# --- Test model
dsc = []

for x, y in test_valid:
    
    if y['tumor'].any():
        
        # --- Predict
        x['dat'] = x['dat'].reshape(-1, 1, 240, 240, 4)
        logits = model.predict(x['dat'])

        if type(logits) is dict:
            logits = logits['tumor']

        # --- Argmax
        dsc.append(dice(y['tumor'][0], logits[0], c=1))

dsc = np.array(dsc)

Use the following lines of code to calculate validataion cohort performance:

In [None]:
# --- Calculate accuracy
print('{}: {:0.5f}'.format('Mean Dice'.ljust(20), np.mean(dsc)))
print('{}: {:0.5f}'.format('Median Dice'.ljust(20), np.median(dsc)))
print('{}: {:0.5f}'.format('25th-centile Dice'.ljust(20), np.percentile(dsc, 25)))
print('{}: {:0.5f}'.format('75th-centile Dice'.ljust(20), np.percentile(dsc, 75)))

## Saving and Loading a Model

After a model has been successfully trained, it can be saved and/or loaded by simply using the `model.save()` and `models.load_model()` methods. 

In [None]:
# --- Serialize a model
model.save('./cnn.hdf5')

In [None]:
# --- Load a serialized model
del model
model = models.load_model('./cnn.hdf5', compile=False)