In [1]:
from   ipywidgets import interactive, fixed
import matplotlib.pyplot as plt
from   model.cnn import CNN
import numpy as np
import pandas as pd
import random
from   source.dataloader import ValidationDataGenerator
from   source.metrics import roc_auc
from   source.preprocessing import get_preprocessed_polyp_segmentation_mask_info
from   source.visualization import plotter_batch
import tensorflow as tf
%config Completer.use_jedi = False

### Settings

In [2]:
data_dir = "preprocessed-data"

# Repeated runs
n_estimators = 2

# Data
raw_image_size = (100,100,100)

# Neural network
patch_size = (50,50,50)

# Data, second channel and gaussian blob
n_channels = 2

# Fix global seed (as good as possible)
seed = 42
np.random.seed(seed)
random.seed(seed)

### Meta information

Load information about the CT scans from 'ct_info.csv' and get a list of the preprocessed ct scans and segmentation masks that are available from 'preprocessed-data/'.

In [3]:
df_ct_info = pd.read_csv('ct_info.csv')
df_ct_info

Unnamed: 0,patient,polyp,segmentation,histopathology,class_label,position,ct_file,segmentation_file
0,1,1,1,XXX,benign,prone,demo-data/demo_ct_001.npy,demo-data/demo_seg_001.npy
1,2,2,2,XXX,benign,prone,demo-data/demo_ct_002.npy,demo-data/demo_seg_002.npy
2,3,3,3,XXX,premalignant,prone,demo-data/demo_ct_003.npy,demo-data/demo_seg_003.npy


In [4]:
df_preprocessed_info = get_preprocessed_polyp_segmentation_mask_info(data_dir)
df_preprocessed_info

Unnamed: 0,patient,polyp,segmentation,preprocessed_ct_file,preprocessed_segmentation_file
0,1,1,1,/home/philipp/Projects/deep-learning-ct-colono...,/home/philipp/Projects/deep-learning-ct-colono...
1,2,2,2,/home/philipp/Projects/deep-learning-ct-colono...,/home/philipp/Projects/deep-learning-ct-colono...
2,3,3,3,/home/philipp/Projects/deep-learning-ct-colono...,/home/philipp/Projects/deep-learning-ct-colono...


Merge information

In [5]:
df_data = df_ct_info.merge(df_preprocessed_info, how='inner', on=['patient', 'polyp', 'segmentation'])
df_data

Unnamed: 0,patient,polyp,segmentation,histopathology,class_label,position,ct_file,segmentation_file,preprocessed_ct_file,preprocessed_segmentation_file
0,1,1,1,XXX,benign,prone,demo-data/demo_ct_001.npy,demo-data/demo_seg_001.npy,/home/philipp/Projects/deep-learning-ct-colono...,/home/philipp/Projects/deep-learning-ct-colono...
1,2,2,2,XXX,benign,prone,demo-data/demo_ct_002.npy,demo-data/demo_seg_002.npy,/home/philipp/Projects/deep-learning-ct-colono...,/home/philipp/Projects/deep-learning-ct-colono...
2,3,3,3,XXX,premalignant,prone,demo-data/demo_ct_003.npy,demo-data/demo_seg_003.npy,/home/philipp/Projects/deep-learning-ct-colono...,/home/philipp/Projects/deep-learning-ct-colono...


In [6]:
df_data.iloc[[0,2]]

Unnamed: 0,patient,polyp,segmentation,histopathology,class_label,position,ct_file,segmentation_file,preprocessed_ct_file,preprocessed_segmentation_file
0,1,1,1,XXX,benign,prone,demo-data/demo_ct_001.npy,demo-data/demo_seg_001.npy,/home/philipp/Projects/deep-learning-ct-colono...,/home/philipp/Projects/deep-learning-ct-colono...
2,3,3,3,XXX,premalignant,prone,demo-data/demo_ct_003.npy,demo-data/demo_seg_003.npy,/home/philipp/Projects/deep-learning-ct-colono...,/home/philipp/Projects/deep-learning-ct-colono...


#### Keras datagenerator (for testing)

