# Heatmap Parameters Analysis

In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../../../..')

import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

import datetime

import numpy as np
import pandas as pd

from src.data import train_test_split, MRISequence
from src.model import create_model, compile_model, load_checkpoint
from src.model.evaluation import show_metrics

In [2]:
import seaborn as sns
import matplotlib.pyplot as plt

sns.set(style="white")

plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['image.cmap'] = 'viridis'

%config InlineBackend.figure_format='retina'
plt.rcParams.update({'font.size': 15})

In [3]:
import tensorflow as tf

# RANDOM_SEED = 250398
# tf.random.set_seed(RANDOM_SEED)

print(tf.version.VERSION)
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

2.3.2
Num GPUs Available:  1


## Setup

In [4]:
%%time

ROOT_DIR = '../../../../tmp'
DEFAULT_CHECKPOINT_DIRECTORY_LOCAL = os.path.join(ROOT_DIR, 'checkpoints')
DEFAULT_BCKP_CHECKPOINT_DIRECTORY_LOCAL = os.path.join(ROOT_DIR, 'bckp-checkpoints')

LOG_DIRECTORY = os.path.join(ROOT_DIR, 'logs')
CHECKPOINT_DIRECTORY = DEFAULT_CHECKPOINT_DIRECTORY_LOCAL

LOG_DIRECTORY_LOCAL = LOG_DIRECTORY
CHECKPOINT_DIRECTORY_LOCAL = CHECKPOINT_DIRECTORY

DATA_DIR_NAME = 'data-v3'
DATA_DIR = os.path.join(ROOT_DIR, DATA_DIR_NAME)

saliencies_and_segmentations_v2_path = os.path.join(ROOT_DIR, 'saliencies_and_segmentations_v2')

if not os.path.exists(CHECKPOINT_DIRECTORY):
    os.mkdir(CHECKPOINT_DIRECTORY)

if not os.path.exists(LOG_DIRECTORY):
    os.mkdir(LOG_DIRECTORY)

val = False

class_names = ['AD', 'CN']

# get paths to data
train_dir, test_dir, val_dir = train_test_split(
    saliencies_and_segmentations_v2_path,
    ROOT_DIR,
    split=(0.8, 0.15, 0.05),
    dirname=DATA_DIR_NAME)

# set the batch size for mri seq
batch_size = 12
input_shape = (104, 128, 104, 1) # (112, 112, 105, 1)
resize_img = True
crop_img = True

# if y is one-hot encoded or just scalar number
one_hot = True

# class weightss (see analysis notebook)
class_weights = {0: 0.8072289156626505, 1: 1.3137254901960784}

# description statistics of the dataset
desc = {'mean': -3.6344006e-09, 'std': 1.0000092, 'min': -1.4982183, 'max': 10.744175}

if 'desc' not in locals():
    print('initializing desc...')
    desc = get_description(MRISequence(
        train_dir,
        64,
        class_names=class_names,
        input_shape=input_shape),
        max_samples=None)
    print(desc)


normalization={ 'type':'normalization', 'desc': desc }
# normalization={'type':'standardization', 'desc':desc }

augmentations = None
augmentations_inplace = True
# enable augmentations in mri seq (otherwise it can be enabled in dataset)
# augmentations={ 'random_swap_hemispheres': 0.5 }

# initialize sequences
print('initializing train_seq...')
train_seq = MRISequence(
    train_dir,
    batch_size,
    class_names=class_names,
    augmentations=augmentations,
    augmentations_inplace=augmentations_inplace,
    input_shape=input_shape,
    resize_img=resize_img,
    crop_img=crop_img,
    one_hot=one_hot,
    class_weights=class_weights,
    normalization=normalization)

print('initializing test_seq...')
test_seq = MRISequence(
    test_dir,
    batch_size,
    class_names=class_names,
    input_shape=input_shape,
    resize_img=resize_img,
    crop_img=crop_img,
    one_hot=one_hot,
    normalization=normalization)

if val:
    print('initializing val_seq...')
    val_seq = MRISequence(
        val_dir,
        batch_size,
        class_names=class_names,
        input_shape=input_shape,
        resize_img=resize_img,
        crop_img=crop_img,
        one_hot=one_hot,
        class_weights=class_weights,
        normalization=normalization)
else:
    print('val_seq = test_seq')
    val_seq = test_seq

model_key = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
log_dir = os.path.join(LOG_DIRECTORY, model_key)
print(f'log_dir: {log_dir}')

not copying files since the destination directory already exists
initializing train_seq...
initializing test_seq...
val_seq = test_seq
log_dir: ../../../../tmp\logs\20210504-183338
Wall time: 0 ns


## Analysis

Each experiment consisted of 10 images, 5 TP and 5 TN.

In [6]:
from os import listdir
from os.path import isfile, join

fpath = os.path.join(ROOT_DIR, "risei-history/heatmap-parameters--b1-1-b2-0")

files = [f for f in listdir(fpath) if isfile(join(fpath, f))]
files[:5]

['hmap-parameters--deletion--m+1024-p1+0.25.cls',
 'hmap-parameters--deletion--m+1024-p1+0.3333333333333333.cls',
 'hmap-parameters--deletion--m+1024-p1+0.5.cls',
 'hmap-parameters--deletion--m+1024-p1+0.6666666666666666.cls',
 'hmap-parameters--deletion--m+1024-p1+0.75.cls']

In [7]:
import re


def parse(fname):
    p = re.compile("^hmap-parameters--(\w+)--m\+(\d+)-p1\+(\d+[.]?\d*)\.cls$")
    return p.match(fname).groups()

print(parse('hmap-parameters--deletion--m+1024-p1+0.6666666666666666.cls'))

