In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0" # first gpu

In [None]:
training = True
num_epochs = 2

In [None]:
import warnings
import time
import datetime
import numpy as np
import cupy as cp
import random
import shutil
from tabulate import tabulate
from itables import init_notebook_mode
from itables import show
import logging
logging.getLogger('tensorflow').disabled = True
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
warnings.filterwarnings('ignore')


from hist_metrics import *

path_1 = 'results/'
path_2 = 'logs/'
eval_results_path = 'eval_results/'

In [None]:
def create_and_remove_dirs():
    for i in range(5):
        if os.path.exists(path_1):
            shutil.rmtree(path_1)

        if os.path.exists(path_2):
            shutil.rmtree(path_2)
        
        if os.path.exists(eval_results_path):
            shutil.rmtree(eval_results_path)
            
    if not os.path.exists(path_1):
        os.makedirs(path_1)
        os.makedirs(path_2)
        os.makedirs(eval_results_path)
        

if (training):
    create_and_remove_dirs()

In [None]:
import tensorflow as tf

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  try:
    #tf.config.experimental.set_visible_devices(gpus[g_n], 'GPU')
    # Currently, memory growth needs to be the same across GPUs
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.experimental.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
    print(e)
    
    

print(tf. __version__)
# print(tf.keras. __version__)
    

In [None]:
seed = 13334
np.random.seed(seed)
random.seed(seed)
tf.random.set_seed(seed)
tf.keras.utils.set_random_seed(seed)
os.environ['PYTHONHASHSEED']=str(seed)
cp.random.seed(seed=seed)

In [None]:
import sys
import pandas as pd
from tqdm import tqdm
import cv2
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import datetime
from PIL import Image, ImageStat
import math
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import UpSampling2D
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import Add
from tensorflow.keras.initializers import RandomNormal

from tensorflow import keras
from tensorflow.keras.callbacks import TensorBoard
import pickle
import json


import tensorflow.experimental.numpy as tnp
tnp.experimental_enable_numpy_behavior()


In [None]:
from matplotlib import pyplot as plt
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from skimage.color import rgb2gray


batch_size= 8

imgz_size = 256


train_image_dir = ""
train_mask_dir = ""
# val_image_dir = ""
# val_mask_dir = ""

import albumentations as A
aug = A.Compose([
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        )
    ]
)


def binarize_mask(mask):
    mask = np.where(mask>0, 1, 0)
    return mask


def normalize_image(image):
    image = aug(image=image)['image']
    return image


img_data_gen_args = dict( preprocessing_function = normalize_image , rotation_range=90, width_shift_range=0.3, height_shift_range=0.3, shear_range=0.5, zoom_range=0.3, horizontal_flip=True, vertical_flip=True, fill_mode='reflect')
mask_data_gen_args = dict( preprocessing_function = binarize_mask , rotation_range=90, width_shift_range=0.3, height_shift_range=0.3, shear_range=0.5, zoom_range=0.3, horizontal_flip=True, vertical_flip=True, fill_mode='reflect')

# img_data_gen_args = dict( preprocessing_function = normalize_image , rotation_range=90, horizontal_flip=True, vertical_flip=True, fill_mode='reflect')
# mask_data_gen_args = dict( preprocessing_function = binarize_mask , rotation_range=90, horizontal_flip=True, vertical_flip=True, fill_mode='reflect')

# img_data_gen_args = dict( preprocessing_function = normalize_image)
# mask_data_gen_args = dict( preprocessing_function = binarize_mask )

image_data_generator = ImageDataGenerator(**img_data_gen_args)
mask_data_generator = ImageDataGenerator(**mask_data_gen_args)




image_generator = image_data_generator.flow_from_directory(train_image_dir, 
                                                           seed=seed,target_size=(imgz_size, imgz_size), shuffle = True,
                                                           batch_size=batch_size,
                                                           class_mode=None)  #Very important to set this otherwise it returns multiple numpy arrays 
                                                                            #thinking class mode is binary.


mask_generator = mask_data_generator.flow_from_directory(train_mask_dir, 
                                                         seed=seed,target_size=(imgz_size, imgz_size), shuffle = True,
                                                         batch_size=batch_size,
                                                         color_mode = 'grayscale',   #Read masks in grayscale
                                                         class_mode=None)



# valid_img_data_gen_args = dict( rescale = 1/255. , rotation_range=90, width_shift_range=0.3, height_shift_range=0.3, shear_range=0.5, zoom_range=0.3, horizontal_flip=True, vertical_flip=True, fill_mode='reflect')
# valid_mask_data_gen_args = dict( preprocessing_function = binarize_mask , rotation_range=90, width_shift_range=0.3, height_shift_range=0.3, shear_range=0.5, zoom_range=0.3, horizontal_flip=True, vertical_flip=True, fill_mode='reflect')

# valid_img_data_gen_args = dict( preprocessing_function = normalize_image)
# valid_mask_data_gen_args = dict( preprocessing_function = binarize_mask )

# valid_image_data_generator = ImageDataGenerator(**valid_img_data_gen_args)
# valid_mask_data_generator = ImageDataGenerator(**valid_mask_data_gen_args)


