# Overview

In this tutorial we will demonstate how to create an unsupervised CNN autoencoder architecture. Subsequently, the pretrained encoding (contracting) backbone of the autoencoder will be used to create a model for survival prediction in brain tumor patients.

This tutorial is part of the class **Introduction to Deep Learning for Medical Imaging** at University of California Irvine (CS190); more information can be found at: https://github.com/peterchang77/dl_tutor/tree/master/cs190.

# Google Colab

The following lines of code will configure your Google Colab environment for this tutorial.

### Enable GPU runtime

Use the following instructions to switch the default Colab instance into a GPU-enabled runtime:

```
Runtime > Change runtime type > Hardware accelerator > GPU
```

# Environment

### Jarvis library

In this notebook we will Jarvis, a custom Python package to facilitate data science and deep learning for healthcare. Among other things, this library will be used for low-level data management, stratification and visualization of high-dimensional medical data.

In [None]:
# --- Install jarvis (only in Google Colab or local runtime)
% pip install jarvis-md

### Imports

Use the following lines to import any additional needed libraries:

In [None]:
import numpy as np, pandas as pd
from tensorflow import losses, optimizers
from tensorflow.keras import Input, Model, models, layers
from jarvis.train import datasets, custom
from jarvis.utils.display import imshow

# Data

The data used in this tutorial will consist of brain tumor MRI exams derived from the MICCAI Brain Tumor Segmentation Challenge (BRaTS). More information about he BRaTS Challenge can be found here: http://braintumorsegmentation.org/. Each single 3D volume will consist of one of four different sequences (T2, FLAIR, T1 pre-contrast and T1 post-contrast). In this exercise, we will use this dataset to derive an autoencoder based on cropped 3D inputs. The custom `datasets.download(...)` method can be used to download a local copy of the dataset. By default the dataset will be archived at `/data/raw/mr_brats_2020`; as needed an alternate location may be specified using `datasets.download(name=..., path=...)`. 

In [None]:
# --- Download dataset
datasets.download(name='mr/brats-2020-mip')

Once downloaded, the `datasets.prepare(...)` method can be used to generate the required python Generators to iterate through the dataset, as well as a `client` object for any needed advanced functionality.

To specificy the correct Generator template file, pass a designated `keyword` string. In this tutorial, we will be using brain MRI volumes that have been cropped to the boundaries of the tumor and resampled to a uniform 3D volume of shape (96, 96, 96, 4).  To select the correct Client template for this task, use the keyword string `096*vox`. 

In [None]:
# --- Prepare generators
gen_train, gen_valid, client = datasets.prepare(name='mr/brats-2020-mip', keyword='mip*vox')

As before, each iteration yields two variables, `xs` and `ys`, each representing a dictionary of model input(s) and output(s). In the current example, there is just a single input and output. Let us examine the generator data:

In [None]:
# --- Yield one example
xs, ys = next(gen_train)

# --- Print dict keys
print('xs keys: {}'.format(xs.keys()))
print('ys keys: {}'.format(ys.keys()))

In [None]:
# --- Print data shape
print('xs shape: {}'.format(xs['dat'].shape))
print('ys shape: {}'.format(ys['survival'].shape))

### Survival scores

In this tutorial, the total days of patient survival have been converted to floating point scores between `[0, 1]` using the following formula:

```
score = log ( days ) / 10
```

### BRATS Data

The BRATS dataset comprises of four different MRI sequences stacked along the channels dimension:

In [None]:
print(xs['dat'][0].shape)

The four channels represent four different input MRI sequences. Each sequence is used to evaluate for a different tissue quality. T2 and FLAIR are used to evaluate edema (fluid) that results from brain injury. T1 images are used to evaluate anatomy and breakdown of the blood-brain-barrier through contrast leakge.  

```
dat[..., 0] = FLAIR
dat[..., 1] = T1 precontrast
dat[..., 2] = T1 postcontrast
dat[..., 3] = T2
```

To visualize these different modalities run the following cell:

In [None]:
imshow(xs['dat'][..., 0], title='FLAIR')
imshow(xs['dat'][..., 1], title='T1 precontrast')
imshow(xs['dat'][..., 2], title='T1 postcontrast')
imshow(xs['dat'][..., 3], title='T2')

### Model inputs

For every input in `xs`, a corresponding `Input(...)` variable can be created and returned in a `inputs` dictionary for ease of model development:

In [None]:
# --- Create model inputs
inputs = client.get_inputs(Input)

print(inputs.keys())
print(inputs['dat'].shape)

In this example, the equivalent Python code to generate `inputs` would be:

```python
inputs = {}
inputs['dat'] = Input(shape=(96, 96, 96, 4))
```

# Autoencoder

The autoencoder architecture is designed to create a compressed latent feature representation of the original input data. By compressing the raw data using an autoencoder, only the most relevant, critical details of the original input are retained in the low dimensional feature representation.

Compared to a standard contract-encoding U-Net architecture for semantic segmentation, two important distinctions should be emphasize:

* no "skip" connections between the contractind and expanding layers
* use of a regression loss function (e.g., MAE or MSE) for optimization

### Create model layers

In [None]:
# --- Define kwargs dictionary
kwargs = {
    'kernel_size': (3, 3, 3),
    'padding': 'same'}

# --- Define lambda functions
conv = lambda x, filters, strides : layers.Conv3D(filters=filters, strides=strides, **kwargs)(x)
norm = lambda x : layers.BatchNormalization()(x)
relu = lambda x : layers.ReLU()(x)
tran = lambda x, filters, strides : layers.Conv3DTranspose(filters=filters, strides=strides, **kwargs)(x)

