In [None]:
%reset

Once deleted, variables cannot be recovered. Proceed (y/[n])? y


In [1]:
import tensorflow as tf
import numpy as np

from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import concatenate

import os
import imageio
import matplotlib.pyplot as plt
import re
from PIL import Image
import shutil

from skimage.transform import resize
from sklearn.model_selection import train_test_split

import random
from tqdm import tqdm

from skimage.io import imread, imshow

In [2]:
##### loading low fidelity training data ######
x_train = np.load("x_train.npy")
y_train = np.load("y_train.npy")

#### loading low fidelity testing data ######
x_test = np.load("x_test.npy")
y_test = np.load("y_test.npy")

In [None]:
print(x_train.shape)
print(y_train.shape)
print(x_test.shape)
print(y_test.shape)

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda, Conv2D, BatchNormalization, Dropout, MaxPooling2D, Conv2DTranspose, concatenate
from tensorflow.keras.models import Model

seed = 42
np.random.seed = seed

IMG_WIDTH = 128
IMG_HEIGHT = 128
IMG_CHANNELS = 3

def dice_coefficient(y_true, y_pred, smooth=1e-6):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    intersection = tf.reduce_sum(y_true * y_pred)
    union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred)
    dice = (2.0 * intersection + smooth) / (union + smooth)
    return dice


#Build the UNet model
inputs = tf.keras.layers.Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
s = tf.keras.layers.Lambda(lambda x: x / 255)(inputs)

#Contraction path
c1 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(s)
c1 = tf.keras.layers.Dropout(0.1)(c1)
c1 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c1)
p1 = tf.keras.layers.MaxPooling2D((2, 2))(c1)

c2 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p1)
c2 = tf.keras.layers.Dropout(0.1)(c2)
c2 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c2)
p2 = tf.keras.layers.MaxPooling2D((2, 2))(c2)

c3 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p2)
c3 = tf.keras.layers.Dropout(0.2)(c3)
c3 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c3)
p3 = tf.keras.layers.MaxPooling2D((2, 2))(c3)

c4 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p3)
c4 = tf.keras.layers.Dropout(0.2)(c4)
c4 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c4)
p4 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(c4)

c5 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p4)
c5 = tf.keras.layers.Dropout(0.3)(c5)
c5 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c5)

#Expansive path
u6 = tf.keras.layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c5)
u6 = tf.keras.layers.concatenate([u6, c4])
c6 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u6)
c6 = tf.keras.layers.Dropout(0.2)(c6)
c6 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c6)

u7 = tf.keras.layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c6)
u7 = tf.keras.layers.concatenate([u7, c3])
c7 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u7)
c7 = tf.keras.layers.Dropout(0.2)(c7)
c7 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c7)

u8 = tf.keras.layers.Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(c7)
u8 = tf.keras.layers.concatenate([u8, c2])
c8 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u8)
c8 = tf.keras.layers.Dropout(0.1)(c8)
c8 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c8)

u9 = tf.keras.layers.Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same')(c8)
u9 = tf.keras.layers.concatenate([u9, c1], axis=3)
c9 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u9)
c9 = tf.keras.layers.Dropout(0.1)(c9)
c9 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c9)

outputs = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid')(c9)


# Compile the model with masked losses and metrics

model = Model(inputs=[inputs], outputs=[outputs])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=[dice_coefficient])

model.summary()




In [None]:
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard
from datetime import datetime
import os

# Define the path to save model checkpoints
model_checkpoint_path = "model_checkpoints"
os.makedirs(model_checkpoint_path, exist_ok=True)

# Define the path to save TensorBoard logs
log_dir = "logs/" + datetime.now().strftime("%Y%m%d-%H%M%S")
os.makedirs(log_dir, exist_ok=True)

# Callback 1: ModelCheckpoint
# Save the model's weights after every epoch if the validation loss improves.
model_checkpoint = ModelCheckpoint(
    filepath=os.path.join(model_checkpoint_path, "model_weights.weights.h5"),
    monitor="val_loss",
    save_best_only=True,
    save_weights_only=True,
    verbose=1
)

# Callback 2: TensorBoard
# Save training logs for visualization in TensorBoard.
tensorboard = TensorBoard(
    log_dir=log_dir,
    histogram_freq=1,
    write_graph=True,
    write_images=True
)

# List of essential callbacks
callbacks = [
    model_checkpoint,
    # early_stopping,
    tensorboard
]

