## Imports

In [None]:
#%pip install -U albumentations>=3.0.0
#!git clone https://github.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models

In [None]:
%cd semanctic-segmentation-tasm
import tensorflow_advanced_segmentation_models as tasm
%cd ..

In [None]:
import os
import cv2
import numpy as np
from time import time
import tensorflow as tf
import albumentations as A
import matplotlib.pyplot as plt
import tensorflow.keras.backend as K
from keras.utils import to_categorical
%matplotlib inline
import matplotlib.pyplot as plt

## Directories

In [None]:
# Data Handling
DATA_DIR = r"D:\Datasets\CeyMo"

x_train_dir = os.path.join(DATA_DIR, 'train/images')
y_train_dir = os.path.join(DATA_DIR, 'train/mask_annotations')

x_valid_dir = os.path.join(DATA_DIR, 'test/images')
y_valid_dir = os.path.join(DATA_DIR, 'train/mask_annotations')

x_test_dir = os.path.join(DATA_DIR, 'test/images')
y_test_dir = os.path.join(DATA_DIR, 'train/mask_annotations')

print(os.path.exists(x_train_dir))

### Helper Functions

In [None]:
def flatten(list):
    return [item for sublist in list for item in sublist]

def get_folders_in_folder(folder):
    return [f[0] for f in os.walk(folder)][1:]

def get_files_in_folder(folder, pattern=None):
    if pattern is None:
        return sorted([os.path.join(folder, f) for f in os.listdir(folder)])
    else:
        return sorted([os.path.join(folder, f) for f in os.listdir(folder) if pattern in f])

def get_files_recursive(folder, pattern=None):
    if not bool(get_folders_in_folder(folder)):
        return get_files_in_folder(folder, pattern)
    else:
        return flatten([get_files_in_folder(f, pattern) for f in get_folders_in_folder(folder)])

In [None]:
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()
    
# helper function for data visualization    
def denormalize(x):
    """Scale image to range 0..1 for correct plot"""
    x_max = np.percentile(x, 98)
    x_min = np.percentile(x, 2)    
    x = (x - x_min) / (x_max - x_min)
    x = x.clip(0, 1)
    return x
    
def round_clip_0_1(x, **kwargs):
    return x.round().clip(0, 1)

In [None]:
get_files_recursive(x_train_dir)[:2], get_files_recursive(y_train_dir)[:2]

## Define the Label classes

In [None]:
from collections import namedtuple

# a label and all meta information
Label = namedtuple( 'Label' , [

    'name'        , # The identifier of this label, e.g. 'car', 'person', ... .
                    # We use them to uniquely name a class

    'id'          , # An integer ID that is associated with this label.
                    # The IDs are used to represent the label in ground truth images
                    # An ID of -1 means that this label does not have an ID and thus
                    # is ignored when creating ground truth images (e.g. license plate).
                    # Do not modify these IDs, since exactly these IDs are expected by the
                    # evaluation server.

    'trainId'     , # Feel free to modify these IDs as suitable for your method. Then create
                    # ground truth images with train IDs, using the tools provided in the
                    # 'preparation' folder. However, make sure to validate or submit results
                    # to our evaluation server using the regular IDs above!
                    # For trainIds, multiple labels might have the same ID. Then, these labels
                    # are mapped to the same class in the ground truth images. For the inverse
                    # mapping, we use the label that is defined first in the list below.
                    # For example, mapping all void-type classes to the same ID in training,
                    # might make sense for some approaches.
                    # Max value is 255!

    'category'    , # The name of the category that this label belongs to

    'categoryId'  , # The ID of this category. Used to create ground truth images
                    # on category level.

    'hasInstances', # Whether this label distinguishes between single instances or not

    'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
                    # during evaluations or not

    'color'       , # The color of this label
    ] )

labels = [
    #       name                      id    trainId   category  catId    hasInstances   ignoreInEval        color
    Label(  'unlabeled'             ,  0 ,        0 , 'void'    , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'bus lane'              ,  1 ,        1 , 'void'    , 0       , False        , True         , (  0,255,255) ),
    Label(  'cycle lane'            ,  2 ,        2 , 'void'    , 0       , False        , True         , (  0,128,255) ),
    Label(  'diamond'               ,  3 ,        3 , 'void'    , 0       , False        , True         , (178,102,255) ),
    Label(  'junction box'          ,  4 ,        4 , 'void'    , 0       , False        , True         , (255,255, 51) ),
    Label(  'left arrow'            ,  5 ,        5 , 'void'    , 0       , False        , True         , (255,102,178) ),
    Label(  'pedestrian crossing'   ,  6 ,        6 , 'void'    , 0       , False        , True         , (255,255,  0) ),
    Label(  'right arrow'           ,  7 ,        7 , 'flat'    , 1       , False        , False        , (255,  0,127) ),
    Label(  'straight arrow'        ,  8 ,        8 , 'flat'    , 1       , False        , False        , (255,  0,255) ),
    Label(  'slow'                  ,  9 ,        9 , 'flat'    , 1       , False        , True         , (  0,255,  0) ),
    Label(  'straight-left arrow'   , 10 ,       10 , 'flat'    , 1       , False        , True         , (255,128,  0) ),
    Label(  'straight-right arrow'  , 11 ,       11 , 'flat'    , 1       , False        , True         , (255,  0,  0) )
]

