In [1]:
import os
import csv
import re
import pickle
import random
import math
import dicom
import numpy as np
from tqdm import tqdm
from natsort import natsorted
from skimage import transform
from sklearn.externals import joblib
from scipy import ndimage
from matplotlib import path

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
from keras.models import Sequential, Graph
from keras.layers.core import Activation, Dense, Dropout, Flatten, Merge, Reshape, Lambda
from keras.layers.core import TimeDistributedDense, TimeDistributedMerge
from keras.layers.recurrent import LSTM, GRU
from keras.layers.convolutional import Convolution2D, MaxPooling2D, AveragePooling2D, UpSampling2D, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU, PReLU, ParametricSoftplus, ELU
from keras.layers.normalization import BatchNormalization
from keras.layers.noise import GaussianDropout, GaussianNoise
from keras.utils import np_utils, generic_utils
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras import initializations
from keras.layers.core import Layer
from keras import backend as K
import theano
from theano import tensor as T

# for preventing python max recursion limit error
import sys
sys.setrecursionlimit(50000)

Using Theano backend.
Using gpu device 3: Tesla K80 (CNMeM is disabled)


In [3]:
def get_filepaths():
    with open('../../data_supp/filepaths_train.pkl', 'rb') as f:
        filepaths_train = pickle.load(f)
    with open('../../data_supp/filepaths_val.pkl', 'rb') as f:
        filepaths_val = pickle.load(f)
    return filepaths_train, filepaths_val

def get_training_labels():
    systole_labels = {}
    diastole_labels = {}
    with open('../../data/train.csv', 'r') as f:
        for _id, systole, diastole in csv.reader(f):
            if _id == 'Id':
                continue
            systole_labels[int(_id)] = float(systole)
            diastole_labels[int(_id)] = float(diastole)
    return systole_labels, diastole_labels

In [4]:
filepaths_train, filepaths_val = get_filepaths()

systole_labels, diastole_labels = get_training_labels()

In [5]:
# return tuple of real value and function represetation
def create_label(pt, mode='ED'):
    if mode == 'ES':
        return systole_labels[pt], systole_labels[pt] < np.arange(600)
    elif mode == 'ED':
        return diastole_labels[pt], diastole_labels[pt] < np.arange(600)
    else:
        raise


def apply_window(arr, window_center, window_width):
    return np.clip(arr, window_center - window_width/2, window_center + window_width/2)


def apply_per_slice_norm(arr):
    mean = np.mean(arr.ravel())
    std = np.std(arr.ravel())
    if std == 0:
        return np.zeros(arr.shape)
    return (arr - mean) / std


def crop_to_square(arr, size):
    x_len, y_len = arr.shape
    shorter_len = min(x_len, y_len)
    x_start = (arr.shape[0] - shorter_len) // 2
    x_end = x_start + shorter_len
    y_start = (arr.shape[1] - shorter_len) // 2
    y_end = y_start + shorter_len
    return transform.resize(arr[x_start:x_end, y_start:y_end], 
                            (size, size), order=1, clip=True, preserve_range=True)


def crop_to_square_normalized(img_orig, pixel_spacing, size):
    img_norm = ndimage.interpolation.zoom(img_orig, [float(x) for x in pixel_spacing], order=0, mode='constant')
    
    length_x, length_y = img_norm.shape
    if length_x >= size:
        x_start = length_x // 2 - size // 2
        x_end = length_x // 2 + size // 2
    else:
        x_start = 0
        x_end = length_x
    if length_y >= size:
        y_start = length_y // 2 - size // 2
        y_end = length_y // 2 + size // 2
    else:
        y_start = 0
        y_end = length_y
    
    img_new = np.zeros((size, size))
    new_x_shift = (size - (x_end - x_start)) // 2
    new_y_shift = (size - (y_end - y_start)) // 2
    img_new[new_x_shift:(new_x_shift + x_end - x_start), 
            new_y_shift:(new_y_shift + y_end - y_start)] = img_norm[x_start:x_end, y_start:y_end]
    
    return img_new


def localize_to_centroid(img, centroid, width_about_centroid):
    # assumes already cropped to square
    x, y = centroid
    x = int(round(x))
    y = int(round(y))
    x_start = x - width_about_centroid // 2
    x_end = x + width_about_centroid // 2
    y_start = y - width_about_centroid // 2
    y_end = y + width_about_centroid // 2
    
    if x_start < 0:
        x_end += (0 - x_start)
        x_start = 0
    if x_end > img.shape[0]:
        x_start -= (img.shape[0] - x_end)
        x_end = img.shape[0]
    if y_start < 0:
        y_end += (0 - y_start)
        y_start = 0
    if y_end > img.shape[1]:
        y_start -= (img.shape[1] - y_end)
        y_end = img.shape[1]
        
    return img[x_start:x_end, y_start:y_end], (x_start, x_end), (y_start, y_end)


def get_all_series_filepaths(filepaths):
    t_slices = 30
    
    # create sax series filepaths
    # handles irregularies such as those including z-slices and t-slices in the same folder
    series_filepaths_all = []
    for view in filepaths.keys(): 
        if not re.match(r'^sax', view):
            continue
        
        if len(filepaths[view]) == t_slices:
            series_filepaths_all.append(filepaths[view])
        elif len(filepaths[view]) < t_slices:
            series_filepaths_all.append(filepaths[view][:] + filepaths[view][:(t_slices - len(filepaths[view]))])
        else:
            if re.match(r'^\w+-\d+-\d+-\d+.dcm$', filepaths[view][0][0]) is not None:
                series_filepaths_split = []
                slices_list = []
                series_filepaths_sort_by_slice = sorted(filepaths[view][:], 
                                                        key=lambda x: '{}-{}'.format(x[0].split('-')[-1].split('.')[0], 
                                                                                     x[0].split('-')[-2]))
                for fname, fpath in series_filepaths_sort_by_slice:
                    nslice = fname.split('-')[-1].split('.')[0]
                    tframe = fname.split('-')[-2]
                    if nslice not in slices_list:
                        if len(series_filepaths_split) == t_slices:
                            series_filepaths_all.append(series_filepaths_split)
                        elif len(series_filepaths_split) < t_slices and len(series_filepaths_split) > 0:
                            series_filepaths_all.append((series_filepaths_split[:] + 
                                                         series_filepaths_split[:(t_slices - len(series_filepaths_split))]))
                        series_filepaths_split = []
                        series_filepaths_split.append((fname, fpath))
                        slices_list.append(nslice)
                    else:
                        series_filepaths_split.append((fname, fpath))
                        
    return series_filepaths_all

In [6]:
with open('../../data_supp/pt_es_frame.pkl', 'rb') as f:
    pt_es_frame = pickle.load(f)

In [None]:
LV_loc_model = Graph()

LV_loc_model.add_input(name='input', input_shape=(1, 256, 256))

