# RISEI Stability Experiment

In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../../../..')

import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import datetime

import numpy as np

from src.data import train_test_split, MRISequence
from src.model import create_model, compile_model, load_checkpoint
from src.model.evaluation import show_metrics

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

sns.set(style="white")

plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['image.cmap'] = 'viridis'

%config InlineBackend.figure_format='retina'
plt.rcParams.update({'font.size': 15})

In [None]:
import tensorflow as tf

# RANDOM_SEED = 250398
# tf.random.set_seed(RANDOM_SEED)

print(tf.version.VERSION)
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

## Setup

In [None]:
%%time

ROOT_DIR = '../../../../tmp'
DEFAULT_CHECKPOINT_DIRECTORY_LOCAL = os.path.join(ROOT_DIR, 'checkpoints')
DEFAULT_BCKP_CHECKPOINT_DIRECTORY_LOCAL = os.path.join(ROOT_DIR, 'bckp-checkpoints')

LOG_DIRECTORY = os.path.join(ROOT_DIR, 'logs')
CHECKPOINT_DIRECTORY = DEFAULT_CHECKPOINT_DIRECTORY_LOCAL

LOG_DIRECTORY_LOCAL = LOG_DIRECTORY
CHECKPOINT_DIRECTORY_LOCAL = CHECKPOINT_DIRECTORY

DATA_DIR_NAME = 'data-v3'
DATA_DIR = os.path.join(ROOT_DIR, DATA_DIR_NAME)

saliencies_and_segmentations_v2_path = os.path.join(ROOT_DIR, 'saliencies_and_segmentations_v2')

if not os.path.exists(CHECKPOINT_DIRECTORY):
    os.mkdir(CHECKPOINT_DIRECTORY)

if not os.path.exists(LOG_DIRECTORY):
    os.mkdir(LOG_DIRECTORY)

val = False

class_names = ['AD', 'CN']

# get paths to data
train_dir, test_dir, val_dir = train_test_split(
    saliencies_and_segmentations_v2_path,
    ROOT_DIR,
    split=(0.8, 0.15, 0.05),
    dirname=DATA_DIR_NAME)

# set the batch size for mri seq
batch_size = 12
input_shape = (104, 128, 104, 1) # (112, 112, 105, 1)
resize_img = True
crop_img = True

# if y is one-hot encoded or just scalar number
one_hot = True

# class weightss (see analysis notebook)
class_weights = {0: 0.8072289156626505, 1: 1.3137254901960784}

# description statistics of the dataset
desc = {'mean': -3.6344006e-09, 'std': 1.0000092, 'min': -1.4982183, 'max': 10.744175}

if 'desc' not in locals():
    print('initializing desc...')
    desc = get_description(MRISequence(
        train_dir,
        64,
        class_names=class_names,
        input_shape=input_shape),
        max_samples=None)
    print(desc)


normalization={ 'type':'normalization', 'desc': desc }
# normalization={'type':'standardization', 'desc':desc }

augmentations = None
augmentations_inplace = True
# enable augmentations in mri seq (otherwise it can be enabled in dataset)
# augmentations={ 'random_swap_hemispheres': 0.5 }

# initialize sequences
print('initializing train_seq...')
train_seq = MRISequence(
    train_dir,
    batch_size,
    class_names=class_names,
    augmentations=augmentations,
    augmentations_inplace=augmentations_inplace,
    input_shape=input_shape,
    resize_img=resize_img,
    crop_img=crop_img,
    one_hot=one_hot,
    class_weights=class_weights,
    normalization=normalization)

print('initializing test_seq...')
test_seq = MRISequence(
    test_dir,
    batch_size,
    class_names=class_names,
    input_shape=input_shape,
    resize_img=resize_img,
    crop_img=crop_img,
    one_hot=one_hot,
    normalization=normalization)

if val:
    print('initializing val_seq...')
    val_seq = MRISequence(
        val_dir,
        batch_size,
        class_names=class_names,
        input_shape=input_shape,
        resize_img=resize_img,
        crop_img=crop_img,
        one_hot=one_hot,
        class_weights=class_weights,
        normalization=normalization)
else:
    print('val_seq = test_seq')
    val_seq = test_seq

model_key = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
log_dir = os.path.join(LOG_DIRECTORY, model_key)
print(f'log_dir: {log_dir}')

In [None]:
# https://www.tensorflow.org/tutorials/structured_data/imbalanced_data#class_weights
# pos / neg
initial_bias = np.log([159/243, 243/159])

model_type = '3d_cnn'
model_config = {
    'input_shape': input_shape,
    'class_names': class_names,
#     'l2_beta': 0.001,
#     'l2_beta': 0.0005,
    'l2_beta': None,
#     'dropout': 0.05,
    'dropout': 0.10,
    'output_bias': initial_bias,
#     'output_bias': None,
    # https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization
    'batch_norm': True,
    'is_complex': False, # a complex layer from the paper, max batch_size is 3
}

compile_config = {
    # default is 0.001
#     'learning_rate': 0.000075,
    'learning_rate': 0.00010,
    'decay_steps': 25,
    'decay_rate': 0.96,
#     'beta_1': 0.85,
    'beta_1': 0.90,
#     'beta_2': 0.990,
    'beta_2': 0.999,
}

train_config = {
    'model_key': model_key,
    'epochs': 150,
    'patience': 75,
    'tensorboard_update_freq': 'epoch',
    'mri_tensorboard_callback': False,
    'model_checkpoint_callback': {'monitor': 'val_auc', 'mode': 'max', 'save_best_only': True},
    'early_stopping_monitor': {'monitor': 'val_auc', 'mode': 'max'},
#     'augmentations': False,
    'augmentations': {
        'invert': (0.5, None),
        'rotate': (0.2, 5), # probability, degrees
        'zoom': (0., 0.),
        'shear': (0.2, 0.5), # probability, degrees
        'blur': (0.2, 0.85),
        'noise': (0.2, 0.00020)
    },
    'batch_size': 8,
#     'model_checkpoint_callback': False,
}

## Model

In [None]:
model = create_model(model_type, model_config)
model, *_ = compile_model(model, **compile_config)
model.build(input_shape=input_shape)
model.summary()

In [None]:
load_checkpoint(model, DEFAULT_BCKP_CHECKPOINT_DIRECTORY_LOCAL, '20210308-175324', 'cp-0058.ckpt')

In [None]:
%%time

show_metrics(model, test_seq, class_names)

## RISEI

In [None]:
from src.risei import RISEI

### Config

In [None]:
risei_config = {
    's': 8, 
    'p1': 1/3, 
    'b1': 0,
    'b2': 1,
    'b2_value': 1,
    'in_paint': '2d', 
    'in_paint_blending': True, 
    'in_paint_radius': 5,
    'in_paint_2d_to_3d': True,
    'processes': 8,
}

### Evaluate

In [None]:
import time

from src.heatmaps.evaluation import get_heatmap

risei_batch_size = 480
risei = RISEI(input_shape[:-1], debug=False, **risei_config)

# we will test this only with one image
batch_x, batch_y, *_ = test_seq[0]
idx = 0
image_x, image_y = batch_x[idx], batch_y[idx]

In [None]:
from multiprocessing import Pool

from tqdm import tqdm


class Serializer():
    def __init__(self, n, input_shape, in_memory=True):
        self.n = n
        self.id = int(time.time())
        self.input_shape = input_shape
        self.save_dir = None
        
        if in_memory:
            self.save_dir = os.path.join(ROOT_DIR, "risei-stability-cache")
            os.makedirs(self.save_dir, exist_ok=True) # mkdir -p
        else:
            self.heatmaps = np.zeros((n, *input_shape[:-1]))
    
    def add_heatmap(self, i, heatmap):
        if self.save_dir is not None:
            np.save(self.__get_fname(i), heatmap)
        else:
            self.heatmaps[i] = heatmap
            
    def get_std(self, processes=8):
        if self.save_dir is not None:
            pass
            return std_heatmaps
        else:
            return get_std_heatmaps_v2(self.heatmaps) 
        
    def __get_fname(self, i):
        return get_fname(self.save_dir, self.id, i)
    
    
def __get_fname(save_dir, _id, i):
    return os.path.join(save_dir, f'hmap_{_id}_{i}.npy')

In [None]:
def generate_heatmaps(n, masks_count, image_x, image_y, tf_reset=5):
    serializer = Serializer(n, input_shape, False)

    i = 0
    while i < n:
        try:
            print(f"generating heatmap #{i}...")
            risei = RISEI(input_shape[:-1], debug=False, **risei_config)
            heatmap, _, _ = get_heatmap(
                image_x,
                image_y,
                model,
                risei,
                batch_size=train_config['batch_size'],
                masks_count=masks_count,
                risei_batch_size=risei_batch_size,
                debug=False,
                log=True
            )

            # save the heatmap to all generated heatmaps
            serializer.add_heatmap(i, heatmap)
            
            i = i + 1
            
            # https://github.com/tensorflow/tensorflow/issues/35010
            if i % tf_reset == 0:
                tf.keras.backend.clear_session()
        except Exception as e:
            print("there was an error, we will try to generate this heatmap again...")
            print(e)

    return serializer

def get_std_heatmaps_v1(heatmaps):
    std_heatmaps = np.zeros(input_shape[:-1])
    
    for z in range(heatmaps.shape[1]):
        for y in range(heatmaps.shape[2]):
            for x in range(heatmaps.shape[3]):
                # select the same voxel from all images
                voxels = heatmaps[:, z, y, x]
                # save the standard deviation for that voxel
                # TODO: optimize with axis param
                std_heatmaps[z, y, x] = np.std(voxels)
                
    return std_heatmaps

def get_std_heatmaps_v2(heatmaps):
    return np.std(heatmaps, axis=0)

In [None]:
# h1 = generate_heatmaps(10, 12, image_x, image_y)
# h2 = generate_heatmaps(10, 12, image_x, image_y)

In [None]:
def run(experiments, masks_counts):
    for masks_count in masks_counts:
        print(f"generating {experiments}x {masks_count} masks...")
        start = time.time()
        # [experiments, z, x, y]
        serializer = generate_heatmaps(experiments, masks_count, image_x, image_y)
        end = time.time()
        # [z, x, y]
        std_heatmaps = serializer.get_std()

        print(f"result for {masks_count} masks (t: {datetime.timedelta(seconds=int(end - start))})")
        print(f"\tmean std: {np.mean(std_heatmaps)}")
        print(f"\tmin std: {np.min(std_heatmaps)}")
        print(f"\tmax std: {np.max(std_heatmaps)}")
        print(f"\tstd std: {np.std(std_heatmaps)}")

        fPath = os.path.join(ROOT_DIR, "risei-stability")
        fName = os.path.join(fPath, f"{int(time.time())}_m{masks_count}.npy")
        os.makedirs(fPath, exist_ok=True) # mkdir -p
        print(f"saving std_heatmaps to {fName} ...")
        np.save(fName, std_heatmaps)

In [None]:
# experiments = 100
# masks_counts = [16, 128, 256, 512]

# run(experiments, masks_counts)

In [None]:
import tracemalloc

In [None]:
experiments = 100
masks_counts = [1024]

tracemalloc.start()

run(experiments, masks_counts)

In [None]:
# from pympler import asizeof
# a = np.zeros((1024, *input_shape[:-1]))
# asizeof.asizeof(a) / 1024 / 1024

In [None]:
import linecache


def display_top(snapshot, key_type='lineno', limit=300):
    snapshot = snapshot.filter_traces((
        tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
        tracemalloc.Filter(False, "<unknown>"),
    ))
    top_stats = snapshot.statistics(key_type)

    print("Top %s lines" % limit)
    for index, stat in enumerate(top_stats[:limit], 1):
        frame = stat.traceback[0]
        # replace "/path/to/module/file.py" with "module/file.py"
        filename = os.sep.join(frame.filename.split(os.sep)[-2:])
        print("#%s: %s:%s: %.1f MiB"
              % (index, filename, frame.lineno, stat.size / 1024 / 1024))
        line = linecache.getline(frame.filename, frame.lineno).strip()
        if line:
            print('    %s' % line)
            
snapshot = tracemalloc.take_snapshot()

display_top(snapshot)

In [None]:
from guppy import hpy; h=hpy()

h.heap()

In [None]:
experiments = 100
masks_counts = [2048]

run(experiments, masks_counts)

In [None]:
experiments = 100
masks_counts = [4096]

run(experiments, masks_counts)