# Overview

In this tutorial we will explore the building blocks required to create a contracting-expanding convolutional neural network (CNN) to perform semantic segmentation of the prostate gland on MRI. Specifically, this algorithm will be implemented as a three-class classifier: background; transitional zone of the prostate; peripheral zone of the prostate. The ability to properly capture this anatomic context is a critical first step in characterizing a potential prostate lesion.

This tutorial is part of the class **Introduction to Deep Learning for Medical Imaging** at University of California Irvine (CS190); more information can be found at: https://github.com/peterchang77/dl_tutor/tree/master/cs190.

# Google Colab

The following lines of code will configure your Google Colab environment for this tutorial.

### Enable GPU runtime

Use the following instructions to switch the default Colab instance into a GPU-enabled runtime:

```
Runtime > Change runtime type > Hardware accelerator > GPU
```

### Mount Google Drive

The Google Colab environment is transient and will reset after any prolonged break in activity. To retain important and/or large files between sessions, use the following lines of code to mount your personal Google drive to this Colab instance:

In [None]:
try:
    # --- Mount gdrive to /content/drive/My Drive/
    from google.colab import drive
    drive.mount('/content/drive')
    
except: pass

Throughout this tutorial we will use the following global `MOUNT_ROOT` variable to reference a location to store long-term data. If you are using a local Jupyter server and/or wish to store your data elsewhere, please update this variable now.

In [None]:
# --- Set data directory
MOUNT_ROOT = '/content/drive/My Drive'

### Select Tensorflow library version

This tutorial will use the (new) Tensorflow 2.0 library. Use the following line of code to select this updated version:

In [None]:
# --- Select Tensorflow 2.x (only in Google Colab)
% tensorflow_version 2.x

# Environment

### Jarvis library

In this notebook we will Jarvis, a custom Python package to facilitate data science and deep learning for healthcare. Among other things, this library will be used for low-level data management, stratification and visualization of high-dimensional medical data.

In [None]:
# --- Install jarvis (only in Google Colab or local runtime)
% pip install jarvis-md

### Imports

Use the following lines to import any additional needed libraries:

In [None]:
import numpy as np, pandas as pd
from tensorflow import losses, optimizers
from tensorflow.keras import Input, Model, models, layers
from jarvis.train import datasets, custom
from jarvis.utils.display import imshow

# Data

The data used in this tutorial will consist of prostate MRI exams. In this current assignment, only T2-weighted images (isolated using a prior algorithm) will be used for segmentation. In prostate imaging, the T2-weighted sequence captures the greatest amount of anatomic detail and is thus ideal for delineation of prostate gland structures. The custom `datasets.download(...)` method can be used to download a local copy of the dataset. By default the dataset will be archived at `/data/raw/mr_prostatex_seg`; as needed an alternate location may be specified using `datasets.download(name=..., path=...)`. 

In [None]:
# --- Download dataset
datasets.download(name='mr/prostatex-seg')

Once downloaded, the `datasets.prepare(...)` method can be used to generate the required python Generators to iterate through the dataset, as well as a `client` object for any needed advanced functionality. As needed, pass any custom configurations (e.g. batch size, normalization parameters, etc) into the optional `configs` dictionary argument. 

### Prostate segmentation

Given the relative high resolution of the input and output data for a U-net architecture, a relatively small batch size will be selected in this exercise (12). In addition, there are two separate processed cohorts in this dataset. This first tutorial will use full field-of-view T2 images resampled to 256 x 256 resolution; this specific dataset can be selected using the `keyword='seg-256'` argument:

In [None]:
# --- Prepare generators
configs = {'batch': {'size': 12}}
gen_train, gen_valid, client = datasets.prepare(name='mr/prostatex-seg', configs=configs, keyword='seg-256')

The created generators yield a total of `n` training samples based on the specified batch size. As before, each iteration yields two variables, `xs` and `ys`, each representing a dictionary of model input(s) and output(s). In the current example, there is just a single input and output. Let us examine the generator data:

