# Experiment - Grad Cam (Captum)

In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../../../..')

import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

import datetime

import numpy as np

from src.data import train_test_split, MRISequence
from src.model import create_model, compile_model, load_checkpoint
from src.model.evaluation import show_metrics

In [2]:
import seaborn as sns
import matplotlib.pyplot as plt

sns.set(style="white")

plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['image.cmap'] = 'viridis'

%config InlineBackend.figure_format='retina'
plt.rcParams.update({'font.size': 15})

## Setup

In [3]:
%%time

ROOT_DIR = '../../../../tmp'
DEFAULT_CHECKPOINT_DIRECTORY_LOCAL = os.path.join(ROOT_DIR, 'checkpoints')
DEFAULT_BCKP_CHECKPOINT_DIRECTORY_LOCAL = os.path.join(ROOT_DIR, 'bckp-checkpoints')

LOG_DIRECTORY = os.path.join(ROOT_DIR, 'logs')
CHECKPOINT_DIRECTORY = DEFAULT_CHECKPOINT_DIRECTORY_LOCAL

LOG_DIRECTORY_LOCAL = LOG_DIRECTORY
CHECKPOINT_DIRECTORY_LOCAL = CHECKPOINT_DIRECTORY

DATA_DIR_NAME = 'data-v3'
DATA_DIR = os.path.join(ROOT_DIR, DATA_DIR_NAME)

saliencies_and_segmentations_v2_path = os.path.join(ROOT_DIR, 'saliencies_and_segmentations_v2')

if not os.path.exists(CHECKPOINT_DIRECTORY):
    os.mkdir(CHECKPOINT_DIRECTORY)

if not os.path.exists(LOG_DIRECTORY):
    os.mkdir(LOG_DIRECTORY)

val = False

class_names = ['AD', 'CN']

# get paths to data
train_dir, test_dir, val_dir = train_test_split(
    saliencies_and_segmentations_v2_path,
    ROOT_DIR,
    split=(0.8, 0.15, 0.05),
    dirname=DATA_DIR_NAME)

# set the batch size for mri seq
batch_size = 12
input_shape = (104, 128, 104, 1) # (112, 112, 105, 1)
resize_img = True
crop_img = True

# if y is one-hot encoded or just scalar number
one_hot = True

# class weightss (see analysis notebook)
class_weights = {0: 0.8072289156626505, 1: 1.3137254901960784}

# description statistics of the dataset
desc = {'mean': -3.6344006e-09, 'std': 1.0000092, 'min': -1.4982183, 'max': 10.744175}

if 'desc' not in locals():
    print('initializing desc...')
    desc = get_description(MRISequence(
        train_dir,
        64,
        class_names=class_names,
        input_shape=input_shape),
        max_samples=None)
    print(desc)


normalization={ 'type':'normalization', 'desc': desc }
# normalization={'type':'standardization', 'desc':desc }

augmentations = None
augmentations_inplace = True
# enable augmentations in mri seq (otherwise it can be enabled in dataset)
# augmentations={ 'random_swap_hemispheres': 0.5 }

# initialize sequences
print('initializing train_seq...')
train_seq = MRISequence(
    train_dir,
    batch_size,
    class_names=class_names,
    augmentations=augmentations,
    augmentations_inplace=augmentations_inplace,
    input_shape=input_shape,
    resize_img=resize_img,
    crop_img=crop_img,
    one_hot=one_hot,
    class_weights=class_weights,
    normalization=normalization)

print('initializing test_seq...')
test_seq = MRISequence(
    test_dir,
    batch_size,
    class_names=class_names,
    input_shape=input_shape,
    resize_img=resize_img,
    crop_img=crop_img,
    one_hot=one_hot,
    normalization=normalization)

if val:
    print('initializing val_seq...')
    val_seq = MRISequence(
        val_dir,
        batch_size,
        class_names=class_names,
        input_shape=input_shape,
        resize_img=resize_img,
        crop_img=crop_img,
        one_hot=one_hot,
        class_weights=class_weights,
        normalization=normalization)
else:
    print('val_seq = test_seq')
    val_seq = test_seq

model_key = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
log_dir = os.path.join(LOG_DIRECTORY, model_key)
print(f'log_dir: {log_dir}')

not copying files since the destination directory already exists
initializing train_seq...
initializing test_seq...
val_seq = test_seq
log_dir: ../../../../tmp\logs\20210410-144337
Wall time: 30.8 ms


## Model

In [4]:
import torch

