# Overview

In this tutorial we will explore best practices for building modern CNN models, focusing on four key design components: layer-level (micro-)architecture, model size and topology, loss functions, and block-level (macro-)architecture. For each design component, we perform a grid search to characterize the effect of each choice on overall model performance. By systematically progressing through each hyperparameter configuration, we attempt to derive general recommendations and best approaches for empiric experimentation. As a representative use case, we will build various convolutional neural networks (CNNs) for segmentation of renal tumors from CT abdomen scans from the KiTS 2021 Challenge. 

## Workshop Links

This tutorial focuses on specific considerations related network architecture and hyperparameter tuning. Other useful tutorials can be found at this link: https://github.com/peterchang77/dl_tutor/tree/master/workshops

# Environment

The following lines of code will configure your Google Colab environment for this tutorial.

### Enable GPU runtime

Use the following instructions to switch the default Colab instance into a GPU-enabled runtime:

```
Runtime > Change runtime type > Hardware accelerator > GPU
```

### Jarvis library

In this notebook we will Jarvis, a custom Python package to facilitate data science and deep learning for healthcare. Among other things, this library will be used for low-level data management, stratification and visualization of high-dimensional medical data.

In [None]:
# --- Install Jarvis library
%pip install tensorflow==2.14.0 
%pip install jarvis-md

### Imports

Use the following lines to import any needed libraries:

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import losses, optimizers
from tensorflow.keras import Input, Model, layers, optimizers, losses, callbacks
from jarvis.train import datasets
from jarvis.utils.display import imshow

# Data

The data used in this tutorial will consist of kidney tumor CT exams derived from the Kidney Tumor Segmentation Challenge (KiTS). More information about he KiTS Challenge can be found here: https://kits21.kits-challenge.org/. The custom `datasets.download(...)` method can be used to download a local copy of the dataset. By default the dataset will be archived at `/data/raw/ct_kits`; as needed an alternate location may be specified using `datasets.download(name=..., path=...)`. 

In [None]:
# --- Download dataset
datasets.download(name='ct/kits')

Once downloaded, the `datasets.prepare(...)` method can be used to generate the required python Generators to iterate through the dataset, as well as a `client` object for any needed advanced functionality. As needed, pass any custom configurations (e.g. batch size, normalization parameters, etc) into the optional `configs` dictionary argument. 

To specificy the correct Generator template file, pass a designated `keyword` string. In this tutorial, we will be using abdominal CT volumes that have been preprocessed into 96 x 96 x 96 matrix volumes, each cropped to the right and left kidney, facilitating ease of algorithm training within the Google Colab platform. To select the correct Client template for this task, use the keyword string `3d-pos`. 

In [None]:
# --- Prepare generators
gen_train, gen_valid, client = datasets.prepare(name='ct/kits', keyword='3d-pos', configs={'batch': {'size': 2}})

The created generators yield a total of `batch['size']` training samples based on the specified batch size. Each iteration yields a dictionary of model inputs, `data`. In the current example, there is just a single input volume `data['x']` and a single target `data['y']`. Let us examine the generator data:

In [None]:
# --- Yield one example
data, _ = next(gen_train)

# --- Print dict keys
for k, v in data.items():
    print('key = {} : shape = {}'.format(k.ljust(7), v.shape))

### KITS Data

As noted above, the input images in the variable `data['x']` are matrices of shape `96 x 96 x 96 x 1` cropped to the right and left kidneys.

Use the following lines of code to visualize using the `imshow(...)` method:

In [None]:
# --- Show the first example
imshow(data['x'][0], vm=[-256, +256])

Use the `montage(...)` function to create an N x N mosaic of all middle slices:

In [None]:
# --- Show "montage" of all images
imshow(data['x'][:, 48], vm=[-256, +256])

### Kidney masks

The ground-truth labels are binary masks of the same matrix shape as the model input:

In [None]:
print(data['y'][0].shape)

Use the `imshow(...)` method to visualize the ground-truth tumor mask labels:

In [None]:
# --- Show tumor masks overlaid on original data
imshow(data['x'][:, 48], data['y'][:, 48], vm=[-256, +256])

