## MRI T1 maps segmentation algorithm (V1.0)
Load packages and initialze global variable.

In [None]:
import numpy as np 
import tensorflow as tf
import pandas as pd
import os
import cv2
%matplotlib inline
import matplotlib.pyplot as plt
from T1maps_CImage import CImage
import re
import random

dataset_file = "sources_files_625_T1maps.xlsx"
result_rootpath = "..\\results\\"

dim = 512
start_epoch = 5
inc_epoch = 5
num_epochs = 75

_batchsize = 16

metrics = ["jacc_metric", "dice_metric"]
activations = ["LeakyReLU", "ReLU"]


# Data loader, file handling and helper routines

- load_shuffeled_images_and_masks()
- plotMask()
- create_excel_file()



In [None]:
"""
===========================================================================================
@fn         load_shuffeled_images_and_masks()
@details    First, load dataframe containing all image and mask paths and second, load dicom images and .mha masks from disk   
@param[in]  file_matrix - dataframe with patient list an file paths
@param[in]  X_shape - target image dimension 
@return     [train_dcms, train_masks, test_dcms, test_masks, val_dcms, val_masks] - Lists of images and masks 
@note  
===========================================================================================
"""  
def load_shuffeled_images_and_masks(file_matrix, X_shape):

    train_dcms = []
    train_masks = []
    test_dcms = []
    test_masks = []
    val_dcms = []
    val_masks = []

    size = file_matrix.shape
    num_data = size[0]

    # randomize index list for later data set splitting
    rand_idx = random.sample(list(range(num_data)), num_data)
    
    ## Split list into two sections 90% and 10%
    def split_two(lst, ratio=[0.9, 0.1]):
        assert(np.sum(ratio) == 1.0)  # makes sure the splits make sense
        train_ratio = ratio[0]
        # note this function needs only the "middle" index to split, the remaining is the rest of the split
        indices_for_splittin = [int(len(lst) * train_ratio)]
        train, test = np.split(lst, indices_for_splittin)
        return train, test

    ## Split list into three sections 80%, 10% and 10%
    def split_three(lst, ratio=[0.8, 0.1, 0.1]):
        train_r, val_r, test_r = ratio
        assert(np.sum(ratio) == 1.0)  # makes sure the splits make sense
        # note we only need to give the first 2 indices to split, the last one it returns the rest of the list or empty
        indicies_for_splitting = [int(len(lst) * train_r), int(len(lst) * (train_r+val_r))]
        train, val, test = np.split(lst, indicies_for_splitting)
        return train, val, test

    ## Create random fractional index lists.
    train_idx, val_idx, test_idx = split_three(rand_idx)
    print("Train-List: ")
    print(train_idx)
    print("Val-List: ")
    print(val_idx)
    print("Test-List: ")
    print(test_idx)
    
    ## Save index list for later testing.
    # idxfile = os.path.join(result_rootpath, "v108_indices.txt")
    # create_index_file(idxfile, train_idx, test_idx, val_idx)
   
    ## Start image reader loop from merged excel path file
    for idx in rand_idx: 
        patidx_str = re.findall(r'\d+', file_matrix["Patient"][idx])
        patidx = int(patidx_str[0])
        #print(patidx)
    
        mha_idx = file_matrix["MHA_INDEX"][idx]
        dcm_img = CImage(file_matrix["DICOM"][idx], "DCM_DUMMY", 'image', ".dcm")
        mha_mask = CImage(file_matrix["MHA_MASK"][idx], "MHA_MASK", 'mask', ".mha")

        mha_mask_data = mha_mask.img[:,:,mha_idx]
        if mha_mask.imagesize[0] != dcm_img.imagesize[0] and mha_mask.imagesize[1] != dcm_img.imagesize[1]:
            dcm_img_data = dcm_img.imgT    
        else:
            dcm_img_data = dcm_img.img    

        dcm_mod = cv2.resize(dcm_img_data,(X_shape,X_shape), interpolation=cv2.INTER_NEAREST)[:,:]
        dcm_mod = dcm_mod.astype(np.uint16)

        mask_mod = cv2.resize(mha_mask_data,(X_shape,X_shape), interpolation=cv2.INTER_NEAREST)[:,:]
        mask_mod16 = (mask_mod - mask_mod.min()) / (mask_mod.max() - mask_mod.min()) * 4095
        mask_mod16 = mask_mod16.astype(np.uint16)
        
        if idx in train_idx:
            train_dcms.append(dcm_mod)
            train_masks.append(mask_mod16)
        elif idx in test_idx:
            test_dcms.append(dcm_mod)
            test_masks.append(mask_mod16)
        elif idx in val_idx:
            val_dcms.append(dcm_mod)
            val_masks.append(mask_mod16)

    return [train_dcms, train_masks, test_dcms, test_masks, val_dcms, val_masks] 

