# RISEI - Images

In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../..')

import datetime
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

import numpy as np

from src.data import train_test_split, MRISequence
from src.model import create_model, compile_model, load_checkpoint
from src.model.training import train
from src.model.evaluation import show_metrics

In [3]:
import seaborn as sns
import matplotlib.pyplot as plt

sns.set(style="white")

plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['image.cmap'] = 'viridis'

%config InlineBackend.figure_format='retina'
plt.rcParams.update({'font.size': 15})

In [4]:
import tensorflow as tf

RANDOM_SEED = 250398
tf.random.set_seed(RANDOM_SEED)

print(tf.version.VERSION)
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

2.3.1
Num GPUs Available:  1


## Setup

In [5]:
%%time

ROOT_DIR = '../../tmp'
DEFAULT_CHECKPOINT_DIRECTORY_LOCAL = os.path.join(ROOT_DIR, 'checkpoints')
DEFAULT_BCKP_CHECKPOINT_DIRECTORY_LOCAL = os.path.join(ROOT_DIR, 'bckp-checkpoints')

LOG_DIRECTORY = os.path.join(ROOT_DIR, 'logs')
CHECKPOINT_DIRECTORY = DEFAULT_CHECKPOINT_DIRECTORY_LOCAL

LOG_DIRECTORY_LOCAL = LOG_DIRECTORY
CHECKPOINT_DIRECTORY_LOCAL = CHECKPOINT_DIRECTORY

DATA_DIR_NAME = 'data-v3'
DATA_DIR = os.path.join(ROOT_DIR, DATA_DIR_NAME)

saliencies_and_segmentations_v2_path = os.path.join(ROOT_DIR, 'saliencies_and_segmentations_v2')

if not os.path.exists(CHECKPOINT_DIRECTORY):
    os.mkdir(CHECKPOINT_DIRECTORY)

if not os.path.exists(LOG_DIRECTORY):
    os.mkdir(LOG_DIRECTORY)

val = False
    
class_names = ['AD', 'CN']

# get paths to data
train_dir, test_dir, val_dir = train_test_split(
    saliencies_and_segmentations_v2_path, 
    ROOT_DIR, 
    split=(0.8, 0.15, 0.05), 
    dirname=DATA_DIR_NAME)

# set the batch size for mri seq
batch_size = 12
input_shape = (104, 128, 104, 1) # (112, 112, 105, 1)
resize_img = True
crop_img = True

# if y is one-hot encoded or just scalar number
one_hot = True

# class weightss (see analysis notebook)
class_weights = {0: 0.8072289156626505, 1: 1.3137254901960784}

# description statistics of the dataset
desc = {'mean': -3.6344006e-09, 'std': 1.0000092, 'min': -1.4982183, 'max': 10.744175}

if 'desc' not in locals():
    print('initializing desc...')
    desc = get_description(MRISequence(
        train_dir,
        64,
        class_names=class_names,
        input_shape=input_shape),
        max_samples=None)
    print(desc)


normalization={ 'type':'normalization', 'desc': desc }
# normalization={'type':'standardization', 'desc':desc }

augmentations = None
augmentations_inplace = True
# enable augmentations in mri seq (otherwise it can be enabled in dataset)
# augmentations={ 'random_swap_hemispheres': 0.5 }

# initialize sequences
print('initializing train_seq...')
train_seq = MRISequence(
    train_dir,
    batch_size,
    class_names=class_names,
    augmentations=augmentations,
    augmentations_inplace=augmentations_inplace,
    input_shape=input_shape,
    resize_img=resize_img,
    crop_img=crop_img,
    one_hot=one_hot,
    class_weights=class_weights,
    normalization=normalization)

print('initializing test_seq...')
test_seq = MRISequence(
    test_dir,
    batch_size,
    class_names=class_names,
    input_shape=input_shape,
    resize_img=resize_img,
    crop_img=crop_img,
    one_hot=one_hot,
    normalization=normalization)

if val:
    print('initializing val_seq...')
    val_seq = MRISequence(
        val_dir,
        batch_size,
        class_names=class_names,
        input_shape=input_shape,
        resize_img=resize_img,
        crop_img=crop_img,
        one_hot=one_hot,
        class_weights=class_weights,
        normalization=normalization)
else:
    print('val_seq = test_seq')
    val_seq = test_seq
    
model_key = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
log_dir = os.path.join(LOG_DIRECTORY, model_key)
print(f'log_dir: {log_dir}')

not copying files since the destination directory already exists
initializing train_seq...
initializing test_seq...
val_seq = test_seq
log_dir: ../../tmp\logs\20201226-124540
Wall time: 15.6 ms


In [7]:
import time
import cv2

import numpy as np

from tqdm import tqdm
from matplotlib import pyplot as plt
from skimage.transform import resize
from skimage.restoration import inpaint_biharmonic

from multiprocessing import Pool

from PIL import Image


def generate_mask(params):
    grid = params['grid']
    options = params['options']
    i = params['i']
    image_data = params['image_data']
    shift_x, shift_y, shift_z = params['random_shift']

    # mask has a soft corners
    mask = __get_mask(options, grid, shift_x, shift_y, shift_z)
    # binary mask does not have a soft corners and is used for an in_painting
    binary_mask = __get_binary_mask(options, grid, shift_x, shift_y, shift_z)

    in_paint_mask = None
    if options['b1'] > 0:
        in_paint_mask = __get_in_paint_mask(options, image_data, mask, binary_mask)

    new_image = __merge(options, image_data, mask, in_paint_mask)

    cache = None
    # when the debug mode is enabled
    # save the generated images to cache
    if options['debug']:
        cache = {
            'grid': grid,
            'binary_mask': binary_mask,
            'in_paint_mask': in_paint_mask,
        }

    return i, new_image, mask, cache


def __merge(options, original_image, mask, in_paint_mask):
    # original_image = (original_image - original_image.min())

    # blend original image with in_paint mask if exists
    new_image = original_image

    if in_paint_mask is not None:
        if options['b1'] < 1:
            new_image = (1 - options['b1']) * original_image + options['b1'] * in_paint_mask
        else:
            new_image = in_paint_mask

    # blend original image with in_paint mask with mask
    if options['b2'] > 0:
        if options['b2_value'] != 0:
            value = 1
            if options['b2_value'] == 'mean':
                value = np.mean(original_image)
            elif options['b2_value'] == 'median':
                value = np.median(original_image)
            new_image = (mask * options['b2']) * new_image + ((1 - mask) * value * options['b2'])
        else:
            new_image = (mask * options['b2']) * new_image

    return new_image


def __get_in_paint_mask(options, image_data, mask, binary_mask):
    if options['in_paint'] == '3d':
        return __get_in_paint_mask_3d(options, image_data, mask, binary_mask)
    return __get_in_paint_mask_2d(options, image_data, mask, binary_mask)


def __get_in_paint_mask_3d(options, image_data, mask, binary_mask):
    start = time.time()

    output = np.zeros(image_data.shape)
    inverted_binary_mask = 1 - binary_mask.astype(np.uint8)
    in_painted = inpaint_biharmonic(image_data, inverted_binary_mask, multichannel=False);

    if options['in_paint_blending']:
        # in_paint with gradual blending of edges (soft edges)
        output = image_data * mask + in_painted * (1 - mask)

    end = time.time()
    print(f"in: {end - start}")

    return output


def __get_in_paint_mask_2d(options, image_data, mask, binary_mask):
    start = time.time()

    in_painted = np.zeros(image_data.shape)
    inverted_binary_mask = (1 - binary_mask).astype(np.uint8)

    if not options['in_paint_2d_to_3d']:
        for z in range(0, image_data.shape[0]):
            in_painted_z = cv2.inpaint(
                image_data[z],
                inverted_binary_mask[z],
                options['in_paint_radius'],
                options['in_paint_algorithm']
            )

            if options['in_paint_blending']:
                # in_paint with gradual blending of edges (soft edges)
                in_painted_z = image_data[z] * mask[z] + in_painted_z * (1 - mask[z])

            in_painted[z] = in_painted_z
    else:
        for i in range(0, image_data.shape[0]):
            in_painted_i = cv2.inpaint(
                image_data[i, :, :],
                inverted_binary_mask[i, :, :],
                options['in_paint_radius'],
                options['in_paint_algorithm']
            )

            if options['in_paint_blending']:
                # in_paint with gradual blending of edges (soft edges)
                in_painted_i = image_data[i, :, :] * mask[i, :, :] + in_painted_i * (1 - mask[i, :, :])

            in_painted[i, :, :] += in_painted_i

        for i in range(0, image_data.shape[1]):
            in_painted_i = cv2.inpaint(
                image_data[:, i, :],
                inverted_binary_mask[:, i, :],
                options['in_paint_radius'],
                options['in_paint_algorithm']
            )

            if options['in_paint_blending']:
                # in_paint with gradual blending of edges (soft edges)
                in_painted_i = image_data[:, i, :] * mask[:, i, :] + in_painted_i * (1 - mask[:, i, :])

            in_painted[:, i, :] += in_painted_i

        for i in range(0, image_data.shape[2]):
            in_painted_i = cv2.inpaint(
                image_data[:, :, i],
                inverted_binary_mask[:, :, i],
                options['in_paint_radius'],
                options['in_paint_algorithm']
            )

            if options['in_paint_blending']:
                # in_paint with gradual blending of edges (soft edges)
                in_painted_i = image_data[:, :, i] * mask[:, :, i] + in_painted_i * (1 - mask[:, :, i])

            in_painted[:, :, i] += in_painted_i

        in_painted /= 3

    end = time.time()

    # print(f"in: {end - start}")

    return in_painted


def __get_mask(options, grid, shift_x, shift_y, shift_z):
    return resize(
        grid,
        options['mask_size'],
        order=1,
        mode='reflect',
        anti_aliasing=False)[
        shift_y:shift_y + options['input_size'][0],
        shift_x:shift_x + options['input_size'][1],
        shift_z:shift_z + options['input_size'][2]
    ]


def __get_binary_mask(options, grid, shift_x, shift_y, shift_z):
    new_grid = np.zeros(options['mask_size'])
    input_size = options['input_size']
    start = time.time()

    for a in range(0, grid.shape[0]):
        for b in range(0, grid.shape[1]):
            for c in range(0, grid.shape[2]):
                x = a * options['cell_size'][0]
                y = b * options['cell_size'][1]
                z = c * options['cell_size'][2]

                new_grid[x:x + options['cell_size'][0], y:y + options['cell_size'][1],
                z:z + options['cell_size'][2]] = int(grid[a][b][c])

    end = time.time()
    # print(f"gbm: {end - start}")

    return new_grid[shift_x:input_size[0] + shift_x, shift_y:input_size[1] + shift_y, shift_z:input_size[2] + shift_z]

In [33]:
N = 10

options = {
    's': 8, 
    'p1': 1/3, 
    'b1': 1,
    'b2': 0.75,
    'b2_value': 0,
    'in_paint': '2d', 
    'in_paint_blending': True, 
    'in_paint_radius': 5,
    'in_paint_2d_to_3d': True,
    'processes': 8,
    'input_size': input_shape[:-1],
    'in_paint_algorithm': cv2.INPAINT_TELEA,
    'debug': False,
}

cell_size = np.ceil(np.array(options['input_size']) / options['s'])
over_cell_size = np.ceil((((options['s'] + 1) * cell_size) - (options['s'] * cell_size)) / options['s'])

new_cell_size = cell_size + over_cell_size
mask_size = (options['s'] * new_cell_size).astype(np.uint32)
over_image_size = mask_size - options['input_size']

options['mask_size'] = mask_size
options['cell_size'] = new_cell_size.astype(np.uint32)
options['over_image_size'] = over_image_size

In [23]:
grids = np.random.rand(N, options['s'], options['s'], options['s']) < options['p1']
grids = grids.astype('float32')

random_shifts = np.array([np.array(np.random.rand(3) * options['over_image_size'], dtype=np.uint) for _ in range(N)])

In [24]:
idx = 4
batch_x, batch_y, *_ = train_seq[0]
image_x = batch_x[idx]
image_y = batch_y[idx]

In [35]:
idx = 0

params = {
    'grid': grids[idx],
    'options': options,
    'i': idx,
    'image_data': image_x.reshape(image_x.shape[:-1]),
    'random_shift': random_shifts[idx],
}

i, new_image, mask, cache = generate_mask(params)