# Overview

Understanding the basis for neural network prediction is a key task for troubleshooting errors, improving model architecture and/or increasing data consistency. One of the most popular techniques for feature visualization is the Gradient-weighted Class Activation Mapping (Grad-CAM) approach, which uses the gradients of any target concept flowing into the final convolutional layer to produce a coarse localization map highlighting important regions in the image for predicting the concept. Grad-CAM is applicable to a wide variety of CNN model-families. In this tutorial, we will provide implementation details for the Grad-CAM technique in Tensorflow / Keras using a pneumonia detection algorithm on chest radiograph.

# Environment

The following lines of code will configure your Google Colab environment for this tutorial.

### Enable GPU runtime

Use the following instructions to switch the default Colab instance into a GPU-enabled runtime:

```
Runtime > Change runtime type > Hardware accelerator > GPU
```

### Jarvis library

In this notebook we will Jarvis, a custom Python package to facilitate data science and deep learning for healthcare. Among other things, this library will be used for low-level data management, stratification and visualization of high-dimensional medical data.

In [None]:
# --- Install Jarvis library
% pip install jarvis-md

### Imports

Use the following lines to import any needed libraries:

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import Input, Model, layers, optimizers, losses, metrics
from matplotlib import pyplot
from scipy.ndimage import zoom
from jarvis.train import datasets
from jarvis.utils.display import imshow, montage

# Data

