In [1]:
import re
import os
import sys
import shutil
import numpy as np
import pandas as pd
import nibabel as nib
import tensorflow as tf
import albumentations as A
import matplotlib.pyplot as plt

from Models.models import Prob_Unet, AttXnet, Xnet, Unet

from glob import glob
from keras import backend as K
from sklearn.model_selection import KFold
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, ReduceLROnPlateau

# 1. Helpers

In [2]:
def read_data(file_path, start=0, end=0, step=1):
    img = nib.load(file_path)
    data = img.get_fdata()
    end = data.shape[-1] if end == 0 else end
    reverse_channel = [data[:,:,i] for i in range(start,end,step)]
    return np.array(reverse_channel, dtype='float32')

def adjust_data(img, cf=1, mask=False):
    if mask is False:
        # Normalized data 
        img = (img - np.mean(img))/(np.std(img) + K.epsilon())

    img = np.expand_dims(img, -1)
    # Central crop if needed
    if cf<1: img = tf.image.central_crop(img, cf)
    return img

def check_data(img, seg):
    # Select data that have a value (image that contain only background)
    idx = [k for k,i in enumerate(img) if len(np.unique(i)) > 1] 
    return img[idx], seg[idx]

def data_augmentation(images, masks):
    transform = A.Compose([A.HorizontalFlip(), A.Rotate(p=0.8)])
    for k,(i,j) in enumerate(zip(images,masks)):
        transformed = transform(image=i, mask=j)
        images[k] = transformed['image']
        masks[k] = transformed['mask']

def find_best(path_model, type_='min'):
    # Find best weights in filename
    path = glob(path_model+'/*.hdf5') 
    split = [i.split('-')[-1] for i in path]
    split.sort(key=natural_keys)

    if type_ == 'min':
        file_ = split[0]
    else:
        file_ = split[-1]
    return [i for i in path if file_ in i] [0]

def show_img(img, y_true, y_pred, idx=0):
    # Visualize image, ground truth, and prediction by slice
    plt.figure(figsize=(15,15))
    plt.subplot(1,3,1); plt.title('FLAIR image')
    plt.imshow(img[idx,:,:,0], cmap='gray'); plt.axis('off')
    plt.subplot(1,3,2); plt.title('Ground Truth')
    plt.imshow(y_true[idx,:,:,0], cmap='gray'); plt.axis('off')
    plt.subplot(1,3,3); plt.title('Prediction')
    plt.imshow(y_pred[idx,:,:,0], cmap='gray'); plt.axis('off')
    plt.show()

def save_wmh(y_pred, file_in, file_out, name='ADNI'):
    # Save WMHs predictions 
    y_true = nib.load(file_in)
    wmh = np.zeros(y_true.shape)

    if name == 'Singapore':
        if wmh.shape[1] != y_pred.shape[2]:
            wmh_temp = tf.image.rot90(y_pred, k=3).numpy()
            for i in range(wmh.shape[-1]):
                wmh[:,12:220,i] = wmh_temp[i,:,:,0]
        else:
            for i in range(wmh.shape[-1]):
                wmh[12:220,:,i] = y_pred[i,:,:,0]

    elif name == 'GE3T':
        for i in range(wmh.shape[-1]):
            wmh[:128,:,i] = y_pred[i,:,:,0]
    else:
        for i in range(wmh.shape[-1]):
            wmh[:,:,i] = y_pred[i,:,:,0]

    wmh = nib.Nifti1Image(wmh, y_true.affine, y_true.header)
    nib.save(wmh, file_out)

def save_wmh_challenge(y_pred, file_in, file_out, name):
    # Save WMHs prediction only for Challenge dataset as full
    y_true = nib.load(file_in)
    wmh = np.zeros(y_true.shape)

    if name == 'Singapore': start=4; end=236
    elif name == 'GE3T': start=54; end=186
    elif name == 'Utrecht': start=8; end=248
    else: start=0; end=240

    if name == 'Singapore':
        if wmh.shape[1] != y_pred.shape[2]:
            wmh_temp = tf.image.rot90(y_pred, k=3).numpy()
            for i in range(wmh.shape[-1]):
                wmh[:,:,i] = wmh_temp[i,:,start:end,0]
        else:
            for i in range(wmh.shape[-1]):
                wmh[:,:,i] = y_pred[i,start:end,:,0]

    elif name == 'Utrecht':
        for i in range(wmh.shape[-1]):
            wmh[:,:,i] = y_pred[i,:,start:end,0]
    else:
        for i in range(wmh.shape[-1]):
            wmh[:,:,i] = y_pred[i,start:end,:,0]

    wmh = nib.Nifti1Image(wmh, y_true.affine, y_true.header)
    nib.save(wmh, file_out)

# https://stackoverflow.com/questions/5967500/how-to-correctly-sort-a-string-with-a-number-inside
def atoi(text):
    return int(text) if text.isdigit() else text

# https://stackoverflow.com/questions/5967500/how-to-correctly-sort-a-string-with-a-number-inside
def natural_keys(text):
    return [atoi(c) for c in re.split(r'(\d+)', text)]

# https://github.com/MrGiovanni/UNetPlusPlus/blob/master/keras/helper_functions.py#L37-L42
def dice_coef(y_true, y_pred):
    smooth = 1.
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

# https://github.com/umbertogriffo/focal-loss-keras/blob/master/src/loss_function/losses.py#L11-L53
def binary_focal_loss(gamma=2., alpha=.25):
    def binary_focal_loss_fixed(y_true, y_pred):
        """
        :param y_true: A tensor of the same shape as `y_pred`
        :param y_pred:  A tensor resulting from a sigmoid
        :return: Output tensor.
        """
        y_true = tf.cast(y_true, tf.float32)
        # Define epsilon so that the back-propagation will not result in NaN for 0 divisor case
        epsilon = K.epsilon()
        # Add the epsilon to prediction value
        # y_pred = y_pred + epsilon
        # Clip the prediciton value
        y_pred = K.clip(y_pred, epsilon, 1.0 - epsilon)
        # Calculate p_t
        p_t = tf.where(K.equal(y_true, 1), y_pred, 1 - y_pred)
        # Calculate alpha_t
        alpha_factor = K.ones_like(y_true) * alpha
        alpha_t = tf.where(K.equal(y_true, 1), alpha_factor, 1 - alpha_factor)
        # Calculate cross entropy
        cross_entropy = -K.log(p_t)
        weight = alpha_t * K.pow((1 - p_t), gamma)
        # Calculate focal loss
        loss = weight * cross_entropy
        # Sum the losses in mini_batch
        loss = K.mean(K.sum(loss, axis=1))
        return loss
    return binary_focal_loss_fixed

# https://github.com/baumgach/PHiSeg-code/blob/c43f3b32e1f434aecba936ff994b6f743ba7a5f8/utils.py#L326-L370
def ambiguity_map(y_gen, seg=None):
    def pixel_wise_xent(m_samp, m_gt, eps=1e-8):
        log_samples = np.log(m_samp + eps)
        return -1.0*np.sum(m_gt*log_samples, axis=-1)

    y_pred = np.average(y_gen, axis=0)
    E_arr = np.zeros(y_gen.shape)
    for i in range(y_gen.shape[0]):
        for j in range(y_gen.shape[1]):
            if seg is None:
                E_arr[i,j,...] = np.expand_dims(pixel_wise_xent(y_gen[i,j,...], y_pred[j,...]), axis=-1)
            else:
                E_arr[i,j,...] = np.expand_dims(pixel_wise_xent(y_gen[i,j,...], seg[j,...]), axis=-1)

    return np.average(E_arr, axis=0)

# 2. Datasets

In [3]:
def get_dataset(name, end=0):
    if name == 'ADNI' and end >= 1:
        start = 5; end = 30; step=1
    elif name == 'Singapore' and end >= 1:
        start = 0; end = 45; step=1
    elif name == 'GE3T' and end >= 1:
        start = 25; end = 70; step=1
    elif name == 'Utrecht' and end >= 1:
        start = 0; end = 45; step=1
    else:
        start = 0; end = 0; step=1

    if name == 'Challenge':
        fold_img, fold_seg = read_challenge(start, end, step=1, cf=1)
    else:
        fold_img, fold_seg = read_dataset(name, start, end, step=step, cf=1)
    
    return fold_img, fold_seg

def read_dataset(name, start, end, step=1, cf=1):
    if name == 'ADNI':
        flair_path = glob(path_dataset+'ADNI/*/*/*brain.nii.gz')
        lesion_path = glob(path_dataset+'ADNI/*/*/*wmh.nii.gz')
    else:
        flair_path = glob(path_dataset+'Challenge/{}/*/*brain.nii.gz'.format(name))
        lesion_path = glob(path_dataset+'Challenge/{}/*/*wmh.nii.gz'.format(name))

    flair_path.sort(key=natural_keys) 
    lesion_path.sort(key=natural_keys)
    
    if name == 'ADNI':
        flair_path = [[i for i in flair_path if i.split('/')[-3] == 'fold'+str(j)] for j in range(1,5)]
        lesion_path = [[i for i in lesion_path if i.split('/')[-3] == 'fold'+str(j)] for j in range(1,5)]
    else:
        flair_path = [flair_path]
        lesion_path = [lesion_path]

    fold_img, fold_seg = [], []
    for fp,lp in zip(flair_path,lesion_path):
        if name == 'Singapore':
            # Adjust first 2 data at Singapore
            img_temp = [np.expand_dims(read_data(i, start=start, end=end, step=step), -1) for i in fp[:2]]
            seg_temp = [np.expand_dims(read_data(i, start=start, end=end, step=step), -1) for i in lp[:2]]
            img_temp = np.array([[np.squeeze(tf.image.rot90(j).numpy(), -1) for j in i] for i in img_temp])
            seg_temp = np.array([[np.squeeze(tf.image.rot90(j).numpy(), -1) for j in i] for i in seg_temp])
            del fp[:2]; del lp[:2]

        # Read Data
        img = np.array([read_data(i, start=start, end=end, step=step) for i in fp])
        seg = np.array([read_data(i, start=start, end=end, step=step) for i in lp])

        if name == 'Singapore':
            # Concat with first 2 data and cropping
            img = np.concatenate([img_temp, img])
            seg = np.concatenate([seg_temp, seg])
            img = np.array([i[:,12:220,:] for i in img])
            seg = np.array([i[:,12:220,:] for i in seg])
        elif name == 'GE3T':
            # Cropping
            img = np.array([i[:,:128,:] for i in img])
            seg = np.array([i[:,:128,:] for i in seg])

        # Normalized data
        img = np.array([[adjust_data(j, cf=cf) for j in i] for i in img])
        seg = np.array([[adjust_data(j, cf=cf, mask=True) for j in i] for i in seg])

        if name not in ['ADNI']:
            seg[(seg == 2).all(axis=-1)] = 0

        fold_img.append(img); fold_seg.append(seg)
    return np.array(fold_img), np.array(fold_seg)

def read_challenge(start, end, step=1, cf=1):
    # Read Challenge dataset as full dataset
    fold_img, fold_seg = [], []
    for name in ['Singapore', 'GE3T', 'Utrecht']:
        flair_path = glob(path_dataset+'Challenge/{}/*/*brain.nii.gz'.format(name))
        lesion_path = glob(path_dataset+'Challenge/{}/*/*wmh.nii.gz'.format(name))
        flair_path.sort(key=natural_keys) 
        lesion_path.sort(key=natural_keys)
        flair_path = [flair_path]
        lesion_path = [lesion_path]

        for fp,lp in zip(flair_path, lesion_path):
            if name == 'Singapore':
                # Adjust first 2 data at Singapore
                img_temp = [np.expand_dims(read_data(i, start=start, end=end, step=step), -1) for i in fp[:2]]
                seg_temp = [np.expand_dims(read_data(i, start=start, end=end, step=step), -1) for i in lp[:2]]
                img_temp = np.array([[np.squeeze(tf.image.rot90(j).numpy(), -1) for j in i] for i in img_temp])
                seg_temp = np.array([[np.squeeze(tf.image.rot90(j).numpy(), -1) for j in i] for i in seg_temp])
                del fp[:2]; del lp[:2]

            # Read Data
            img = np.array([read_data(i, start=start, end=end, step=step) for i in fp])
            seg = np.array([read_data(i, start=start, end=end, step=step) for i in lp])
            
            if name == 'Singapore':
                img = np.concatenate([img_temp, img])
                seg = np.concatenate([seg_temp, seg])
            fold_img.append(img); fold_seg.append(seg)

    # Get the max sizes from all insitutions
    sizes = [i.shape[-2:] for i in fold_img]
    max_sizes = np.max(list(zip(*sizes)), -1).tolist()
    diff_sizes = [(int((max_sizes[0]-i)/2), (int((max_sizes[1]-j)/2))) for i,j in sizes]
    fold_img_padd, fold_seg_padd = [], []

    for k,(i,j) in enumerate(zip(fold_img, fold_seg)):
        # Zero padding
        img = np.zeros(i.shape[:2] + tuple(max_sizes))
        seg = np.zeros(i.shape[:2] + tuple(max_sizes))
        x = sizes[k][0]; x_pad = diff_sizes[k][0]
        y = sizes[k][1]; y_pad = diff_sizes[k][1]

        img[:,:,x_pad:x+x_pad,y_pad:y+y_pad] = i
        seg[:,:,x_pad:x+x_pad,y_pad:y+y_pad] = j

        # Normalized data
        img = np.array([[adjust_data(j, cf=cf) for j in i] for i in img])
        seg = np.array([[adjust_data(j, cf=cf, mask=True) for j in i] for i in seg])
        seg[(seg == 2).all(axis=-1)] = 0
        fold_img_padd.append(img); fold_seg_padd.append(seg)

    return fold_img_padd, fold_seg_padd

# 3. Parameters

In [4]:
# Paths config
## For saving WMHs prediction
result_path = 'Results/{}/{}/{}'
## Pre-trained path
path = 'pre-trained/'
## Dataset path
path_dataset = 'Datasets/'
## Evaluation path
eval_path = 'Evaluation_Excel/'

# Hyperparameters
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2)
generate_pred = 30
threshold = 0.25
batch_size = 16
epoch = 50
lr = 0.001
gamma = 0.25
alpha = 0.5

#If not use Probabilistic U-Net please fill this variable as None
lr_latent = 0.001

if lr_latent is None:
    uniq = str(lr)+'_'+str(gamma)+'_'+str(alpha)
else:
    uniq = str(lr)+'_'+str(lr_latent)+'_'+str(gamma)+'_'+str(alpha)

def pick_model(model_name):
    n_blocks = 4
    n_classes = 1 
    input_shape = (None, None, 1)
    decoder_filters = (256, 128, 64, 32)

    # Deterministic
    if model_name == 'U_Net':
        md = Unet(use_backbone=False, input_shape=input_shape, attention=False,
                n_upsample_blocks=n_blocks, decoder_filters=decoder_filters)

    elif model_name == 'AttU_Net':
        md = Unet(use_backbone=False, input_shape=input_shape, attention=True,
                n_upsample_blocks=n_blocks, decoder_filters=decoder_filters)
        
    elif model_name == 'XNet':
        md = Xnet(use_backbone=False, input_shape=input_shape, attention=False,
                n_upsample_blocks=n_blocks, decoder_filters=decoder_filters, deep_supervision=False)

    elif model_name == 'AttXNet':
        md = Xnet(use_backbone=False, input_shape=input_shape, attention=True,
                n_upsample_blocks=n_blocks, decoder_filters=decoder_filters, deep_supervision=False)

    elif model_name == 'XNet_ds':
        md = Xnet(use_backbone=False, input_shape=input_shape, attention=False,
                n_upsample_blocks=n_blocks, decoder_filters=decoder_filters, deep_supervision=True)

    elif model_name == 'AttXNet_ds':
        md = Xnet(use_backbone=False, input_shape=input_shape, attention=True,
                n_upsample_blocks=n_blocks, decoder_filters=decoder_filters, deep_supervision=True)

    # Probabilistic
    elif model_name == 'Prob_U_Net':
        md = Prob_Unet(num_classes=n_classes, activation='sigmoid', latent_dim=6)

    else:
        md = None
    return md

# 4. Experiments

## 4.1. K-Fold Cross Validation

In [None]:
# U_Net, AttU_Net, XNet, AttXNet, XNet_ds, AttXNet_ds, Prob_U_Net
model_name = 'Prob_U_Net'

# ADNI, Challenge
train = 'ADNI'
amount = 2
dsc_list = []
kf = KFold(n_splits=amount)
fold_img, fold_seg = get_dataset(train)

for k, (train_idx, test_idx) in enumerate(kf.split(fold_img)):
    num = 0 if k == 0 else 10 
    # Training data 
    if train == 'ADNI':
        img_temp = np.concatenate([fold_img[train_idx[0]], fold_img[train_idx[1]]])
        seg_temp = np.concatenate([fold_seg[train_idx[0]], fold_seg[train_idx[1]]])
        img_train = np.concatenate(img_temp[:24])
        seg_train = np.concatenate(seg_temp[:24])
        img_val = np.concatenate(img_temp[-6:])
        seg_val = np.concatenate(seg_temp[-6:])
    else:
        img_temp = [i[num:num+10] for i in fold_img]
        seg_temp = [i[num:num+10] for i in fold_seg]
        img_train = np.concatenate([np.concatenate(i[:-2]) for i in img_temp])
        seg_train = np.concatenate([np.concatenate(i[:-2]) for i in seg_temp])
        img_val = np.concatenate([np.concatenate(i[-2:]) for i in img_temp])
        seg_val = np.concatenate([np.concatenate(i[-2:]) for i in seg_temp])

    del img_temp, seg_temp
    img_train, seg_train = check_data(img_train, seg_train)
    img_val, seg_val = check_data(img_val, seg_val)
    data_augmentation(img_train, seg_train)
    
    # Model
    logdir = path+'logdir/{}_{}_fold{}_2_{}'.format(model_name,uniq,k+1,train)
    temp = logdir.split('/')[-1]
    # Becareful with these 4 lines!
    ## Will delete existing folders and files
    if len(glob(logdir)) >= 1: shutil.rmtree(logdir)
    if len(glob(path+temp)) >= 1: shutil.rmtree(path+temp)
    ## Will create folders and files
    if len(glob(logdir)) == 0: os.mkdir(logdir)
    if len(glob(path+temp)) == 0: os.mkdir(path+temp)

    tensorboard = TensorBoard(log_dir=logdir)

    save_weights_only = True if lr_latent is not None else False
    model_checkpoint = ModelCheckpoint(
        filepath=path+temp+"/checkpoint-best-{epoch:02d}-{val_loss:.4f}.hdf5",
        monitor='val_loss', mode='min', save_best_only=True, save_weights_only=save_weights_only)
    
    # Compile model
    md = pick_model(model_name)
    if model_name == 'Prob_U_Net':
        md.compile(unet_opt=Adam(learning_rate=lr), 
                   prior_opt=Adam(learning_rate=lr_latent), 
                   posterior_opt=Adam(learning_rate=lr_latent),
                   loss=binary_focal_loss(gamma=gamma, alpha=alpha), metric=dice_coef)
    
    else:
        md.compile(optimizer=Adam(learning_rate=lr), metrics=[dice_coef],
                   loss=binary_focal_loss(gamma=gamma, alpha=alpha))

    # Training
    md.fit(img_train, seg_train, epochs=epoch, batch_size=batch_size,
            validation_data=(img_val,seg_val), verbose=2,
            callbacks=[model_checkpoint, tensorboard])
    print('Finish training on fold-', k+1,'\n')

    # Testing
    if lr_latent is not None: md.built = True
    best_weight = find_best(path+temp)
    md.load_weights(best_weight)
    dsc_patient = []

    ## ADNI Testing
    if train == 'ADNI':
        paths = glob(path_dataset+'ADNI/{}/*'.format('fold'+str(test_idx[0]+1)))
        paths = paths + glob(path_dataset+'ADNI/{}/*'.format('fold'+str(test_idx[1]+1)))
        paths.sort(key=natural_keys)

        img_test = np.concatenate([fold_img[test_idx[0]], fold_img[test_idx[1]]])
        seg_test = np.concatenate([fold_seg[test_idx[0]], fold_seg[test_idx[1]]])

        for idx_p, (img, seg) in enumerate(zip(img_test, seg_test)):
            if lr_latent is not None:
                y_gen = []
                for x in range(generate_pred):
                    if model_name == 'Prob_U_Net':
                        z_sample, _,_ = md.prior.predict(img)
                        y_pred = md.det_unet.predict([img, z_sample])
                    else:
                        y_pred = md.predict(img)

                    y_gen.append(y_pred)

                y_gen = np.array(y_gen)
                y_ss = ambiguity_map(y_gen)
                y_sy = ambiguity_map(y_gen, seg)
                y_pred = np.average(y_gen, axis=0)
            else:
                y_pred = md.predict(img)
                y_ss = None; y_sy = None

            y_pred[y_pred > threshold] = 1
            y_pred[y_pred <= threshold] = 0
            dsc = dice_coef(seg, y_pred).numpy()
            dsc_patient.append(dsc)
            
            patient_id = paths[idx_p].split('/')[-1]
            fold_id = paths[idx_p].split('/')[-2]
            file_in = '{}/{}_wmh.nii.gz'.format(paths[idx_p], patient_id)
            file_out = result_path.format('ADNI', fold_id, patient_id)
            if len(glob(file_out)) == 0: os.makedirs(file_out)
            file_out_pred = file_out+'/{}_wmh_{}.nii.gz'.format(patient_id, model_name)
            show_img(img, seg, y_pred, 18)
            save_wmh(y_pred, file_in, file_out_pred)

            # Ambiguity maps
            if y_ss is not None and y_sy is not None:
                file_out_1 = file_out+'/{}_am_ss_{}.nii.gz'.format(patient_id, model_name)
                file_out_2 = file_out+'/{}_am_sy_{}.nii.gz'.format(patient_id, model_name)
                save_wmh(y_ss, file_in, file_out_1)
                save_wmh(y_sy, file_in, file_out_2)

    ## Challenge Testing
    else:
        for iter_,inst in enumerate(['Singapore', 'GE3T', 'Utrecht']):
            dsc_temp = []
            paths = glob(path_dataset+'Challenge/{}/*'.format(inst))
            paths.sort(key=natural_keys)
            paths = paths[10-num:20-num]
            
            img_test = np.array(fold_img[iter_][10-num:20-num], dtype='float32')
            seg_test = np.array(fold_seg[iter_][10-num:20-num], dtype='float32')
    
            for idx_p, (img, seg) in enumerate(zip(img_test, seg_test)):
                if lr_latent is not None:
                    y_gen = []
                    for x in range(generate_pred):
                        if model_name == 'Prob_U_Net':
                            z_sample, _,_ = md.prior.predict(img)
                            y_pred = md.det_unet.predict([img, z_sample])
                        else:
                            y_pred = md.predict(img)

                        y_gen.append(y_pred)

                    y_gen = np.array(y_gen)
                    y_ss = ambiguity_map(y_gen)
                    y_sy = ambiguity_map(y_gen, seg)
                    y_pred = np.average(y_gen, axis=0)
                else:
                    y_pred = md.predict(img)
                    y_ss = None; y_sy = None

                y_pred[y_pred > threshold] = 1
                y_pred[y_pred <= threshold] = 0
                dsc = dice_coef(seg, y_pred).numpy()
                dsc_temp.append(dsc)
                
                patient_id = paths[idx_p].split('/')[-1]
                file_in = '{}/{}_wmh.nii.gz'.format(paths[idx_p], patient_id)
                file_out = result_path.format('Challenge', inst, patient_id)
                if len(glob(file_out)) == 0: os.makedirs(file_out)
                file_out_pred = file_out+'/{}_wmh_{}.nii.gz'.format(patient_id, model_name)
                show_img(img, seg, y_pred, 20)
                save_wmh_challenge(y_pred, file_in, file_out_pred, name=inst)

                # Ambiguity maps
                if y_ss is not None and y_sy is not None:
                    file_out_1 = file_out+'/{}_am_ss_{}.nii.gz'.format(patient_id, model_name)
                    file_out_2 = file_out+'/{}_am_sy_{}.nii.gz'.format(patient_id, model_name)
                    save_wmh_challenge(y_ss, file_in, file_out_1, name=inst)
                    save_wmh_challenge(y_sy, file_in, file_out_2, name=inst)

            dsc_patient.append(dsc_temp)
    dsc_list.append(dsc_patient)
    print('Average Dice Coefficient on fold-{} : {:.4f}'.format(k+1, np.average(dsc_patient)))
print('Average Dice Coefficient : {:.4f}'.format(np.average(dsc_list)))

## 4.2. Cross Dataset

In [None]:
# U_Net, AttU_Net, XNet, AttXNet, XNet_ds, AttXNet_ds, Prob_U_Net
model_name = 'Prob_U_Net'

# ADNI, Challenge
train = 'ADNI'

# ADNI, Singapore, GE3T, Utrecht
test = ['Singapore', 'GE3T', 'Utrecht']

fold_img, fold_seg = get_dataset(train)
dsc_dict = {i:{j:0 for j in test} for i in [train]}

for type_data in [train]: # train data
    # Training data
    if type_data == 'Challenge':
        img_train = np.concatenate([np.concatenate(i[:-4]) for i in fold_img])
        seg_train = np.concatenate([np.concatenate(i[:-4]) for i in fold_seg])
        img_val = np.concatenate([np.concatenate(i[-4:]) for i in fold_img])
        seg_val = np.concatenate([np.concatenate(i[-4:]) for i in fold_seg])
    else:
        if type_data == 'ADNI':
            img_temp = np.concatenate(fold_img)
            seg_temp = np.concatenate(fold_seg)
        else:
            img_temp = fold_img[0]
            seg_temp = fold_seg[0]

        img_train = np.concatenate(img_temp[:-6])
        seg_train = np.concatenate(seg_temp[:-6])
        img_val = np.concatenate(img_temp[-6:])
        seg_val = np.concatenate(seg_temp[-6:])

    img_train, seg_train = check_data(img_train, seg_train)
    img_val, seg_val = check_data(img_val, seg_val)
    data_augmentation(img_train, seg_train)

    # Model
    logdir = path+'logdir/{}_{}_{}'.format(model_nam, uniq, type_data)
    temp = logdir.split('/')[-1]
    # # Becareful with these 4 lines!
    ## Will delete existing folders and files
    if len(glob(logdir)) >= 1: shutil.rmtree(logdir)
    if len(glob(path+temp)) >= 1: shutil.rmtree(path+temp)
    ## Will create folders and files
    if len(glob(logdir)) == 0: os.mkdir(logdir)
    if len(glob(path+temp)) == 0: os.mkdir(path+temp)

    tensorboard = TensorBoard(log_dir=logdir)

    save_weights_only = True if lr_latent is not None else False
    model_checkpoint = ModelCheckpoint(
        filepath=path+temp+"/checkpoint-best-{epoch:02d}-{val_loss:.4f}.hdf5",
        monitor='val_loss', mode='min', save_best_only=True, save_weights_only=save_weights_only)

    # Compile model
    md = pick_model(model_name)
    if model_name == 'Prob_U_Net':
        md.compile(unet_opt=Adam(learning_rate=lr), 
                   prior_opt=Adam(learning_rate=lr_latent), 
                   posterior_opt=Adam(learning_rate=lr_latent),
                   loss=binary_focal_loss(gamma=gamma, alpha=alpha), metric=dice_coef)
    else:
        md.compile(optimizer=Adam(learning_rate=lr), metrics=[dice_coef], 
                   loss=binary_focal_loss(gamma=gamma, alpha=alpha))

    # Training
    md.fit(img_train, seg_train, epochs=epoch, batch_size=batch_size,
            validation_data=(img_val,seg_val), verbose=2,
            callbacks=[reduce_lr, model_checkpoint, tensorboard])
    print('Finished Training with Dataset ', type_data, '\n')

    # Testing
    if lr_latent is not None: md.built = True
    best_weight = find_best(path+temp)
    md.load_weights(best_weight)

    for i in test: # test data
        score = []
        if i == 'ADNI': paths = glob(path_dataset+'ADNI/*/*')
        else: paths = glob(path_dataset+'Challenge/{}/*'.format(i))
        paths.sort(key=natural_keys)
        
        img_test, seg_test = get_dataset(i)
        img_test = np.concatenate(img_test)
        seg_test = np.concatenate(seg_test)
        
        for idx_p, (img, seg) in enumerate(zip(img_test, seg_test)):
            if lr_latent is not None:
                y_gen = []
                for x in range(generate_pred):
                    if model_name == 'Prob_U_Net':
                        z_sample, _,_ = md.prior.predict(img)
                        y_pred = md.det_unet.predict([img, z_sample])
                    else:
                        y_pred = md.predict(img)

                    y_gen.append(y_pred)

                y_gen = np.array(y_gen)
                y_ss = ambiguity_map(y_gen)
                y_sy = ambiguity_map(y_gen, seg)
                y_pred = np.average(y_gen, axis=0)
            else:
                y_pred = md.predict(img)
                y_ss = None; y_sy = None

            y_pred[y_pred > threshold] = 1
            y_pred[y_pred <= threshold] = 0
            dsc = dice_coef(seg, y_pred).numpy()
            score.append(dsc)

            patient_id = paths[idx_p].split('/')[-1]
            if i == 'ADNI':
                fold_id = paths[idx_p].split('/')[-2]
                file_out = result_path.format(i, fold_id, patient_id)
            else:
                fold_id = 'Challenge'
                file_out = result_path.format(fold_id, i, patient_id)

            file_in = '{}/{}_wmh.nii.gz'.format(paths[idx_p], patient_id)
            if len(glob(file_out)) == 0: os.makedirs(file_out)
            file_out_pred = file_out+'/{}_wmh_{}_{}.nii.gz'.format(patient_id, model_name,type_data)
            show_img(img, seg, y_pred, 18)
            save_wmh(y_pred, file_in, file_out_pred, name=i)

            # Ambiguity maps
            if y_ss is not None and y_sy is not None:
                file_out_1 = file_out+'/{}_am_ss_{}_{}.nii.gz'.format(patient_id, model_name, type_data)
                file_out_2 = file_out+'/{}_am_sy_{}_{}.nii.gz'.format(patient_id, model_name, type_data)
                save_wmh(y_ss, file_in, file_out_1, name=i)
                save_wmh(y_sy, file_in, file_out_2, name=i)

        dsc_dict[type_data][i] = np.average(score)
        print('Dice Coefficient with dataset {} trained by {}: {:.4f}'.format(i, type_data, dsc_dict[type_data][i]))

    avg_ = np.average([dsc_dict[train][i] for i in test])
    print('Average: {:.4f}'.format(avg_))

# 5. Evaluation

In [None]:
# 'ADNI', 'Challenge'
train = 'ADNI'
# 'ADNI', 'Singapore', 'GE3T', 'Utrecht', 'Challenge'
test = ['Singapore', 'GE3T', 'Utrecht']
# U_Net, AttU_Net, XNet, AttXNet, KiU_Net, Prob_U_Net
model_name = 'Prob_U_Net'
df_by_model, dsc, patient = [], [], []

for i in test: # testing 
    if i == 'ADNI': 
        paths = glob(path_dataset+'ADNI/*/*/*wmh.nii.gz')
    elif i == 'Challenge':
        paths = []
        for j in ['Singapore', 'GE3T', 'Utrecht']:
            path = glob(path_dataset+'Challenge/{}/*/*wmh.nii.gz'.format(j))
            path.sort(key=natural_keys)
            paths += path
    else: 
        paths = glob(path_dataset+'Challenge/{}/*/*wmh.nii.gz'.format(i))

    if i != 'Challenge': paths.sort(key=natural_keys)
    temp, temp_dsc, temp_patient = [], [], []

    for lp in paths:
        patient_id = lp.split('/')[-2]
        if i == 'ADNI':
            fold_id = lp.split('/')[-3]
            pp = result_path.format(i, fold_id, patient_id)
        elif i == 'Challenge':
            fold_id = lp.split('/')[-3]
            pp = result_path.format('Challenge', fold_id, patient_id)
        else:
            pp = result_path.format('Challenge', i, patient_id)
        
        if i == train:
            pp = pp+'/{}_wmh_{}.nii.gz'.format(patient_id, model_name) # training
        else:
            pp = pp+'/{}_wmh_{}_{}.nii.gz'.format(patient_id, model_name, train) # training

        voxel_true = nib.load(lp); wmh_true = read_data(lp)
        voxel_pred = nib.load(pp); wmh_pred = read_data(pp)
        volume_true = np.prod([abs(voxel_true.affine[i][i]) for i in range(3)])/1000
        volume_true = np.count_nonzero(wmh_true)*volume_true
        volume_pred = np.prod([abs(voxel_pred.affine[i][i]) for i in range(3)])/1000
        volume_pred = np.count_nonzero(wmh_pred)*volume_pred
        score = dice_coef(wmh_true, wmh_pred).numpy()
        temp_dsc.append(score)
        temp_patient.append(patient_id)
        temp.append([volume_true, volume_pred])
    df_by_model.append(temp)
    dsc.append(temp_dsc)
    patient.append(temp_patient)

df_temp = [pd.DataFrame(i, index=j).transpose() for i,j in zip(df_by_model, patient)]
temp_str = '_Cross' if train not in test else ''
with pd.ExcelWriter(eval_path+'output{}_{}_{}.xlsx'.format(temp_str, train, model_name)) as writter:
    num = 0
    for k,i in enumerate(test):
        df_temp[k].to_excel(writter, startrow=num, startcol=0, sheet_name='vol')
        pd.DataFrame([dsc[k]], columns=df_temp[k].columns).to_excel(
            writter, startrow=num, startcol=0, sheet_name='dsc')
        num = num + df_temp[k].shape[0] + 3