# --- Show tumor masks isolated
imshow(data['y'][:, 48])

# Model Design

This tutorial focuses on the following key components of modern neural network model design:

* Layer-level (Micro-)Architecture
* Model Size and Topology
* Loss Function
* Block-level (Macro-)Architecture

While there are are certainly numerous additional key considerations, a deliberate and organized approach to these four primary design components captures the majority of performance variance for most tasks, enabling a customized high-performing state-of-the-art model for nearly every use case. Let us take a closer look into each of these components.

## Layer-level (Micro-)Architecture

A standard convolutional neural network (CNN) block is typically composed of the following three layers:

* linear operation
* feature normalization
* nonlinear activation

**Linear operation**: A convolution (or convolution transpose) layer remains a high-performing popular choice for medical imaging applications. If relevant, a 3D kernel is preferred (see Other Design Considerations below for data preparation recommendations). For the most part, a 3x3(x3) kernel size is near optimal (VGG-style) although depthwise (and/or separable) convolutions (ConvNeXt-style) do tend to yield marginal but robust improvements.

**Feature normalization**: While classic approaches utilize the `BatchNormalization` method, more recent data suggests that batch-independent normalization schemes such as `GroupNormalization` or `LayerNormalization` tend to lead to more stable optimization and improved generalization. This is especially true when batch-level statistics are unstable due to high data variance (e.g., wide dynamic range as seen in medical imaging). 

**Nonlinear activation**: While classic approaches utilize the `ReLU` nonlinearity, the resulting sparse activations may lead to unstable gradients during optimization. While modern optimizers tend to mitigate this theoretical risk, simple replacements such as the `LeakyReLU` offer a good alternative with only minimal incremental computational cost. Of note, more complex nonlinear functions such as `eLU` and `GeLU` have shown potential marginal improvement when used in combination with recent advanced architectures.

### Code Implementation

The following modular template can be used to create VGG-style blocks using interchangble layer components:

In [None]:
def create_layers(kernel_size=3, norm_name='LayerNormalization', func_name='LeakyReLU', **kwargs):

    # --- Define kwargs 
    kwargs_ = {
        'kernel_size': kernel_size,
        'padding': 'same',
        'kernel_initializer': 'he_normal'}

    # --- Define layers 
    conv = lambda x, filters, strides : layers.Conv3D(filters=filters, strides=strides, **kwargs_)(x)
    tran = lambda x, filters, strides : layers.Conv3DTranspose(filters=filters, strides=strides, **kwargs_)(x)

    norm = lambda x : getattr(layers, norm_name)()(x)
    relu = lambda x : getattr(layers, func_name)()(x) 

    concat = lambda *x : layers.Concatenate()(list(x))

    return conv, tran, norm, relu, concat

In [None]:
def create_blocks_vgg(**kwargs):

    conv, tran, norm, relu, concat = create_layers(**kwargs)

    conv1 = lambda filters, x : relu(norm(conv(x, filters, strides=1)))
    conv2 = lambda filters, x : relu(norm(conv(x, filters, strides=2)))
    tran2 = lambda filters, x : relu(norm(tran(x, filters, strides=2)))

    return conv1, conv2, tran2

### Experimental Results

In this section, the base network architecture is implemented using a standard VGG-style block, with each block representing the serial application of a linear operation, feature normalization, and nonlinear activation. In this context, we perform the following hyperparameter grid search:

Feature normalization: 

* `BatchNormalization` 
* `LayerNormalization`

Nonlinear activation: 

* `ReLU`
* `LeakyReLU`

