In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:

import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from glob import glob
import tensorflow as tf
import tensorflow.keras as keras
import keras.backend as K
import tensorflow.keras.layers as L
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, MaxPool2D, Add, Dropout, Concatenate, Conv2DTranspose, Dense, Reshape, Flatten, Softmax, Lambda, UpSampling2D, AveragePooling2D, Activation, BatchNormalization, GlobalAveragePooling2D, SeparableConv2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import MeanIoU
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.metrics import BinaryAccuracy, Precision, Recall
from sklearn.model_selection import train_test_split
from tensorflow.keras.applications import DenseNet121
!pip install tensorflow-wavelets
import tensorflow_wavelets.Layers.DWT as DWT
import os
import numpy as np
import cv2
from glob import glob
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from PIL import Image

from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.metrics import Recall, Precision, MeanIoU

from tqdm import tqdm

from tensorflow.keras import backend as K
import time
import zipfile
import shutil

In [None]:
tissue_train = sorted(os.listdir(os.path.join(train_path, "TissueImages")))
mask_train = sorted(os.listdir(os.path.join(train_path, "GroundTruth")))

In [None]:
tissue_test = sorted([
    file for file in os.listdir(os.path.join(test_path, "TissueImages"))
    if file.lower().endswith('.png')
])
mask_test = sorted(os.listdir(os.path.join(test_path, "GroundTruth")))

In [None]:
def apply_otsu_thresholding(image_path):
    image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    threshold, thresholded = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    #cv2.imwrite(image_path, thresholded)
    return threshold

In [None]:
def create_train(tissue_train, mask_train,batch):
  def parse_images(tissue_train, mask_train):

    tissue_file_str = tf.strings.join([train_path, "TissueImages/", tissue_train])
    mask_file_str = tf.strings.join([train_path, "GroundTruth/", mask_train])

    tissue_image = tf.io.read_file(tissue_file_str)
    mask_image = tf.io.read_file(mask_file_str)

    tissue_image = tf.image.decode_png(tissue_image, channels=3)
    mask_image = tf.image.decode_png(mask_image, channels=1)

    tissue_image = tf.image.resize(tissue_image, [512, 512])
    mask_image = tf.image.resize(mask_image, [512, 512])

    # _, mask_image = cv2.threshold(mask_image, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

    tissue_image = tf.cast(tissue_image, tf.float32) / 255.0
    mask_image = tf.cast(mask_image, tf.float32) / 255.0

    return tissue_image, mask_image


  dataset = tf.data.Dataset.from_tensor_slices((tissue_train, mask_train))
  dataset = dataset.map(parse_images)
  dataset = dataset.batch(batch)
  dataset = dataset.repeat()

  return dataset

In [None]:
def create_train_gaussian(tissue_train, mask_train, batch, noise_stddev=0.01):
    def parse_images(tissue_train, mask_train):
        tissue_file_str = tf.strings.join([train_path, "TissueImages/", tissue_train])
        mask_file_str = tf.strings.join([train_path, "GroundTruth/", mask_train])

        tissue_image = tf.io.read_file(tissue_file_str)
        mask_image = tf.io.read_file(mask_file_str)

        tissue_image = tf.image.decode_png(tissue_image, channels=3)
        mask_image = tf.image.decode_png(mask_image, channels=1)

        tissue_image = tf.image.resize(tissue_image, [512, 512])
        mask_image = tf.image.resize(mask_image, [512, 512])

        tissue_image = tf.cast(tissue_image, tf.float32) / 255.0
        mask_image = tf.cast(mask_image, tf.float32) / 255.0

        noise = tf.random.normal(shape=tf.shape(tissue_image), mean=0.0, stddev=noise_stddev)
        tissue_image = tissue_image + noise
        tissue_image = tf.clip_by_value(tissue_image, 0.0, 1.0)

        return tissue_image, mask_image

    dataset = tf.data.Dataset.from_tensor_slices((tissue_train, mask_train))
    dataset = dataset.map(parse_images)
    dataset = dataset.batch(batch)
    dataset = dataset.repeat()

    return dataset


In [None]:
def create_test(tissue_test, mask_test,batch):
  def parse_images(tissue_test, mask_test):

    tissue_file_str = tf.strings.join([test_path, "TissueImages/", tissue_test])
    mask_file_str = tf.strings.join([test_path, "GroundTruth/", mask_test])

    tissue_image = tf.io.read_file(tissue_file_str)
    mask_image = tf.io.read_file(mask_file_str)

    tissue_image = tf.image.decode_png(tissue_image, channels=3)
    mask_image = tf.image.decode_png(mask_image, channels=1)

    tissue_image = tf.image.resize(tissue_image, [512, 512])
    mask_image = tf.image.resize(mask_image, [512, 512])

    tissue_image = tf.cast(tissue_image, tf.float32) / 255.0
    mask_image = tf.cast(mask_image, tf.float32) / 255.0


    return tissue_image, mask_image

  dataset = tf.data.Dataset.from_tensor_slices((tissue_test, mask_test))
  dataset = dataset.map(parse_images)
  dataset = dataset.batch(batch)
  dataset = dataset.repeat()

  return dataset