"""
===========================================================================================
@fn         plotMask()
@details    Display function for source images and reference masks
@param[in]  X -source images
@param[in]  y - reference masks 
@return     void
@note  
""" 
def plotMask(X,y):
    sample = []
    
    for i in range(6):
        left = X[i]
        right = y[i]
        combined = np.hstack((left,right))
        sample.append(combined)
        
    for i in range(0,6,3):

        plt.figure(figsize=(25,10))
        
        plt.subplot(2,3,1+i)
        plt.imshow(sample[i],'gray')
        
        plt.subplot(2,3,2+i)
        plt.imshow(sample[i+1],'gray')
                
        plt.subplot(2,3,3+i)
        plt.imshow(sample[i+2],'gray')
        
        plt.show()

"""
===========================================================================================
@fn         create_excel_file()
@details    Creates dataframe of model arruracies and writes to excel  
@return     void
@note  
""" 
def create_excel_file(filepath, loss_arr, val_loss_arr, acc_arr, val_acc_arr, sheet):
    df = pd.DataFrame({'Loss': loss_arr,
                       'Validation Loss': val_loss_arr,
                       'Accuracy': acc_arr,
                       'Validation Accuracy': val_acc_arr})
    writer = pd.ExcelWriter(filepath, engine='xlsxwriter')
    # Convert the dataframe to an XlsxWriter Excel object.
    df.to_excel(writer, sheet_name=sheet, index=False)

    workbook  = writer.book
    worksheet = writer.sheets[sheet]

    # Add a header format.
    header_format = workbook.add_format({
        'bold': True,
        'text_wrap': True,
        'valign': 'top',
        'fg_color': '#D7E4BC',
        'border': 1})

    writer.close()


## Define Metrics, loss functions and model architecture

In [None]:
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras import backend as K
from keras.layers import LeakyReLU
#from tensorflow.keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ModelCheckpoint, LearningRateScheduler

"""
===========================================================================================
@fn         dice_coef()
@details    DSC calculation function  
@param[in]  y_true - Set of reference mask
@param[in]  y_pred - Set of prediction mask
@return     Dice Similay Coefficient 
@note  
""" 
def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + 1) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1)

"""
===========================================================================================
@fn         dice_coef_loss()
@details    Dice loss function  
@param[in]  y_true - Set of reference mask
@param[in]  y_pred - Set of prediction mask
@return     Dice Loss 
@note  
""" 
def dice_coef_loss(y_true, y_pred):
    return 1-dice_coef(y_true, y_pred)

"""
===========================================================================================
@fn         jaccard_coef()
@details    Creates dataframe of model arruracies and writes to excel  
@param[in]  y_true - Set of reference mask
@param[in]  y_pred - Set of prediction mask
@return     Dice Similay Coefficient 
@note  
""" 
def jaccard_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + 1.0)
"""
===========================================================================================
@fn         jaccard_coef_loss()
@details    IOU loss function  
@param[in]  y_true - Set of reference mask
@param[in]  y_pred - Set of prediction mask
@return     IOU Loss 
@note  
""" 
def jaccard_coef_loss(y_true, y_pred):
    return 1-jaccard_coef(y_true, y_pred)  

