# Overview

In this tutorial we will explore a few of the advanced approaches to customize a segmentation algorithm for medical imaging. These approaches will build upon the standard U-Net contracting-expanding CNN architecture. As in the prior tutorial, the goal remains to perform kidney segmentation on CT.

This tutorial is part of the class **Introduction to Deep Learning for Medical Imaging** at University of California Irvine (CS190); more information can be found at: https://github.com/peterchang77/dl_tutor/tree/master/cs190.

# Google Colab

The following lines of code will configure your Google Colab environment for this tutorial.

### Enable GPU runtime

Use the following instructions to switch the default Colab instance into a GPU-enabled runtime:

```
Runtime > Change runtime type > Hardware accelerator > GPU
```

# Environment

### Jarvis library

In this notebook we will Jarvis, a custom Python package to facilitate data science and deep learning for healthcare. Among other things, this library will be used for low-level data management, stratification and visualization of high-dimensional medical data.

In [None]:
# --- Install jarvis (only in Google Colab or local runtime)
% pip install jarvis-md

### Imports

Use the following lines to import any additional needed libraries:

In [None]:
import numpy as np, pandas as pd
import tensorflow as tf
from tensorflow.keras import Input, Model, models, layers, losses, metrics, optimizers
from jarvis.train import datasets
from jarvis.utils.display import imshow

# Data

The data used in this tutorial will consist of kidney tumor CT exams derived from the Kidney Tumor Segmentation Challenge (KiTS). More information about the KiTS Challenge can be found here: https://kits21.kits-challenge.org/. In this exercise, we will use this dataset to derive a model for slice-by-slice kidney segmentation. The custom `datasets.download(...)` method can be used to download a local copy of the dataset. By default the dataset will be archived at `/data/raw/ct_kits`; as needed an alternate location may be specified using `datasets.download(name=..., path=...)`. 

In [None]:
# --- Download dataset
datasets.download(name='ct/kits')

Once downloaded, the `datasets.prepare(...)` method can be used to generate the required python Generators to iterate through the dataset, as well as a `client` object for any needed advanced functionality. As needed, pass any custom configurations (e.g. batch size, normalization parameters, etc) into the optional `configs` dictionary argument. 

In this tutorial, we will be using abdominal CT volumes that have been preprocessed into 96 x 96 x 96 matrix volumes, each cropped to the right and left kidney, facilitating ease of algorithm training within the Google Colab platform. Based on model implementation strategy, both 2D and 3D data have been prepared for this exercise. To specificy the correct Generator template file, pass a designated `keyword` string. 

**2D dataset**: To select the 2D data of input size `(1, 96, 96, 1)` use the keyword `2d-bin`:

In [None]:
# --- Prepare generators
gen_train, gen_valid, client = datasets.prepare(name='ct/kits', keyword='2d-bin', custom_layers=True)

**3D dataset**: To select the 3D data of input size `(96, 96, 96, 1)` use the keyword `3d-bin`:

In [None]:
# --- Prepare generators
gen_train, gen_valid, client = datasets.prepare(name='ct/kits', keyword='3d-bin', custom_layers=True)

The created generators yield a total of `batch['size']` training samples based on the specified batch size. As before, each iteration yields dictionary of model inputs, `xs`. In the current example, there is just a single input image `xs['dat']` and a single target `xs['lbl']`. Let us examine the generator data:

In [None]:
# --- Yield one example
xs, _ = next(gen_train)

# --- Print dict keys
for k, v in xs.items():
    print('key = {} : shape = {}'.format(k.ljust(7), v.shape))

### KITS Data


Use the following lines of code to visualize using the `imshow(...)` method:

In [None]:
# --- Show the first example
imshow(xs['dat'][0])

### Kidney masks

The ground-truth labels are binary masks of the same matrix shape as the model input:

In [None]:
print(xs['lbl'][0].shape)

Use the `imshow(...)` method to visualize the ground-truth tumor mask labels:

In [None]:
# --- Show tumor masks overlaid on original data
imshow(xs['dat'][0], xs['lbl'][0])

# --- Show tumor masks isolated
imshow(xs['lbl'][0])

# 3D Models

Many medical imaging modalities yield 3D volume datasets. Full 3D CNN models implemented with 3D convolutional and convolutional transpose operations can process native 3D datasets and learn hierarchial 3D features. 

Given that we have *already* started to use `layers.Conv3D` and `layers.Conv3DTranspose` in our code, the choice between 2D and 3D operations requires only changing the kernel size and stride of certain model layers:

* kernel size: `(3, 3, 3)`, instead of `(1, 3, 3)`
* stride: `(2, 2, 2)`, instead of `(1, 2, 2)` (only when subsampling is requires, otherwise still `1`)

Use the follow block to define 3D convolutional blocks:

In [None]:
# --- Define kwargs
kwargs = {
    'kernel_size': (3, 3, 3),
    'padding': 'same',
    'kernel_initializer': 'he_normal'}

# --- Define block components
conv = lambda x, filters, strides : layers.Conv3D(filters=filters, strides=strides, **kwargs)(x)
tran = lambda x, filters, strides : layers.Conv3DTranspose(filters=filters, strides=strides, **kwargs)(x)

norm = lambda x : layers.BatchNormalization()(x)
relu = lambda x : layers.ReLU()(x)

conv1 = lambda filters, x : relu(norm(conv(x, filters, strides=1)))
conv2 = lambda filters, x : relu(norm(conv(x, filters, strides=(2, 2, 2))))
tran2 = lambda filters, x : relu(norm(tran(x, filters, strides=(2, 2, 2))))

concat = lambda a, b : layers.Concatenate()([a, b])

A standard 3D U-Net series of contracting layers can be defined as follows: 

In [None]:
# --- Define model input 
x = Input(shape=(96, 96, 96, 1), dtype='float32')

# --- Define contracting layers
l1 = conv1(8, x)
l2 = conv1(16, conv2(16, l1))
l3 = conv1(32, conv2(32, l2))
l4 = conv1(48, conv2(48, l3))
l5 = conv1(64, conv2(64, l4))

# Network Connections

A key component of the contracting-expanding segmentation architectures is the use of connections to combine both low- and high-level features. In addition to the standard **concatenation** operation, several variations can be used.

### Residual connection

If *SAME* padding is used throughout the network architecture, and if the number of filters used at each block is symmetric, then the corresponding contracting and expanding layers should have exactly the same feature map size. In this scenario, a **residual** (addition) operation can be used instead of the standard concatenation. 

In [None]:
# --- Use residual connections
l6 = tran2(48, l5)
l7 = tran2(32, conv1(48, l6 + l4))
l8 = tran2(16, conv1(32, l7 + l3))
l9 = tran2(8,  conv1(16, l8 + l2))

What are the advantages or disadvantages of this approach?

### Multiple operations

As discussed in the lecture, high-resolution but shallow layers in the contracting arm of the network may sometimes be too "raw" and thus introduce noise into the network predictions. To help overcome this effect, consider the use of additional operations to refine the shallow contracting layers **prior** to combination with the expanding arm.

In [None]:
# --- Use multiple operations
l6 = tran2(48, l5)
l7 = tran2(32, conv1(48, l6 + conv1(48, l4)))
l8 = tran2(16, conv1(32, l7 + conv1(32, l3)))
l9 = tran2(8,  conv1(16, l8 + conv1(16, l2)))

What are the advantages or disadvantages of this approach?

# Deep Supervision

Multi-Resolution U-Net

```
[ input ]  -------- [ logits  ] ---> loss-00

   ||                   /\
   \/                   ||
   
[  L-1  ]  ------>  [ L-(n-1) ] ---> loss-01 (subsampled 1/2)

   ||                   /\
   \/                   ||
   
[  L-2  ]  ------>  [ L-(n-2) ] ---> loss-02 (subsampled 1/4)

   ||                   /\
   \/                   ||
   
[  L-3  ]  ------>  [ L-(n-3) ] ---> loss-03 (subsampled 1/8)

   ||                   /\
   \/                   ||
   
   ...                 ...
```

In [None]:
# --- Define model input 
x = Input(shape=(96, 96, 96, 1), dtype='float32')

# --- Define contracting layers
l1 = conv1(8, x)
l2 = conv1(16, conv2(16, l1))
l3 = conv1(32, conv2(32, l2))
l4 = conv1(48, conv2(48, l3))
l5 = conv1(64, conv2(64, l4))

# --- Define expanding layers
l6  = tran2(48, l5)
l7  = tran2(32, conv1(48, concat(l4, l6)))
l8  = tran2(16, conv1(32, concat(l3, l7)))
l9  = conv1(8, tran2(8,  conv1(16, concat(l2, l8))))

In [None]:
logits = layers.Conv3D(filters=2, **kwargs)(l9)

