# Experiment - Grad Cam (Captum)

In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../../../..')

import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

import datetime

import numpy as np

from src.data import train_test_split, MRISequence
from src.model import create_model, compile_model, load_checkpoint
from src.model.evaluation import show_metrics

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

sns.set(style="white")

plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['image.cmap'] = 'viridis'

%config InlineBackend.figure_format='retina'
plt.rcParams.update({'font.size': 15})

## Setup

In [None]:
%%time

ROOT_DIR = '../../../../tmp'
DEFAULT_CHECKPOINT_DIRECTORY_LOCAL = os.path.join(ROOT_DIR, 'checkpoints')
DEFAULT_BCKP_CHECKPOINT_DIRECTORY_LOCAL = os.path.join(ROOT_DIR, 'bckp-checkpoints')

LOG_DIRECTORY = os.path.join(ROOT_DIR, 'logs')
CHECKPOINT_DIRECTORY = DEFAULT_CHECKPOINT_DIRECTORY_LOCAL

LOG_DIRECTORY_LOCAL = LOG_DIRECTORY
CHECKPOINT_DIRECTORY_LOCAL = CHECKPOINT_DIRECTORY

DATA_DIR_NAME = 'data-v3'
DATA_DIR = os.path.join(ROOT_DIR, DATA_DIR_NAME)

saliencies_and_segmentations_v2_path = os.path.join(ROOT_DIR, 'saliencies_and_segmentations_v2')

if not os.path.exists(CHECKPOINT_DIRECTORY):
    os.mkdir(CHECKPOINT_DIRECTORY)

if not os.path.exists(LOG_DIRECTORY):
    os.mkdir(LOG_DIRECTORY)

val = False

class_names = ['AD', 'CN']

# get paths to data
train_dir, test_dir, val_dir = train_test_split(
    saliencies_and_segmentations_v2_path,
    ROOT_DIR,
    split=(0.8, 0.15, 0.05),
    dirname=DATA_DIR_NAME)

# set the batch size for mri seq
batch_size = 12
input_shape = (104, 128, 104, 1) # (112, 112, 105, 1)
resize_img = True
crop_img = True

# if y is one-hot encoded or just scalar number
one_hot = True

# class weightss (see analysis notebook)
class_weights = {0: 0.8072289156626505, 1: 1.3137254901960784}

# description statistics of the dataset
desc = {'mean': -3.6344006e-09, 'std': 1.0000092, 'min': -1.4982183, 'max': 10.744175}

if 'desc' not in locals():
    print('initializing desc...')
    desc = get_description(MRISequence(
        train_dir,
        64,
        class_names=class_names,
        input_shape=input_shape),
        max_samples=None)
    print(desc)


normalization={ 'type':'normalization', 'desc': desc }
# normalization={'type':'standardization', 'desc':desc }

augmentations = None
augmentations_inplace = True
# enable augmentations in mri seq (otherwise it can be enabled in dataset)
# augmentations={ 'random_swap_hemispheres': 0.5 }

# initialize sequences
print('initializing train_seq...')
train_seq = MRISequence(
    train_dir,
    batch_size,
    class_names=class_names,
    augmentations=augmentations,
    augmentations_inplace=augmentations_inplace,
    input_shape=input_shape,
    resize_img=resize_img,
    crop_img=crop_img,
    one_hot=one_hot,
    class_weights=class_weights,
    normalization=normalization)

print('initializing test_seq...')
test_seq = MRISequence(
    test_dir,
    batch_size,
    class_names=class_names,
    input_shape=input_shape,
    resize_img=resize_img,
    crop_img=crop_img,
    one_hot=one_hot,
    normalization=normalization)

if val:
    print('initializing val_seq...')
    val_seq = MRISequence(
        val_dir,
        batch_size,
        class_names=class_names,
        input_shape=input_shape,
        resize_img=resize_img,
        crop_img=crop_img,
        one_hot=one_hot,
        class_weights=class_weights,
        normalization=normalization)
else:
    print('val_seq = test_seq')
    val_seq = test_seq

model_key = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
log_dir = os.path.join(LOG_DIRECTORY, model_key)
print(f'log_dir: {log_dir}')

## Model

In [None]:
import torch

from torchsummary import summary

from src.model.torch import Net3DCNN, load_weights

In [None]:
net = Net3DCNN()

weights = np.load(os.path.join(ROOT_DIR, 'tf-weights.npy'), allow_pickle=True)
weights_bn = np.load(os.path.join(ROOT_DIR, 'tf-weights-bn.npy'), allow_pickle=True)

load_weights(net, weights, weights_bn)

net.eval()
net.cuda()

summary(net.cuda(), (1, 104, 128, 104))

## Experiments

In [None]:
NOTEBOOK_KEY = 'captum--grad-cam-v1'
BATCH_SIZE = 24
VERBOSE = 1
LOG = True
SEED = 42

heatmap_evaluation_options = {
    'evaluation_step_size': 1000,
    'evaluation_max_steps': -1,
    'evaluation_batch_size': BATCH_SIZE
}

