# Overview

In this tutorial we will explore several strategies to address class imbalance as well as how to tune a network with weighted loss functions (e.g. class weights and masks). Strategies discussed include:

* stratified sampling
* pixel-level class weights
* pixel-level masked loss

Ultimately, the goal of this tutorial (and class assignment) is to create a high sensitivity detector for pulmonary infection (pneumonia) on chest radiographs. 

This tutorial is part of the class **Introduction to Deep Learning for Medical Imaging** at University of California Irvine (CS190); more information can be found at: https://github.com/peterchang77/dl_tutor/tree/master/cs190.

# Google Colab

The following lines of code will configure your Google Colab environment for this tutorial.

### Enable GPU runtime

Use the following instructions to switch the default Colab instance into a GPU-enabled runtime:

```
Runtime > Change runtime type > Hardware accelerator > GPU
```

### Mount Google Drive

The Google Colab environment is transient and will reset after any prolonged break in activity. To retain important and/or large files between sessions, use the following lines of code to mount your personal Google drive to this Colab instance:

In [None]:
try:
    # --- Mount gdrive to /content/drive/My Drive/
    from google.colab import drive
    drive.mount('/content/drive')
    
except: pass

Throughout this tutorial we will use the following global `MOUNT_ROOT` variable to reference a location to store long-term data. If you are using a local Jupyter server and/or wish to store your data elsewhere, please update this variable now.

In [None]:
# --- Set data directory
MOUNT_ROOT = '/content/drive/My Drive'

### Select Tensorflow library version

This tutorial will use the Tensorflow 2.1 library. Use the following line of code to select and download this specific version:

In [None]:
# --- Select Tensorflow 2.x (only in Google Colab)
% tensorflow_version 2.x
% pip install tensorflow-gpu==2.1

# Environment

### Jarvis library

In this notebook we will Jarvis, a custom Python package to facilitate data science and deep learning for healthcare. Among other things, this library will be used for low-level data management, stratification and visualization of high-dimensional medical data.

In [None]:
# --- Install jarvis (only in Google Colab or local runtime)
% pip install jarvis-md

### Imports

Use the following lines to import any additional needed libraries:

In [None]:
import numpy as np, pandas as pd
from tensorflow import losses, optimizers
from tensorflow.keras import Input, Model, models, layers, metrics
from jarvis.train import datasets, custom
from jarvis.train.client import Client
from jarvis.utils.general import overload, tools as jtools
from jarvis.utils.display import imshow

# Data