LV_loc_model.add_node(Convolution2D(32, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='conv-1-1', input='input')
LV_loc_model.add_node(BatchNormalization(), name='conv-1-1-bn', input='conv-1-1')
LV_loc_model.add_node(ELU(), name='conv-1-1-activ', input='conv-1-1-bn')
LV_loc_model.add_node(Convolution2D(32, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='conv-1-2', input='conv-1-1-activ')
LV_loc_model.add_node(BatchNormalization(), name='conv-1-2-bn', input='conv-1-2')
LV_loc_model.add_node(ELU(), name='conv-1-2-activ', input='conv-1-2-bn')
LV_loc_model.add_node(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'),
               name='pool-1', input='conv-1-2-activ')

LV_loc_model.add_node(Convolution2D(64, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='conv-2-1', input='pool-1')
LV_loc_model.add_node(BatchNormalization(), name='conv-2-1-bn', input='conv-2-1')
LV_loc_model.add_node(ELU(), name='conv-2-1-activ', input='conv-2-1-bn')
LV_loc_model.add_node(Convolution2D(64, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='conv-2-2', input='conv-2-1-activ')
LV_loc_model.add_node(BatchNormalization(), name='conv-2-2-bn', input='conv-2-2')
LV_loc_model.add_node(ELU(), name='conv-2-2-activ', input='conv-2-2-bn')
LV_loc_model.add_node(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'),
               name='pool-2', input='conv-2-2-activ')

LV_loc_model.add_node(Convolution2D(128, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='conv-3-1', input='pool-2')
LV_loc_model.add_node(BatchNormalization(), name='conv-3-1-bn', input='conv-3-1')
LV_loc_model.add_node(ELU(), name='conv-3-1-activ', input='conv-3-1-bn')
LV_loc_model.add_node(Convolution2D(128, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='conv-3-2', input='conv-3-1-activ')
LV_loc_model.add_node(BatchNormalization(), name='conv-3-2-bn', input='conv-3-2')
LV_loc_model.add_node(ELU(), name='conv-3-2-activ', input='conv-3-2-bn')
LV_loc_model.add_node(Convolution2D(128, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='conv-3-3', input='conv-3-2-activ')
LV_loc_model.add_node(BatchNormalization(), name='conv-3-3-bn', input='conv-3-3')
LV_loc_model.add_node(ELU(), name='conv-3-3-activ', input='conv-3-3-bn')
LV_loc_model.add_node(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'),
               name='pool-3', input='conv-3-3-activ')

LV_loc_model.add_node(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='conv-4-1', input='pool-3')
LV_loc_model.add_node(BatchNormalization(), name='conv-4-1-bn', input='conv-4-1')
LV_loc_model.add_node(ELU(), name='conv-4-1-activ', input='conv-4-1-bn')
LV_loc_model.add_node(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='conv-4-2', input='conv-4-1-activ')
LV_loc_model.add_node(BatchNormalization(), name='conv-4-2-bn', input='conv-4-2')
LV_loc_model.add_node(ELU(), name='conv-4-2-activ', input='conv-4-2-bn')
LV_loc_model.add_node(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='conv-4-3', input='conv-4-2-activ')
LV_loc_model.add_node(BatchNormalization(), name='conv-4-3-bn', input='conv-4-3')
LV_loc_model.add_node(ELU(), name='conv-4-3-activ', input='conv-4-3-bn')
LV_loc_model.add_node(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'),
               name='pool-4', input='conv-4-3-activ')

LV_loc_model.add_node(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='conv-5-1', input='pool-4')
LV_loc_model.add_node(BatchNormalization(), name='conv-5-1-bn', input='conv-5-1')
LV_loc_model.add_node(ELU(), name='conv-5-1-activ', input='conv-5-1-bn')
LV_loc_model.add_node(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='conv-5-2', input='conv-5-1-activ')
LV_loc_model.add_node(BatchNormalization(), name='conv-5-2-bn', input='conv-5-2')
LV_loc_model.add_node(ELU(), name='conv-5-2-activ', input='conv-5-2-bn')
LV_loc_model.add_node(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='conv-5-3', input='conv-5-2-activ')
LV_loc_model.add_node(BatchNormalization(), name='conv-5-3-bn', input='conv-5-3')
LV_loc_model.add_node(ELU(), name='conv-5-3-activ', input='conv-5-3-bn')
LV_loc_model.add_node(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'),
               name='pool-5', input='conv-5-3-activ')

LV_loc_model.add_node(Flatten(), name='flatten', input='pool-5')
LV_loc_model.add_node(Dense(4096, activation='relu'), name='fc-1', input='flatten')
LV_loc_model.add_node(Dropout(0.5), name='dropout-1', input='fc-1')
LV_loc_model.add_node(Dense(4096, activation='relu'), name='fc-2', input='dropout-1')
LV_loc_model.add_node(Dropout(0.5), name='dropout-2', input='fc-2')
LV_loc_model.add_node(Reshape((64, 8, 8)), name='reshape', input='dropout-2')

LV_loc_model.add_node(UpSampling2D(size=(2, 2), dim_ordering='th'), name='unpool-1', input='reshape')
LV_loc_model.add_node(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='deconv-1-1', input='unpool-1')
LV_loc_model.add_node(BatchNormalization(), name='deconv-1-1-bn', input='deconv-1-1')
LV_loc_model.add_node(ELU(), name='deconv-1-1-activ', input='deconv-1-1-bn')
LV_loc_model.add_node(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='deconv-1-2', input='deconv-1-1-activ')
LV_loc_model.add_node(BatchNormalization(), name='deconv-1-2-bn', input='deconv-1-2')
LV_loc_model.add_node(ELU(), name='deconv-1-2-activ', input='deconv-1-2-bn')
LV_loc_model.add_node(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='deconv-1-3', input='deconv-1-2-activ')
LV_loc_model.add_node(BatchNormalization(), name='deconv-1-3-bn', input='deconv-1-3')
LV_loc_model.add_node(ELU(), name='deconv-1-3-activ', input='deconv-1-3-bn')

LV_loc_model.add_node(UpSampling2D(size=(2, 2), dim_ordering='th'), name='unpool-2', input='deconv-1-3-activ')
LV_loc_model.add_node(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='deconv-2-1', input='unpool-2')
LV_loc_model.add_node(BatchNormalization(), name='deconv-2-1-bn', input='deconv-2-1')
LV_loc_model.add_node(ELU(), name='deconv-2-1-activ', input='deconv-2-1-bn')
LV_loc_model.add_node(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='deconv-2-2', input='deconv-2-1-activ')
LV_loc_model.add_node(BatchNormalization(), name='deconv-2-2-bn', input='deconv-2-2')
LV_loc_model.add_node(ELU(), name='deconv-2-2-activ', input='deconv-2-2-bn')
LV_loc_model.add_node(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='deconv-2-3', input='deconv-2-2-activ')
LV_loc_model.add_node(BatchNormalization(), name='deconv-2-3-bn', input='deconv-2-3')
LV_loc_model.add_node(ELU(), name='deconv-2-3-activ', input='deconv-2-3-bn')

LV_loc_model.add_node(UpSampling2D(size=(2, 2), dim_ordering='th'), name='unpool-3', input='deconv-2-3-activ')
LV_loc_model.add_node(Convolution2D(128, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='deconv-3-1', input='unpool-3')
LV_loc_model.add_node(BatchNormalization(), name='deconv-3-1-bn', input='deconv-3-1')
LV_loc_model.add_node(ELU(), name='deconv-3-1-activ', input='deconv-3-1-bn')
LV_loc_model.add_node(Convolution2D(128, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='deconv-3-2', input='deconv-3-1-activ')
LV_loc_model.add_node(BatchNormalization(), name='deconv-3-2-bn', input='deconv-3-2')
LV_loc_model.add_node(ELU(), name='deconv-3-2-activ', input='deconv-3-2-bn')
LV_loc_model.add_node(Convolution2D(128, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='deconv-3-3', input='deconv-3-2-activ')
LV_loc_model.add_node(BatchNormalization(), name='deconv-3-3-bn', input='deconv-3-3')
LV_loc_model.add_node(ELU(), name='deconv-3-3-activ', input='deconv-3-3-bn')

LV_loc_model.add_node(UpSampling2D(size=(2, 2), dim_ordering='th'), name='unpool-4', input='deconv-3-3-activ')
LV_loc_model.add_node(Convolution2D(64, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='deconv-4-1', input='unpool-4')
LV_loc_model.add_node(BatchNormalization(), name='deconv-4-1-bn', input='deconv-4-1')
LV_loc_model.add_node(ELU(), name='deconv-4-1-activ', input='deconv-4-1-bn')
LV_loc_model.add_node(Convolution2D(64, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='deconv-4-2', input='deconv-4-1-activ')
LV_loc_model.add_node(BatchNormalization(), name='deconv-4-2-bn', input='deconv-4-2')
LV_loc_model.add_node(ELU(), name='deconv-4-2-activ', input='deconv-4-2-bn')

LV_loc_model.add_node(UpSampling2D(size=(2, 2), dim_ordering='th'), name='unpool-5', input='deconv-4-2-activ')
LV_loc_model.add_node(Convolution2D(32, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='deconv-5-1', input='unpool-5')
LV_loc_model.add_node(BatchNormalization(), name='deconv-5-1-bn', input='deconv-5-1')
LV_loc_model.add_node(ELU(), name='deconv-5-1-activ', input='deconv-5-1-bn')
LV_loc_model.add_node(Convolution2D(32, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
               name='deconv-5-2', input='deconv-5-1-activ')
LV_loc_model.add_node(BatchNormalization(), name='deconv-5-2-bn', input='deconv-5-2')
LV_loc_model.add_node(ELU(), name='deconv-5-2-activ', input='deconv-5-2-bn')

LV_loc_model.add_node(Convolution2D(1, 1, 1, activation='sigmoid', init='uniform', border_mode='same', dim_ordering='th'),
               name='prob-map', input='deconv-5-2-activ')
LV_loc_model.add_node(Reshape((256, 256)), name='prob-map-reshape', input='prob-map')
LV_loc_model.add_node(Dropout(0.5), name='prob-map-dropout', input='prob-map-reshape')

LV_loc_model.add_node(GRU(256, activation='tanh', inner_activation='hard_sigmoid', return_sequences=True),
               name='rnn-we', input='prob-map-dropout')
LV_loc_model.add_node(GRU(256, activation='tanh', inner_activation='hard_sigmoid', go_backwards=True, return_sequences=True),
               name='rnn-ew', input='prob-map-dropout')
LV_loc_model.add_node(TimeDistributedDense(256, init='uniform', activation='sigmoid'),
               name='rnn-1', inputs=['rnn-we', 'rnn-ew'], merge_mode='concat', concat_axis=-1)

LV_loc_model.add_node(Rotate90(direction='counterclockwise'), name='rotate', input='prob-map-dropout')
LV_loc_model.add_node(GRU(256, activation='tanh', inner_activation='hard_sigmoid', return_sequences=True),
               name='rnn-ns', input='rotate')
LV_loc_model.add_node(GRU(256, activation='tanh', inner_activation='hard_sigmoid', go_backwards=True, return_sequences=True),
               name='rnn-sn', input='rotate')
LV_loc_model.add_node(TimeDistributedDense(256, init='uniform', activation='sigmoid'),
               name='rnn-2-rotated', inputs=['rnn-ns', 'rnn-sn'], merge_mode='concat', concat_axis=-1)
LV_loc_model.add_node(Rotate90(direction='clockwise'), name='rnn-2', input='rnn-2-rotated')

LV_loc_model.add_node(Activation('linear'), name='pre-output', inputs=['rnn-1', 'rnn-2'], merge_mode='mul')
LV_loc_model.add_output(name='output', input='pre-output')

LV_loc_model.compile('adam', {'output': binaryCE})

LV_loc_model.load_weights('../../model_weights/weights_trainset2_full.hdf5')

def pred_loc_map(image_stack):
    preds = LV_loc_model.predict({'input': np.expand_dims(np.array(image_stack).astype(np.float32), axis=1)}, 
                                 verbose=0)['output']
    return [preds[i,:,:] for i in range(preds.shape[0])]

In [27]:
def create_MIP(filepaths, full_size=256, frame=0):
    series_filepaths_all = get_all_series_filepaths(filepaths)

    vol3d = []
    for series_filepaths in series_filepaths_all:
        fname, fpath = natsorted(series_filepaths, lambda x: x[0])[frame]
        df = dicom.read_file(fpath)
        img2d = df.pixel_array
        vol3d.append(apply_per_slice_norm(crop_to_square(img2d, full_size)).astype(np.float32))
    orig_shape = img2d.shape
    pixel_spacing = df.PixelSpacing
    vol3d_mask = pred_loc_map(vol3d)

    vol3d_MIP = np.mean(np.array(vol3d), axis=0)
    vol3d_mask_MIP = np.mean(np.array(vol3d_mask), axis=0)

    return vol3d_MIP, vol3d_mask_MIP, orig_shape, pixel_spacing


def get_MIP_centroid(vol3d_mask_MIP):
    return ndimage.measurements.center_of_mass(vol3d_mask_MIP)


def create_localized_image_stack(filepaths, centroid, full_size=256, local_size=96, frame=0):
    series_filepaths_all = get_all_series_filepaths(filepaths)
                        
    # sort series by z-locations
    z_locs = []
    for series_filepaths in series_filepaths_all:
        df = dicom.read_file(natsorted(series_filepaths, lambda x: x[0])[frame][1])
        z_locs.append(abs(float(df.SliceLocation)))
    series_filepaths_all_zsorted = sorted(zip(z_locs, series_filepaths_all), key=lambda pair: pair[0])
    
    series_filepaths_all_zsorted_with_depths = []
    for i in range(len(series_filepaths_all_zsorted)):
        z_loc, series_filepaths = series_filepaths_all_zsorted[i]
        if i == len(series_filepaths_all_zsorted) - 1:
            z_depth = float(dicom.read_file(natsorted(series_filepaths, lambda x: x[0])[frame][1]).SliceThickness)
        else:
            z_depth = series_filepaths_all_zsorted[i+1][0] - z_loc

        # filter out tiny depths, which are likely repeats
        if z_depth > 0.01:
            series_filepaths_all_zsorted_with_depths.append((z_depth, series_filepaths))

    img_stack = []
    z_depths = []
    for z_depth, series_filepaths in series_filepaths_all_zsorted_with_depths:
        fname, fpath = natsorted(series_filepaths, lambda x: x[0])[frame]
        df = dicom.read_file(fpath)
        img = df.pixel_array
        img_localized, _, _ = localize_to_centroid(crop_to_square(img, full_size), centroid, local_size)
        img_processed = apply_per_slice_norm(img_localized)
        img_stack.append(img_processed.astype(np.float32))
        z_depths.append(z_depth)
    orig_shape = img.shape
    pixel_spacing = [float(s) for s in df.PixelSpacing]
    
    #img_stack_masks = pred_seg_map(img_stack)
    img_stack_masks = None

    return img_stack, img_stack_masks, z_depths, orig_shape, pixel_spacing


def create_full_image_stack(filepaths, full_size=256, frame=0):
    series_filepaths_all = get_all_series_filepaths(filepaths)
                        
    # sort series by z-locations
    z_locs = []
    for series_filepaths in series_filepaths_all:
        df = dicom.read_file(natsorted(series_filepaths, lambda x: x[0])[frame][1])
        z_locs.append(abs(float(df.SliceLocation)))
    series_filepaths_all_zsorted = sorted(zip(z_locs, series_filepaths_all), key=lambda pair: pair[0])
    
    series_filepaths_all_zsorted_with_depths = []
    for i in range(len(series_filepaths_all_zsorted)):
        z_loc, series_filepaths = series_filepaths_all_zsorted[i]
        if i == len(series_filepaths_all_zsorted) - 1:
            z_depth = float(dicom.read_file(natsorted(series_filepaths, lambda x: x[0])[frame][1]).SliceThickness)
        else:
            z_depth = series_filepaths_all_zsorted[i+1][0] - z_loc

        # filter out tiny depths, which are likely repeats
        if z_depth > 0.01:
            series_filepaths_all_zsorted_with_depths.append((z_depth, series_filepaths))

    img_stack = []
    z_depths = []
    for z_depth, series_filepaths in series_filepaths_all_zsorted_with_depths:
        fname, fpath = natsorted(series_filepaths, lambda x: x[0])[frame]
        df = dicom.read_file(fpath)
        img = df.pixel_array
        img_processed = apply_per_slice_norm(crop_to_square(img, full_size))
        img_stack.append(img_processed.astype(np.float32))
        z_depths.append(z_depth)
    orig_shape = img.shape
    pixel_spacing = [float(s) for s in df.PixelSpacing]
    
    img_stack_masks = pred_loc_map(img_stack)

    return img_stack, img_stack_masks, z_depths, orig_shape, pixel_spacing

### create training data

In [28]:
train_val_split = 0.05
pts_train_val = random.sample(list(range(1, 501)), int(500 * train_val_split))
pts_train = list(set(range(1, 501)) - set(pts_train_val))

In [30]:
data_ED_train_batches = []
data_ED_train_val_batches = []
data_ES_train_batches = []
data_ES_train_val_batches = []

data_ED_val_batches = []
data_ES_val_batches = []
pt_indices_val = []

for pt in tqdm(range(1, 701)):

    if pt <= 500:
        filepaths = filepaths_train[pt]
    else:
        filepaths = filepaths_val[pt]
    
    vol3d_MIP_ED, vol3d_mask_MIP_ED, orig_shape_ED, pixel_spacing_ED = \
        create_MIP(filepaths, full_size=256, frame=0)
    centroid_ED = get_MIP_centroid(vol3d_mask_MIP_ED)
    img_stack_ED, img_stack_masks_ED, z_depths_ED, orig_shape_ED, pixel_spacing_ED = \
        create_localized_image_stack(filepaths, centroid_ED, full_size=256, local_size=96, frame=0)
    
    vol3d_MIP_ES, vol3d_mask_MIP_ES, orig_shape_ES, pixel_spacing_ES = \
        create_MIP(filepaths, full_size=256, frame=pt_es_frame[pt])
    centroid_ES = get_MIP_centroid(vol3d_mask_MIP_ES)
    img_stack_ES, img_stack_masks_ES, z_depths_ES, orig_shape_ES, pixel_spacing_ES = \
        create_localized_image_stack(filepaths, centroid_ES, full_size=256, local_size=96, frame=pt_es_frame[pt])
        
    scaling_array_ED = (np.array(z_depths_ED).astype(np.float32) / 10) * \
        (pixel_spacing_ED[0] / 10) * (pixel_spacing_ED[1] / 10) * \
        (min(orig_shape_ED) / 256) * (min(orig_shape_ED) / 256)
        
    scaling_array_ES = (np.array(z_depths_ES).astype(np.float32) / 10) * \
        (pixel_spacing_ES[0] / 10) * (pixel_spacing_ES[1] / 10) * \
        (min(orig_shape_ES) / 256) * (min(orig_shape_ES) / 256)
        
    data_ED_batch = {
        'input': np.expand_dims(np.array(img_stack_ED).astype(np.float32), axis=1),
        'scaling': np.expand_dims(scaling_array_ED, axis=1)
    }
    data_ES_batch = {
        'input': np.expand_dims(np.array(img_stack_ES).astype(np.float32), axis=1),
        'scaling': np.expand_dims(scaling_array_ES, axis=1)
    }
    
    if pt <= 500:
        vol_ED, stepfunc_ED = create_label(pt, mode='ED')
        data_ED_batch['volume'] = vol_ED
        data_ED_batch['volume_cdf'] = stepfunc_ED.astype(np.float32)
        vol_ES, stepfunc_ES = create_label(pt, mode='ES')
        data_ES_batch['volume'] = vol_ES
        data_ES_batch['volume_cdf'] = stepfunc_ES.astype(np.float32)
        
        if pt in pts_train:
            data_ED_train_batches.append(data_ED_batch)
            data_ES_train_batches.append(data_ES_batch)
        else:
            data_ED_train_val_batches.append(data_ED_batch)
            data_ES_train_val_batches.append(data_ES_batch)
    
    else:
        data_ED_val_batches.append(data_ED_batch)
        data_ES_val_batches.append(data_ES_batch)
        pt_indices_val.append(pt)

print('train ED\n',
      len(data_ED_train_batches), data_ED_train_batches[0]['input'].shape, data_ED_train_batches[0]['scaling'].shape, 
      data_ED_train_batches[0]['volume_cdf'].shape,
      '\ntrain val ED\n',
      len(data_ED_train_val_batches), data_ED_train_val_batches[0]['input'].shape, data_ED_train_val_batches[0]['scaling'].shape, 
      data_ED_train_val_batches[0]['volume_cdf'].shape,
      '\ntrain ES\n',
      len(data_ES_train_batches), data_ES_train_batches[0]['input'].shape, data_ES_train_batches[0]['scaling'].shape, 
      data_ES_train_batches[0]['volume_cdf'].shape,
      '\ntrain val ES\n',
      len(data_ES_train_val_batches), data_ES_train_val_batches[0]['input'].shape, data_ES_train_val_batches[0]['scaling'].shape, 
      data_ES_train_val_batches[0]['volume_cdf'].shape,
      '\nval ED\n',
      len(data_ED_val_batches), data_ED_val_batches[0]['input'].shape, data_ED_val_batches[0]['scaling'].shape, 
      '\nval ES\n',
      len(data_ES_val_batches), data_ES_val_batches[0]['input'].shape, data_ES_val_batches[0]['scaling'].shape)

                                                 

train ED
 475 (11, 1, 96, 96) (11, 1) (600,) 
train val ED
 25 (10, 1, 96, 96) (10, 1) (600,) 
train ES
 475 (11, 1, 96, 96) (11, 1) (600,) 
train val ES
 25 (10, 1, 96, 96) (10, 1) (600,) 
val ED
 200 (9, 1, 96, 96) (9, 1) 
val ES
 200 (9, 1, 96, 96) (9, 1)




In [31]:
with open('../../data_proc/data_localized_transfer_learning.pkl', 'wb') as fout:
    pickle.dump((pts_train, pts_train_val, 
                 data_ED_train_batches, data_ED_train_val_batches, 
                 data_ES_train_batches, data_ES_train_val_batches,
                 data_ED_val_batches, data_ES_val_batches, pt_indices_val), fout)

### test CDF creation from segmentation probability map thresholds

In [7]:
import theano
from theano import tensor as T

psteps = 10
y = theano.shared(value=np.array([i / psteps for i in range(psteps + 1)]), strict=False)
x = T.tensor3()
z = T.sum(T.switch(T.gt(T.tile(T.flatten(x, outdim=2), (psteps + 1, 1, 1)).dimshuffle(1, 2, 0), y), 1, 0), axis=1)
f = theano.function([x], z)
f(np.arange(0, 1, 0.05).reshape((2,2,5)).astype(np.float32))

array([[ 9,  8,  6,  4,  2,  0,  0,  0,  0,  0,  0],
       [10, 10, 10, 10, 10,  9,  8,  5,  4,  1,  0]])

In [11]:
psteps = 10
zsteps = 20
y = theano.shared(value=np.array([i/psteps for i in range(psteps+1)]), strict=False)
x = T.tensor3()
z = T.sum(T.switch(T.gt(T.tile(T.flatten(x), (psteps+1,1)).dimshuffle(1,0), y), 1, 0), axis=0)
zprime = y[::-1][T.argmin(T.switch(T.gt(T.tile(z, (zsteps, 1)), 
                                        T.tile(theano.shared(value=np.arange(zsteps)), (psteps+1,1)).T), 1, 0), 
                          axis=1)]
fprime = theano.function([x], zprime)
fprime(np.arange(0, 1, 0.05).reshape((1,4,5)).astype(np.float32))

array([ 0. ,  0.1,  0.1,  0.1,  0.2,  0.3,  0.3,  0.3,  0.4,  0.5,  0.5,
        0.5,  0.6,  0.6,  0.7,  0.7,  0.8,  0.8,  0.9,  1. ])

In [29]:
import theano
from theano import tensor as T


x = T.tensor3()
z = T.sum(T.switch(T.ge(T.flatten(x, outdim=2), 0.5), 1, 0), axis=1, keepdims=True)
f = theano.function([x], z)
f(np.arange(0, 1, 0.05).reshape((2,2,5)).astype(np.float32)).shape

(2, 1)

In [68]:
np.tile(np.zeros((1,1)), (5,1)).shape

(5, 1)

In [90]:
import theano
from theano import tensor as T


x = T.matrix()
zz = T.mean(T.abs_(T.sum(T.switch(T.ge(T.flatten(x, outdim=2), 0.5), 1, 0), axis=1, keepdims=True) - \
                  theano.shared(value=np.ones((5,1)))), axis=0)
z = T.mean(T.abs_(T.sum(x, axis=0, keepdims=True) - theano.shared(value=np.ones((5,1)))), axis=None)
f = theano.function([x], z)
f(np.zeros((5,1)).astype(np.float32))

array(1.0)

### define neural net for transfer learning

In [91]:
PROB_STEPS = 1000
VOL_STEPS = 600

################################################
# DEFINE 2D LOSS FUNCTIONS
################################################

def MAE(y_true, y_pred):
    return K.mean(K.abs(K.sum(y_pred, axis=0, keepdims=True) - y_true), axis=None)

def CRPS(y_true, y_pred):
    # not differentiable currently
    thresh = K.variable(np.array([i / PROB_STEPS for i in range(PROB_STEPS + 1)]))
    vol = K.sum(y_pred, axis=0)
    cdf = thresh[::-1][T.argmin(T.switch(T.gt(T.tile(vol, (VOL_STEPS, 1)), 
                                              T.tile(K.variable(np.arange(VOL_STEPS)), (PROB_STEPS + 1, 1)).T), 
                                         1, 0), 
                                axis=1)]
    return K.mean(K.square(cdf - y_true), axis=None)
    
    
################################################
# DEFINE LAMBDA LAYER FUNCTIONS
################################################

def thresholds_to_areas(X):
    from keras import backend as K
    from theano import tensor as T
    PROB_STEPS = 1000
    VOL_STEPS = 600
    return K.sum(T.switch(T.gt(T.tile(T.flatten(X, outdim=2), (PROB_STEPS + 1, 1, 1)).dimshuffle(1, 2, 0), 
                                K.variable(np.array([i / PROB_STEPS for i in range(PROB_STEPS + 1)]))), 
                           1, 0), 
                  axis=1)

def scaling_tensor(X):
    from keras import backend as K
    from theano import tensor as T
    PROB_STEPS = 1000
    VOL_STEPS = 600
    return T.tile(X, (1, PROB_STEPS + 1))

def to_area(X):
    from theano import tensor as T
    return T.sum(T.switch(T.ge(T.flatten(X, outdim=2), 0.5), 1, 0), axis=1, keepdims=True)


################################################
# DEFINE ROTATE90 LAYER
################################################

class Rotate90(Layer):
    def __init__(self, direction='clockwise', **kwargs):
        super(Rotate90, self).__init__(**kwargs)
        self.direction = direction

    def get_output(self, train):
        X = self.get_input(train)
        if self.direction == 'clockwise':
            return X.transpose((0, 2, 1))[:, :, ::-1]
        elif self.direction == 'counterclockwise':
            return X.transpose((0, 2, 1))[:, ::-1, :]
        else:
            raise

    def get_config(self):
        config = {"name": self.__class__.__name__}
        base_config = super(Rotate90, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
    

################################################
# DEFINE NEURAL NETWORKS
################################################

# ED

vol_model_ED = Graph()

vol_model_ED.add_input(name='input', input_shape=(1, 96, 96))

vol_model_ED.add_node(Convolution2D(64, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='conv-1-1', input='input')
vol_model_ED.add_node(BatchNormalization(), name='conv-1-1-bn', input='conv-1-1')
vol_model_ED.add_node(ELU(), name='conv-1-1-activ', input='conv-1-1-bn')
vol_model_ED.add_node(Convolution2D(64, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='conv-1-2', input='conv-1-1-activ')
vol_model_ED.add_node(BatchNormalization(), name='conv-1-2-bn', input='conv-1-2')
vol_model_ED.add_node(ELU(), name='conv-1-2-activ', input='conv-1-2-bn')
vol_model_ED.add_node(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'),
                      name='pool-1', input='conv-1-2-activ')

vol_model_ED.add_node(Convolution2D(128, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='conv-2-1', input='pool-1')
vol_model_ED.add_node(BatchNormalization(), name='conv-2-1-bn', input='conv-2-1')
vol_model_ED.add_node(ELU(), name='conv-2-1-activ', input='conv-2-1-bn')
vol_model_ED.add_node(Convolution2D(128, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='conv-2-2', input='conv-2-1-activ')
vol_model_ED.add_node(BatchNormalization(), name='conv-2-2-bn', input='conv-2-2')
vol_model_ED.add_node(ELU(), name='conv-2-2-activ', input='conv-2-2-bn')
vol_model_ED.add_node(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'),
                      name='pool-2', input='conv-2-2-activ')

vol_model_ED.add_node(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='conv-3-1', input='pool-2')
vol_model_ED.add_node(BatchNormalization(), name='conv-3-1-bn', input='conv-3-1')
vol_model_ED.add_node(ELU(), name='conv-3-1-activ', input='conv-3-1-bn')
vol_model_ED.add_node(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='conv-3-2', input='conv-3-1-activ')
vol_model_ED.add_node(BatchNormalization(), name='conv-3-2-bn', input='conv-3-2')
vol_model_ED.add_node(ELU(), name='conv-3-2-activ', input='conv-3-2-bn')
vol_model_ED.add_node(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='conv-3-3', input='conv-3-2-activ')
vol_model_ED.add_node(BatchNormalization(), name='conv-3-3-bn', input='conv-3-3')
vol_model_ED.add_node(ELU(), name='conv-3-3-activ', input='conv-3-3-bn')
vol_model_ED.add_node(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'),
                      name='pool-3', input='conv-3-3-activ')

vol_model_ED.add_node(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='conv-4-1', input='pool-3')
vol_model_ED.add_node(BatchNormalization(), name='conv-4-1-bn', input='conv-4-1')
vol_model_ED.add_node(ELU(), name='conv-4-1-activ', input='conv-4-1-bn')
vol_model_ED.add_node(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='conv-4-2', input='conv-4-1-activ')
vol_model_ED.add_node(BatchNormalization(), name='conv-4-2-bn', input='conv-4-2')
vol_model_ED.add_node(ELU(), name='conv-4-2-activ', input='conv-4-2-bn')
vol_model_ED.add_node(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='conv-4-3', input='conv-4-2-activ')
vol_model_ED.add_node(BatchNormalization(), name='conv-4-3-bn', input='conv-4-3')
vol_model_ED.add_node(ELU(), name='conv-4-3-activ', input='conv-4-3-bn')
vol_model_ED.add_node(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'),
                      name='pool-4', input='conv-4-3-activ')

vol_model_ED.add_node(Flatten(), name='flatten', input='pool-4')
vol_model_ED.add_node(Dense(2304, activation='relu'), name='fc-1', input='flatten')
vol_model_ED.add_node(Dropout(0.5), name='dropout-1', input='fc-1')
vol_model_ED.add_node(Dense(2304, activation='relu'), name='fc-2', input='dropout-1')
vol_model_ED.add_node(Dropout(0.5), name='dropout-2', input='fc-2')
vol_model_ED.add_node(Reshape((64, 6, 6)), name='reshape', input='dropout-2')

vol_model_ED.add_node(UpSampling2D(size=(2, 2), dim_ordering='th'), name='unpool-1', input='reshape')
vol_model_ED.add_node(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='deconv-1-1', input='unpool-1')
vol_model_ED.add_node(BatchNormalization(), name='deconv-1-1-bn', input='deconv-1-1')
vol_model_ED.add_node(ELU(), name='deconv-1-1-activ', input='deconv-1-1-bn')
vol_model_ED.add_node(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='deconv-1-2', input='deconv-1-1-activ')
vol_model_ED.add_node(BatchNormalization(), name='deconv-1-2-bn', input='deconv-1-2')
vol_model_ED.add_node(ELU(), name='deconv-1-2-activ', input='deconv-1-2-bn')
vol_model_ED.add_node(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='deconv-1-3', input='deconv-1-2-activ')
vol_model_ED.add_node(BatchNormalization(), name='deconv-1-3-bn', input='deconv-1-3')
vol_model_ED.add_node(ELU(), name='deconv-1-3-activ', input='deconv-1-3-bn')

vol_model_ED.add_node(UpSampling2D(size=(2, 2), dim_ordering='th'), name='unpool-2', input='deconv-1-3-activ')
vol_model_ED.add_node(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='deconv-2-1', input='unpool-2')
vol_model_ED.add_node(BatchNormalization(), name='deconv-2-1-bn', input='deconv-2-1')
vol_model_ED.add_node(ELU(), name='deconv-2-1-activ', input='deconv-2-1-bn')
vol_model_ED.add_node(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='deconv-2-2', input='deconv-2-1-activ')
vol_model_ED.add_node(BatchNormalization(), name='deconv-2-2-bn', input='deconv-2-2')
vol_model_ED.add_node(ELU(), name='deconv-2-2-activ', input='deconv-2-2-bn')
vol_model_ED.add_node(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='deconv-2-3', input='deconv-2-2-activ')
vol_model_ED.add_node(BatchNormalization(), name='deconv-2-3-bn', input='deconv-2-3')
vol_model_ED.add_node(ELU(), name='deconv-2-3-activ', input='deconv-2-3-bn')

vol_model_ED.add_node(UpSampling2D(size=(2, 2), dim_ordering='th'), name='unpool-3', input='deconv-2-3-activ')
vol_model_ED.add_node(Convolution2D(128, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='deconv-3-1', input='unpool-3')
vol_model_ED.add_node(BatchNormalization(), name='deconv-3-1-bn', input='deconv-3-1')
vol_model_ED.add_node(ELU(), name='deconv-3-1-activ', input='deconv-3-1-bn')
vol_model_ED.add_node(Convolution2D(128, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='deconv-3-2', input='deconv-3-1-activ')
vol_model_ED.add_node(BatchNormalization(), name='deconv-3-2-bn', input='deconv-3-2')
vol_model_ED.add_node(ELU(), name='deconv-3-2-activ', input='deconv-3-2-bn')

vol_model_ED.add_node(UpSampling2D(size=(2, 2), dim_ordering='th'), name='unpool-4', input='deconv-3-2-activ')
vol_model_ED.add_node(Convolution2D(64, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='deconv-4-1', input='unpool-4')
vol_model_ED.add_node(BatchNormalization(), name='deconv-4-1-bn', input='deconv-4-1')
vol_model_ED.add_node(ELU(), name='deconv-4-1-activ', input='deconv-4-1-bn')
vol_model_ED.add_node(Convolution2D(64, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='deconv-4-2', input='deconv-4-1-activ')
vol_model_ED.add_node(BatchNormalization(), name='deconv-4-2-bn', input='deconv-4-2')
vol_model_ED.add_node(ELU(), name='deconv-4-2-activ', input='deconv-4-2-bn')

vol_model_ED.add_node(Convolution2D(1, 1, 1, activation='sigmoid', init='uniform', border_mode='same', dim_ordering='th'),
                      name='prob-map', input='deconv-4-2-activ')
vol_model_ED.add_node(Reshape((96, 96)), name='prob-map-reshape', input='prob-map')
vol_model_ED.add_node(Dropout(0.5), name='prob-map-dropout', input='prob-map-reshape')

vol_model_ED.add_node(GRU(96, activation='tanh', inner_activation='hard_sigmoid', return_sequences=True),
                      name='rnn-we', input='prob-map-dropout')
vol_model_ED.add_node(GRU(96, activation='tanh', inner_activation='hard_sigmoid', go_backwards=True, return_sequences=True),
                      name='rnn-ew', input='prob-map-dropout')
vol_model_ED.add_node(TimeDistributedDense(96, init='uniform', activation='sigmoid'),
                      name='rnn-1', inputs=['rnn-we', 'rnn-ew'], merge_mode='concat', concat_axis=-1)

vol_model_ED.add_node(Rotate90(direction='counterclockwise'), name='rotate', input='prob-map-dropout')
vol_model_ED.add_node(GRU(96, activation='tanh', inner_activation='hard_sigmoid', return_sequences=True),
                      name='rnn-ns', input='rotate')
vol_model_ED.add_node(GRU(96, activation='tanh', inner_activation='hard_sigmoid', go_backwards=True, return_sequences=True),
                      name='rnn-sn', input='rotate')
vol_model_ED.add_node(TimeDistributedDense(96, init='uniform', activation='sigmoid'),
                      name='rnn-2-rotated', inputs=['rnn-ns', 'rnn-sn'], merge_mode='concat', concat_axis=-1)
vol_model_ED.add_node(Rotate90(direction='clockwise'), name='rnn-2', input='rnn-2-rotated')

vol_model_ED.add_node(Activation('linear'), name='pre-output', inputs=['rnn-1', 'rnn-2'], merge_mode='mul')

vol_model_ED.load_weights('../../model_weights/weights_trainset2_local.hdf5')

# freeze early layers
for name, layer in vol_model_ED.nodes.items():
    #if name not in ['prob-map', 'rnn-we', 'rnn-ew', 'rnn-1', 'rnn-ns', 'rnn-sn', 'rnn-2-rotated']:
    layer.trainable = False
        
# additional layers on top of segmentation layers        

#vol_model_ED.add_input(name='scaling', input_shape=(1,))
#vol_model_ED.add_node(Lambda(scaling_tensor, output_shape=(PROB_STEPS + 1,)),
#                      name='scaling-tensor', input='scaling')
#vol_model_ED.add_node(Lambda(thresholds_to_areas, output_shape=(PROB_STEPS + 1,)), 
#                      name='areas', input='pre-output')
#vol_model_ED.add_node(Activation('linear'), 
#                      name='areas-scaled', inputs=['areas', 'scaling-tensor'], merge_mode='mul')
#vol_model_ED.add_output(name='volume_cdf', input='areas-scaled')
#vol_model_ED.compile('adam', {'volume_cdf': CRPS})

vol_model_ED.add_input(name='scaling', input_shape=(1,))
vol_model_ED.add_node(Lambda(to_area, output_shape=(1,)),
                      name='area', input='pre-output')
vol_model_ED.add_node(Activation('linear'),
                      name='area-scaled', inputs=['area', 'scaling'], merge_mode='mul')
vol_model_ED.add_output(name='volume', input='area-scaled')

vol_model_ED.compile('adam', {'volume': MAE})


# ES

vol_model_ES = Graph()

vol_model_ES.add_input(name='input', input_shape=(1, 96, 96))

vol_model_ES.add_node(Convolution2D(64, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='conv-1-1', input='input')
vol_model_ES.add_node(BatchNormalization(), name='conv-1-1-bn', input='conv-1-1')
vol_model_ES.add_node(ELU(), name='conv-1-1-activ', input='conv-1-1-bn')
vol_model_ES.add_node(Convolution2D(64, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='conv-1-2', input='conv-1-1-activ')
vol_model_ES.add_node(BatchNormalization(), name='conv-1-2-bn', input='conv-1-2')
vol_model_ES.add_node(ELU(), name='conv-1-2-activ', input='conv-1-2-bn')
vol_model_ES.add_node(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'),
                      name='pool-1', input='conv-1-2-activ')

vol_model_ES.add_node(Convolution2D(128, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='conv-2-1', input='pool-1')
vol_model_ES.add_node(BatchNormalization(), name='conv-2-1-bn', input='conv-2-1')
vol_model_ES.add_node(ELU(), name='conv-2-1-activ', input='conv-2-1-bn')
vol_model_ES.add_node(Convolution2D(128, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='conv-2-2', input='conv-2-1-activ')
vol_model_ES.add_node(BatchNormalization(), name='conv-2-2-bn', input='conv-2-2')
vol_model_ES.add_node(ELU(), name='conv-2-2-activ', input='conv-2-2-bn')
vol_model_ES.add_node(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'),
                      name='pool-2', input='conv-2-2-activ')

vol_model_ES.add_node(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='conv-3-1', input='pool-2')
vol_model_ES.add_node(BatchNormalization(), name='conv-3-1-bn', input='conv-3-1')
vol_model_ES.add_node(ELU(), name='conv-3-1-activ', input='conv-3-1-bn')
vol_model_ES.add_node(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='conv-3-2', input='conv-3-1-activ')
vol_model_ES.add_node(BatchNormalization(), name='conv-3-2-bn', input='conv-3-2')
vol_model_ES.add_node(ELU(), name='conv-3-2-activ', input='conv-3-2-bn')
vol_model_ES.add_node(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='conv-3-3', input='conv-3-2-activ')
vol_model_ES.add_node(BatchNormalization(), name='conv-3-3-bn', input='conv-3-3')
vol_model_ES.add_node(ELU(), name='conv-3-3-activ', input='conv-3-3-bn')
vol_model_ES.add_node(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'),
                      name='pool-3', input='conv-3-3-activ')

vol_model_ES.add_node(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='conv-4-1', input='pool-3')
vol_model_ES.add_node(BatchNormalization(), name='conv-4-1-bn', input='conv-4-1')
vol_model_ES.add_node(ELU(), name='conv-4-1-activ', input='conv-4-1-bn')
vol_model_ES.add_node(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='conv-4-2', input='conv-4-1-activ')
vol_model_ES.add_node(BatchNormalization(), name='conv-4-2-bn', input='conv-4-2')
vol_model_ES.add_node(ELU(), name='conv-4-2-activ', input='conv-4-2-bn')
vol_model_ES.add_node(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='conv-4-3', input='conv-4-2-activ')
vol_model_ES.add_node(BatchNormalization(), name='conv-4-3-bn', input='conv-4-3')
vol_model_ES.add_node(ELU(), name='conv-4-3-activ', input='conv-4-3-bn')
vol_model_ES.add_node(MaxPooling2D(pool_size=(2,2), strides=None, border_mode='valid', dim_ordering='th'),
                      name='pool-4', input='conv-4-3-activ')

vol_model_ES.add_node(Flatten(), name='flatten', input='pool-4')
vol_model_ES.add_node(Dense(2304, activation='relu'), name='fc-1', input='flatten')
vol_model_ES.add_node(Dropout(0.5), name='dropout-1', input='fc-1')
vol_model_ES.add_node(Dense(2304, activation='relu'), name='fc-2', input='dropout-1')
vol_model_ES.add_node(Dropout(0.5), name='dropout-2', input='fc-2')
vol_model_ES.add_node(Reshape((64, 6, 6)), name='reshape', input='dropout-2')

vol_model_ES.add_node(UpSampling2D(size=(2, 2), dim_ordering='th'), name='unpool-1', input='reshape')
vol_model_ES.add_node(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='deconv-1-1', input='unpool-1')
vol_model_ES.add_node(BatchNormalization(), name='deconv-1-1-bn', input='deconv-1-1')
vol_model_ES.add_node(ELU(), name='deconv-1-1-activ', input='deconv-1-1-bn')
vol_model_ES.add_node(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='deconv-1-2', input='deconv-1-1-activ')
vol_model_ES.add_node(BatchNormalization(), name='deconv-1-2-bn', input='deconv-1-2')
vol_model_ES.add_node(ELU(), name='deconv-1-2-activ', input='deconv-1-2-bn')
vol_model_ES.add_node(Convolution2D(512, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='deconv-1-3', input='deconv-1-2-activ')
vol_model_ES.add_node(BatchNormalization(), name='deconv-1-3-bn', input='deconv-1-3')
vol_model_ES.add_node(ELU(), name='deconv-1-3-activ', input='deconv-1-3-bn')

vol_model_ES.add_node(UpSampling2D(size=(2, 2), dim_ordering='th'), name='unpool-2', input='deconv-1-3-activ')
vol_model_ES.add_node(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='deconv-2-1', input='unpool-2')
vol_model_ES.add_node(BatchNormalization(), name='deconv-2-1-bn', input='deconv-2-1')
vol_model_ES.add_node(ELU(), name='deconv-2-1-activ', input='deconv-2-1-bn')
vol_model_ES.add_node(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='deconv-2-2', input='deconv-2-1-activ')
vol_model_ES.add_node(BatchNormalization(), name='deconv-2-2-bn', input='deconv-2-2')
vol_model_ES.add_node(ELU(), name='deconv-2-2-activ', input='deconv-2-2-bn')
vol_model_ES.add_node(Convolution2D(256, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='deconv-2-3', input='deconv-2-2-activ')
vol_model_ES.add_node(BatchNormalization(), name='deconv-2-3-bn', input='deconv-2-3')
vol_model_ES.add_node(ELU(), name='deconv-2-3-activ', input='deconv-2-3-bn')

vol_model_ES.add_node(UpSampling2D(size=(2, 2), dim_ordering='th'), name='unpool-3', input='deconv-2-3-activ')
vol_model_ES.add_node(Convolution2D(128, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='deconv-3-1', input='unpool-3')
vol_model_ES.add_node(BatchNormalization(), name='deconv-3-1-bn', input='deconv-3-1')
vol_model_ES.add_node(ELU(), name='deconv-3-1-activ', input='deconv-3-1-bn')
vol_model_ES.add_node(Convolution2D(128, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='deconv-3-2', input='deconv-3-1-activ')
vol_model_ES.add_node(BatchNormalization(), name='deconv-3-2-bn', input='deconv-3-2')
vol_model_ES.add_node(ELU(), name='deconv-3-2-activ', input='deconv-3-2-bn')

vol_model_ES.add_node(UpSampling2D(size=(2, 2), dim_ordering='th'), name='unpool-4', input='deconv-3-2-activ')
vol_model_ES.add_node(Convolution2D(64, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='deconv-4-1', input='unpool-4')
vol_model_ES.add_node(BatchNormalization(), name='deconv-4-1-bn', input='deconv-4-1')
vol_model_ES.add_node(ELU(), name='deconv-4-1-activ', input='deconv-4-1-bn')
vol_model_ES.add_node(Convolution2D(64, 3, 3, init='he_uniform', border_mode='same', dim_ordering='th'),
                      name='deconv-4-2', input='deconv-4-1-activ')
vol_model_ES.add_node(BatchNormalization(), name='deconv-4-2-bn', input='deconv-4-2')
vol_model_ES.add_node(ELU(), name='deconv-4-2-activ', input='deconv-4-2-bn')

vol_model_ES.add_node(Convolution2D(1, 1, 1, activation='sigmoid', init='uniform', border_mode='same', dim_ordering='th'),
                      name='prob-map', input='deconv-4-2-activ')
vol_model_ES.add_node(Reshape((96, 96)), name='prob-map-reshape', input='prob-map')
vol_model_ES.add_node(Dropout(0.5), name='prob-map-dropout', input='prob-map-reshape')

vol_model_ES.add_node(GRU(96, activation='tanh', inner_activation='hard_sigmoid', return_sequences=True),
                      name='rnn-we', input='prob-map-dropout')
vol_model_ES.add_node(GRU(96, activation='tanh', inner_activation='hard_sigmoid', go_backwards=True, return_sequences=True),
                      name='rnn-ew', input='prob-map-dropout')
vol_model_ES.add_node(TimeDistributedDense(96, init='uniform', activation='sigmoid'),
                      name='rnn-1', inputs=['rnn-we', 'rnn-ew'], merge_mode='concat', concat_axis=-1)

vol_model_ES.add_node(Rotate90(direction='counterclockwise'), name='rotate', input='prob-map-dropout')
vol_model_ES.add_node(GRU(96, activation='tanh', inner_activation='hard_sigmoid', return_sequences=True),
                      name='rnn-ns', input='rotate')
vol_model_ES.add_node(GRU(96, activation='tanh', inner_activation='hard_sigmoid', go_backwards=True, return_sequences=True),
                      name='rnn-sn', input='rotate')
vol_model_ES.add_node(TimeDistributedDense(96, init='uniform', activation='sigmoid'),
                      name='rnn-2-rotated', inputs=['rnn-ns', 'rnn-sn'], merge_mode='concat', concat_axis=-1)
vol_model_ES.add_node(Rotate90(direction='clockwise'), name='rnn-2', input='rnn-2-rotated')

vol_model_ES.add_node(Activation('linear'), name='pre-output', inputs=['rnn-1', 'rnn-2'], merge_mode='mul')

vol_model_ES.load_weights('../../model_weights/weights_trainset2_local.hdf5')

# freeze early layers
for name, layer in vol_model_ES.nodes.items():
    #if name not in ['prob-map', 'rnn-we', 'rnn-ew', 'rnn-1', 'rnn-ns', 'rnn-sn', 'rnn-2-rotated']:
    layer.trainable = False
        
# additional layers on top of segmentation layers        

#vol_model_ES.add_input(name='scaling', input_shape=(1,))
#vol_model_ES.add_node(Lambda(scaling_tensor, output_shape=(PROB_STEPS + 1,)),
#                      name='scaling-tensor', input='scaling')
#vol_model_ES.add_node(Lambda(thresholds_to_areas, output_shape=(PROB_STEPS + 1,)), 
#                      name='areas', input='pre-output')
#vol_model_ES.add_node(Activation('linear'), 
#                      name='areas-scaled', inputs=['areas', 'scaling-tensor'], merge_mode='mul')
#vol_model_ES.add_output(name='volume_cdf', input='areas-scaled')
#vol_model_ES.compile('adam', {'volume_cdf': CRPS})

vol_model_ES.add_input(name='scaling', input_shape=(1,))
vol_model_ES.add_node(Lambda(to_area, output_shape=(1,)),
                      name='area', input='pre-output')
vol_model_ES.add_node(Activation('linear'),
                      name='area-scaled', inputs=['area', 'scaling'], merge_mode='mul')
vol_model_ES.add_output(name='volume', input='area-scaled')

vol_model_ES.compile('adam', {'volume': MAE})

### training

In [92]:
with open('../../data_proc/data_localized_transfer_learning.pkl', 'rb') as f:
    (pts_train, pts_train_val, 
     data_ED_train_batches, data_ED_train_val_batches, 
     data_ES_train_batches, data_ES_train_val_batches,
     data_ED_val_batches, data_ES_val_batches, pt_indices_val) = pickle.load(f)

In [93]:
# train ED

nb_epochs = 1
train_rand_index = list(range(475))

loss_best = 1e6
val_loss_best = 1e6
patience = 0

for epoch in range(nb_epochs):
    random.shuffle(train_rand_index)
    loss_tot = 0
    val_loss_tot = 0
    loss_avg = 1e6
    val_loss_avg = 1e6
    
    print('ED training (epoch {})...'.format(epoch))
    for m, idx in enumerate(range(475)):
        batch = data_ED_train_batches[idx]
        batch['volume'] = np.tile(np.array([batch['volume']]), (batch['input'].shape[0], 1))
        outs = vol_model_ED.train_on_batch(batch)
        loss_tot += outs[0]
        loss_avg = loss_tot / (m+1)
        print('{}, {}'.format(outs[0], loss_avg))
    print('ED validation (epoch {})...'.format(epoch))
    for m, idx in enumerate(range(25)):
        batch = data_ED_train_val_batches[idx]
        batch['volume'] = np.tile(np.array([batch['volume']]), (batch['input'].shape[0], 1))
        outs = vol_model_ED.test_on_batch(batch)
        val_loss_tot += outs[0]
        val_loss_avg = val_loss_tot / (m+1)
        print('{}, {}'.format(outs[0], val_loss_avg))
    
    if val_loss_avg < val_loss_best:
        print('~~~ saving weights to ../../model_weights/weights_trainset2_local_ED_mae_transfer.hdf5')
        vol_model_ED.save_weights('../../model_weights/weights_trainset2_local_ED_mae_transfer.hdf5', 
                                  overwrite=True)
        val_loss_best = val_loss_avg
        patience = 0
    else:
        patience += 1
    
    if patience > 10:
        print('~~~~~~ EARLY STOPPING ~~~~~~')
        break

ED training (epoch 0)...
3.2034385800361633, 3.2034385800361633
9.482682943344116, 6.34306076169014
0.9927557706832886, 4.559625764687856
7.682967662811279, 5.340461239218712
13.649464845657349, 7.00226196050644
23.951502323150635, 9.827135354280472
18.51903772354126, 11.0688356927463
27.547135829925537, 13.128623209893703
13.634790897369385, 13.184864064057669
15.14991420507431, 13.381369078159333
33.926537454128265, 15.249111657792872
2.1001343727111816, 14.15336355070273
30.18085289001465, 15.386247346034416
10.868175506591797, 15.063527928931373
3.877896785736084, 14.317819186051686
20.397149622440338, 14.697777338325977
83.99943995475769, 18.774345727527844
8.322925567626953, 18.193711274200016
3.163785219192505, 17.402662534462777
34.85467317700386, 18.275263066589833
38.94246733188629, 19.25941565065157
20.77475929260254, 19.328294907103885
32.912532448768616, 19.918913930654526
48.54814112186432, 21.111798396954935
6.293223857879639, 20.51905541539192
6.749307096004486, 19.9894

Above training is conducted in separate script