# RISEI Parameters Experiment

In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../../../..')

import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import datetime

import numpy as np

from src.data import train_test_split, MRISequence
from src.model import create_model, compile_model, load_checkpoint
from src.model.evaluation import show_metrics

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

sns.set(style="white")

plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['image.cmap'] = 'viridis'

%config InlineBackend.figure_format='retina'
plt.rcParams.update({'font.size': 15})

In [None]:
import tensorflow as tf

# RANDOM_SEED = 250398
# tf.random.set_seed(RANDOM_SEED)

print(tf.version.VERSION)
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

## Setup

In [None]:
%%time

ROOT_DIR = '../../../../tmp'
DEFAULT_CHECKPOINT_DIRECTORY_LOCAL = os.path.join(ROOT_DIR, 'checkpoints')
DEFAULT_BCKP_CHECKPOINT_DIRECTORY_LOCAL = os.path.join(ROOT_DIR, 'bckp-checkpoints')

LOG_DIRECTORY = os.path.join(ROOT_DIR, 'logs')
CHECKPOINT_DIRECTORY = DEFAULT_CHECKPOINT_DIRECTORY_LOCAL

LOG_DIRECTORY_LOCAL = LOG_DIRECTORY
CHECKPOINT_DIRECTORY_LOCAL = CHECKPOINT_DIRECTORY

DATA_DIR_NAME = 'data-v3'
DATA_DIR = os.path.join(ROOT_DIR, DATA_DIR_NAME)

saliencies_and_segmentations_v2_path = os.path.join(ROOT_DIR, 'saliencies_and_segmentations_v2')

if not os.path.exists(CHECKPOINT_DIRECTORY):
    os.mkdir(CHECKPOINT_DIRECTORY)

if not os.path.exists(LOG_DIRECTORY):
    os.mkdir(LOG_DIRECTORY)

val = False

class_names = ['AD', 'CN']

# get paths to data
train_dir, test_dir, val_dir = train_test_split(
    saliencies_and_segmentations_v2_path,
    ROOT_DIR,
    split=(0.8, 0.15, 0.05),
    dirname=DATA_DIR_NAME)

# set the batch size for mri seq
batch_size = 12
input_shape = (104, 128, 104, 1) # (112, 112, 105, 1)
resize_img = True
crop_img = True

# if y is one-hot encoded or just scalar number
one_hot = True

# class weightss (see analysis notebook)
class_weights = {0: 0.8072289156626505, 1: 1.3137254901960784}

# description statistics of the dataset
desc = {'mean': -3.6344006e-09, 'std': 1.0000092, 'min': -1.4982183, 'max': 10.744175}

if 'desc' not in locals():
    print('initializing desc...')
    desc = get_description(MRISequence(
        train_dir,
        64,
        class_names=class_names,
        input_shape=input_shape),
        max_samples=None)
    print(desc)


normalization={ 'type':'normalization', 'desc': desc }
# normalization={'type':'standardization', 'desc':desc }

augmentations = None
augmentations_inplace = True
# enable augmentations in mri seq (otherwise it can be enabled in dataset)
# augmentations={ 'random_swap_hemispheres': 0.5 }

# initialize sequences
print('initializing train_seq...')
train_seq = MRISequence(
    train_dir,
    batch_size,
    class_names=class_names,
    augmentations=augmentations,
    augmentations_inplace=augmentations_inplace,
    input_shape=input_shape,
    resize_img=resize_img,
    crop_img=crop_img,
    one_hot=one_hot,
    class_weights=class_weights,
    normalization=normalization)

print('initializing test_seq...')
test_seq = MRISequence(
    test_dir,
    batch_size,
    class_names=class_names,
    input_shape=input_shape,
    resize_img=resize_img,
    crop_img=crop_img,
    one_hot=one_hot,
    normalization=normalization)

if val:
    print('initializing val_seq...')
    val_seq = MRISequence(
        val_dir,
        batch_size,
        class_names=class_names,
        input_shape=input_shape,
        resize_img=resize_img,
        crop_img=crop_img,
        one_hot=one_hot,
        class_weights=class_weights,
        normalization=normalization)