# Now, use the callbacks when training your model
results = model.fit(x_train, y_train, validation_split=0.2, batch_size=128, epochs=100, callbacks=callbacks)

In [None]:
# Load model weights and make predictions
model.load_weights('model_weights.weights.h5')
preds_test = model.predict(x_test, verbose=1)
preds_test_t = (preds_test > 0.75).astype(np.uint8)

# Define the circular mask function
def create_circular_mask(height, width, radius=63):
    center = (int(height / 2), int(width / 2))
    Y, X = np.ogrid[:height, :width]
    dist_from_center = np.sqrt((X - center[1]) ** 2 + (Y - center[0]) ** 2)
    mask = dist_from_center <= radius
    return mask.astype(np.float32)

# Generate the circular mask for 128x128 images
IMG_HEIGHT, IMG_WIDTH = 128, 128
circular_mask = create_circular_mask(IMG_HEIGHT, IMG_WIDTH)
circular_mask = tf.convert_to_tensor(circular_mask, dtype=tf.float32)

# Define the Dice score function
def dice_score(masked_y_true, masked_y_pred, smooth=1e-6):
    intersection = tf.reduce_sum(masked_y_true * masked_y_pred)
    union = tf.reduce_sum(masked_y_true) + tf.reduce_sum(masked_y_pred)
    dice = (2.0 * intersection + smooth) / (union + smooth)
    return dice

# Calculate and print the Dice score for each test sample
dice_scores = []
for i in range(len(y_test)):
    # Mask and threshold the ground truth
    y_test_sample = y_test[i, :, :, 0]  # Extract and remove last dimension
    thresholded_y_test = tf.cast(y_test_sample > 0, tf.float32)
    masked_y_test = thresholded_y_test * circular_mask

    # Mask and threshold the prediction
    y_pred_sample = preds_test_t[i, :, :, 0]  # Extract and remove last dimension
    thresholded_y_pred = tf.cast(y_pred_sample > 0, tf.float32)
    masked_y_pred = thresholded_y_pred * circular_mask

    # Calculate Dice score
    dice = dice_score(masked_y_test, masked_y_pred)
    dice_scores.append(dice.numpy())
    print(f"Dice Score for sample {i}: {dice.numpy()}")

# Optional: Print average Dice score across all test samples
average_dice_score = np.mean(dice_scores)
print(f"Average Dice Score for all test samples: {average_dice_score}")


**Attention UNet**

In [None]:
def gating_signal(input, out_size, batch_norm=False):
    x = tf.keras.layers.Conv2D(out_size, (1, 1), padding='same')(input)
    if batch_norm:
        x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)
    return x

In [None]:
def repeat_elem(tensor, rep):
    # Lambda function to repeat elements of a tensor along an axis
    # by a factor of rep.
    # If tensor has shape (None, 256, 256, 3), lambda will return a tensor of shape
    # (None, 256, 256, 6), if specified axis=3 and rep=2.

    return tf.keras.layers.Lambda(lambda x: tf.repeat(x, rep, axis=3))(tensor)