"""
===========================================================================================
@fn         unet()
@details    Creates dataframe of model arruracies and writes to excel  
@author     MM
@date       
@return     model
@note  
"""
def unet(_activation, _metric, input_size=(256,256,1)):
    if _activation == "ReLU":
        inputs = Input(input_size)
        
        conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
        conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
        pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

        conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
        conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
        pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

        conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
        conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
        pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

        conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
        conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
        pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

        conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
        conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)

        up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3)
        conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)
        conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)

        up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3)
        conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)
        conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)

        up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
        conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8)
        conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)

        up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
        conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9)
        conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)

        _outputs = Conv2D(1, (1, 1), activation='sigmoid')(conv9)

    elif _activation == "LeakyReLU":
        inputs = Input(input_size)
        
        conv1 = Conv2D(32, (3, 3), activation=LeakyReLU(alpha=0.1), padding='same')(inputs)
        conv1 = Conv2D(32, (3, 3), activation=LeakyReLU(alpha=0.1), padding='same')(conv1)
        pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

        conv2 = Conv2D(64, (3, 3), activation=LeakyReLU(alpha=0.1), padding='same')(pool1)
        conv2 = Conv2D(64, (3, 3), activation=LeakyReLU(alpha=0.1), padding='same')(conv2)
        pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

        conv3 = Conv2D(128, (3, 3), activation=LeakyReLU(alpha=0.1), padding='same')(pool2)
        conv3 = Conv2D(128, (3, 3), activation=LeakyReLU(alpha=0.1), padding='same')(conv3)
        pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

        conv4 = Conv2D(256, (3, 3), activation=LeakyReLU(alpha=0.1), padding='same')(pool3)
        conv4 = Conv2D(256, (3, 3), activation=LeakyReLU(alpha=0.1), padding='same')(conv4)
        pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

        conv5 = Conv2D(512, (3, 3), activation=LeakyReLU(alpha=0.1), padding='same')(pool4)
        conv5 = Conv2D(512, (3, 3), activation=LeakyReLU(alpha=0.1), padding='same')(conv5)

        up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3)
        conv6 = Conv2D(256, (3, 3), activation=LeakyReLU(alpha=0.1), padding='same')(up6)
        conv6 = Conv2D(256, (3, 3), activation=LeakyReLU(alpha=0.1), padding='same')(conv6)

        up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3)
        conv7 = Conv2D(128, (3, 3), activation=LeakyReLU(alpha=0.1), padding='same')(up7)
        conv7 = Conv2D(128, (3, 3), activation=LeakyReLU(alpha=0.1), padding='same')(conv7)

        up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
        conv8 = Conv2D(64, (3, 3), activation=LeakyReLU(alpha=0.1), padding='same')(up8)
        conv8 = Conv2D(64, (3, 3), activation=LeakyReLU(alpha=0.1), padding='same')(conv8)

        up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
        conv9 = Conv2D(32, (3, 3), activation=LeakyReLU(alpha=0.1), padding='same')(up9)
        conv9 = Conv2D(32, (3, 3), activation=LeakyReLU(alpha=0.1), padding='same')(conv9)

        _outputs = Conv2D(1, (1, 1), activation='sigmoid')(conv9)

    _model = Model(inputs=[inputs], outputs=[_outputs])

    if _metric == "dice_metric":      
        _model.compile(optimizer=Adam(learning_rate=1e-5), loss=dice_coef_loss,metrics=[dice_coef, 'binary_accuracy'])
    elif _metric == "jacc_metric": 
        _model.compile(optimizer=Adam(learning_rate=1e-5), loss=[jaccard_coef_loss], metrics = [jaccard_coef, 'binary_accuracy'])

    _model.summary()

    return _model

# Load required data and perform the training loop



In [None]:
from IPython.display import clear_output
from keras.optimizers import Adam 

## step 1
_file_matrix = pd.read_excel(dataset_file , sheet_name="matches")
[X_train, y_train, X_test, y_test, X_val, y_val,] = load_shuffeled_images_and_masks(_file_matrix, dim)

print("Num Training-Dataset: %d" % len(X_train))
print("Num Training-Masks: %d" % len(y_train))
print("Num Test-Dataset: %d" % len(X_test))
print("Num Test-Masks: %d" % len(y_test))
print("Num Val-Dataset: %d" % len(X_val))
print("Num Val-Masks: %d" % len(y_val))

print("training set")
plotMask(X_train,y_train)
print("testing set")
plotMask(X_test,y_test)

## step 2 - Cast image and masks lists to numpy array  
X_train = np.array(X_train).reshape(len(X_train),dim,dim,1)
y_train = np.array(y_train).reshape(len(y_train),dim,dim,1)
X_test = np.array(X_test).reshape(len(X_test),dim,dim,1)
y_test = np.array(y_test).reshape(len(y_test),dim,dim,1)
X_val = np.array(X_val).reshape(len(X_val),dim,dim,1)
y_val = np.array(y_val).reshape(len(y_val),dim,dim,1)
assert X_train.shape == y_train.shape
assert X_test.shape == y_test.shape
assert X_val.shape == y_val.shape

## step 3 - center value range of datasets to zero within [-1, 1]
train_vol = (X_train-2048.0)/2048.0
validation_vol = (X_val-2048.0)/2048.0
test_vol = (X_test-2048.0)/2048.0

train_seg = (y_train>2048).astype(np.float32)
validation_seg = (y_val>2048).astype(np.float32)
test_seg = (y_test>2048).astype(np.float32)

