# Overview

In this tutorial we will explore several strategies to address class imbalance as well as how to tune a network with weighted loss functions (e.g. class weights and masks). Strategies discussed include:

* stratified sampling
* pixel-level class weights
* pixel-level masked loss

Ultimately, the goal of this tutorial (and class assignment) is to create a high sensitivity detector for contrast enhancing tumor on chest radiographs. 

This tutorial is part of the class **Introduction to Deep Learning for Medical Imaging** at University of California Irvine (CS190); more information can be found at: https://github.com/peterchang77/dl_tutor/tree/master/cs190.

# Google Colab

The following lines of code will configure your Google Colab environment for this tutorial.

### Enable GPU runtime

Use the following instructions to switch the default Colab instance into a GPU-enabled runtime:

```
Runtime > Change runtime type > Hardware accelerator > GPU
```

### Select Tensorflow library version

This tutorial will use the Tensorflow 2.1 library. Use the following line of code to select and download this specific version:

In [None]:
# --- Download Tensorflow 2.x (only in Google Colab)
% pip install tensorflow-gpu==2.1

# Environment

### Jarvis library

In this notebook we will Jarvis, a custom Python package to facilitate data science and deep learning for healthcare. Among other things, this library will be used for low-level data management, stratification and visualization of high-dimensional medical data.

In [None]:
# --- Install jarvis (only in Google Colab or local runtime)
% pip install jarvis-md

### Imports

Use the following lines to import any additional needed libraries:

In [None]:
import numpy as np, pandas as pd
from tensorflow import losses, optimizers
from tensorflow.keras import Input, Model, models, layers, metrics
from jarvis.train import datasets, custom
from jarvis.train.client import Client
from jarvis.utils.general import overload, tools as jtools
from jarvis.utils.display import imshow

# Data

The data used in this tutorial will consist of brain tumor MRI exams derived from the MICCAI Brain Tumor Segmentation Challenge (BRaTS). More information about he BRaTS Challenge can be found here: http://braintumorsegmentation.org/. Each single 2D slice will consist of one of four different sequences (T2, FLAIR, T1 pre-contrast and T1 post-contrast). In this exercise, we will use this dataset to derive a model for slice-by-slice tumor segmentation. The custom `datasets.download(...)` method can be used to download a local copy of the dataset. By default the dataset will be archived at `/data/raw/mr_brats_2020`; as needed an alternate location may be specified using `datasets.download(name=..., path=...)`. 

In [None]:
# --- Download dataset
datasets.download(name='mr/brats-2020-mip')

Once downloaded, the `datasets.prepare(...)` method can be used to generate the required python Generators to iterate through the dataset, as well as a `client` object for any needed advanced functionality.

To specificy the correct Generator template file, pass a designated `keyword` string. In this tutorial, we will be using brain MRI volumes that have been preprocessed using a *mean intensity projection* (MIP) algorithm to subsample the original 155-slice inputs to 40-50 slices, facilitating ease of algorithm training within the Google Colab platform. In addition we will be performing voxel-level tumor prediction (e.g., a prediction for every single voxel in the 3D volume). To select the correct Client template for this task, use the keyword string `mip*vox`. 

In [None]:
# --- Prepare generators
gen_train, gen_valid, client = datasets.prepare(name='mr/brats-2020-mip', keyword='mip*vox')

As before, each iteration yields two variables, `xs` and `ys`, each representing a dictionary of model input(s) and output(s). In the current example, there is just a single input and output. Let us examine the generator data:

In [None]:
# --- Yield one example
xs, ys = next(gen_train)

# --- Print dict keys
print('xs keys: {}'.format(xs.keys()))
print('ys keys: {}'.format(ys.keys()))

In [None]:
# --- Print data shape
print('xs shape: {}'.format(xs['dat'].shape))
print('ys shape: {}'.format(ys['tumor'].shape))

### Tumor masks

The ground-truth labels are four-class masks of the same matrix shape as the model input:

In [None]:
print(ys['tumor'][0].shape)

Use the `imshow(...)` method to visualize the ground-truth tumor mask labels:

In [None]:
# --- Show tumor masks overlaid on original data
imshow(xs['dat'], ys['tumor'])

# --- Show tumor masks isolated
imshow(ys['tumor'])

### Enhancing tumor

