# Overview

In this tutorial we will create a model for deep clustering as a pretext task for kidney segmentation.

This tutorial is part of the class **Introduction to Deep Learning for Medical Imaging** at University of California Irvine (CS190); more information can be found at: https://github.com/peterchang77/dl_tutor/tree/master/cs190.

# Google Colab

The following lines of code will configure your Google Colab environment for this tutorial.

### Enable GPU runtime

Use the following instructions to switch the default Colab instance into a GPU-enabled runtime:

```
Runtime > Change runtime type > Hardware accelerator > GPU
```

# Environment

### Jarvis library

In this notebook we will Jarvis, a custom Python package to facilitate data science and deep learning for healthcare. Among other things, this library will be used for low-level data management, stratification and visualization of high-dimensional medical data.

In [None]:
# --- Install jarvis (only in Google Colab or local runtime)
% pip install jarvis-md

### faiss library

To facilitate fast kmeans clustering, we will use an efficient algorithm implemented by the Facebook AI Research team as part of the `faiss` library. In brief, `faiss` is a library for efficient similarity search and clustering of dense vectors. More information can be found here: https://github.com/facebookresearch/faiss.

In [None]:
# --- Install faiss
% pip install faiss-cpu

### Imports

Use the following lines to import any additional needed libraries:

In [None]:
import numpy as np, pandas as pd
from scipy import ndimage
import tensorflow as tf
from tensorflow.keras import Input, Model, losses, metrics, layers, optimizers
import faiss
from jarvis.train import datasets
from jarvis.utils import io
from jarvis.utils.display import imshow

# Data

The data used in this tutorial will consist of kidney tumor CT exams derived from the Kidney Tumor Segmentation Challenge (KiTS). More information about he KiTS Challenge can be found here: https://kits21.kits-challenge.org/. In this exercise, we will use this dataset to derive a model for slice-by-slice kidney segmentation. The custom `datasets.download(...)` method can be used to download a local copy of the dataset. By default the dataset will be archived at `/data/raw/ct_kits`; as needed an alternate location may be specified using `datasets.download(name=..., path=...)`. 

In [None]:
# --- Download dataset
datasets.download(name='ct/kits')

### Data loader

In this tutorial, only the middle 2D slice of each volume will be used to promote fast model convergence. Since this small dataset fits easily into RAM memory, the following code block may be used to load these slices into a single Numpy array. Preparing data in this manner will also facilitate rapid iteration including efficient dataset clustering during the training process.

In [None]:
def load_data(label=1, flip=True, a_min=-128, a_max=256):

    # --- Create data client
    _, _, client = datasets.prepare(name='ct/kits', keyword='3d')

    dats, lbls = [], []

    for sid, fnames, header in client.db.cursor():

        lbl, _ = io.load(fnames['lbl-crp'])
        
        if label in lbl:
            
            dat, _ = io.load(fnames['dat-crp'])
            dats.append(dat[48:49])
            lbls.append(lbl[48:49] >= label)

            if header['cohort-left'] and flip:
                dats[-1]= dats[-1][..., ::-1, :]
                lbls[-1]= lbls[-1][..., ::-1, :]

    dats = np.stack(dats, axis=0)
    lbls = np.stack(lbls, axis=0)
    
    # --- Nomralize dats
    dats = (dats - a_min) / (a_max - a_min)
    dats = dats.clip(min=0, max=1)

    return {'dat': dats, 'lbl': lbls}

In [None]:
# --- Load data
xs = load_data()

### KITS Data

The input images in the variable `dat` are matrices of shape `1 x 96 x 96 x 1`. Note that even though the images here are 2D in shape, the full matrix is a 3D tensor `(z, y, x)` where `z = 1` in this implementation. Note that although the 3rd z-axis dimension is redundant here (for a single slice input), more complex models and architectures will commonly require a full 3D tensor. Because of this, we will directly use 3D convolutions throughout the tutorial materials for consistency.

Use the following lines of code to visualize using the `imshow(...)` method:

In [None]:
# --- Show the first example
imshow(xs['dat'][0])

Use the `montage(...)` function to create an N x N mosaic of all images:

In [None]:
# --- Show "montage" of 16 images
imshow(xs['dat'][:16])

### Kidney masks

The ground-truth labels are two class masks of the same matrix shape as the model input:

In [None]:
print(xs['lbl'][0].shape)

The three classes represent:

* class 0: background
* class 1: kidney

Use the `imshow(...)` method to visualize the ground-truth tumor mask labels:

In [None]:
# --- Show tumor masks overlaid on original data
imshow(xs['dat'][:16], xs['lbl'][:16])

# --- Show tumor masks isolated
imshow(xs['lbl'][:16])

# Clusters