In [None]:
def attention_block(x, gating, inter_shape):
    shape_x = tf.keras.backend.int_shape(x)
    shape_g = tf.keras.backend.int_shape(gating)

    # Getting the x signal to the same shape as the gating signal
    theta_x = tf.keras.layers.Conv2D(inter_shape, (2, 2), strides=(1, 1), padding='same')(x)
    shape_theta_x = tf.keras.backend.int_shape(theta_x)

    # Getting the gating signal to the same number of filters as the inter_shape
    phi_g = tf.keras.layers.Conv2D(inter_shape, (1, 1), padding='same')(gating)
    upsample_g = tf.keras.layers.Conv2DTranspose(inter_shape, (3, 3),
                                                 strides=(shape_theta_x[1] // shape_g[1], shape_theta_x[2] // shape_g[2]),
                                                 padding='same')(phi_g)

    concat_xg = tf.keras.layers.add([upsample_g, theta_x])
    act_xg = tf.keras.layers.Activation('relu')(concat_xg)
    psi = tf.keras.layers.Conv2D(1, (1, 1), padding='same')(act_xg)
    sigmoid_xg = tf.keras.layers.Activation('sigmoid')(psi)
    shape_sigmoid = tf.keras.backend.int_shape(sigmoid_xg)
    upsample_psi = tf.keras.layers.UpSampling2D(size=(shape_x[1] // shape_sigmoid[1], shape_x[2] // shape_sigmoid[2]))(sigmoid_xg)  # 32

    upsample_psi = repeat_elem(upsample_psi, shape_x[3])

    y = tf.keras.layers.multiply([upsample_psi, x])

    result = tf.keras.layers.Conv2D(shape_x[3], (1, 1), padding='same')(y)
    result_bn = tf.keras.layers.BatchNormalization()(result)
    return result_bn

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda, Conv2D, BatchNormalization, Dropout, MaxPooling2D, Conv2DTranspose, concatenate
from tensorflow.keras.models import Model


seed = 42
tf.random.set_seed(seed)

IMG_WIDTH = 128
IMG_HEIGHT = 128
IMG_CHANNELS = 3

# Build the UNet model
inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
# s = Lambda(lambda x: x / 255)(inputs)
s = Lambda(lambda x: x / 255.0)(inputs)


# Contraction path
c1 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(s)
c1 = BatchNormalization()(c1)
c1 = Dropout(0.1)(c1)
c1 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c1)
p1 = MaxPooling2D((2, 2))(c1)

In [None]:
c2 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p1)
c2 = BatchNormalization()(c2)
c2 = Dropout(0.1)(c2)
c2 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c2)
p2 = MaxPooling2D((2, 2))(c2)

In [None]:
c3 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p2)
c3 = BatchNormalization()(c3)
c3 = Dropout(0.2)(c3)
c3 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c3)
p3 = MaxPooling2D((2, 2))(c3)

In [None]:
c4 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p3)
c4 = BatchNormalization()(c4)
c4 = Dropout(0.2)(c4)
c4 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c4)
p4 = MaxPooling2D(pool_size=(2, 2))(c4)

In [None]:
c5 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p4)
c5 = BatchNormalization()(c5)
c5 = Dropout(0.3)(c5)
c5 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c5)


In [None]:
# Expansive path
u6 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c5)
u6 = concatenate([u6, c4])
c6 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u6)
c6 = BatchNormalization()(c6)
c6 = Dropout(0.2)(c6)
c6 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c6)


In [None]:
gating_4 = gating_signal(c6, 128)
att_4 = attention_block(c4, gating_4, 128)

In [None]:
from tensorflow.keras.layers import UpSampling2D

u7 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c6)
att_4_upsampled = UpSampling2D(size=(2, 2))(att_4)  # Upsample the attention block
u7 = concatenate([u7, att_4_upsampled, c3])
c7 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u7)
c7 = BatchNormalization()(c7)
c7 = Dropout(0.2)(c7)
c7 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c7)


In [None]:
gating_3 = gating_signal(c7, 64)
att_3 = attention_block(c3, gating_3, 64)

In [None]:
u8 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(c7)
att_3_upsampled = UpSampling2D(size=(2, 2))(att_3)  # Upsample the attention block
u8 = concatenate([u8, att_3_upsampled, c2])

In [None]:
c8 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u8)
c8 = BatchNormalization()(c8)
c8 = Dropout(0.1)(c8)
c8 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c8)


In [None]:
gating_2 = gating_signal(c8, 32)
att_2 = attention_block(c2, gating_2, 32)

u9 = Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same')(c8)
att_2_upsampled = UpSampling2D(size=(2, 2))(att_2)  # Upsample the attention block
u9 = concatenate([u9, att_2_upsampled, c1], axis=3)
c9 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u9)
c9 = BatchNormalization()(c9)
c9 = Dropout(0.1)(c9)
c9 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c9)


In [None]:
outputs = Conv2D(1, (1, 1), activation='sigmoid')(c9)


**DENSE UNet**

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, MaxPooling2D, UpSampling2D, concatenate, Dropout
from tensorflow.keras.models import Model

seed = 42
tf.random.set_seed(seed)

IMG_WIDTH = 128
IMG_HEIGHT = 128
IMG_CHANNELS = 3

# Build the ResNet U-Net model
inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
s = Activation('linear')(inputs)  # You can change the activation function as needed

# Contraction path
conv1 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(s)
c1 = BatchNormalization()(conv1)
c1 = Activation('relu')(c1)
c1 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(s)
c1 = Dropout(0.2)(c1)
c1 = concatenate([conv1, c1], axis=3)
p1 = MaxPooling2D((2, 2))(c1)