A deeply supervised U-Net creates a series of predictions at different resolutions:

* `c0`: original resolution (96)
* `c1`: subsampled by 1/2 (48)
* `c2`: subsampled by 1/4 (24)
* ... and so on.

To produce these multi-resolution logit scores, consider a dictionary as follows:

In [None]:
# --- Create logits
logits = {
    'c0': layers.Conv3D(filters=2, **kwargs)(l9),
    'c1': layers.Conv3D(filters=2, **kwargs)(l8),
    # ... what other layers can be added?
}

Let us create a backbone model with these multi-resolution outputs:

In [None]:
# --- Create model
backbone = Model(inputs=x, outputs=logits)

### Training model

In [None]:
# --- Define inputs
inputs = {
    'dat': Input(shape=(96, 96, 96, 1), name='dat'),
    'lbl': Input(shape=(96, 96, 96, 1), name='lbl')}

# --- Define first step of new wrapper model
logits = backbone(inputs['dat'])

Now, importantly, each of different logit scores you created above must be matched to a subsampled version of the ground-truth segmentation mask, `inputs['lbl']`. To do this, consider simply serial application of the `MaxPool` operation; compared to a naive resampling operation, this ensures that as your prediction becomes smaller, the original target foreground does not disappear.

In [None]:
loss = {}
true = inputs['lbl']

for c in sorted(logits.keys()):
    
    # --- Subsample ground-truth label if needed
    if c != 'c0':
        true = layers.MaxPooling3D(pool_size=(2, 2, 2))(true)
    
    # --- Create loss at current resolution
    loss[c] = losses.SparseCategoricalCrossentropy(from_logits=True)(
        y_true=true,
        y_pred=logits[c])

 ### Dice score metric
 
The standard metric for spatial overlap is known as the **Dice score**. The Dice score is not a default metric built in the Tensorflow library, and thus we will define it here.

In [None]:
def calculate_dsc(y_true, y_pred, c=1):
    """
    Method to calculate the Dice score coefficient for given class
    
    :params
    
      y_true : ground-truth label
      y_pred : predicted logits scores
           c : class to calculate DSC on
    
    """    
    true = y_true[..., 0] == c
    pred = tf.math.argmax(y_pred, axis=-1) == c 

    A = tf.math.count_nonzero(true & pred) * 2
    B = tf.math.count_nonzero(true) + tf.math.count_nonzero(pred)
    
    return tf.math.divide_no_nan(
        tf.cast(A, tf.float32), 
        tf.cast(B, tf.float32))

In [None]:
# --- Define Dice score
dsc = calculate_dsc(y_true=inputs['lbl'], y_pred=logits['c0'])

### Create model

Now let us create the new wrapper model. The inputs are defined above already in our `inputs` Python dictionary. As outputs, let us return both the `logits` tensor as well as the `loss`. We will name this new wrapper model `training` because it will be used for training only. 

In [None]:
training = Model(inputs=inputs, outputs={**logits, **loss, **{'dsc': dsc}})

Now let's add the `loss` and `metric` tensorws we defined above to the new `training` model:

In [None]:
# --- Add loss
for l in loss.values():
    training.add_loss(l)

# --- Add metric
training.add_metric(dsc, name='dsc')

### Compile model

 To prepare the model for learning, a graph must be **compiled** with a strategy for optimization.

In [None]:
# --- Define an Adam optimizer
optimizer = optimizers.Adam(learning_rate=2e-4)

# --- Compile model
training.compile(optimizer=optimizer)

### In-memory data

For moderate sized datasets which are too large to fit into immediate hard-drive cache, but small enough to fit into RAM memory, it is often times a good idea to first load all training data into RAM memory for increased speed of training. The `client` can be used for this purpose as follows:

In [None]:
# --- Load data into memory for faster training
client.load_data_in_memory()

Now, let us train the model:

In [None]:
# --- Train model
training.fit(
    x=gen_train, 
    steps_per_epoch=100, 
    epochs=10,
    validation_data=gen_valid,
    validation_steps=100,
    validation_freq=5)

### Evaluation

Use the following lines of code to run prediction through the **valid** cohort generator.

In [None]:
# --- Create validation generator
test_train, test_valid = client.create_generators(test=True, expand=True)

dsc = []

for x, _ in test_valid:
    
    # --- Predict
    outputs = training.predict(x)

    # --- Argmax
    dsc.append(outputs['dsc'])

dsc = np.array(dsc)