In [None]:
from skimage.transform import resize

from src.heatmaps.evaluation import HeatmapEvaluationV2
from src.data import tf_predict, torch_predict

from captum.attr import LayerGradCam

import torch.nn.functional as F


attribution_method = LayerGradCam(net, net.conv3)


def heatmap_fn_v1(image_x, image_y, **kwargs): # kwargs - log, seed
    target = np.argmax(image_y, axis=0)
    image_x = np.transpose(image_x, axes=(3, 0, 1, 2)) # transpose to the torch axes
    batch_x = torch.from_numpy(np.array([image_x])).float().cuda().requires_grad_(True)
    attribution = attribution_method.attribute(batch_x, int(target), relu_attributions=True)
    attribution = attribution.to('cpu').detach().numpy()[0]
    attribution = np.transpose(np.array(attribution), axes=(1, 2, 3, 0)) # transpose back
    attribution = resize(attribution, input_shape[:-1]) # resize since the size is of the last conv layer
    return attribution


def heatmap_fn_v2(image_x, image_y, **kwargs): # kwargs - log, seed
    target = np.argmax(image_y, axis=0)
    image_x = np.transpose(image_x, axes=(3, 0, 1, 2)) # transpose to the torch axes
    batch_x = torch.from_numpy(np.array([image_x])).float().cuda().requires_grad_(True)
    attribution = attribution_method.attribute(batch_x, int(target), relu_attributions=True)
    attribution = F.interpolate(attribution, image_x.shape[1:], mode='trilinear', align_corners=False) # resize since the size is of the last conv layer
    attribution = attribution.to('cpu').detach().numpy()[0]
    attribution = np.transpose(np.array(attribution), axes=(1, 2, 3, 0)) # transpose back
    return attribution

heatmap_fn = heatmap_fn_v1

predict_fn = torch_predict(net)

In [None]:
# from src.heatmaps.evaluation import HeatmapEvaluationV2
# from src.data import tf_predict, torch_predict

# from captum.attr import GuidedBackprop, GuidedGradCam


# attribution_method = GuidedGradCam(net, net.conv3)
# attribution_method = GuidedBackprop(net)

# def heatmap_fn(image_x, image_y, **kwargs): # kwargs - log, seed
#     target = np.argmax(image_y, axis=0)
#     image_x = np.transpose(image_x, axes=(3, 0, 1, 2)) # transpose to the torch axes
#     batch_x = torch.from_numpy(np.array([image_x])).float().cuda().requires_grad_(True)
#     attribution = attribution_method.attribute(batch_x, int(target))
#     attribution = attribution.to('cpu').detach().numpy()[0]
#     attribution = np.transpose(np.array(attribution), axes=(1, 2, 3, 0)) # transpose back
#     return attribution

# predict_fn = torch_predict(net)

### 10TP 10TN 10FP 10FN

In [None]:
from src.data import select_from_dataset, numpy_to_sequence

images_x, images_y, images_y_pred = select_from_dataset(torch_predict(net), test_seq, max_category=10)
print(images_x.shape)
sequence = numpy_to_sequence(images_x, images_y, batch_size=BATCH_SIZE)

In [None]:
%%time

he = HeatmapEvaluationV2(predict_fn, heatmap_fn, sequence, **heatmap_evaluation_options)

history = he.evaluate('insertion', log=LOG, verbose=VERBOSE, seed=SEED)

history.save(os.path.join(ROOT_DIR, 'risei-history'), f'{NOTEBOOK_KEY}-insertion-TP-TN-FP-FN')

In [None]:
%%time

he = HeatmapEvaluationV2(predict_fn, heatmap_fn, sequence, **heatmap_evaluation_options)

history = he.evaluate('deletion', log=LOG, verbose=VERBOSE, seed=SEED)

history.save(os.path.join(ROOT_DIR, 'risei-history'), f'{NOTEBOOK_KEY}-deletion-TP-TN-FP-FN')

### 20TP 20TN

In [None]:
from src.data import select_from_dataset, numpy_to_sequence

images_x, images_y, images_y_pred = select_from_dataset(torch_predict(net), test_seq, max_category=20, fp_max=0, fn_max=0)
print(images_x.shape)
sequence = numpy_to_sequence(images_x, images_y, batch_size=BATCH_SIZE)

In [None]:
%%time

he = HeatmapEvaluationV2(predict_fn, heatmap_fn, sequence, **heatmap_evaluation_options)

history = he.evaluate('insertion', log=LOG, verbose=VERBOSE, seed=SEED)

history.save(os.path.join(ROOT_DIR, 'risei-history'), f'{NOTEBOOK_KEY}-insertion-TP-TN')

In [None]:
%%time

he = HeatmapEvaluationV2(predict_fn, heatmap_fn, sequence, **heatmap_evaluation_options)

history = he.evaluate('deletion', log=LOG, verbose=VERBOSE, seed=SEED)

history.save(os.path.join(ROOT_DIR, 'risei-history'), f'{NOTEBOOK_KEY}-deletion-TP-TN')