# valid_img_generator = valid_image_data_generator.flow_from_directory(val_image_dir, 
#                                                                seed=seed,target_size=(imgz_size, imgz_size),
#                                                                batch_size=batch_size, 
#                                                                class_mode=None) #Default batch size 32, if not specified here
# valid_mask_generator = valid_mask_data_generator.flow_from_directory(val_mask_dir, 
#                                                                seed=seed,target_size=(imgz_size, imgz_size),
#                                                                batch_size=batch_size, 
#                                                                color_mode = 'grayscale',   #Read masks in grayscale
#                                                                class_mode=None)  #Default batch size 32, if not specified here


train_generator = zip(image_generator, mask_generator)
# val_generator = zip(valid_img_generator, valid_mask_generator)


In [None]:
x = image_generator.next()
y = mask_generator.next()
for i in range(0,1):
    image = x[i]
    mask = y[i]
    plt.subplot(1,2,1)
    plt.imshow(image)
    plt.subplot(1,2,2)
    plt.imshow(mask)
    plt.show()

In [None]:
# x = valid_img_generator.next()
# y = valid_mask_generator.next()
# for i in range(0,1):
#     image = x[i]
#     mask = y[i]
#     plt.subplot(1,2,1)
#     plt.imshow(image)
#     plt.subplot(1,2,2)
#     plt.imshow(mask)
#     plt.show()

In [None]:
############################################################


layer_count_attn = 1

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"


from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, MaxPool2D, UpSampling2D, Concatenate, Add

from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, UpSampling2D
from tensorflow.keras.layers import AveragePooling2D, Conv2DTranspose, Concatenate, Input

from tensorflow.keras import models, layers, regularizers
from tensorflow.keras import backend as K
from swa.keras import SWA

from tensorflow.keras.layers import GlobalAveragePooling2D, Reshape, Dense, Input


initializer = keras.initializers.HeUniform(seed=None)


def tf_avg_func(tensor):
    x = tnp.array(tensor)
    x = x[0] + x[1]+ x[2]+ x[3]+ x[4]+ x[5]
    x = tnp.divide(x, 6)
    return x

def ASPP(inputs):
    shape = inputs.shape

    y_pool = AveragePooling2D(pool_size=(shape[1], shape[2]), name='average_pooling')(inputs)
    y_pool = Conv2D(filters=256, kernel_size=1, padding='same', use_bias=False, kernel_initializer = initializer)(y_pool)
    y_pool = BatchNormalization(name=f'bn_1')(y_pool)
    y_pool = layers.LeakyReLU()(y_pool)
    y_pool = UpSampling2D((shape[1], shape[2]), interpolation="bilinear")(y_pool)

    y_1 = Conv2D(filters=256, kernel_size=1, dilation_rate=1, padding='same', use_bias=False, kernel_initializer = initializer)(inputs)
    y_1 = BatchNormalization()(y_1)
    #y_1 = Activation('relu')(y_1)

    y_6 = Conv2D(filters=256, kernel_size=3, dilation_rate=6, padding='same', use_bias=False, kernel_initializer = initializer)(inputs)
    y_6 = BatchNormalization()(y_6)
    #y_6 = Activation('relu')(y_6)

    y_12 = Conv2D(filters=256, kernel_size=3, dilation_rate=12, padding='same', use_bias=False, kernel_initializer = initializer)(inputs)
    y_12 = BatchNormalization()(y_12)
    #y_12 = Activation('relu')(y_12)

    y_18 = Conv2D(filters=256, kernel_size=3, dilation_rate=18, padding='same', use_bias=False, kernel_initializer = initializer)(inputs)
    y_18 = BatchNormalization()(y_18)
    #y_18 = Activation('relu')(y_18)

    y = Concatenate()([y_pool, y_1, y_6, y_12, y_18])

    y = Conv2D(filters=256, kernel_size=1, dilation_rate=1, padding='same', use_bias=False, kernel_initializer = initializer)(y)
    y = BatchNormalization()(y)
    y = layers.LeakyReLU()(y)
    return y



