# Overview

In this tutorial we will explore a few of the advanced approaches to customize a segmentation algorithm for medical imaging. These approaches will build upon the standard U-Net contracting-expanding CNN architecture. As in the prior tutorial, the goal remains to perform brain tumor segmentation on multi-sequence MRI.

This tutorial is part of the class **Introduction to Deep Learning for Medical Imaging** at University of California Irvine (CS190); more information can be found at: https://github.com/peterchang77/dl_tutor/tree/master/cs190.

# Google Colab

The following lines of code will configure your Google Colab environment for this tutorial.

### Enable GPU runtime

Use the following instructions to switch the default Colab instance into a GPU-enabled runtime:

```
Runtime > Change runtime type > Hardware accelerator > GPU
```

# Environment

### Jarvis library

In this notebook we will Jarvis, a custom Python package to facilitate data science and deep learning for healthcare. Among other things, this library will be used for low-level data management, stratification and visualization of high-dimensional medical data.

In [None]:
# --- Install jarvis (only in Google Colab or local runtime)
% pip install jarvis-md

### Imports

Use the following lines to import any additional needed libraries:

In [None]:
import numpy as np, pandas as pd
from tensorflow import losses, optimizers
from tensorflow.keras import Input, Model, models, layers
from jarvis.train import datasets, custom
from jarvis.utils.display import imshow

# Data

The data used in this tutorial will consist of brain tumor MRI exams derived from the MICCAI Brain Tumor Segmentation Challenge (BRaTS). More information about he BRaTS Challenge can be found here: http://braintumorsegmentation.org/. Each single 2D slice will consist of one of four different sequences (T2, FLAIR, T1 pre-contrast and T1 post-contrast). In this exercise, we will use this dataset to derive a model for slice-by-slice tumor segmentation. The custom `datasets.download(...)` method can be used to download a local copy of the dataset. By default the dataset will be archived at `/data/raw/mr_brats_2020`; as needed an alternate location may be specified using `datasets.download(name=..., path=...)`. 

In [None]:
# --- Download dataset
datasets.download(name='mr/brats-2020-mip')

Once downloaded, the `datasets.prepare(...)` method can be used to generate the required python Generators to iterate through the dataset, as well as a `client` object for any needed advanced functionality.

To specificy the correct Generator template file, pass a designated `keyword` string. In this tutorial, we will be using brain MRI volumes that have been preprocessed using a *mean intensity projection* (MIP) algorithm to subsample the original 155-slice inputs to 40-50 slices, facilitating ease of algorithm training within the Google Colab platform. In addition we will be performing voxel-level tumor prediction (e.g., a prediction for every single voxel in the 3D volume). To select the correct Client template for this task, use the keyword string `mip*vox`. 

Finally, for sake of simplicity, this tutorial will binarize the ground-truth labels (instead of the original four separate tumor classes). To do so, pass the following `configs` dictionary into the `datasets.prepare(...)` method. As needed, modify the custom `configs` dictionary with additional configurations as needed (e.g. batch size, normalization parameters, etc). 

In [None]:
# --- Prepare generators
configs = {'specs': {'ys': {'tumor': {'norms': {'clip': {'max': 1}}}}}}
gen_train, gen_valid, client = datasets.prepare(name='mr/brats-2020-mip', keyword='mip*vox', configs=configs)

### Model inputs

For every input in `xs`, a corresponding `Input(...)` variable can be created and returned in a `inputs` dictionary for ease of model development:

In [None]:
# --- Create model inputs
inputs = client.get_inputs(Input)

print(inputs.keys())
print(inputs['dat'].shape)

In this example, the equivalent Python code to generate `inputs` would be:

```python
inputs = {}
inputs['dat'] = Input(shape=(1, 240, 240, 4))
```

# Network Connections

A key component of the contracting-expanding segmentation architectures is the use of connections to combine both low- and high-level features. In addition to the standard **concatenation** operation, several variations can be used.

Let us start by building a contracting layers of a standard 2D U-Net:

In [None]:
# --- Define kwargs
kwargs = {
    'kernel_size': (1, 3, 3),
    'padding': 'same',
    'kernel_initializer': 'he_normal'}

# --- Define block components
conv = lambda x, filters, strides : layers.Conv3D(filters=filters, strides=strides, **kwargs)(x)
tran = lambda x, filters, strides : layers.Conv3DTranspose(filters=filters, strides=strides, **kwargs)(x)

norm = lambda x : layers.BatchNormalization()(x)
relu = lambda x : layers.ReLU()(x)