In [None]:
TOTAL_TRAIN_SAMPLES = len(tissue_train)
TOTAL_TEST_SAMPLES = len(tissue_test)
BATCH_SIZE = 2

train_dataset = create_train(tissue_train, mask_train,BATCH_SIZE)
test_dataset = create_test(tissue_test, mask_test,BATCH_SIZE)

**MODEL**

*Graph Layer*

In [None]:
from tensorflow.keras.utils import register_keras_serializable

In [None]:
@register_keras_serializable(package="MyLayers")
class SqueezeExcitation(tf.keras.layers.Layer):
    def __init__(self, reduction_ratio=16, **kwargs):
        super(SqueezeExcitation, self).__init__(**kwargs)
        self.reduction_ratio = reduction_ratio
        self.global_pooling = tf.keras.layers.GlobalAveragePooling2D()
        # These will be initialized in build()
        self.squeeze_conv = None
        self.excitation_conv = None

    def build(self, input_shape):
        c = input_shape[-1]
        self.squeeze_conv = tf.keras.layers.Conv2D(
            filters=c // self.reduction_ratio,
            kernel_size=(1, 1),
            activation='relu',
            kernel_initializer='he_normal',
            use_bias=False
        )
        self.excitation_conv = tf.keras.layers.Conv2D(
            filters=c,
            kernel_size=(1, 1),
            activation='sigmoid',
            kernel_initializer='he_normal',
            use_bias=False
        )
        super(SqueezeExcitation, self).build(input_shape)

    def call(self, inputs):
        x = self.global_pooling(inputs)
        x = tf.keras.layers.Reshape((1, 1, inputs.shape[-1]))(x)
        x = self.squeeze_conv(x)
        x = self.excitation_conv(x)
        return inputs * x

    def get_config(self):
        config = super(SqueezeExcitation, self).get_config()
        config.update({'reduction_ratio': self.reduction_ratio})
        return config


