# basic model


In [1]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, BatchNormalization, ReLU, MaxPooling2D, Concatenate, Add, Conv2D, Subtract, Multiply, concatenate

Model with RGF

In [2]:
def calculate_loss_between_segmentation_imgs(y_hat_n, y_n):
    pass

def residual_guided_fusion(input_shape_rgb, input_shape_depth, y_n, filters, num_classes):
    rgb_input = Input(shape=input_shape_rgb)
    depth_input = Input(shape=input_shape_depth)

    # RGB PATH
    # generate RGB predicted mask y_hat_n through a 1 × 1 convolutional layer
    y_hat_n = Conv2D(filters, kernel_size=(1, 1), padding='same')(rgb_input)
    loss_n = calculate_loss_between_segmentation_imgs(y_hat_n, y_n)
    y_res = Subtract()([y_n, y_hat_n])


    # DEPTH PATH
    # we subtract the RGB feature maps with depth feature maps by element-wise subtraction to get the difference between them. 
    difference_maps = Subtract()([depth_input, rgb_input])
    # The channel of the different features is adjusted to the number of classes through a 1 × 1 convolution.
    depth_conv = Conv2D(num_classes, kernel_size=(1, 1), padding='same')(difference_maps)
    skip = depth_conv

    # Then a residual unit with a 3 × 3convolution is used to generate the predicted residual mask y_hat_nres
    depth_conv = Conv2D(filters, kernel_size=(3, 3), padding='same')(depth_conv)
    y_hat_res = Add()([depth_conv, skip])

    loss_res = calculate_loss_between_segmentation_imgs(y_hat_res, y_res)

    # The channel of y_hat_res is adjusted to that of the RGB feature maps by a 1 ×1 convolution and result is fused with the RGB feature maps through an element-wise multiplication
    channels_rgb = input_shape_rgb[-1]
    y_hat_res_conv = Conv2D(channels_rgb, kernel_size=(1, 1), padding='same')(y_res)
    combined_path = Multiply()([y_hat_res_conv, rgb_input])

    stacked = Concatenate()([combined_path, rgb_input, y_hat_res_conv])

    return Conv2D(filters, kernel_size=(3, 3), padding='same')(stacked)






In [9]:
def calculate_loss_between_segmentation_imgs(y_hat_n, y_n):
    pass

def residual_guided_fusion(rgb_input, depth_input, filters, y_n=None, num_classes=19):
    input_shape_rgb = rgb_input.shape
    input_shape_rgb = depth_input.shape

    # RGB PATH
    # generate RGB predicted mask y_hat_n through a 1 × 1 convolutional layer
    y_hat_n = Conv2D(filters, kernel_size=(1, 1), padding='same')(rgb_input)
    # loss_n = calculate_loss_between_segmentation_imgs(y_hat_n, y_n)
    # y_res = Subtract()([y_n, y_hat_n])


    # DEPTH PATH
    # we subtract the RGB feature maps with depth feature maps by element-wise subtraction to get the difference between them. 
    difference_maps = Subtract()([depth_input, rgb_input])
    # The channel of the different features is adjusted to the number of classes through a 1 × 1 convolution.
    depth_conv = Conv2D(filters, kernel_size=(1, 1), padding='same')(difference_maps)
    skip = depth_conv

    # Then a residual unit with a 3 × 3convolution is used to generate the predicted residual mask y_hat_nres
    depth_conv = Conv2D(filters, kernel_size=(3, 3), padding='same')(depth_conv)
    y_hat_res = Add()([depth_conv, skip])

    # loss_res = calculate_loss_between_segmentation_imgs(y_hat_res, y_res)

    # The channel of y_hat_res is adjusted to that of the RGB feature maps by a 1 ×1 convolution and result is fused with the RGB feature maps through an element-wise multiplication
    channels_rgb = input_shape_rgb[-1]
    y_hat_res_conv = Conv2D(channels_rgb, kernel_size=(1, 1), padding='same')(y_hat_res)
    combined_path = Multiply()([y_hat_res_conv, rgb_input])

    stacked = Concatenate()([combined_path, rgb_input, y_hat_res_conv])

    return Conv2D(filters, kernel_size=(3, 3), padding='same')(stacked)






In [10]:


def encoder_block(inputs, filters):
    x = Conv2D(filters, kernel_size=(3, 3), padding='same')(inputs)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(filters, kernel_size=(3, 3), padding='same')(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    return x, x

def decoder_block(inputs, skip_connection, filters, skip=True):
    x = Conv2DTranspose(filters, kernel_size=(2, 2), strides=(2, 2), padding='same')(inputs)
    x = Conv2D(filters, kernel_size=(3, 3), padding='same')(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(filters, kernel_size=(3, 3), padding='same')(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)

    if skip:
        if x.shape[1] != skip_connection.shape[1] or x.shape[2] != skip_connection.shape[2]:
            skip_connection = Conv2D(filters, kernel_size=(1, 1), padding='same')(skip_connection)
        x = Concatenate()([x, skip_connection])  # Skip connection

    return x




def build_model(input_shape_rgb, input_shape_depth):
    rgb_input = Input(shape=input_shape_rgb)
    depth_input = Input(shape=input_shape_depth)

    # Encoder for RGB
    rgb_enc1, rgb_skip1 = encoder_block(rgb_input, 32)
    rgb_enc2, rgb_skip2 = encoder_block(rgb_enc1, 64)
    rgb_enc3, rgb_skip3 = encoder_block(rgb_enc2, 128)

    # Encoder for Depth
    depth_enc1, depth_skip1 = encoder_block(depth_input, 32)
    depth_enc2, depth_skip2 = encoder_block(depth_enc1, 64)
    depth_enc3, depth_skip3 = encoder_block(depth_enc2, 128)

    # Decoder for Depth
    depth_dec3 = decoder_block(depth_enc3, depth_skip2, 128)
    depth_dec2 = decoder_block(depth_dec3, depth_skip1, 64)
    depth_dec1 = decoder_block(depth_dec2, None, 32, False)

    # Decoder for RGB
    rgb_dec3 = decoder_block(rgb_enc3, rgb_skip2, 128)
    rgb_dec3 = Add()([rgb_dec3, depth_dec3])

    rgb_dec2 = decoder_block(rgb_dec3, rgb_skip1, 64)
    rgb_dec2 = Add()([rgb_dec2, depth_dec2])

    rgb_dec1 = decoder_block(rgb_dec2, None, 32, False)
    rgb_dec1 = Add()([rgb_dec1, depth_dec1])

    # rgf module
    rgf_1 = residual_guided_fusion(rgb_dec1, depth_dec1, 32)

    # Final output layer for RGB
    rgb_output = Conv2D(3, kernel_size=(1, 1), activation='sigmoid')(rgb_dec1)
    
# (input_shape_rgb, input_shape_depth, y_n, filters, num_classes):
    # Create model
    model = Model(inputs=[rgb_input, depth_input], outputs=rgb_output)
    return model

# Example usage
input_shape_rgb = (256, 256, 3)
input_shape_depth = (256, 256, 1)
model = build_model(input_shape_rgb, input_shape_depth)