else:
    print('val_seq = test_seq')
    val_seq = test_seq

model_key = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
log_dir = os.path.join(LOG_DIRECTORY, model_key)
print(f'log_dir: {log_dir}')

In [None]:
# https://www.tensorflow.org/tutorials/structured_data/imbalanced_data#class_weights
# pos / neg
initial_bias = np.log([159/243, 243/159])

model_type = '3d_cnn'
model_config = {
    'input_shape': input_shape,
    'class_names': class_names,
#     'l2_beta': 0.001,
#     'l2_beta': 0.0005,
    'l2_beta': None,
#     'dropout': 0.05,
    'dropout': 0.10,
    'output_bias': initial_bias,
#     'output_bias': None,
    # https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization
    'batch_norm': True,
    'is_complex': False, # a complex layer from the paper, max batch_size is 3
}

compile_config = {
    # default is 0.001
#     'learning_rate': 0.000075,
    'learning_rate': 0.00010,
    'decay_steps': 25,
    'decay_rate': 0.96,
#     'beta_1': 0.85,
    'beta_1': 0.90,
#     'beta_2': 0.990,
    'beta_2': 0.999,
}

train_config = {
    'model_key': model_key,
    'epochs': 150,
    'patience': 75,
    'tensorboard_update_freq': 'epoch',
    'mri_tensorboard_callback': False,
    'model_checkpoint_callback': {'monitor': 'val_auc', 'mode': 'max', 'save_best_only': True},
    'early_stopping_monitor': {'monitor': 'val_auc', 'mode': 'max'},
#     'augmentations': False,
    'augmentations': {
        'invert': (0.5, None),
        'rotate': (0.2, 5), # probability, degrees
        'zoom': (0., 0.),
        'shear': (0.2, 0.5), # probability, degrees
        'blur': (0.2, 0.85),
        'noise': (0.2, 0.00020)
    },
    'batch_size': 8,
#     'model_checkpoint_callback': False,
}

## Model

In [None]:
model = create_model(model_type, model_config)
model, *_ = compile_model(model, **compile_config)
model.build(input_shape=input_shape)
model.summary()

In [None]:
load_checkpoint(model, DEFAULT_BCKP_CHECKPOINT_DIRECTORY_LOCAL, '20210308-175324', 'cp-0058.ckpt')

In [None]:
%%time

# just to test of it is ok to clear the session after loading the weigths
tf.keras.backend.clear_session()

show_metrics(model, test_seq, class_names)

## Experiments

In [None]:
import time

from src.heatmaps.evaluation import get_heatmap
from src.risei import RISEI
from src.data import tf_predict

In [None]:
from src.data import select_from_dataset, numpy_to_sequence

BATCH_SIZE = 24
images_x, images_y, images_y_pred = select_from_dataset(tf_predict(model), test_seq, max_category=5, fn_max=0, fp_max=0)
print(images_x.shape)
sequence = numpy_to_sequence(images_x, images_y, batch_size=BATCH_SIZE)

In [None]:
BATCH_SIZE = 24
VERBOSE = 1
LOG = True
SEED = 42
RISEI_BATCH_SIZE = 480

heatmap_evaluation_options = {
    'evaluation_step_size': 1000,
    'evaluation_max_steps': -1,
    'evaluation_batch_size': BATCH_SIZE
}

# risei config
risei_config = {
    's': 8, 
    'p1': 1/3, 
    'b1': 0.5,
    'b2': 1,
    'b2_value': 0,
    'in_paint': '2d', 
    'in_paint_blending': True, 
    'in_paint_radius': 5,
    'in_paint_2d_to_3d': True,
    'processes': 8,
}

In [None]:
import itertools

p1_options = [1/4, 1/3, 1/2, 2/3, 3/4]
masks_counts = [8, 16, 32, 64, 128, 256, 512, 1024, 2048]
masks_counts = [1024, 2048]

def map_fn(config, value):
    config.update({ 'p1': value })
    return config

experiments = itertools.product(masks_counts, p1_options)
# m - masks_count
# p1 - p1 in rise config
experiments = [(masks_count, f"m+{masks_count}-p1+{value}", map_fn(risei_config.copy(), value)) 
               for masks_count, value in experiments]