In [None]:
def create_features(x, x_weight=1, x_blur=3, coords_weight=1., backbone=None, backbone_weight=1., **kwargs):
    """
    Method to construct feature vector for clustering
    
    """
    x_ = [] 

    # --- Use features from raw data voxels
    if x_weight > 0:
        xx = x.copy()
        if x_blur > 0:
            xx[:, 0] = ndimage.gaussian_filter(xx[:, 0], sigma=(0, x_blur, x_blur, 0))
        x_.append(xx * x_weight)

    # --- Use features from coordinate location
    if coords_weight > 0:
        ij = np.meshgrid(*tuple([np.linspace(0, 1, 96) for _ in range(2)]), indexing='ij')
        ij = np.expand_dims(np.stack(ij, axis=-1), axis=0)
        ij = np.stack([ij] * x.shape[0], axis=0)
        x_.append(ij * coords_weight)

    # --- Use features from CNN-derived backbone
    if backbone is not None:
        yy = backbone.predict(x)
        x_.append(yy * backbone_weight)

    return np.concatenate(x_, axis=-1).reshape(x.size, -1)

In [None]:
def create_clusters(x, n_clusters=8, **kwargs):

    # --- Create features
    x_ = create_features(x=x, **kwargs)

    # --- Apply kmeans clustering
    kmeans = faiss.Kmeans(x_.shape[-1], n_clusters)
    kmeans.train(x_.astype('float32'))
    clusters = kmeans.index.search(x_.astype('float32'), 1)[1].reshape(x.shape)

    return kmeans, clusters

In [None]:
kmeans, clusters = create_clusters(x=xs['dat'], n_clusters=8)

In [None]:
imshow(xs['dat'][:16], clusters[:16])
imshow(clusters[:16])

# Model

For this task, we will implement a standard contracting-expanding network for semantic segmentation (e.g. U-Net). In the assignment, feel free to try various architecture permutations.

### Create backbone

Define standard lambda functions:

In [None]:
# --- Define kwargs dictionary
kwargs = {
    'kernel_size': (1, 3, 3),
    'padding': 'same'}

# --- Define lambda functions
conv = lambda x, filters, strides : layers.Conv3D(filters=filters, strides=strides, **kwargs)(x)
norm = lambda x : layers.BatchNormalization()(x)
relu = lambda x : layers.ReLU()(x)
tran = lambda x, filters, strides : layers.Conv3DTranspose(filters=filters, strides=strides, **kwargs)(x)

concat = lambda a, b : layers.Concatenate()([a, b])

# --- Define stride-1, stride-2 blocks
conv1 = lambda filters, x : relu(norm(conv(x, filters, strides=1)))
conv2 = lambda filters, x : relu(norm(conv(x, filters, strides=(1, 2, 2))))
tran2 = lambda filters, x : relu(norm(tran(x, filters, strides=(1, 2, 2))))

Define standard U-Net backbone:

In [None]:
def create_unet():
    
    # --- Input input
    x = Input(shape=(None, 96, 96, 1), dtype='float32')

    # --- Define contracting layers
    l1 = conv1(8, x)
    l2 = conv1(16, conv2(16, l1))
    l3 = conv1(32, conv2(32, l2))
    l4 = conv1(48, conv2(48, l3))
    l5 = conv1(64, conv2(64, l4))

    # --- Define expanding layers
    l6  = tran2(48, l5)
    l7  = tran2(32, conv1(48, concat(l4, l6)))
    l8  = tran2(16, conv1(32, concat(l3, l7)))
    l9  = tran2(8,  conv1(16, concat(l2, l8)))
    l10 = conv1(8,  l9)

    # --- Create embedding
    outputs = layers.Conv3D(filters=8, **kwargs)(l10)

    # --- Create model
    backbone = Model(inputs=x, outputs=outputs) 
    
    return backbone

# Training model

Using this `backbone` model, a total of two separate `training` models need to be created:

* training for deep clustering pretext task (pretraining)
* training for kidney segmentation task (fine-tuning)

For both `training` models, the `backbone` model architecture is wrapped in a second model with additional layer(s) that define optimization behavior including loss function derivations.

### Inputs

In [None]:
def create_inputs(use_augmentation=True, **kwargs):
    """
    Method to create generic model inputs (for pretraining and fine-tuning)
    
    """
    x = Input(shape=(None, 96, 96, 1))
    y = Input(shape=(None, 96, 96, 1))

    inputs = {'x': x, 'y': y}

    # --- Data augmentation
    if use_augmentation:
        
        a = layers.Concatenate()((inputs['x'][:, 0], inputs['y'][:, 0]))
        a = layers.experimental.preprocessing.RandomRotation(factor=0.2, interpolation='nearest')(a)
        a = layers.experimental.preprocessing.RandomTranslation(0.2, 0.2, interpolation='nearest')(a)
        a = layers.experimental.preprocessing.RandomZoom(0.2, interpolation='nearest')(a)
        a = tf.expand_dims(a, axis=1)

        x = a[..., 0:1]
        y = a[..., 1:2]
        
    return inputs, x, y