from torchsummary import summary

from src.model.torch import Net3DCNN, load_weights

In [5]:
net = Net3DCNN()

weights = np.load(os.path.join(ROOT_DIR, 'tf-weights.npy'), allow_pickle=True)
weights_bn = np.load(os.path.join(ROOT_DIR, 'tf-weights-bn.npy'), allow_pickle=True)

load_weights(net, weights, weights_bn)

net.eval()
net.cuda()

summary(net.cuda(), (1, 104, 128, 104))

[Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)), BatchNorm3d(32, eps=0.001, momentum=0.99, affine=True, track_running_stats=True), MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)), BatchNorm3d(64, eps=0.001, momentum=0.99, affine=True, track_running_stats=True), MaxPool3d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False), Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)), BatchNorm3d(128, eps=0.001, momentum=0.99, affine=True, track_running_stats=True), MaxPool3d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False), Flatten(start_dim=1, end_dim=-1), Dropout(p=0.1, inplace=False), Linear(in_features=10240, out_features=256, bias=True), Dropout(p=0.1, inplace=False), Linear(in_features=256, out_features=2, bias=True)]
copy trainable parameters...


--- conv1.weight / conv3d/kernel:0 ---
(32, 1, 3, 3, 3)
torch

## Experiments

In [6]:
NOTEBOOK_KEY = 'captum--grad-cam'
BATCH_SIZE = 24
VERBOSE = 1
LOG = True
SEED = 42

heatmap_evaluation_options = {
    'evaluation_step_size': 1000,
    'evaluation_max_steps': -1,
    'evaluation_batch_size': BATCH_SIZE
}

In [7]:
from skimage.transform import resize

from src.heatmaps.evaluation import HeatmapEvaluationV2
from src.data import tf_predict, torch_predict

from captum.attr import LayerGradCam


attribution_method = LayerGradCam(net, net.conv3)

def heatmap_fn(image_x, image_y, **kwargs): # kwargs - log, seed
    target = np.argmax(image_y, axis=0)
    image_x = np.transpose(image_x, axes=(3, 0, 1, 2)) # transpose to the torch axes
    batch_x = torch.from_numpy(np.array([image_x])).float().cuda().requires_grad_(True)
    attribution = attribution_method.attribute(batch_x, int(target))
    attribution = attribution.to('cpu').detach().numpy()[0]
    attribution = np.transpose(np.array(attribution), axes=(1, 2, 3, 0)) # transpose back
    attribution = resize(attribution, input_shape[:-1]) # resize since the size is of the last conv layer
    return attribution

predict_fn = torch_predict(net)

In [8]:
# from src.heatmaps.evaluation import HeatmapEvaluationV2
# from src.data import tf_predict, torch_predict

# from captum.attr import GuidedBackprop, GuidedGradCam


# attribution_method = GuidedGradCam(net, net.conv3)
# attribution_method = GuidedBackprop(net)

# def heatmap_fn(image_x, image_y, **kwargs): # kwargs - log, seed
#     target = np.argmax(image_y, axis=0)
#     image_x = np.transpose(image_x, axes=(3, 0, 1, 2)) # transpose to the torch axes
#     batch_x = torch.from_numpy(np.array([image_x])).float().cuda().requires_grad_(True)
#     attribution = attribution_method.attribute(batch_x, int(target))
#     attribution = attribution.to('cpu').detach().numpy()[0]
#     attribution = np.transpose(np.array(attribution), axes=(1, 2, 3, 0)) # transpose back
#     return attribution

# predict_fn = torch_predict(net)

### 10TP 10TN 10FP 10FN

In [9]:
from src.data import select_from_dataset, numpy_to_sequence

images_x, images_y, images_y_pred = select_from_dataset(torch_predict(net), test_seq, max_category=10)
print(images_x.shape)
sequence = numpy_to_sequence(images_x, images_y, batch_size=BATCH_SIZE)

tp: 10, tn: 10, fp: 10, fn: 10
(40, 104, 128, 104, 1)


In [10]:
%%time

he = HeatmapEvaluationV2(predict_fn, heatmap_fn, sequence, **heatmap_evaluation_options)

history = he.evaluate('insertion', log=LOG, verbose=VERBOSE, seed=SEED)

history.save(os.path.join(ROOT_DIR, 'risei-history'), f'{NOTEBOOK_KEY}-insertion-TP-TN-FP-FN')