In [None]:
experiments[0]

In [None]:
from src.heatmaps.evaluation import HeatmapEvaluationV2
from src.heatmaps.heatmaps import get_heatmap
from src.data import tf_predict, torch_predict
from src.risei import RISEI


def get_heatmap_fn(risei, masks_count):
    def heatmap_fn(image_x, image_y, **kwargs): # kwargs - log, seed, evaluation_idx
        seed = kwargs.get('seed', None)
        evaluation_idx = kwargs.get('evaluation_idx', None)
        log = kwargs.get('log', None)
        heatmap_seed = None if seed is None else seed + evaluation_idx
        print(f"generating heatmap (masks_count={masks_count}; seed={heatmap_seed})")
        heatmap, _, _ = get_heatmap(
            image_x,
            image_y,
            model,
            risei,
            batch_size=BATCH_SIZE,
            masks_count=masks_count,
            risei_batch_size=RISEI_BATCH_SIZE,
            debug=False,
            seed=heatmap_seed,
            log=log
        )
        # print(f'{image_x.shape} {heatmap.shape}')
        return heatmap.reshape(input_shape)
    return heatmap_fn

predict_fn = tf_predict(model)

In [None]:
def run(experiments):
    for masks_count, history_fname, risei_config in experiments:
        print(f'\n*********\nrunning experiment: {history_fname}\n---------\n')
        risei = RISEI(input_shape[:-1], debug=False, **risei_config)
        heatmap_fn = get_heatmap_fn(risei, masks_count)
        he = HeatmapEvaluationV2(predict_fn, heatmap_fn, sequence, **heatmap_evaluation_options)
        
        for method in ['insertion', 'deletion']:
            print('\n')
            history = he.evaluate(method, log=LOG, verbose=VERBOSE, seed=SEED)
            history.save(os.path.join(ROOT_DIR, 'risei-history/hmap-parameters'), f'hmap-parameters--{method}--{history_fname}')
        
        tf.keras.backend.clear_session()

In [None]:
print(f'number of experiments: {len(experiments) * 2}')

In [None]:
# %%time

# run(experiments[0:])

In [None]:
%%time

run(experiments[40:])

In [17]:
# %%time

# run(experiments[0:])


*********
running experiment: m+128-p1+0.25
---------



sequence len: 10, method: insertion
evaluation 1/10
generating heatmap...
generating heatmap (masks_count=128; seed=42)
...finished in 0:02:08s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:02:07
auc: 806009.8181664944 (0:04:15s)

evaluation 2/10
generating heatmap...
generating heatmap (masks_count=128; seed=43)
...finished in 0:02:06s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:02:07
auc: 701537.1916890144 (0:04:13s)

evaluation 3/10
generating heatmap...
generating heatmap (masks_count=128; seed=44)
...finished in 0:02:05s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:02:04
auc: 756034.3935489655 (0:04:10s)

evaluation 4/10
generating heatmap...
generating heatmap (masks_count=128; seed=45)
...finished in 0:02:03s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0

...finished in 0:01:56s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:45
auc: 838956.3507735729 (0:03:41s)

evaluation 4/10
generating heatmap...
generating heatmap (masks_count=128; seed=45)
...finished in 0:01:55s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:44
auc: 985945.8688795567 (0:03:40s)

evaluation 5/10
generating heatmap...
generating heatmap (masks_count=128; seed=46)
...finished in 0:01:55s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:43
auc: 950867.8500950336 (0:03:39s)

evaluation 6/10
generating heatmap...
generating heatmap (masks_count=128; seed=47)
...finished in 0:01:55s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:44
auc: 720297.0500662923 (0:03:39s)

