In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import tensorflow as tf
import numpy as np
from tensorflow.keras import models,layers
from tensorflow.keras.utils import get_file
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
import tensorflow.keras.backend as K
import re
import sys
from collections import namedtuple
#import tensorflow.keras.backend as K
import math

%cd /content/drive/MyDrive/MRI_ACDCA/
from keras_vision_transformer import swin_layers
from keras_vision_transformer import transformer_layers
from keras_vision_transformer import utils
from tensorflow.python.framework.ops import disable_eager_execution

/content/drive/.shortcut-targets-by-id/1kdKWMtQq063wklZtXKlzGd-nUeOnEiDw/MRI_ACDCA


## Preparation

In [None]:
class Swish(tf.keras.layers.Layer):
    def __init__(self, name=None, **kwargs):
        super().__init__(name=name, **kwargs)

    def call(self, inputs, **kwargs):
        return tf.nn.swish(inputs)

    def get_config(self):
        config = super().get_config()
        config['name'] = self.name
        return config

def squeeze_excite_block(reduce_ratio=0.25,name_block=None):
  def call(inputs):
    filters = inputs.shape[-1]
    num_reduced_filters= max(1, int(filters * reduce_ratio))
    se = Lambda(lambda a: K.mean(a, axis=[1,2], keepdims=True))(inputs)

    se = Conv2D(
            num_reduced_filters,
            kernel_size=[1, 1],
            strides=[1, 1],
            kernel_initializer='he_normal',
            padding='same',
            use_bias=True
        )(se)
    se = Swish()(se)
    se = Conv2D(
            filters,
            kernel_size=[1, 1],
            strides=[1, 1],
            kernel_initializer='he_normal',
            padding='same',
            use_bias=True
        )(se)
    se = Activation('sigmoid')(se)
    if name_block is not None:
      out = Multiply(name=name_block)([se, inputs])
    else : 
      out = Multiply()([se, inputs])
    return out
  return call