('deletion', '1024', '0.6666666666666666')


In [8]:
from src.heatmaps.evaluation import HeatmapEvaluationHistory

data = {}


def append(key, value):
    if not key in data:
        data[key] = []
    data[key].append(value)

    
for fname in files:
    metric, masks_count, p1 = parse(fname)
    append('metric', metric)
    append('masks_count', int(masks_count))
    append('p1', float(p1))
    
    history = HeatmapEvaluationHistory.load(fpath, fname[:-4])
    desc = history._description()
    for key, value in desc.items():
        append(key, value)
    
    
df = pd.DataFrame(data=data)
df.head()

Unnamed: 0,metric,masks_count,p1,heatmaps,auc_mean,auc_p25,auc_median,auc_p75,auc_max,auc_min,auc_std
0,deletion,1024,0.25,10,0.643593,0.563228,0.655659,0.724473,0.749755,0.538872,0.080467
1,deletion,1024,0.333333,10,0.640157,0.588006,0.627699,0.707924,0.738993,0.53058,0.068003
2,deletion,1024,0.5,10,0.654849,0.626846,0.657283,0.689328,0.710017,0.58455,0.038588
3,deletion,1024,0.666667,10,0.623226,0.580452,0.601526,0.653028,0.756338,0.51537,0.071507
4,deletion,1024,0.75,10,0.599749,0.545744,0.586926,0.666345,0.699307,0.496409,0.068916


In [9]:
def table(metric, value):
    df_m = df[df['metric'] == metric]
    df_m = df_m.sort_values('masks_count')
    return pd.pivot_table(df_m, values=value, index=["masks_count"], columns="p1", fill_value=0)

In [10]:
table('deletion', 'auc_median')

p1,0.250000,0.333333,0.500000,0.666667,0.750000
masks_count,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
8,0.568857,0.560997,0.539627,0.574222,0.526298
16,0.569358,0.547365,0.518288,0.533546,0.510874
32,0.621808,0.579686,0.618804,0.59145,0.604745
64,0.600704,0.61524,0.6123,0.604009,0.602021
128,0.60569,0.570218,0.598747,0.608839,0.523231
256,0.627778,0.622465,0.605179,0.578617,0.588248
512,0.653583,0.622137,0.646226,0.605187,0.604104
1024,0.655659,0.627699,0.657283,0.601526,0.586926
2048,0.65049,0.621239,0.647952,0.618868,0.591306


In [11]:
table('insertion', 'auc_median')

p1,0.250000,0.333333,0.500000,0.666667,0.750000
masks_count,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
8,0.57945,0.632806,0.572831,0.602406,0.627229
16,0.577185,0.623981,0.635889,0.646753,0.632246
32,0.596792,0.543668,0.557625,0.57027,0.588991
64,0.52921,0.526029,0.560768,0.554994,0.558734
128,0.554642,0.574473,0.605209,0.57386,0.556533
256,0.526284,0.504812,0.524492,0.544606,0.572755
512,0.508116,0.525889,0.514605,0.535332,0.567616
1024,0.484499,0.498557,0.536745,0.543785,0.554009
2048,0.492035,0.510781,0.536827,0.544496,0.552609


In [12]:
df_m = df.copy().sort_values("masks_count")
pd.pivot_table(df_m, values="auc_median", index=["masks_count"], columns=["p1", "metric"], fill_value=0)

p1,0.250000,0.250000,0.333333,0.333333,0.500000,0.500000,0.666667,0.666667,0.750000,0.750000
metric,deletion,insertion,deletion,insertion,deletion,insertion,deletion,insertion,deletion,insertion
masks_count,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2
8,0.568857,0.57945,0.560997,0.632806,0.539627,0.572831,0.574222,0.602406,0.526298,0.627229
16,0.569358,0.577185,0.547365,0.623981,0.518288,0.635889,0.533546,0.646753,0.510874,0.632246
32,0.621808,0.596792,0.579686,0.543668,0.618804,0.557625,0.59145,0.57027,0.604745,0.588991
64,0.600704,0.52921,0.61524,0.526029,0.6123,0.560768,0.604009,0.554994,0.602021,0.558734
128,0.60569,0.554642,0.570218,0.574473,0.598747,0.605209,0.608839,0.57386,0.523231,0.556533
256,0.627778,0.526284,0.622465,0.504812,0.605179,0.524492,0.578617,0.544606,0.588248,0.572755
512,0.653583,0.508116,0.622137,0.525889,0.646226,0.514605,0.605187,0.535332,0.604104,0.567616
1024,0.655659,0.484499,0.627699,0.498557,0.657283,0.536745,0.601526,0.543785,0.586926,0.554009
2048,0.65049,0.492035,0.621239,0.510781,0.647952,0.536827,0.618868,0.544496,0.591306,0.552609


In [13]:
df_m = df.copy().sort_values("masks_count")
pd.pivot_table(df_m, values="auc_median", index=["masks_count"], columns="p1", fill_value=0) # value is mean of insertion and deletion auc_median

p1,0.250000,0.333333,0.500000,0.666667,0.750000
masks_count,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
8,0.574154,0.596902,0.556229,0.588314,0.576764
16,0.573271,0.585673,0.577089,0.59015,0.57156
32,0.6093,0.561677,0.588215,0.58086,0.596868
64,0.564957,0.570634,0.586534,0.579502,0.580378
128,0.580166,0.572346,0.601978,0.591349,0.539882
256,0.577031,0.563638,0.564836,0.561612,0.580501
512,0.58085,0.574013,0.580415,0.570259,0.58586
1024,0.570079,0.563128,0.597014,0.572655,0.570468
2048,0.571262,0.56601,0.592389,0.581682,0.571957