conv1 = lambda filters, x : relu(norm(conv(x, filters, strides=1)))
conv2 = lambda filters, x : relu(norm(conv(x, filters, strides=(1, 2, 2))))
tran2 = lambda filters, x : relu(norm(tran(x, filters, strides=(1, 2, 2))))

# --- Define contracting layers
l1 = conv1(8, inputs['dat'])
l2 = conv1(16, conv2(16, l1))
l3 = conv1(32, conv2(32, l2))
l4 = conv1(48, conv2(48, l3))
l5 = conv1(64, conv2(64, l4))

### Residual connection

If *SAME* padding is used throughout the network architecture, and if the number of filters used at each block is symmetric, then the corresponding contracting and expanding layers should have exactly the same feature map size. In this scenario, a **residual** (addition) operation can be used instead of the standard concatenation. 

In [None]:
# --- Use residual connections
l6 = tran2(48, l5)
l7 = tran2(32, conv1(48, l6 + l4))
l8 = tran2(16, conv1(32, l7 + l3))
l9 = tran2(8,  conv1(16, l8 + l2))

What are the advantages or disadvantages of this approach?

### Multiple operations

As discussed in the lecture, high-resolution but shallow layers in the contracting arm of the network may sometimes be too "raw" and thus introduce noise into the network predictions. To help overcome this effect, consider the use of additional operations to refine the shallow contracting layers **prior** to combination with the expanding arm.

In [None]:
# --- Use multiple operations
l6 = tran2(48, l5)
l7 = tran2(32, conv1(48, l6 + conv1(48, l4)))
l8 = tran2(16, conv1(32, l7 + conv1(32, l3)))
l9 = tran2(8,  conv1(16, l8 + conv1(16, l2)))

What are the advantages or disadvantages of this approach?

# Hybrid 3D / 2D Network

As discussed in lecture, the **hybrid 3D/2D** network architecture is a modified approach to use a 3D *slab* input to predict every desired 2D slice output. Accordinging, the contracting layer will comprise of **3D** operations, like the expanding layer will comprise of **2D** operations. 

For this example, let us consider the following inputs and outputs:

* input: 3-slice volume (None, 3, 240, 240, 4)
* output: 1-slice prediction (None, 1, 240, 240, 1) 

Let us look at the required architecture modifications in detail.

### 3D inputs

As discussed, the network will use a 3-slice `3D` input. To change the input shape yielded by the Python generators, use the following `specs` dictionary shown below with preparing a dataset. 

In [None]:
# --- Prepare generators
configs = {
    'specs': {
        'xs': {'dat': {'shape': [3, 240, 240, 4]}},
        'ys': {'tumor': {'norms': {'clip': {'max': 1}}}}}}

gen_train, gen_valid, client = datasets.prepare(name='mr/brats-2020-mip', keyword='mip*vox', configs=configs)

To confirm the correct *new* 3D input shape:

In [None]:
# --- Load batch of data
xs, ys = next(gen_train)

# --- Print data shape
print('xs shape: {}'.format(xs['dat'].shape))
print('ys shape: {}'.format(ys['tumor'].shape))

### Contracting layers

Within the contracting layers, in addition to the use of 3D convolutions, downsampling techniques must be carefully implemented. Specifically:

* use **strided convolutions** (with *SAME* padding) to subsample in the *xy-* direction
* use **VALID** padding convolutions with (2, 1, 1) kernels to subample in the *z-* direction

This design is deliberate in order to gradually decrease the *n*-slice input over time (a strided convolution in the z-direction will be too aggressive). In addition, *VALID* padding in the z-direction ensures that the algorithm inference can be easily performed on an arbitrary number of slices simultaneously (e.g. using *SAME* padding would require dividing a full 3D volume into *n* number of 3-slice inputs and repeatedly running prediction).

Use of the following code cell two define these two different types of convolutions:

In [None]:
# --- Define 2D conv (xy-features)
conv_2d = lambda x, filters, strides : layers.Conv3D(
    filters=filters, 
    strides=strides, 
    kernel_size=(1, 3, 3), 
    padding='same',
    kernel_initializer='he_normal')(x)

# --- Define 1D conv (z-features)
conv_1d = lambda x, filters, k=2 : layers.Conv3D(
    filters=filters,
    strides=1,
    kernel_size=(k, 1, 1),
    padding='valid',
    kernel_initializer='he_normal')(x)

Note that both convolutions are of class `layers.Conv3D(...)`, however the *VALID* padded (k, 1, 1) convolution is functionally a 1D convolution that only looks to decrease the number of slices in the feature maps by 1. 