The data used in this tutorial will consist of (frontal projection) chest radiographs from a subset of the RSNA / Kaggle pneumonia challenge (https://www.kaggle.com/c/rsna-pneumonia-detection-challenge). From the complete cohort, a random subset of 1,000 exams will be used for training and evaluation.

### Download

The custom `datasets.download(...)` method can be used to download a local copy of the dataset. By default the dataset will be archived at `/data/raw/xr_pna`; as needed an alternate location may be specified using `datasets.download(name=..., path=...)`. 

In [None]:
# --- Download dataset
datasets.download(name='xr/pna-512')

### Python generators

Once the dataset is downloaded locally, Python generators to iterate through the dataset can be easily prepared using the `datasets.prepare(...)` method:

In [None]:
# --- Prepare generators
gen_train, gen_valid, client = datasets.prepare(name='xr/pna-512', keyword='cls-512')

The created generators, `gen_train` and `gen_valid`, are designed to yield two variables per iteration: `xs` and `ys`. Both `xs` and `ys` each represent a dictionary of NumPy arrays containing model input(s) and output(s) for a single *batch* of training. The use of Python generators provides a generic interface for data input for a number of machine learning libraries including Tensorflow 2 / Keras.

Note that any valid Python iterable method can be used to loop through the generators indefinitely. For example the Python built-in `next(...)` method will yield the next batch of data:

In [None]:
# --- Yield one example
xs, ys = next(gen_train)

### Data exploration

To help facilitate algorithm design, each original chest radiograph has been resampled to a uniform `(512, 512)` matrix. Overall, the dataset comprises a total of `1,000` 2D images: a total of `500` negaative exams and `500` positive exams.

### `xs` dictionary

The `xs` dictionary contains a single batch of model inputs:

1. `dat`: input chest radiograph resampled to `(1, 512, 512, 1)` matrix shape

In [None]:
# --- Print keys 
for key, arr in xs.items():
    print('xs key: {} | shape = {}'.format(key.ljust(8), arr.shape))

### `ys` dictionary

The `ys` dictionary contains a single batch of model outputs:

1. `pna`: binary classification of pneumonia vs. not pneumonia chest radiographs

* 0 = negative
* 1 = positive of pneumonia

In [None]:
# --- Print keys 
for key, arr in ys.items():
    print('ys key: {} | shape = {}'.format(key.ljust(8), arr.shape))

### Visualization

Use the following lines of code to visualize a single input image using the `imshow(...)` method:

In [None]:
# --- Show labels
xs, ys = next(gen_train)
imshow(xs['dat'][0])

Use the following lines of code to visualize an N x N mosaic of all images in the current batch using the `imshow(...)` method:

In [None]:
# --- Show "montage" of all images
imshow(xs['dat'], figsize=(12, 12))

### Model inputs

For every input in `xs`, a corresponding `Input(...)` variable can be created and returned in a `inputs` dictionary for ease of model development:

In [None]:
# --- Create model inputs
inputs = client.get_inputs(Input)

In this example, the equivalent Python code to generate `inputs` would be:

```python
inputs = {}
inputs['dat'] = Input(shape=(1, 512, 512, 1))
```

# Model

To visualize learned features using Grad-CAM, we will first train a simple binary classifier CNN model for detection of pnuemonia. The simple model will consist of simple alternating stride-1 and stride-2 convolutions with a 3 x 3 kernel.

In [None]:
# --- Define lambda functions
conv = lambda x, filters, stride : layers.Conv3D(
    kernel_size=(1, 3, 3),
    filters=filters, 
    strides=stride, 
    padding='same')(x)

norm = lambda x : layers.BatchNormalization()(x)
relu = lambda x : layers.ReLU()(x)

# --- Define standard stride-1 and stride-2 blocks
conv1 = lambda filters, x : relu(norm(conv(x, filters=filters, stride=1)))
conv2 = lambda filters, x : relu(norm(conv(x, filters=filters, stride=(1, 2, 2))))

Now we are ready to build our model:

In [None]:
# --- Define contracting layers
l1 = conv1(16, inputs['dat'])
l2 = conv1(32, conv2(32, l1))
l3 = conv1(48, conv2(48, l2))
l4 = conv1(64, conv2(64, l3))
l5 = conv1(80, conv2(80, l4))
l6 = conv1(96, conv2(96, l5))

c1 = layers.GlobalAveragePooling3D()(l6)

# --- Create logits
logits = {}
logits['pna'] = layers.Dense(2, name='pna')(c1)

# --- Create model
model = Model(inputs=inputs, outputs=logits)

Note that from an original input of `(512, 512)`, application of 5 total subsampling operations will yield a `(16, 16)` feature map in the `l6` (last) convolutional layer.

# Model Training

### Compile model

In [None]:
# --- Compile model
model.compile(
    optimizer=optimizers.Adam(learning_rate=2e-4), 
    loss={'pna': losses.SparseCategoricalCrossentropy(from_logits=True)}, 
    metrics={'pna': 'sparse_categorical_accuracy'})

### In-Memory Data

The following line of code will load all training data into RAM memory. This strategy can be effective for increasing speed of training for small to medium-sized datasets.

In [None]:
# --- Load data into memory
client.load_data_in_memory()

### Training

In [None]:
model.fit(
    x=gen_train, 
    steps_per_epoch=100, 
    epochs=8,
    validation_data=gen_valid,
    validation_steps=100,
    validation_freq=4)

# Grad-CAM

Gradient-weighted class activation mapping (Grad-CAM) uses the gradients of any target concept flowing into the final convolutional layer to produce a coarse localization map highlighting important regions in the image for predicting the concept. Grad-CAM is applicable to a wide variety of CNN model-families: (1) CNNs with fully-connected layers, (2) CNNs used for structured outputs, (3) CNNs used in tasks with multimodal inputs or reinforcement learning, without any architectural changes or re-training. 

![Grad-CAM](https://miro.medium.com/proxy/1*hHPn81BbKEl7xDsHr5aSIA.png)

Additional information can be found here: https://arxiv.org/pdf/1610.02391.pdf.

In [None]:
def create_heatmap(x, model, layer_index=-5, pred_class=1):
    """
    Method to create heatmap using Grad-CAM technique
    
    """
    # --- Create new model including the last layer index
    grad_model = tf.keras.models.Model(
        [model.inputs], [model.layers[layer_index].output, model.output])

    # --- Record forward pass operations as TF objects 
    with tf.GradientTape() as tape:
        last_conv_layer_output, preds = grad_model(x)
        class_channel = preds['pna'][:, pred_class]

    # --- Calculate gradient
    grads = tape.gradient(class_channel, last_conv_layer_output)

    # --- Determine mean gradient for each channel (feature map)
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2, 3))

    # --- Scale and squeeze
    heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)
    
    # --- Convert to Numpy
    heatmap = heatmap.numpy()
    
    # --- Clip and normalize to [0 - 1]
    heatmap = heatmap.clip(min=0) / heatmap.max(axis=(1,2), keepdims=True)
    
    return heatmap

In [None]:
# --- Create heatmaps for data batch in xs
heatmap = create_heatmap(xs, model)

Use the following code to overlay heatmaps onto raw data:

In [None]:
def overlay(x, heatmap, cmap='jet', alpha=0.2, figsize=(7, 7)):
    
    # --- Use Jarvis montage function to collapse into a N x N grid
    im = np.squeeze(montage(x['dat']))
    hm = np.squeeze(montage(heatmap))
    
    # --- Zoom
    hm = zoom(hm, zoom=np.array(im.shape) / np.array(hm.shape), order=1)
    
    # --- Draw figure
    pyplot.clf()
    pyplot.figure(figsize=figsize)
    pyplot.axis('off')
    pyplot.imshow(im, cmap='gray')
    pyplot.imshow(hm, cmap=cmap, alpha=alpha)

In [None]:
# --- Draw
xs, ys = next(gen_train)
overlay(xs, heatmap, figsize=(12, 12))