In [None]:
class WeightLayer(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super(WeightLayer, self).__init__()

  def build(self, input_shape):
    # g_init = tf.random_normal_initializer()
      # w_init = tf.constant_initializer(1.0)
      self.weight_layer = keras.layers.Conv2D(
            filters = input_shape[-1],
            kernel_size = (1, 1),
            strides = (1, 1),
            padding = "same"
        )
      super(WeightLayer, self).build(input_shape)

  def call(self, inputs):
      w = self.weight_layer(inputs)
      weighted_inp = tf.multiply(inputs, w)
      return weighted_inp

In [None]:
class NewGraphLayer:

  def __init__(self, n):
        self.n = n
        self.mask_matrix = tf.tile( tf.eye(n), [1, n])
        self.unmask_matrix = tf.constant([[int(j // n == i) for j in range(n ** 2)] for i in range(n)],dtype=tf.float32)
        self.unflatten_mat = tf.transpose(tf.reshape(tf.eye(n)*tf.expand_dims(tf.expand_dims(tf.eye(n), axis = -1), axis = -1), (n, n*n, n)), perm = [2,1,0])
        # self.kl = weight_layer

  def custom_flatten(self,A):
        """
        A : (B, n,n, c)
        C : (B, n^2, c)
        """
        n = A.shape[1]
        B = tf.transpose(A, perm=[0,3,1,2])
        B = (B @  self.mask_matrix)*self.unmask_matrix
        B = tf.reduce_sum(B,axis=-2)
        C = tf.transpose(B, perm=[0,2,1])
        return C

  def custom_unflatten(self, A):
      """
      A : (B, n^2)
      C : (B, N, N)
      """
      C = tf.transpose((tf.matmul(A , self.unflatten_mat)),perm=[1,2,0])
      return C

  def sum_channels(self, flattened_nodes):
        """
        flattened_nodes : (B, N^2, C)
        x : (B, C)
        """
        x = tf.reduce_sum(flattened_nodes, axis = 1)
        x = tf.expand_dims(x, axis = 1)
        return x

  def compute_dot_products(self):
        """
        summed_vectors: (B, 1, C)
        node_features = (B, N^2, C)
        dot_products = (B, N^2)
        """
        dot_products = tf.reduce_sum(tf.multiply(self.node_features, self.summed_vectors), axis = -1)
        return dot_products

  def prune_channel_by_variance(self, feature_map): # feature_map = B x N x N x C
        variance_per_channel_vector = tf.math.reduce_variance(feature_map, axis = (1, 2), keepdims = True) # B x 1 x 1 x C
        # print(variance_per_channel_vector)
        # scaled_variance_per_channel_vector =  variance_per_channel_vector/ tf.reduce_max(variance_per_channel_vector, axis = (1,2,3), keepdims = True)
        # print(scaled_variance_per_channel_vector)
        mean_variance_per_sample = tf.reduce_mean(variance_per_channel_vector, axis = (1,2,3), keepdims = True) # (B,)
        # relu_mask = tf.keras.activations.relu((scaled_variance_per_channel_vector-mean_variance_per_sample)/tf.reduce_max(scaled_variance_per_channel_vector-mean_variance_per_sample, axis = (1,2,3), keepdims = True))
        # relu_mask_scaled = relu_mask / tf.reduce_max(relu_mask, axis = (1,2,3), keepdims = True)
        # print(mean_variance_per_sample)
        pruning_mask = tf.where(variance_per_channel_vector > mean_variance_per_sample, 1.0, 0.0)
        # print(relu_mask_scaled)
        # pruned_feature_map = feature_map * relu_mask_scaled
        pruned_feature_map = feature_map * pruning_mask

        # print(pruned_feature_map)
        return pruned_feature_map

  def create_graph_map(self, dot_products):
        map = self.custom_unflatten(dot_products)
        # map = tf.cast(map, tf.float32) / tf.reduce_max(map)
        # scaled_map = (map - tf.reduce_min(map)) / (tf.reduce_max(map) - tf.reduce_min(map))
        # min_values = tf.reduce_min(map, axis=(1, 2), keepdims=True)
        max_values = tf.reduce_max(map, axis=(1, 2), keepdims=True)
        # scaled_map = (map - min_values) / (max_values - min_values)
        map = tf.cast(map, tf.float32) / (max_values)
        return 3. * map

  def __call__(self, input_data):
        self.input_data = input_data
        self.pruned_data = self.prune_channel_by_variance(input_data)
        self.node_features =self.custom_flatten(self.pruned_data)
        self.summed_vectors = self.sum_channels(self.node_features)
        dot_products = self.compute_dot_products()
        # kernel_similarity_scores = self.kl(self.node_features)
        map = self.create_graph_map(dot_products)
        return map

**Attention layer**

In [None]:
@register_keras_serializable(package="MyLayers")
class GraphAttentionLayer(tf.keras.layers.Layer):
    def __init__(self, n, **kwargs):
        self.n = n
        super(GraphAttentionLayer, self).__init__(**kwargs)
    # @tf.function
    def build(self, input_shape):
        self.phi_g = Conv2D(
            filters = input_shape[-1],
            kernel_size = (1, 1),
            strides = (1, 1),
            padding = "same"
        )
        self.theta_x = Conv2D(
            filters = input_shape[-1],
            kernel_size = (1,1),
            strides = (1,1),
            padding = "same"
        )
        self.concatenate = Concatenate(axis = -1)
        self.psi = Conv2D(
            filters = 1,
            kernel_size = (1,1),
            padding = "same"
        )
        self.result = Conv2D(
            filters = input_shape[-1],
            kernel_size = (1,1),
            padding = "same"
        )
        self.multiply = tf.keras.layers.Multiply()
        self.bn =  tf.keras.layers.BatchNormalization()
        self.up = UpSampling2D(size = (2, 2))
        self.wl = WeightLayer()
        super(GraphAttentionLayer, self).build(input_shape)

    # @tf.function
    def call(self, input_x, input_g):
        _, h, w, c = input_x.shape
        x = tf.reshape(input_x, (tf.shape(input_x)[0], h//2, w//2, c * 4))
        x = self.theta_x(x)
        g = self.phi_g(input_g)

        concat_inputs = self.concatenate([x, g])
        concat_inputs_wl = self.wl(concat_inputs)
        concat_inputs_activated = tf.keras.activations.relu(concat_inputs_wl)

        a = NewGraphLayer(self.n)
        map = a(concat_inputs_activated)
        map_expanded = tf.expand_dims(map, axis = -1)
        map_expanded_ = tf.keras.activations.sigmoid(map_expanded)
        map_upsampled = self.up(map_expanded_)

        y = self.multiply([input_x, map_upsampled])
        y_res = self.result(y)

        y_bn = self.bn(y_res)

        return y_bn,map_upsampled

*Instance normalization*

In [None]:
class InstanceNormalization(tf.keras.layers.Layer):
    def __init__(self, epsilon=1e-5):
        super(InstanceNormalization, self).__init__()
        self.epsilon = epsilon

    def build(self, input_shape):
        # Create a scale parameter and a shift parameter for each channel
        self.scale = self.add_weight(
            name='scale',
            shape=(input_shape[-1],),
            initializer='ones',
            trainable=True
        )
        self.shift = self.add_weight(
            name='shift',
            shape=(input_shape[-1],),
            initializer='zeros',
            trainable=True
        )

    def call(self, inputs):
        # Calculate mean and variance for each channel independently
        mean = tf.reduce_mean(inputs, axis=[1, 2], keepdims=True)
        variance = tf.reduce_mean(tf.square(inputs - mean), axis=[1, 2], keepdims=True)

        # Normalize the input
        normalized = (inputs - mean) / tf.sqrt(variance + self.epsilon)

        # Apply scale and shift
        output = self.scale * normalized + self.shift
        return output

In [None]:
@register_keras_serializable(package="MyLayers")
class GroupNormalization(tf.keras.layers.Layer):
    def __init__(self, groups=1, epsilon=1e-5, **kwargs):
        super(GroupNormalization, self).__init__(**kwargs)
        self.groups = groups
        self.epsilon = epsilon


    def build(self, input_shape):
        # Ensure that the number of channels is divisible by the number of groups
        assert input_shape[-1] % self.groups == 0

        # Create a scale parameter and a shift parameter for each group
        self.scale = self.add_weight(
            name='scale',
            shape=(input_shape[-1],),
            initializer='ones',
            trainable=True
        )
        self.shift = self.add_weight(
            name='shift',
            shape=(input_shape[-1],),
            initializer='zeros',
            trainable=True
        )

    def call(self, inputs):
        input_shape = tf.shape(inputs)
        batch_size, height, width, channels = input_shape[0], input_shape[1], input_shape[2], input_shape[3]

        # Reshape inputs to [batch_size, height, width, groups, channels_per_group]
        grouped_inputs = tf.reshape(inputs, [batch_size, height, width, self.groups, channels // self.groups])

        # Calculate mean and variance for each group independently
        mean = tf.reduce_mean(grouped_inputs, axis=[1, 2, 4], keepdims=True)
        variance = tf.reduce_mean(tf.square(grouped_inputs - mean), axis=[1, 2, 4], keepdims=True)

        # mean, variance = tf.nn.moments(grouped_inputs, [1, 2, 4], keepdims=True)

        # Normalize the input within each group
        normalized = (grouped_inputs - mean) / tf.sqrt(variance + self.epsilon)

        # Reshape back to the original shape
        normalized = tf.reshape(normalized, [batch_size, height, width, channels])

        # Apply scale and shift
        output = self.scale * normalized + self.shift
        return output

        def get_config(self):
          config = super(GroupNormalization, self).get_config()
          config.update({
              'groups': self.groups,
              'epsilon': self.epsilon,
          })
          return config


*weighted GAP*

In [None]:
class WeightedGlobalAveragePooling2D(tf.keras.layers.Layer):
    def __init__(self, num_channels, **kwargs):
        super(WeightedGlobalAveragePooling2D, self).__init__(**kwargs)
        self.num_channels = num_channels
        # Create a trainable weight variable for each channel
        self.channel_weights = self.add_weight(name='channel_weights',
                                              shape=(num_channels,),
                                              initializer='ones',
                                              trainable=True)

    def call(self, inputs):
        # Calculate weighted global average pooling
        weighted_sum = tf.reduce_sum(inputs * self.channel_weights, axis=[1, 2])
        weighted_average = weighted_sum / tf.reduce_sum(self.channel_weights)
        return weighted_average

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.num_channels)

In [None]:
def WGCAM(x):
    num_filters = x.shape[-1]
    wav = DWT.DWT(concat=0)(x)
    wav = Conv2DTranspose(num_filters*4, (2, 2), strides=2, padding="same")(wav)
    wav = SeparableConv2D(num_filters, (1,1), padding="same")(wav)
    x_sam = SeparableConv2D(num_filters, (1,1), padding="same")(x)
    x_sam = wav+x_sam
    x_sam = SeparableConv2D(num_filters, (1,1), padding="same", activation='sigmoid')(x_sam)
    x_cam = WeightedGlobalAveragePooling2D(num_filters)(x)
    x_cam = Dense(num_filters/4, activation='relu')(x_cam)
    x_cam = Dense(num_filters, activation='sigmoid')(x_cam)
    x_cam = tf.keras.layers.Reshape((1, 1, x_cam.shape[-1]))(x_cam)
    x = x*x_sam
    x = tf.keras.layers.Multiply()([x, x_cam])
    return x

In [None]:
def CombinedUpsampleLayer(inputs):
    _,H,W,C = inputs.shape
    # gaussian = UpSampling2D(size=(2, 2), interpolation="gaussian")(inputs)
    combined = UpSampling2D(size=(2, 2), interpolation="lanczos5")(inputs)
    # combined = tf.keras.layers.Add()([gaussian, lanczos])
    combined_attn = Conv2D(C, 1, padding="same")(combined)

    # Assuming you want to upsample to the original input size
    upsampled = Conv2DTranspose(C, (2, 2), strides=2, padding="same")(inputs)

    x = Concatenate()([combined_attn,upsampled])
    return x

In [None]:
# @title Default title text
class FunctionalModel:
  def __init__(self):
    self.pre_trained_backbone =  tf.keras.applications.DenseNet121(
      include_top=False ,
      weights='imagenet' ,
      input_shape=(512, 512, 3))

  def conv_block(self,inputs, num_filters):
    x1 = Conv2D(num_filters//2, 5, padding="same")(inputs)
    x1 = GroupNormalization()(x1)
    x1 = Activation("relu")(x1)

    x2 = Conv2D(num_filters//2, 3, padding="same")(inputs)
    x2 = GroupNormalization()(x2)
    x2 = Activation("relu")(x2)

    x2 = Concatenate()([x1,x2])
    x2 = Conv2D(num_filters, 1, padding="same")(x2)

    x3 = Conv2D(num_filters, 1, padding="same")(inputs)
    x3 = GroupNormalization()(x3)
    x3 = Activation("relu")(x3)

    x3 = Concatenate()([x2,x3])

    x = Conv2D(num_filters, 3, padding="same")(x3)
    x = GroupNormalization()(x)
    x = Activation("relu")(x)

    return x


  def bridge_layer(self, x, num_filters):
    x_init = x

    x = Conv2D(num_filters, 4, padding="same")(x)
    x = GroupNormalization(groups=4)(x)
    x = Activation("relu")(x)

    x = Conv2D(num_filters, 4, padding="same")(x)
    x = GroupNormalization(groups=4)(x)
    x = Activation("relu")(x)
    return x

  def decoder_block(self,inputs, skip_features, num_filters):
    x = CombinedUpsampleLayer(inputs)
    skip_features = Activation("relu")(skip_features)
    x = Concatenate()([x, skip_features])
    x = self.conv_block(x, num_filters)
    return x

  def get(self):

    inputs = self.pre_trained_backbone.input

    s1 = 'input_1'
    s2 = 'conv1_relu'
    s3 = 'pool2_relu'
    s4 = 'pool3_relu'

    # Encoder
    s1 = inputs
    s2 = self.pre_trained_backbone.get_layer(s2).output #256
    s3 = self.pre_trained_backbone.get_layer(s3).output #128
    s4 = self.pre_trained_backbone.get_layer(s4).output #64
    b = self.pre_trained_backbone.get_layer("pool4_relu").output  ## 32

    x1 = SqueezeExcitation()(s4)
    G11, map11 = GraphAttentionLayer(n=32, name = "Graph11_64")(x1, b)
    G12, map12 = GraphAttentionLayer(n=32, name = "Graph12_64")(x1, b)
    G13, map13 = GraphAttentionLayer(n=32, name = "Graph13_64")(x1, b)

    G1 = Concatenate(axis=-1)([G11, G12, G13])

    # Decoder
    d1 = self.decoder_block(b, G1, 512) # 32 -> b gets upsampled to 64 through CombinedUpsampleLayer() # 64

    # Graph attention 1 to d2 with dimension 64
    x2 = SqueezeExcitation()(s3)
    G21, map21 = GraphAttentionLayer(n=64, name = "Graph21_64")(x2, d1)
    G22, map22 = GraphAttentionLayer(n=64, name = "Graph22_64")(x2, d1)
    G23, map23 = GraphAttentionLayer(n=64, name = "Graph23_64")(x2, d1)

    # G2 = Concatenate([G21, G22, G23], axis = -1)
    G2 = Concatenate(axis=-1)([G21, G22, G23])
    # G2 =  SqueezeExcitation()(G2)

    d2 = self.decoder_block(d1, G2, 256)                               # 128

    # Graph attention 2 to d3 with dimension 256
    x3 = SqueezeExcitation()(s2)
    G31, map31 = GraphAttentionLayer(n=128, name = "Graph31_256")(x3, d2)
    G32, map32 = GraphAttentionLayer(n=128, name = "Graph32_256")(x3, d2)
    G33, map33 = GraphAttentionLayer(n=128, name = "Graph33_256")(x3, d2)

    # G3 = Concatenate([G31, G32, G33], axis = -1)
    G3 = Concatenate(axis=-1)([G31, G32, G33])

    d3 = self.decoder_block(d2, G3, 128)                             # 256
    d4 = self.decoder_block(d3, s1, 64)                                 # 512


    outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4)
    # map1 = s4
    model = Model(inputs=inputs, outputs=[outputs, s3, s4, map11, map12, map13, map21, map22, map23, map31, map32, map33])
    return model


In [None]:
smooth = 1e-3

class DiceCoeff:
  def __init__(self):
    pass
  def dice_coef(self, y_true, y_pred):
    y_true = tf.keras.layers.Flatten()(y_true)
    y_pred = tf.keras.layers.Flatten()(y_pred)
    intersection = tf.reduce_sum(y_true * y_pred)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + smooth)
  def __call__(self, y_true, y_pred):
    return self.dice_coef(y_true, y_pred)

class Precision(tf.keras.metrics.Metric):
    def __init__(self, name='precision', **kwargs):
        super(Precision, self).__init__(name=name, **kwargs)
        self.true_positives = self.add_weight(name='tp', initializer='zeros')
        self.predicted_positives = self.add_weight(name='pp', initializer='zeros')

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.cast(y_true, tf.bool)
        y_pred = tf.cast(y_pred, tf.bool)

        true_positives = tf.reduce_sum(tf.cast(tf.logical_and(y_true, y_pred), tf.float32))
        predicted_positives = tf.reduce_sum(tf.cast(y_pred, tf.float32))

        self.true_positives.assign_add(true_positives)
        self.predicted_positives.assign_add(predicted_positives)

    def result(self):
        return self.true_positives / (self.predicted_positives + K.epsilon())

    def reset_states(self):
        self.true_positives.assign(0)
        self.predicted_positives.assign(0)


class Recall(tf.keras.metrics.Metric):
    def __init__(self, name='recall', **kwargs):
        super(Recall, self).__init__(name=name, **kwargs)
        self.true_positives = self.add_weight(name='tp', initializer='zeros')
        self.possible_positives = self.add_weight(name='pp', initializer='zeros')

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.cast(y_true, tf.bool)
        y_pred = tf.cast(y_pred, tf.bool)

        true_positives = tf.reduce_sum(tf.cast(tf.logical_and(y_true, y_pred), tf.float32))
        possible_positives = tf.reduce_sum(tf.cast(y_true, tf.float32))

        self.true_positives.assign_add(true_positives)
        self.possible_positives.assign_add(possible_positives)

    def result(self):
        return self.true_positives / (self.possible_positives + K.epsilon())

    def reset_states(self):
        self.true_positives.assign(0)
        self.possible_positives.assign(0)


class F1Score(tf.keras.metrics.Metric):
    def __init__(self, name='f1_score', **kwargs):
        super(F1Score, self).__init__(name=name, **kwargs)
        self.precision = Precision()
        self.recall = Recall()

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.reshape(y_true, [-1])
        y_pred = tf.reshape(y_pred, [-1])
        self.precision.update_state(y_true, y_pred, sample_weight)
        self.recall.update_state(y_true, y_pred, sample_weight)

    def result(self):
        precision = self.precision.result()
        recall = self.recall.result()
        return 2 * ((precision * recall) / (precision + recall + K.epsilon()))

    def reset_states(self):
        self.precision.reset_states()
        self.recall.reset_states()


class VOE:
  def __init__(self):
    pass
  def voe_metric(self, y_true, y_pred):
    intersection = tf.reduce_sum(y_true * y_pred)
    union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) - intersection
    voe = 1.0 - (intersection / union)
    return voe
  def __call__(self, y_true, y_pred):
    return self.voe_metric(y_true, y_pred)

!pip install keras_unet_collection



In [None]:
!pip install hausdorff



In [None]:
from keras_unet_collection import losses
from hausdorff import hausdorff_distance

class DiceLoss:
  def __init__(self):
    pass
  def dice_loss(self, y_true, y_pred):
    return 1.0 - DiceCoeff()(y_true, y_pred)
  def __call__(self, y_true, y_pred):
    return self.dice_loss(y_true, y_pred)

class HybridLoss:
  def __init__(self):
    pass

  def hybrid_loss(self, y_true, y_pred):
    loss_focal = losses.focal_tversky(y_true, y_pred, alpha=0.3, gamma=4/3)
    loss_dice = DiceLoss()(y_true, y_pred)
    return loss_focal+loss_dice

  def __call__(self, y_true, y_pred):
    return self.hybrid_loss(y_true, y_pred)

class CustomLoss:
  def __init__(self):
    pass

  def custom_loss(self, y_true, y_pred):
    loss_dice = DiceLoss()(y_true, y_pred)
    loss_voe = VOE()(y_true, y_pred)
    return loss_voe + loss_dice

  def __call__(self, y_true, y_pred):
    return self.custom_loss(y_true, y_pred)

class BCELoss:
  def __init__(self):
    pass

  def bce_loss(self, y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_true, logits=y_pred))
    return loss

  def __call__(self, y_true, y_pred):
    return self.bce_loss(y_true, y_pred)

class CombinedLoss:
  def __init__(self):
    pass

  def combined_loss(self, y_true, y_pred):
    loss = DiceLoss()(y_true, y_pred) + BCELoss()(y_true, y_pred)
    return loss

  def __call__(self, y_true, y_pred):
    return self.combined_loss(y_true, y_pred)

class HausdorffLoss:
    def __init__(self):
        pass

    def euclidean_distance(self, x, y):
        # Compute pairwise Euclidean distance between two sets of points
        return tf.sqrt(tf.reduce_sum(tf.square(tf.expand_dims(x, 1) - tf.expand_dims(y, 0)), axis=-1))

    def hausdorff_distance(self, x, y):
        # Compute pairwise distances
        distances_x_to_y = self.euclidean_distance(x, y)
        distances_y_to_x = self.euclidean_distance(y, x)

        # Calculate Hausdorff distance
        hausdorff_distance = tf.reduce_max(tf.reduce_min(distances_x_to_y, axis=1))
        hausdorff_distance = tf.maximum(hausdorff_distance, tf.reduce_max(tf.reduce_min(distances_y_to_x, axis=1)))

        return hausdorff_distance

    def hausdorff_loss(self, pmask, gtmask):
        pmask1 = tf.squeeze(pmask[0])
        pmask2 = tf.squeeze(pmask[1])
        gtmask1 = tf.squeeze(gtmask[0])
        gtmask2 = tf.squeeze(gtmask[1])

        # Calculate Hausdorff distance between the masks
        loss1 = self.hausdorff_distance(pmask1, gtmask1)
        loss2 = self.hausdorff_distance(pmask2, gtmask2)

        # Return the average Hausdorff loss
        return (loss1 + loss2) / 2

    def __call__(self, pmask, gtmask):
        return self.hausdorff_loss(pmask, gtmask)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tqdm import tqdm

class CustomModelWrapper:
    def __init__(self, model):
        self.model = model
        self.prev_model = model
        self.t_steps_per_epoch = TOTAL_TRAIN_SAMPLES // BATCH_SIZE
        self.v_steps_per_epoch = TOTAL_TEST_SAMPLES // BATCH_SIZE
        self.cur_epoch = 0
        self.best_model = {'score': 0, 'model': self.model, 'result': None, 'metric_vals': None}
        self.history = {}

    def compile(self, loss_objs: dict, optimizer_obj, metrics: dict):
        self.loss_objs, self.optimizer_obj, self.metrics = loss_objs, optimizer_obj, metrics

    @tf.function
    def train_single_batch(self, x, y):
        with tf.GradientTape() as tape:
            preds, s3, s4, map11, map12, map13, map21, map22, map23, map31, map32, map33 = self.model(x, training=True)
            custom_loss_value = self.loss_objs['custom_loss'](y, preds)
            dice_loss = self.loss_objs['dice_loss'](y, preds)
            hausdorff_loss = self.loss_objs['hausdorff_loss'](preds, y) / 30
            loss_value = custom_loss_value + hausdorff_loss

        grads = tape.gradient(loss_value, self.model.trainable_weights)
        self.optimizer_obj.apply_gradients(zip(grads, self.model.trainable_weights))
        return preds, loss_value, dice_loss, hausdorff_loss, custom_loss_value

    def train_single_epoch(self, data):
        losses = []
        dice_scores = []
        hausdorff_loss = []
        custom_loss = []

        pbar = tqdm(total=self.t_steps_per_epoch, position=0, leave=True,
                    bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} ')

        for step in range(1, self.t_steps_per_epoch + 1):
            (x, y) = next(data)
            preds, loss_value, dice_loss, hloss, custom_loss_value = self.train_single_batch(x, y)

            losses.append(loss_value)
            hausdorff_loss.append(hloss)
            custom_loss.append(custom_loss_value)
            dice_score = self.metrics['train_dice'](y, preds)
            if tf.reduce_max(y * preds) > 0:
                dice_scores.append(dice_score)

            pbar.set_description(
                f"Training loss: {loss_value:.4f}, Training Custom loss: {custom_loss_value:.4f}, Training Dice loss: {dice_loss:.4f}, Hloss: {hloss:.4f} for step: {step}")
            pbar.update()
        pbar.close()
        return losses, dice_scores, hausdorff_loss, custom_loss

    def validate_single_epoch(self, data):
        losses = []
        val_dice_scores = []

        for step in range(self.v_steps_per_epoch):
            (x, y) = next(data)
            preds, s4, s3, map11, map12, map13, map21, map22, map23, map31, map32, map33 = self.model(x)

            loss_value = self.loss_objs['custom_loss'](y, preds)
            losses.append(loss_value)

            val_dice_score = self.metrics['val_dice'](y, preds)
            if tf.reduce_max(y * preds) > 0:
                val_dice_scores.append(val_dice_score)

        return losses, val_dice_scores

    def fit(self, train_data, val_data, epochs):
        train_data_iter = iter(train_data)  # Convert train_data to an iterator
        val_data_iter = iter(val_data)

        history = {'train_loss': [], 'val_loss': [], 'train_dice_loss': [], 'train_dice': [], 'val_dice': [], 'hausdorff_loss': [], 'custom_loss': []}
        for epoch in range(epochs):
            train_losses, train_dice_scores, hloss, closs = self.train_single_epoch(train_data_iter)
            train_dice_result = np.mean(train_dice_scores)
            train_hloss = np.mean(hloss)
            train_custom_loss = np.mean(closs)
            history['hausdorff_loss'].append(train_hloss)
            history['train_dice'].append(train_dice_result)
            history['custom_loss'].append(train_custom_loss)

            val_losses, val_dice_scores = self.validate_single_epoch(val_data_iter)
            val_dice_result = np.mean(val_dice_scores)
            history['val_dice'].append(val_dice_result)

            history['train_loss'].append(np.mean(train_losses))
            history['val_loss'].append(np.mean(val_losses))

            print(
                f'\n Epoch {epoch}: Train loss: {np.mean(train_losses):.4f}, Validation Loss: {np.mean(val_losses):.4f}, Train Dice: {train_dice_result:.4f}, Validation Dice: {val_dice_result:.4f}, Train Custom Loss: {train_custom_loss:.4f}, Train Hloss: {train_hloss:.4f}')

            self.cur_epoch += 1
            self.prev_model = self.model

            if self.cur_epoch % 5 == 0:
                # Display the intermediate feature maps in a 3x3 grid
                out, s3, s4, map11, map12, map13, map21, map22, map23, map31, map32, map33 = self.model(tissue, training=False)

                feature_maps = [map11, map12, map13, map21, map22, map23, map31, map32, map33, out, out, out]

                plt.figure(figsize=(15, 10))
                for i, fmap in enumerate(feature_maps):
                    plt.subplot(3, 4, i + 1)
                    plt.imshow(fmap[1, :, :, 0], cmap='gray')
                    plt.title(f'Map {i // 4 + 1}-{i % 4 + 1}')
                plt.show()

            if self.best_model['score'] < val_dice_result:
                print("Storing new best ...")
                self.best_model['score'] = val_dice_result
                self.best_model['model'] = self.model
                self.best_model['metric_vals'] = {'val_dice_tf': val_dice_result}

        history['model'] = self.model
        self.history = history

        return history


In [None]:
model_wrapper = CustomModelWrapper(FunctionalModel().get())

In [None]:
model_wrapper.compile(
    loss_objs = {
        'hybrid_loss':HybridLoss(),
        'custom_loss':HybridLoss(),
        'dice_loss': DiceLoss(),
        'hausdorff_loss' : HausdorffLoss()
        },
    optimizer_obj =  tf.keras.optimizers.Nadam(learning_rate = 0.0003),
    metrics = {
        'train_acc':tf.keras.metrics.Accuracy(),
        'val_acc':tf.keras.metrics.Accuracy(),
        'train_dice':DiceCoeff(),
        'val_dice': DiceCoeff(),
        'train_f1': F1Score(),
        'val_f1': F1Score(),
        'train_rec': Recall(),
        'val_rec': Recall(),
        'train_pre': Precision(),
        'val_pre': Precision()
    }
)

In [None]:
history = model_wrapper.fit(
    train_dataset,
    test_dataset,
    epochs=50
)

  0%|          | 0/15 