In [None]:
labels_color = [list(labels[k].color) for k in range(len(labels)) if labels[k].trainId >= 0 and labels[k].trainId < 255]
labels_name = [labels[k].name for k in range(len(labels)) if labels[k].trainId >= 0 and labels[k].trainId < 255]

print("Number of classes - ", len(labels_color))
print("\n")
for name, color in zip(labels_name, labels_color):
    print(f"{name} - {color}")

## Define some Global variables

In [None]:
TOTAL_CLASSES = labels_name
N_CLASSES = len(labels_color)
BATCH_SIZE = 2
HEIGHT = 256
WIDTH = 256
BACKBONE_NAME = "efficientnetb3"
WEIGHTS = "imagenet"
WWO_AUG = True # train data with and without augmentation

In [None]:
train_shuffle = True
val_shuffle = True
seed = 29598

## Data Augmentation Functions

In [None]:
# define heavy augmentations
def get_training_augmentation(height, width):
    train_transform = [
        A.HorizontalFlip(p=0.3),

        # A.ShiftScaleRotate(scale_limit=0.6, rotate_limit=0.2, shift_limit=0.1, p=1, border_mode=0),

        A.PadIfNeeded(min_height=height, min_width=width, always_apply=True, border_mode=0),
        # A.RandomCrop(height=height, width=width, always_apply=True),
        A.Resize(height, width, always_apply=True),

        # A.GaussNoise(p=0.2),
        # A.Perspective(p=0.5),

        # A.OneOf(
        #     [
        #         A.CLAHE(p=1),
        #         A.RandomBrightnessContrast(p=1),
        #         A.RandomGamma(p=1),
        #     ],
        #     p=0.9,
        # ),

        # A.OneOf(
        #     [
        #         A.Sharpen(p=1),
        #         A.Blur(blur_limit=3, p=1),
        #         A.MotionBlur(blur_limit=3, p=1),
        #     ],
        #     p=0.9,
        # ),

        # A.OneOf(
        #     [
        #         A.RandomContrast(p=1),
        #         A.HueSaturationValue(p=1),
        #     ],
        #     p=0.9,
        # ),
        # A.Lambda(mask=round_clip_0_1)
    ]
    return A.Compose(train_transform)

def get_validation_augmentation(height, width):
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        A.PadIfNeeded(height, width),
        A.Resize(height, width, always_apply=True)
    ]
    return A.Compose(test_transform)

def data_get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        A.Lambda(image=preprocessing_fn),
    ]
    return A.Compose(_transform)

## Data Generation Functions

### Functions

In [None]:
def create_image_label_path_generator(images_dir, masks_dir, shuffle=False, seed=None):
    # ids = sorted(os.listdir(images_dir))
    # mask_ids = sorted(os.listdir(masks_dir))
    ids = get_files_recursive(images_dir)
    mask_ids = get_files_recursive(masks_dir)

    if shuffle == True:
        if seed is not None:
            tf.random.set_seed(seed)

        indices = tf.range(start=0, limit=tf.shape(ids)[0], dtype=tf.int32)
        shuffled_indices = tf.random.shuffle(indices)

        ids = tf.gather(ids, shuffled_indices).numpy().astype(str)
        mask_ids = tf.gather(mask_ids, shuffled_indices).numpy().astype(str)

    images_fps = [os.path.join(images_dir, image_id) for image_id in ids]
    masks_fps = [os.path.join(masks_dir, image_id) for image_id in mask_ids]

    while True:
        for i in range(len(images_fps)):
            yield [images_fps[i], masks_fps[i]]

In [None]:
def label_segmentation_mask(seg, class_labels):
    """
    Given a 3D (W, H, depth=3) segmentation mask, prepare a 2D labeled segmentation mask
    # Arguments
        seg: The segmentation mask where each cell of depth provides the r, g, and b values
        class_labels
    # Returns
        Labeled segmentation mask where each cell provides its label value
    """
    seg = seg.astype("uint8")

    # returns a 2D matrix of size W x H of the segmentation mask
    label = np.zeros(seg.shape[:2], dtype=np.uint8)

    for i, rgb in enumerate(class_labels):
        label[(seg == rgb).all(axis=2)] = i
    return label

