In [None]:
from   ipywidgets import interactive, fixed
import matplotlib.pyplot as plt
from   model.cnn import CNN
import numpy as np
import pandas as pd
import random
from   source.dataloader import ValidationDataGenerator
from   source.gradcampp import model_modifier
from   source.metrics import roc_auc
from   source.preprocessing import get_preprocessed_polyp_segmentation_mask_info
from   source.visualization import plotter_batch, plotter_gradcam
import tensorflow as tf
%config Completer.use_jedi = False

### Setup

In [None]:
data_dir = "preprocessed-data"

# Number of trained estimators in the ensemble
n_estimators = 2

# Image properties
raw_image_size = (100,100,100)
patch_size     = (50,50,50) # Field-of-view of the network
n_channels     = 1 # 1 = CT image only, 2 = CT image + manual expert segmentation mask

# Fix global seed (as good as possible) to ensure reproducibility of results
seed = 42
np.random.seed(seed)
random.seed(seed)

### Meta information

Load information of the CT scans from 'ct_info.csv' and get a list of the preprocessed ct scans and segmentation masks that are available in 'preprocessed-data/'.

In [None]:
df_ct_info = pd.read_csv('ct_info.csv')
df_ct_info

In [None]:
df_preprocessed_info = get_preprocessed_polyp_segmentation_mask_info(data_dir)
df_preprocessed_info

Merge information into a single dataframe

In [None]:
df_data = df_ct_info.merge(df_preprocessed_info, how='inner', on=['patient', 'polyp', 'segmentation'])
df_data

### Datagenerator

In [None]:
data_generator = ValidationDataGenerator(data=df_data,
                                         batch_size=df_data.shape[0],
                                         patch_size=patch_size,
                                         n_channels=n_channels,
                                         num_threads=1,
                                         shuffle=False)

### Inspect test batches (sanity check)

In [None]:
test_batch   = data_generator[0]
X_test_batch = test_batch[0]
y_test_batch = test_batch[1]
print('X_test_batch:', X_test_batch.shape, ', y_test_batch:', y_test_batch.shape)

In [None]:
interactive(plotter_batch,
            batch        = fixed(test_batch),
            sample_nr    = (0,X_test_batch.shape[0]-1),
            channel      = (0,X_test_batch.shape[4]-1),
            slice_x      = (0,X_test_batch.shape[1]-1),
            slice_y      = (0,X_test_batch.shape[2]-1),
            slice_z      = (0,X_test_batch.shape[3]-1),
            cmap         = ["gist_yarg", "cool", "inferno", "magma", "plasma", "viridis"],
            reverse_cmap = [True, False])

### Model

In [None]:
model = CNN(input_shape=(50, 50, 50, n_channels), classes=1, dropout=0.1, mc=False)
model.summary()

### Predict on test set

In [None]:
predictions = []

for est in range(n_estimators):
    
    print('\nEstimator: {:d}'.format(est))
    
    # Load trained weights from disk
    weights = 'weights/ResNet18_3D_Dropout_SingleChannel_Ensemble_{:s}'.format(str(est+1))
    
    # Load trained weights into model
    print('\tLoad weights: {:s}'.format(weights))
    model.load_weights(weights)
    
    # Predict
    predictions_estimator = np.asarray(model.predict(data_generator))
    predictions.append(predictions_estimator)
    
predictions = np.asarray(predictions)

### Evaluation

In [None]:
# Ground truth
y_true = np.asarray(data_generator[0][1])

# Ensemble predictions
y_pred_ensemble = np.mean(predictions.squeeze(), axis=0)

In [None]:
# Calculate ROC-AUC
ensemble_roc_auc = roc_auc(y_true, y_pred_ensemble).numpy()
print('ROC_AUC = {:.2f}'.format(ensemble_roc_auc))

### GradCAM++

In [None]:
def gc_loss(output):
    loss_list  = [output[i][0] for i in range(df_data.shape[0])]
    loss_tuple = tuple(loss_list)
    return loss_tuple # (output[0][true_class[0]], output[1][true_class[1]], ...)

In [None]:
from tf_keras_vis.gradcam import GradcamPlusPlus
from tf_keras_vis.utils import normalize

cams = []

for est in range(n_estimators):
    
    print('\nGradCAM++ for Estimator {:d}'.format(est))
    
    ##############################
    ######### CNN Model ##########
    ##############################
    
    # Load trained weights from disk
    weights = 'weights/ResNet18_3D_Dropout_SingleChannel_Ensemble_{:s}'.format(str(est+1))
    
    # Load trained weights into model
    print('\tLoad weights: {:s}'.format(weights))
    model.load_weights(weights)
    
    ############################
    ######### GradCAM ##########
    ############################
    
    # Create Gradcam object
    gradcam = GradcamPlusPlus(model,
                              model_modifier=model_modifier,
                              clone=False)

    # Generate heatmap with GradCAM form first neuron
    cam_est = gradcam(gc_loss,
                      X_test_batch,
                      penultimate_layer=47, # model.layers number
                      seek_penultimate_conv_layer=False)
    cam_est = normalize(cam_est)
    
    # Store GradCAM
    cams.append(cam_est)
    
cams = np.asarray(cams)

In [None]:
model.layers[47].name

#### Plot GradCAM++

In [None]:
seg_data_generator = ValidationDataGenerator(data=df_data,
                                             batch_size=df_data.shape[0],
                                             patch_size=patch_size,
                                             n_channels=2,
                                             num_threads=1,
                                             shuffle=False)


X_seg = seg_data_generator[0][0][...,1]

In [None]:
interactive(plotter_gradcam,
            X            = fixed(X_test_batch),
            cam          = fixed(cams),
            seg          = fixed(X_seg),
            y_true       = fixed(y_true),
            y_pred       = fixed(y_pred_ensemble),
            sample_nr    = (0,X_test_batch.shape[0]-1),
            cam_nr       = (0,cams.shape[0]-1),
            modality     = (0,X_test_batch.shape[4]-1),
            slice_x      = (0,X_test_batch.shape[1]-1),
            slice_y      = (0,X_test_batch.shape[2]-1),
            slice_z      = (0,X_test_batch.shape[3]-1),
            alpha        = (0.0,1.0),
            threshold    = (0.0,1.0,0.05),
            cmap         = ["inferno", "cool", "gist_yarg", "jet", "magma", "plasma", "viridis"],
            reverse_cmap = [False, True])

#### GradCAM++ activation in segmented voxels

In [None]:
activation_threshold = 0.25

In [None]:
n_voxels_seg = X_seg[X_seg==1.0].shape[0]

aggregated_cam             = np.mean(cams, axis=0)
aggregated_cam[X_seg==0.0] = 0.0
n_voxels_high_act          = aggregated_cam[aggregated_cam>=activation_threshold].shape[0]

In [None]:
p_high_act = float(n_voxels_high_act) / float(n_voxels_seg)
print('Fraction of segmented voxels with a GradCAM++ activation > {:.2f}: {:.2f}% ({:d}/{:d})'.format(activation_threshold, 100*p_high_act, n_voxels_high_act, n_voxels_seg))