## step 4 - Start training loop
for epochs in range(start_epoch, num_epochs, inc_epoch):
    for _activation in activations:
        for _metric in metrics:
            print("Training mit %d Epochen:"%(epochs))
            # Create resultfolder and Excel file
            resultdir = "%d_epochs_%s_%s_%dbatch\\"%(epochs,_metric, _activation,_batchsize)
            resultpath = os.path.join(result_rootpath, resultdir)
            try: 
                os.mkdir(resultpath) 
                print("Directory '%s' created" % resultpath)
            except OSError as error: 
                print(error) 

            ## Compile and train the U-Net Model
            model = unet(_activation, _metric, input_size=(512,512,1))
            
            from keras.callbacks import ModelCheckpoint, LearningRateScheduler, EarlyStopping, ReduceLROnPlateau, CSVLogger
            weight_file="myocardseg_model_weights%depoch_%s_%s.best.hdf5"%(epochs, _activation, _metric)
            weight_path = os.path.join(resultpath, weight_file) 

            ## Callbacks, Early Stopping and Reduced LR
            checkpoint = ModelCheckpoint(weight_path, monitor='val_loss', verbose=1, save_best_only=True, mode='min', save_weights_only = True)
            reduceLROnPlat = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, verbose=1, mode='min', min_delta=0.0001, cooldown=2, min_lr=1e-6)

            with_early = False
            if with_early:
                early_stopping = EarlyStopping(monitor="val_loss", mode="min", patience=15) 
                callbacks_list = [checkpoint, early_stopping, reduceLROnPlat]
            else:
                callbacks_list = [checkpoint, reduceLROnPlat]

            ## TRAIN the model
            if _metric == "dice_metric":      
                model.compile(optimizer=Adam(learning_rate=2e-4), loss=[dice_coef_loss], metrics = [dice_coef, 'binary_accuracy'])
            elif _metric == "jacc_metric": 
                model.compile(optimizer=Adam(learning_rate=2e-4), loss=[jaccard_coef_loss], metrics = [jaccard_coef, 'binary_accuracy'])

            model.reset_states()
            loss_history = model.fit(   x=train_vol, 
                                        y=train_seg, 
                                        batch_size=_batchsize, 
                                        epochs=epochs, 
                                        validation_data=(test_vol,test_seg), 
                                        callbacks=callbacks_list)
            print(loss_history)

            ## Clear output in notebook
            clear_output()

            ## Plot metric and evaluate
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize = (10, 5))
            ax1.plot(loss_history.history['loss'], '-', label = 'Loss')
            ax1.plot(loss_history.history['val_loss'], '-', label = 'Validation Loss')
            ax1.set_xlabel("Epoch")
            ax1.legend()
            ax1.grid()

            ax2.plot(100*np.array(loss_history.history['binary_accuracy']), '-', label = 'Accuracy')
            ax2.plot(100*np.array(loss_history.history['val_binary_accuracy']), '-', label = 'Validation Accuracy')
            ax2.set_xlabel("Epoch")
            ax2.legend()
            ax2.grid()
            filename = "HR_%dEpochs_Accuracy_%s_%s.png"%(epochs, _activation, _metric)
            figurepath = os.path.join(resultpath, filename) 
            fig.savefig(figurepath, dpi=600)

            ## Save accuracy results to excel
            excelfile = "%dEpochs_accuarcy_%s_%s.xlsx"%(epochs, _activation, _metric)
            excelpath = os.path.join(resultpath,excelfile)
            try:
                create_excel_file(excelpath, np.array(loss_history.history['loss']), np.array(loss_history.history['val_loss']), np.array(loss_history.history['binary_accuracy']), np.array(loss_history.history['val_binary_accuracy']), "val_acc")
                print("File '%s' was created" % excelfile)
            except OSError as error: 
                print(error) 

            ## Test the model with validation dataset and plot results
            pred_candidates = np.random.randint(1,validation_vol.shape[0],10)

            plt.figure(figsize=(16,10))
            pred_index = np.arange(0,validation_vol.shape[0],1)
            print(pred_index)
            print(pred_index.shape[0])
           
            preds = model.predict(validation_vol)

            for i in pred_index:
                plt.figure(figsize=(16,5))
                plt.subplot(1,3,1)
                    
                plt.imshow(validation_vol[i],'gray', interpolation='none')
                plt.xlabel("MRI T1-Map")
                    
                plt.subplot(1,3,2)
                plt.imshow(validation_seg[i],'gray', interpolation='none')
                plt.xlabel("Ground Truth")
                    
                plt.subplot(1,3,3)
                plt.imshow(preds[i],'gray', interpolation='none')
                plt.xlabel("Prediction")
            
                filename = "%d_Epochs_-Ipol_%s_%s_%dbatch_%s_Segm%d.png"%(epochs,_metric, _activation,_batchsize,_metric,i)
                figurepath = os.path.join(resultpath, filename) 
                plt.savefig(figurepath, dpi=300)
                plt.cla()
                clear_output()
