# Modified U-net with Time Context - Training

In [None]:
%matplotlib inline
import os
import matplotlib.pylab as plt
import numpy as np
import keras
from keras import backend as K
from keras.models import Model
from keras.layers import Input, Conv2D, MaxPooling2D, concatenate
from keras.layers import UpSampling2D, Dropout 
from keras.layers.noise import GaussianNoise
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.preprocessing.image import ImageDataGenerator
import sys
import time
import nibabel as nib
np.random.seed(302)

# CNN metrics for segmentation problems
smooth = 1. #CNN dice coefficient smooth

def dice_coef(y_true, y_pred):
    ''' Metric used for CNN training'''
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_coef_loss(y_true, y_pred):
    ''' Loss function'''
    return -dice_coef(y_true, y_pred)

def get_unet_mod(patch_size = (None,None),learning_rate = 1e-5):
    ''' Get U-Net model with gaussian noise and dropout'''
    
    gaussian_noise_std = 0.025
    dropout = 0.25
    
    inputs = Input((patch_size[0], patch_size[1],3))
    input_with_noise = GaussianNoise(gaussian_noise_std)(inputs)    

    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(input_with_noise)
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
    pool4 = Dropout(dropout)(pool4)

    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)

    up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4],axis=-1)
    up6 = Dropout(dropout)(up6)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)

    up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3],axis=-1)
    up7 = Dropout(dropout)(up7)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)

    up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2],axis=-1)
    up8 = Dropout(dropout)(up8)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)

    up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=-1)
    up9 = Dropout(dropout)(up9)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)

    conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)

    model = Model(input=inputs, output=conv10)
    opt = Adam(lr=learning_rate, decay = 1e-6)
    model.compile(optimizer= opt,loss=dice_coef_loss, metrics=[dice_coef])

    return model

In [None]:
patches_path = "/media/roberto/DATA/GDrive/Jupyter-Scripts/Carotid-CNN/Patches"
prefixes = ["Kmd","Djm","ARS","AS","BAB","ETB","IC","LKS","MBT","MK","RAB","SNL"]
patch_files = [f for f in os.listdir(patches_path) if f.endswith("_orig.npy") or f.endswith("_orig2.npy")]

for fold in xrange(1,12):
    patches = np.zeros((1,64,64,3))
    labels = np.zeros((1,64,64,1))
    for ii in patch_files:
        if ii.startswith(prefixes[fold]):
            continue
        aux_patches = np.load(os.path.join(patches_path,ii))
        try:
            aux_labels = np.load(os.path.join(patches_path,ii[:-9]+"_seg.npy"))[:,:,:,np.newaxis]
        except:
            aux_labels = np.load(os.path.join(patches_path,ii[:-10]+"_seg2.npy"))[:,:,:,np.newaxis]
        patches = np.concatenate((patches,aux_patches),axis = 0)
        labels = np.concatenate((labels,aux_labels),axis = 0)

    patches = patches[1:] 
    labels = labels[1:]
    indexes = np.arange(patches.shape[0],dtype = np.int32)
    np.random.shuffle(indexes)
    patches = patches[indexes]
    labels = labels[indexes]

    train01 = patches[:-5500]
    labels01 = labels[:-5500]

    val01 = patches[-5500:]
    labels_val01 = labels[-5500:]
    model_name = "carotid_unet_tr_cnn"
    mean = np.mean(train01)  
    std = np.std(train01)

    train01 -= mean
    train01 /= std

    val01-= mean
    val01/= std
    print train01.shape

    np.save(prefixes[fold]+".npy",np.array([mean,std]))
    # Early stopping callback to shut down training after 10 epochs with no improvement
    earlyStopping = EarlyStopping(monitor='val_dice_coef',
                                           patience=15, 
                                           verbose=1, mode='auto')

    # Checkpoint callback to save model after each improvement along the epochs
    checkpoint = ModelCheckpoint(prefixes[fold] + model_name + '.hdf5', mode = 'max', monitor='val_dice_coef'
                                 ,verbose=0, save_best_only=True, save_weights_only = True)


    model = get_unet_mod(patch_size = (64,64))
    #print model.summary()

    seed = 905
    image_datagen = ImageDataGenerator(
            rotation_range=30,
            width_shift_range=0.05,
            height_shift_range=0.05,
            shear_range=0.15,
            zoom_range=0.15,
            horizontal_flip=True,
            fill_mode='constant',
            cval = 0)

    mask_datagen = ImageDataGenerator(
            rotation_range=30,
            width_shift_range=0.05,
            height_shift_range=0.05,
            shear_range=0.15,
            zoom_range=0.15,
            horizontal_flip=True,
            fill_mode='constant',
            cval = 0)



    image_datagen.fit(train01, augment=True, seed=seed)
    mask_datagen.fit(labels01, augment=True, seed=seed)

    image_generator = image_datagen.flow(train01,batch_size = 32,seed = seed)
    mask_generator = mask_datagen.flow(labels01,batch_size = 32,seed = seed)


    # function to merge generators
    def combine_generator(gen1, gen2):
        while True:
            yield(gen1.next(), gen2.next())

    # combine generators into one which yields image and masks
    combined = combine_generator(image_generator, mask_generator)

    hist = model.fit_generator(combined,
                     epochs=100,
                     steps_per_epoch=train01.shape[0] / 32,
                     verbose=1,
                     validation_data= (val01,labels_val01),
                     callbacks=[checkpoint,earlyStopping])   

    #Load the best_model during training
    best_model = get_unet_mod(learning_rate =  5e-6)
    best_model.load_weights(prefixes[fold] + model_name + '.hdf5')

    hist = best_model.fit_generator(combined,
                     epochs=100,
                     steps_per_epoch=train01.shape[0] / 32,
                     verbose=1,
                     validation_data= (val01,labels_val01),
                     callbacks=[checkpoint,earlyStopping])     