In [11]:
import tensorflow as tf
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Concatenate, Input, Multiply, Conv2DTranspose, Subtract, Add, Lambda
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.utils import to_categorical
import cv2
import numpy as np
import os
from tensorflow.keras.utils import Sequence



In [12]:
def encoder_block(x, num_filters):
    c1 = Conv2D(num_filters, 3, activation='relu', padding='same')(x)
    c1 = Conv2D(num_filters,3, activation='relu', padding='same')(c1)
    p1 = MaxPooling2D((2, 2))(c1)
    return p1, c1

def decoder_block(x, skip, num_filters):
    u7 = Conv2DTranspose(num_filters, (2, 2), strides=(2, 2), padding='same')(x)
    u7 = Concatenate()([u7, skip])
    c7 = Conv2D(num_filters, 3, activation='relu', padding='same')(u7)
    c7 = Conv2D(num_filters, 3, activation='relu', padding='same')(c7)
    return c7

def residual_guided_fusion(rgb_decoder_out, depth_decoder_out, y_n, filters, num_classes):
    # RGB PATH
    # generate RGB predicted mask y_hat_n through a 1 × 1 convolutional layer
    y_hat_n = Conv2D(num_classes, kernel_size=(1, 1), padding='same')(rgb_decoder_out)
    y_res = Subtract()([y_n, y_hat_n])

    # DEPTH PATH
    # we subtract the RGB feature maps with depth feature maps by element-wise subtraction to get the difference between them. 
    difference_maps = Subtract()([depth_decoder_out, rgb_decoder_out])
    # The channel of the different features is adjusted to the number of classes through a 1 × 1 convolution.
    depth_conv = Conv2D(num_classes, kernel_size=(1, 1), padding='same')(difference_maps)
    skip = depth_conv

    # Then a residual unit with a 3 × 3convolution is used to generate the predicted residual mask y_hat_nres
    depth_conv = Conv2D(num_classes, kernel_size=(3, 3), padding='same')(depth_conv)
    y_hat_res = Add()([depth_conv, skip])

    # The channel of y_hat_res is adjusted to that of the RGB feature maps by a 1 ×1 convolution and result is fused with the RGB feature maps through an element-wise multiplication
    channels_rgb = rgb_decoder_out.shape[-1]
    y_hat_res_conv = Conv2D(channels_rgb, kernel_size=(1, 1), padding='same')(y_res)

    combined_path = Multiply()([y_hat_res_conv, rgb_decoder_out])

    stacked = Concatenate()([combined_path, rgb_decoder_out, y_hat_res_conv])

    return Conv2D(filters, kernel_size=(3, 3), padding='same')(stacked), y_hat_n, y_hat_res, y_res


def unet(input_shape=(128, 256, 3), num_classes=20):
    inputs = Input(shape=input_shape)

    # Split the combined input tensor into RGB and depth
    inputs_rgb = Lambda(lambda x: x[..., :3])(inputs)  # Extract RGB (3 channels)
    inputs_depth = Lambda(lambda x: x[..., 3:])(inputs)  # Extract Depth (1 channel)


    rgb_enc1, rgb_enc1_skip = encoder_block(inputs_rgb, 32) # (64,128,32)
    rgb_enc2, rgb_enc2_skip = encoder_block(rgb_enc1, 64) # (32,64,64)
    rgb_enc3, rgb_enc3_skip = encoder_block(rgb_enc2, 128) # (16,32,128)

    depth_enc1, depth_enc1_skip = encoder_block(inputs_depth, 32) # (64,128,32)
    depth_enc2, depth_enc2_skip = encoder_block(depth_enc1, 64) # (32,64,64)
    depth_enc3, depth_enc3_skip = encoder_block(depth_enc2, 128) # (16,32,128)

    # Bottleneck
    c5 = Conv2D(256, 3, activation='relu', padding='same')(rgb_enc3) # (16,32,256)
    c5 = Conv2D(256, 3, activation='relu', padding='same')(c5) # (16,32,256)

    c6 = Conv2D(256, 3, activation='relu', padding='same')(depth_enc3) # (16,32,256)
    c6 = Conv2D(256, 3, activation='relu', padding='same')(c6) # (16,32,256)

    # # Decoder

    depth_dec3 = decoder_block(c6, depth_enc3_skip, 128) # (16,32,128)
    depth_dec2 = decoder_block(depth_dec3, depth_enc2_skip, 64) # (32,64,64)
    depth_dec1 = decoder_block(depth_dec2, depth_enc1_skip, 32) # (64,128,32)

    rgb_dec3 = decoder_block(c5, rgb_enc3_skip, 128) # (16,32,128)
    # rgb_dec2 = decoder_block(rgb_dec3, rgb_enc2_skip, 64) # (32,64,64)
    # rgb_dec1 = decoder_block(rgb_dec2, rgb_enc1_skip, 32) # (64,128,32)

    rgb_dec2 = decoder_block(Add()([rgb_dec3, depth_dec3]), rgb_enc2_skip, 64) # (32,64,64)
    rgb_dec1 = decoder_block(Add()([rgb_dec2, depth_dec2]), rgb_enc1_skip, 32) # (64,128,32)

    # Segmentation output
    rgb_segmentation_output = Conv2D(num_classes, 1, activation='softmax', name='rgb_segmentation')(rgb_dec1)
    depth_segmentation_output = Conv2D(num_classes, 1, activation='softmax', name='depth_segmentation')(depth_dec1)

    model = Model(inputs=inputs, outputs=[rgb_segmentation_output, depth_segmentation_output])
    return model