conv2 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p1)
c2 = BatchNormalization()(conv2)
c2 = Activation('relu')(c2)
c2 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c2)
c2 = Dropout(0.2)(c2)
c2 = concatenate([conv2, c2], axis=3)
p2 = MaxPooling2D((2, 2))(c2)

conv3 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p2)
c3 = BatchNormalization()(conv3)
c3 = Activation('relu')(c3)
c3 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c3)
c3 = Dropout(0.2)(c3)
c3 = concatenate([conv3, c3], axis=3)
p3 = MaxPooling2D((2, 2))(c3)

conv4 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p3)
c4 = BatchNormalization()(conv4)
c4 = Activation('relu')(c4)
c4 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c4)
c4 = Dropout(0.3)(c4)
c4 = concatenate([conv4, c4], axis=3)
c4 = Dropout(0.5)(c4)
p4 = MaxPooling2D((2, 2))(c4)

conv5 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p4)
c5 = BatchNormalization()(conv5)
c5 = Activation('relu')(c5)
c5 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c5)
c5 = Dropout(0.3)(c5)
c5 = concatenate([conv5, c5], axis=3)
c5 = Dropout(0.5)(c5)


# Decoder
u6 = Conv2D(128, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2, 2))(c5))
u6 = concatenate([c4, u6], axis=3)
conv6 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(u6)
u6 = BatchNormalization(axis=3)(conv6)
u6 = Activation('relu')(u6)
u6 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(u6)
u6 = Dropout(0)(u6)
u6 = concatenate([conv6, u6], axis=3)

u7 = Conv2D(64, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2, 2))(u6))
u7 = concatenate([c3, u7], axis=3)
conv7 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(u7)
u7 = BatchNormalization(axis=3)(conv7)
u7 = Activation('relu')(u7)
u7 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(u7)
u7 = Dropout(0)(u7)
u7 = concatenate([conv7, u7], axis=3)

u8 = Conv2D(32, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2, 2))(u7))
u8 = concatenate([c2, u8], axis=3)
conv8 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer='he_normal')(u8)
u8 = BatchNormalization(axis=3)(conv8)
u8 = Activation('relu')(u8)
u8 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer='he_normal')(u8)
u8 = Dropout(0)(u8)
u8 = concatenate([conv8, u8], axis=3)

u9 = Conv2D(16, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2, 2))(u8))
u9 = concatenate([c1, u9], axis=3)
conv9 = Conv2D(16, 3, activation='relu', padding='same', kernel_initializer='he_normal')(u9)
u9 = BatchNormalization(axis=3)(conv9)
u9 = Activation('relu')(u9)
u9 = Conv2D(16, 3, activation='relu', padding='same', kernel_initializer='he_normal')(u9)
u9 = Dropout(0)(u9)
u9 = concatenate([conv9, u9], axis=3)

conv9 = Conv2D(2, (1, 1), activation='sigmoid')(u9)

outputs = Conv2D(1, (1, 1), activation='sigmoid')(conv9)

# Create the model
model = Model(inputs=inputs, outputs=outputs)



**RESNet attention UNet**

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, Dropout, MaxPooling2D, UpSampling2D, concatenate
from tensorflow.keras.models import Model


In [None]:
def res_conv_block(x, filter_size, size, dropout, batch_norm=False):
    conv = Conv2D(size, (filter_size, filter_size), padding='same')(x)
    if batch_norm is True:
        conv = BatchNormalization()(conv)
    conv = Activation('relu')(conv)

    conv = Conv2D(size, (filter_size, filter_size), padding='same')(conv)
    if batch_norm is True:
        conv = BatchNormalization()(conv)
    if dropout > 0:
        conv = Dropout(dropout)(conv)

    shortcut = Conv2D(size, kernel_size=(1, 1), padding='same')(x)
    if batch_norm is True:
        shortcut = BatchNormalization()(shortcut)

    res_path = concatenate([shortcut, conv])
    res_path = Activation('relu')(res_path)
    return res_path


In [None]:
def gating_signal(input, out_size, batch_norm=False):
    x = tf.keras.layers.Conv2D(out_size, (1, 1), padding='same')(input)
    if batch_norm:
        x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)
    return x

In [None]:
def repeat_elem(tensor, rep):
    # Lambda function to repeat elements of a tensor along an axis
    # by a factor of rep.
    # If tensor has shape (None, 256, 256, 3), lambda will return a tensor of shape
    # (None, 256, 256, 6), if specified axis=3 and rep=2.

    return tf.keras.layers.Lambda(lambda x: tf.repeat(x, rep, axis=3))(tensor)