In [None]:
# --- Yield one example
xs, ys = next(gen_train)

# --- Print dict keys
print('xs keys: {}'.format(xs.keys()))
print('ys keys: {}'.format(ys.keys()))

In [None]:
# --- Print data shape
print('xs shape: {}'.format(xs['dat'].shape))
print('ys shape: {}'.format(ys['zones'].shape))

Notice one key difference now is that the target `ys` outputs now consist of a full 2D matrix **equal in resolution** to the input `xs`. Use the following lines of code to visualize both the image data and corresponding mask label using the `imshow(...)` method:

In [None]:
# --- Show the first example
xs, ys = next(gen_train)
imshow(xs['dat'][0], ys['zones'][0])

In this exercise, each voxel in the label mask contains one of three separate classes for segmentation:

* 0 = background (non-prostate)
* 1 = peripheral zone (prostate gland)
* 2 = transitional zone (prostate gland)

The appearance of prostate cancer is different in the various prostate gland zones; thus it is critical to first determine the location of each potential prostate lesion prior to further characterization.

Use the `montage(...)` function to create an N x N mosaic of all images:

In [None]:
# --- Show "montage" of all images
imshow(xs['dat'], ys['zones'], figsize=(12, 12))

### Model inputs

For every input in `xs`, a corresponding `Input(...)` variable can be created and returned in a `inputs` dictionary for ease of model development:

In [None]:
# --- Create model inputs
inputs = client.get_inputs(Input)

print(inputs.keys())
print(inputs['dat'].shape)

In this example, the equivalent Python code to generate `inputs` would be:

```python
inputs = {}
inputs['dat'] = Input(shape=(1, 256, 256, 1))
```

### 3D operations

Note that the model input shapes for this exercise (and all subsequent exercises) will be provided as 3D tensors. Even if your current model does not require 3D data (as in this current tutorial), all 2D tensors can be represented by a 3D tensor with a z-axis shape of 1. In addition, designing all models with this configuration (e.g. 3D operations) ensures that minimal code changes are needed when testing various 2D and 3D network architectures. The key differences include use of full 3D `kernel_size` and `strides`. See below for more information.

# Contracting Layers

As discussed in lecture, the contracting layers of a U-Net architecture are essentially identical to a standard feed-forward CNN. In addition, several key modifications to the original architecture will be made in ths implementation including:

* same padding (vs. valid padding)
* strided convoltions (vs. max-pooling)
* smaller filters (channel depths)

Let us start by defining the contracting layer architecture below. Recall that slight adjustments need to be made to account for 3D operations and inputs (despite using a 2D architecture):

In [None]:
# --- Define kwargs dictionary
kwargs = {
    'kernel_size': (1, 3, 3),
    'padding': 'same'}

# --- Define lambda functions
conv = lambda x, filters, strides : layers.Conv3D(filters=filters, strides=strides, **kwargs)(x)
norm = lambda x : layers.BatchNormalization()(x)
relu = lambda x : layers.LeakyReLU()(x)

# --- Define stride-1, stride-2 blocks
conv1 = lambda filters, x : relu(norm(conv(x, filters, strides=1)))
conv2 = lambda filters, x : relu(norm(conv(x, filters, strides=(1, 2, 2))))

Using these lambda functions, let us define a simple 9-layer contracting network topology with a total a four subsample (stride-2 convolution) operations:

In [None]:
# --- Define contracting layers
l1 = conv1(8, inputs['dat'])
l2 = conv1(16, conv2(16, l1))
l3 = conv1(32, conv2(32, l2))
l4 = conv1(48, conv2(48, l3))
l5 = conv1(64, conv2(64, l4))

What is the shape of the `l5` feature map?

# Expanding Layers

As discussed in lecture, the expanding layers are simply implemented by reversing the operations found in the contract layers above. Specifically, each subsample operation is now replaced by a **convolutional transpose**. Due to the use of **same** padding, defining a transpose operation with the exact same parameters as a strided convolution will ensure that layers in the expanding pathway will exactly match the shape of the corresponding contracting layer.