def conv_block(filters,block_name=None): #kernel_size = (3,3), dilation = 1
  def call(inputs):
    x = inputs

    x = Conv2D(filters, kernel_size=(3,3), padding="same",dilation_rate =3, use_bias=False,kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Swish()(x)

    x = Conv2D(filters, kernel_size=(3,3), padding="same",dilation_rate =3, use_bias=False,kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Swish()(x)

    x = squeeze_excite_block(name_block=block_name)(x)

    return x
  return call

In [None]:
def up_and_concate(down_layer, layer, data_format='channels_last'):
    data_format='channels_last'
    if data_format == 'channels_first':
        in_channel = down_layer.get_shape().as_list()[1]
    else:
        in_channel = down_layer.get_shape().as_list()[3]

    # up = Conv2DTranspose(out_channel, [2, 2], strides=[2, 2])(down_layer)
    up = UpSampling2D(size=(2, 2), data_format=data_format)(down_layer)

    if data_format == 'channels_first':
        my_concat = Lambda(lambda x: K.concatenate([x[0], x[1]], axis=1))
    else:
        my_concat = Lambda(lambda x: K.concatenate([x[0], x[1]], axis=3))

    concate = my_concat([up, layer])

    return concate
def attention_up_and_concate(down_layer, layer, data_format='channels_last'):
    data_format='channels_last'
    if data_format == 'channels_first':
        in_channel = down_layer.get_shape().as_list()[1]
    else:
        in_channel = down_layer.get_shape().as_list()[3]

    # up = Conv2DTranspose(out_channel, [2, 2], strides=[2, 2])(down_layer)
    up = UpSampling2D(size=(2, 2), data_format=data_format)(down_layer)

    layer = attention_block_2d(x=layer, g=up, inter_channel=in_channel // 4, data_format=data_format)

    if data_format == 'channels_first':
        my_concat = Lambda(lambda x: K.concatenate([x[0], x[1]], axis=1))
    else:
        my_concat = Lambda(lambda x: K.concatenate([x[0], x[1]], axis=3))

    concate = my_concat([up, layer])
    return concate
def attention_block_2d(x, g, inter_channel, data_format='channels_last'):
    data_format='channels_last'
    # theta_x(?,g_height,g_width,inter_channel)

    theta_x = Conv2D(inter_channel, [1, 1], strides=[1, 1], data_format=data_format)(x)

    # phi_g(?,g_height,g_width,inter_channel)

    phi_g = Conv2D(inter_channel, [1, 1], strides=[1, 1], data_format=data_format)(g)

    # f(?,g_height,g_width,inter_channel)

    f = Activation('swish')(add([theta_x, phi_g]))

    # psi_f(?,g_height,g_width,1)

    psi_f = Conv2D(1, [1, 1], strides=[1, 1], data_format=data_format)(f)

    rate = Activation('sigmoid')(psi_f)

    # rate(?,x_height,x_width)

    # att_x(?,x_height,x_width,x_channel)

    att_x = multiply([x, rate])

    return att_x

In [None]:
def convolution_block(x, filters, strides, dilation_rate=(1,1), padding='same'):
    x = Conv2D(filters, strides, dilation_rate=dilation_rate, padding=padding)(x)
    x = BatchNormalization()(x)
    x = Swish()(x)
    return x

def residual_block(blockInput, num_filters):
    x = Swish()(blockInput)
    x = BatchNormalization()(x)
    blockInput = BatchNormalization()(blockInput)
    x = convolution_block(x, num_filters, (3,3))
    x = convolution_block(x, num_filters, (3,3), activation=False) #here originally no activation
    x = squeeze_excite_block()(x)
    x = Add()([x, blockInput])
    return x

In [None]:
def ASPP(x, filter):
    shape = x.shape

    y1 = AveragePooling2D(pool_size=(shape[1], shape[2]))(x)
    y1 = Conv2D(filter, 1, padding="same",use_bias=False,kernel_initializer='he_normal')(y1)
    y1 = BatchNormalization()(y1)
    y1 = Swish()(y1)
    y1 = UpSampling2D((shape[1], shape[2]), interpolation='bilinear')(y1)
    y1 = squeeze_excite_block()(y1)

    y2 = Conv2D(filter, 1, dilation_rate=1, padding="same", use_bias=False,kernel_initializer='he_normal')(x)
    y2 = BatchNormalization()(y2)
    y2 = Swish()(y2)
    y2 = squeeze_excite_block()(y2)

    y3 = Conv2D(filter, 3, dilation_rate=6, padding="same", use_bias=False,kernel_initializer='he_normal')(x)
    y3 = BatchNormalization()(y3)
    y3 = Swish()(y3)
    y3 = squeeze_excite_block()(y3)

    y4 = Conv2D(filter, 5, dilation_rate=12, padding="same", use_bias=False,kernel_initializer='he_normal')(x)
    y4 = BatchNormalization()(y4)
    y4 = Swish()(y4)
    y4 = squeeze_excite_block()(y4)

    y5 = Conv2D(filter, 7, dilation_rate=18, padding="same", use_bias=False,kernel_initializer='he_normal')(x)
    y5 = BatchNormalization()(y5)
    y5 = Swish()(y5)
    y5 = squeeze_excite_block()(y5)

    y = Concatenate()([y1, y2, y3, y4, y5])

    y = Conv2D(filter, 1, dilation_rate=1, padding="same", use_bias=False,kernel_initializer='he_normal')(y)
    y = BatchNormalization()(y)
    y = Swish()(y)
    y = squeeze_excite_block()(y)
    return y

In [None]:
def RFB(x, filter):
    #x = Swish()(x)
    x0 = convolution_block(x, filter, strides=(1,1), dilation_rate=(1,1), padding='same') 

    x1 = convolution_block(x, filter, strides=(1,1), dilation_rate=(1,1), padding='same')
    x1 = convolution_block(x1, filter, strides=(1,3), dilation_rate=(1,1), padding='same')
    x1 = convolution_block(x1, filter, strides=(3,1), dilation_rate=(1,1), padding='same')
    x1 = convolution_block(x1, filter, strides=(3,3), dilation_rate=(3,3), padding='same')

    x2 = convolution_block(x, filter, strides=(1,1), dilation_rate=(1,1), padding='same')
    x2 = convolution_block(x2, filter, strides=(1,5), dilation_rate=(1,1), padding='same')
    x2 = convolution_block(x2, filter, strides=(5,1), dilation_rate=(1,1), padding='same')
    x2 = convolution_block(x2, filter, strides=(3,3), dilation_rate=(5,5), padding='same')

    x3 = convolution_block(x, filter, strides=(1,1), dilation_rate=(1,1), padding='same')
    x3 = convolution_block(x3, filter, strides=(1,7), dilation_rate=(1,1), padding='same')
    x3 = convolution_block(x3, filter, strides=(7,1), dilation_rate=(1,1), padding='same')
    x3 = convolution_block(x3, filter, strides=(3,3), dilation_rate=(7,7), padding='same')

    x_res = convolution_block(x, filter, strides=(1,1), dilation_rate=(1,1), padding='same')

    x_concat = concatenate([x0,x1,x2,x3], axis=-1)
    x_concat = convolution_block(x_concat, filter, strides=(1,1), dilation_rate=(1,1), padding='same')
    x = Swish()(x_concat+x_res)
    return x

In [None]:
def _bernoulli(shape, mean): ####swish?!?
    return tf.nn.relu(tf.sign(mean - tf.random.uniform(shape, minval=0, maxval=1, dtype=tf.float32)))


class DropBlock2D(tf.keras.layers.Layer):
    def __init__(self, keep_prob, block_size, scale=True, **kwargs):
        super(DropBlock2D, self).__init__(**kwargs)
        self.keep_prob = float(keep_prob) if isinstance(keep_prob, int) else keep_prob
        self.block_size = int(block_size)
        self.scale = tf.constant(scale, dtype=tf.bool) if isinstance(scale, bool) else scale

    def compute_output_shape(self, input_shape):
        return input_shape

    def build(self, input_shape):
        assert len(input_shape) == 4
        _, self.h, self.w, self.channel = input_shape.as_list()
        # pad the mask
        p1 = (self.block_size - 1) // 2
        p0 = (self.block_size - 1) - p1
        self.padding = [[0, 0], [p0, p1], [p0, p1], [0, 0]]
        self.set_keep_prob()
        super(DropBlock2D, self).build(input_shape)

    def call(self, inputs, training=None, **kwargs):
        def drop():
            mask = self._create_mask(tf.shape(inputs))
            output = inputs * mask
            output = tf.cond(self.scale,
                             true_fn=lambda: output * tf.compat.v1.to_float(tf.size(mask)) / tf.reduce_sum(mask),
                             false_fn=lambda: output)
            return output

        if training is None:
            training = K.learning_phase()
        output = tf.cond(tf.logical_or(tf.logical_not(training), tf.equal(self.keep_prob, 1.0)),
                         true_fn=lambda: inputs,
                         false_fn=drop)
        return output

    def set_keep_prob(self, keep_prob=None):
        """This method only supports Eager Execution"""
        if keep_prob is not None:
            self.keep_prob = keep_prob
        w, h = tf.compat.v1.to_float(self.w), tf.compat.v1.to_float(self.h)
        self.gamma = (1. - self.keep_prob) * (w * h) / (self.block_size ** 2) / \
                     ((w - self.block_size + 1) * (h - self.block_size + 1))

    def _create_mask(self, input_shape):
        sampling_mask_shape = tf.stack([input_shape[0],
                                       self.h - self.block_size + 1,
                                       self.w - self.block_size + 1,
                                       self.channel])
        mask = _bernoulli(sampling_mask_shape, self.gamma)
        mask = tf.pad(mask, self.padding)
        mask = tf.nn.max_pool(mask, [1, self.block_size, self.block_size, 1], [1, 1, 1, 1], 'SAME')
        mask = 1 - mask
        return mask

In [None]:
def mvn(tensor):
    '''Performs per-channel spatial mean-variance normalization.'''
    epsilon = 1e-6
    mean = K.mean(tensor, axis=(1,2), keepdims=True)
    std = K.std(tensor, axis=(1,2), keepdims=True)
    mvn = (tensor - mean) / (std + epsilon)
    
    return mvn

def crop(tensors):
    '''
    List of 2 tensors, the second tensor having larger spatial dimensions.
    '''
    h_dims, w_dims = [], []
    for t in tensors:
        b, h, w, d = K.int_shape(t)
        h_dims.append(h)
        w_dims.append(w)
    crop_h, crop_w = (h_dims[1] - h_dims[0]), (w_dims[1] - w_dims[0])
    rem_h = int(crop_h % 2)
    rem_w = int(crop_w % 2)
    tt_h = int(crop_h / 2)
    tt_w = int(crop_w / 2)
    crop_h_dims = (tt_h, tt_h + rem_h)
    crop_w_dims = (tt_w, tt_w + rem_w)
    cropped = Cropping2D(cropping=(crop_h_dims, crop_w_dims))(tensors[1])
    
    return cropped

## Residual Attention

In [None]:
def residual_attention_concate_UNet(input_shape=(128, 128, 1), out_channels=3):
    input = Input(shape=input_shape, dtype='float', name='data')
    batchnorm1 = BatchNormalization()(input)

    batchnorm1 = RFB(batchnorm1,32)

    conv1 = Conv2D(32, 3, padding = 'same')(batchnorm1)
    conv1 = BatchNormalization()(conv1)
    conv1 = Swish()(conv1)
    conv1 = Conv2D(32, 3,  padding = 'same')(conv1)
    conv1 = BatchNormalization()(conv1)
    conv1 = Swish()(conv1)
    conv1 = squeeze_excite_block()(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, 3,  padding = 'same')(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Swish()(conv2)
    conv2 = Conv2D(64, 3,  padding = 'same')(conv2)
    conv2 = BatchNormalization()(conv2)
    conv2 = Swish()(conv2)
    conv2 = squeeze_excite_block()(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, 3,  padding = 'same')(pool2)
    conv3 = BatchNormalization()(conv3)
    conv3 = Swish()(conv3)
    conv3 = Conv2D(128, 3,  padding = 'same')(conv3)
    conv3 = BatchNormalization()(conv3)
    conv3 = Swish()(conv3)
    conv3 = squeeze_excite_block()(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(256, 3,  padding = 'same')(pool3)
    conv4 = BatchNormalization()(conv4)
    conv4 = Swish()(conv4)
    conv4 = Conv2D(256, 3,  padding = 'same')(conv4)
    conv4 = BatchNormalization()(conv4)
    conv4 = Swish()(conv4)
    conv4 = squeeze_excite_block()(conv4)
    drop4 = Dropout(0.5)(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
    pool4 = ASPP(pool4,512)

    #bottle_neck
    conv5 = Conv2D(512, (3, 3), activation=None, padding="same")(pool4)
    conv5 = residual_block(conv5, 512)
    conv5 = residual_block(conv5, 512)
    conv5 = LeakyReLU(alpha=0.1)(conv5)
    drop5 = Dropout(0.5)(conv5)
    
    #conv6 = attention_up_and_concate(conv5,conv4) #conv6 = concatenate([deconv6, conv4])
    deconv6 = Conv2DTranspose(256, (3, 3), strides=(2, 2), padding="same")(conv5)
    conv6 = concatenate([deconv6, conv4])
    conv6 = Dropout(0.2)(conv6)
    conv6 = Conv2D(256, (3, 3), activation=None, padding="same")(conv6)
    conv6 = residual_block(conv6,256)
    conv6 = residual_block(conv6,256)
    conv6 = LeakyReLU(alpha=0.1)(conv6)
    
    #conv7 = attention_up_and_concate(conv6,conv3) #conv7 = concatenate([deconv7, conv3])  
    deconv7 = Conv2DTranspose(128, (3, 3), strides=(2, 2), padding="same")(conv6) 
    conv7 = concatenate([deconv7, conv3])  
    conv7 = Dropout(0.2)(conv7)
    conv7 = Conv2D(128, (3, 3), activation=None, padding="same")(conv7)
    conv7 = residual_block(conv7,128)
    conv7 = residual_block(conv7,128)
    conv7 = LeakyReLU(alpha=0.1)(conv7)

    #conv8 = attention_up_and_concate(conv7,conv2) #conv8 = concatenate([deconv8,conv2])  
    deconv8 = Conv2DTranspose(64, (3, 3), strides=(2, 2), padding="same")(conv7)
    conv8 = concatenate([deconv8,conv2])  
    conv8 = Dropout(0.2)(conv8)
    conv8 = Conv2D(64, (3, 3), activation=None, padding="same")(conv8)
    conv8 = residual_block(conv8,64)
    conv8 = residual_block(conv8,64)
    conv8 = LeakyReLU(alpha=0.1)(conv8)
    
    #conv9 = attention_up_and_concate(conv8,conv1) #conv9 = concatenate([deconv9, conv1])  
    deconv9 = Conv2DTranspose(32, (3, 3), strides=(2, 2), padding="same")(conv8)  
    conv9 = concatenate([deconv9, conv1])
    conv9 = Dropout(0.2)(conv9)
    conv9 = Conv2D(32, (3, 3), activation=None, padding="same")(conv9)
    conv9 = residual_block(conv9,32)
    conv9 = residual_block(conv9,32)
    conv9 = LeakyReLU(alpha=0.1)(conv9)
    
    '''conv10 = Conv2DTranspose(16, (3, 3), strides=(2, 2), padding="same")(conv9)   
    conv10 = Dropout(0.1)(conv10)
    conv10 = Conv2D(16, (3, 3), activation=None, padding="same")(conv10)
    conv10 = residual_block(conv10,16)
    conv10 = residual_block(conv10,16)
    conv10 = LeakyReLU(alpha=0.1)(conv10)'''
    
    #conv10 = Dropout(0.1)(conv10)
    output = Conv2D(out_channels, (1,1), padding="same", activation="softmax")(conv9)    
    
    model = Model(input, output)
    #model.name = 'u-xception'

    return model

In [None]:
#generator = residual_attention_concate_UNet(input_shape=(128, 128, 1))
#generator.summary()

## SegNet

In [None]:
def decoder_block(n_filter,skip=None):
  def call(inputs):
    x= Conv2DTranspose(n_filter, (2,2), strides=(2, 2), padding='same',kernel_initializer = 'he_normal')(inputs)
    out = x
    if skip is not None :
      attention = conv_block(n_filter)(skip)
      out = Concatenate()([x,attention])
    out = conv_block(n_filter)(out)

    return out
  return call
  
def dow_block(kernel_size=(2,2),stride=(2,2)):
  def call(inputs):
    out = MaxPooling2D(kernel_size, strides=stride)(inputs)
    return out
  return call

In [None]:
def encoderSegnet(input_s=(128,128,1)):
  inp= Input(shape=input_s)
  o = inp
  nums_filter=[64,128,256,512,512]
  count=0
  for f in nums_filter[:-1]:
    count+=1
    o = conv_block(f,block_name='output_block_'+str(count))(o) #kernel=3,dilation = 3,
    o = dow_block()(o)

  o = conv_block(nums_filter[-1],block_name='output_block_'+str(count+1))(o)
  #o = Dropout(0.5)(o)
  return Model(inp,o)

list_skip = ["output_block_4", "output_block_3", "output_block_2", "output_block_1"]

In [None]:
def seg_net(input_shape= (192,256,3), list_skip = list_skip,out_channels=3):
  encoder = encoderSegnet(input_s = input_shape)
  skip_connect=[encoder.get_layer(i).output for i in list_skip]
  num_filters = [512,256, 128, 64]

  o = encoder.output
  o = ASPP(o,128)
  
  for i, f in enumerate(num_filters):
    o = decoder_block(f,skip=skip_connect[i])(o)
  
  o = Conv2D(out_channels,(3, 3), padding='same', kernel_initializer='he_normal')(o)
  # yn = Activation('softmax')(o[...,:-1])
  # bn = o[...,-1:]
  # output = Concatenate()([yn,bn])
  if out_channels > 1 : 
    output = Activation('softmax', name = 'softmax')(o)
  else :
    output = Activation('sigmoid', name = 'sigmoid')(o)
  return Model(encoder.input,output)

In [None]:
generator = seg_net(input_shape = (128,128,1), out_channels=3)
generator.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 128, 128, 1) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 128, 128, 64) 192         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 128, 128, 64) 256         conv2d[0][0]                     
__________________________________________________________________________________________________
swish (Swish)                   (None, 128, 128, 64) 0           batch_normalization[0][0]        
____________________________________________________________________________________________