def SqueezeAndExcitation(inputs, ratio=8):
    b, _, _, c = inputs.shape
    x = GlobalAveragePooling2D()(inputs)
    x = Dense(c//ratio, activation="relu", use_bias=False)(x)
    x = Dense(c, activation="sigmoid", use_bias=False)(x)
    
    x = layers.multiply([inputs, x])
    return x


def conv_block(inputs, out_ch, rate=1):
    x = Conv2D(out_ch, 3, padding="same", dilation_rate=rate, use_bias=False, kernel_initializer = initializer)(inputs)
    x = BatchNormalization()(x)
    #x = layers.Dropout(0.3)(x)
    x = layers.LeakyReLU()(x)
    #x = SqueezeAndExcitation(x)
    
    return x


def conv_block_simple(inputs, out_ch, rate=1):
    x = Conv2D(out_ch, 3, padding="same", dilation_rate=1  )(inputs)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    
    x = Conv2D(out_ch, 3, padding="same", dilation_rate=1 )(x)
    x = BatchNormalization()(x)
    
    #x = layers.Dropout(0.3)(x)
    
    shortcut = Conv2D(out_ch, 3, padding="same", dilation_rate=1 )(inputs)
    shortcut = layers.BatchNormalization(axis=3)(shortcut)
    
    res_path = layers.add([shortcut, x])
    res_path = layers.Activation('relu')(res_path)

    return res_path


# Attention structure
FILTER_NUM = 64 # number of basic filters for the first layer
FILTER_SIZE = 3 # size of the convolutional filter
UP_SAMP_SIZE = 2 # size of upsampling filters

def repeat_elem(tensor, rep):

    return layers.Lambda(lambda x, repnum: K.repeat_elements(x, repnum, axis=3),
                        arguments={'repnum': rep})(tensor)

def gating_signal(input, out_size):

    x = layers.Conv2D(out_size, (1, 1), padding='same')(input)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    return x

def attention_block(x, gating, inter_shape):
    shape_x = K.int_shape(x)
    shape_g = K.int_shape(gating)

# Getting the x signal to the same shape as the gating signal
    theta_x = layers.Conv2D(inter_shape, (2, 2), strides=(2, 2), padding='same')(x)  # 16
    shape_theta_x = K.int_shape(theta_x)

# Getting the gating signal to the same number of filters as the inter_shape
    phi_g = layers.Conv2D(inter_shape, (1, 1), padding='same')(gating)
    upsample_g = layers.Conv2DTranspose(inter_shape, (3, 3),
                                strides=(shape_theta_x[1] // shape_g[1], shape_theta_x[2] // shape_g[2]),
                                padding='same')(phi_g)  # 16

    concat_xg = layers.add([upsample_g, theta_x])
    act_xg = layers.Activation('relu')(concat_xg)
    psi = layers.Conv2D(1, (1, 1), padding='same')(act_xg)
    sigmoid_xg = layers.Activation('sigmoid')(psi)
    shape_sigmoid = K.int_shape(sigmoid_xg)
    upsample_psi = layers.UpSampling2D(size=(shape_x[1] // shape_sigmoid[1], shape_x[2] // shape_sigmoid[2]))(sigmoid_xg)  # 32

    upsample_psi = repeat_elem(upsample_psi, shape_x[3])

    y = layers.multiply([upsample_psi, x])

    global layer_count_attn
    result = layers.Conv2D(shape_x[3], (1, 1), padding='same' )(y)
    
    result_bn = layers.BatchNormalization(name='attention_'+str(layer_count_attn))(result)
    layer_count_attn = layer_count_attn + 1
    return result_bn


def I_UNet_7(inputs, out_ch, int_ch, num_layers, rate=2):
    """ Initial Conv """
    x = conv_block(inputs, out_ch) # (None, 256, 256, 64)
    init_feats = x
    

    """ Encoder """
    skip = []
    x = conv_block(x, int_ch) # (None, 256, 256, 32)    
    skip.append(x)

    x = MaxPool2D((2, 2))(x)
    x_128 = conv_block(x, int_ch) # (None, 128, 128, 32)
    skip.append(x_128)
    
    x = MaxPool2D((2, 2))(x_128)
    x_64 = conv_block(x, int_ch) # (None, 64, 64, 32)
    skip.append(x_64)
    
    x = MaxPool2D((2, 2))(x_64)
    x_32 = conv_block(x, int_ch) # (None, 32, 32, 32)
    skip.append(x_32)
    
    x = MaxPool2D((2, 2))(x_32)
    x_16 = conv_block(x, int_ch) # (None, 16, 16, 32)
    skip.append(x_16)
    
    x = MaxPool2D((2, 2))(x_16)
    x = conv_block(x, int_ch) # (None, 8, 8, 32)
    skip.append(x)



    """ Bridge """
    x = conv_block(x, int_ch, rate=rate) # (None, 8, 8, 32)
    
    
    
    """ Decoder """
    skip.reverse()

    x = Concatenate()([x, skip[0]]) # (None, 8, 8, 64)
    x = conv_block(x, int_ch) # (None, 8, 8, 32)
    
    #-------------------------------------------------------------#
    gating_16 = gating_signal(x, 8*FILTER_NUM)
    att_16 = attention_block(x_16, gating_16, 8*FILTER_NUM)
    x = UpSampling2D(size=(2, 2), interpolation="bilinear")(x)
    x = layers.concatenate([x, att_16], axis=3)
    x = conv_block(x, int_ch)
    #-------------------------------------------------------------#
    
    
    #-------------------------------------------------------------#
    gating_32 = gating_signal(x, 4*FILTER_NUM)
    att_32 = attention_block(x_32, gating_32, 4*FILTER_NUM)
    x = UpSampling2D(size=(2, 2), interpolation="bilinear")(x)
    x = layers.concatenate([x, att_32], axis=3)
    x = conv_block(x, int_ch)
    #-------------------------------------------------------------#
    

    #-------------------------------------------------------------#
    gating_64 = gating_signal(x, 2*FILTER_NUM)
    att_64 = attention_block(x_64, gating_64, 2*FILTER_NUM)
    x = UpSampling2D(size=(2, 2), interpolation="bilinear")(x)
    x = layers.concatenate([x, att_64], axis=3)
    x = conv_block(x, int_ch)
    #-------------------------------------------------------------#
    
    
    #-------------------------------------------------------------#
    gating_128 = gating_signal(x, FILTER_NUM)
    att_128 = attention_block(x_128, gating_128, FILTER_NUM)
    x = UpSampling2D(size=(2, 2), interpolation="bilinear")(x)
    x = layers.concatenate([x, att_128], axis=3)
    x = conv_block(x, int_ch)
    #-------------------------------------------------------------#


    x = UpSampling2D(size=(2, 2), interpolation="bilinear")(x)
    x = Concatenate()([x, skip[-1]])
    x = conv_block(x, out_ch)

    """ Add """
    x = Add()([x, init_feats])
    return x


def I_UNet_6(inputs, out_ch, int_ch, num_layers, rate=2):
    """ Initial Conv """
    x = conv_block(inputs, out_ch) # (None, 128, 128, 128)
    init_feats = x
    

    """ Encoder """
    skip = []
    x = conv_block(x, int_ch) # (None, 128, 128, 32)
    #print(x.shape)
    skip.append(x)
    


    x = MaxPool2D((2, 2))(x)
    x_64 = conv_block(x, int_ch) # (None, 64, 64, 32)
    #print(x_64.shape)
    skip.append(x_64)
    
    x = MaxPool2D((2, 2))(x_64)
    x_32 = conv_block(x, int_ch) # (None, 32, 32, 32)
    #print(x_32.shape)
    skip.append(x_32)
    
    x = MaxPool2D((2, 2))(x_32)
    x_16 = conv_block(x, int_ch) # (None, 16, 16, 32)
    #print(x_16.shape)
    skip.append(x_16)
    
    x = MaxPool2D((2, 2))(x_16)
    x = conv_block(x, int_ch) # (None, 8, 8, 32)
    #print(x.shape)
    skip.append(x)
    



    """ Bridge """
    x = conv_block(x, int_ch, rate=rate)  # (None, 8, 8, 32)
    
    
    
    """ Decoder """
    skip.reverse()

    x = Concatenate()([x, skip[0]])  # (None, 8, 8, 64)
    x = conv_block(x, int_ch)  # (None, 8, 8, 32)


    #-------------------------------------------------------------#
    gating_16 = gating_signal(x, 8*FILTER_NUM)
    att_16 = attention_block(x_16, gating_16, 8*FILTER_NUM)
    x = UpSampling2D(size=(2, 2), interpolation="bilinear")(x)
    x = layers.concatenate([x, att_16], axis=3)
    x = conv_block(x, int_ch)
    #-------------------------------------------------------------#
    
    
    #-------------------------------------------------------------#
    gating_32 = gating_signal(x, 4*FILTER_NUM)
    att_32 = attention_block(x_32, gating_32, 4*FILTER_NUM)
    x = UpSampling2D(size=(2, 2), interpolation="bilinear")(x)
    x = layers.concatenate([x, att_32], axis=3)
    x = conv_block(x, int_ch)
    #-------------------------------------------------------------#
    

    #-------------------------------------------------------------#
    gating_64 = gating_signal(x, 2*FILTER_NUM)
    att_64 = attention_block(x_64, gating_64, 2*FILTER_NUM)
    x = UpSampling2D(size=(2, 2), interpolation="bilinear")(x)
    x = layers.concatenate([x, att_64], axis=3)
    x = conv_block(x, int_ch)
    #-------------------------------------------------------------#




    x = UpSampling2D(size=(2, 2), interpolation="bilinear")(x)  # (None, 128, 128, 32)
    x = Concatenate()([x, skip[-1]])
    x = conv_block(x, out_ch)  # (None, 128, 128, 128)

    """ Add """
    x = Add()([x, init_feats])  # (None, 128, 128, 128)
    return x



def I_UNet_5(inputs, out_ch, int_ch, num_layers, rate=2):
    """ Initial Conv """
    x = conv_block(inputs, out_ch) # (None, 128, 128, 128)
    init_feats = x
    

    """ Encoder """
    skip = []
    x = conv_block(x, int_ch) # (None, 128, 128, 32)
    skip.append(x)

    
    x = MaxPool2D((2, 2))(x)
    x_32 = conv_block(x, int_ch) # (None, 32, 32, 32)
    skip.append(x_32)
    
    x = MaxPool2D((2, 2))(x_32)
    x_16 = conv_block(x, int_ch) # (None, 16, 16, 32)
    skip.append(x_16)
    
    x = MaxPool2D((2, 2))(x_16)
    x = conv_block(x, int_ch) # (None, 8, 8, 32)
    skip.append(x)
    



    """ Bridge """
    x = conv_block(x, int_ch, rate=rate)  # (None, 8, 8, 32)
    
    
    
    """ Decoder """
    skip.reverse()

    x = Concatenate()([x, skip[0]])  # (None, 8, 8, 64)
    x = conv_block(x, int_ch)  # (None, 8, 8, 32)


    #-------------------------------------------------------------#
    gating_16 = gating_signal(x, 8*FILTER_NUM)
    att_16 = attention_block(x_16, gating_16, 8*FILTER_NUM)
    x = UpSampling2D(size=(2, 2), interpolation="bilinear")(x)
    x = layers.concatenate([x, att_16], axis=3)
    x = conv_block(x, int_ch)
    #-------------------------------------------------------------#
    
    
    #-------------------------------------------------------------#
    gating_32 = gating_signal(x, 4*FILTER_NUM)
    att_32 = attention_block(x_32, gating_32, 4*FILTER_NUM)
    x = UpSampling2D(size=(2, 2), interpolation="bilinear")(x)
    x = layers.concatenate([x, att_32], axis=3)
    x = conv_block(x, int_ch)
    #-------------------------------------------------------------#
    

    x = UpSampling2D(size=(2, 2), interpolation="bilinear")(x)  # (None, 128, 128, 32)
    x = Concatenate()([x, skip[-1]])
    x = conv_block(x, out_ch)  # (None, 128, 128, 128)

    """ Add """
    x = Add()([x, init_feats])  # (None, 128, 128, 128)
    return x



def I_UNet_4(inputs, out_ch, int_ch, num_layers, rate=2):
    """ Initial Conv """
    x = conv_block(inputs, out_ch) # (None, 128, 128, 128)
    init_feats = x
    

    """ Encoder """
    skip = []
    x = conv_block(x, int_ch) # (None, 128, 128, 32)
    skip.append(x)


    
    x = MaxPool2D((2, 2))(x)
    x_16 = conv_block(x, int_ch) # (None, 16, 16, 32)
    skip.append(x_16)
    
    x = MaxPool2D((2, 2))(x_16)
    x = conv_block(x, int_ch) # (None, 8, 8, 32)
    skip.append(x)
    



    """ Bridge """
    x = conv_block(x, int_ch, rate=rate)  # (None, 8, 8, 32)
    
    
    
    """ Decoder """
    skip.reverse()

    x = Concatenate()([x, skip[0]])  # (None, 8, 8, 64)
    x = conv_block(x, int_ch)  # (None, 8, 8, 32)


    #-------------------------------------------------------------#
    gating_16 = gating_signal(x, 8*FILTER_NUM)
    att_16 = attention_block(x_16, gating_16, 8*FILTER_NUM)
    x = UpSampling2D(size=(2, 2), interpolation="bilinear")(x)
    x = layers.concatenate([x, att_16], axis=3)
    x = conv_block(x, int_ch)
    #-------------------------------------------------------------#
    
    

    x = UpSampling2D(size=(2, 2), interpolation="bilinear")(x)  # (None, 128, 128, 32)
    x = Concatenate()([x, skip[-1]])
    x = conv_block(x, out_ch)  # (None, 128, 128, 128)

    """ Add """
    x = Add()([x, init_feats])  # (None, 128, 128, 128)

    return x



def bridge_block(inputs, out_ch, int_ch):
    """ Initial Conv """
    x0 = conv_block(inputs, out_ch, rate=1)

    """ Encoder """
    x1 = conv_block(x0, int_ch, rate=1)
    x2 = conv_block(x1, int_ch, rate=2)
    x3 = conv_block(x2, int_ch, rate=4)

    """ Bridge """
    x4 = conv_block(x3, int_ch, rate=8)

    """ Decoder """
    x = Concatenate()([x4, x3])
    x = conv_block(x, int_ch, rate=4)

    x = Concatenate()([x, x2])
    x = conv_block(x, int_ch, rate=2)

    x = Concatenate()([x, x1])
    x = conv_block(x, out_ch, rate=1)

    """ Addition """
    x = Add()([x, x0])
    return x

def histosegplusplus(input_shape, out_ch, int_ch, num_classes=2):
    """ Input Layer """
    inputs = Input(input_shape)
    s0 = inputs

    """ Encoder """
    skip = []
    
    s1 = I_UNet_7(s0, out_ch[0], int_ch[0], 7)
    s1 = conv_block(s1, out_ch[0])
    s1 = SqueezeAndExcitation(s1)
    skip.append(s1)
    p1 = MaxPool2D((2, 2))(s1)

    s2 = I_UNet_6(p1, out_ch[1], int_ch[1], 6)
    s2 = conv_block(s2, out_ch[1])
    s2 = SqueezeAndExcitation(s2)
    skip.append(s2)
    p2 = MaxPool2D((2, 2))(s2)

    s3 = I_UNet_5(p2, out_ch[2], int_ch[2], 5)
    s3 = conv_block(s3, out_ch[2])
    s3 = SqueezeAndExcitation(s3)
    skip.append(s3)
    p3 = MaxPool2D((2, 2))(s3)

    s4 = I_UNet_4(p3, out_ch[3], int_ch[3], 4)
    s4 = conv_block(s4, out_ch[3])
    s4 = SqueezeAndExcitation(s4) 
    skip.append(s4)
    p4 = MaxPool2D((2, 2))(s4)

    s5 = bridge_block(p4, out_ch[4], int_ch[4])
    s5 = conv_block(s5, out_ch[4])
    s5 = SqueezeAndExcitation(s5)
    skip.append(s5)
    p5 = MaxPool2D((2, 2))(s5)

    """ Bridge """
    b1 = bridge_block(p5, out_ch[5], int_ch[5])
    b1 = ASPP(b1)
    b2 = UpSampling2D(size=(2, 2), interpolation="bilinear")(b1)

    """ Decoder """
    skip.reverse()
    
    d1 = Concatenate()([b2, s5])
    d1 = bridge_block(d1, out_ch[6], int_ch[6])


    
    gating = gating_signal(d1, 16*FILTER_NUM)
    att = attention_block(s4, gating, 8*FILTER_NUM)
    u1 = UpSampling2D(size=(2, 2), interpolation="bilinear")(d1)
    d2 = Concatenate()([u1, att])
    d2 = I_UNet_4(d2, out_ch[7], int_ch[7], 4)

    
    
    gating = gating_signal(d2, 8*FILTER_NUM)
    att = attention_block(s3, gating, 4*FILTER_NUM)
    u2 = UpSampling2D(size=(2, 2), interpolation="bilinear")(d2)
    d3 = Concatenate()([u2, att])
    d3 = I_UNet_5(d3, out_ch[8], int_ch[8], 5)

    
    
    gating = gating_signal(d3, 4*FILTER_NUM)
    att = attention_block(s2, gating, 2*FILTER_NUM)
    u3 = UpSampling2D(size=(2, 2), interpolation="bilinear")(d3)
    d4 = Concatenate()([u3, att])
    d4 = I_UNet_6(d4, out_ch[9], int_ch[9], 6)

    
    
    gating = gating_signal(d4, 2*FILTER_NUM)
    att = attention_block(s1, gating, 2*FILTER_NUM)
    u4 = UpSampling2D(size=(2, 2), interpolation="bilinear")(d4)
    d5 = Concatenate()([u4, att])
    d5 = I_UNet_7(d5, out_ch[10], int_ch[10], 7)

    

    """ Side Outputs """

    z1 = Conv2D(num_classes, 3, padding="same", use_bias=False, kernel_initializer = initializer)(d5)

    z2 = Conv2D(num_classes, 3, padding="same", use_bias=False, kernel_initializer = initializer)(d4)
    z2 = UpSampling2D(size=(2, 2), interpolation="bilinear")(z2)

    z3 = Conv2D(num_classes, 3, padding="same", use_bias=False, kernel_initializer = initializer)(d3)
    z3 = UpSampling2D(size=(4, 4), interpolation="bilinear")(z3)

    z4 = Conv2D(num_classes, 3, padding="same", use_bias=False, kernel_initializer = initializer)(d2)
    z4 = UpSampling2D(size=(8, 8), interpolation="bilinear")(z4)

    z5 = Conv2D(num_classes, 3, padding="same", use_bias=False, kernel_initializer = initializer)(d1)
    z5 = UpSampling2D(size=(16, 16), interpolation="bilinear")(z5)

    z6 = Conv2D(num_classes, 3, padding="same", use_bias=False, kernel_initializer = initializer)(b1)
    z6 = UpSampling2D(size=(32, 32), interpolation="bilinear")(z6)
    
    
    
    can_1 = Concatenate()([z1, z2, z3, z4, z5, z6])    
    
    o_5 = Conv2D(num_classes, 3, padding="same", use_bias=False, kernel_initializer = initializer)(can_1)
    
    o_5 = Activation("sigmoid")(o_5)
    
    model = tf.keras.models.Model(inputs, outputs=[o_5])

    return model

def build_model(input_shape, num_classes=1):
    out_ch = [64, 128, 256, 512, 512, 512, 512, 256, 128, 64, 64]
    int_ch = [32, 32, 64, 128, 256, 256, 256, 128, 64, 32, 16]
    model = histosegplusplus(input_shape, out_ch, int_ch, num_classes=num_classes)
    return model


def build_model_hist():
    return build_model((256, 256, 3))
    
model = build_model_hist()


model.summary(expand_nested=True, show_trainable=True)


In [None]:
print_model_params(model)

In [None]:
os.environ["SM_FRAMEWORK"] = "tf.keras"
import segmentation_models as sm

dice_loss = sm.losses.DiceLoss(per_image = True) 
focal_loss = sm.losses.CategoricalFocalLoss(gamma=2.)
jaccard_loss = sm.losses.JaccardLoss()
binary_focal_loss = sm.losses.BinaryFocalLoss(gamma=2.)
binary_crossentropy = sm.losses.BinaryCELoss()

total_loss = (0.4*binary_crossentropy) + (0.6*binary_focal_loss)

In [None]:
keras_iou = tf.keras.metrics.BinaryIoU(target_class_ids=[1], threshold=0.5)
iou_score = sm.metrics.IOUScore(per_image = True, threshold = 0.5)

In [None]:
def edge_aware_loss(y_true, y_pred, edge_weight=2.0):
    """
    Edge-aware loss function for segmentation tasks.
    
    Args:
    - y_true: Ground truth masks.
    - y_pred: Predicted masks.
    - edge_weight: Weight multiplier for edge regions.
    
    Returns:
    - Loss value.
    """
    # Calculate binary cross-entropy
    bce = tf.keras.losses.binary_crossentropy(y_true, y_pred, from_logits=False)
    bce = tf.expand_dims(bce, axis=-1)  # Ensure BCE has the correct shape
    
    # Detect edges in the ground truth mask
    sobel_x = tf.image.sobel_edges(y_true)[:,:,:,:,0]
    sobel_y = tf.image.sobel_edges(y_true)[:,:,:,:,1]
    edge_mask = tf.sqrt(tf.square(sobel_x) + tf.square(sobel_y))
    edge_mask = tf.cast(edge_mask > 0.1, tf.float32)  # Threshold for edges
    
    # Ensure edge_mask is compatible for broadcasting
    edge_mask = tf.reduce_mean(edge_mask, axis=-1, keepdims=True)
    
    # Weights for edge regions
    weights = 1 + edge_mask * (edge_weight - 1)
    
    # Apply weights
    weighted_bce = tf.multiply(bce, weights)
    
    return tf.reduce_mean(weighted_bce)

In [None]:
opt = tf.keras.optimizers.legacy.Adam(learning_rate=0.001)

BCE = keras.losses.BinaryCrossentropy(from_logits=False)

def model_compile(model_x):
    model_x.compile(optimizer=opt, loss = edge_aware_loss, metrics= [iou_score])

model_compile(model)

In [None]:
import time
from keras.callbacks import Callback

class TimeHistory(Callback):
    def on_train_begin(self, logs={}):
        self.times = []

    def on_epoch_begin(self, epoch, logs={}):
        self.epoch_start_time = time.time()

    def on_epoch_end(self, epoch, logs={}):
        self.times.append(time.time() - self.epoch_start_time)
        print(f"Epoch {epoch+1} took {format_time(self.times[-1])}")

time_callback = TimeHistory()

In [None]:
filepath= path_1 + "{epoch:04d}.hdf5"
checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath=filepath, save_best_only=False, save_weights_only=True, verbose= 1)


csv_logger = tf.keras.callbacks.CSVLogger(path_2 + "logs.csv", separator=",", append=False)

reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=0.000000001, verbose=1)


num_train_imgs = len(os.listdir(train_image_dir + 'train/'))
steps_per_epoch = num_train_imgs //batch_size
print("num_train_imgs -------------- = " + str(num_train_imgs))
print("steps_per_epoch -------------- = " + str(steps_per_epoch))


start_epoch = 100

from swa.keras import SWA
swa = SWA(start_epoch=start_epoch, lr_schedule='cyclic', swa_lr=0.0001, swa_lr2=0.01, swa_freq=3, batch_size=batch_size, verbose=1)
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', mode="min", restore_best_weights=True, patience=3, verbose=1)


def scheduler(epoch, lr):
    if epoch < 100:
        return lr
    else:
        return lr * tf.math.exp(-0.1)

schedule_lr = keras.callbacks.LearningRateScheduler(scheduler, verbose=1)

callbacks_list = [checkpoint, csv_logger, time_callback]

In [None]:
if (training):
    start_time = time.time()

    history = model.fit(train_generator,
                        steps_per_epoch=steps_per_epoch,
                        epochs=num_epochs, 
                        callbacks=callbacks_list,
                        batch_size=batch_size, verbose=1)



    end_time = time.time()

    total_training_time = end_time - start_time
    total_training_time_formatted = format_time(total_training_time)

    print(f"Total training time: {total_training_time:.3f} seconds")
    print(f"Total training time: {total_training_time_formatted}")

    model.save_weights(path_1 + 'last_weights.hdf5')
    
    epoch_times = time_callback.times
    calc_time(epoch_times)

In [None]:
####---------------------------------------------------------Evaluation Evaluation Evaluation Evaluation Evaluation Evaluation Evaluation Evaluation Evaluation Evaluation Evaluation---------------------------------------------------------###

In [None]:
test_images_arg = ""
test_masks_arg = ""



X_test = np.load(test_images_arg)
y_test = np.load(test_masks_arg)


print(X_test.shape)
print(y_test.shape)


X_test = aug(image=X_test)['image']
print('Min: %.3f, Max: %.3f' % (X_test.min(), X_test.max()))
# print(np.unique(X_test))

y_test = np.where(y_test>0, 1, 0)
y_test = y_test.astype(np.float32())
print('Min: %.3f, Max: %.3f' % (y_test.min(), y_test.max()))
print(np.unique(y_test))

In [None]:
i = random.randint(0, X_test.shape[0]-1)
print(i)
image = X_test[i]
mask = y_test[i]
plt.subplot(1,2,1)
plt.imshow(image)
plt.subplot(1,2,2)
plt.imshow(mask)
plt.show()

In [None]:
def model_clear_and_build():
    tf.keras.backend.clear_session()
    
    model = 0
    model = build_model_hist()
    model_compile(model)

In [None]:
model_clear_and_build()

pred_index = -1
df_combined = 0
df_counter = 0

current_cwd = os.getcwd()
weights_folder_path = current_cwd + "/results/"
dirs_ = sorted( os.listdir(weights_folder_path) )

try:
    dirs_.remove('last_weights.hdf5')
    dirs_.remove('last_weights_2.hdf5')
    dirs_.remove('last_weights_3.hdf5')
except:
    pass


pred = model.predict(X_test[0:3], batch_size=1, verbose=0)

if isinstance(pred, list):
    len_pred = len(pred)
else:
    len_pred = 1

print("No of Prediction Activations: ", len_pred)

if len_pred == X_test.shape[0]:
    len_pred = 1
    pred_index = -1
elif (len_pred == 1):
    len_pred = 1
    pred_index = -1
else:
    pred_index = 0


print("Pred_Index: ", pred_index)



#### Controling Predictions

# 0 = all predictions
# 1 = specific prediction
# 2 = default


control_sup = 2

if (control_sup == 1):

    ##################
    len_pred = 1
    pred_index = 0
    print("No of Prediction Activations: ", len_pred)
    print("Pred_Index: ", pred_index)
    ##################



In [None]:
if (training):
    def df_convert(df, ep_str, index_str):
        df = df.transpose()
        #df = df[1:]
        df.columns = df.iloc[0]
        df = df[1:]
        df.reset_index(drop=True, inplace=True)
        df.insert(0, 'Index_Out', index_str)
        df.insert(0, 'Epoch', ep_str)
        return df

    for f in range(0, len(dirs_)):
        pbar = tqdm( range(0, len_pred) )
        epoch_f = dirs_[f]

        for p in pbar:
            
            model_clear_and_build()
            
            pbar.set_description("Epoch: " + epoch_f)
            model.load_weights(weights_folder_path + dirs_[f])
            
            pred_sm_f = model.predict(X_test, verbose=0, batch_size=1)
    
            if (control_sup == 0):
                pred_index = p
                pred_sm_f = pred_sm_f[pred_index]
            
            if (control_sup == 1):
                pred_index = pred_index
                pred_sm_f = pred_sm_f[pred_index]
            

            
            df = eval_all_metrics(y_test, pred_sm_f)
            df = df.round(4)
            
            df = df_convert(df, epoch_f, pred_index)
            
            if df_counter == 0:
                df_combined = df
                
                df_counter = 1
            else:
                df_combined = pd.concat([df_combined, df], ignore_index=True)
            
            df_combined.to_csv(eval_results_path +"Epoch_"+ epoch_f +"_eval_all_metrics.csv")
                
        
            
    df_combined.to_csv(eval_results_path + "all_eval.csv")

In [None]:
if (training):
    df = pd.read_csv(eval_results_path + "all_eval.csv")

    # Find the rows with maximum values in "sm_iou_score_per_image" grouped by "Epoch" and "Index_Out"
    max_rows = df.loc[df.groupby([ "Index_Out"])["sm_iou_score_per_image"].idxmax()]
    max_rows = max_rows.drop(columns=['Unnamed: 0'])
    max_rows.to_csv(eval_results_path + "all_eval_max_filter.csv")
    show(max_rows)

In [None]:
df = pd.read_csv(eval_results_path + "all_eval_max_filter.csv")
max_iou_index = df['sm_iou_score_per_image'].idxmax()
max_iou_row = df.iloc[max_iou_index]
max_iou_weight_file = weights_folder_path + max_iou_row[1]
best_epoch_str = max_iou_row[1]
print(max_iou_weight_file)
print(best_epoch_str)
pred_index = df.iloc[max_iou_index][2]
print(pred_index)

In [None]:
#####################################################################################################################################################################

In [None]:
model_clear_and_build()
model.load_weights(max_iou_weight_file)

In [None]:
image_full_test_directory = ""
mask_full_test_directory = ""

create_pred_dirs()

In [None]:
patch_img_size = 256
patch_step_size = 128

resize_img = True
resize_wh = [1024, 1024]

In [None]:
pred_dir_prefix = "predictions/pred_tta/"
tta = True
pred_full(tta, control_sup, pred_index, pred_dir_prefix, model, image_full_test_directory, mask_full_test_directory, patch_img_size, patch_step_size, resize_img, resize_wh)

In [None]:
threshold_pred = 0.5

In [None]:
y_test = np.load("predictions/pred_tta/preds_array.npy")
pred = np.load("predictions/pred_tta/y_tests_array.npy")


print('Min: %.3f, Max: %.3f' % (y_test.min(), y_test.max()))
print(np.unique(y_test))

print('Min: %.3f, Max: %.3f' % (pred.min(), pred.max()))
print(np.unique(pred))

In [None]:
df = eval_all_metrics(y_test, pred, all = 1, threshold_pred = threshold_pred)

if (training):
    df.to_csv(eval_results_path + "all_metrics_" + best_epoch_str + ".csv")
else:
    df.to_csv(eval_results_path + "all_metrics_manual_testing_" + best_epoch_str + ".csv")

df = df.round(4)
style_df = highlight_values_with_colors(df, 'Score')
style_df