### Convolutional transpose

Let us start by defining an additional lambda function for the convolutional transpose:

In [None]:
# --- Define single transpose
tran = lambda x, filters, strides : layers.Conv3DTranspose(filters=filters, strides=strides, **kwargs)(x)

# --- Define transpose block
tran2 = lambda filters, x : relu(norm(tran(x, filters, strides=(1, 2, 2))))

Carefully compare these functions to the single `conv` operations as well as the `conv1` and `conv2` blocks above. Notice that they share the exact same configurations.

Let us now apply the first convolutional transpose block to the `l5` feature map:

In [None]:
# --- Define expanding layers
l6 = tran2(48, l5)

What is the shape of the `l6` feature map?

### Concatenation

The first connection in this specific U-Net derived architecture is a link between the `l4` and the `l6` layers:

```
l1 -------------------> l9
  \                    /
   l2 -------------> l8
     \              /   
      l3 -------> l7
        \        /
         l4 -> l6
           \  /
            l5
```

To mediate the first connection between contracting and expanding layers, we must ensure that `l4` and `l6` match in feature map size (the number of filters / channel depth *do not* necessarily). Using the `same` padding as above should ensure that this is the case and thus simplifies the connection operation:

In [None]:
# --- Ensure shapes match
print(l4.shape)
print(l6.shape)

# --- Concatenate
concat = lambda a, b : layers.Concatenate()([a, b])
concat(l4, l6)

Note that since `l4` and `l6` are **exactly the same shape** (including matching channel depth), what additional operation could be used here instead of a concatenation?

### Full expansion

Alternate the use of `conv1` and `tran2` blocks to build the remainder of the expanding pathway:

In [None]:
# --- Define expanding layers
l7 = tran2(32, conv1(48, concat(l4, l6)))
l8 = tran2(16, conv1(32, concat(l3, l7)))
l9 = tran2(8, conv1(16, concat(l2, l8)))
l10 = conv1(8, l9)

# Logits

The last convolution projects the `l10` feature map into a total of just `n` feature maps, one for each possible class prediction. In this 3-class prediction task, a total of `3` feature maps will be needed. Recall that these feature maps essentially act as a set of **logit scores** for each voxel location throughout the image.

As in all prior exercises, **do not** use an activation here in the final convolution:

In [None]:
# --- Create logits
logits = {}
logits['zones'] = layers.Conv3D(filters=3, name='zones', **kwargs)(l10)

# Model

Let us first create our model:

In [None]:
# --- Create model
model = Model(inputs=inputs, outputs=logits)

### Custom Dice score metric

For model compilation, a single modification from the prior tutorials is needed to account for a specific Keras metric to track segmentation performance. Recall in lecture that the metric of choice for this task is the **Dice score**. The Dice score is not a default metric built in the Tensorflow library, however a custom metric is available for your convenience as part of the `jarvis-md` package. It is invoked using the `custom.dsc(cls=...)` call, where the argument `cls` refers to the number of *non-zero* classes to track (e.g. the background Dice score is typically not tracked). In this exercise, it will be important to track the performance of segmentation for **peripheral zone** (class = 1) and **transitional zone** (class = 2), thus set the `cls` argument to `2`.

In [None]:
# --- Compile model
model.compile(
    optimizer=optimizers.Adam(learning_rate=2e-4),
    loss={'zones': losses.SparseCategoricalCrossentropy(from_logits=True)},
    metrics={'zones': custom.dsc(cls=2)},
    experimental_run_tf_function=False)

### In-memory data

For moderate sized datasets which are too large to fit into immediate hard-drive cache, but small enough to fit into RAM memory, it is often times a good idea to first load all training data into RAM memory for increased speed of training. The `client` can be used for this purpose as follows:

In [None]:
# --- Load data into memory for faster training
client.load_data_in_memory()

Now, let us train the model:

In [None]:
# --- Train model
model.fit(
    x=gen_train, 
    steps_per_epoch=500, 
    epochs=4,
    validation_data=gen_valid,
    validation_steps=500,
    validation_freq=4,
    use_multiprocessing=True)

# Evaluation

To test the trained model, the following steps are required:

* load data
* use `model.predict(...)` to obtain logit scores
* compare prediction with ground-truth
* serialize in Pandas DataFrame

Recall that the generator used to train the model simply iterates through the dataset randomly. For model evaluation, the cohort must instead be loaded manually in an orderly way. For this tutorial, we will create new **test mode** data generators, which will simply load each example individually once for testing. 

In [None]:
# --- Create validation generator
test_train, test_valid = client.create_generators(test=True, expand=True)

**Important note**: although the model is trained using 2D slices, there is nothing to preclude passing an entire 3D volume through the model at one time (e.g. consider that the entire 3D volume is a single *batch* of data). In fact, typically performance metrics for medical imaging models are commonly reported on a volume-by-volume basis (not slice-by-slice). Thus, use the `expand=True` flag in `client.create_generators(...)` as above to yield entire 3D volumes instead of slices.

In [None]:
# --- Run entire volume through model
x, y = next(test_train)
logits = model.predict(x['dat'])

The key to converting this vector to a final global prediction is to implement some sort of aggregation metric. The most common shown below uses the mean prediction as the final global classification. 

### Dice score

While the Dice score metric for Tensorflow has been provided already, an implementation must still be used to manually calculate the performance during validation. Use the following code cell block to implement:

In [None]:
def dice(y_true, y_pred, c=1, epsilon=1):
    """
    Method to calculate the Dice score coefficient for given class
    
    :params
    
      (np.ndarray) y_true : ground-truth label
      (np.ndarray) y_pred : predicted logits scores
      (int)             c : class to calculate DSC on
    
    """
    assert y_true.ndim == y_pred.ndim
    
    true = y_true[..., 0] == c
    pred = np.argmax(y_pred, axis=-1) == c 

    A = np.count_nonzero(true & pred) * 2
    B = np.count_nonzero(true) + np.count_nonzero(pred) + epsilon
    
    return A / B

Use the following lines of code to run prediction through the **valid** cohort generator. Note that the **Dice score** for both the two classes of interest must be individually calculated:

In [None]:
# --- Create validation generator
test_train, test_valid = client.create_generators(test=True, expand=True)

dsc_pz = []
dsc_tz = []

for x, y in test_valid:
    
    # --- Predict
    logits = model.predict(x['dat'])

    if type(logits) is dict:
        logits = logits['zones']

    # --- Argmax
    dsc_pz.append(dice(y['zones'][0], logits[0], c=1))
    dsc_tz.append(dice(y['zones'][0], logits[0], c=2))

dsc_pz = np.array(dsc_pz)
dsc_tz = np.array(dsc_tz)

Prepare results in Pandas DataFrame for ease of analysis and sharing:

In [None]:
# --- Define columns
df = pd.DataFrame(index=np.arange(dsc_tz.size))
df['dsc_pz'] = dsc_pz
df['dsc_tz'] = dsc_tz

# --- Print accuracy
print(df['dsc_pz'].mean())
print(df['dsc_tz'].mean())

## Saving and Loading a Model

After a model has been successfully trained, it can be saved and/or loaded by simply using the `model.save()` and `models.load_model()` methods. 

In [None]:
# --- Serialize a model
model.save('./organ_segmentation.hdf5')

In [None]:
# --- Load a serialized model
del model
model = models.load_model('./organ_segmentation.hdf5', compile=False)

# Exercises

### Exercise 1

Instead of a concatenation operation to connect the contracting and expanding layers, implement a **residual** connection instead. What are the potential advantages or disadvatages of this technique? 

Use the following code cell to experiment:

### Exercise 2

Instead of a single concatenation or residual connection, implement a series of convolutional blocks (with activation functions) to connect the contracting and expanding layers. What are the potential advantages or disadvantages of this technique?

Use the following code cell to experiment: