In [15]:
import os
import csv
import re
import pickle
import random
import math
import dicom
import numpy as np
import h5py
from tqdm import tqdm
from natsort import natsorted
from skimage import transform
from sklearn.externals import joblib
from scipy import ndimage
from matplotlib import path

import matplotlib.pyplot as plt
%matplotlib inline

In [16]:
with open('../../data_supp/z_exclude_ED.csv', 'r') as f:
    reader = csv.reader(f)
    z_exclude_ED = list([(int(pt), int(start), int(end)) for pt, start, end in reader])
    
with open('../../data_supp/z_exclude_ES.csv', 'r') as f:
    reader = csv.reader(f)
    z_exclude_ES = list([(int(pt), int(start), int(end)) for pt, start, end in reader])

In [17]:
with open('../../data_supp/frames_ES.pkl', 'rb') as f:
    frames_ES = pickle.load(f)
with open('../../data_supp/frames_ED.pkl', 'rb') as f:
    frames_ED = pickle.load(f)

In [18]:
z_exclude_ED[0]

(429, 1, 10)

In [19]:
frames_ED[0]

0

In [20]:
def get_filepaths():
    with open('../../data_supp/filepaths_train.pkl', 'rb') as f:
        filepaths_train = pickle.load(f)
    with open('../../data_supp/filepaths_val.pkl', 'rb') as f:
        filepaths_val = pickle.load(f)
    return filepaths_train, filepaths_val

def get_training_labels():
    systole_labels = {}
    diastole_labels = {}
    with open('../../data/train.csv', 'r') as f:
        for _id, systole, diastole in csv.reader(f):
            if _id == 'Id':
                continue
            systole_labels[int(_id)] = float(systole)
            diastole_labels[int(_id)] = float(diastole)
    return systole_labels, diastole_labels

In [21]:
filepaths_train, filepaths_val = get_filepaths()

In [22]:
systole_labels, diastole_labels = get_training_labels()

# return tuple of real value and function represetation
def create_label(pt, mode='ED'):
    if mode == 'ES':
        return systole_labels[pt], systole_labels[pt] < np.arange(600)
    elif mode == 'ED':
        return diastole_labels[pt], diastole_labels[pt] < np.arange(600)
    else:
        raise

In [23]:
def apply_window(arr, window_center, window_width):
    return np.clip(arr, window_center - window_width/2, window_center + window_width/2)


def apply_per_slice_norm(arr):
    mean = np.mean(arr.ravel())
    std = np.std(arr.ravel())
    if std == 0:
        return np.zeros(arr.shape)
    return (arr - mean) / std


def crop_to_square(arr, size):
    x_len, y_len = arr.shape
    shorter_len = min(x_len, y_len)
    x_start = (arr.shape[0] - shorter_len) // 2
    x_end = x_start + shorter_len
    y_start = (arr.shape[1] - shorter_len) // 2
    y_end = y_start + shorter_len
    return transform.resize(arr[x_start:x_end, y_start:y_end], 
                            (size, size), order=1, clip=True, preserve_range=True)


def crop_to_square_normalized(img_orig, pixel_spacing, size):
    img_norm = ndimage.interpolation.zoom(img_orig, [float(x) for x in pixel_spacing], order=0, mode='constant')
    
    length_x, length_y = img_norm.shape
    if length_x >= size:
        x_start = length_x // 2 - size // 2
        x_end = length_x // 2 + size // 2
    else:
        x_start = 0
        x_end = length_x
    if length_y >= size:
        y_start = length_y // 2 - size // 2
        y_end = length_y // 2 + size // 2
    else:
        y_start = 0
        y_end = length_y
    
    img_new = np.zeros((size, size))
    new_x_shift = (size - (x_end - x_start)) // 2
    new_y_shift = (size - (y_end - y_start)) // 2
    img_new[new_x_shift:(new_x_shift + x_end - x_start), 
            new_y_shift:(new_y_shift + y_end - y_start)] = img_norm[x_start:x_end, y_start:y_end]
    
    return img_new


def img_augmentation(img, nb_samples, rotation=True, shift=True):
    
    img_aug_collection = []
        
    for i in range(nb_samples):
    
        img_aug = img
            
        # shift +/- 0 to 0.2
        if shift:
            shift_y = round(0.2 * random.randrange(-img.shape[0], img.shape[0]))
            shift_x = round(0.2 * random.randrange(-img.shape[1], img.shape[1]))
            img_aug = ndimage.interpolation.shift(img_aug, (shift_y, shift_x), order=0, mode='constant')
        
        # rotation +/- 0 to 30 degrees with probability 0.5
        if rotation and random.random() > 0.5:
            angle = random.randrange(-30, 30)
            img_aug = ndimage.interpolation.rotate(img_aug, angle, axes=(0, 1), 
                                                   order=0, mode='constant', reshape=False)
        
        img_aug_collection.append(img_aug)
        
    return img_aug_collection


def localize_to_centroid(img, centroid, width_about_centroid):
    # assumes already cropped to square
    x, y = centroid
    x = int(round(x))
    y = int(round(y))
    x_start = x - width_about_centroid // 2
    x_end = x + width_about_centroid // 2
    y_start = y - width_about_centroid // 2
    y_end = y + width_about_centroid // 2
    
    if x_start < 0:
        x_end += (0 - x_start)
        x_start = 0
    if x_end > img.shape[0]:
        x_start -= (img.shape[0] - x_end)
        x_end = img.shape[0]
    if y_start < 0:
        y_end += (0 - y_start)
        y_start = 0
    if y_end > img.shape[1]:
        y_start -= (img.shape[1] - y_end)
        y_end = img.shape[1]
        
    return img[x_start:x_end, y_start:y_end], (x_start, x_end), (y_start, y_end)


def get_all_series_filepaths(filepaths):
    t_slices = 30
    
    # create sax series filepaths
    # handles irregularies such as those including z-slices and t-slices in the same folder
    series_filepaths_all = []
    for view in filepaths.keys(): 
        if not re.match(r'^sax', view):
            continue
        
        if len(filepaths[view]) == t_slices:
            series_filepaths_all.append(filepaths[view])
        elif len(filepaths[view]) < t_slices:
            series_filepaths_all.append(filepaths[view][:] + filepaths[view][:(t_slices - len(filepaths[view]))])
        else:
            if re.match(r'^\w+-\d+-\d+-\d+.dcm$', filepaths[view][0][0]) is not None:
                series_filepaths_split = []
                slices_list = []
                series_filepaths_sort_by_slice = sorted(filepaths[view][:], 
                                                        key=lambda x: '{}-{}'.format(x[0].split('-')[-1].split('.')[0], 
                                                                                     x[0].split('-')[-2]))
                for fname, fpath in series_filepaths_sort_by_slice:
                    nslice = fname.split('-')[-1].split('.')[0]
                    tframe = fname.split('-')[-2]
                    if nslice not in slices_list:
                        if len(series_filepaths_split) == t_slices:
                            series_filepaths_all.append(series_filepaths_split)
                        elif len(series_filepaths_split) < t_slices and len(series_filepaths_split) > 0:
                            series_filepaths_all.append((series_filepaths_split[:] + 
                                                         series_filepaths_split[:(t_slices - len(series_filepaths_split))]))
                        series_filepaths_split = []
                        series_filepaths_split.append((fname, fpath))
                        slices_list.append(nslice)
                    else:
                        series_filepaths_split.append((fname, fpath))
                        
    return series_filepaths_all


def normalized_z_loc(df):
    # assumes patient position HFS
    
    position = [float(s) for s in df.ImagePositionPatient]
    orientation = [float(s) for s in df.ImageOrientationPatient]
    
    # first voxel coordinates from DICOM ImagePositionPatient field
    x_loc, y_loc, z_loc = position
    
    # row/column direction cosines from DICOM ImageOrientationPatient field
    row_dircos_x, row_dircos_y, row_dircos_z, col_dircos_x, col_dircos_y, col_dircos_z = orientation
    
    # normalized direction cosines
    dircos_x = row_dircos_y * col_dircos_z - row_dircos_z * col_dircos_y
    dircos_y = row_dircos_z * col_dircos_x - row_dircos_x * col_dircos_z
    dircos_z = row_dircos_x * col_dircos_y - row_dircos_y * col_dircos_x
    
    # z-coordinate location recalculated based on reference
    z_loc_norm = dircos_x * x_loc + dircos_y * y_loc + dircos_z * z_loc
    return z_loc_norm


def create_MIP(filepaths, full_size=256, frame=0):
    series_filepaths_all = get_all_series_filepaths(filepaths)

    vol3d = []
    for series_filepaths in series_filepaths_all:
        fname, fpath = natsorted(series_filepaths, lambda x: x[0])[frame]
        df = dicom.read_file(fpath)
        img2d = df.pixel_array
        vol3d.append(apply_per_slice_norm(crop_to_square(img2d, full_size)).astype(np.float32))
    orig_shape = img2d.shape
    pixel_spacing = df.PixelSpacing
    vol3d_mask = pred_loc_map(vol3d)

    vol3d_MIP = np.mean(np.array(vol3d), axis=0)
    vol3d_mask_MIP = np.mean(np.array(vol3d_mask), axis=0)

    return vol3d_MIP, vol3d_mask_MIP, orig_shape, pixel_spacing


def get_MIP_centroid(vol3d_mask_MIP):
    return ndimage.measurements.center_of_mass(vol3d_mask_MIP)


def create_localized_image_stack(filepaths, centroid, full_size=256, local_size=96, frame=0):
    series_filepaths_all = get_all_series_filepaths(filepaths)
                        
    # sort series by z-locations
    z_locs = []
    for series_filepaths in series_filepaths_all:
        df = dicom.read_file(natsorted(series_filepaths, lambda x: x[0])[frame][1])
        z_locs.append(normalized_z_loc(df))
    series_filepaths_all_zsorted = sorted(zip([min(z_locs) - z_loc for z_loc in z_locs], 
                                              series_filepaths_all), key=lambda pair: pair[0])
    
    series_filepaths_all_zsorted_with_depths = []
    for i in range(len(series_filepaths_all_zsorted)):
        z_loc, series_filepaths = series_filepaths_all_zsorted[i]
        if i == len(series_filepaths_all_zsorted) - 1:
            z_depth = float(dicom.read_file(natsorted(series_filepaths, lambda x: x[0])[frame][1]).SliceThickness)
        else:
            z_depth = series_filepaths_all_zsorted[i+1][0] - z_loc

        # filter out tiny depths, which are likely repeats
        if z_depth > 0.01:
            series_filepaths_all_zsorted_with_depths.append((z_depth, series_filepaths))

    img_stack = []
    z_depths = []
    for z_depth, series_filepaths in series_filepaths_all_zsorted_with_depths:
        fname, fpath = natsorted(series_filepaths, lambda x: x[0])[frame]
        df = dicom.read_file(fpath)
        img = df.pixel_array
        img_localized, _, _ = localize_to_centroid(crop_to_square(img, full_size), centroid, local_size)
        img_processed = apply_per_slice_norm(img_localized)
        img_stack.append(img_processed.astype(np.float32))
        z_depths.append(z_depth)
    orig_shape = img.shape
    pixel_spacing = [float(s) for s in df.PixelSpacing]
    
    #img_stack_masks = pred_seg_map(img_stack)
    img_stack_masks = None

    return img_stack, img_stack_masks, z_depths, orig_shape, pixel_spacing


def create_full_image_stack(filepaths, full_size=256, frame=0):
    series_filepaths_all = get_all_series_filepaths(filepaths)
                        
    # sort series by z-locations
    z_locs = []
    for series_filepaths in series_filepaths_all:
        df = dicom.read_file(natsorted(series_filepaths, lambda x: x[0])[frame][1])
        z_locs.append(normalized_z_loc(df))
    series_filepaths_all_zsorted = sorted(zip([min(z_locs) - z_loc - min(z_locs) for z_loc in z_locs], 
                                              series_filepaths_all), key=lambda pair: pair[0])
    
    series_filepaths_all_zsorted_with_depths = []
    for i in range(len(series_filepaths_all_zsorted)):
        z_loc, series_filepaths = series_filepaths_all_zsorted[i]
        if i == len(series_filepaths_all_zsorted) - 1:
            z_depth = float(dicom.read_file(natsorted(series_filepaths, lambda x: x[0])[frame][1]).SliceThickness)
        else:
            z_depth = series_filepaths_all_zsorted[i+1][0] - z_loc

        # filter out tiny depths, which are likely repeats
        if z_depth > 0.01:
            series_filepaths_all_zsorted_with_depths.append((z_depth, series_filepaths))

    img_stack = []
    z_depths = []
    for z_depth, series_filepaths in series_filepaths_all_zsorted_with_depths:
        fname, fpath = natsorted(series_filepaths, lambda x: x[0])[frame]
        df = dicom.read_file(fpath)
        img = df.pixel_array
        img_processed = apply_per_slice_norm(crop_to_square(img, full_size))
        img_stack.append(img_processed.astype(np.float32))
        z_depths.append(z_depth)
    orig_shape = img.shape
    pixel_spacing = [float(s) for s in df.PixelSpacing]
    
    #img_stack_masks = pred_loc_map(img_stack)
    img_stack_masks = None

    return img_stack, img_stack_masks, z_depths, orig_shape, pixel_spacing

In [24]:
from keras.models import Sequential, Graph
from keras.layers.core import Activation, Dense, Dropout, Flatten, Merge, Reshape, Lambda
from keras.layers.core import TimeDistributedDense, TimeDistributedMerge
from keras.layers.recurrent import LSTM, GRU
from keras.layers.convolutional import Convolution2D, MaxPooling2D, AveragePooling2D, UpSampling2D, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU, PReLU, ParametricSoftplus, ELU
from keras.layers.normalization import BatchNormalization
from keras.layers.noise import GaussianDropout, GaussianNoise
from keras.utils import np_utils, generic_utils
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras import initializations
from keras.layers.core import Layer
from keras import backend as K

# for preventing python max recursion limit error
import sys
sys.setrecursionlimit(50000)

def RMSE(y_true, y_pred):
    return K.sqrt(K.mean(K.square(y_pred - y_true), axis=None, keepdims=False))

def binaryCE(y_true, y_pred):
    return K.mean(K.binary_crossentropy(y_pred, y_true), axis=None, keepdims=False)

class Rotate90(Layer):
    def __init__(self, direction='clockwise', **kwargs):
        super(Rotate90, self).__init__(**kwargs)
        self.direction = direction

    def get_output(self, train):
        X = self.get_input(train)
        if self.direction == 'clockwise':
            return X.transpose((0, 2, 1))[:, :, ::-1]
        elif self.direction == 'counterclockwise':
            return X.transpose((0, 2, 1))[:, ::-1, :]
        else:
            raise

    def get_config(self):
        config = {"name": self.__class__.__name__}
        base_config = super(Rotate90, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
    
LV_loc_model = Graph()

LV_loc_model.add_input(name='input', input_shape=(1, 256, 256))

LV_loc_model.add_node(Convolution2D(32, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='conv-1-1', input='input')
LV_loc_model.add_node(BatchNormalization(), name='conv-1-1-bn', input='conv-1-1')
LV_loc_model.add_node(ELU(), name='conv-1-1-activ', input='conv-1-1-bn')
LV_loc_model.add_node(Convolution2D(32, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='conv-1-2', input='conv-1-1-activ')
LV_loc_model.add_node(BatchNormalization(), name='conv-1-2-bn', input='conv-1-2')
LV_loc_model.add_node(ELU(), name='conv-1-2-activ', input='conv-1-2-bn')
LV_loc_model.add_node(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'),
               name='pool-1', input='conv-1-2-activ')

LV_loc_model.add_node(Convolution2D(64, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='conv-2-1', input='pool-1')
LV_loc_model.add_node(BatchNormalization(), name='conv-2-1-bn', input='conv-2-1')
LV_loc_model.add_node(ELU(), name='conv-2-1-activ', input='conv-2-1-bn')
LV_loc_model.add_node(Convolution2D(64, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='conv-2-2', input='conv-2-1-activ')
LV_loc_model.add_node(BatchNormalization(), name='conv-2-2-bn', input='conv-2-2')
LV_loc_model.add_node(ELU(), name='conv-2-2-activ', input='conv-2-2-bn')
LV_loc_model.add_node(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'),
               name='pool-2', input='conv-2-2-activ')

LV_loc_model.add_node(Convolution2D(128, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='conv-3-1', input='pool-2')
LV_loc_model.add_node(BatchNormalization(), name='conv-3-1-bn', input='conv-3-1')
LV_loc_model.add_node(ELU(), name='conv-3-1-activ', input='conv-3-1-bn')
LV_loc_model.add_node(Convolution2D(128, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='conv-3-2', input='conv-3-1-activ')
LV_loc_model.add_node(BatchNormalization(), name='conv-3-2-bn', input='conv-3-2')
LV_loc_model.add_node(ELU(), name='conv-3-2-activ', input='conv-3-2-bn')
LV_loc_model.add_node(Convolution2D(128, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='conv-3-3', input='conv-3-2-activ')
LV_loc_model.add_node(BatchNormalization(), name='conv-3-3-bn', input='conv-3-3')
LV_loc_model.add_node(ELU(), name='conv-3-3-activ', input='conv-3-3-bn')
LV_loc_model.add_node(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'),
               name='pool-3', input='conv-3-3-activ')

LV_loc_model.add_node(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='conv-4-1', input='pool-3')
LV_loc_model.add_node(BatchNormalization(), name='conv-4-1-bn', input='conv-4-1')
LV_loc_model.add_node(ELU(), name='conv-4-1-activ', input='conv-4-1-bn')
LV_loc_model.add_node(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='conv-4-2', input='conv-4-1-activ')
LV_loc_model.add_node(BatchNormalization(), name='conv-4-2-bn', input='conv-4-2')
LV_loc_model.add_node(ELU(), name='conv-4-2-activ', input='conv-4-2-bn')
LV_loc_model.add_node(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='conv-4-3', input='conv-4-2-activ')
LV_loc_model.add_node(BatchNormalization(), name='conv-4-3-bn', input='conv-4-3')
LV_loc_model.add_node(ELU(), name='conv-4-3-activ', input='conv-4-3-bn')
LV_loc_model.add_node(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'),
               name='pool-4', input='conv-4-3-activ')

LV_loc_model.add_node(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='conv-5-1', input='pool-4')
LV_loc_model.add_node(BatchNormalization(), name='conv-5-1-bn', input='conv-5-1')
LV_loc_model.add_node(ELU(), name='conv-5-1-activ', input='conv-5-1-bn')
LV_loc_model.add_node(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='conv-5-2', input='conv-5-1-activ')
LV_loc_model.add_node(BatchNormalization(), name='conv-5-2-bn', input='conv-5-2')
LV_loc_model.add_node(ELU(), name='conv-5-2-activ', input='conv-5-2-bn')
LV_loc_model.add_node(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='conv-5-3', input='conv-5-2-activ')
LV_loc_model.add_node(BatchNormalization(), name='conv-5-3-bn', input='conv-5-3')
LV_loc_model.add_node(ELU(), name='conv-5-3-activ', input='conv-5-3-bn')
LV_loc_model.add_node(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'),
               name='pool-5', input='conv-5-3-activ')

LV_loc_model.add_node(Flatten(), name='flatten', input='pool-5')
LV_loc_model.add_node(Dense(4096, activation='relu'), name='fc-1', input='flatten')
LV_loc_model.add_node(Dropout(0.5), name='dropout-1', input='fc-1')
LV_loc_model.add_node(Dense(4096, activation='relu'), name='fc-2', input='dropout-1')
LV_loc_model.add_node(Dropout(0.5), name='dropout-2', input='fc-2')
LV_loc_model.add_node(Reshape((64, 8, 8)), name='reshape', input='dropout-2')

LV_loc_model.add_node(UpSampling2D(size=(2, 2), dim_ordering='th'), name='unpool-1', input='reshape')
LV_loc_model.add_node(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='deconv-1-1', input='unpool-1')
LV_loc_model.add_node(BatchNormalization(), name='deconv-1-1-bn', input='deconv-1-1')
LV_loc_model.add_node(ELU(), name='deconv-1-1-activ', input='deconv-1-1-bn')
LV_loc_model.add_node(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='deconv-1-2', input='deconv-1-1-activ')
LV_loc_model.add_node(BatchNormalization(), name='deconv-1-2-bn', input='deconv-1-2')
LV_loc_model.add_node(ELU(), name='deconv-1-2-activ', input='deconv-1-2-bn')
LV_loc_model.add_node(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='deconv-1-3', input='deconv-1-2-activ')
LV_loc_model.add_node(BatchNormalization(), name='deconv-1-3-bn', input='deconv-1-3')
LV_loc_model.add_node(ELU(), name='deconv-1-3-activ', input='deconv-1-3-bn')

LV_loc_model.add_node(UpSampling2D(size=(2, 2), dim_ordering='th'), name='unpool-2', input='deconv-1-3-activ')
LV_loc_model.add_node(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='deconv-2-1', input='unpool-2')
LV_loc_model.add_node(BatchNormalization(), name='deconv-2-1-bn', input='deconv-2-1')
LV_loc_model.add_node(ELU(), name='deconv-2-1-activ', input='deconv-2-1-bn')
LV_loc_model.add_node(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='deconv-2-2', input='deconv-2-1-activ')
LV_loc_model.add_node(BatchNormalization(), name='deconv-2-2-bn', input='deconv-2-2')
LV_loc_model.add_node(ELU(), name='deconv-2-2-activ', input='deconv-2-2-bn')
LV_loc_model.add_node(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='deconv-2-3', input='deconv-2-2-activ')
LV_loc_model.add_node(BatchNormalization(), name='deconv-2-3-bn', input='deconv-2-3')
LV_loc_model.add_node(ELU(), name='deconv-2-3-activ', input='deconv-2-3-bn')

LV_loc_model.add_node(UpSampling2D(size=(2, 2), dim_ordering='th'), name='unpool-3', input='deconv-2-3-activ')
LV_loc_model.add_node(Convolution2D(128, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='deconv-3-1', input='unpool-3')
LV_loc_model.add_node(BatchNormalization(), name='deconv-3-1-bn', input='deconv-3-1')
LV_loc_model.add_node(ELU(), name='deconv-3-1-activ', input='deconv-3-1-bn')
LV_loc_model.add_node(Convolution2D(128, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='deconv-3-2', input='deconv-3-1-activ')
LV_loc_model.add_node(BatchNormalization(), name='deconv-3-2-bn', input='deconv-3-2')
LV_loc_model.add_node(ELU(), name='deconv-3-2-activ', input='deconv-3-2-bn')
LV_loc_model.add_node(Convolution2D(128, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='deconv-3-3', input='deconv-3-2-activ')
LV_loc_model.add_node(BatchNormalization(), name='deconv-3-3-bn', input='deconv-3-3')
LV_loc_model.add_node(ELU(), name='deconv-3-3-activ', input='deconv-3-3-bn')

LV_loc_model.add_node(UpSampling2D(size=(2, 2), dim_ordering='th'), name='unpool-4', input='deconv-3-3-activ')
LV_loc_model.add_node(Convolution2D(64, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='deconv-4-1', input='unpool-4')
LV_loc_model.add_node(BatchNormalization(), name='deconv-4-1-bn', input='deconv-4-1')
LV_loc_model.add_node(ELU(), name='deconv-4-1-activ', input='deconv-4-1-bn')
LV_loc_model.add_node(Convolution2D(64, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='deconv-4-2', input='deconv-4-1-activ')
LV_loc_model.add_node(BatchNormalization(), name='deconv-4-2-bn', input='deconv-4-2')
LV_loc_model.add_node(ELU(), name='deconv-4-2-activ', input='deconv-4-2-bn')

LV_loc_model.add_node(UpSampling2D(size=(2, 2), dim_ordering='th'), name='unpool-5', input='deconv-4-2-activ')
LV_loc_model.add_node(Convolution2D(32, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='deconv-5-1', input='unpool-5')
LV_loc_model.add_node(BatchNormalization(), name='deconv-5-1-bn', input='deconv-5-1')
LV_loc_model.add_node(ELU(), name='deconv-5-1-activ', input='deconv-5-1-bn')
LV_loc_model.add_node(Convolution2D(32, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='deconv-5-2', input='deconv-5-1-activ')
LV_loc_model.add_node(BatchNormalization(), name='deconv-5-2-bn', input='deconv-5-2')
LV_loc_model.add_node(ELU(), name='deconv-5-2-activ', input='deconv-5-2-bn')

LV_loc_model.add_node(Convolution2D(1, 1, 1, activation='sigmoid', init='uniform', border_mode='same', dim_ordering='th'),
               name='prob-map', input='deconv-5-2-activ')
LV_loc_model.add_node(Reshape((256, 256)), name='prob-map-reshape', input='prob-map')
LV_loc_model.add_node(Dropout(0.5), name='prob-map-dropout', input='prob-map-reshape')

LV_loc_model.add_node(GRU(256, activation='tanh', inner_activation='hard_sigmoid', return_sequences=True),
               name='rnn-we', input='prob-map-dropout')
LV_loc_model.add_node(GRU(256, activation='tanh', inner_activation='hard_sigmoid', go_backwards=True, return_sequences=True),
               name='rnn-ew', input='prob-map-dropout')
LV_loc_model.add_node(TimeDistributedDense(256, init='uniform', activation='sigmoid'),
               name='rnn-1', inputs=['rnn-we', 'rnn-ew'], merge_mode='concat', concat_axis=-1)

LV_loc_model.add_node(Rotate90(direction='counterclockwise'), name='rotate', input='prob-map-dropout')
LV_loc_model.add_node(GRU(256, activation='tanh', inner_activation='hard_sigmoid', return_sequences=True),
               name='rnn-ns', input='rotate')
LV_loc_model.add_node(GRU(256, activation='tanh', inner_activation='hard_sigmoid', go_backwards=True, return_sequences=True),
               name='rnn-sn', input='rotate')
LV_loc_model.add_node(TimeDistributedDense(256, init='uniform', activation='sigmoid'),
               name='rnn-2-rotated', inputs=['rnn-ns', 'rnn-sn'], merge_mode='concat', concat_axis=-1)
LV_loc_model.add_node(Rotate90(direction='clockwise'), name='rnn-2', input='rnn-2-rotated')

LV_loc_model.add_node(Activation('linear'), name='pre-output', inputs=['rnn-1', 'rnn-2'], merge_mode='mul')
LV_loc_model.add_output(name='output', input='pre-output')

LV_loc_model.compile('adam', {'output': binaryCE})

LV_loc_model.load_weights('../../model_weights/weights_trainset2_full.hdf5')

def pred_loc_map(image_stack):
    preds = LV_loc_model.predict({'input': np.expand_dims(np.array(image_stack).astype(np.float32), axis=1)}, 
                                 verbose=0)['output']
    return [preds[i,:,:] for i in range(preds.shape[0])]

Using Theano backend.
Using gpu device 3: Tesla K80 (CNMeM is disabled)


### create training data

In [35]:
top_data_full_training = []
top_labels_full_training = []
top_data_full_validation = []
top_labels_full_validation = []

apex_data_full_training = []
apex_labels_full_training = []
apex_data_full_validation = []
apex_labels_full_validation = []

top_data_localized_training = []
top_labels_localized_training = []
top_data_localized_validation = []
top_labels_localized_validation = []

apex_data_localized_training = []
apex_labels_localized_training = []
apex_data_localized_validation = []
apex_labels_localized_validation = []

In [36]:
# full images

idx_z_exclude_ED_train = random.sample(range(len(z_exclude_ED)), round(0.9*len(z_exclude_ED)))
idx_z_exclude_ES_train = random.sample(range(len(z_exclude_ES)), round(0.9*len(z_exclude_ES)))

idx_train = idx_z_exclude_ED_train + [i+len(z_exclude_ED) for i in idx_z_exclude_ES_train]

for idx, (pt, start, end) in enumerate(tqdm(z_exclude_ED + z_exclude_ES)):
    if idx < len(z_exclude_ED):
        frame = frames_ED[pt - 1]
    else:
        frame = frames_ES[pt - 1]
    filepaths = filepaths_train[pt]
    img_stack, _, _, _, _ = create_full_image_stack(filepaths, full_size=256, frame=frame)
    if idx in idx_train:
        for img in img_stack[:start]:
            top_data_full_training.append(apply_per_slice_norm(img))
            top_labels_full_training.append(True)
            for image_aug in img_augmentation(img, 20, rotation=True, shift=True):
                top_data_full_training.append(apply_per_slice_norm(image_aug))
                top_labels_full_training.append(True)
        for img in img_stack[start:end]:
            top_data_full_training.append(apply_per_slice_norm(img))
            top_labels_full_training.append(False)
            apex_data_full_training.append(apply_per_slice_norm(img))
            apex_labels_full_training.append(False)
            for image_aug in img_augmentation(img, 5, rotation=True, shift=True):
                top_data_full_training.append(apply_per_slice_norm(image_aug))
                top_labels_full_training.append(False)
                apex_data_full_training.append(apply_per_slice_norm(image_aug))
                apex_labels_full_training.append(False)
        for img in img_stack[end:]:
            apex_data_full_training.append(apply_per_slice_norm(img))
            apex_labels_full_training.append(True)
            for image_aug in img_augmentation(img, 40, rotation=True, shift=True):
                apex_data_full_training.append(apply_per_slice_norm(image_aug))
                apex_labels_full_training.append(True)
    else:
        for img in img_stack[:start]:
            top_data_full_validation.append(apply_per_slice_norm(img))
            top_labels_full_validation.append(True)
            for image_aug in img_augmentation(img, 20, rotation=True, shift=True):
                top_data_full_validation.append(apply_per_slice_norm(image_aug))
                top_labels_full_validation.append(True)
        for img in img_stack[start:end]:
            top_data_full_validation.append(apply_per_slice_norm(img))
            top_labels_full_validation.append(False)
            apex_data_full_validation.append(apply_per_slice_norm(img))
            apex_labels_full_validation.append(False)
            for image_aug in img_augmentation(img, 5, rotation=True, shift=True):
                top_data_full_validation.append(apply_per_slice_norm(image_aug))
                top_labels_full_validation.append(False)
                apex_data_full_validation.append(apply_per_slice_norm(image_aug))
                apex_labels_full_validation.append(False)
        for img in img_stack[end:]:
            apex_data_full_validation.append(apply_per_slice_norm(img))
            apex_labels_full_validation.append(True)
            for image_aug in img_augmentation(img, 40, rotation=True, shift=True):
                apex_data_full_validation.append(apply_per_slice_norm(image_aug))
                apex_labels_full_validation.append(True)



In [37]:
print(len(top_labels_full_training), sum(top_labels_full_training)/len(top_labels_full_training))
print(len(top_labels_full_validation), sum(top_labels_full_validation)/len(top_labels_full_validation))
print(len(apex_labels_full_training), sum(apex_labels_full_training)/len(apex_labels_full_training))
print(len(apex_labels_full_validation), sum(apex_labels_full_validation)/len(apex_labels_full_validation))

18342 0.5152109911678115
2079 0.5151515151515151
16231 0.45215944796993407
1828 0.4485776805251641


In [38]:
# localized images

idx_z_exclude_ED_train = random.sample(range(len(z_exclude_ED)), round(0.9*len(z_exclude_ED)))
idx_z_exclude_ES_train = random.sample(range(len(z_exclude_ES)), round(0.9*len(z_exclude_ES)))

idx_train = idx_z_exclude_ED_train + [i+len(z_exclude_ED) for i in idx_z_exclude_ES_train]

for idx, (pt, start, end) in enumerate(tqdm(z_exclude_ED + z_exclude_ES)):
    if idx < len(z_exclude_ED):
        frame = frames_ED[pt - 1]
    else:
        frame = frames_ES[pt - 1]
    test_filepaths = filepaths_train[pt]
    vol3d_MIP, vol3d_mask_MIP, _, _ = create_MIP(test_filepaths, full_size=256, frame=frame)
    centroid = get_MIP_centroid(vol3d_mask_MIP)
    img_stack, _, _, _, _ = create_localized_image_stack(filepaths, centroid, full_size=256, local_size=96, frame=frame)
    if idx in idx_train:
        for img in img_stack[:start]:
            top_data_localized_training.append(apply_per_slice_norm(img))
            top_labels_localized_training.append(True)
            for image_aug in img_augmentation(img, 20, rotation=True, shift=True):
                top_data_localized_training.append(apply_per_slice_norm(image_aug))
                top_labels_localized_training.append(True)
        for img in img_stack[start:end]:
            top_data_localized_training.append(apply_per_slice_norm(img))
            top_labels_localized_training.append(False)
            apex_data_localized_training.append(apply_per_slice_norm(img))
            apex_labels_localized_training.append(False)
            for image_aug in img_augmentation(img, 5, rotation=True, shift=True):
                top_data_localized_training.append(apply_per_slice_norm(image_aug))
                top_labels_localized_training.append(False)
                apex_data_localized_training.append(apply_per_slice_norm(image_aug))
                apex_labels_localized_training.append(False)
        for img in img_stack[end:]:
            apex_data_localized_training.append(apply_per_slice_norm(img))
            apex_labels_localized_training.append(True)
            for image_aug in img_augmentation(img, 40, rotation=True, shift=True):
                apex_data_localized_training.append(apply_per_slice_norm(image_aug))
                apex_labels_localized_training.append(True)
    else:
        for img in img_stack[:start]:
            top_data_localized_validation.append(apply_per_slice_norm(img))
            top_labels_localized_validation.append(True)
            for image_aug in img_augmentation(img, 20, rotation=True, shift=True):
                top_data_localized_validation.append(apply_per_slice_norm(image_aug))
                top_labels_localized_validation.append(True)
        for img in img_stack[start:end]:
            top_data_localized_validation.append(apply_per_slice_norm(img))
            top_labels_localized_validation.append(False)
            apex_data_localized_validation.append(apply_per_slice_norm(img))
            apex_labels_localized_validation.append(False)
            for image_aug in img_augmentation(img, 5, rotation=True, shift=True):
                top_data_localized_validation.append(apply_per_slice_norm(image_aug))
                top_labels_localized_validation.append(False)
                apex_data_localized_validation.append(apply_per_slice_norm(image_aug))
                apex_labels_localized_validation.append(False)
        for img in img_stack[end:]:
            apex_data_localized_validation.append(apply_per_slice_norm(img))
            apex_labels_localized_validation.append(True)
            for image_aug in img_augmentation(img, 40, rotation=True, shift=True):
                apex_data_localized_validation.append(apply_per_slice_norm(image_aug))
                apex_labels_localized_validation.append(True)



In [39]:
print(len(top_labels_localized_training), sum(top_labels_localized_training)/len(top_labels_localized_training))
print(len(top_labels_localized_validation), sum(top_labels_localized_validation)/len(top_labels_localized_validation))
print(len(apex_labels_localized_training), sum(apex_labels_localized_training)/len(apex_labels_localized_training))
print(len(apex_labels_localized_validation), sum(apex_labels_localized_validation)/len(apex_labels_localized_validation))

18204 0.5214238628872775
2013 0.511177347242921
18142 0.5197883364568405
2173 0.5471698113207547


In [40]:
top_data_full_training = np.expand_dims(np.array(top_data_full_training).astype(np.float32), axis=1)
top_labels_full_training = np.array(top_labels_full_training).astype(np.bool)
top_data_full_validation = np.expand_dims(np.array(top_data_full_validation).astype(np.float32), axis=1)
top_labels_full_validation = np.array(top_labels_full_validation).astype(np.bool)

apex_data_full_training = np.expand_dims(np.array(apex_data_full_training).astype(np.float32), axis=1)
apex_labels_full_training = np.array(apex_labels_full_training).astype(np.bool)
apex_data_full_validation = np.expand_dims(np.array(apex_data_full_validation).astype(np.float32), axis=1)
apex_labels_full_validation = np.array(apex_labels_full_validation).astype(np.bool)

top_data_localized_training = np.expand_dims(np.array(top_data_localized_training).astype(np.float32), axis=1)
top_labels_localized_training = np.array(top_labels_localized_training).astype(np.bool)
top_data_localized_validation = np.expand_dims(np.array(top_data_localized_validation).astype(np.float32), axis=1)
top_labels_localized_validation = np.array(top_labels_localized_validation).astype(np.bool)

apex_data_localized_training = np.expand_dims(np.array(apex_data_localized_training).astype(np.float32), axis=1)
apex_labels_localized_training = np.array(apex_labels_localized_training).astype(np.bool)
apex_data_localized_validation = np.expand_dims(np.array(apex_data_localized_validation).astype(np.float32), axis=1)
apex_labels_localized_validation = np.array(apex_labels_localized_validation).astype(np.bool)

print(top_data_full_training.shape, top_labels_full_training.shape, 
      top_data_full_validation.shape, top_labels_full_validation.shape, 
      apex_data_full_training.shape, apex_labels_full_training.shape, 
      apex_data_full_validation.shape, apex_labels_full_validation.shape, 
      top_data_localized_training.shape, top_labels_localized_training.shape, 
      top_data_localized_validation.shape, top_labels_localized_validation.shape, 
      apex_data_localized_training.shape, apex_labels_localized_training.shape, 
      apex_data_localized_validation.shape, apex_labels_localized_validation.shape)

(18342, 1, 256, 256) (18342,) (2079, 1, 256, 256) (2079,) (16231, 1, 256, 256) (16231,) (1828, 1, 256, 256) (1828,) (18204, 1, 96, 96) (18204,) (2013, 1, 96, 96) (2013,) (18142, 1, 96, 96) (18142,) (2173, 1, 96, 96) (2173,)


In [41]:
shuffle_index = list(range(top_data_full_training.shape[0]))
random.shuffle(shuffle_index)
top_data_full_training = top_data_full_training[shuffle_index]
top_labels_full_training = top_labels_full_training[shuffle_index]

shuffle_index = list(range(top_data_full_validation.shape[0]))
random.shuffle(shuffle_index)
top_data_full_validation = top_data_full_validation[shuffle_index]
top_labels_full_validation = top_labels_full_validation[shuffle_index]

shuffle_index = list(range(apex_data_full_training.shape[0]))
random.shuffle(shuffle_index)
apex_data_full_training = apex_data_full_training[shuffle_index]
apex_labels_full_training = apex_labels_full_training[shuffle_index]

shuffle_index = list(range(apex_data_full_validation.shape[0]))
random.shuffle(shuffle_index)
apex_data_full_validation = apex_data_full_validation[shuffle_index]
apex_labels_full_validation = apex_labels_full_validation[shuffle_index]

shuffle_index = list(range(top_data_localized_training.shape[0]))
random.shuffle(shuffle_index)
top_data_localized_training = top_data_localized_training[shuffle_index]
top_labels_localized_training = top_labels_localized_training[shuffle_index]

shuffle_index = list(range(top_data_localized_validation.shape[0]))
random.shuffle(shuffle_index)
top_data_localized_validation = top_data_localized_validation[shuffle_index]
top_labels_localized_validation = top_labels_localized_validation[shuffle_index]

shuffle_index = list(range(apex_data_localized_training.shape[0]))
random.shuffle(shuffle_index)
apex_data_localized_training = apex_data_localized_training[shuffle_index]
apex_labels_localized_training = apex_labels_localized_training[shuffle_index]

shuffle_index = list(range(apex_data_localized_validation.shape[0]))
random.shuffle(shuffle_index)
apex_data_localized_validation = apex_data_localized_validation[shuffle_index]
apex_labels_localized_validation = apex_labels_localized_validation[shuffle_index]

In [42]:
joblib.dump((top_data_full_training, top_labels_full_training,
             top_data_full_validation, top_labels_full_validation,
             apex_data_full_training, apex_labels_full_training,
             apex_data_full_validation, apex_labels_full_validation,
             top_data_localized_training, top_labels_localized_training,
             top_data_localized_validation, top_labels_localized_validation,
             apex_data_localized_training, apex_labels_localized_training,
             apex_data_localized_validation, apex_labels_localized_validation), 
            '../../data_proc/trainset2_zslice_classification.pkl')

['../../data_proc/trainset2_zslice_classification.pkl',
 '../../data_proc/trainset2_zslice_classification.pkl_01.npy',
 '../../data_proc/trainset2_zslice_classification.pkl_02.npy',
 '../../data_proc/trainset2_zslice_classification.pkl_03.npy',
 '../../data_proc/trainset2_zslice_classification.pkl_04.npy',
 '../../data_proc/trainset2_zslice_classification.pkl_05.npy',
 '../../data_proc/trainset2_zslice_classification.pkl_06.npy',
 '../../data_proc/trainset2_zslice_classification.pkl_07.npy',
 '../../data_proc/trainset2_zslice_classification.pkl_08.npy',
 '../../data_proc/trainset2_zslice_classification.pkl_09.npy',
 '../../data_proc/trainset2_zslice_classification.pkl_10.npy',
 '../../data_proc/trainset2_zslice_classification.pkl_11.npy',
 '../../data_proc/trainset2_zslice_classification.pkl_12.npy',
 '../../data_proc/trainset2_zslice_classification.pkl_13.npy',
 '../../data_proc/trainset2_zslice_classification.pkl_14.npy',
 '../../data_proc/trainset2_zslice_classification.pkl_15.npy',

### training

In [3]:
(top_data_full_training, top_labels_full_training,
top_data_full_validation, top_labels_full_validation,
apex_data_full_training, apex_labels_full_training,
apex_data_full_validation, apex_labels_full_validation,
top_data_localized_training, top_labels_localized_training,
top_data_localized_validation, top_labels_localized_validation,
apex_data_localized_training, apex_labels_localized_training,
apex_data_localized_validation, apex_labels_localized_validation) = \
    joblib.load('../../data_proc/trainset2_zslice_classification.pkl')

In [11]:
from keras.models import Sequential, Graph
from keras.layers.core import Activation, Dense, Dropout, Flatten, Merge, Reshape, Lambda
from keras.layers.core import TimeDistributedDense, TimeDistributedMerge
from keras.layers.recurrent import LSTM, GRU
from keras.layers.convolutional import Convolution2D, MaxPooling2D, AveragePooling2D, UpSampling2D, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU, PReLU, ParametricSoftplus, ELU
from keras.layers.normalization import BatchNormalization
from keras.layers.noise import GaussianDropout, GaussianNoise
from keras.utils import np_utils, generic_utils
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras import initializations
from keras.layers.core import Layer
from keras import backend as K

# for preventing python max recursion limit error
import sys
sys.setrecursionlimit(50000)

Using Theano backend.
Using gpu device 3: Tesla K80 (CNMeM is disabled)


In [50]:
# top-of-LV full-image classifier

model_full_top_exclusion = Sequential()

model_full_top_exclusion.add(Convolution2D(32, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th', 
                                           input_shape=(1, 256, 256)))
model_full_top_exclusion.add(BatchNormalization())
model_full_top_exclusion.add(ELU())
model_full_top_exclusion.add(Convolution2D(32, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'))
model_full_top_exclusion.add(BatchNormalization())
model_full_top_exclusion.add(ELU())
model_full_top_exclusion.add(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'))

model_full_top_exclusion.add(Convolution2D(64, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'))
model_full_top_exclusion.add(BatchNormalization())
model_full_top_exclusion.add(ELU())
model_full_top_exclusion.add(Convolution2D(64, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'))
model_full_top_exclusion.add(BatchNormalization())
model_full_top_exclusion.add(ELU())
model_full_top_exclusion.add(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'))

model_full_top_exclusion.add(Convolution2D(128, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'))
model_full_top_exclusion.add(BatchNormalization())
model_full_top_exclusion.add(ELU())
model_full_top_exclusion.add(Convolution2D(128, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'))
model_full_top_exclusion.add(BatchNormalization())
model_full_top_exclusion.add(ELU())
model_full_top_exclusion.add(Convolution2D(128, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'))
model_full_top_exclusion.add(BatchNormalization())
model_full_top_exclusion.add(ELU())
model_full_top_exclusion.add(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'))

model_full_top_exclusion.add(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'))
model_full_top_exclusion.add(BatchNormalization())
model_full_top_exclusion.add(ELU())
model_full_top_exclusion.add(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'))
model_full_top_exclusion.add(BatchNormalization())
model_full_top_exclusion.add(ELU())
model_full_top_exclusion.add(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'))
model_full_top_exclusion.add(BatchNormalization())
model_full_top_exclusion.add(ELU())
model_full_top_exclusion.add(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'))

model_full_top_exclusion.add(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'))
model_full_top_exclusion.add(BatchNormalization())
model_full_top_exclusion.add(ELU())
model_full_top_exclusion.add(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'))
model_full_top_exclusion.add(BatchNormalization())
model_full_top_exclusion.add(ELU())
model_full_top_exclusion.add(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'))
model_full_top_exclusion.add(BatchNormalization())
model_full_top_exclusion.add(ELU())
model_full_top_exclusion.add(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'))

weights_file = h5py.File('../../model_weights/weights_trainset2_full.hdf5')
weights = [weights_file['graph']['param_{}'.format(p)] for p in range(weights_file['graph'].attrs['nb_params'])]
for layer in model_full_top_exclusion.layers:
    nb_param = len(layer.get_weights())
    layer.set_weights(weights[:nb_param])
    #layer.trainable = False
    weights = weights[nb_param:]
weights_file.close()

model_full_top_exclusion.add(Flatten())
model_full_top_exclusion.add(Dense(1024, activation='relu'))
model_full_top_exclusion.add(Dropout(0.5))
model_full_top_exclusion.add(Dense(1, activation='sigmoid'))

model_full_top_exclusion.compile(optimizer='adam', loss='binary_crossentropy', class_mode='binary')

In [None]:
batch_size = 32
nb_epoch = 100

checkpointer = ModelCheckpoint(filepath='../../model_weights/weights_trainset2_full_zslice_top_exclusion.hdf5',
                               verbose=1, save_best_only=True)
earlystopping = EarlyStopping(monitor='val_loss', patience=5, verbose=1)

model_full_top_exclusion.fit(top_data_full_training, top_labels_full_training,
                             batch_size=batch_size, nb_epoch=nb_epoch, verbose=1, shuffle=True, show_accuracy=True,
                             validation_data=(top_data_full_validation, top_labels_full_validation),
                             callbacks=[checkpointer, earlystopping])

In [None]:
# LV-apex full-image classifier

In [None]:
# top-of-LV localized-image classifier

model_local_top_exclusion = Sequentual()

model_local_top_exclusion.add(Convolution2D(64, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th', 
                                                 input_shape=(1, 96, 96)))
model_local_top_exclusion.add(BatchNormalization())
model_local_top_exclusion.add(ELU())
model_local_top_exclusion.add(Convolution2D(64, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'))
model_local_top_exclusion.add(BatchNormalization())
model_local_top_exclusion.add(ELU())
model_local_top_exclusion.add(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'))

model_local_top_exclusion.add(Convolution2D(128, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'))
model_local_top_exclusion.add(BatchNormalization())
model_local_top_exclusion.add(ELU())
model_local_top_exclusion.add(Convolution2D(128, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'))
model_local_top_exclusion.add(BatchNormalization())
model_local_top_exclusion.add(ELU())
model_local_top_exclusion.add(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'))

model_local_top_exclusion.add(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'))
model_local_top_exclusion.add(BatchNormalization())
model_local_top_exclusion.add(ELU())
model_local_top_exclusion.add(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'))
model_local_top_exclusion.add(BatchNormalization())
model_local_top_exclusion.add(ELU())
model_local_top_exclusion.add(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'))
model_local_top_exclusion.add(BatchNormalization())
model_local_top_exclusion.add(ELU())
model_local_top_exclusion.add(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'))

model_local_top_exclusion.add(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'))
model_local_top_exclusion.add(BatchNormalization())
model_local_top_exclusion.add(ELU())
model_local_top_exclusion.add(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'))
model_local_top_exclusion.add(BatchNormalization())
model_local_top_exclusion.add(ELU())
model_local_top_exclusion.add(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'))
model_local_top_exclusion.add(BatchNormalization())
model_local_top_exclusion.add(ELU())
model_local_top_exclusion.add(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'))

weights_file = h5py.File('../../model_weights/weights_trainset2_local.hdf5')
weights = [weights_file['graph']['param_{}'.format(p)] for p in range(weights_file['graph'].attrs['nb_params'])]
for layer in model_local_top_exclusion.layers:
    nb_param = len(layer.get_weights())
    layer.set_weights(weights[:nb_param])
    #layer.trainable = False
    weights = weights[nb_param:]
weights_file.close()

model_local_top_exclusion.add(Flatten())
model_local_top_exclusion.add(Dense(1024, activation='relu'))
model_local_top_exclusion.add(Dropout(0.5))
model_local_top_exclusion.add(Dense(1, activation='sigmoid'))

model_local_top_exclusion.compile(optimizer='adam', loss='binary_crossentropy', class_mode='binary')

In [None]:
batch_size = 128
nb_epoch = 100

checkpointer = ModelCheckpoint(filepath='../../model_weights/weights_trainset2_local_zslice_top_exclusion.hdf5',
                               verbose=1, save_best_only=True)
earlystopping = EarlyStopping(monitor='val_loss', patience=5, verbose=1)

model_local_top_exclusion.fit(top_data_localized_training, top_labels_localized_training,
                            batch_size=batch_size, nb_epoch=nb_epoch, verbose=1, shuffle=True, show_accuracy=True,
                            validation_data=(top_data_localized_validation, top_labels_localized_validation),
                            callbacks=[checkpointer, earlystopping])

In [None]:
# LV-apex localized-image classifier

The above training code is run separately in a script.