In [7]:
data_generator = ValidationDataGenerator(data=df_data,
                                         batch_size=df_data.shape[0],
                                         patch_size=patch_size,
                                         n_channels=n_channels,
                                         num_threads=1,
                                         shuffle=False)

#### Inspect test batches

In [8]:
test_batch   = data_generator[0]
X_test_batch = test_batch[0]
y_test_batch = test_batch[1]
print('X_test_batch:', X_test_batch.shape, ', y_test_batch:', y_test_batch.shape)

X_test_batch: (3, 50, 50, 50, 2) , y_test_batch: (3,)


In [9]:
interactive(plotter_batch,
            batch        = fixed(test_batch),
            sample_nr    = (0,X_test_batch.shape[0]-1),
            channel      = (0,X_test_batch.shape[4]-1),
            slice_x      = (0,X_test_batch.shape[1]-1),
            slice_y      = (0,X_test_batch.shape[2]-1),
            slice_z      = (0,X_test_batch.shape[3]-1),
            cmap         = ["gist_yarg", "cool", "inferno", "magma", "plasma", "viridis"],
            reverse_cmap = [True, False])

interactive(children=(IntSlider(value=1, description='sample_nr', max=2), IntSlider(value=0, description='chan…

### Model

In [10]:
model = CNN(input_shape=(50, 50, 50, n_channels), classes=1, dropout=0.1, mc=False)
model.summary()

Model: "cnn"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input (InputLayer)              [(None, 50, 50, 50,  0                                            
__________________________________________________________________________________________________
res1a_branch2a (Conv3D)         (None, 25, 25, 25, 1 880         input[0][0]                      
__________________________________________________________________________________________________
bn1a_branch2a (BatchNormalizati (None, 25, 25, 25, 1 64          res1a_branch2a[0][0]             
__________________________________________________________________________________________________
activation (Activation)         (None, 25, 25, 25, 1 0           bn1a_branch2a[0][0]              
________________________________________________________________________________________________

### Predictions

In [11]:
predictions = []

for est in range(n_estimators):
    
    print('\nEstimator: {:d}'.format(est))
    
    # Load trained weights from disk
    weights = 'weights/ResNet18_3D_Dropout_SecondChannel_Segmentation_Ensemble_{:s}'.format(str(est+1))
    
    # Load trained weights into model
    print('\tLoad weights: {:s}'.format(weights))
    model.load_weights(weights)
    
    # Model predictions
    predictions_estimator = np.asarray(model.predict(data_generator))
    predictions.append(predictions_estimator)
    
predictions = np.asarray(predictions)


Estimator: 0
	Load weights: weights/ResNet18_3D_Dropout_SecondChannel_Segmentation_Ensemble_1

Estimator: 1
	Load weights: weights/ResNet18_3D_Dropout_SecondChannel_Segmentation_Ensemble_2


### Evaluation

In [12]:
# Ground truth
y_true = np.asarray(data_generator[0][1])

# Ensemble predictions
y_pred_ensemble = np.mean(predictions.squeeze(), axis=0)

In [13]:
# Calculate ROC-AUC
ensemble_roc_auc = roc_auc(y_true, y_pred_ensemble).numpy()
print('ROC_AUC = {:.2f}'.format(ensemble_roc_auc))

ROC_AUC = 0.50


### GradCAM++

In [14]:
def loss(output):
    loss_list  = [output[i][0] for i in range(df_data.shape[0])]
    loss_tuple = tuple(loss_list)
    return loss_tuple # (output[0][true_class[0]], output[1][true_class[1]], ...)

def model_modifier(m):
    m.layers[-1].activation = tf.keras.activations.linear
    return m

In [15]:
model.layers[47].name

'add_5'

In [16]:
from tf_keras_vis.gradcam import GradcamPlusPlus
from tf_keras_vis.utils import normalize

cams = []

for est in range(n_estimators):
    
    print('\nGradCAM++ for Estimator {:d}'.format(est))
    
    ##############################
    ######### CNN Model ##########
    ##############################
    
    # Load trained weights from disk
    weights = 'weights/ResNet18_3D_Dropout_SecondChannel_Segmentation_Ensemble_{:s}'.format(str(est+1))
    
    # Load trained weights into model
    print('\tLoad weights: {:s}'.format(weights))
    model.load_weights(weights)
    
    ############################
    ######### GradCAM ##########
    ############################
    
    # Create Gradcam object
    gradcam = GradcamPlusPlus(model,
                              model_modifier=model_modifier,
                              clone=False)

    # Generate heatmap with GradCAM form first neuron
    cam_est = gradcam(loss,
                      X_test_batch,
                      penultimate_layer=47, # model.layers number
                      seek_penultimate_conv_layer=False)
    cam_est = normalize(cam_est)
    
    # Store GradCAM
    cams.append(cam_est)
    
cams = np.asarray(cams)


GradCAM++ for Estimator 0
	Load weights: weights/ResNet18_3D_Dropout_SecondChannel_Segmentation_Ensemble_1

GradCAM++ for Estimator 1
	Load weights: weights/ResNet18_3D_Dropout_SecondChannel_Segmentation_Ensemble_2


#### Plot GradCAM

In [17]:
seg_data_generator = ValidationDataGenerator(data=df_data,
                                             batch_size=df_data.shape[0],
                                             patch_size=patch_size,
                                             n_channels=n_channels,
                                             num_threads=1,
                                             shuffle=False)


X_seg = seg_data_generator[0][0][...,1]

3

In [18]:
from matplotlib import cm
def plotter_gradcam(X, cam, seg, cam_nr, modality, y_true, y_pred, sample_nr, slice_x, slice_y, slice_z, alpha, threshold, cmap, reverse_cmap):
     
    # Pick and normalize image
    X = np.copy(X[sample_nr])
    X[...,0] = np.divide(X[...,0],np.amax(X[...,0]))
    if modality>0:
        X[...,1] = np.divide(X[...,1],np.amax(X[...,1]))
    
    # Pick cam
    cam = np.copy(cam[cam_nr,sample_nr,...])
    
    # Pick segmentation
    seg = np.copy(seg[sample_nr])
    
    # Colormap
    color_map = plt.cm.get_cmap(cmap)
    if reverse_cmap:
        color_map = color_map.reversed()
        
    print('Image {:d}: [{:.2f}, {:.2f}]'.format(sample_nr,np.amin(X),np.amax(X)))
    print('Cam   {:d}: [{:.2f}, {:.2f}]'.format(sample_nr,np.amin(cam),np.amax(cam)))
    print('True label: {:d}'.format(y_true[sample_nr]))
    print('Prediction: {:.2f}'.format(y_pred[sample_nr]))
    
    # Threshold
    X_in = np.zeros_like(X[...,modality])
    X_in[np.where(cam>=threshold)]=1
    X_out = np.zeros_like(X[...,modality])
    
    # Figure
    fig, ax = plt.subplots(2, 3, figsize=(14,7.6), sharex=True, sharey=True)
    
    # Original input image with GradCAM on top
    ax[0,0].imshow(X[slice_x,:,:,modality], cmap=plt.cm.get_cmap('gist_yarg').reversed())
    ax[0,1].imshow(X[:,slice_y,:,modality], cmap=plt.cm.get_cmap('gist_yarg').reversed())
    ax[0,2].imshow(X[:,:,slice_z,modality], cmap=plt.cm.get_cmap('gist_yarg').reversed())

    ax[0,0].imshow(cam[slice_x,:,:], cmap=color_map, vmin=0.0, vmax=1.0, alpha=alpha) # overlay
    ax[0,1].imshow(cam[:,slice_y,:], cmap=color_map, vmin=0.0, vmax=1.0, alpha=alpha) # overlay
    ax[0,2].imshow(cam[:,:,slice_z], cmap=color_map, vmin=0.0, vmax=1.0, alpha=alpha) # overlay
    
    ax[0,0].set_xlabel('y', fontsize=16)
    ax[0,0].set_ylabel('z', fontsize=16)
    ax[0,1].set_xlabel('x', fontsize=16)
    ax[0,1].set_ylabel('z', fontsize=16)
    ax[0,2].set_xlabel('x', fontsize=16)
    ax[0,2].set_ylabel('y', fontsize=16)
    
    im = plt.imshow(np.zeros((50,50)), cmap=color_map, vmin=0.0, vmax=1.0, alpha=1.0)
    
    # Binarized CT
    ax[1,0].imshow(X[slice_x,:,:,modality], cmap=plt.cm.get_cmap('gist_yarg').reversed())
    ax[1,1].imshow(X[:,slice_y,:,modality], cmap=plt.cm.get_cmap('gist_yarg').reversed())
    ax[1,2].imshow(X[:,:,slice_z,modality], cmap=plt.cm.get_cmap('gist_yarg').reversed())
    
    ax[1,0].imshow(X_in[slice_x,:,:], cmap=plt.cm.get_cmap('bwr'), alpha=0.6)
    ax[1,1].imshow(X_in[:,slice_y,:], cmap=plt.cm.get_cmap('bwr'), alpha=0.6)
    ax[1,2].imshow(X_in[:,:,slice_z], cmap=plt.cm.get_cmap('bwr'), alpha=0.6)
    
    ax[1,0].imshow(X_out[slice_x,:,:], cmap=plt.cm.get_cmap('bwr'), alpha=0.6)
    ax[1,1].imshow(X_out[:,slice_y,:], cmap=plt.cm.get_cmap('bwr'), alpha=0.6)
    ax[1,2].imshow(X_out[:,:,slice_z], cmap=plt.cm.get_cmap('bwr'), alpha=0.6)
    
    ax[1,0].set_xlabel('y', fontsize=16)
    ax[1,0].set_ylabel('z', fontsize=16)
    ax[1,1].set_xlabel('x', fontsize=16)
    ax[1,1].set_ylabel('z', fontsize=16)
    ax[1,2].set_xlabel('x', fontsize=16)
    ax[1,2].set_ylabel('y', fontsize=16)
    
    # Settings
    for axis in ax.flatten():
        axis.tick_params(labelsize=14)
    
    plt.tight_layout()
    
    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.8, 0.09, 0.025, 0.885])
    fig.colorbar(im, cax=cbar_ax)
    cbar_ax.tick_params(labelsize=14)
    
    plt.show()

In [19]:
interactive(plotter_gradcam,
            X            = fixed(X_test_batch),
            cam          = fixed(cams),
            seg          = fixed(X_seg),
            y_true       = fixed(y_true),
            y_pred       = fixed(y_pred_ensemble),
            sample_nr    = (0,X_test_batch.shape[0]-1),
            cam_nr       = (0,cams.shape[0]-1),
            modality     = (0,X_test_batch.shape[4]-1),
            slice_x      = (0,X_test_batch.shape[1]-1),
            slice_y      = (0,X_test_batch.shape[2]-1),
            slice_z      = (0,X_test_batch.shape[3]-1),
            alpha        = (0.0,1.0),
            threshold    = (0.0,1.0,0.05),
            cmap         = ["inferno", "cool", "gist_yarg", "jet", "magma", "plasma", "viridis"],
            reverse_cmap = [False, True])

interactive(children=(IntSlider(value=0, description='cam_nr', max=1), IntSlider(value=0, description='modalit…

### GradCAM++ activation in segmented voxels

In [20]:
activation_threshold = 0.25

In [21]:
n_voxels_seg = X_seg[X_seg==1.0].shape[0]

aggregated_cam             = np.mean(cams, axis=0)
aggregated_cam[X_seg==0.0] = 0.0
n_voxels_high_act          = aggregated_cam[aggregated_cam>=activation_threshold].shape[0]

In [22]:
p_high_act = float(n_voxels_high_act) / float(n_voxels_seg)
print('Fraction of segmented voxels with a GradCAM++ activation > {:.2f}: {:.2f}% ({:d}/{:d})'.format(activation_threshold, 100*p_high_act, n_voxels_high_act, n_voxels_seg))

Fraction of segmented voxels with a GradCAM++ activation > 0.25: 100.00% (771/771)