Finally, note that instead of hard-coding a (2, 1, 1) 1D kernel, we will instead use a (k, 1, 1) kernel, that can be defined dynamically. This will be important as we create other 1D kernel shapes in the expanding layer below.

Based on this, the following lambda helper functions can be defined:

In [None]:
# --- Define lambda functions
norm = lambda x : layers.BatchNormalization()(x)
relu = lambda x : layers.LeakyReLU()(x)

# --- Define stride-1 3D, stride-2 3D and stride-1 1D (z-subsample) blocks
conv1 = lambda filters, x : relu(norm(conv_2d(x, filters, strides=(1, 1, 1))))
conv2 = lambda filters, x : relu(norm(conv_2d(x, filters, strides=(1, 2, 2))))
convZ = lambda filters, k, x : relu(norm(conv_1d(x, filters, k=k)))

Use these lambda functions, whenever we need to subsample in the *xy-* direction we will use `conv2`. Whenever we need to subsample in the *z-* direction we will use `convZ`. The following code cell demonstrates usage:

In [None]:
# --- Define arbitrary input
dat = Input(shape=(3, 240, 240, 4))

# --- Define contracting layers
l1 = conv1(8,  inputs['dat'])
l2 = conv1(16, conv2(16, l1))
l3 = conv1(32, conv2(32, l2))
l4 = conv1(48, convZ(48, 2, conv2(48, l3)))
l5 = conv1(64, convZ(64, 2, conv2(64, l4)))

What are the sizes of all the layers `l1`, `l2`, etc...? Which ones are 3-slices, 2-slices, and 1-slice?

In [None]:
# --- Create a temporary model and pass in dat
Model(inputs=inputs, outputs=l5)({'dat': dat})

### Expanding layers

In the expanding layers, in addition to the use of 2D operations, the connections between contracting and expanding layers must be carefully implemented to ensure that the *3D* contracting layers can be concatenated or added to the *2D* expanding layers. While the *xy-* feature map size should matching, the *z-* feature map size may not. To convert a 3D (Z, n, n) feature map to a 2D (1, n, n) feature map, consider using a (Z, 1, 1) valid padded convolution, similar to the 1D convolution defined above in the contracting layers.

The following cell demonstrates how to reuse the `convZ` lambda function to implement this concept:

In [None]:
# --- 3-slices to 1-slice
p3 = convZ(32, 3, l3)
print(Model(inputs=inputs, outputs=l3)({'dat': dat}).shape)
print(Model(inputs=inputs, outputs=p3)({'dat': dat}).shape)

# --- 2-slices to 1-slice
p4 = convZ(48, 2, l4)
print(Model(inputs=inputs, outputs=l4)({'dat': dat}).shape)
print(Model(inputs=inputs, outputs=p4)({'dat': dat}).shape)

Keep in mind that when you use this `convZ(..., k, ...)` lambda function, the `k` represents the *number of slices* that the input feature map consists of, indicating the a (k, 1, 1) kernel_size is required (to convert to a single slice feature map output). 

Now let us create the remaining 2D convolutional transpose operations:

In [None]:
# --- Define 2D transpose
tran = lambda x, filters : layers.Conv3DTranspose(
    filters=filters, 
    strides=(1, 2, 2),
    kernel_size=(1, 3, 3),
    padding='same',
    kernel_initializer='he_normal')(x)

# --- Define transpose block
tran2 = lambda filters, x : relu(norm(tran(x, filters)))

The following code cell demonstrates creation of the expanding layers. Recall the U-Net diagram to ensure that the correct layers are combined:

```
l1 -------------------> l9
  \                    /
   l2 -------------> l8
     \              /   
      l3 -------> l7
        \        /
         l4 -> l6
           \  /
            l5
```

As described in the lecture and previous tutorial, connections can be implemented either via concatenation or residual connections:

In [None]:
# --- Create expanding layers using concatenation
concat = lambda a, b : layers.Concatenate()([a, b])

l6 =  tran2(48, conv1(48, l5))
l7 =  tran2(32, conv1(48, concat(convZ(48, 2, l4), l6)))
l8 =  tran2(16, conv1(32, concat(convZ(32, 3, l3), l7)))
l9 =  tran2(8,  conv1(16, concat(convZ(16, 3, l2), l8)))
l10 = conv1(8,  conv1(8,  concat(convZ(8,  3, l1), l9)))

In [None]:
# --- Create expanding layers using residual
l6 =  tran2(48, conv1(48, l5))
l7 =  tran2(32, conv1(48, convZ(48, 2, l4) + l6))
l8 =  tran2(16, conv1(32, convZ(32, 3, l3) + l7))
l9 =  tran2(8,  conv1(16, convZ(16, 3, l2) + l8))
l10 = conv1(8,  conv1(8,  convZ(8,  3, l1) + l9))

