<a href="https://colab.research.google.com/github/suhitaghosh10/unet/blob/master/unet_gn_wn_wdsc.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import sys
import numpy as np
import os

from keras import backend as K
K.set_image_data_format('channels_last')  # TF dimension ordering in this code

from weight_norm import AdamWithWeightnorm 
from keras.callbacks import ModelCheckpoint, LearningRateScheduler, ReduceLROnPlateau, EarlyStopping, TensorBoard

from keras.models import Model
from keras.layers import concatenate, Input, Conv3D, MaxPooling3D, Conv3DTranspose, Lambda, Dropout
from keras.callbacks import CSVLogger
#from keras.utils.vis_utils import plot_model

smooth = 1.

class anisotopic_UNET:

    def get_Tversky(alpha=.3, beta=.7, verb=0):
        def Tversky(y_true, y_pred):
            y_true_f = K.flatten(y_true)
            y_pred_f = K.flatten(y_pred)
            intersection = K.sum(y_true_f * y_pred_f)
            G_P = alpha * K.sum((1 - y_true_f) * y_pred_f)  # G not P
            P_G = beta * K.sum(y_true_f * (1 - y_pred_f))  # P not G
            return (intersection + smooth) / (intersection + smooth + G_P + P_G)

        def Tversky_loss(y_true, y_pred):
            return -Tversky(y_true, y_pred)

        return Tversky, Tversky_loss

    # Tversky, Tversky_loss = get_Tversky(alpha = .3,beta= .7,verb=0)
    # Metrics = [dice_coef_loss ,Tversky

      
    def dice_coefficient(self, y_true, y_pred, axis=(-4, -3, -2), smooth=0.00001):
          return K.mean(2. * (K.sum(y_true * y_pred,
                              axis=axis) + smooth/2)/(K.sum(y_true,
                                                            axis=axis) + K.sum(y_pred,
                                                                               axis=axis) + smooth))
      
    def W_dice_coefficient(self, y_true, y_pred, smooth=0.00001):
          return K.mean(2. * (K.sum(y_true * y_pred) + smooth/2)/(K.sum(y_true,) + K.sum(y_pred) + smooth))
                                                                               


    def dice_coefficient_loss(self, y_true, y_pred):
         return -self.dice_coefficient(y_true, y_pred)
      
    def W_dice_coefficient_loss(self, y_true, y_pred):
         return -self.W_dice_coefficient(y_true, y_pred)
 

    # downsampling, analysis path
    def downLayer(self, inputLayer, filterSize, i, bn=False):

        conv = Conv3D(filterSize, (3, 3, 3), activation='relu', padding='same',  name='conv'+str(i)+'_1')(inputLayer)
        if bn:
            conv = GroupNormalization(groups=8, axis=1)(conv)
        conv = Conv3D(filterSize * 2, (3, 3, 3), activation='relu', padding='same',  name='conv'+str(i)+'_2')(conv)
        if bn:
            conv = GroupNormalization(groups=8, axis=1)(conv)
        pool = MaxPooling3D(pool_size=(1, 2, 2))(conv)

        return pool, conv


    # upsampling, synthesis path
    def upLayer(self, inputLayer, concatLayer, filterSize, i, bn=False, do= False):

        up = Conv3DTranspose(filterSize, (2, 2, 2), strides=(1, 2, 2), activation='relu', padding='same',  name='up'+str(i))(inputLayer)
       # print( concatLayer.shape)
        up = concatenate([up, concatLayer])
        conv = Conv3D(int(filterSize/2), (3, 3, 3), activation='relu', padding='same',  name='conv'+str(i)+'_1')(up)
        if bn:
            conv = GroupNormalization(groups=8, axis=1)(conv)
        if do:
            conv = Dropout(0.5, seed = 3, name='Dropout_' + str(i))(conv)
        conv = Conv3D(int(filterSize/2), (3, 3, 3), activation='relu', padding='same',  name='conv'+str(i)+'_2')(conv)
        if bn:
            conv = GroupNormalization(groups=8, axis=1)(conv)

        return conv



    """
    TODO: 
    correct position of batch normalization. instead of using it after relu activation, place it between
    convolutions and activation function
    """

    def get_net(self, nrInputChannels=1, learningRate=5e-5, bn = True, do = False):

        sfs = 16  # start filter size

        inputs = Input((32, 168, 168, nrInputChannels))

        conv1, conv1_b_m = self.downLayer(inputs, sfs, 1, bn)
        conv2, conv2_b_m = self.downLayer(conv1, sfs * 2, 2, bn)

        conv3 = Conv3D(sfs * 4, (3, 3, 3), activation='relu', padding='same', name='conv' + str(3) + '_1')(conv2)
        if bn:
            conv3 = GroupNormalization(groups=8, axis=1)(conv3)
        conv3 = Conv3D(sfs * 8, (3, 3, 3), activation='relu', padding='same', name='conv' + str(3) + '_2')(conv3)
        if bn:
            conv3 = GroupNormalization(groups=8, axis=1)(conv3)
        pool3 = MaxPooling3D(pool_size=(2, 2, 2))(conv3)
        # conv3, conv3_b_m = downLayer(conv2, sfs*4, 3, bn)

        conv4 = Conv3D(sfs * 16, (3, 3, 3), activation='relu', padding='same', name='conv4_1')(pool3)
        if bn:
            conv4 = GroupNormalization(groups=8, axis=1)(conv4)
        if do:
            conv4 = Dropout(0.5, seed=4, name='Dropout_' + str(4))(conv4)
        conv4 = Conv3D(sfs * 16, (3, 3, 3), activation='relu', padding='same', name='conv4_2')(conv4)
        if bn:
            conv4 = GroupNormalization(groups=8, axis=1)(conv4)

        # conv5 = upLayer(conv4, conv3_b_m, sfs*16, 5, bn, do)
        up1 = Conv3DTranspose(sfs * 16, (2, 2, 2), strides=(2, 2, 2), activation='relu', padding='same',
                              name='up' + str(5))(conv4)
        up1 = concatenate([up1, conv3])
        conv5 = Conv3D(int(sfs * 8), (3, 3, 3), activation='relu', padding='same', name='conv' + str(5) + '_1')(up1)
        if bn:
            conv5 = GroupNormalization(groups=8, axis=1)(conv5)
        if do:
            conv5 = Dropout(0.5, seed=5, name='Dropout_' + str(5))(conv5)
        conv5 = Conv3D(int(sfs * 8), (3, 3, 3), activation='relu', padding='same', name='conv' + str(5) + '_2')(conv5)
        if bn:
            conv5 = GroupNormalization(groups=8, axis=1)(conv5)

        conv6 = self.upLayer(conv5, conv2_b_m, sfs * 8, 6, bn, do)
        conv7 = self.upLayer(conv6, conv1_b_m, sfs * 4, 7, bn, do)

        conv_out = Conv3D(5, (1, 1, 1), activation='softmax', name='conv_final_softmax')(conv7)

        pz = Lambda(lambda x: x[:, :, :, :, 0], name='pz')(conv_out)
        cz = Lambda(lambda x: x[:, :, :, :, 1], name='cz')(conv_out)
        us = Lambda(lambda x: x[:, :, :, :, 2], name='us')(conv_out)
        afs = Lambda(lambda x: x[:, :, :, :, 3], name='afs')(conv_out)
        bg = Lambda(lambda x: x[:, :, :, :, 4], name='bg')(conv_out)

        model = Model(inputs=[inputs], outputs=[pz, cz, us, afs, bg])
        model.compile(optimizer=AdamWithWeightnorm(lr=learningRate, beta_1=0.9, beta_2=0.999),
                      loss={'pz': self.dice_coefficient_loss, 'cz': self.dice_coefficient_loss, 'us': self.dice_coefficient_loss,
                            'afs': self.dice_coefficient_loss, 'bg': self.dice_coefficient_loss},
                      
                      metrics={'pz': self.dice_coefficient, 'cz': self.dice_coefficient, 'us': self.dice_coefficient,
                               'afs': self.dice_coefficient, 'bg': self.dice_coefficient})
        
        return model





    