sequence len: 40, method: insertion
evaluation 1/40
generating heatmap...
...finished in 0:00:00s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:01
auc: 1014381.1405599117 (0:01:01s)

evaluation 2/40
generating heatmap...
...finished in 0:00:00s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:03
auc: 869835.8499407768 (0:01:03s)

evaluation 3/40
generating heatmap...
...finished in 0:00:00s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:01
auc: 1013504.7940611839 (0:01:02s)

evaluation 4/40
generating heatmap...
...finished in 0:00:00s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:02
auc: 1064146.3383436203 (0:01:02s)

evaluation 5/40
generating heatmap...
...finished in 0:00:00s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:01
auc: 900922.3870038986 (0:01:01s)

evalua

In [10]:
%%time

he = HeatmapEvaluationV2(predict_fn, heatmap_fn, sequence, **heatmap_evaluation_options)

history = he.evaluate('deletion', log=LOG, verbose=VERBOSE, seed=SEED)

history.save(os.path.join(ROOT_DIR, 'risei-history'), f'{NOTEBOOK_KEY}-deletion-TP-TN-FP-FN')

sequence len: 40, method: deletion
evaluation 1/40
generating heatmap...
...finished in 0:00:00s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:00:53
auc: 509256.44458830357 (0:00:54s)

evaluation 2/40
generating heatmap...
...finished in 0:00:00s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:00:56
auc: 633737.4076247215 (0:00:56s)

evaluation 3/40
generating heatmap...
...finished in 0:00:00s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:00:57
auc: 760434.4952106476 (0:00:58s)

evaluation 4/40
generating heatmap...
...finished in 0:00:00s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:00:54
auc: 750248.056024313 (0:00:55s)

evaluation 5/40
generating heatmap...
...finished in 0:00:00s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:00:53
auc: 1087589.0807509422 (0:00:54s)

evaluatio

### 20TP 20TN

In [11]:
from src.data import select_from_dataset, numpy_to_sequence

images_x, images_y, images_y_pred = select_from_dataset(torch_predict(net), test_seq, max_category=20, fp_max=0, fn_max=0)
print(images_x.shape)
sequence = numpy_to_sequence(images_x, images_y, batch_size=BATCH_SIZE)

tp: 20, tn: 20, fp: 0, fn: 0
(40, 104, 128, 104, 1)


In [12]:
%%time

he = HeatmapEvaluationV2(predict_fn, heatmap_fn, sequence, **heatmap_evaluation_options)

history = he.evaluate('insertion', log=LOG, verbose=VERBOSE, seed=SEED)

history.save(os.path.join(ROOT_DIR, 'risei-history'), f'{NOTEBOOK_KEY}-insertion-TP-TN')

sequence len: 40, method: insertion
evaluation 1/40
generating heatmap...
...finished in 0:00:00s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:01
auc: 1014381.1405599117 (0:01:01s)

evaluation 2/40
generating heatmap...
...finished in 0:00:00s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:03
auc: 869835.8499407768 (0:01:03s)

evaluation 3/40
generating heatmap...
...finished in 0:00:00s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:04
auc: 1013504.7940611839 (0:01:04s)

evaluation 4/40
generating heatmap...
...finished in 0:00:00s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:01
auc: 1064146.3383436203 (0:01:01s)

evaluation 5/40
generating heatmap...
...finished in 0:00:00s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:02
auc: 900922.3870038986 (0:01:03s)

evalua

In [13]:
%%time

he = HeatmapEvaluationV2(predict_fn, heatmap_fn, sequence, **heatmap_evaluation_options)

history = he.evaluate('deletion', log=LOG, verbose=VERBOSE, seed=SEED)

history.save(os.path.join(ROOT_DIR, 'risei-history'), f'{NOTEBOOK_KEY}-deletion-TP-TN')

sequence len: 40, method: deletion
evaluation 1/40
generating heatmap...
...finished in 0:00:00s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:00:52
auc: 509256.44458830357 (0:00:53s)

evaluation 2/40
generating heatmap...
...finished in 0:00:00s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:00:55
auc: 633737.4076247215 (0:00:55s)

evaluation 3/40
generating heatmap...
...finished in 0:00:00s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:00:57
auc: 760434.4952106476 (0:00:57s)

evaluation 4/40
generating heatmap...
...finished in 0:00:00s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:00:56
auc: 750248.056024313 (0:00:56s)

evaluation 5/40
generating heatmap...
...finished in 0:00:00s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:00:54
auc: 1087589.0807509422 (0:00:54s)

evaluatio