def one_hot_encode(seg, class_labels):
    """
    Convert a segmentation mask label array to one-hot format
    by replacing each pixel value with a vector of length num_classes
    # Arguments
        seg: The 3D array segmentation mask
        class_labels
    # Returns
        A 3D array with the same width and height as the input, but
        with a depth size of num_classes
    """
    num_classes = len(class_labels)  # seg dim = H*W*3
    label = label_segmentation_mask(seg, class_labels)  # label dim = H*W
    one_hot = to_categorical(label, num_classes)  # one_hot dim = H*W*N
    return one_hot

def decode_one_hot(label_one_hot, labels_name):
    pred = np.argmax(label_one_hot, axis=-1)
    color_codes = np.array(labels_name)
    pred = color_codes[pred.astype(np.uint8)]
    return pred

In [None]:
def process_image_label(images_paths, masks_paths, class_labels, augmentation=None):
    # read data
    image = cv2.imread(images_paths)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    mask = cv2.imread(masks_paths)
    mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
    
    # apply augmentations
    if augmentation:
        sample = augmentation(image=image, mask=mask)
        image, mask = sample['image'], sample['mask']
    
    mask_one_hot = one_hot_encode(mask, class_labels)
    return image, mask, mask_one_hot

#### Example

In [None]:
image_label_path_generator = create_image_label_path_generator(x_train_dir, y_train_dir, shuffle=True, seed=None)
image_path, label_path = next(image_label_path_generator)

print(image_path, label_path)

image = cv2.imread(image_path)
mask = cv2.imread(label_path)
image_aug, label_aug, label_one_hot = process_image_label(image_path, label_path, labels_color, augmentation=get_training_augmentation(height=HEIGHT, width=WIDTH))

print(image.shape, mask.shape, image_aug.shape, label_aug.shape, label_one_hot.shape)

fig = plt.figure(figsize=(12, 8))
plt.subplot(2, 3, 1)
plt.imshow(image)
plt.subplot(2, 3, 2)
plt.imshow(mask)
plt.subplot(2, 3, 4)
plt.imshow(image_aug)
plt.subplot(2, 3, 5)
plt.imshow(label_aug)
plt.subplot(2, 3, 6)
plt.imshow(decode_one_hot(label_one_hot, labels_color))
plt.show()

In [None]:
################################################################################################
# Data Generator
################################################################################################
def DataGenerator(train_dir, label_dir, batch_size, height, width, class_labels, augmentation, wwo_aug=False, shuffle=False, seed=None):
    image_label_path_generator = create_image_label_path_generator(
        train_dir, label_dir, shuffle=shuffle, seed=seed
    )
    if wwo_aug:
        while True:
            images = np.zeros(shape=[batch_size, height, width, 3])
            labels = np.zeros(shape=[batch_size, height, width, len(class_labels)], dtype=np.float32)
            for i in range(0, batch_size, 2):
                image_path, label_path = next(image_label_path_generator)
                image_aug, label_aug, label_aug_oh = process_image_label(image_path, label_path, class_labels, augmentation=augmentation)
                image_wo_aug, label_wo_aug, label_wo_aug_oh = process_image_label(image_path, label_path, class_labels, 
                                                                                  augmentation=get_validation_augmentation(height=HEIGHT, width=WIDTH))
                images[i], labels[i] = image_aug, label_aug_oh
                images[i + 1], labels[i + 1] = image_wo_aug, label_wo_aug_oh

            yield tf.convert_to_tensor(images), tf.convert_to_tensor(labels, tf.float32)
    else:
        while True:
            images = np.zeros(shape=[batch_size, height, width, 3])
            labels = np.zeros(shape=[batch_size, height, width, len(class_labels)], dtype=np.float32)
            for i in range(batch_size):
                image_path, label_path = next(image_label_path_generator)
                image, label, label_oh = process_image_label(image_path, label_path, class_labels, augmentation=augmentation)
                images[i], labels[i] = image, label_oh

            yield tf.convert_to_tensor(images), tf.convert_to_tensor(labels, tf.float32)

### Get the Data

<p>There are three options for the training dataset: </p>

- Training without augmentation
- Training with augmentation
- Training with and without augmentation (twice the data)

<p>Validation and Test data are of course without augmentation</p>