# --- Define stride-1, stride-2 blocks
conv1 = lambda filters, x : relu(norm(conv(x, filters, strides=1)))
conv2 = lambda filters, x : relu(norm(conv(x, filters, strides=2)))
tran2 = lambda filters, x : relu(norm(tran(x, filters, strides=2)))

Using these lambda functions, let us define a standard contracting-encoding architecture:

In [None]:
# --- Define contracting layers
l1 = conv1(8, inputs['dat'])
l2 = conv1(16, conv2(16, l1))
l3 = conv1(32, conv2(32, l2))
l4 = conv1(48, conv2(48, l3))
l5 = conv1(64, conv2(64, l4))

# --- Define expanding layers
l6  = tran2(48, l5)
l7  = tran2(32, conv1(48, l6))
l8  = tran2(16, conv1(32, l7))
l9  = tran2(8,  conv1(16, l8))
l10 = conv1(8,  l9)

### Create autoencoder

The full autoencoder architecture will span all layers above:

In [None]:
# --- Create autoencoder
ae_outputs = {'recon': layers.Conv3D(filters=4, name='recon', **kwargs)(l10)}
ae = Model(inputs=inputs, outputs=ae_outputs)

### Create encoder

In addition to the full autoencoder, let us create a second smaller encoder just spanning the contracting layers from the autoencoder above. Note that we will define this encoder model using the **exact same** layers as above. Thus, updates to one model will necessarily update the corresponding parameters in the second model.

In [None]:
# --- Create encoder
encoder = Model(inputs=inputs, outputs=l5)

### Generator

To train the autoencoder, the default training generators which yield labels corresponding to survival scores need to be modified to instead yield the original input data. Use the following nested Python generator strategy to accomplish this:

In [None]:
def ae_generator(G):
    """
    Method to modify standard generator for autoencoder
    
    """
    for xs, ys in G:

        ys = {'recon': xs['dat']}

        yield xs, ys

## Training

Use the following blocks to compile and train the autoencoder model:

In [None]:
# --- Compile model
ae.compile(
    optimizer=optimizers.Adam(learning_rate=1e-3),
    loss={'recon': losses.MeanSquaredError(from_logits=True)},
    experimental_run_tf_function=False)

### In-memory data

For moderate sized datasets which are too large to fit into immediate hard-drive cache, but small enough to fit into RAM memory, it is often times a good idea to first load all training data into RAM memory for increased speed of training. The `client` can be used for this purpose as follows:

In [None]:
# --- Load data into memory for faster training
client.load_data_in_memory()

Now, let us train the model:

In [None]:
# --- Train model
ae.fit(
    x=ae_generator(gen_train), 
    steps_per_epoch=500, 
    epochs=4,
    use_multiprocessing=True)

Now that the model is trained, how would you visualize some of the autoencoder results?

# Survival Model

The final model for patient survival will reuse the pre-trained encoder layers of the autoencoder (note that the `encoder` model object above shares weights with the `ae` model object used to training). 

To re-use the `encoder` object, first freeze the model weights:

In [None]:
# --- Freeze encoder model weights
encoder.trainable = False

Now let us define a new survival model:

In [None]:
# --- Re-use the encoder model layers on input
latent = encoder(inputs)

# --- Finalize model
h0 = layers.Flatten()(latent)
h1 = layers.Dense(32, activation='relu')(h0)

logits = {}
logits['survival'] = layers.Dense(1, activation='sigmoid', name='survival')(h1)

# --- Create encoder
model = Model(inputs=inputs, outputs=logits)

## Training

Use the following blocks to compile and train the survival model:

In [None]:
# --- Compile model
model.compile(
    optimizer=optimizers.Adam(learning_rate=1e-3),
    loss={'tumor': losses.MeanSquaredError(from_logits=True)},
    experimental_run_tf_function=False)

Now, let us train the model:

In [None]:
# --- Train model
model.fit(
    x=gen_train, 
    steps_per_epoch=500, 
    epochs=4,
    validation_data=gen_valid,
    validation_steps=500,
    validation_freq=4,
    use_multiprocessing=True)

# Evaluation

To test the trained model, the following steps are required:

* load data
* use `model.predict(...)` to obtain logit scores
* compare prediction with ground-truth
* serialize in Pandas DataFrame

Recall that the generator used to train the model simply iterates through the dataset randomly. For model evaluation, the cohort must instead be loaded manually in an orderly way. For this tutorial, we will create new **test mode** data generators, which will simply load each example individually once for testing. 

### Running evaluation

Use the following lines of code to run prediction through the **valid** cohort generator.

In [None]:
# --- Create validation generator
test_train, test_valid = client.create_generators(test=True, expand=True)

mae = []

for x, y in test_valid:
    
    # --- Predict
    logits = model.predict(x['dat'])

    if type(logits) is dict:
        logits = logits['survival']

    # --- Argmax
    mae.append(np.abs(logits['survival'].squeeze() - y['survival']))

mae = np.array(mae)

Prepare results in Pandas DataFrame for ease of analysis and sharing:

In [None]:
# --- Define columns
df = pd.DataFrame(index=np.arange(mae.size))
df['MAE'] = mae

# --- Print accuracy
print(df['MAE'].mean())

## Saving and Loading a Model

After a model has been successfully trained, it can be saved and/or loaded by simply using the `model.save()` and `models.load_model()` methods. 

In [None]:
# --- Serialize a model
model.save('./survival.hdf5')

In [None]:
# --- Load a serialized model
del model
model = models.load_model('./survival.hdf5', compile=False)