See __[link](https://docs.google.com/viewer?url=https://raw.githubusercontent.com/peterchang77/dl_tutor/master/workshops/model-design/pdfs/00_micro.pdf)__ for summary of results.

### Recommendation

In general, recommend using a combination of `LayerNormalization` and `LeakyReLU` in conjunction with 3D convolution operations. As needed, consider a hyperparameter sweep of depthwise convolutions and GeLU nonlinearities (see ConvNeXt architecture below).


## Model Size and Topology

For any given design strategy, overall model size can often be scaled by varying either the total number of network layers or the total number of features (channel depth) per layer. By contrast, note that the total number of feature map resolutions (dictated by the total number of downsampling and, if relevant, upsampling operations) is commonly fixed such that the smallest feature map size is between 3x3(x3) to 5x5(x5). As an important consideration detailed below, the high-dimensional nature of medical imaging data makes it such that GPU memory is primarily constrained by the size of feature map activations *not* the actual model weights. 

**Number of layers**: At minimum, a single layer is required for each target feature map resolution. In this minimal use case, each convolutional block will include a downsampling mechanism (e.g., a series of stride-2 convolutional blocks). However, with the use of relatively small kernel sizes, it is recommended that a minimum of two convolution layers are applied at each resolution for adequate spatial coverage (e.g., alternating stride-2 and stride-1 convolutions). As an exception, the first (full resolution) feature map may sometimes be implemented with just a single layer. This aggressive downsampling strategy is typical of many state-of-the-art models with the assumption that high-frequency (small) features are not critical for high performance. In addition, specific to high-dimensional medical imaging, a disproportionate amount of GPU memory is required to instantiate full resolution feature maps. However, given the heterogeneity of medical imaging problems, this assumption should be evaluated empirically.

**Feature map depth**: As noted above, full resolution feature maps account for a disproportionate amount of GPU memory when working with high-dimensional medical imaging. With respect to this consideration, relatively small feature map depths are favored in the earlier model layers, especially when building fully-convolutional encoder-decoder (U-Net) architectures. By contrast, relatively large feature maps can be used in the deeper network layers without significant incremental computational burden. 

### Code Implementation

The following code snippet example can be used to define a series of alternating stride-2 and stride-1 convolutional blocks with variable feature map depth. Note that this specific implementation corresponds to the `small` topology and `linear` feature map growth rate decribed below. For full implementation, see additional code in the sections below.

```python
# --- Define filters
filters = [16, 32, 48, 64, 80, 96]

# --- Create encoder
l0 = conv1(filters[0], x)
l1 = conv1(filters[1], conv2(filters[1], l0))
l2 = conv1(filters[2], conv2(filters[2], l1))
l3 = conv1(filters[3], conv2(filters[3], l2))
l4 = conv1(filters[4], conv2(filters[4], l3))
l5 = conv1(filters[5], conv2(filters[5], l4))

# --- Create decoder
t4 = conv1(filters[4], tran2(filters[4], l5) + l4)
t3 = conv1(filters[3], tran2(filters[3], t4) + l3)
t2 = conv1(filters[2], tran2(filters[2], t3) + l2)
t1 = conv1(filters[1], tran2(filters[1], t2) + l1)
t0 = conv1(filters[0], tran2(filters[0], t1) + l0)
```

### Experimental Results

In this section, the base network architecture is implemented using standard VGG-style blocks with `LayerNormalization` and `LeakyReLU` nonlinearity. The original `(96, 96, 96)` input volume is downsampled five times yielding a total of 6 feature map resolutions (96, 48, 24, 12, 6, and 3). At each of the 6 feature map resolutions we perform the following hyperparameter grid search:

Number of layers (at each resolution):

* Small: `1-2-2-2-2-2`
* Medium: `2-2-2-2-2-2`
* Large: `2-3-3-3-3-3`

Feature map depth (at each resolution):

* Linear: `16-32-48-64-80-96`
* Exponential: `16-32-64-128-196-256`

See __[link](https://docs.google.com/viewer?url=https://raw.githubusercontent.com/peterchang77/dl_tutor/master/workshops/model-design/pdfs/01_scale.pdf)__ for summary of results.

### Recommendation

With sufficient training data size (and the adequate use of data augmentation), larger and deeper models show a trend towards better performance over smaller and shallower models and, at minimum, tend to be non-inferior. In the absence of a rigorous grid search, recommend using the `large` model topology with `exponential` feature map growth at each resolution.  

## Loss Function

Of the various model design components, perhaps the most important consideration is the combination of loss functions used for optimization. In the context of the medical imaging, high performance often requires accounting for the inherent class imbalance of data distributions, both on a per-exam (rare incidence of disease) and per-voxel (small size of finding) basis. To address class imbalance in classification and segmentation tasks, standard cross-entropy loss may be complimented with the use of class weights, focal loss, and/or differentiable Dice score loss.

**Class weights**: After calculating the standard per instance cross entropy loss and prior to loss reduction or aggregation, the individual loss values of imbalanced classes can be multiplied by a scalar value > 1.0 to increase the contribution of those specific exams and/or regions. For semantic segmentation tasks, this is implemented as a per-voxel weight tensor that is multiplied (point-wise) by the standard cross-entropy loss prior to reduction. Note that class weights do *not* need to fully compensate for the degree of class imbalance inherent to the data; rather a relatively small range of class weights between `[1, 10]` (and up to 20-40 in certain use cases) is typically more than adequate.

**Focal loss**: Compared to manual class weights which uniformly increase the contribution of certain classes to the loss function, focal loss dynamically weighs the contribution of each instance proportional to the *difficulty* of classification. Here, the proxy for *difficulty* is defined using the loss value itself, with high loss values suggesting poor classification. The *gamma* hyperparameter defines the relative emphasis on difficult cases, with higher *gamma* values placing greater emphasis on challenging predictions. The default *gamma* value of 2 often yields robust performance, although a grid search across values between `[1, 3]` may be used to fine-tune model accuracy.

**Dice loss**: The numerator of the Dice score coefficient is defined based on the sensitivity of foreground detection and is thus a strong regularizer for class imbalanced segmentation tasks. To covert the standard discrete Dice score coefficient to a continuous, differentiable function, replace the binarized prediction masks with a sigmoid-activated (binary) or softmax-activated (multi-class) logit score. Used in isolation, a Dice score loss may lead to unstable optimization dynamics (especially for small target predictions) and is thus often paired with a more stable standard cross-entropy or focal loss objective.

### Code Implementation

The following code implements weighted cross-entropy loss, focal loss, and differentiable soft Dice score loss.

In [None]:
def create_bce(y_true, y_pred, w_samp=None, from_logits=True, **kwargs):
    """
    Method to create weighted cross-entropy loss 
    
    """
    return losses.BinaryCrossentropy(from_logits=from_logits, **kwargs)(
        y_true=y_true, 
        y_pred=y_pred, 
        sample_weight=w_samp)

In [None]:
def create_foc(y_true, y_pred, gamma=2.0, alpha=1.0, w_samp=1.0, **kwargs):
    """
    Method to create focal loss 
    
    """
    # --- Calculate standard cross entropy with alpha weighting
    loss = tf.nn.weighted_cross_entropy_with_logits(
        labels=y_true, logits=y_pred, pos_weight=alpha)

    # --- Calculate modulation to pos and neg labels 
    p = tf.math.sigmoid(y_pred)
    modulation_pos = (1 - p) ** gamma
    modulation_neg = p ** gamma

    mask = tf.dtypes.cast(y_true > 0.5, dtype=tf.bool)
    modulation = tf.where(mask, modulation_pos, modulation_neg)

    return tf.math.reduce_mean(modulation * loss * w_samp)

In [None]:
def create_sft(y_true, y_pred, axis=(0, 1, 2, 3), epsilon=1, **kwargs):
    """
    Method to create soft Dice score loss 
    
    """
    true = y_true[..., 0]
    pred = tf.nn.sigmoid(y_pred[..., 0])

    A = tf.math.reduce_sum(true * pred, axis=axis) * 2
    B = tf.math.reduce_sum(true, axis=axis) + tf.math.reduce_sum(pred, axis=axis) + epsilon

    return -tf.math.reduce_mean(A / B)

For reference, the following code implements a regular (discrete) Dice score error metric for tracking performing during model optimization.

In [None]:
def create_dsc(y_true, y_pred, axis=(1, 2, 3), **kwargs):
    """
    Method to create Dice score error 
    
    """
    y_true = y_true > 0
    y_pred = y_pred > 0

    A = tf.math.count_nonzero(y_true & y_pred, axis=axis) * 2
    B = tf.math.count_nonzero(y_true, axis=axis) + tf.math.count_nonzero(y_pred, axis=axis)

    return tf.math.reduce_mean(tf.math.divide_no_nan(
        tf.cast(A, tf.float32), 
        tf.cast(B, tf.float32)))

### Experimental Results

In this section, the base network architecture is implemented using standard VGG-style blocks with `large` model topology and `exponential` feature map growth. Subsequently, we perform the following hyperparameter grid search:

* Binary cross-entropy (BCE): with class weights ranging from `[1, 2, 5, 10]`
* Focal loss: with *gamma* values ranging from `[1.5, 2.0, 3.0]`
* Soft Dice score loss: paired with each permutation of BCE or focal loss, as above

See __[link](https://docs.google.com/viewer?url=https://raw.githubusercontent.com/peterchang77/dl_tutor/master/workshops/model-design/pdfs/02_loss.pdf)__ for summary of results.

### Recommendation

In general, recommend using a combination of focal loss with *gamma* value of 2.0 and soft Dice score loss for robust performance across various tasks. However, given the importance of loss function choice for optimal performance, a hyperparameter sweep of loss formulations often yields favorable compute-to-accuracy ratio gains.

## Layer-level (Micro-)Architecture

While standard VGG-style blocks are both simple and high-performing across various medical imaging tasks, newer modeling strategies may yield marginal gains. Perhaps the most popular modification is the use of ResNet-style residual connections. More recently, the ConvNeXt-style block is a powerful motif that compiles various state-of-the art advances in a single high-performing heuristic. 

**ResNet**: The use of the residual connections improves overall gradient stability during the optimization process, allowing for faster and more robust model convergence as well as the potential for very deep network topologies spanning hundreds of layers. In addition, this strategy is simple to add alongside the standard VGG-style block and is thus a popular motif across various tasks.

**ConvNeXt**: By compiling and testing various model improvements introduced in recent years, the authors of ConvNeXt highlight key high-performing motifs. Perhaps the most important modification is factorization of the standard convolution into independent spatial (depthwise convolution) and feature (pointwise convolution) components, significantly reducing the overall number of model weights. In addition, empiric results inspired by advances in Transformer architecture alternate the use of GeLU nonlinearities with layer normalization. 

### Code Implementation

The following code implements ResNet-style and ConvNeXt-style blocks. Notice that by design, each block can be used interchangably with the VGG-style blocks defined earlier in this tutorial.

In [None]:
def create_blocks_resnet(**kwargs):

    conv1_, conv2_, tran2_ = create_blocks_vgg(**kwargs)
    conv, tran, norm, _, _ = create_layers(kernel_size=2, **kwargs)

    conv1 = lambda filters, x : conv1_(filters, conv1_(filters, x)) + x
    conv2 = lambda filters, x : conv1_(filters, conv2_(filters, x)) + conv(x=x, filters=filters, strides=2)
    tran2 = lambda filters, x : conv1_(filters, tran2_(filters, x)) + tran(x=x, filters=filters, strides=2)

    return conv1, conv2, tran2

In [None]:
def create_blocks_convnext(**kwargs):

    conv, tran, norm, _, _ = create_layers(kernel_size=2, **kwargs)

    conv1 = lambda filters, x : create_blocks_convnext_conv1(x=x, filters=filters) + x
    conv2 = lambda filters, x : conv(norm(x), filters, strides=2)
    tran2 = lambda filters, x : tran(norm(x), filters, strides=2)

    return conv1, conv2, tran2

In [None]:
def create_blocks_convnext_conv1(x, filters, kernel_size=3, factor=2, name=None, configs=None, **kwargs):
    """
    Method to create ConvNext-style block

    NOTE: compared to original implementation, an extra Layer Norm is added at beginning

    """
    # --- Define layers
    conv = lambda x, **kwargs : layers.Conv3D(
        kernel_initializer='he_normal', padding='same', **kwargs)(x)

    norm = lambda x : layers.LayerNormalization()(x)
    gelu = lambda x : layers.Activation('gelu')(x)

    # --- Apply ConvNext block
    l0 = conv(x=norm(x), filters=x.shape[-1], groups=x.shape[-1], kernel_size=kernel_size)
    l1 = conv(x=norm(l0), filters=int(filters * factor), kernel_size=1)
    l2 = conv(x=gelu(l1), filters=filters, kernel_size=1, name=name)

    return l2 

### Experimental Results

In this section, the base network architecture is implemented using standard VGG-style blocks with `large` model topology and `exponential` feature map growth. This base model comprising of standard VGG-style blocks is compared to models implemented with ResNet-style blocks or ConvNeXt-style blocks, while holding all other design choices and network topology considerations constant. Models are trained using weighted binary cross-entropy (class weight of 10) or focal loss (gamma of 2.0), either without or with a soft Dice score loss component.

See __[link](https://docs.google.com/viewer?url=https://raw.githubusercontent.com/peterchang77/dl_tutor/master/workshops/model-design/pdfs/03_macro.pdf)__ for summary of results.

### Recommendation

Both ResNet and ConvNeXt models yield marginal but robust gains over the base VGG architecture, with a trend towards overall more consistent improvements by the ConvNeXt strategy. In the absence of a rigorous grid search, recommend a baseline VGG model as well as a ConvNeXt model using previously identified optimal configurations for other hyperparameters to assess for incremental performance improvement. 

# Other Design Considerations

Though the following design considerations have a theoretical range of valid operating choices, base assumptions and other predefined heuristics can often achieve satisfactory performance without the need for rigorous hyperparameter search.  

**Data preprocessing**: In general, increased data context leads to better performance (e.g., larger field-of-view, higher resolution, 3D instead of 2D) and therefore recommend using the largest allowable model input within the constraints of available GPU memory and other model design choices noted above. As needed, consider a two-step cascaded approach to combine global context with higher resolution local features.  

**Optimization**: In general, deep supervision with auxiliary loss objectives at each available feature map resolution will improve performance across all tasks. Additionally, robust performance can often be achieved with the Adam optimizer coupled with a learning rate decay, which together supports a wide range of independently optimized per-weight updates without the need for carefully calibrating the learning rate.

In [None]:
lr_callback = callbacks.LearningRateScheduler(lambda epoch, lr : lr * 0.99)

# Model Training

Use the following modular code template to create models spanning the various hyperparameter configurations described in this tutorial:

In [None]:
def create_model(shape=(96, 96, 96, 1), style='vgg', losses=['bce'], class_weight=1.0, filters=[16, 32, 48, 64, 80, 96], **kwargs):

    inputs = {
        'x': Input(shape=shape, dtype='float32'),
        'y': Input(shape=shape, dtype='float32')}

    # --- Create layers
    conv, tran, norm, relu, concat = create_layers(**kwargs)

    # --- Create blocks 
    conv1, conv2, tran2 = {
        'vgg': create_blocks_vgg,
        'resnet': create_blocks_resnet,
        'convnext': create_blocks_convnext}[style](**kwargs)

    # --- Create first layer conv
    if style in ['resnet', 'convnext']:
        x = conv(inputs['x'], filters=filters[0], strides=1)
    else:
        x = inputs['x']

    # --- Create encoder
    l0 = conv1(filters[0], x)
    l1 = conv1(filters[1], conv2(filters[1], l0))
    l2 = conv1(filters[2], conv2(filters[2], l1))
    l3 = conv1(filters[3], conv2(filters[3], l2))
    l4 = conv1(filters[4], conv2(filters[4], l3))
    l5 = conv1(filters[5], conv2(filters[5], l4))

    # --- Create decoder
    t4 = conv1(filters[4], tran2(filters[4], l5) + l4)
    t3 = conv1(filters[3], tran2(filters[3], t4) + l3)
    t2 = conv1(filters[2], tran2(filters[2], t3) + l2)
    t1 = conv1(filters[1], tran2(filters[1], t2) + l1)
    t0 = conv1(filters[0], tran2(filters[0], t1) + l0)

    # --- Create logits
    logits = {}
    if style == 'convnext':
        logits['t0'] = conv(norm(t0), filters=1, strides=1)
        logits['t1'] = conv(norm(t1), filters=1, strides=1)
        logits['t2'] = conv(norm(t2), filters=1, strides=1)
        logits['t3'] = conv(norm(t3), filters=1, strides=1)
        logits['t4'] = conv(norm(t4), filters=1, strides=1)
        logits['l5'] = conv(norm(l5), filters=1, strides=1)
    else:
        logits['t0'] = conv(t0, filters=1, strides=1)
        logits['t1'] = conv(t1, filters=1, strides=1)
        logits['t2'] = conv(t2, filters=1, strides=1)
        logits['t3'] = conv(t3, filters=1, strides=1)
        logits['t4'] = conv(t4, filters=1, strides=1)
        logits['l5'] = conv(l5, filters=1, strides=1)

    # --- Create model
    model = Model(inputs=inputs, outputs=logits)

    # --- Create deep supervision: pool op
    mpool = lambda x : layers.MaxPooling3D(pool_size=2, strides=2, padding='valid')(x)

    # --- Create deep supervision: subsampled ground-truth mask
    y_true = {}
    y_true['t0'] = tf.cast(inputs['y'] == 2, tf.float32)
    y_true['t1'] = mpool(y_true['t0'])
    y_true['t2'] = mpool(y_true['t1'])
    y_true['t3'] = mpool(y_true['t2'])
    y_true['t4'] = mpool(y_true['t3'])
    y_true['l5'] = mpool(y_true['t4'])

    # --- Create deep supervision: subsampled weight mask
    w_samp = {}
    w_samp['t0'] = y_true['t0'] * (class_weight - 1) + 1
    w_samp['t1'] = mpool(w_samp['t0'])
    w_samp['t2'] = mpool(w_samp['t1'])
    w_samp['t3'] = mpool(w_samp['t2'])
    w_samp['t4'] = mpool(w_samp['t3'])
    w_samp['l5'] = mpool(w_samp['t4'])

    # --- Create losses
    loss_functions = {
        'bce': create_bce,
        'foc': create_foc,
        'sft': create_sft}

    for k in logits:
        for l in losses:

            loss = loss_functions[l](
                y_true=y_true[k],
                y_pred=logits[k])

            model.add_loss(loss)
            model.add_metric(loss, name='{}-{}'.format(k, l))

    # --- Create DSC
    dsc = create_dsc(y_true=y_true['t0'], y_pred=logits['t0'])
    model.add_metric(dsc, name='dsc')

    # --- Compile model
    model.compile(
        optimizer=optimizers.Adam(learning_rate=1e-3))

    return model

Here is a representative call to create the base VGG model with a combination of focal and Dice score loss:

In [None]:
model = create_model(style='vgg', losses=['foc', 'sft'])

To check the properties of the created model object, use the `model.summary()` method:

In [None]:
# --- Print model summary
model.summary(120)

### In-Memory Data

The following line of code will load all training data into RAM memory. This strategy can be effective for increasing speed of training for small to medium-sized datasets.

In [None]:
# --- Load data into memory
client.load_data_in_memory()

### Tensorboard

Optionally, to use Tensorboard, create the necessary Keras callbacks:

In [None]:
tensorboard_callback = callbacks.TensorBoard('./logs')

### Training

Once the model has been compiled and the data prepared (via a generator), training can be invoked using the `model.fit(...)` method. Ensure that both the training and validation data generators are used. In this particular example, we are defining arbitrary epochs of 100 steps each. Training will proceed for 200 epochs in total. Validation statistics will be assess every fifth epoch. As needed, tune these arugments as need.

In [None]:
model.fit(
    x=gen_train, 
    steps_per_epoch=100, 
    epochs=200,
    validation_data=gen_valid,
    validation_steps=100,
    validation_freq=5,
    callbacks=[tensorboard_callback, lr_callback]
)

### Launching Tensorboard

After running several iterations, start Tensorboard using the following cells. After Tensorboard has registered the first several checkpoints, subsequent data will be updated automatically (asynchronously) and model training can be resumed:

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs

## Saving and Loading a Model

After a model has been successfully trained, it can be saved and/or loaded by simply using the `model.save()` and `models.load_model()` methods. 

In [None]:
# --- Serialize a model
model.save('./model.hdf5')

In [None]:
# --- Load a serialized model
del model
model = models.load_model('./model.hdf5', compile=False)