In [None]:
##########################################################################################################
# Data Generator
#           augmentation - wwo_aug
# get_train_augmentation -  true    -   both (training) and (validation) augmentation   - TrainSetwwoAug
# get_train_augmentation -  false   -   only (training) augmentation                    - TrainSet
# get_valid_augmentation -  true    -   both (validation) and (validation) augmentation - xx
# get_valid_augmentation -  false   -   only (validation) augmentation                  - TrainSetwoAug, ValidationSet, TestSet
##########################################################################################################
TrainSet = DataGenerator(
    x_train_dir,
    y_train_dir,
    BATCH_SIZE,
    HEIGHT,
    WIDTH,
    labels_color,
    augmentation=get_training_augmentation(height=HEIGHT, width=WIDTH),
    shuffle=train_shuffle,
    seed=seed
    )

TrainSetwoAug = DataGenerator(
    x_train_dir,
    y_train_dir,
    BATCH_SIZE,
    HEIGHT,
    WIDTH,
    labels_color,
    augmentation=get_validation_augmentation(height=HEIGHT, width=WIDTH),
    shuffle=train_shuffle,
    seed=seed
)

TrainSetwwoAug = DataGenerator(
    x_train_dir,
    y_train_dir,
    BATCH_SIZE,
    HEIGHT,
    WIDTH,
    labels_color,
    augmentation=get_training_augmentation(height=HEIGHT, width=WIDTH),
    wwo_aug=True,
    shuffle=train_shuffle,
    seed=seed
)

ValidationSet = DataGenerator(
    x_valid_dir,
    y_valid_dir,
    1,
    HEIGHT,
    WIDTH,
    labels_color,
    augmentation=get_validation_augmentation(height=HEIGHT, width=WIDTH),
    shuffle=val_shuffle,
    seed=seed
)

TestSet = DataGenerator(
    x_test_dir,
    y_test_dir,
    1,
    HEIGHT,
    WIDTH,
    labels_color,
    augmentation=get_validation_augmentation(height=HEIGHT, width=WIDTH),
)

In [None]:
for i in TrainSet:
    sample_image, sample_mask = i[0][0], i[1][0]
    print(len(i))
    print(i[0].shape)
    print(i[1].shape)
    
    fig = plt.figure(figsize=(8, 6))
    plt.subplot(1, 2, 1)
    plt.imshow(i[0][0])
    plt.subplot(1, 2, 2)
    plt.imshow(decode_one_hot(i[1][0], labels_color))
    plt.show()    
    break

In [None]:
n_samples_train = len(get_files_recursive(x_train_dir))
n_samples_valid = len(get_files_recursive(x_valid_dir))
n_samples_test = len(get_files_recursive(x_test_dir))

n_samples_train, n_samples_valid, n_samples_test

## Create the Model

In [None]:
base_model, layers, layer_names = tasm.create_base_model(name=BACKBONE_NAME, weights=WEIGHTS, height=HEIGHT, width=WIDTH, include_top=False, pooling=None)

BACKBONE_TRAINABLE = False
model = tasm.DANet(n_classes=N_CLASSES, base_model=base_model, output_layers=layers, backbone_trainable=BACKBONE_TRAINABLE)

### Define the optimizer as well as losses, metrics and callbacks

In [None]:
opt = tf.keras.optimizers.SGD(learning_rate=0.2, momentum=0.9)
metrics = [tasm.metrics.IOUScore(threshold=0.5)]
categorical_focal_dice_loss = tasm.losses.CategoricalFocalLoss(alpha=0.25, gamma=2.0) + tasm.losses.DiceLoss()

model.compile(
    optimizer=opt,
    loss=categorical_focal_dice_loss,
    metrics=metrics,
)
model.run_eagerly = False

callbacks = [
             tf.keras.callbacks.ModelCheckpoint("model.hdf5", verbose=1, save_weights_only=True, save_best_only=True),
             tf.keras.callbacks.ReduceLROnPlateau(monitor="iou_score", factor=0.2, patience=6, verbose=1, mode="max"),
             tf.keras.callbacks.EarlyStopping(monitor="iou_score", patience=16, mode="max", verbose=1, restore_best_weights=True)
]

#### Short check if model works properly

In [None]:
def display(display_list):
    plt.figure(figsize=(10, 10))

    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis('off')
    plt.show()

def create_mask(pred_mask):
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]