In [0]:
def train_model(epochs, learningRate, imgs, gt_list, val_imgs, val_gt_list, foldNr):

      name = 'UNet_zones_Fold' + str(foldNr) + '_LR_' + str(learningRate)

    # keras callbacks
      csv_logger = CSVLogger(name+'.csv', append=True, separator=';')
      model_checkpoint = ModelCheckpoint(name+'.h5', monitor='val_loss', save_best_only=True, verbose=1, mode='min')
      earlyStopImprovement = EarlyStopping(monitor='val_loss', min_delta=0.001, patience=100, verbose=1, mode='min')
      LRDecay = ReduceLROnPlateau(monitor='val_loss', factor=0.8, patience=25, verbose=1, mode='min', min_lr=1e-8,
                                epsilon=0.01)
      tensorboard = TensorBoard(log_dir='./zonal2_tensorboard_logs/', write_graph=False, write_grads=False, histogram_freq=0,
                              batch_size=5,
                              write_images=False)

      print('-' * 30)
      print('Creating and compiling model...')
      print('-' * 30)

      network = anisotopic_UNET()
      model = network.get_net(learningRate = LR, bn = True, do=True)
    # plot_model(model, to_file='model.png')

      print('-' * 30)
      print('Fitting model...')
      print('-' * 30)

      cb = [csv_logger, model_checkpoint, earlyStopImprovement, LRDecay]

      print('Callbacks: ', cb)

      history = model.fit(imgs, gt_list, batch_size=2, epochs=epochs,
                        verbose=1, validation_data=[val_imgs, val_gt_list], shuffle=True, callbacks=cb)
      model.save(name + '_final.h5')

      return history
    

In [0]:
from google.colab import drive
drive.mount('/content/drive/')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/drive/


In [0]:
from group_norm import GroupNormalization
if __name__ == '__main__':

    """
    TODO: 
    correct position of batch normalization. instead of using it after relu activation, place it between
    convolutions and activation function
    """

    #os.environ["CUDA_VISIBLE_DEVICES"] = str(sys.argv[3])
    #os.environ["CUDA_VISIBLE_DEVICES"] = '0'


    epochs = 400    #sys.argv[1]
    LR = 5e-4       #sys.argv[2]
    foldNr = 1      #sys.argv[4]
    '''
    #val_imgs = np.load('final_test_array_imgs.npy')
    val_imgs = np.load('image_npy_orig4.npy')
    print(val_imgs.shape)
    #val_gt = np.load('/home/suhita/zonals/data/test/image_npy_gt.npy')
    val_gt = np.load('image_npy_gt4.npy')

    print(val_gt.shape)
    val_gt = val_gt.astype(np.uint8)
    val_gt_list = [val_gt[:, :, :, :, 0], val_gt[:, :, :, :, 1], val_gt[:, :, :, :, 2], val_gt[:, :, :, :, 3],
                   val_gt[:, :, :, :, 4]]
    predict(val_imgs, val_gt_list, 'model.h5')
'''
    # load validation images and GT 
    val_imgs = np.load('drive/My Drive/prostrate/valArray_imgs_fold'+str(foldNr)+'.npy')
    print(val_imgs.shape)
    val_gt = np.load('drive/My Drive/prostrate/valArray_GT_fold'+str(foldNr)+'.npy')
    print(val_gt.shape)
    val_gt = val_gt.astype(np.uint8)
    val_gt_list = [val_gt[:, :, :, :, 0], val_gt[:, :, :, :, 1], val_gt[:, :, :, :, 2], val_gt[:, :, :, :, 3],
                   val_gt[:, :, :, :, 4]]

    # load train images and GT
    train_imgs = np.load(
        'drive/My Drive/prostrate/trainArray_imgs_fold' + str(foldNr) + '.npy')
    train_imgs_red = train_imgs[0:30,:,:,:,:]
    print(train_imgs_red.shape)
    train_gt = np.load(
        'drive/My Drive/prostrate/trainArray_GT_fold' + str(foldNr) + '.npy')
    train_gt = train_gt.astype(np.uint8)
    train_gt_red = train_gt[0:30,:,:,:,:]
    print(train_gt_red.shape)
    
    #train_gt_list_red = [train_gt_red[:, :, :, :, 0], train_gt_red[:, :, :, :, 1], train_gt_red[:, :, :, :, 2], train_gt_red[:, :, :, :, 3],
     #                train_gt_red[:, :, :, :, 4]]
    
    train_gt_list = [train_gt[:, :, :, :, 0], train_gt[:, :, :, :, 1], train_gt[:, :, :, :, 2], train_gt[:, :, :, :, 3],
                     train_gt[:, :, :, :, 4]]
    
    train_model(epochs, LR, train_imgs, train_gt_list, val_imgs, val_gt_list, foldNr)



(20, 32, 168, 168, 1)
(20, 32, 168, 168, 5)
(30, 32, 168, 168, 1)
(30, 32, 168, 168, 5)
------------------------------
Creating and compiling model...
------------------------------




------------------------------
Fitting model...
------------------------------
Callbacks:  [<keras.callbacks.CSVLogger object at 0x7fdacdcf8f28>, <keras.callbacks.ModelCheckpoint object at 0x7fdacdcf8940>, <keras.callbacks.EarlyStopping object at 0x7fdad1e876a0>, <keras.callbacks.ReduceLROnPlateau object at 0x7fdad1f277f0>]
Train on 58 samples, validate on 20 samples
Epoch 1/400

Epoch 00001: val_loss improved from inf to -1.28483, saving model to UNet_zones_Fold1_LR_0.0005.h5
Epoch 2/400

Epoch 00002: val_loss improved from -1.28483 to -1.60942, saving model to UNet_zones_Fold1_LR_0.0005.h5
Epoch 3/400

Epoch 00003: val_loss improved from -1.60942 to -1.72040, saving model to UNet_zones_Fold1_LR_0.0005.h5
Epoch 4/400

Epoch 00004: val_loss improved from -1.72040 to -1.82111, saving model to UNet_zones_Fold1_LR_0.0005.h5
Epoch 5/400

Epoch 00005: val_loss improved from -1.82111 to -1.93943, saving model to UNet_zones_Fold1_LR_0.0005.h5
Epoch 6/400

Epoch 00006: val_loss improved from -

In [0]:
def predict(val_imgs, val_gt_list, modelName):
    # modelName: e.g. 'model_LR_5e-5_BN_DO_0.5_fold1.h5'
    nrChanels = 1
    out_dir = 'predictions/'

    network = anisotopic_UNET()
    model = network.get_net(bn=True, do=False)
    model.load_weights(modelName)
    print(model.evaluate([val_imgs], val_gt_list , batch_size=1, verbose=1))
    out = model.predict([val_imgs], batch_size=2, verbose=1)
    

    np.save('predicted_' + modelName[:-3] + '.npy', out)

In [17]:
    epochs = 5  # sys.argv[1]
    LR = 5e-5  # sys.argv[2]
    foldNr = 1  # sys.argv[4]

    # val_imgs = np.load('/home/suhita/zonals/data/test_anneke/final_test_array_imgs.npy')
    val_imgs = np.load('drive/My Drive/prostrate/final_test_array_imgs.npy')
    print(val_imgs.shape)
    # val_gt = np.load('/home/suhita/zonals/data/test/image_npy_gt.npy')
    val_gt = np.load('drive/My Drive/prostrate/final_test_array_GT.npy')

    print(val_gt.shape)
    val_gt = val_gt.astype(np.uint8)
    val_gt_list = [val_gt[:, :, :, :, 0], val_gt[:, :, :, :, 1], val_gt[:, :, :, :, 2], val_gt[:, :, :, :, 3],
                   val_gt[:, :, :, :, 4]]
    predict(val_imgs, val_gt_list, 'UNet_zones_Fold1_LR_0.0005_final.h5')

(20, 32, 168, 168, 1)
(20, 32, 168, 168, 5)
[-4.125465369224548, -0.7731551110744477, -0.8570601373910904, -0.9271850526332855, -0.579103235900402, -0.9889618843793869, 0.7731551110744477, 0.8570601373910904, 0.9271850526332855, 0.579103235900402, 0.9889618843793869]


In [23]:
!pip install SimpleITK

import numpy as np
import sys
import SimpleITK as sitk
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import os
import csv
import math
import matplotlib.cm as cm
import utils


def getDice(prediction, groundTruth):
    filter = sitk.LabelOverlapMeasuresImageFilter()
    filter.Execute(prediction, groundTruth)
    dice = filter.GetDiceCoefficient()
    return dice


def relativeAbsoluteVolumeDifference(prediction, groundTruth):
    # get number of pixels in segmentation
    connectedFilter = sitk.ConnectedComponentImageFilter()
    connectedComponents = connectedFilter.Execute(prediction)
    labelFilter = sitk.LabelShapeStatisticsImageFilter()
    labelFilter.Execute(connectedComponents)
    x = labelFilter.GetNumberOfLabels()
    if x == 0:
        return 0.0
    pixelsPrediction = labelFilter.GetNumberOfPixels(1)

    connectedComponents = connectedFilter.Execute(groundTruth)
    labelFilter = sitk.LabelShapeStatisticsImageFilter()
    labelFilter.Execute(connectedComponents)
    pixelsGT = labelFilter.GetNumberOfPixels(1)

    # compute and return relative absolute Volume Difference
    return (abs((pixelsPrediction / pixelsGT) - 1)) * 100


def getBoundaryDistances(prediction, groundTruth):
    # get surfaces
    contourP = sitk.BinaryContour(prediction, fullyConnected=True)
    maxFilter = sitk.MinimumMaximumImageFilter()
    maxFilter.Execute(contourP)
    x = maxFilter.GetMaximum()
    if maxFilter.GetMaximum() == 0:
        return (0.0, 0.0)
    contourGT = sitk.BinaryContour(groundTruth, fullyConnected=True)

    contourP_dist = sitk.DanielssonDistanceMap(contourP, inputIsBinary=True, squaredDistance=False,
                                               useImageSpacing=True)
    contourGT_dist = sitk.DanielssonDistanceMap(contourGT, inputIsBinary=True, squaredDistance=False,
                                                useImageSpacing=True)

    # image with directed distance from prediction contour to GT contour
    contourP_masked = sitk.Mask(contourP_dist, contourGT)

    # image with directed distance from GT contour to predicted contour
    contourGT_masked = sitk.Mask(contourGT_dist, contourP)
    # sitk.WriteImage(contourGT_masked, 'contourGT_masked.nrrd')
    sitk.WriteImage(contourP_masked, 'contourP_masked.nrrd')

    contourP_arr = sitk.GetArrayFromImage(contourP)
    contourP_arr = contourP_arr.astype(np.bool)
    contourP_arr_inv = np.invert(contourP_arr)
    contourGT_arr = sitk.GetArrayFromImage(contourGT)
    contourGT_arr = contourGT_arr.astype(np.bool)
    contourGT_arr_inv = np.invert(contourGT_arr)

    dist_PredtoGT = sitk.GetArrayFromImage(contourP_masked)
    dist_PredtoGT = np.ma.masked_array(dist_PredtoGT, contourGT_arr_inv).compressed()

    dist_GTtoPred = sitk.GetArrayFromImage(contourGT_masked)
    dist_GTtoPred = np.ma.masked_array(dist_GTtoPred, contourP_arr_inv).compressed()

    hausdorff = max(np.percentile(dist_PredtoGT, 95), np.percentile(dist_GTtoPred, 95))
    distances = np.concatenate((dist_PredtoGT, dist_GTtoPred))
    mean = distances.mean()
    return (hausdorff, mean)


def evaluateFiles_zones(GT_array, pred_directory, csvName):
    with open(csvName + '.csv', 'w') as csvfile:
        csvwriter = csv.writer(csvfile, delimiter=';', lineterminator='\n',
                               quotechar='|', quoting=csv.QUOTE_MINIMAL)
        csvwriter.writerow(
            ['Case', 'PZ Dice', 'CZ Dice', 'US Dice', 'AFS Dice', 'PZ MeanDis', 'CZ MeanDis', 'US MeanDis',
             'AFS MeanDis'])

        nrImgs = GT_array.shape[0]
        dices = np.zeros((nrImgs, 4), dtype=np.float32)
        print(dices.shape)
        mad = np.zeros((nrImgs, 4), dtype=np.float32)

        for imgNumber in range(0, nrImgs):
            print('Case' + str(imgNumber))
            values = ['Case' + str(imgNumber)]
            temp_dice = []
            temp_mad = []

            for zoneIndex in range(0, 4):
                pred_arr = np.load(pred_directory + 'predicted_' + str(imgNumber) + '.npy')[zoneIndex]
                pred_arr = thresholdArray(pred_arr, 0.3)
                # pred_arr = pred_arr.astype(int)
                maxValue = np.max(pred_arr)
                pred_img = sitk.GetImageFromArray(pred_arr)

                GT_label = sitk.GetImageFromArray(GT_array[imgNumber, :, :, :, zoneIndex])
                #####pred_img = utils.resampleToReference(pred_img, GT_label, sitk.sitkNearestNeighbor, 0)
                pred_img = utils.castImage(pred_img, sitk.sitkUInt8)
                ####GT_label = utils.resampleToReference(GT_label, pred_img, sitk.sitkNearestNeighbor, 0)
                GT_label = utils.castImage(GT_label, sitk.sitkUInt8)

                sitk.WriteImage(pred_img, 'predImg.nrrd')
                sitk.WriteImage(GT_label, 'GT_label.nrrd')

                dice = getDice(pred_img, GT_label)
                temp_dice.append(dice)
                print(dice)
                # avd = relativeAbsoluteVolumeDifference(pred_img, GT_label)
                [hausdorff, avgDist] = getBoundaryDistances(pred_img, GT_label)
                temp_mad.append(avgDist)
                # values.append(dice)
                # values.append(avgDist)
                dices[imgNumber, zoneIndex] = dice
                mad[imgNumber, zoneIndex] = avgDist

            values.append(temp_dice[0])
            values.append(temp_dice[1])
            values.append(temp_dice[2])
            values.append(temp_dice[3])

            values.append(temp_mad[0])
            values.append(temp_mad[1])
            values.append(temp_mad[2])
            values.append(temp_mad[3])
            # values.append(temp_mad)
            csvwriter.writerow(values)

        csvwriter.writerow('')
        average = ['Average', np.average(dices[:, 0]), np.average(dices[:, 1]), np.average(dices[:, 2]),
                   np.average(dices[:, 3]), np.average(mad[:, 0]), np.average(mad[:, 1]), np.average(mad[:, 2]),
                   np.average(mad[:, 3])]
        median = ['Median', np.median(dices[:, 0]), np.median(dices[:, 1]), np.median(dices[:, 2]),
                  np.median(dices[:, 3]), np.median(mad[:, 0]), np.median(mad[:, 1]), np.median(mad[:, 2]),
                  np.median(mad[:, 3])]
        std = ['STD', np.std(dices[:, 0]), np.std(dices[:, 1]), np.std(dices[:, 2]),
               np.std(dices[:, 3]), np.std(mad[:, 0]), np.std(mad[:, 1]), np.std(mad[:, 2]),
               np.std(mad[:, 3])]

        csvwriter.writerow(average)
        csvwriter.writerow(median)
        csvwriter.writerow(std)

        print('Dices')
        print(np.average(dices[:, 0]))
        print(np.average(dices[:, 1]))
        print(np.average(dices[:, 2]))
        print(np.average(dices[:, 3]))

        print('Mean Dist')
        print(np.average(mad[:, 0]))
        print(np.average(mad[:, 1]))
        print(np.average(mad[:, 2]))
        print(np.average(mad[:, 3]))


def evaluateFiles(GT_directory, pred_directory, csvName):
    with open(csvName + '.csv', 'w') as csvfile:
        csvwriter = csv.writer(csvfile, delimiter=';', lineterminator='\n',
                               quotechar='|', quoting=csv.QUOTE_MINIMAL)
        csvwriter.writerow(['Case', 'Dice', 'Average Volume Difference', '95-Hausdorff', 'Avg Hausdorff'])

        cases = os.listdir(GT_directory)

        for case in cases:
            pred_img = sitk.ReadImage(pred_directory + case + '/predicted_CC.nrrd')
            print(case[:-1])
            GT_label = sitk.ReadImage(GT_directory + case + '/Segmentation-label_whole.nrrd')
            GT_label = utils.castImage(GT_label, sitk.sitkUInt8)
            pred_img = utils.resampleToReference(pred_img, GT_label, sitk.sitkNearestNeighbor, 0)
            pred_img = utils.castImage(pred_img, sitk.sitkUInt8)
            dice = getDice(pred_img, GT_label)
            print(dice)
            avd = relativeAbsoluteVolumeDifference(pred_img, GT_label)
            [hausdorff, avgDist] = getBoundaryDistances(pred_img, GT_label)

            csvwriter.writerow([case, dice, avd, hausdorff, avgDist])


def visualizeResults(directory, img_tra, img_cor, img_sag, pred_img, GT_img, i):
    if not os.path.exists(directory + 'visualResults'):
        os.makedirs(directory + 'visualResults')

    contour = sitk.BinaryContour(pred_img, True)
    colors = [(1, 0, 0), (0, 0, 0)]  # R -> G -> B
    n_bins = 5  # Discretizes the interpolation into bins
    cmap_name = 'my_list'

    contourGT = sitk.BinaryContour(GT_img, True)
    colorsGT = [(1, 1, 0), (0, 0, 0)]  # R -> G -> B
    cmap_nameGT = 'GT'

    contourArray = sitk.GetArrayFromImage(contour)
    contourArray = np.flip(contourArray, axis=0)
    contourArrayGT = sitk.GetArrayFromImage(contourGT)
    contourArrayGT = np.flip(contourArrayGT, axis=0)
    imgArray_tra = sitk.GetArrayFromImage(img_tra)
    imgArray_tra = np.flip(imgArray_tra, axis=0)
    imgArray_cor = sitk.GetArrayFromImage(img_cor)
    imgArray_cor = np.flip(imgArray_cor, axis=0)
    imgArray_sag = sitk.GetArrayFromImage(img_sag)
    imgArray_sag = np.flip(imgArray_sag, axis=0)

    masked_data = np.ma.masked_where(contourArray == 0, contourArray)
    masked_dataGT = np.ma.masked_where(contourArrayGT == 0, contourArrayGT)
    z = imgArray_tra.shape[0]

    plt.subplots(3, 3, figsize=(30, 30))

    colormap = LinearSegmentedColormap.from_list(cmap_name, colors, N=n_bins)
    colormapGT = LinearSegmentedColormap.from_list(cmap_nameGT, colorsGT, N=n_bins)
    plt.subplot(3, 3, 1)
    plt.imshow(imgArray_tra[int(z / 3), :, :], 'gray')
    plt.imshow(masked_data[int(z / 3), :, :], cmap=colormap, interpolation='none')
    plt.imshow(masked_dataGT[int(z / 3), :, :], cmap=colormapGT, interpolation='none')
    plt.axis('off')

    plt.subplot(3, 3, 2)
    plt.imshow(imgArray_tra[int(z / 2), :, :], 'gray')
    plt.imshow(masked_data[int(z / 2), :, :], cmap=colormap, interpolation='none')
    plt.imshow(masked_dataGT[int(z / 2), :, :], cmap=colormapGT, interpolation='none')
    plt.axis('off')

    plt.subplot(3, 3, 3)
    plt.imshow(imgArray_tra[int(2 * (z / 3)), :, :], 'gray')
    plt.imshow(masked_data[int(2 * (z / 3)), :, :], cmap=colormap, interpolation='none')
    plt.imshow(masked_dataGT[int(2 * (z / 3)), :, :], cmap=colormapGT, interpolation='none')
    plt.axis('off')

    plt.subplot(3, 3, 4)
    plt.imshow(imgArray_cor[:, int(z / 3), :], 'gray')
    plt.imshow(masked_data[:, int(z / 3), :], cmap=colormap, interpolation='none')
    plt.imshow(masked_dataGT[:, int(z / 3), :], cmap=colormapGT, interpolation='none')
    plt.axis('off')

    plt.subplot(3, 3, 5)
    plt.imshow(imgArray_cor[:, int(z / 2), :], 'gray')
    plt.imshow(masked_data[:, int(z / 2), :], cmap=colormap, interpolation='none')
    plt.imshow(masked_dataGT[:, int(z / 2), :], cmap=colormapGT, interpolation='none')
    plt.axis('off')

    plt.subplot(3, 3, 6)
    plt.imshow(imgArray_cor[:, int(2 * (z / 3)), :], 'gray')
    plt.imshow(masked_data[:, int(2 * (z / 3)), :], cmap=colormap, interpolation='none')
    plt.imshow(masked_dataGT[:, int(2 * (z / 3)), :], cmap=colormapGT, interpolation='none')
    plt.axis('off')

    plt.subplot(3, 3, 7)
    plt.imshow(imgArray_sag[:, :, int(z / 3)], 'gray')
    plt.imshow(masked_data[:, :, int(z / 3)], cmap=colormap, interpolation='none')
    plt.imshow(masked_dataGT[:, :, int(z / 3)], cmap=colormapGT, interpolation='none')
    plt.axis('off')

    plt.subplot(3, 3, 8)
    plt.imshow(imgArray_sag[:, :, int(z / 2)], 'gray')
    plt.imshow(masked_data[:, :, int(z / 2)], cmap=colormap, interpolation='none')
    plt.imshow(masked_dataGT[:, :, int(z / 2)], cmap=colormapGT, interpolation='none')
    plt.axis('off')

    plt.subplot(3, 3, 9)
    plt.imshow(imgArray_sag[:, :, int(2 * (z / 3))], 'gray')
    plt.imshow(masked_data[:, :, int(2 * (z / 3))], cmap=colormap, interpolation='none')
    plt.imshow(masked_dataGT[:, :, int(z / 3)], cmap=colormapGT, interpolation='none')
    plt.axis('off')

    plt.savefig(directory + 'visualResults/img_' + str(i) + '.png')


def visualizeResultsSmall(directory, img_tra, img_cor, img_sag, pred_img, GT_img, i):
    if not os.path.exists(directory + 'visualResults'):
        os.makedirs(directory + 'visualResults')

    contour = sitk.BinaryContour(pred_img, True)
    colors = [(1, 0, 0), (0, 0, 0)]  # R -> G -> B
    n_bins = 5  # Discretizes the interpolation into bins
    cmap_name = 'my_list'

    contourGT = sitk.BinaryContour(GT_img, True)
    colorsGT = [(1, 1, 0), (0, 0, 0)]  # R -> G -> B
    cmap_nameGT = 'GT'

    contourArray = sitk.GetArrayFromImage(contour)
    contourArray = np.flip(contourArray, axis=0)
    contourArrayGT = sitk.GetArrayFromImage(contourGT)
    contourArrayGT = np.flip(contourArrayGT, axis=0)
    imgArray_tra = sitk.GetArrayFromImage(img_tra)
    imgArray_tra = np.flip(imgArray_tra, axis=0)
    imgArray_cor = sitk.GetArrayFromImage(img_cor)
    imgArray_cor = np.flip(imgArray_cor, axis=0)
    imgArray_sag = sitk.GetArrayFromImage(img_sag)
    imgArray_sag = np.flip(imgArray_sag, axis=0)

    masked_data = np.ma.masked_where(contourArray == 0, contourArray)
    masked_dataGT = np.ma.masked_where(contourArrayGT == 0, contourArrayGT)
    z = imgArray_tra.shape[0]

    plt.subplots(3, 3, figsize=(30, 30))

    colormap = LinearSegmentedColormap.from_list(cmap_name, colors, N=n_bins)
    colormapGT = LinearSegmentedColormap.from_list(cmap_nameGT, colorsGT, N=n_bins)

    plt.subplot(1, 3, 1)
    plt.imshow(imgArray_tra[int(z / 2), :, :], 'gray')
    plt.imshow(masked_data[int(z / 2), :, :], cmap=colormap, interpolation='none')
    plt.imshow(masked_dataGT[int(z / 2), :, :], cmap=colormapGT, interpolation='none')
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(imgArray_sag[:, :, int(z / 2)], 'gray')
    plt.imshow(masked_data[:, :, int(z / 2)], cmap=colormap, interpolation='none')
    plt.imshow(masked_dataGT[:, :, int(z / 2)], cmap=colormapGT, interpolation='none')
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.imshow(imgArray_cor[:, int(2 * (z / 3)), :], 'gray')
    plt.imshow(masked_data[:, int(2 * (z / 3)), :], cmap=colormap, interpolation='none')
    plt.imshow(masked_dataGT[:, int(2 * (z / 3)), :], cmap=colormapGT, interpolation='none')
    plt.axis('off')

    plt.savefig(directory + 'visualResults/img_' + str(i) + '.png')


def regionBasedEvaluation(directory, csvName):
    with open(csvName + '.csv', 'w') as csvfile:
        csvwriter = csv.writer(csvfile, delimiter=';', lineterminator='\n',
                               quotechar='|', quoting=csv.QUOTE_MINIMAL)
        csvwriter.writerow(
            ['Case', 'Dice Apex', 'Dice Mid', 'Dice Base', 'AVD Apex', 'AVD Mid', 'AVD Base', '95-Hausdorff Apex',
             '95-Hausdorff Mid', '95-Hausdorff Base', 'Avg Dis Apex', 'Avg Dis Mid', 'Avg Dis Base'])

        for i in range(0, 15):
            predicted_array = np.load(directory + '/predicted_test_' + str(i) + '.npy')
            GT_array = np.load(directory + 'imgs_test_GT_' + str(i) + '.npy')
            GT_array = GT_array.astype(np.uint8)
            pred_img = sitk.GetImageFromArray(predicted_array[0, :, :, :, 0])

            pred_img = utils.binaryThresholdImage(pred_img, 0.5)
            GT_label = sitk.GetImageFromArray(GT_array[:, :, :, 0])
            pred_img = utils.getLargestConnectedComponents(pred_img)
            pred_img.SetSpacing([0.5, 0.5, 0.5])
            GT_label.SetSpacing([0.5, 0.5, 0.5])

            # get boundingboxes
            bb_P = utils.getBoundingBox(pred_img)
            bb_GT = utils.getBoundingBox(GT_label)
            startZ = max(0, min(bb_P[2], bb_GT[2]) - 1)

            sizeZ = max(bb_P[2] + bb_P[5], bb_GT[2] + bb_P[5]) - startZ + 1

            apex_P = sitk.RegionOfInterest(pred_img, [168, 168, math.floor(sizeZ / 3)],
                                           [0, 0, startZ])
            apex_GT = sitk.RegionOfInterest(GT_label, [168, 168, math.floor(sizeZ / 3)],
                                            [0, 0, startZ])
            mid_P = sitk.RegionOfInterest(pred_img, [168, 168, math.floor(sizeZ / 3)],
                                          [0, 0, startZ + math.floor(sizeZ / 3)])
            mid_GT = sitk.RegionOfInterest(GT_label, [168, 168, math.floor(sizeZ / 3)],
                                           [0, 0, startZ + math.floor(sizeZ / 3)])
            base_P = sitk.RegionOfInterest(pred_img, [168, 168, math.floor(sizeZ / 3)],
                                           [0, 0, startZ + 2 * math.floor(sizeZ / 3)])
            base_GT = sitk.RegionOfInterest(GT_label, [168, 168, math.floor(sizeZ / 3)],
                                            [0, 0, startZ + 2 * math.floor(sizeZ / 3)])

            sitk.WriteImage(apex_P, 'apex_P.nrrd')
            sitk.WriteImage(apex_GT, 'apex_GT.nrrd')
            sitk.WriteImage(mid_P, 'mid_P.nrrd')
            sitk.WriteImage(mid_GT, 'mid_GT.nrrd')
            sitk.WriteImage(base_P, 'base_P.nrrd')
            sitk.WriteImage(base_GT, 'base_GT.nrrd')

            dice_apex = getDice(apex_P, apex_GT)
            avd_apex = relativeAbsoluteVolumeDifference(apex_P, apex_GT)
            [hausdorff_apex, avgDist_apex] = getBoundaryDistances(apex_P, apex_GT)

            dice_mid = getDice(mid_P, mid_GT)
            avd_mid = relativeAbsoluteVolumeDifference(mid_P, mid_GT)
            [hausdorff_mid, avgDist_mid] = getBoundaryDistances(mid_P, mid_GT)

            dice_base = getDice(base_P, base_GT)
            avd_base = relativeAbsoluteVolumeDifference(base_P, base_GT)
            [hausdorff_base, avgDist_base] = getBoundaryDistances(base_P, base_GT)

            csvwriter.writerow(
                ['Case' + str(i), dice_apex, dice_mid, dice_base, avd_apex, avd_mid, avd_base, hausdorff_apex,
                 hausdorff_mid, hausdorff_base, avgDist_apex, avgDist_mid, avgDist_base])

            print(i)


def thresholdArray(array, threshold):
    # threshold image
    array[array < threshold] = 0
    array[array >= threshold] = 1
    array = np.asarray(array, np.int16)

    return array


def removeSegmentationsInImagePaddedRegion(array_test, array_pred):
    for i in range(0, array_test.shape[0]):
        if np.count_nonzero(array_test[i, :, :, 0]) == 0:
            array_pred[:, i, :, :] = 0


def getConnectedComponents(predictionImage):
    pred_img = utils.castImage(predictionImage, sitk.sitkInt8)
    pred_img_cc = utils.getLargestConnectedComponents(pred_img)
    pred_img_cc = utils.castImage(pred_img_cc, sitk.sitkInt8)

    img_isl = sitk.Subtract(pred_img, pred_img_cc)

    return pred_img_cc, img_isl


def removeIslands(predictedArray):
    pred = predictedArray
    print(pred.shape)
    pred_pz = thresholdArray(pred[0, :, :, :], 0.5)
    pred_cz = thresholdArray(pred[1, :, :, :], 0.5)
    pred_us = thresholdArray(pred[2, :, :, :], 0.5)
    pred_afs = thresholdArray(pred[3, :, :, :], 0.5)
    pred_bg = thresholdArray(pred[4, :, :, :], 0.5)

    pred_pz_img = sitk.GetImageFromArray(pred_pz)
    pred_cz_img = sitk.GetImageFromArray(pred_cz)
    pred_us_img = sitk.GetImageFromArray(pred_us)
    pred_afs_img = sitk.GetImageFromArray(pred_afs)
    pred_bg_img = sitk.GetImageFromArray(pred_bg)
    # pred_bg_img = utils.castImage(pred_bg, sitk.sitkInt8)

    pred_pz_img_cc, pz_otherCC = getConnectedComponents(pred_pz_img)
    pred_cz_img_cc, cz_otherCC = getConnectedComponents(pred_cz_img)
    pred_us_img_cc, us_otherCC = getConnectedComponents(pred_us_img)
    pred_afs_img_cc, afs_otherCC = getConnectedComponents(pred_afs_img)
    pred_bg_img_cc, bg_otherCC = getConnectedComponents(pred_bg_img)

    added_otherCC = sitk.Add(afs_otherCC, pz_otherCC)
    added_otherCC = sitk.Add(added_otherCC, cz_otherCC)
    added_otherCC = sitk.Add(added_otherCC, us_otherCC)
    added_otherCC = sitk.Add(added_otherCC, bg_otherCC)

    # sitk.WriteImage(added_otherCC, 'addedOtherCC.nrrd')
    # sitk.WriteImage(pred_cz_img, 'pred_cz.nrrd')

    pz_dis = sitk.SignedMaurerDistanceMap(pred_pz_img_cc, insideIsPositive=True, squaredDistance=False,
                                          useImageSpacing=False)
    cz_dis = sitk.SignedMaurerDistanceMap(pred_cz_img_cc, insideIsPositive=True, squaredDistance=False,
                                          useImageSpacing=False)
    us_dis = sitk.SignedMaurerDistanceMap(pred_us_img_cc, insideIsPositive=True, squaredDistance=False,
                                          useImageSpacing=False)
    afs_dis = sitk.SignedMaurerDistanceMap(pred_afs_img_cc, insideIsPositive=True, squaredDistance=False,
                                           useImageSpacing=False)
    bg_dis = sitk.SignedMaurerDistanceMap(pred_bg_img_cc, insideIsPositive=True, squaredDistance=False,
                                          useImageSpacing=False)

    # sitk.WriteImage(pred_cz_img_cc, 'pred_cz_cc.nrrd')
    # sitk.WriteImage(cz_dis, 'cz_dis.nrrd')

    array_pz = sitk.GetArrayFromImage(pred_pz_img_cc)
    array_cz = sitk.GetArrayFromImage(pred_cz_img_cc)
    array_us = sitk.GetArrayFromImage(pred_us_img_cc)
    array_afs = sitk.GetArrayFromImage(pred_afs_img_cc)
    array_bg = sitk.GetArrayFromImage(pred_bg_img_cc)

    finalPrediction = np.zeros([5, 32, 168, 168])
    finalPrediction[0] = array_pz
    finalPrediction[1] = array_cz
    finalPrediction[2] = array_us
    finalPrediction[3] = array_afs
    finalPrediction[4] = array_bg

    array = np.zeros([1, 1, 1, 1])

    for x in range(0, pred_cz_img.GetSize()[0]):
        for y in range(0, pred_cz_img.GetSize()[1]):
            for z in range(0, pred_cz_img.GetSize()[2]):

                pos = [x, y, z]
                if (added_otherCC[pos] > 0):
                    # print(pz_dis.GetPixel(x,y,z),cz_dis.GetPixel(x,y,z),us_dis.GetPixel(x,y,z), afs_dis.GetPixel(x,y,z))
                    array = [pz_dis.GetPixel(x, y, z), cz_dis.GetPixel(x, y, z), us_dis.GetPixel(x, y, z),
                             afs_dis.GetPixel(x, y, z), bg_dis.GetPixel(x, y, z)]
                    maxValue = max(array)
                    max_index = array.index(maxValue)
                    finalPrediction[max_index, z, y, x] = 1

    return finalPrediction


def postprocesAndEvaluateFiles(name, GT_array, csvName):
    prediction = np.load(name + '.npy')
    print(GT_array.shape)
    print(prediction.shape)

    outDir = name[:-3] + '/'
    if not os.path.exists(outDir):
        os.makedirs(outDir)

    for i in range(0, prediction.shape[1]):
        print(i)
        array = removeIslands(prediction[:, i, :, :, :])
        # print('preditction', prediction.shape)
        # array = prediction[:, i, :, :, :]
        np.save(outDir + 'predicted_' + str(i), array)

    evaluateFiles_zones(GT_array, pred_directory=outDir, csvName=csvName)


if __name__ == '__main__':
    name = 'predicted_UNet_zones_Fold1_LR_0.0005_final'
    GT_array_name = 'drive/My Drive/prostrate/final_test_array_GT.npy'
    csvName = 'evaluation_sad.csv'
    GT_array = np.load(GT_array_name)

    # weights epochs LR gpu_id dist orient prediction LRDecay earlyStop
    postprocesAndEvaluateFiles(name, GT_array, csvName)


(20, 32, 168, 168, 5)
(5, 20, 32, 168, 168)
0
(5, 32, 168, 168)
1
(5, 32, 168, 168)
2
(5, 32, 168, 168)
3
(5, 32, 168, 168)
4
(5, 32, 168, 168)
5
(5, 32, 168, 168)
6
(5, 32, 168, 168)
7
(5, 32, 168, 168)
8
(5, 32, 168, 168)
9
(5, 32, 168, 168)
10
(5, 32, 168, 168)
11
(5, 32, 168, 168)
12
(5, 32, 168, 168)
13
(5, 32, 168, 168)
14
(5, 32, 168, 168)
15
(5, 32, 168, 168)
16
(5, 32, 168, 168)
17
(5, 32, 168, 168)
18
(5, 32, 168, 168)
19
(5, 32, 168, 168)
(20, 4)
Case0
0.8066826996878678
0.8188700831118141
0.6883160361566789
0.5596961784951341
Case1
0.6894690573058287
0.8872335563091025
0.6938394523957685
0.38319783197831975
Case2
0.7647625398236646
0.7454995054401583
0.5904821540388228
0.3126050420168067
Case3
0.7662292125826486
0.883408119772535
0.6986016066646831
0.4259534422981674
Case4
0.7733268367382452
0.9172325132939871
0.7919923736892278
0.35330836454431963
Case5
0.8544139291641876
0.8665952210843852
0.5999129299085764
0.5703787450537027
Case6
0.7795302139137602
0.7852642632131607
0