What are the advantages or disadvantages of either?

# Logits

The last convolution projects the `l10` feature map into a totla of just `n` feature maps, one for each possible class prediction. In this 3-class prediction task, a total of `3` feature maps will be needed. Recall that these feature maps essentially act as a set of **logit scores** for each voxel location throughout the image.

As in all prior exercises, **do not** use an activation here in the final convolution:

In [None]:
# --- Create logits
logits = {}
logits['tumor'] = layers.Conv3D(
    name='tumor',
    filters=2, 
    strides=1, 
    kernel_size=(1, 3, 3), 
    padding='same',
    kernel_initializer='he_normal')(l10)

# Model

Let us first create our model:

In [None]:
# --- Create model
model = Model(inputs=inputs, outputs=logits)

### Custom Dice score metric

For model compilation, a single modification from the prior tutorials is needed to account for a specific Keras metric to track segmentation performance. Recall in lecture that the metric of choice for this task is the **Dice score**. The Dice score is not a default metric built in the Tensorflow library, however a custom metric is available for your convenience as part of the `jarvis-md` package. It is invoked using the `custom.dsc(cls=...)` call, where the argument `cls` refers to the number of *non-zero* classes to track (e.g. the background Dice score is typically not tracked). In this exercise, it will be important to track the performance of the tumor (foreground) class only, thus set the `cls` argument to `1`.

In [None]:
# --- Compile model
model.compile(
    optimizer=optimizers.Adam(learning_rate=2e-4),
    loss={'tumor': losses.SparseCategoricalCrossentropy(from_logits=True)},
    metrics={'tumor': custom.dsc(cls=1)},
    experimental_run_tf_function=False)

### Training

For moderate sized datasets which are too large to fit into immediate hard-drive cache, but small enough to fit into RAM memory, it is often times a good idea to first load all training data into RAM memory for increased speed of training. The `client` can be used for this purpose as follows:

In [None]:
client.load_data_in_memory()

Use the following code cell to train the algorithm:

In [None]:
model.fit(
    x=gen_train, 
    steps_per_epoch=500, 
    epochs=12,
    validation_data=gen_valid,
    validation_steps=500,
    validation_freq=4,
    use_multiprocessing=True)

### Dice score

While the Dice score metric for Tensorflow has been provided already, an implementation must still be used to manually calculate the performance during validation. Use the following code cell block to implement:

In [None]:
def dice(y_true, y_pred, c=1, epsilon=1):
    """
    Method to calculate the Dice score coefficient for given class
    
    :params
    
      (np.ndarray) y_true : ground-truth label
      (np.ndarray) y_pred : predicted logits scores
      (int)             c : class to calculate DSC on
    
    """
    assert y_true.ndim == y_pred.ndim
    
    true = y_true[..., 0] == c
    pred = np.argmax(y_pred, axis=-1) == c 

    A = np.count_nonzero(true & pred) * 2
    B = np.count_nonzero(true) + np.count_nonzero(pred) + epsilon
    
    return A / B

### Hybrid 3D / 2D inference

In addition to the standard approach for running model inference (e.g. create test generators, load entire 3D volume and use `model.predict(...)`), recall that for a hydrid 3D/2D model one must pad the input volume in the z-direction to account for *VALID* padded (2, 1, 1) convolutions during the contracting layers. For a model trained on 3-slice inputs, a total of two *VALID* padded (2, 1, 1) convolutions were defined, so the z-direction must be padded by 2 (one on top, and on bottom). Use `np.pad(...)` to perform this preprocessing step as shown below:

In [None]:
# --- Create validation generator
test_train, test_valid = client.create_generators(test=True, expand=True)

dsc = []

for x, y in test_valid:
    
    # --- Predict
    x['dat'] = np.pad(x['dat'], ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0)))
    logits = model.predict(x['dat'])

    if type(logits) is dict:
        logits = logits['tumor']

    # --- Argmax
    dsc.append(dice(y['tumor'][0], logits[0], c=1))
    
dsc = np.array(dsc)

Prepare results in Pandas DataFrame for ease of analysis and sharing:

In [None]:
# --- Define columns
df = pd.DataFrame(index=np.arange(dsc.size))
df['Dice score'] = dsc

# --- Print accuracy
print(df['Dice score'].median())

Use the following code cell to visualize a single 3D volume prediction:

In [None]:
# --- Visualization
imshow(x['dat'][0, 1:-1], np.argmax(logits, axis=-1))