In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.layers import Input, DepthwiseConv2D, Conv2D, BatchNormalization, LayerNormalization, ReLU, Add, Concatenate, Multiply, GlobalAveragePooling2D, Dense, Reshape, Conv1D, Activation, Lambda, UpSampling2D
from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, Dropout, Multiply, Add, BatchNormalization, Activation, UpSampling2D, Concatenate

def depthwise_conv(x, kernel_size, name=None):
    return layers.DepthwiseConv2D(kernel_size=kernel_size, padding='same', use_bias=False, name=name)(x)

def pointwise_conv(x, filters, name=None):
    return layers.Conv2D(filters, kernel_size=(1, 1), padding='same', use_bias=False, name=name)(x)

def HMAA(inputs, block_size=7, grid_size=7, reduction_ratio=16, layer_name_prefix=""):
    B = tf.shape(inputs)[0]
    H, W, C = tf.shape(inputs)[1], tf.shape(inputs)[2], tf.shape(inputs)[3]

    channels = inputs.shape[-1]

    unique_id = tf.keras.backend.get_uid('dense')
    shared_dense_one = layers.Dense(channels // reduction_ratio, activation='relu', name=f'{layer_name_prefix}dense_one_{unique_id}')
    shared_dense_two = layers.Dense(channels, activation='sigmoid', name=f'{layer_name_prefix}dense_two_{unique_id}')

    local_att = layers.GlobalAveragePooling2D(name=f'{layer_name_prefix}global_avg_pool_{unique_id}')(inputs)
    local_att = shared_dense_one(local_att)
    local_att = shared_dense_two(local_att)
    local_att = layers.Reshape((1, 1, channels), name=f'{layer_name_prefix}reshape_local_{unique_id}')(local_att)

    global_att = layers.GlobalMaxPooling2D(name=f'{layer_name_prefix}global_max_pool_{unique_id}')(inputs)
    global_att = shared_dense_one(global_att)
    global_att = shared_dense_two(global_att)
    global_att = layers.Reshape((1, 1, channels), name=f'{layer_name_prefix}reshape_global_{unique_id}')(global_att)

    combined_channel_att = layers.Add(name=f'{layer_name_prefix}add_channel_{unique_id}')([local_att, global_att])
    combined_channel_att = layers.Activation('sigmoid', name=f'{layer_name_prefix}sigmoid_channel_{unique_id}')(combined_channel_att)

    channel_refined = layers.Multiply(name=f'{layer_name_prefix}multiply_channel_{unique_id}')([inputs, combined_channel_att])

    block_h = tf.reshape(channel_refined, (B, H // block_size, block_size, W, C))
    block_h = tf.transpose(block_h, [0, 1, 3, 2, 4])
    block_h = tf.reshape(block_h, (-1, block_size, C))

    block_h = LayerNormalization(name=f'{layer_name_prefix}layer_norm_block_h_{unique_id}')(block_h)
    block_h_att = layers.Dense(channels, activation='sigmoid', name=f'{layer_name_prefix}dense_block_h_{unique_id}')(block_h)
    block_h = block_h * block_h_att

    block_h = tf.reshape(block_h, (B, H // block_size, W, block_size, C))
    block_h = tf.transpose(block_h, [0, 1, 3, 2, 4])
    block_h = tf.reshape(block_h, (B, H, W, C))

    grid_w = tf.reshape(channel_refined, (B, H, W // grid_size, grid_size, C))
    grid_w = tf.transpose(grid_w, [0, 2, 1, 3, 4])
    grid_w = tf.reshape(grid_w, (-1, grid_size, C))

    grid_w = LayerNormalization(name=f'{layer_name_prefix}layer_norm_grid_w_{unique_id}')(grid_w)
    grid_w_att = layers.Dense(channels, activation='sigmoid', name=f'{layer_name_prefix}dense_grid_w_{unique_id}')(grid_w)
    grid_w = grid_w * grid_w_att

    grid_w = tf.reshape(grid_w, (B, W // grid_size, H, grid_size, C))
    grid_w = tf.transpose(grid_w, [0, 2, 1, 3, 4])
    grid_w = tf.reshape(grid_w, (B, H, W, C))

    height_width_att = layers.Add(name=f'{layer_name_prefix}add_height_width_{unique_id}')([block_h, grid_w])
    height_width_att = layers.Activation('sigmoid', name=f'{layer_name_prefix}sigmoid_height_width_{unique_id}')(height_width_att)

    spatial_refined = layers.Multiply(name=f'{layer_name_prefix}multiply_spatial_{unique_id}')([channel_refined, height_width_att])

    avg_pool = tf.reduce_mean(spatial_refined, axis=-1, keepdims=True)
    max_pool = tf.reduce_max(spatial_refined, axis=-1, keepdims=True)
    concat = layers.Concatenate(axis=-1, name=f'{layer_name_prefix}concat_spatial_{unique_id}')([avg_pool, max_pool])
    spatial_att = layers.Conv2D(1, kernel_size=7, padding='same', activation='sigmoid', name=f'{layer_name_prefix}conv_spatial_{unique_id}')(concat)

    combined_spatial_att = layers.Multiply(name=f'{layer_name_prefix}multiply_combined_spatial_{unique_id}')([height_width_att, spatial_att])
    refined_output = layers.Multiply(name=f'{layer_name_prefix}multiply_refined_output_{unique_id}')([spatial_refined, combined_spatial_att])

    return refined_output

def fuzzy_membership_function(inputs, num_membership_functions, layer_name_prefix="Fuzzy_"):
    """
    Compute fuzzy membership values for the input features.
    
    Args:
        inputs: Input feature map.
        num_membership_functions: Number of fuzzy membership functions for each feature channel.
        layer_name_prefix: Prefix for layer naming.
        
    Returns:
        Aggregated fuzzy membership values.
    """
    channels = inputs.shape[-1]
    unique_id = tf.keras.backend.get_uid('fuzzy')

    # Learnable Gaussian parameters (mu and sigma)
    mu = tf.Variable(tf.random.normal([num_membership_functions, channels]), trainable=True, name=f"{layer_name_prefix}mu_{unique_id}")
    sigma = tf.Variable(tf.random.normal([num_membership_functions, channels]), trainable=True, name=f"{layer_name_prefix}sigma_{unique_id}")

    # Expand dimensions for broadcasting
    x = tf.expand_dims(inputs, axis=-2)
    mu = tf.expand_dims(mu, axis=0)
    sigma = tf.expand_dims(sigma, axis=0)

    # Compute Gaussian fuzzy membership values
    fuzzy_values = tf.exp(-tf.square(x - mu) / (2 * tf.square(sigma)))

    # Aggregate membership values using product
    return tf.reduce_prod(fuzzy_values, axis=-1)

def fuzzy_learning_module(inputs, num_membership_functions, uncertainty_drop_rate=0.2, filters=64, layer_name_prefix="FLM_"):
    """
    Enhanced Fuzzy Learning Module without unnecessary upsampling.
    """
    unique_id = tf.keras.backend.get_uid('fuzzy_learning')

    # Step 1: Compute fuzzy membership values
    fuzzy_features = fuzzy_membership_function(inputs, num_membership_functions, layer_name_prefix=layer_name_prefix)

    # Transform fuzzy features with pointwise convolution
    fuzzy_transformed = Conv2D(filters, (1, 1), padding='same', activation=None, name=f"{layer_name_prefix}pointwise_conv_{unique_id}")(fuzzy_features)

    # Step 2: Monte Carlo Dropout for Uncertainty Estimation
    mc_dropout = Dropout(rate=uncertainty_drop_rate, name=f"{layer_name_prefix}dropout_{unique_id}")(fuzzy_transformed)
    uncertainty_map = tf.math.reduce_variance(mc_dropout, axis=-1, keepdims=True, name=f"{layer_name_prefix}uncertainty_map_{unique_id}")
    uncertainty_map = Activation('sigmoid', name=f"{layer_name_prefix}uncertainty_activation_{unique_id}")(uncertainty_map)

    # Step 3: Separate Processing for High- and Low-Confidence Features
    high_conf_features = Multiply(name=f"{layer_name_prefix}high_conf_features_{unique_id}")([inputs, 1.0 - uncertainty_map])
    low_conf_features = Multiply(name=f"{layer_name_prefix}low_conf_features_{unique_id}")([inputs, uncertainty_map])

    # Process high-confidence features using depthwise separable convolution
    high_conf_processed = DepthwiseConv2D((3, 3), padding='same', activation='relu', name=f"{layer_name_prefix}high_conf_dwconv_{unique_id}")(high_conf_features)
    high_conf_processed = Conv2D(filters, (1, 1), padding='same', activation='relu', name=f"{layer_name_prefix}high_conf_pwconv_{unique_id}")(high_conf_processed)

    # Specialized processing for low-confidence features using depthwise separable convolution
    low_conf_context = DepthwiseConv2D((5, 5), padding='same', activation='relu', name=f"{layer_name_prefix}low_conf_dwconv_{unique_id}")(low_conf_features)
    low_conf_context = Conv2D(filters, (1, 1), padding='same', activation='relu', name=f"{layer_name_prefix}low_conf_pwconv_{unique_id}")(low_conf_context)
    low_conf_context = BatchNormalization(name=f"{layer_name_prefix}low_conf_bn_{unique_id}")(low_conf_context)

    # Further refine low-confidence features
    low_conf_processed = DepthwiseConv2D((3, 3), padding='same', activation='relu', name=f"{layer_name_prefix}low_conf_refine_dwconv_{unique_id}")(low_conf_context)
    low_conf_processed = Conv2D(filters, (1, 1), padding='same', activation='relu', name=f"{layer_name_prefix}low_conf_refine_pwconv_{unique_id}")(low_conf_processed)

    # Step 4: Combine high-confidence and refined low-confidence features
    combined_features = Concatenate(name=f"{layer_name_prefix}combined_features_{unique_id}")([high_conf_processed, low_conf_processed])
    combined_features = DepthwiseConv2D((3, 3), padding='same', activation='relu', name=f"{layer_name_prefix}combined_dwconv_{unique_id}")(combined_features)
    combined_features = Conv2D(filters, (1, 1), padding='same', activation='relu', name=f"{layer_name_prefix}combined_pwconv_{unique_id}")(combined_features)

    # Step 5: Residual Connection for Stability
    refined_features = BatchNormalization(name=f"{layer_name_prefix}batch_norm_{unique_id}")(combined_features)
    refined_output = Add(name=f"{layer_name_prefix}residual_add_{unique_id}")([inputs, refined_features])

    return refined_output, uncertainty_map


def squeeze_excite_block(input_tensor, ratio=16):
    unique_id = tf.keras.backend.get_uid('squeeze_excite')
    channel_axis = -1
    filters = input_tensor.shape[channel_axis]
    se_shape = (1, 1, filters)

    se = GlobalAveragePooling2D(name=f'global_avg_pool_{unique_id}')(input_tensor)
    se = Reshape(se_shape, name=f'reshape_{unique_id}')(se)
    se = Dense(filters // ratio, activation='relu', use_bias=False, name=f'dense_1_{unique_id}')(se)
    se = Dense(filters, activation='sigmoid', use_bias=False, name=f'dense_2_{unique_id}')(se)

    return Multiply(name=f'multiply_{unique_id}')([input_tensor, se])

def swish_activation(x):
    return x * tf.sigmoid(x)

def eca_layer(input_tensor, gamma=2, b=1, layer_name_prefix=""):
    channels = input_tensor.shape[-1]
    t = int(abs((tf.math.log(float(channels)) / tf.math.log(2.0) + b) / gamma))
    k_size = t if t % 2 else t + 1
    
    unique_id = tf.keras.backend.get_uid('eca')
    gap = tf.reduce_mean(input_tensor, axis=[1, 2], keepdims=True, name=f'{layer_name_prefix}_gap_{unique_id}')
    gap = Lambda(lambda x: tf.squeeze(x, axis=1), name=f'{layer_name_prefix}_squeeze_{unique_id}')(gap)
    gap = Conv1D(1, kernel_size=k_size, padding='same', name=f'{layer_name_prefix}_conv1d_{unique_id}')(gap)
    gap = Lambda(lambda x: tf.expand_dims(x, axis=1), name=f'{layer_name_prefix}_expand_{unique_id}')(gap)

    scale = Activation('sigmoid', name=f'{layer_name_prefix}_sigmoid_{unique_id}')(gap)

    return input_tensor * scale


def MCAU(input_tensor, num_filters, layer_name_prefix=""):
    unique_id = tf.keras.backend.get_uid('MSCA')

    # Dynamic Kernel Size Adaptation - Multi-Scale Convolutions with Norm
    conv_3x1 = depthwise_conv(input_tensor, (3, 1), name=f'{layer_name_prefix}_conv_3x1_{unique_id}')
    conv_1x3 = depthwise_conv(conv_3x1, (1, 3), name=f'{layer_name_prefix}_conv_1x3_{unique_id}')
    conv_1x3 = layers.LayerNormalization(name=f'{layer_name_prefix}_norm_1x3_{unique_id}')(conv_1x3)

    conv_5x1 = depthwise_conv(input_tensor, (5, 1), name=f'{layer_name_prefix}_conv_5x1_{unique_id}')
    conv_1x5 = depthwise_conv(conv_5x1, (1, 5), name=f'{layer_name_prefix}_conv_1x5_{unique_id}')
    conv_1x5 = layers.LayerNormalization(name=f'{layer_name_prefix}_norm_1x5_{unique_id}')(conv_1x5)

    conv_11x1 = depthwise_conv(input_tensor, (11, 1), name=f'{layer_name_prefix}_conv_11x1_{unique_id}')
    conv_1x11 = depthwise_conv(conv_11x1, (1, 11), name=f'{layer_name_prefix}_conv_1x11_{unique_id}')
    conv_1x11 = layers.LayerNormalization(name=f'{layer_name_prefix}_norm_1x11_{unique_id}')(conv_1x11)

    # Combine the multi-scale convolutions
    combined = layers.Add(name=f'{layer_name_prefix}_add_combined_{unique_id}')([conv_1x3, conv_1x5, conv_1x11])

    # Pointwise Convolution to reduce the dimensionality
    mixed = pointwise_conv(combined, num_filters, name=f'{layer_name_prefix}_mixed_{unique_id}')
    mixed = layers.LayerNormalization(name=f'{layer_name_prefix}_norm_mixed_{unique_id}')(mixed)

    # ECA-Net for efficient channel attention
    eca = eca_layer(mixed, layer_name_prefix=f'{layer_name_prefix}_eca')

    # Apply the SE attention
    attention = layers.Multiply(name=f'{layer_name_prefix}_se_attention_{unique_id}')([mixed, eca])

    # Final output with the attention
    output = layers.Multiply(name=f'{layer_name_prefix}_output_{unique_id}')([input_tensor, attention])

    return output


def FMSAN_stage(input_tensor, num_filters, num_membership_functions, stage_index):
    unique_id = tf.keras.backend.get_uid('MSCAN_stage')
    ffn_output = pointwise_conv(input_tensor, num_filters, name=f'MSCAN_stage_{stage_index}_pointwise_conv_{unique_id}')
    ffn_output = layers.LayerNormalization(name=f'MSCAN_stage_{stage_index}_layer_norm1_{unique_id}')(ffn_output)
    ffn_output = layers.Activation('relu', name=f'MSCAN_stage_{stage_index}_relu1_{unique_id}')(ffn_output)
    ffn_output = layers.Conv2D(num_filters, (3, 3), padding='same', name=f'MSCAN_stage_{stage_index}_conv_{unique_id}')(ffn_output)
    ffn_output = layers.LayerNormalization(name=f'MSCAN_stage_{stage_index}_layer_norm2_{unique_id}')(ffn_output)

    if input_tensor.shape[-1] != num_filters:
        input_tensor = pointwise_conv(input_tensor, num_filters, name=f'MSCAN_stage_{stage_index}_input_align_{unique_id}')
    ffn_output = layers.Add(name=f'MSCAN_stage_{stage_index}_add_{unique_id}')([input_tensor, ffn_output])

    msca_output = MCAU(ffn_output, num_filters, layer_name_prefix=f'MSCAN_stage_{stage_index}_MSCA')
    msca_output = layers.LayerNormalization(name=f'MSCAN_stage_{stage_index}_layer_norm3_{unique_id}')(msca_output)

    # Enhanced Fuzzy Learning Module
    fuzzy_output, _ = fuzzy_learning_module(
        inputs=msca_output,
        num_membership_functions=num_membership_functions,
        uncertainty_drop_rate=0.2,  
        filters=num_filters,
        layer_name_prefix=f'MSCAN_stage_{stage_index}_fuzzy_'
    )

    output = layers.Add(name=f'MSCAN_stage_{stage_index}_output_{unique_id}')([ffn_output, fuzzy_output])

    output = layers.Conv2D(num_filters, (3, 3), strides=(2, 2), padding='same', name=f'MSCAN_stage_{stage_index}_downsample_{unique_id}')(output)

    return output

def FMSAN(input_tensor, num_filters_initial=48, num_stages=4, num_membership_functions=5):
    x = input_tensor
    num_filters = num_filters_initial

    encoder_features = []
    for stage in range(num_stages):
        x = FMSAN_stage(x, num_filters, num_membership_functions, stage_index=stage + 1)
        encoder_features.append(x)
        print(f"Encoder stage {stage + 1}: resolution = {x.shape[1:3]}, filters = {num_filters}")
        num_filters *= 2

    return encoder_features

def GMFAM(F4, F3, F_PHA):
    unique_id = tf.keras.backend.get_uid('GMFA')
    target_shape = F4.shape[1:3]
    F3_resized = tf.image.resize(F3, target_shape, name=f'F3_resize_{unique_id}')
    F_PHA_resized = tf.image.resize(F_PHA, target_shape, name=f'F_PHA_resize_{unique_id}')

    concatenated = layers.Concatenate(name=f'concat_{unique_id}')([F4, F3_resized, F_PHA_resized])

    def shared_conv(x, filters, kernel_size):
        conv_id = tf.keras.backend.get_uid('shared_conv')
        x = layers.DepthwiseConv2D(kernel_size=kernel_size, padding='same', name=f'shared_depthwise_conv_{conv_id}')(x)
        x = layers.Conv2D(filters, kernel_size=1, padding='same', activation='relu', name=f'shared_conv_{conv_id}')(x)
        return x

    filters = concatenated.shape[-1] // 2
    conv1 = shared_conv(concatenated, filters, kernel_size=1)
    conv3 = shared_conv(concatenated, filters, kernel_size=3)
    conv5 = shared_conv(concatenated, filters, kernel_size=5)

    multi_scale = layers.Concatenate(name=f'multi_scale_concat_{unique_id}')([conv1, conv3, conv5])

    Q = layers.DepthwiseConv2D(kernel_size=3, padding='same', name=f'multi_scale_depthwise_conv_{unique_id}')(multi_scale)
    Q = layers.BatchNormalization(name=f'multi_scale_bn_{unique_id}')(Q)

    total_channels = Q.shape[-1]
    split_channels = int(total_channels // 4)
    N = Q[:, :, :, :split_channels]
    g4 = Q[:, :, :, split_channels:2*split_channels]
    g3 = Q[:, :, :, 2*split_channels:3*split_channels]
    g_PHA = Q[:, :, :, 3*split_channels:4*split_channels]

    g4 = layers.Conv2D(F4.shape[-1], 1, padding='same', activation='sigmoid', name=f'g4_conv_{unique_id}')(g4)
    g3 = layers.Conv2D(F3_resized.shape[-1], 1, padding='same', activation='sigmoid', name=f'g3_conv_{unique_id}')(g3)
    g_PHA = layers.Conv2D(F_PHA_resized.shape[-1], 1, padding='same', activation='sigmoid', name=f'g_PHA_conv_{unique_id}')(g_PHA)

    F4_gated = layers.Multiply(name=f'F4_gated_{unique_id}')([F4, g4])
    F3_gated = layers.Multiply(name=f'F3_gated_{unique_id}')([F3_resized, g3])
    F_PHA_gated = layers.Multiply(name=f'F_PHA_gated_{unique_id}')([F_PHA_resized, g_PHA])

    # Fuse N with the last concatenation
    gated_concat = layers.Concatenate(name=f'gated_concat_{unique_id}')([N, F4_gated, F3_gated, F_PHA_gated])

    R = layers.Conv2D(F4.shape[-1], kernel_size=1, activation='relu', name=f'R_conv_{unique_id}')(gated_concat)

    return R

def depthwise_separable_conv(x, filters, name=""):
    unique_id = tf.keras.backend.get_uid('depthwise_separable')
    x = layers.DepthwiseConv2D(kernel_size=3, padding='same', name=f'{name}_depthwise_conv_{unique_id}')(x)
    x = layers.BatchNormalization(name=f'{name}_bn1_{unique_id}')(x)
    x = layers.Activation('relu', name=f'{name}_activation1_{unique_id}')(x)
    x = layers.Conv2D(filters, kernel_size=1, padding='same', name=f'{name}_pointwise_conv_{unique_id}')(x)
    x = layers.BatchNormalization(name=f'{name}_bn2_{unique_id}')(x)
    x = layers.Activation('relu', name=f'{name}_activation2_{unique_id}')(x)
    return x

def depthwise_separable_conv_bn(inputs, pointwise_filters, kernel_size, dilation_rate=1, name=""):
    unique_id = tf.keras.backend.get_uid('depthwise_separable_bn')
    x = DepthwiseConv2D(kernel_size, padding='same', dilation_rate=dilation_rate, name=f'{name}_depthwise_conv_{unique_id}')(inputs)
    x = BatchNormalization(name=f'{name}_bn1_{unique_id}')(x)
    x = Activation('relu', name=f'{name}_activation1_{unique_id}')(x)
    x = Conv2D(pointwise_filters, (1, 1), padding='same', name=f'{name}_pointwise_conv_{unique_id}')(x)
    x = BatchNormalization(name=f'{name}_bn2_{unique_id}')(x)
    x = Activation('relu', name=f'{name}_activation2_{unique_id}')(x)
    return x

def AAFM(inputs, layer_name_prefix=""):
    unique_id = tf.keras.backend.get_uid('atrous_attention_fusion')
    filters = inputs.shape[-1]

    F1 = depthwise_separable_conv_bn(inputs, filters, kernel_size=3, dilation_rate=1, name=f'{layer_name_prefix}_F1')
    F1_sum = Add(name=f'{layer_name_prefix}_F1_sum_{unique_id}')([inputs, F1])

    F2 = depthwise_separable_conv_bn(F1_sum, filters, kernel_size=3, dilation_rate=2, name=f'{layer_name_prefix}_F2')
    F2_sum = Add(name=f'{layer_name_prefix}_F2_sum_{unique_id}')([F1_sum, F2])

    F3 = depthwise_separable_conv_bn(F2_sum, filters, kernel_size=3, dilation_rate=4, name=f'{layer_name_prefix}_F3')
    F3_sum = Add(name=f'{layer_name_prefix}_F3_sum_{unique_id}')([F2_sum, F3])

    F4 = depthwise_separable_conv_bn(F3_sum, filters, kernel_size=3, dilation_rate=6, name=f'{layer_name_prefix}_F4')
    F4_sum = Add(name=f'{layer_name_prefix}_F4_sum_{unique_id}')([F3_sum, F4])

    concat_features = Concatenate(name=f'{layer_name_prefix}_concat_{unique_id}')([F1_sum, F2_sum, F3_sum, F4_sum])

    eca_features = eca_layer(concat_features, layer_name_prefix=f'{layer_name_prefix}_eca')

    return eca_features

def GMSFB(encoder_features, decoder_features, filters, layer_name_prefix=""):
    unique_id = tf.keras.backend.get_uid('lightweight_mcff')
    encoder_features_resized = tf.image.resize(encoder_features, tf.shape(decoder_features)[1:3], name=f'{layer_name_prefix}_encoder_resize_{unique_id}')

    gating = layers.UpSampling2D(size=(2, 2), interpolation='bilinear', name=f'{layer_name_prefix}_upsample_{unique_id}')(decoder_features)
    gating_resized = tf.image.resize(gating, tf.shape(encoder_features_resized)[1:3], name=f'{layer_name_prefix}_gating_resize_{unique_id}')
    
    gating_conv = layers.Conv2D(filters, 1, padding='same', activation='sigmoid', name=f'{layer_name_prefix}_gating_conv_{unique_id}')(gating_resized)
    attention_features = layers.Multiply(name=f'{layer_name_prefix}_attention_multiply_{unique_id}')([encoder_features_resized, gating_conv])
    
    f1 = depthwise_separable_conv_bn(attention_features, filters, 3, dilation_rate=1, name=f'{layer_name_prefix}_f1_{unique_id}')
    f2 = depthwise_separable_conv_bn(attention_features, filters, 3, dilation_rate=2, name=f'{layer_name_prefix}_f2_{unique_id}')
    f3 = depthwise_separable_conv_bn(attention_features, filters, 3, dilation_rate=4, name=f'{layer_name_prefix}_f3_{unique_id}')
    
    concat_features = layers.Concatenate(name=f'{layer_name_prefix}_concat_{unique_id}')([f1, f2, f3])
    
    fused_features = layers.Conv2D(filters, 1, padding="same", name=f'{layer_name_prefix}_fused_conv_{unique_id}')(concat_features)
    
    output = layers.Add(name=f'{layer_name_prefix}_output_{unique_id}')([fused_features, attention_features])
    
    return output

def WMFM(encoder_outputs):
    def feature_transform(x, target_shape):
        unique_id = tf.keras.backend.get_uid('feature_transform')
        x = layers.Conv2D(target_shape[-1], kernel_size=1, name=f'feature_transform_conv_{unique_id}')(x)
        x = tf.image.resize(x, size=(target_shape[1], target_shape[2]), method='bilinear', name=f'feature_transform_resize_{unique_id}')
        return x

    transformed_features = []
    for i in range(len(encoder_outputs)):
        transformed_features.append(feature_transform(encoder_outputs[i], encoder_outputs[-1].shape))

    weights = [tf.Variable(1.0, trainable=True) for _ in transformed_features]

    normalized_weights = tf.nn.softmax(weights)

    weighted_features = []
    for i, feature in enumerate(transformed_features):
        weighted_features.append(normalized_weights[i] * feature)

    fused_feature = weighted_features[0]
    for i in range(1, len(weighted_features)):
        fused_feature = layers.Multiply(name=f'weighted_feature_multiply_{i}')([fused_feature, weighted_features[i]])

    return fused_feature

def FusionSegNet(input_shape):
    inputs = layers.Input(shape=input_shape, name='input_layer')

    encoder_outputs = FMSAN(inputs, num_filters_initial=64, num_stages=4, num_membership_functions=5)

    fused_encoder_output = WMFM(encoder_outputs)

    enhanced_output = HMAA(fused_encoder_output, layer_name_prefix='Hierarchical_Attention_Refinement_layer_')

    gmfa_output = GMFAM(encoder_outputs[3], encoder_outputs[2], enhanced_output)

    atrous_attention_fusion_output = AAFM(gmfa_output, layer_name_prefix='atrous_attention_fusion_bottleneck_layer_')

    decoder_outputs = []
    decoder_input = atrous_attention_fusion_output

    for i, (encoder_index, filters) in enumerate(zip([3, 2, 1, 0], [512, 256, 128, 64]), start=1):
        x = layers.UpSampling2D((2, 2), name=f'decoder_upsample_stage_{i}')(decoder_input)
        x = layers.Conv2D(filters, (3, 3), padding='same', activation='relu', name=f'decoder_conv_stage_{i}')(x)
        x = layers.BatchNormalization(name=f'decoder_bn_stage_{i}')(x)
        print(f"Decoder stage {i}: resolution = {x.shape[1:3]}, filters = {filters}")

        x = GMSFB(encoder_outputs[encoder_index], x, filters, layer_name_prefix=f'decoder_mcff_stage_{i}_')
        
        x = HMAA(x, layer_name_prefix=f'decoder_Hierarchical_Attention_Refinement_stage_{i}_')

        decoder_input = x
        decoder_outputs.append(x)

    x = layers.Conv2D(1, (1, 1), activation='sigmoid', name='output_layer')(decoder_outputs[-1])

    x = layers.Lambda(lambda x: tf.image.resize(x, (input_shape[0], input_shape[1])), name='resize_output')(x)

    model = Model(inputs, x, name='FusionSegNet_Model')

    return model

# Create model and print summary
model = FusionSegNet((224, 224, 3))
model.summary()