def show_predictions(dataset=None, num=1):
    output_model = model(sample_image[tf.newaxis, ...])
    print(output_model.shape)
    
    output_mask = create_mask(output_model)
    print(output_mask.shape)

    scce = tf.keras.losses.CategoricalCrossentropy()
    print("SparseCategoricalCrossentroy: " + str(scce(sample_mask, output_model[0]).numpy()))
    print("Iou-Score: " + str(tasm.losses.iou_score(sample_mask, output_model[0]).numpy()))
    print("categorical Focal Dice Loss: " + str(categorical_focal_dice_loss(sample_mask, output_model[0]).numpy()))

    display([sample_image, K.one_hot(K.squeeze(output_mask, axis=-1), 3)])
    
show_predictions()

In [None]:
model.summary()

# Training

#### Training Procedure
##### 1) Train model with freezed backbone only on train data
##### 2) Train completely unfreezed model with train and validation data

### 1) Train model with freezed backbone only on train data

In [None]:
model.layers

In [None]:
## Set Backbone trainable to False
for layer in model.layers:
    if "model" in layer.name:
        layer.trainable = False

    print(layer.name + ": " + str(layer.trainable))

In [None]:
steps_per_epoch = np.floor(n_samples_train / BATCH_SIZE)

print(BATCH_SIZE, steps_per_epoch)

In [None]:
history = model.fit(
    TrainSet,
    steps_per_epoch=steps_per_epoch,
    epochs=15,
    callbacks=callbacks,
    )

#### Plot Training IoU Scores and Losses

In [None]:
plt.figure(figsize=(30, 5))
plt.subplot(121)
plt.plot(history.history['iou_score'])
# plt.plot(history.history['val_iou_score'])
plt.title('Model IOU Score')
plt.ylabel('IOU Score')
plt.xlabel('Epoch')
plt.legend(['Train'], loc='upper left')

# Plot training & validation loss values
plt.subplot(122)
plt.plot(history.history['loss'])
# plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train'], loc='upper left')
plt.show()

### Train completely unfreezed model with train and test data

In [None]:
# Make whole model trainable and use validation set
for layer in model.layers:
    layer.trainable = True

    print(layer.name + ": " + str(layer.trainable))

In [None]:
opt = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9)
metrics = [tasm.metrics.IOUScore(threshold=0.5)]
categorical_focal_dice_loss = tasm.losses.CategoricalFocalLoss(alpha=0.25, gamma=2.0) + tasm.losses.DiceLoss()

model.compile(
    optimizer=opt,
    loss=categorical_focal_dice_loss,
    metrics=metrics,
)
model.run_eagerly = True

callbacks = [
             tf.keras.callbacks.ModelCheckpoint("DeepLabV3plus.hdf5", verbose=1, save_weights_only=True, save_best_only=True),
             tf.keras.callbacks.ReduceLROnPlateau(monitor="val_iou_score", factor=0.2, patience=6, verbose=1, mode="max"),
             tf.keras.callbacks.EarlyStopping(monitor="val_iou_score", patience=16, mode="max", verbose=1, restore_best_weights=True)
]

In [None]:
history = model.fit(
    TrainSetwwoAug,
    steps_per_epoch=steps_per_epoch,
    epochs=30,
    callbacks=callbacks,
    validation_data=ValidationSet,
    validation_steps=len(os.listdir(x_valid_dir)),
    )

#### Plot Training IoU Scores and Losses

In [None]:
plt.figure(figsize=(30, 5))
plt.subplot(121)
plt.plot(history.history['iou_score'])
plt.plot(history.history['val_iou_score'])
plt.title('Model IOU Score')
plt.ylabel('IOU Score')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')

# Plot training & validation loss values
plt.subplot(122)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()

# Evaluation on Test Data

In [None]:
scores = model.evaluate(TestSet, steps=101)

print("Loss: {:.5}".format(scores[0]))
for metric, value in zip(metrics, scores[1:]):
    if metric != "accuracy":
        metric = metric.__name__
    print("mean {}: {:.5}".format(metric, value))

## Visual Examples on Test Data

In [None]:
plt.imshow(decode_one_hot(model.predict(image)[0], labels_color))
plt.show()

In [None]:
model.predict(image)[0].shape

In [None]:
pr_mask.shape

In [None]:
n = 5
ids = np.random.choice(np.arange(101), size=n,replace=False)
print(ids)

counter = 0
second_counter = 0
for i in TestSet:
    if counter in ids:
        image, gt_mask = i
        # image = np.expand_dims(image, axis=0)
        pr_mask = model.predict(image)
        pr_mask = np.argmax(pr_mask, axis=-1)

        print(counter)
        
        visualize(
            image=denormalize(image.numpy().squeeze()),
            gt_mask=gt_mask.numpy().squeeze(),
            pr_mask=pr_mask.squeeze(),
        )
        second_counter += 1
    counter += 1
    if second_counter == n:
        break