In [None]:
WIDTH, HEIGHT = 1024, 1024
NUM_CLASSES = 3 # 0. Background, 1, Exudatas(Hard + Soft), 2. Red Dot + Hemorrhages + Microaneurysms
BATCH_SIZE = 1


EPOCHS = 50

# On Colab
MODEL_DIR = "./Diabetic Retinopathy/"
BASE_DIR = './DR_data/'

palette = [[0],[100],[200]]#100：(green)Red Dot + Hemorrhages + Microaneurysms 200：(yellow)Exudatas(Hard + Soft)
category_types = ["Background", "EX", "RHM"]

# Construct Model

In [None]:
from keras.preprocessing import image
from keras.models import Model, load_model, Sequential
from keras import backend as K
from keras.utils import np_utils
from keras.preprocessing.image import img_to_array
from sklearn.preprocessing import LabelEncoder
from keras import metrics
from keras.losses import binary_crossentropy

import matplotlib.pyplot as plt
from tensorflow import keras

# from keras.layers.Layers import Conv2D, MaxPooling2D, UpSampling2D, BatchNormalization, Reshape, Permute, Activation, Input
# from keras.keras.layers import DepthwiseConv2D, ZeroPadding2D, AveragePooling2D, Concatenate, Dropout, Conv2DTranspose
# from keras import layers
from keras.utils.np_utils import to_categorical
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping, CSVLogger
# from keras.layers.merge import concatenate
from PIL import Image
from tensorflow.keras.optimizers import Adam
import tensorflow as tf
from tensorflow.keras.applications import ResNet50

import matplotlib as mpl
import seaborn as sns

import os
import random
import numpy as np
import cv2
import matplotlib.pyplot as plt
import shutil
import pandas as pd
import time
from tqdm import *

In [None]:
def ASPP(tensor):
    '''atrous spatial pyramid pooling'''
    dims = K.int_shape(tensor)

    y_pool = keras.layers.AveragePooling2D(pool_size=(dims[1], dims[2]), name='average_pooling')(tensor)
    y_pool = keras.layers.Conv2D(filters=512, kernel_size=1, padding='same',
                    kernel_initializer='he_normal', name='pool_1x1conv2d', use_bias=False)(y_pool)