In [None]:
def calculate_dsc(y_true, y_pred, cls=1):
    """ 
    Method to calculate Dice score for given class
    
    """
    true = y_true[..., 0] == cls
    pred = tf.math.argmax(y_pred, axis=-1) == cls

    A = tf.math.count_nonzero(true & pred) * 2
    B = tf.math.count_nonzero(true) + tf.math.count_nonzero(pred)

    return tf.math.divide_no_nan(tf.cast(A, tf.float32), tf.cast(B, tf.float32))

In [None]:
def compile_model(training, y, outputs, use_dsc=False, **kwargs):

    sce = losses.SparseCategoricalCrossentropy(from_logits=True)(y_true=y, y_pred=outputs)
    acc = metrics.sparse_categorical_accuracy(y_true=y, y_pred=outputs)

    training.add_loss(sce)
    training.add_metric(acc, 'acc')

    if use_dsc:
        dsc = calculate_dsc(y_true=y, y_pred=outputs)
        training.add_metric(dsc, 'dsc')

    training.compile(optimizer=optimizers.Adam(learning_rate=1e-3))

    return training

In [None]:
def pretrain(xs, clusters, backbone=None, epochs=50, batch_size=15, **kwargs):

    # --- Create inputs
    inputs, x, y = create_inputs(**kwargs)

    # --- Create backbone (unet)
    if backbone is None:
        backbone = create_unet()

    # --- Create training outputs
    outputs = backbone(x)
    outputs = layers.Conv3D(kernel_size=1, filters=(clusters.max() + 1))(outputs)
    
    # --- Create training model and losses
    training = Model(inputs=inputs, outputs=outputs)
    training = compile_model(training, y, outputs)

    # --- Train
    training.fit(x={'x': xs['dat'], 'y': clusters}, epochs=epochs, batch_size=batch_size)

    return backbone, training

In [None]:
backbone, training = pretrain(xs, clusters)

In [None]:
def finetune(xs, backbone=None, N=10, epochs=500, batch_size=15, validation_freq=50, **kwargs): 

    # --- Create inputs
    inputs, x, y = create_inputs(**kwargs)

    # --- Create backbone (unet)
    if backbone is None:
        backbone = create_unet()

    # --- Create training outputs
    outputs = backbone(x)
    outputs = layers.Conv3D(kernel_size=1, filters=2)(outputs)
    
    # --- Create training model and losses
    training = Model(inputs=inputs, outputs=outputs)
    training = compile_model(training, y, outputs, use_dsc=True)

    # --- Train
    training.fit(
        x={'x': xs['dat'][:N], 'y': xs['lbl'][:N]}, 
        validation_data={'x': xs['dat'][N:], 'y': xs['lbl'][N:]}, 
        validation_freq=validation_freq,
        batch_size=max(N, batch_size),
        epochs=epochs)

    return backbone, training

In [None]:
backbone, training = finetune(xs, backbone=None)

### Training Loops

In [None]:
def run_experiment(xs, n_clusters=8, epochs=3, pretrain_epochs=50, finetune_epochs=500, N_finetune=10):

    backbone = None
    
    for epoch in range(epochs):

        print('==================================================================')
        print('STARTING EPOCH: {}'.format(epoch + 1))
        print('==================================================================')

        # --- Create clusters
        kwargs = {
            'backbone': backbone, 
            'x_weight': 1. if epoch == 0 else 0, 
            'coords_weight': 1. if epoch == 0 else 0}

        kmeans, clusters = create_clusters(x=xs['dat'], n_clusters=n_clusters, **kwargs)

        # --- Perform pretraining
        backbone, training = pretrain(xs=xs, clusters=clusters, backbone=backbone, epochs=pretrain_epochs)

    # --- Perform fine-tuning
    backbone, training = finetune(xs=xs, clusters=clusters, N=N_finetune, epochs=finetune_epochs)
    
    return backbone, training

In [None]:
backbone, training = run_experiment(xs)

# Evaluation

In [88]:
# --- Run prediction
logits = training.predict({'x': xs['dat'], 'y': xs['lbl']})

dice = []
for y_true, y_pred in zip(xs['lbl'], logits):
    
    dsc = calculate_dsc(y_true=y_true, y_pred=y_pred).numpy()
    dice.append(dsc)

### Saving results

In [90]:
# --- Define columns
df = pd.DataFrame(np.arange(len(dice)))
df['dice'] = dice

## Saving and Loading a Model

After a model has been successfully trained, it can be saved and/or loaded by simply using the `training.save()` and `training.load_model()` methods. 

In [None]:
# --- Serialize a model
training.save('./model.hdf5')

In [None]:
# --- Load a serialized model
del training
training = models.load_model('./model.hdf5', compile=False)