The data used in this tutorial will consist of (frontal projection) chest radiographs from the RSNA / Kaggle pneumonia challenge (https://www.kaggle.com/c/rsna-pneumonia-detection-challenge). The chest radiograph is the standard screening exam of choice to identify and trend changes in lung disease including infection (pneumonia). 

The custom `datasets.download(...)` method can be used to download a local copy of the dataset. By default the dataset will be archived at `/data/raw/xr_pna`; as needed an alternate location may be specified using `datasets.download(name=..., path=...)`. 

In [None]:
# --- Download dataset
datasets.download(name='xr/pna')

Once downloaded, the `datasets.prepare(...)` method can be used to generate the required python Generators to iterate through the dataset, as well as a `client` object for any needed advanced functionality. As needed, pass any custom configurations (e.g. batch size, normalization parameters, etc) into the optional `configs` dictionary argument. 

In [None]:
# --- Prepare generators
configs = {'batch': {'size': 8}}
gen_train, gen_valid, client = datasets.prepare(name='xr/pna', configs=configs, keyword='seg')

The created generators yield a total of `n` training samples based on the specified batch size. As before, each iteration yields two variables, `xs` and `ys`, each representing a dictionary of model input(s) and output(s). 

Compared to prior tutorials with just a single input and output, there are two separate inputs in the `xs` dictionary. Specifically a new array named `msk` is available as an additional `xs` input. The purpose of this `msk` object will be described in detail in the following tutorial. For now, note that the `msk` conforms to the boundaries of the right and left lungs individually.

Let us examine the generator data:

In [None]:
# --- Yield one example
xs, ys = next(gen_train)

# --- Print dict keys
print('xs keys: {}'.format(xs.keys()))
print('ys keys: {}'.format(ys.keys()))

In [None]:
# --- Print data shape
print('xs shape: {}'.format(xs['dat'].shape))
print('xs shape: {}'.format(xs['msk'].shape))
print('ys shape: {}'.format(ys['pna'].shape))

Use the following lines of code to visualize both the image data and corresponding mask label using the `imshow(...)` method:

In [None]:
# --- Show the first example, msk
xs, ys = next(gen_train)
imshow(xs['dat'][0], xs['msk'][0], radius=3)

In [None]:
# --- Show the first example, pna
xs, ys = next(gen_train)
imshow(xs['dat'][0], ys['pna'][0], radius=3)

Use the `montage(...)` function to create an N x N mosaic of all images:

In [None]:
# --- Show "montage" of all images, msk
imshow(xs['dat'], xs['msk'], figsize=(12, 12), radius=3)

In [None]:
# --- Show "montage" of all images, pna
imshow(xs['dat'], ys['pna'], figsize=(12, 12), radius=3)

### 3D operations

Note that the model input shapes for this exercise will be provided as 3D tensors. Even if your current model does not require 3D data (as in this current tutorial), all 2D tensors can be represented by a 3D tensor with a z-axis shape of 1. In addition, designing all models with this configuration (e.g. 3D operations) ensures that minimal code changes are needed when testing various 2D and 3D network architectures. 

# Weighted Loss

To implement custom loss weights (and/or masks), a generic `msk` array will be used to perform a point-wise multiplication against the final pixel-by-pixel loss. For locations where the loss should be **weighted**, use a constant value > 1, For locations where the loss should be ignored (**masked**), use a constant value of 0.

### `Client` object

The `Client` object is used as an interface to training data. Among other useful features, it is used to create Python generators for model training:

```python
# --- Create training generators
gen_train, gen_valid = client.create_generators()

# --- Create testing generators
test_train, test_valid = client.create_generators(test=True)
```

Note that in prior assignments, the use of the standard `datasets.prepare(...)` method will implicitly create a new `Client()` object and invoke the `create_generators(...)` function as shown above. While this approach is sufficient for many projects, if data yielded by the generator needs to be manually altered in any way, the original `Client()` object needs to be overloaded (modified). To facilitate modification, use the `@overload` Python decorator (as part of the `jarvis` library). See below for more information.

### Creating custom loss weights and masks

As above, the `msk` array that is yielded by the generator will need to be customized to account for any desired weights and/or masked values. Recall also that in the above data exploratory step, the default `msk` array will be populated with a mask of the right (value == 1) and left (value == 2) lungs. 

To modify this baseline `msk` array, we will overload the `preprocess` method of the `Client` class. The `arrays` variable is a dictionary that conforms to the various `xs` and `ys` arrays used to train the model. In this particular experiment, the following array dictionary is defined (conforms to the exact same nomenclature as above):

```python
arrays = {
    
    # --- All xs input(s)
    'xs': {
        'dat': ...,
        'msk': ...},
    
    # --- All ys output(s)
    'ys': {
        'pna': ...}
}
```

**Variant 1**: Mask the loss function to ignore non-lung and non-pneumonia pixels

In [None]:
@overload(Client)
def preprocess(self, arrays, **kwargs):
    """
    Method to create a custom msk array for class weights and/or masks
    
    """
    # --- Create msk
    msk = np.zeros(arrays['xs']['dat'].shape)
    
    return arrays

**Variant 2**: Use class weights to increase penalty for pnuemonia pixels

In [None]:
@overload(Client)
def preprocess(self, arrays, **kwargs):
    """
    Method to create a custom msk array for class weights and/or masks
    
    """
    # --- Create msk
    msk = np.ones(arrays['xs']['dat'].shape)
    
    return arrays

**Variant 3**: Use a combination of both class weights and masked losses

In [None]:
@overload(Client)
def preprocess(self, arrays, **kwargs):
    """
    Method to create a custom msk array for class weights and/or masks
    
    """
    # --- Create msk
    msk = np.zeros(arrays['xs']['dat'].shape)
    
    return arrays

### Visualization

To manually create your new custom `Client` object, pass the location of the `client*.yml` file that defines all the configurations for your data client. This file is located in the `ymls/` directory of the data download. For convenience, the `jarvis.utils.general.tools` module (imported as `jtools` above) can be used to find this location automatically: 

In [None]:
# --- Find client yml file
yml = '{}/data/ymls/client-seg.yml'.format(jtools.get_paths('xr/pna')['code'])

Use the following block to create a new `Client` object and visualize the custom `msk`: 

In [None]:
# --- Manually create Client
client = Client(yml)

# --- Manually create generators
gen_train, gen_valid = client.create_generators()

# --- Show
xs, ys = next(gen_train)
imshow(xs['dat'][0], xs['msk'][0])

# Stratified Sampling

The data consists of a total of three separate cohorts: negative; positive (infection); not normal but not infection. Note that the positive cases of infection account for approximately 1/5th (20%) of the data. As a result it will be necessary to increase the frequency of this cohort to optimize training efficiency.

Furthermore, the third indeterminate category (e.g. not normal but not infection) is quite subjective. In general it is quite to difficult to differentiate between various disease processes on chest radiograph alone. As a result, in this particular algorithm (where the primary goal is to optimize for algorithm sensitivity), we will **ignore** this third category. 

To implement stratified sampling, pass the appropriate `sampling` specifications to the `configs` variable when creating the `Client()` object (this is the exact same configs variable used by the `datasets.prepare(...)` method).

In [None]:
# --- Configs dict
configs = {
    'batch': {'size': 8},
    'sampling': {
        'cohort-neg': 0.5,
        'cohort-pna': 0.5}}

# --- Manually create Client
client = Client(yml, configs=configs)

# --- Manually create generators
gen_train, gen_valid = client.create_generators()

# Model

To localize pneumonia (lung infection) on chest radiographs, we will implement a standard contracting-expanding network (e.g. U-Net). Note that the original model input shape of `(512, 512)` is larger than in prior exercises; as a result a total of 6 subsampling blocks will be used. In the assignment, feel free to try various architecture permutations.

### Create Inputs

As before, use the `client.get_inputs(...)` to create model inputs:

In [None]:
# --- Create inputs
inputs = client.get_inputs(Input)

### Create model

In [None]:
# --- Define kwargs dictionary
kwargs = {
    'kernel_size': (1, 3, 3),
    'padding': 'same'}

# --- Define lambda functions
conv = lambda x, filters, strides : layers.Conv3D(filters=filters, strides=strides, **kwargs)(x)
norm = lambda x : layers.BatchNormalization()(x)
relu = lambda x : layers.LeakyReLU()(x)
tran = lambda x, filters, strides : layers.Conv3DTranspose(filters=filters, strides=strides, **kwargs)(x)

# --- Define stride-1, stride-2 blocks
conv1 = lambda filters, x : relu(norm(conv(x, filters, strides=1)))
conv2 = lambda filters, x : relu(norm(conv(x, filters, strides=(1, 2, 2))))
tran2 = lambda filters, x : relu(norm(tran(x, filters, strides=(1, 2, 2))))

In [None]:
# --- Define contracting layers
l1 = conv1(8, inputs['dat'])
l2 = conv1(16, conv2(16, l1))
l3 = conv1(32, conv2(32, l2))
l4 = conv1(48, conv2(48, l3))
l5 = conv1(64, conv2(64, l4))
l6 = conv1(80, conv2(80, l5))

# --- Define expanding layers
l7  = tran2(64, l6)
l8  = tran2(48, conv1(64, l7  + l5))
l9  = tran2(32, conv1(48, l8  + l4))
l10 = tran2(16, conv1(32, l9  + l3))
l11 = tran2(8,  conv1(16, l10 + l2))
l12 = conv1(8,  conv1(8,  l11 + l1))

# --- Create logits
logits = {}
logits['pna'] = layers.Conv3D(filters=2, name='pna', **kwargs)(l12)

# --- Create model
model = Model(inputs=inputs, outputs=logits) 

### Compile model

To compile this model, several custom `loss` and `metrics` objects will need to be defined.

#### Loss

As in prior tutorials, a standard (sparse) softmax cross-entropy loss will be used to optimize the segmentation model. A custom softmax cross entropy loss function is available as part of the `jarvis.train.custom` module to implement the necessary modifications for weighted and/or masked loss functions. To use this object, simply pass the `inputs['msk']` array as the first argument into the loss function initializer.

For many common loss functions, the low-level Tensorflow or Keras loss object does support weighted loss calculations, however are not availabe by default using the standard `model.fit(...)` API. To accomodate this, Python closures can be used to create a wrapper around the default loss function calculation:

```python
def sce(weights, scale=1.0):

    loss = losses.SparseCategoricalCrossentropy(from_logits=True)

    def sce(y_true, y_pred):

        return loss(y_true=y_true, y_pred=y_pred, sample_weight=weights) * scale

    return sce 
```

In [None]:
# --- Create custom weighted loss
loss = {'pna': custom.sce(inputs['msk'])}

#### Metrics

The goal of this model is to optimize for overall algorithm sensitivity. However, to ensure that the model simply does not predict positive for *every* pixel (this would in fact yield a sensitivity of 100%), we will ensure that the overall Dice score metric remains within a reasonable value.

A series of custom metrics including Dice score and sensitivity calculation are availabe as part of the `jarvis.train.custom` module to implement weighted and/or masked metrics. To use this object, simply pass the `inputs['msk']` array as the first argument into the metrics initializer. Since we are using two separate metrics in this example, pass both as part of a Python list.

In [None]:
# --- Create metrics
metrics = custom.dsc(weights=inputs['msk'])
metrics += [custom.softmax_ce_sens(weights=inputs['msk'])]

metrics = {'pna': metrics}

To compile the final model:

In [None]:
# --- Compile the model
model.compile(
    optimizer=optimizers.Adam(learning_rate=2e-4),
    loss=loss,
    metrics=metrics,
    experimental_run_tf_function=False)

# Training

### In-memory data

For moderate sized datasets which are too large to fit into immediate hard-drive cache, but small enough to fit into RAM memory, it is often times a good idea to first load all training data into RAM memory for increased speed of training. The `client` can be used for this purpose as follows:

In [None]:
# --- Load data into memory for faster training
client.load_data_in_memory()

*Important*: For the current dataset, which is relatively large, your Google Colab instance may not be able to load all data into memory. If so, just continue on to training below.

### Tensorboard

To use Tensorboard, create the necessary Keras callbacks:

In [None]:
from tensorflow.keras import callbacks  
tensorboard_callback = callbacks.TensorBoard('./logs')

Now, let us train the model:

In [None]:
# --- Train model
model.fit(
    x=gen_train, 
    steps_per_epoch=50, 
    epochs=120,
    validation_data=gen_valid,
    validation_steps=50,
    validation_freq=4,
    use_multiprocessing=True,
    callbacks=[tensorboard_callback])

### Launching Tensorboard

After running several iterations, start Tensorboard using the following cells. After Tensorboard has registered the first several checkpoints, subsequent data will be updated automatically (asynchronously) and model training can be resumed:

In [None]:
% load_ext tensorboard
% tensorboard --logdir logs

# Evaluation

To test the trained model, the following steps are required:

* load data
* use `model.predict(...)` to obtain logit scores
* compare prediction with ground-truth (Dice score, sensitivity)
* serialize in Pandas DataFrame

Recall that the generator used to train the model simply iterates through the dataset randomly. For model evaluation, the cohort must instead be loaded manually in an orderly way. For this tutorial, we will create new **test mode** data generators, which will simply load each example individually once for testing. 

In [None]:
# --- Create validation generator
test_train, test_valid = client.create_generators(test=True)

To run prediction on a single (first) example from the generator:

In [None]:
# --- Run a single prediction
x, y = next(test_valid)
logits = model.predict(x)

Let us visualize the predicted results. Recall that the `np.argmax(...)` function can be used to convert raw logit scores to predictions:

In [None]:
# --- Create prediction
pred = np.argmax(logits[0], axis=-1)

# --- Show
imshow(x['dat'][0], pred[0])

**Checkpoint** What is the problem with this mask?

Recall that during training, the algorithm is never penalized regardless of class for predictions *outside of the mask* (e.g. values == 0) used for training. Thus, to generate the final prediction, one needs to similarly remove the masked values of the prediction:

In [None]:
# --- Clean up pred using mask
pred[x['msk'][0, ..., 0] == 0] = 0

# --- Show
imshow(x['dat'][0], pred[0], radius=3)

That is much better. Let us look at the ground-truth:

In [None]:
# --- Show
imshow(x['dat'][0], y['pna'][0], radius=3)

### Testing for sensitivity

In addition to evaluating overall model Dice score, the goal of this exercise is to create a high-sensitivity model. Recall that sensitivity is defined as the number of TP predictions / all positive exams (e.g. proportion of positive findings that are correctly identified).

In [None]:
def calculate_sens(pred, true):
    """
    Method to calculate sensitivity from pred and true masks
    
    """
    pass

In [None]:
# --- Calculate sens
calculate_sens(
    pred=pred,
    true=y['pna'][0, ..., 0])

### Running evaluation

In [None]:
# --- Create validation generator
test_train, test_valid = client.create_generators(test=True)

for x, y in test_valid:
    
    # --- Create prediction
    pred = np.argmax(logits[0], axis=-1)
    
    # --- Clean up pred using mask
    pred[x['msk'][0, ..., 0] == 0] = 0
    
    # --- Calculate Dice
    dice = ...
    
    # --- Calculate sens
    sens = ...

### Saving results

In [None]:
# --- Define columns
df = pd.DataFrame(...)
df['dice'] = ...
df['sens'] = ...

## Saving and Loading a Model

After a model has been successfully trained, it can be saved and/or loaded by simply using the `model.save()` and `models.load_model()` methods. 

In [None]:
# --- Serialize a model
model.save('./lesion_segmentation.hdf5')

In [None]:
# --- Load a serialized model
del model
model = models.load_model('./lesion_segmentation.hdf5', compile=False)