#     print(y_pool.shape)# (None, 1, 1, 256)
    y_pool = keras.layers.BatchNormalization(name=f'bn_1')(y_pool)
    y_pool = keras.layers.Activation('relu', name=f'relu_1')(y_pool)

    # y_pool = Upsample(tensor=y_pool, size=[dims[1], dims[2]])
    y_pool = keras.layers.Conv2DTranspose(filters=512, kernel_size=(2, 2), 
                  kernel_initializer='he_normal', dilation_rate=1024 // 16 - 1)(y_pool)

    y_1 = keras.layers.Conv2D(filters=512, kernel_size=1, dilation_rate=1, padding='same',
                 kernel_initializer='he_normal', name='ASPP_conv2d_d1', use_bias=False)(tensor)
    y_1 = keras.layers.BatchNormalization(name=f'bn_2')(y_1)
    y_1 = keras.layers.Activation('relu', name=f'relu_2')(y_1)

    y_6 = keras.layers.Conv2D(filters=512, kernel_size=3, dilation_rate=6, padding='same',
                 kernel_initializer='he_normal', name='ASPP_conv2d_d6', use_bias=False)(tensor)
    y_6 = keras.layers.BatchNormalization(name=f'bn_3')(y_6)
    y_6 = keras.layers.Activation('relu', name=f'relu_3')(y_6)

    y_12 = keras.layers.Conv2D(filters=512, kernel_size=3, dilation_rate=12, padding='same',
                  kernel_initializer='he_normal', name='ASPP_conv2d_d12', use_bias=False)(tensor)
    y_12 = keras.layers.BatchNormalization(name=f'bn_4')(y_12)
    y_12 = keras.layers.Activation('relu', name=f'relu_4')(y_12)

    y_18 = keras.layers.Conv2D(filters=512, kernel_size=3, dilation_rate=18, padding='same',
                  kernel_initializer='he_normal', name='ASPP_conv2d_d18', use_bias=False)(tensor)
    y_18 = keras.layers.BatchNormalization(name=f'bn_5')(y_18)
    y_18 = keras.layers.Activation('relu', name=f'relu_5')(y_18)

    y = keras.layers.concatenate([y_pool, y_1, y_6, y_12, y_18], name='ASPP_concat')

    y = keras.layers.Conv2D(filters=512, kernel_size=1, dilation_rate=1, padding='same',
               kernel_initializer='he_normal', name='ASPP_conv2d_final', use_bias=False)(y)
    y = keras.layers.BatchNormalization(name=f'bn_final')(y)
    y = keras.layers.Activation('relu', name=f'relu_final')(y)
    return y


def DeepLabV3Plus(img_height=1024, img_width=1024, nclasses=3):
#     print('*** Building DeepLabv3Plus Network ***')

    base_model = keras.applications.ResNet50(input_shape=(img_height, img_width, 3), weights='imagenet', include_top=False)
    
    image_features = base_model.get_layer('conv4_block6_out').output
    x_a = ASPP(image_features)
    # x_a = Upsample(tensor=x_a, size=[img_height // 4, img_width // 4])
    x_a = keras.layers.Conv2DTranspose(filters=512, kernel_size=(2, 2), 
                          kernel_initializer='he_normal', dilation_rate=(img_height // 16 * 3))(x_a)
#     print('hhhhhhhhhhhhhhhhhhh',x_a.shape)# (None, 128, 128, 256)

    x_b = base_model.get_layer('conv2_block3_out').output
    x_b = keras.layers.Conv2D(filters=96, kernel_size=1, padding='same',
                 kernel_initializer='he_normal', name='low_level_projection', use_bias=False)(x_b)
#     print('cccccccccccccccc',x_b.shape)# (None, 128, 128, 48)
    x_b = keras.layers.BatchNormalization(name=f'bn_low_level_projection')(x_b)
    x_b = keras.layers.Activation('relu', name='low_level_activation')(x_b)

    x = keras.layers.concatenate([x_a, x_b], name='decoder_concat')

    x = keras.layers.Conv2D(filters=512, kernel_size=3, padding='same', activation='relu',
               kernel_initializer='he_normal', name='decoder_conv2d_1', use_bias=False)(x)
#     print('gggggggggggggg',x.shape)# (None, 128, 128, 256)
    x = keras.layers.BatchNormalization(name=f'bn_decoder_1')(x)
    x = keras.layers.Activation('relu', name='activation_decoder_1')(x)

    x = keras.layers.Conv2D(filters=512, kernel_size=3, padding='same', activation='relu',
               kernel_initializer='he_normal', name='decoder_conv2d_2', use_bias=False)(x)
#     print('lllllllllllll',x.shape)# (None, 128, 128, 256)
    x = keras.layers.BatchNormalization(name=f'bn_decoder_2')(x)
    x = keras.layers.Activation('relu', name='activation_decoder_2')(x)
    # x = Upsample(x, [img_height, img_width])
    x = keras.layers.Conv2DTranspose(filters=512, kernel_size=(2, 2), 
                        kernel_initializer='he_normal', dilation_rate=img_height // 4 * 3)(x)
#     print("nnnnnnnnnnnnnnnnnnnn",x.shape)# (None, 512, 512, 256)

    x = keras.layers.Conv2D(nclasses, (1, 1), name='output_layer')(x)
#     print('rrrrrrrrrrrrrrrrr',x.shape)# (None, 512, 512, 1)
    x = keras.layers.Activation('softmax')(x) 
#     print('qqqqqqqqqqqqqqqqqqqq',x.shape)# (None, 512, 512, 1)
    '''
    x = Activation('softmax')(x) 
    tf.losses.SparseCategoricalCrossentropy(from_logits=True)
    Args:
        from_logits: Whether `y_pred` is expected to be a logits tensor. By default,
        we assume that `y_pred` encodes a probability distribution.
    '''     
    model = keras.models.Model(inputs=base_model.input, outputs=x, name='DeepLabV3_Plus')
#     print(f'*** Output_Shape => {model.output_shape} ***')
    return model
# DeepLabV3Plus(nclasses=3)

# Loss Function

In [None]:
def Dice(y_true, y_pred, solo = True, num_classes = NUM_CLASSES):
    """
    Dice
    
    Dice = 2TP / (2TP + FP + FN) = 2|X∩Y| / (|X| + |Y|) 
         = sum(2 X*Y) / (sum(X) +sum(Y))
    """ 
    smooth = 0.0001
    if solo:
        y_true = tf.one_hot(tf.cast(y_true, dtype=tf.int32),
                            depth=num_classes,
                            dtype=tf.float32,
                            )
        
    numerator = 2 * tf.reduce_sum(y_true * y_pred)
    denominator = tf.reduce_sum(y_true + y_pred)
    return (numerator + smooth) / (denominator + smooth)

In [None]:
def Jaccard(y_true, y_pred, solo = True, num_classes = NUM_CLASSES):
    """
    Jaccard
    
    IoU = TP / (TP + FP + FN) = |X∩Y| / ( |X| + |Y| - |X∩Y| )
        = sum(A*B) / (sum(A)+sum(B)-sum(A*B))
    """
    smooth = 0.0001
    if solo:
        y_true = tf.one_hot(tf.cast(y_true, dtype=tf.int32),
                            depth=num_classes,
                            dtype=tf.float32,
                            )
    intersection = tf.reduce_sum(y_true * y_pred, axis=-1)
    sum_ = tf.reduce_sum(y_true + y_pred, axis=-1)
    return ((intersection + smooth) / (sum_ - intersection + smooth))

In [None]:
class CategoricalFocalLoss(tf.losses.Loss):
    """
    식 : loss = - y_true * alpha * ((1 - y_pred)^gamma) * log(y_pred)
        
    alpha: the same as weighting factor in balanced cross entropy, default 0.25
    gamma: focusing parameter for modulating factor (1-p), default 2.0

    y_true =  [[0., 1.0, 0.], [0., 0., 1.], [0., 1., 0.]]
    y_pred = [[0.70, 0.15, 0.15], [0.1, 0.8, 0.1], [0.25, 0.65, 0.1]]
    y_true = tf.cast(y_true, dtype= "float32")
    y_pred = tf.cast(y_pred, dtype= "float32")
    gamma=3.0
    alpha=0.25
    """
    def __init__(self, solo = True, num_classes = NUM_CLASSES, gamma = 2.0, alpha=0.25):
        super(CategoricalFocalLoss, self).__init__(reduction = 'auto', name = "CategoricalFocalLoss")
        self._num_classes = num_classes
        self._gamma = gamma
        self._alpha = alpha
        self._epsilon = 1e-07
        self.solo = solo
        
    def call(self, y_true, y_pred):
        if self.solo:
            y_true = tf.one_hot(tf.cast(y_true, dtype=tf.int32),
                                depth=self._num_classes,
                                dtype=tf.float32,
                                )
#         print("y_trueeeeeee",y_true.shape)
#         print("y_preddddddd",y_pred.shape)
        y_pred = tf.clip_by_value(y_pred, self._epsilon, 1.0 - self._epsilon)      
        loss = - y_true * self._alpha * tf.math.pow((1 - y_pred), self._gamma) * tf.math.log(y_pred)
        """
        Another Code
        alpha = tf.where(tf.equal(y_true, 1.), alpha, (1.0 - self._alpha))
        pt = tf.where(tf.equal(y_true, 1.), y_pred, 1-y_pred)
        y_pred = tf.add(y_pred, self._epsilon)
        loss = alpha * tf.pow(1.0 - pt, self._gamma) * tf.multiply(y_true, -tf.math.log(y_pred))
        """
        return tf.reduce_mean(loss)

In [None]:
#--------------Dice loss------------
def dice_coef_fun(smooth=1,solo = True):
    def dice_coef(y_true, y_pred):
        if solo:
            y_true = tf.one_hot(tf.cast(y_true, dtype=tf.int32),
                                depth=NUM_CLASSES,
                                dtype=tf.float32,
                                )
        
        intersection = K.sum(y_true * y_pred, axis=(1,2))
        union = K.sum(y_true, axis=(1,2)) + K.sum(y_pred, axis=(1,2))
        sample_dices=(2. * intersection + smooth) / (union + smooth) 
        
        dices=K.mean(sample_dices,axis=0)
        return K.mean(dices)
    return dice_coef
 
def dice_coef_loss_fun(smooth=0):
    def dice_coef_loss(y_true,y_pred):
        return 1-dice_coef_fun(smooth=smooth)(y_true=y_true,y_pred=y_pred)#1-1-?
    return dice_coef_loss

# Dataset

In [None]:
from PIL import Image
class Dataset_Generator():
    def __init__(self,
                 base_dir = BASE_DIR,
                 num_classes = NUM_CLASSES,
                 batch_size = BATCH_SIZE,
                 height = HEIGHT,
                 width = WIDTH,
                 epochs = EPOCHS,
                ):
        
        self.base_dir = BASE_DIR
        self.num_classes = float(num_classes)
        self.batch_size = batch_size
        self.height = HEIGHT
        self.width = WIDTH
        self.epochs = epochs
        self.class_values = list(range(len(category_types)))
        #self.images_list = []
        self.images_list = os.listdir(self.base_dir + "Training/images/")
        random.shuffle(self.images_list)
        
    def __del__(self):
        print("Dataset Generator is destructed")
            
    def _preprocessor(self):
        
        try:
            os.mkdir(self.base_dir+"Training")
            os.mkdir(self.base_dir+"Test")
            os.mkdir(self.base_dir+"Training/images")
            os.mkdir(self.base_dir+"Test/images")
            os.mkdir(self.base_dir+"Training/masks")
            os.mkdir(self.base_dir+"Test/masks")
        except FileExistsError:
            pass
        
        idrid_cnt = diaretdb_cnt = 0 

        image_list = os.listdir(self.base_dir + "image/")
        for i, file_name in enumerate(image_list):
            image_list[i] = file_name.split(".")[0]
        image_list.sort()

        mask_class_dir = ["MA", "HE", "EX", "SE"]
        mask_file_list = []

        for cls in mask_class_dir:
            mask_file_list.append(os.listdir(self.base_dir + f"mask/{cls}"))

        zero_1 = np.zeros([2848, 4288], dtype = np.uint8)
        zero_2 = np.zeros([1152, 1500], dtype = np.uint8)

        loss_cnt = 0

        for i, file_name in enumerate(image_list):
            if "IDRiD" in file_name:
                zero = zero_1
                thres = 1
            elif "image" in file_name:
                zero = zero_2
                # [63, 127, 189, 252]
                thres = 127

            mask_list = []
        
            for cls in range(4):
              
                if "IDRiD" in file_name:
                    mask_file_name = f"{file_name}_{mask_class_dir[cls]}.tif"
                elif "image" in file_name:
                    mask_file_name = f"{file_name}.png"

             
                if mask_file_name in mask_file_list[cls]:
                    mask = cv2.imread(f"{self.base_dir}mask/{mask_class_dir[cls]}/{mask_file_name}", 0)
                    _, mask = cv2.threshold(mask, thres, 1, cv2.THRESH_BINARY)
                else:
                    mask = zero
                mask_list.append(mask)

            Class_1 = cv2.bitwise_or(mask_list[0], mask_list[1]) * 100
            Class_2 = cv2.bitwise_or(mask_list[2], mask_list[3]) * 200
            mask = Class_1 + Class_2
            del Class_1, Class_2, mask_list

            if np.all(mask == zero):
                loss_cnt += 1
                print(f"{file_name} has no mask")
            else:
       
                if "IDRiD" in file_name:
                    file_name = f"{file_name}.jpg"
                elif "image" in file_name:
                    file_name = f"{file_name}.png"
                img = cv2.imread(f"{self.base_dir}image/{file_name}")

                if "IDRiD" in file_name:
                    gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
                    if i==3 or i == 10:
                        thres = 10
                    else:
                        thres = 30

                    _, binary_img = cv2.threshold(gray_img, thres, 255, cv2.THRESH_BINARY)
                    del gray_img

                    # contours
                    contours, hierachy = cv2.findContours(binary_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)

    
                    x_min = np.min(contours[-1], axis = 0)
                    x_max = np.max(contours[-1], axis = 0)
                    x_min, x_max = x_min[0][0], x_max[0][0]
                    del contours, hierachy

                    # Crop
                    img = img[:, x_min:x_max+1]
                    mask = mask[:, x_min:x_max+1]

                    # Padding
                    if (x_max-x_min)/2848 >= 1.25:
                        pad_left, pad_right = 0, 0
                    else:
                        pad_left, pad_right = 200, 200

                    img = cv2.copyMakeBorder(img, 500, 500, pad_left, pad_right, cv2.BORDER_CONSTANT,value=0)
                    mask = cv2.copyMakeBorder(mask, 500, 500, pad_left, pad_right, cv2.BORDER_CONSTANT,value=0)

                # diaretdb
                elif "image" in file_name:
                    img = cv2.copyMakeBorder(img, 174, 174, 0, 0, cv2.BORDER_CONSTANT,value=0)
                    mask = cv2.copyMakeBorder(mask, 174, 174, 0, 0, cv2.BORDER_CONSTANT,value=0)

                # Resize
                img = cv2.resize(img, dsize=(self.height, self.width), interpolation=cv2.INTER_AREA)
                mask = cv2.resize(mask, dsize=(self.height, self.width), interpolation=cv2.INTER_AREA)
                
                if "IDRiD" in file_name and idrid_cnt < 60:
                    cv2.imwrite(f'{self.base_dir}Training/images/{file_name}', img)
                    cv2.imwrite(f'{self.base_dir}Training/masks/{file_name}', mask)
                    idrid_cnt += 1

                elif "IDRiD" in file_name and idrid_cnt >= 60:
                    cv2.imwrite(f'{self.base_dir}Test/images/{file_name}', img)
                    cv2.imwrite(f'{self.base_dir}Test/masks/{file_name}', mask)
                    idrid_cnt += 1

                elif "image" in file_name and diaretdb_cnt < 40:
                    cv2.imwrite(f'{self.base_dir}Training/images/{file_name}', img)
                    cv2.imwrite(f'{self.base_dir}Training/masks/{file_name}', mask)
                    diaretdb_cnt += 1

                elif "image" in file_name and diaretdb_cnt >= 40:
                    cv2.imwrite(f'{self.base_dir}Test/images/{file_name}', img)
                    cv2.imwrite(f'{self.base_dir}Test/masks/{file_name}', mask)
                    diaretdb_cnt += 1

                print(f"{file_name} completed!")
        self.images_list = os.listdir(self.base_dir + "Training/images/")
        random.shuffle(self.images_list)
        print(f"Preprocessing completed!. Number of no mask data : {loss_cnt}")
    
    def _Image_Reshape(self, image, mask):
       
        image = np.reshape(image, ((self.batch_size,) + image.shape))
        mask = np.reshape(mask, ((self.batch_size,) + mask.shape))

        return (image/255, mask/100)
    
    def train_generator(self, k):
        """
        Training Data Augmentation
        """
        if self.images_list:
            pass
        else:
            self._preprocessor()
        x_center, y_center = self.width/2, self.height/2

        for _ in range(self.epochs):
            for i, file_name in enumerate(self.images_list):
                if 20*k-20 <= i < 20*k:
                    pass
                else:
             
                    img = cv2.imread(f"{self.base_dir}Training/images/{file_name}")
                    mask = cv2.imread(f"{self.base_dir}Training/masks/{file_name}", 0)
                
                    yield self._Image_Reshape(img, mask)

                    flip_img = cv2.flip(img, 1)
                    flip_mask = cv2.flip(mask, 1)
                    yield self._Image_Reshape(flip_img, flip_mask)

                    for degree in range(90, 360, 90):
                        matrix = cv2.getRotationMatrix2D((x_center, y_center), degree, 1)
                            
                        rot_img = cv2.warpAffine(img, matrix, (self.width, self.height))
                        rot_mask = cv2.warpAffine(mask, matrix, (self.width, self.height))
                        yield self._Image_Reshape(rot_img, rot_mask)

                        # filp 
                        rot_flip_img = cv2.warpAffine(flip_img, matrix, (self.width, self.height))
                        rot_flip_mask = cv2.warpAffine(flip_mask, matrix, (self.width, self.height))
                        yield self._Image_Reshape(rot_flip_img, rot_flip_mask)


    def valid_generator(self, k):
        """
        Validataion Data Augmentation
        """
        x_center, y_center = self.width/2, self.height/2
        for _ in range(self.epochs):
            for i, file_name in enumerate(self.images_list):
  
                    img = cv2.imread(f"{self.base_dir}Training/images/{file_name}")
                    mask = cv2.imread(f"{self.base_dir}Training/masks/{file_name}", 0)
                                             
                    yield self._Image_Reshape(img, mask)

                    flip_img = cv2.flip(img, 1)
                    flip_mask = cv2.flip(mask, 1)
                    yield self._Image_Reshape(flip_img, flip_mask)

                    for degree in range(90, 360, 90):
                        matrix = cv2.getRotationMatrix2D((x_center, y_center), degree, 1)
          
                        rot_img = cv2.warpAffine(img, matrix, (self.width, self.height))
                        rot_mask = cv2.warpAffine(mask, matrix, (self.width, self.height))
                        yield self._Image_Reshape(rot_img, rot_mask)

                        # rot_filp 
                        rot_flip_img = cv2.warpAffine(flip_img, matrix, (self.width, self.height))
                        rot_flip_mask = cv2.warpAffine(flip_mask, matrix, (self.width, self.height))
                        yield self._Image_Reshape(rot_flip_img, rot_flip_mask)


                
    def test_generator(self):
        images_list = os.listdir(self.base_dir + "Test/images/")
        for i, file_name in enumerate(images_list):
         
            img = cv2.imread(f"{self.base_dir}Test/images/{file_name}")
            mask = cv2.imread(f"{self.base_dir}Test/masks/{file_name}",0)
#             print("111111",mask.shape)# (1024,1024)
                     
                    
            yield self._Image_Reshape(img, mask)
            

# Training the Model and Evaluation

In [None]:
import datetime
class MODEL():
    def __init__(self,
                 model_dir = MODEL_DIR,
                 batch_size = BATCH_SIZE,
                 width = WIDTH,
                 height = HEIGHT,
                 k = 0,
                ):
        
        self.batch_size = batch_size
        self.loss_fn = dice_coef_loss_fun()#CategoricalFocalLoss()#SoftDiceLoss(NUM_CLASSES)

        self.optimizer =  tf.keras.optimizers.Adam(learning_rate=0.00005)# 0.00005
        self.generator = Dataset_Generator()
        self.model_dir = MODEL_DIR
        self.optimal_k = k

        self.test_dataset = tf.data.Dataset.from_generator(
                        dataset.test_generator,
                        (tf.float32, tf.float32),
                        (tf.TensorShape([1, HEIGHT, WIDTH, 3]), tf.TensorShape([1,  HEIGHT, WIDTH])),
                        )
        self.width = width
        self.height = height
    
    def __del__(self):
        print("MODEL is destructed")

    def Run_training(self, epochs= EPOCHS):        
        print("Model Complie....")        
        
        # K-fold: k = 5
        K = 5
        mean_Dice = mean_IoU = 0
        DiceIoU_list = []
        for k in range(1, K+1):
            model = DeepLabV3Plus()# Unet()
            model.compile(loss =self.loss_fn,#'mse', 'categorical_crossentropy'
                          optimizer = self.optimizer,
                          metrics = [Dice, Jaccard]
                          )
#             model.summary()
            callbacks_list = [tf.keras.callbacks.ModelCheckpoint(
                                        filepath=os.path.join(
                                            f"{self.model_dir}dlbv3plus-Net_{k}.h5"),
                                        monitor="val_loss",
#                                         monitor="val_Dice",
                                        mode = "min",
                                       
                                        save_best_only=True,
                                        save_weights_only=True,
                                        verbose=1,
                                        ),
                              tf.keras.callbacks.EarlyStopping(
                                        monitor = 'val_loss',
#                                         monitor="val_Dice",
                                        mode = "min",
                                    
                                        min_delta = 0.01,
                                        patience = 5,
                                        )
                              ]
            print(f"{k}th fold Start Training....")
            
            #  Obtain the current time
            start_time = datetime.datetime.now()
            
            history = model.fit(self.generator.train_generator(k),
                                steps_per_epoch = (K-1) * 20 * 8,
                                validation_data = self.generator.valid_generator(k),
                                validation_steps = 20 * 8,
                                callbacks = callbacks_list,
                                epochs = epochs,
                                batch_size = self.batch_size,
                                shuffle = True,
                                )
            #  Total training time
            end_time = datetime.datetime.now()
            log_time = "Total training time: " + str((end_time - start_time).seconds / 60) + "m"
            print(log_time)
            with open('TrainTime.txt','w') as f:
                f.write(log_time)
            
            loss = history.history['loss']
            val_loss = history.history['val_loss']
            dice = history.history["Dice"]
            val_dice = history.history["val_Dice"]
            iou = history.history["Jaccard"]
            val_iou = history.history["val_Jaccard"]
            
            DiceIoU_list.append( val_dice[-1] + val_iou[-1] )
            mean_Dice += val_dice[-1]
            mean_IoU += val_iou[-1]

            epochs_range = range(len(loss))
            
            plt.figure(k, figsize=(15, 5))

            plt.subplot(1, 3, 1)
            plt.plot(epochs_range, loss, label='Training Loss')
            plt.plot(epochs_range, val_loss, label='Validation Loss')
            plt.legend(loc='upper right')
            plt.title('Loss')

            plt.subplot(1, 3, 2)
            plt.plot(epochs_range, dice, label='Training Dice')
            plt.plot(epochs_range, val_dice, label='Validation Dice')
            plt.legend(loc='lower right')
            plt.title('Dice Coefficient')

            plt.subplot(1, 3, 3)
            plt.plot(epochs_range, iou, label='Training IoU')
            plt.plot(epochs_range, val_iou, label='Validation IoU')
            plt.legend(loc='lower right')
            plt.title('IoU')
            plt.show()

            input_image = tf.keras.Input(shape=(self.width, self.height, 3), name="Image")
            predictions = model(input_image, training = True)
            inference_model = tf.keras.Model(inputs=input_image, outputs=predictions)

            for i, test in enumerate(self.test_dataset):
                img, mask = test
                prediction = inference_model.predict(img)
                
                img = img[0].numpy()
                mask = mask[0].numpy()
#----------------------------True---------------
#                 mask = mask[0]
#                 mask = (np.argmax(mask, axis=-1)).astype(np.uint8)
#                 mask = cv2.resize(mask, (1024, 1024))

                prediction = prediction[0]
                prediction = tf.math.argmax(prediction, 2)
                prediction = prediction.numpy()
                
                fig = plt.figure(10, figsize = (20,20))
                ax1 = fig.add_subplot(1, 3, 1)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                ax1.imshow(img)
                ax1.set_title('Image')
                ax1.axis("off")

                ax2 = fig.add_subplot(1, 3, 2)
                ax2.imshow(mask)
                ax2.set_title('Ground Truth Mask')
                ax2.axis("off")

                ax3 = fig.add_subplot(1, 3, 3)
                ax3.imshow(prediction)
                ax3.set_title('Prediction')
                ax3.axis("off")
                plt.show()

                if i == 1:
                    break
            del model
        print("Training End\n\n")
        self.optimal_k = DiceIoU_list.index(max(DiceIoU_list)) + 1
        print(f"K-Fold Cross Validation Result\nmDice : {mean_Dice*20:.3f}, mIoU : {mean_IoU*20:.3f}, Optimal_K : {self.optimal_k}\n\n")

        
    def Evaluation(self, num_sample):
        input_image = tf.keras.Input(shape=(self.width, self.height, 3), name="image")
        model = linknet()
        model.load_weights(
            f"{self.model_dir}U-Net_5.h5")

        model.compile(loss = self.loss_fn, 
                      optimizer = self.optimizer,
                      metrics = [Dice, Jaccard]
                      )

        _, dice, iou = model.evaluate(self.test_dataset, batch_size = self.batch_size, verbose= 1)
#         print("diceeeee",dice)
        print(f"\n\nDice : {dice*100:.2f}, IoU : {iou*100:.2f}\n\n")
        predictions = model(input_image, training=True)
        inference_model = tf.keras.Model(inputs=input_image, outputs=predictions)
        
        print("Display predictions")
        for i, test in enumerate(self.test_dataset):
            img, mask = test
            prediction = inference_model.predict(img)
            
            img = img[0].numpy()
            mask = mask[0].numpy()
#------------------------True-----------------------
#             mask = mask[0]
#             mask = (np.argmax(mask, axis=-1)).astype(np.uint8)
#             mask = cv2.resize(mask, (1024, 1024))

            prediction = prediction[0]
            prediction = tf.math.argmax(prediction, -1)
            prediction = prediction.numpy()
            
            fig = plt.figure(i, figsize = (20,20))
            ax1 = fig.add_subplot(1, 3, 1)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            ax1.imshow(img)
            ax1.set_title('Image')
            ax1.axis("off")

            ax2 = fig.add_subplot(1, 3, 2)
            ax2.imshow(mask)
            ax2.set_title('Ground Truth Mask')
            ax2.axis("off")

            ax3 = fig.add_subplot(1, 3, 3)
            ax3.imshow(prediction)
            ax3.set_title('Prediction')
            ax3.axis("off")
            plt.show()

            if i == num_sample:
                break

In [None]:
# config = tf.compat.v1.ConfigProto()
# config.gpu_options.allow_growth = True  
# config.gpu_options.per_process_gpu_memory_fraction = 0.5  
# config.gpu_options.visible_device_list = "0" 
# set_session(tf.compat.v1.Session(config=config))
vgg_CFL = MODEL()
vgg_CFL.Run_training()
vgg_CFL.Evaluation(num_sample = -1)
del vgg_CFL