In this tutorial, we will examine a challenging class imbalanced problem of segmenting enhancing tumor components. As a ratio of the overall tumor volume, enhancing tumor comprises a minority of foreground voxels (and an even smaller proportion of overall voxels in the entire image). Enhancing tumor is labeled as class `3` in this cohort.

In [None]:
# --- Print percentange of ground-truth voxels with enhancing tumor
print(np.sum(ys['tumor'] == 3) / ys['tumor'].size)

### Stratified Sampling

The first strategy we explore to address class imbalance is stratified sampling e.g., we will increase the sampling frequency of slices with enhancing tumor to approximately 50%. More precisely, we will use the following sampling distribution:

* class 0: 30% (background)
* class 1: 10% (tumor necrosis)
* class 2: 10% (tumor edema)
* class 3: 50% (tumor enhancement)

To do so, we pass the appropriate `sampling` specifications to the `configs` variable when creating the data generators and the `Client()` object:

In [None]:
# --- Configs dict
configs = {
    'batch': {'size': 8},
    'sampling': {
        'lbl-mip-00': 0.3,
        'lbl-mip-01': 0.1,
        'lbl-mip-02': 0.1,
        'lbl-mip-03': 0.5}}

# --- Prepare generators
gen_train, gen_valid, client = datasets.prepare(name='mr/brats-2020-mip', keyword='mip*vox', configs=configs)

### 3D operations

Note that the model input shapes for this exercise will be provided as 3D tensors. Even if your current model does not require 3D data (as in this current tutorial), all 2D tensors can be represented by a 3D tensor with a z-axis shape of 1. In addition, designing all models with this configuration (e.g. 3D operations) ensures that minimal code changes are needed when testing various 2D and 3D network architectures. 

# Weighted Loss

To implement custom loss weights (and/or masks), a generic `msk` array will be used to perform a point-wise multiplication against the final pixel-by-pixel loss. For locations where the loss should be **weighted**, use a constant value > 1, For locations where the loss should be ignored (**masked**), use a constant value of 0.

### Creating custom loss weights and masks

The `msk` array for weighted loss is considered a model input in the Tensorflow 2 / Keras API. Thus, in this implementation, the `xs` variable yielded by the Jarvis Python generator should contain two separate arrays:

```python
xs, ys = next(gen_train)

xs = {
    'dat': ... (as usual) ...,
    'msk': ... weighted loss modifier ...}
```

Additionally the `ys` dictionary currently contains a four-class segmentation label, whereas the target enhancing tumor for the current task is represented as class `3`.

Thus two modifications are needed to the current Python generators:

* modify `xs` to yield an additional tensor `msk` equal in size to the output (and input) tensor
* modify `ys` to yield a binarized tumor segmentation mask (equal to class `3`)

To modify the existing Python generators, use the nested generator strategy:

```python
def CustomGenerator(G):
    
    for xs, ys in G:
        
        # --- Add customization code here
        
        yield xs, ys
        
# --- Create custom generators
gen_train_custom = CustomGenerator(gen_train)
gen_valid_custom = CustomGenerator(gen_valid)
```

### Implementation

There three different custom class weight `msk` tensors that will be explored in this tutorial:

**Variant 1**: Use class weights to increase the penalty for enhancing tumor voxels

In [None]:
def CustomGenerator(G):
    
    for xs, ys in G:
        
        # --- Define msk
        xs['msk'] = np.ones(ys['tumor'].shape, dtype='float32')
        xs['msk'][ys['tumor'] == 3] = 5.0
        
        # --- Binarize ys
        ys['tumor'] = ys['tumor'] == 3
        ys['tumor'] = ys['tumor'].astype('uint8')
        
        yield xs, ys

**Variant 2**: Use a masked loss function to ignore the contribution of non-tumor voxels

In [None]:
def CustomGenerator(G):
    
    for xs, ys in G:
        
        # --- Define msk
        
        # --- Binarize ys
        ys['tumor'] = ys['tumor'] == 3
        ys['tumor'] = ys['tumor'].astype('uint8')
        
        yield xs, ys

**Variant 3**: Use a combination of both class weights and masked losses

In [None]:
def CustomGenerator(G):
    
    for xs, ys in G:
        
        # --- Define msk
        
        # --- Binarize ys
        ys['tumor'] = ys['tumor'] == 3
        ys['tumor'] = ys['tumor'].astype('uint8')
        
        yield xs, ys

### Visualization

Use the following block to create a new `Client` object and visualize the custom `msk`: 

