In [None]:
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (Input, Lambda, Conv2D, Conv2DTranspose, Dropout, MaxPooling2D,
                                     AveragePooling2D, UpSampling2D, DepthwiseConv2D, SeparableConv2D,
                                     GlobalAveragePooling2D, GlobalMaxPooling2D, Dense, Reshape, Multiply,
                                     Concatenate, Add, BatchNormalization, Activation, ReLU, LayerNormalization)
from tensorflow.keras.activations import gelu
tf.keras.backend.clear_session()
from keras.applications import EfficientNetV2M


def repeat_elem(tensor, rep):
    return tf.keras.layers.Lambda(lambda x, repnum: K.repeat_elements(x, repnum, axis=3),
                                  arguments={'repnum': rep})(tensor)

def ECA(x):
    k_size = 5
    squeeze = tf.reduce_mean(x, (2, 3))
    squeeze = tf.expand_dims(squeeze, axis=1)
    attn = tf.keras.layers.Conv1D(1, k_size, padding='same', use_bias=False)(squeeze)
    attn = tf.expand_dims(tf.transpose(attn, [0, 2, 1]), 3)
    attn = tf.math.sigmoid(attn)
    return x * attn

def spatial_attention(input_tensor):
    avg_pool = GlobalAveragePooling2D()(input_tensor)
    max_pool = GlobalMaxPooling2D()(input_tensor)
    concat = Concatenate()([avg_pool, max_pool])
    attention = Dense(1, activation='sigmoid')(concat)
    attention = Reshape((1, 1, -1))(attention)
    return Multiply()([input_tensor, attention])

def channel_shuffle(x, groups):
    shape = tf.shape(x)
    batch_size, height, width, channels = shape[0], shape[1], shape[2], shape[3]
    channels_per_group = channels // groups
    x = tf.reshape(x, [batch_size, height, width, groups, channels_per_group])
    x = tf.transpose(x, [0, 1, 2, 4, 3])
    return tf.reshape(x, [batch_size, height, width, channels])

def SEDNet(x, repetition, filters, strides, dilation, groups):
    for _ in range(repetition):
        x = BatchNormalization()(x)
        x = ReLU()(x)
        y = Conv2D(4 * filters, 1, strides=strides, padding="same")(x)
        y = Conv2D(filters, 3, activation='relu', padding="same", dilation_rate=dilation)(y)
        x = Concatenate()([y, x])
        x = channel_shuffle(x, groups)
    return x

def SAEF(filters, length, inp):
    shortcut = Conv2D(filters, (1, 1), padding='same', use_bias=False)(inp)
    shortcut = BatchNormalization()(shortcut)
    out1 = Conv2D(filters, (3, 3), padding='same')(inp)
    out1 = BatchNormalization()(out1)
    out1 = Activation('relu')(out1)
    out1 = spatial_attention(out1)
    out2 = Conv2D(filters, (5, 5), padding='same')(inp)
    out2 = BatchNormalization()(out2)
    out2 = Activation('relu')(out2)
    out2 = spatial_attention(out2)
    out = Add()([shortcut, out1, out2])
    out = Activation('relu')(out)
    return BatchNormalization()(out)


def EMSCA(x, filters):
    shape = x.shape
    y1 = AveragePooling2D(pool_size=(shape[1], shape[2]))(x)
    y1 = Conv2D(filters, (1, 1), padding="same")(y1)
    y1 = BatchNormalization()(y1)
    y1 = Activation("relu")(y1)
    y1 = UpSampling2D((shape[1], shape[2]), interpolation='bilinear')(y1)

    def branch(x, dilation):
        y = Conv2D(filters, (1, 1), padding="same")(x)
        y = DepthwiseConv2D((3, 3), dilation_rate=(dilation, dilation), padding="same")(y)
        y = Conv2D(filters, (1, 1), padding="same", use_bias=False)(y)
        y = BatchNormalization()(y)
        y = Activation("relu")(y)
        y = ECA(y)
        y = Conv2D(filters, (1, 1), padding="same")(y)
        return Add()([y, x])

    y2 = branch(x, 4)
    y3 = branch(x, 8)
    y4 = branch(x, 12)

    y = Concatenate()([y1, y2, y3, y4])
    y = DepthwiseConv2D((3, 3), padding="same")(y)
    y = Conv2D(filters, (1, 1), padding="same", use_bias=False)(y)
    y = BatchNormalization()(y)
    y = Activation("relu")(y)
    return y