In [13]:
# Paths to the dataset
rgb_dir = './train/image'
label_dir = './train/label'
depth_dir = './train/depth'

# Parameters
rgb_shape = (128, 256, 3)
depth_shape = (128, 256, 1)
input_shape = (128, 256, 4)
label_shape = (128, 256)
num_classes = 20
batch_size = 8


class NPYDataGenerator(Sequence):
    def __init__(self, rgb_dir, label_dir, depth_dir, batch_size=8, image_size=(128, 256), num_classes=20, shuffle=True):
        self.rgb_dir = rgb_dir
        self.label_dir = label_dir
        self.depth_dir = depth_dir
        self.batch_size = batch_size
        self.image_size = image_size
        self.num_classes = num_classes
        self.shuffle = shuffle
        self.file_indices = sorted([int(f.split('.')[0]) for f in os.listdir(rgb_dir)])
        self.on_epoch_end()

    def __len__(self):
        return len(self.file_indices) // self.batch_size

    def __getitem__(self, index):
        batch_indices = self.file_indices[index * self.batch_size:(index + 1) * self.batch_size]
        X, y = self.__data_generation(batch_indices)
        return X, y

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.file_indices)

    def __data_generation(self, batch_indices):
        # Initialize empty arrays for RGB, depth, and labels
        X = np.empty((self.batch_size, *self.image_size, 3), dtype=np.float32)
        depth = np.empty((self.batch_size, *self.image_size, 1), dtype=np.float32)  # Depth with 1 channel
        y = np.empty((self.batch_size, *self.image_size, 1), dtype=np.int32)

        for i, idx in enumerate(batch_indices):
            rgb_path = os.path.join(self.rgb_dir, f"{idx}.npy")
            label_path = os.path.join(self.label_dir, f"{idx}.npy")
            depth_path = os.path.join(self.depth_dir, f"{idx}.npy")
            
            # Load and normalize RGB image
            X[i] = np.load(rgb_path).astype(np.float32) / 255.0            
            # Load and normalize depth map
            depth[i] = np.load(depth_path).astype(np.float32) / 255.0  # Assuming depth is in [0, 255] range
            
            # Load label and add 1 to shift labels from -1 (background) to 0 (background) and 1-19 for classes
            y[i] = np.load(label_path).astype(np.int32).reshape((*self.image_size, 1)) + 1
        
        y_one_hot = to_categorical(y, num_classes=num_classes)
        combined_input = np.concatenate([X, depth], axis=-1)

        return combined_input, {'rgb_segmentation': y_one_hot, 'depth_segmentation': y_one_hot}

# Initialize the data generator
train_generator = NPYDataGenerator(rgb_dir, label_dir, depth_dir, batch_size=batch_size, image_size=rgb_shape[:2], num_classes=num_classes)

# Define the model
unet_model = unet(input_shape, num_classes)

unet_model.compile(
    optimizer=Adam(),
    loss={'rgb_segmentation': CategoricalCrossentropy(from_logits=False), 
          'depth_segmentation': CategoricalCrossentropy(from_logits=False)},
    metrics={'rgb_segmentation': 'accuracy', 'depth_segmentation': 'accuracy'}
)

unet_model.fit(train_generator, epochs=1)


  self._warn_if_super_not_called()


[1m  7/297[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m16:19[0m 3s/step - depth_segmentation_accuracy: 0.2062 - depth_segmentation_loss: 2.9834 - loss: 5.9733 - rgb_segmentation_accuracy: 0.0978 - rgb_segmentation_loss: 2.9899

KeyboardInterrupt: 