In [None]:
# --- Create custom generators
gen_train_custom = CustomGenerator(gen_train)
gen_valid_custom = CustomGenerator(gen_valid)

xs, ys = next(gen_train_custom)
imshow(xs['dat'], xs['msk'])

# Model

To localize enhancing tumor on brain MRI, we will implement a standard contracting-expanding network (e.g. U-Net). In the assignment, feel free to try various architecture permutations.

### Create Inputs

As before, use the `client.get_inputs(...)` to create model inputs:

In [None]:
# --- Create inputs
inputs = client.get_inputs(Input)

In addition, be sure to include the additional new entry in `xs` representing the custom class weights `msk` tensor:

In [None]:
inputs['msk'] = Input(shape=(None, 240, 240, 1), dtype='float32', name='msk')

### Create model

In [None]:
# --- Define kwargs dictionary
kwargs = {
    'kernel_size': (1, 3, 3),
    'padding': 'same'}

# --- Define lambda functions
conv = lambda x, filters, strides : layers.Conv3D(filters=filters, strides=strides, **kwargs)(x)
norm = lambda x : layers.BatchNormalization()(x)
relu = lambda x : layers.ReLU()(x)
tran = lambda x, filters, strides : layers.Conv3DTranspose(filters=filters, strides=strides, **kwargs)(x)

concat = lambda a, b : layers.Concatenate()([a, b])

# --- Define stride-1, stride-2 blocks
conv1 = lambda filters, x : relu(norm(conv(x, filters, strides=1)))
conv2 = lambda filters, x : relu(norm(conv(x, filters, strides=(1, 2, 2))))
tran2 = lambda filters, x : relu(norm(tran(x, filters, strides=(1, 2, 2))))

In [None]:
# --- Define contracting layers
l1 = conv1(8, inputs['dat'])
l2 = conv1(16, conv2(16, l1))
l3 = conv1(32, conv2(32, l2))
l4 = conv1(48, conv2(48, l3))
l5 = conv1(64, conv2(64, l4))

# --- Define expanding layers
l6  = tran2(48, l5)
l7  = tran2(32, conv1(48, concat(l4, l6)))
l8  = tran2(16, conv1(32, concat(l3, l7)))
l9  = tran2(8,  conv1(16, concat(l2, l8)))
l10 = conv1(8,  l9)

# --- Create logits
logits = {}
logits['tumor'] = layers.Conv3D(filters=2, name='tumor', **kwargs)(l10)

# --- Create model
model = Model(inputs=inputs, outputs=logits) 

### Compile model

To compile this model, several custom `loss` and `metrics` objects will need to be defined.

#### Loss

As in prior tutorials, a standard (sparse) softmax cross-entropy loss will be used to optimize the segmentation model. A custom softmax cross entropy loss function is available as part of the `jarvis.train.custom` module to implement the necessary modifications for weighted and/or masked loss functions. To use this object, simply pass the `inputs['msk']` array as the first argument into the loss function initializer.

For many common loss functions, the low-level Tensorflow or Keras loss object does support weighted loss calculations, however are not availabe by default using the standard `model.fit(...)` API. To accomodate this, Python closures can be used to create a wrapper around the default loss function calculation:

```python
def sce(weights, scale=1.0):

    loss = losses.SparseCategoricalCrossentropy(from_logits=True)

    def sce(y_true, y_pred):

        return loss(y_true=y_true, y_pred=y_pred, sample_weight=weights) * scale

    return sce 
```

In [None]:
# --- Create custom weighted loss
loss = {'tumor': custom.sce(inputs['msk'])}

#### Metrics

For class imbalanced datasets, Dice score may be a limited evaluation metric. Thus we will additionally use foreground sensitivity as an additional value to track overall model performance.

A series of custom metrics including Dice score and sensitivity calculation are availabe as part of the `jarvis.train.custom` module to implement weighted and/or masked metrics. To use this object, simply pass the `inputs['msk']` array as the first argument into the metrics initializer. Since we are using two separate metrics in this example, pass both as part of a Python list.

In [None]:
# --- Create metrics
metrics = custom.dsc(weights=inputs['msk'])
metrics += [custom.softmax_ce_sens(weights=inputs['msk'])]

metrics = {'tumor': metrics}

To compile the final model:

In [None]:
# --- Compile the model
model.compile(
    optimizer=optimizers.Adam(learning_rate=2e-4),
    loss=loss,
    metrics=metrics,
    experimental_run_tf_function=False)