def mscc_branch(x, kernel_size, filters=16):
    y = DepthwiseConv2D((kernel_size, 1), padding='same', activation='relu')(x)
    y = DepthwiseConv2D((1, kernel_size), padding='same', activation='relu')(y)
    y = Conv2D(filters, (1, 1), padding='same', activation='relu')(y)
    return y

def MSCC(input_tensor, filters=16):
    k_sizes = [11, 9, 7, 5]
    branches = [mscc_branch(input_tensor, k, filters) for k in k_sizes]
    return Concatenate()(branches)


def gating_signal(inputs, out_size):
    x = Conv2D(out_size, (1, 1), padding='same')(inputs)
    x = BatchNormalization()(x)
    return Activation('relu')(x)

def attention_block(x, gating, inter_shape):
    shape_x, shape_g = K.int_shape(x), K.int_shape(gating)
    theta_x = Conv2D(inter_shape, (2, 2), strides=(2, 2), padding='same')(x)
    phi_g = Conv2D(inter_shape, (1, 1), padding='same')(gating)
    upsample_g = Conv2DTranspose(inter_shape, (3, 3), strides=(theta_x.shape[1] // shape_g[1], theta_x.shape[2] // shape_g[2]), padding='same')(phi_g)
    concat_xg = Add()([upsample_g, theta_x])
    act_xg = Activation('relu')(concat_xg)
    psi = Conv2D(1, (1, 1), padding='same')(act_xg)
    sigmoid_xg = Activation('sigmoid')(psi)
    upsample_psi = UpSampling2D(size=(shape_x[1] // sigmoid_xg.shape[1], shape_x[2] // sigmoid_xg.shape[2]))(sigmoid_xg)
    upsample_psi = repeat_elem(upsample_psi, shape_x[3])
    y = Multiply()([upsample_psi, x])
    result = Conv2D(shape_x[3], (1, 1), padding='same')(y)
    return BatchNormalization()(result)


def SS_MLP(input_tensor, num_channels, shift_amount=5, dropout_rate=0.1):
    def shift_tensor(x, shift, axis):
        if shift > 0:
            if axis == 1:
                x = tf.pad(x, [[0, 0], [shift, 0], [0, 0], [0, 0]])
                x = x[:, :-shift, :, :]
            elif axis == 2:
                x = tf.pad(x, [[0, 0], [0, 0], [shift, 0], [0, 0]])
                x = x[:, :, :-shift, :]
        return x

    def apply_shifted_mlp(x, shift, axis):
        x_shifted = shift_tensor(x, shift, axis)
        x_shifted = DepthwiseConv2D(kernel_size=(3, 3), padding='same')(x_shifted)
        x_shifted = Activation(gelu)(x_shifted)
        x_shifted = Dropout(dropout_rate)(x_shifted)
        return shift_tensor(x_shifted, -shift, axis)

    x = LayerNormalization()(input_tensor)
    x_h = apply_shifted_mlp(x, shift_amount, axis=1)
    x_w = apply_shifted_mlp(x, shift_amount, axis=2)
    return Add()([x, x_h, x_w])

def ERC(input, num_filters, dilation_rate=1):
    x = Conv2D(num_filters, 3, padding="same", dilation_rate=dilation_rate)(input)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = Conv2D(num_filters, 3, padding="same", dilation_rate=dilation_rate)(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    shortcut = Conv2D(num_filters, 1, padding='same')(input)
    shortcut = BatchNormalization()(shortcut)
    x = Concatenate()([shortcut, x])
    x = Activation("relu")(x)
    return ECA(x)

def cross_attention_module(high_level_features, low_level_features, filters):
    high = Conv2D(filters, (1, 1), padding='same', activation='relu')(high_level_features)
    high = Conv2D(filters, (3, 3), padding='same', activation='relu')(high)
    high = Lambda(lambda t: tf.image.resize(t[0], tf.shape(t[1])[1:3], method='bilinear'))([high, low_level_features])
    low = Conv2D(filters, (1, 1), padding='same', activation='relu')(low_level_features)
    low = Conv2D(filters, (3, 3), padding='same', activation='relu')(low)
    fused = ReLU()(Add()([high, low]))
    attn = GlobalAveragePooling2D()(fused)
    attn = Reshape((1, 1, filters))(attn)
    attn = Conv2D(filters, (1, 1), activation='sigmoid')(attn)
    return Multiply()([fused, attn])

def decoder_block(input, skip, filters, dilation_rate=1):
    cross = cross_attention_module(input, skip, filters)
    gate = gating_signal(input, filters)
    attn = attention_block(cross, gate, filters)
    x = UpSampling2D((2, 2), interpolation="bilinear")(input)
    x = Concatenate()([x, attn])
    x = SS_MLP(x, filters, shift_amount=5, dropout_rate=0.1)
    return ERC(x, filters, dilation_rate=dilation_rate)

# --- Model Architecture ---
inputs = Input((256, 256, 3))
s = Lambda(lambda x: x / 255)(inputs)
irnet = EfficientNetV2M(include_top=False, weights="imagenet", input_tensor=inputs)

# Backbone outputs
s = irnet.get_layer("input_1").output
s2 = irnet.get_layer("block1c_project_activation").output
s3 = irnet.get_layer("block2c_expand_activation").output
s4 = irnet.get_layer("block3c_expand_activation").output

print("Backbone s shape:", s.shape)
print("Backbone s2 shape:", s2.shape)
print("Backbone s3 shape:", s3.shape)
print("Backbone s4 shape:", s4.shape)

mscc_out = MSCC(s)
print("After MSCC:", mscc_out.shape)

c1 = SEDNet(mscc_out, repetition=2, filters=16, strides=1, dilation=1, groups=4)
c1 = Concatenate()([c1, s])
print("After Encoder Stage 1 (Concatenated with s):", c1.shape)
p1 = MaxPooling2D((2, 2))(c1)
c1 = SAEF(16, 4, c1)

c2 = SEDNet(ERC(p1, 32), 2, 32, 1, 2, 4)
c2 = Concatenate()([c2, s2])
print("After Encoder Stage 2:", c2.shape)
p2 = MaxPooling2D((2, 2))(c2)
c2 = SAEF(32, 3, c2)

c3 = SEDNet(ERC(p2, 64), 2, 64, 1, 3, 4)
c3 = Concatenate()([c3, s3])
print("After Encoder Stage 3:", c3.shape)
p3 = MaxPooling2D((2, 2))(c3)
c3 = SAEF(64, 2, c3)

c4 = SEDNet(ERC(p3, 128), 2, 128, 1, 4, 4)
c4 = Concatenate()([c4, s4])
print("After Encoder Stage 4:", c4.shape)
p4 = MaxPooling2D((2, 2))(c4)
c4 = SAEF(128, 1, c4)

mlp_out = SS_MLP(c4, 128, shift_amount=5, dropout_rate=0.1)
mlp_out = MaxPooling2D(pool_size=(2, 2))(mlp_out)
mlp_out = Conv2D(256, (1, 1), padding='same', activation='relu')(mlp_out)

conv_out = Conv2D(256, (3, 3), activation='relu', padding='same')(
               Dropout(0.2)(
                   Conv2D(256, (3, 3), activation='relu', padding='same')(p4)))

fused_bottleneck = Add()([mlp_out, conv_out])
print("After Bottleneck Fusion:", fused_bottleneck.shape)

c5 = EMSCA(fused_bottleneck, 256)
print("After EMSCA:", c5.shape)

d1 = decoder_block(c5, c4, 256, 1)
print("After Decoder Stage 1:", d1.shape)
d2 = decoder_block(d1, c3, 128, 2)
print("After Decoder Stage 2:", d2.shape)
d3 = decoder_block(d2, c2, 128, 4)
print("After Decoder Stage 3:", d3.shape)
d4 = decoder_block(d3, c1, 64, 8)
print("After Decoder Stage 4:", d4.shape)

outputs = Conv2D(1, (1, 1), activation='sigmoid')(d4)
print("Final Output:", outputs.shape)

model = Model(inputs=inputs, outputs=outputs)
model._name = "MSAC-Net"

model.summary()

Backbone s shape: (None, 256, 256, 3)
Backbone s2 shape: (None, 128, 128, 24)
Backbone s3 shape: (None, 64, 64, 192)
Backbone s4 shape: (None, 32, 32, 320)
After MSCC: (None, 256, 256, 64)
After Encoder Stage 1 (Concatenated with s): (None, 256, 256, 99)
After Encoder Stage 2: (None, 128, 128, 152)
After Encoder Stage 3: (None, 64, 64, 448)
After Encoder Stage 4: (None, 32, 32, 832)
After Bottleneck Fusion: (None, 16, 16, 256)
After EMSCA: (None, 16, 16, 256)
After Decoder Stage 1: (None, 32, 32, 512)
After Decoder Stage 2: (None, 64, 64, 256)
After Decoder Stage 3: (None, 128, 128, 256)
After Decoder Stage 4: (None, 256, 256, 128)
Final Output: (None, 256, 256, 1)
Model: "MSAC-Net"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 256, 256, 3  0           []                               
           