In [None]:
def attention_block(x, gating, inter_shape):
    shape_x = tf.keras.backend.int_shape(x)
    shape_g = tf.keras.backend.int_shape(gating)

    # Getting the x signal to the same shape as the gating signal
    theta_x = tf.keras.layers.Conv2D(inter_shape, (2, 2), strides=(1, 1), padding='same')(x)
    shape_theta_x = tf.keras.backend.int_shape(theta_x)

    # Getting the gating signal to the same number of filters as the inter_shape
    phi_g = tf.keras.layers.Conv2D(inter_shape, (1, 1), padding='same')(gating)
    upsample_g = tf.keras.layers.Conv2DTranspose(inter_shape, (3, 3),
                                                 strides=(shape_theta_x[1] // shape_g[1], shape_theta_x[2] // shape_g[2]),
                                                 padding='same')(phi_g)

    concat_xg = tf.keras.layers.add([upsample_g, theta_x])
    act_xg = tf.keras.layers.Activation('relu')(concat_xg)
    psi = tf.keras.layers.Conv2D(1, (1, 1), padding='same')(act_xg)
    sigmoid_xg = tf.keras.layers.Activation('sigmoid')(psi)
    shape_sigmoid = tf.keras.backend.int_shape(sigmoid_xg)
    upsample_psi = tf.keras.layers.UpSampling2D(size=(shape_x[1] // shape_sigmoid[1], shape_x[2] // shape_sigmoid[2]))(sigmoid_xg)  # 32

    upsample_psi = repeat_elem(upsample_psi, shape_x[3])

    y = tf.keras.layers.multiply([upsample_psi, x])

    result = tf.keras.layers.Conv2D(shape_x[3], (1, 1), padding='same')(y)
    result_bn = tf.keras.layers.BatchNormalization()(result)
    return result_bn

In [None]:
seed = 42
tf.random.set_seed(seed)

IMG_WIDTH = 128
IMG_HEIGHT = 128
IMG_CHANNELS = 3

# Build the ResNet U-Net model
inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
s = Activation('linear')(inputs)  # You can change the activation function as needed

# Contraction path
c1 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(s)
c1 = BatchNormalization()(c1)
c1 = Dropout(0.1)(c1)
c1 = res_conv_block(c1, 3, 16, 0.1, True)
p1 = MaxPooling2D((2, 2))(c1)

c2 = res_conv_block(p1, 3, 32, 0.1, True)
p2 = MaxPooling2D((2, 2))(c2)

c3 = res_conv_block(p2, 3, 64, 0.2, True)
p3 = MaxPooling2D((2, 2))(c3)

c4 = res_conv_block(p3, 3, 128, 0.2, True)
p4 = MaxPooling2D(pool_size=(2, 2))(c4)

c5 = res_conv_block(p4, 3, 256, 0.3, True)

# Expansive path
u6 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c5)
u6 = concatenate([u6, c4])
c6 = res_conv_block(u6, 3, 128, 0.2, True)

gating_4 = gating_signal(c6, 128)
att_4 = res_conv_block(c4, 3, 128, 0.1, True)

u7 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c6)
att_4_upsampled = UpSampling2D(size=(2, 2))(att_4)
u7 = concatenate([u7, att_4_upsampled, c3])
c7 = res_conv_block(u7, 3, 64, 0.2, True)

gating_3 = gating_signal(c7, 64)
att_3 = res_conv_block(c3, 3, 64, 0.1, True)

u8 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(c7)
att_3_upsampled = UpSampling2D(size=(2, 2))(att_3)
u8 = concatenate([u8, att_3_upsampled, c2])

c8 = res_conv_block(u8, 3, 32, 0.1, True)

gating_2 = gating_signal(c8, 32)
att_2 = res_conv_block(c2, 3, 32, 0.1, True)

u9 = Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same')(c8)
att_2_upsampled = UpSampling2D(size=(2, 2))(att_2)
u9 = concatenate([u9, att_2_upsampled, c1])

c9 = res_conv_block(u9, 3, 16, 0.1, True)

outputs = Conv2D(1, (1, 1), activation='sigmoid')(c9)

# Create the model
model = Model(inputs=inputs, outputs=outputs)