# Training

### In-memory data

For moderate sized datasets which are too large to fit into immediate hard-drive cache, but small enough to fit into RAM memory, it is often times a good idea to first load all training data into RAM memory for increased speed of training. The `client` can be used for this purpose as follows:

In [None]:
# --- Load data into memory for faster training
client.load_data_in_memory()

*Important*: For the current dataset, which is relatively large, your Google Colab instance may not be able to load all data into memory. If so, just continue on to training below.

### Tensorboard

To use Tensorboard, create the necessary Keras callbacks:

In [None]:
from tensorflow.keras import callbacks  
tensorboard_callback = callbacks.TensorBoard('./logs')

Now, let us train the model:

In [None]:
# --- Train model
model.fit(
    x=gen_train_custom, 
    steps_per_epoch=100, 
    epochs=10,
    validation_data=gen_valid_custom,
    validation_steps=100,
    validation_freq=4,
    use_multiprocessing=True,
    callbacks=[tensorboard_callback])

### Launching Tensorboard

After running several iterations, start Tensorboard using the following cells. After Tensorboard has registered the first several checkpoints, subsequent data will be updated automatically (asynchronously) and model training can be resumed:

In [None]:
% load_ext tensorboard
% tensorboard --logdir logs

# Evaluation

To test the trained model, the following steps are required:

* load data
* use `model.predict(...)` to obtain logit scores
* compare prediction with ground-truth (Dice score, sensitivity)
* serialize in Pandas DataFrame

Recall that the generator used to train the model simply iterates through the dataset randomly. For model evaluation, the cohort must instead be loaded manually in an orderly way. For this tutorial, we will create new **test mode** data generators, which will simply load each example individually once for testing. 

In [None]:
# --- Create validation generator
test_train, test_valid = client.create_generators(test=True)
test_train = CustomGenerator(test_train)
test_valid = CustomGenerator(test_valid)

To run prediction on a single (first) example from the generator:

In [None]:
# --- Run a single prediction
x, y = next(test_valid)
logits = model.predict(x)

Let us visualize the predicted results. Recall that the `np.argmax(...)` function can be used to convert raw logit scores to predictions:

In [None]:
# --- Create prediction
pred = np.argmax(logits[0], axis=-1)

# --- Show
imshow(x['dat'][0, ..., 0], pred)

**Checkpoint** What is the problem with this mask?

Recall that during training, the algorithm is never penalized regardless of class for predictions *outside of the mask* (e.g. values == 0) used for training. Thus, to generate the final prediction, one needs to similarly remove the masked values of the prediction:

In [None]:
# --- Clean up pred using mask
pred[x['msk'][0, ..., 0] == 0] = 0

# --- Show
imshow(x['dat'][0, ..., 0], pred, radius=3)

That is much better. Let us look at the ground-truth:

In [None]:
# --- Show
imshow(x['dat'][0], y['tumor'][0], radius=3)

### Testing for sensitivity

In addition to evaluating overall model Dice score, the goal of this exercise is to create a high-sensitivity model. Recall that sensitivity is defined as the number of TP predictions / all positive exams (e.g. proportion of positive findings that are correctly identified).

In [None]:
def calculate_sens(pred, true):
    """
    Method to calculate sensitivity from pred and true masks
    
    """
    pass

In [None]:
# --- Calculate sens
calculate_sens(
    pred=pred,
    true=y['tumor'][0, ..., 0])

### Running evaluation

In [None]:
# --- Create validation generator
test_train, test_valid = client.create_generators(test=True)

for x, y in test_valid:
    
    # --- Create prediction
    pred = np.argmax(logits[0], axis=-1)
    
    # --- Clean up pred using mask
    pred[x['msk'][0, ..., 0] == 0] = 0
    
    # --- Calculate Dice
    dice = ...
    
    # --- Calculate sens
    sens = ...

### Saving results

In [None]:
# --- Define columns
df = pd.DataFrame(...)
df['dice'] = ...
df['sens'] = ...

## Saving and Loading a Model

After a model has been successfully trained, it can be saved and/or loaded by simply using the `model.save()` and `models.load_model()` methods. 

In [None]:
# --- Serialize a model
model.save('./class_imbalance.hdf5')

In [None]:
# --- Load a serialized model
del model
model = models.load_model('./class_imbalance.hdf5', compile=False)