# Overview

## Setting Up Your Environment

The following lines of code will download several external Python libraries with code prepared for this tutorial. More information and additional tutorials may be found at the following GitHub repository: https://github.com/peterchang77/dl_train. 

In [None]:
!wget -O setenv.py https://raw.githubusercontent.com/peterchang77/dl_utils/master/setenv.py
from setenv import prepare_env
prepare_env()

Next we will download and prepare data for this tutorial:

In [None]:
from dl_train import datasets
datasets.download(name='brats')

## Import 

The following modules will be used in this tutorial:

In [None]:
# --- Select Tensorflow 2.0 (only in Google Colab)
%tensorflow_version 2.x

In [None]:
import glob, os
import numpy as np

from tensorflow.keras import Input, Model, layers, models, losses, optimizers
from tensorflow import math

from dl_utils.display import imshow

# Data

The data you have downloaded above contains preprocessed images and labels for this tutorial. To access the data, we will prepare Python generators (`gen_train` and `gen_valid`) using the custom `datasets` module. Specifically, we will create generators that:

* yield a total of 16 training examples per batch
* use the first fold (=0) for validation
* sample equally from foreground (`fg`) and background (`bg`) cases at 50% frequency each

In [None]:
gen_train, gen_valid = datasets.prepare(name='brats', configs={
    'batch': {
        'size': 16,
        'fold': 0,
        'sampling': {
            'fg': 0.5,
            'bg': 0.5
}}})

The returned Python generators yield a tuple `(xs, ys)` that conform to the Tensorflow / Keras 2.0 API for training input(s) and output(s). Let us take a closer look here:

In [None]:
# --- Yield the next batch
xs, ys = next(gen_train)

# --- Inspect xs and ys dictionaries
print(xs.keys())
print(ys.keys())

# --- Inspect the `t2` array
print(xs['t2'].shape)

### Visualization

Let us now view the underlying voxel data using the `imshow()` method available in the custom `dl_utils.display` module. This useful function can be used to directly visualize any 2D slice of data (first argument), as well as overlay any mask if optionally provided (second argument). 

Example usage as follows:

In [None]:
imshow(xs['flair'], ys['lbl'], figsize=(12, 12))

# Model

In this exercise, we will create a custom variant of the standard contracting-expanding netowrk topology, popularly referred to as a U-Net architecture. We will define the algorithm completely here in the next several code cells using the functional API of Tensorflow/Keras. For a more general overview of basic Tensorflow/Keras usage, see the following tutorial links (remote/local). 

## Creating a Convolutional Block

To help facilitate concise and readable code, we will create template Python lambda functions to succintly define convolutional blocks, defined as the following series of consecutive operations:

* convolution (or convolutional-transpose)
* batch normalization
* activation function (ReLU, leaky ReLU, etc)

In [None]:
# --- Define convolution parameters
kwargs = {
    'kernel_size': (1, 3, 3),
    'padding': 'same',
    'kernel_initializer': 'he_normal'}

# --- Define block components
conv = lambda x, filters, strides : layers.Conv3D(filters=filters, strides=strides, **kwargs)(x)
tran = lambda x, filters, strides : layers.Conv3DTranspose(filters=filters, strides=strides, **kwargs)(x)
norm = lambda x : layers.BatchNormalization()(x)
relu = lambda x : layers.LeakyReLU()(x)

# --- Define stride-1, stride-2 blocks
conv1 = lambda filters, x : relu(norm(conv(x, filters, strides=1)))
conv2 = lambda filters, x : relu(norm(conv(x, filters, strides=(1, 2, 2))))
tran2 = lambda filters, x : relu(norm(tran(x, filters, strides=(1, 2, 2)))) 

Now, we are ready to define the full model.

In [None]:
def create_model(input_shape):
    """
    Method to create simple U-Net architecture

    """
    # --- Define input
    inputs = Input(shape=input_shape)

    # --- Define contracting layers
    l1 = conv1(16, inputs)
    l2 = conv1(32, conv2(32, l1))
    l3 = conv1(48, conv2(48, l2))
    l4 = conv1(64, conv2(64, l3))
    l5 = conv1(80, conv2(80, l4))

    # --- Define expanding layers
    l6 = tran2(64, l5)
    l7 = tran2(48, conv1(64, l6 + l4))
    l8 = tran2(32, conv1(48, l7 + l3))
    l9 = tran2(16, conv1(32, l8 + l2))

    logits = layers.Conv3D(filters=2, **kwargs)(conv1(16, l1 + l9))

    return Model(inputs=inputs, outputs=logits)

In [None]:
# --- Create model and show summary
model = create_model(arrays['dat'].shape)
model.summary()

## Preparing the Model

Next, we must compile the model with the requisite objects that define training dynamics (e.g. how the algorithm with learn). This will include classes that encapsulate the model `optimizer`, `loss` and `metrics` for evaluating algorithm performance. For more information about how these strategies are determined and defined, see the following tutorial links (remote/local).

In [None]:
# --- Define optimizer
optimizer = optimizers.Adam(learning_rate=2e-4)

# --- Define loss
loss = losses.SparseCategoricalCrossentropy(from_logits=True)

### Custom Metric

In addition the standard evaluation metric of `accuracy` (% of pixels or voxels that are predicted correctly), we will also keep track of a common metric to evaluate spatial overlap of masks: the Dice score. To do so, we need to create a custom metric function. For more information about creating custom metrics (and losses) in Tensorflow/Keras, see the following tutorial links (remote/local).

In [None]:
def metric_dice():

    def dice(y_true, y_pred):

        true = y_true == 1
        pred = y_pred[..., 1] > y_pred[..., 0] 

        A = math.count_nonzero(true & pred) * 2
        B = math.count_nonzero(true) + math.count_nonzero(pred)

        return A / B

    return dice

### Compile

At last we are ready to compile the model. This is done simply with a call using the `model.compile()` method:

In [None]:
model.compile(
    optimzer=optimizer,
    loss=loss,
    metrics=['accuracy', metric_dice()])

# Training

One of the primary advantages to the Tensorflow/Keras API is that by following the above "recipe" to customize, instantiate and compile a `model` object, several very easy-to-use interfaces are available for algorithm training. In this tutorial, we will use data prepared from Python generators (`gen_train` and `gen_valid` as above) to train the model using the `model.fit_generator()` method. Usage is shown as follows using a single line of code:

In [None]:
# --- Train the model
model.fit_generator(
    generator=gen_train,
    steps_per_epoch=500,
    epochs=4,
    validation_data=gen_valid,
    validation_steps=100)

# Prediction

How did we do? The validation performance metrics (accuracy, Dice score) give us a reasonable benchmark, but the most important thing to do at the end of the day is to visually check some examples for yourself. Let us pass some validation data manually into the model using the `model.predict()` method and see some results:

In [None]:
# --- Load data and preproces
arrays = client.get(shape=SHAPE, split='valid')

# --- Run prediction
pred = model.predict(np.expand_dims(arrays['dat'], axis=0))
mask = np.argmax(pred[0], axis=-1)

# --- Show prediction
imshow(arrays['dat'][0, ..., 1], mask)

## Saving and Loading a Model

After a model has been successfully trained, it can be saved and/or loaded by simply using the `model.save()` and `models.load_model()` methods. Note that any custom losses and/or metrics will need to be provided via a dictionary.

In [None]:
# --- Serialize a model
os.makedirs('./models', exist_ok=True)
model.save('./models/unet.hdf5')

In [None]:
# --- Load a serialized model
model = models.load_model('./models/unet.hdf5', custom_objects={'dice': metric_dice()})