# basic model


In [24]:
import tensorflow as tf
from tensorflow.keras import layers, models

def encoder_block(inputs, filters):
    x = layers.Conv2D(filters, kernel_size=(3, 3), padding='same')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(filters, kernel_size=(3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    return x, x

def decoder_block(inputs, skip_connection, filters, skip=True):
    # print('input', inputs.shape)
    x = layers.Conv2DTranspose(filters, kernel_size=(2, 2), strides=(2, 2), padding='same')(inputs)
    
   
    x = layers.Conv2D(filters, kernel_size=(3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(filters, kernel_size=(3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    if skip:
        # print('x', x.shape, "skip_connection", skip_connection.shape)
        # Ensure skip_connection is the same size as x
        if x.shape[1] != skip_connection.shape[1] or x.shape[2] != skip_connection.shape[2]:
            skip_connection = layers.Conv2D(filters, kernel_size=(1, 1), padding='same')(skip_connection)
        
        x = layers.Concatenate()([x, skip_connection])  # Skip connection

    return x

def build_model(input_shape_rgb, input_shape_depth):
    # Inputs
    rgb_input = layers.Input(shape=input_shape_rgb)
    depth_input = layers.Input(shape=input_shape_depth)
    print(rgb_input.shape, depth_input.shape)

    # Encoder for RGB
    rgb_enc1, rgb_skip1 = encoder_block(rgb_input, 32)
    rgb_enc2, rgb_skip2 = encoder_block(rgb_enc1, 64)
    rgb_enc3, rgb_skip3 = encoder_block(rgb_enc2, 128)
    # print('rgb')
    # print('1', rgb_enc1, rgb_skip1)
    # print('2', rgb_enc2, rgb_skip2)
    # print('3', rgb_enc3, rgb_skip3)

    # Encoder for Depth
    depth_enc1, depth_skip1 = encoder_block(depth_input, 32)
    depth_enc2, depth_skip2 = encoder_block(depth_enc1, 64)
    depth_enc3, depth_skip3 = encoder_block(depth_enc2, 128)
    # print('depth')
    # print('1', depth_enc1, depth_skip1)
    # print('2', depth_enc2, depth_skip2)
    # print('3', depth_enc3, depth_skip3)

    # Decoder for Depth
    depth_dec3 = decoder_block(depth_enc3, depth_skip2, 128)
    depth_dec2 = decoder_block(depth_dec3, depth_skip1, 64)
    depth_dec1 = decoder_block(depth_dec2, None, 32, False)

    # Decoder for RGB
    rgb_dec3 = decoder_block(rgb_enc3, rgb_skip2, 128)
    rgb_dec3 = layers.Add()([rgb_dec3, depth_dec3])  # Add depth information

    rgb_dec2 = decoder_block(rgb_dec3, rgb_skip1, 64)
    rgb_dec2 = layers.Add()([rgb_dec2, depth_dec2])  # Add depth information

    rgb_dec1 = decoder_block(rgb_dec2, None, 32, False)
    rgb_dec1 = layers.Add()([rgb_dec1, depth_dec1])  # Add depth information

    # Final output layer for RGB
    rgb_output = layers.Conv2D(3, kernel_size=(1, 1), activation='sigmoid')(rgb_dec1)
    print("output shape", rgb_output.shape)

    # Create model
    model = models.Model(inputs=[rgb_input, depth_input], outputs=rgb_output)

    return model

# Example usage
input_shape_rgb = (256, 256, 3)  # RGB image input shape
input_shape_depth = (256, 256, 1)  # Depth image input shape
model = build_model(input_shape_rgb, input_shape_depth)

# Summary of the model
model.summary()


(None, 256, 256, 3) (None, 256, 256, 1)
rgb
1 <KerasTensor shape=(None, 128, 128, 32), dtype=float32, sparse=False, name=keras_tensor_4742> <KerasTensor shape=(None, 128, 128, 32), dtype=float32, sparse=False, name=keras_tensor_4742>
2 <KerasTensor shape=(None, 64, 64, 64), dtype=float32, sparse=False, name=keras_tensor_4749> <KerasTensor shape=(None, 64, 64, 64), dtype=float32, sparse=False, name=keras_tensor_4749>
3 <KerasTensor shape=(None, 32, 32, 128), dtype=float32, sparse=False, name=keras_tensor_4756> <KerasTensor shape=(None, 32, 32, 128), dtype=float32, sparse=False, name=keras_tensor_4756>
depth
1 <KerasTensor shape=(None, 128, 128, 32), dtype=float32, sparse=False, name=keras_tensor_4763> <KerasTensor shape=(None, 128, 128, 32), dtype=float32, sparse=False, name=keras_tensor_4763>
2 <KerasTensor shape=(None, 64, 64, 64), dtype=float32, sparse=False, name=keras_tensor_4770> <KerasTensor shape=(None, 64, 64, 64), dtype=float32, sparse=False, name=keras_tensor_4770>
3 <KerasTe