evaluation 7/10
generating heatmap...
generating heatmap (masks_count=128; seed=48)
...finished in 0:01:55s
evaluate heatmaps (voxels: 1

...finished in 0:02:04
auc: 828830.8981359005 (0:03:18s)

evaluation 6/10
generating heatmap...
generating heatmap (masks_count=128; seed=47)
...finished in 0:01:15s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:02:04
auc: 625998.466938734 (0:03:20s)

evaluation 7/10
generating heatmap...
generating heatmap (masks_count=128; seed=48)
...finished in 0:01:15s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:02:05
auc: 691900.8631855249 (0:03:21s)

evaluation 8/10
generating heatmap...
generating heatmap (masks_count=128; seed=49)
...finished in 0:01:14s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:02:07
auc: 383038.49881887436 (0:03:21s)

evaluation 9/10
generating heatmap...
generating heatmap (masks_count=128; seed=50)
...finished in 0:01:14s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:02:04
auc: 1092576.565682888 (0:03:

...finished in 0:01:02s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:44
auc: 753555.9858083725 (0:02:46s)

evaluation 9/10
generating heatmap...
generating heatmap (masks_count=128; seed=50)
...finished in 0:01:02s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:44
auc: 754902.2604525089 (0:02:46s)

evaluation 10/10
generating heatmap...
generating heatmap (masks_count=128; seed=51)
...finished in 0:01:02s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:44
auc: 929705.2299082279 (0:02:47s)

saved to: ../../../../tmp\risei-history/hmap-parameters\hmap-parameters--deletion--m+128-p1+0.75.cls

*********
running experiment: m+256-p1+0.25
---------



sequence len: 10, method: insertion
evaluation 1/10
generating heatmap...
generating heatmap (masks_count=256; seed=42)
...finished in 0:04:04s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)..

...finished in 0:02:05
auc: 1078419.9425280094 (0:05:50s)

saved to: ../../../../tmp\risei-history/hmap-parameters\hmap-parameters--insertion--m+256-p1+0.3333333333333333.cls


sequence len: 10, method: deletion
evaluation 1/10
generating heatmap...
generating heatmap (masks_count=256; seed=42)
...finished in 0:03:47s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:44
auc: 968424.554258585 (0:05:32s)

evaluation 2/10
generating heatmap...
generating heatmap (masks_count=256; seed=43)
...finished in 0:03:43s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:44
auc: 771377.8193891048 (0:05:28s)

evaluation 3/10
generating heatmap...
generating heatmap (masks_count=256; seed=44)
...finished in 0:03:39s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:44
auc: 821620.3837096691 (0:05:24s)

evaluation 4/10
generating heatmap...
generating heatmap (masks_count=256; see

...finished in 0:01:34s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:42
auc: 693262.1288597584 (0:03:16s)

evaluation 3/10
generating heatmap...
generating heatmap (masks_count=256; seed=44)
...finished in 0:02:18s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:02:02
auc: 516627.812191844 (0:04:21s)

evaluation 4/10
generating heatmap...
generating heatmap (masks_count=256; seed=45)
...finished in 0:02:19s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:59
auc: 732583.2661539316 (0:04:19s)

evaluation 5/10
generating heatmap...
generating heatmap (masks_count=256; seed=46)
...finished in 0:02:18s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:57
auc: 748419.2916154861 (0:04:15s)

evaluation 6/10
generating heatmap...
generating heatmap (masks_count=256; seed=47)
...finished in 0:02:15s
evaluate heatmaps (voxels: 13

...finished in 0:01:57s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:39
auc: 1002371.6050982475 (0:03:37s)

evaluation 6/10
generating heatmap...
generating heatmap (masks_count=256; seed=47)
...finished in 0:01:53s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:39
auc: 874809.0052306652 (0:03:32s)

evaluation 7/10
generating heatmap...
generating heatmap (masks_count=256; seed=48)
...finished in 0:01:55s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:39
auc: 644972.5357443094 (0:03:34s)

evaluation 8/10
generating heatmap...
generating heatmap (masks_count=256; seed=49)
...finished in 0:01:54s
evaluate heatmaps (voxels: 1385000, step_size: 1000, max_steps: -1)...
...finished in 0:01:38
auc: 858667.5039976835 (0:03:32s)

evaluation 9/10
generating heatmap...
generating heatmap (masks_count=256; seed=50)
...finished in 0:01:55s
evaluate heatmaps (voxels: 

In [None]:
%%time

run(experiments[30:40])