# MoBioFP - Fingertip Segmentation using U-Net

## Environment Variables

In [None]:
%env SM_FRAMEWORK=tf.keras

## Imported libraries

In [None]:
import os
import cv2
import random
import albumentations as A
import numpy as np
import matplotlib.pyplot as plt
import platform

# TODO: this should be removed as already part of tensorflow.keras
import keras

# TODO: this should be removed as already part of tensorflow.keras
import segmentation_models as sm

from tensorflow.keras import backend as K
from tensorflow.keras.layers import (
    Input,
    Conv2D,
    MaxPooling2D,
    concatenate,
    Conv2DTranspose,
    BatchNormalization,
    Activation,
    Dropout,
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.utils import Sequence
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from sklearn.model_selection import train_test_split
from scipy import ndimage

## Functions used

### Utility functions

In [None]:
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(" ".join(name.split("_")).title())
        plt.imshow(image)
    plt.show()


def make_prediction(model, image, shape):
    """
    Makes a prediction using a trained model on an image.

    Parameters:
    model (keras.Model): The trained model to use for prediction.
    image (str): The path to the image file to predict on.
    shape (tuple): The target size to resize the image to before prediction.

    Returns:
    np.array: A 2D numpy array representing the predicted mask, reshaped to (256, 256).
    """
    img = img_to_array(load_img(image, target_size=shape))
    img = np.expand_dims(img, axis=0) / 255.0
    mask = model.predict(img).round()

    mask = (mask[0] > 0.5) * 1
    # print(np.unique(mask,return_counts=True))
    mask = np.reshape(mask, (256, 256))

    return mask


# Get iamge prediction, merge it with original image and diplay or save it
def mask_and_segmented_image(image, save=False, output_path=None):
    # Load image
    original_img = load_img(image)
    img = img_to_array(original_img)

    # Do prediction and resize mask
    mask = make_prediction(model, image, (256, 256, 3))  # TODO: model not defined!
    mask2 = cv2.merge([mask, mask, mask]).astype("float32")
    mask2 = cv2.resize(mask2, (img.shape[1], img.shape[0]))

    # Get segmented image
    h, w = img.shape[:2]
    mask_resized = cv2.resize(np.uint8(mask * 1), (w, h))
    mask_resized = mask_resized != 0
    segment = np.zeros((h, w, 3))
    segment[:, :, 0] = img[:, :, 0] * mask_resized
    segment[:, :, 1] = img[:, :, 1] * mask_resized
    segment[:, :, 2] = img[:, :, 2] * mask_resized
    segment[np.where((segment == [0, 0, 0]).all(axis=2))] = [0, 0, 0]
    img[np.where((img == [255, 255, 255]).all(axis=2))] = [0, 0, 0]

    if save:
        image_output_path = output_path + "/{}".format(os.path.basename(image)).replace(
            ".jpg", ".png"
        )
        cv2.imwrite(
            image_output_path,
            cv2.cvtColor(segment.astype("float32"), cv2.COLOR_RGB2BGR),
        )
        # plt.imshow(segment/255.)
        # plt.waitforbuttonpress()
    else:
        plt.imshow(segment / 255.0)
        # plt.waitforbuttonpress()

### Functions for dataset generation, augmentation and metrics

In [None]:
# Class for generating dataset
class DataGenerator(Sequence):
    "Generates data for Keras"

    def __init__(
        self,
        images,
        image_dir,
        labels,
        label_dir,
        augmentation=None,
        preprocessing=None,
        batch_size=8,
        dim=(256, 256, 3),
        shuffle=True,
    ):
        "Initialization"
        self.dim = dim
        self.images = images
        self.image_dir = image_dir
        self.labels = labels
        self.label_dir = label_dir
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        "Denotes the number of batches per epoch"
        return int(np.floor(len(self.images) / self.batch_size))

    def __getitem__(self, index):
        "Generate one batch of data"
        # Generate indexes of the batch
        indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size]

        # Find list of IDs
        list_IDs_temp = [k for k in indexes]

        # Generate data
        X, y = self.__data_generation(list_IDs_temp)

        return X, y

    def on_epoch_end(self):
        "Updates indexes after each epoch"
        self.indexes = np.arange(len(self.images))
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def __data_generation(self, list_IDs_temp):
        "Generates data containing batch_size samples"  # X : (n_samples, *dim, n_channels)
        # Initialization
        batch_imgs = list()
        batch_labels = list()

        # Generate data
        for i in list_IDs_temp:
            # degree = np.random.random() * 360
            # Store sample
            img = load_img(self.image_dir + "/" + self.images[i], target_size=self.dim)
            img = img_to_array(img) / 255.0

            # Store class
            label = load_img(
                self.label_dir + "/" + self.labels[i], target_size=self.dim
            )
            label = img_to_array(label)[:, :, 0]
            label = label != 0
            label = ndimage.binary_erosion(ndimage.binary_erosion(label))
            label = ndimage.binary_dilation(ndimage.binary_dilation(label))
            label = np.expand_dims((label) * 1, axis=2)

            # apply augmentations
            if self.augmentation:
                # TODO: Rename in case random.sample is used
                sample = self.augmentation(image=img, mask=label)
                img, label = sample["image"], sample["mask"]

            # apply preprocessing
            if self.preprocessing:
                # TODO: Rename in case random.sample is used
                sample = self.preprocessing(image=img, mask=label)
                img, label = sample["image"], sample["mask"]

            batch_imgs.append(img)  # transformed_img
            batch_labels.append(label)  # transformed_label

        return np.array(batch_imgs, dtype=np.float32), np.array(
            batch_labels, dtype=np.float32
        )


# Some augmentations
def round_clip_0_1(x, **kwargs):
    """
    Rounds the input to the nearest integer and clips it to the range [0, 1].

    Parameters:
    x (np.array): The input array to round and clip.
    **kwargs: Arbitrary keyword arguments. This is included to maintain compatibility with the albumentations library, which may pass additional arguments.

    Returns:
    np.array: The rounded and clipped input array.
    """
    return x.round().clip(0, 1)


def get_training_augmentation():
    """
    Defines the augmentation pipeline for training data.

    Returns:
    albumentations.Compose: The augmentation pipeline.
    """
    train_transform = [
        A.PadIfNeeded(min_height=256, min_width=256, always_apply=True, border_mode=0),
        # Flip augmentations
        A.OneOf(
            [A.HorizontalFlip(p=1), A.VerticalFlip(p=1), A.Transpose(p=1)],
            p=0.9,
        ),
        # Geometric augmentations
        A.OneOf(
            [
                A.ShiftScaleRotate(
                    scale_limit=0.3,
                    rotate_limit=45,
                    shift_limit=0.2,
                    border_mode=0,
                    p=1,
                ),
                A.Perspective(p=1),
            ],
            p=0.9,
        ),
        # Resolution augmentation
        A.OneOf(
            [
                A.Sharpen(p=1),
                A.Blur(blur_limit=3, p=1),
                A.MotionBlur(blur_limit=3, p=1),
            ],
            p=0.9,
        ),
        # Visual alterations
        A.OneOf(
            [
                # A.GaussNoise(var_limit=(0.0, 0.01), p=1),
                A.HueSaturationValue(
                    hue_shift_limit=1, sat_shift_limit=0.2, val_shift_limit=0.5, p=1
                ),
                A.RandomBrightnessContrast(p=1),
                # A.RandomGamma(p=1),
            ],
            p=0.5,
        ),
        A.Lambda(mask=round_clip_0_1),
    ]

    return A.Compose(train_transform)


def get_validation_augmentation():
    """
    Defines the augmentation pipeline for validation data.

    Returns:
    albumentations.Compose: The augmentation pipeline.
    """
    test_transform = [A.PadIfNeeded(256, 256)]  # make image shape divisible by 32

    return A.Compose(test_transform)


def jaccard_distance_loss(y_true, y_pred, smooth=100):
    """
    Calculates the Jaccard distance loss between the true and predicted labels.

    Parameters:
    y_true (tf.Tensor): The true labels.
    y_pred (tf.Tensor): The predicted labels.
    smooth (int, optional): A smoothing factor to prevent division by zero. Defaults to 100.

    Returns:
    tf.Tensor: The Jaccard distance loss.
    """
    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
    sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1)
    jac = (intersection + smooth) / (sum_ - intersection + smooth)

    return (1 - jac) * smooth


def dice_coef(y_true, y_pred):
    """
    Calculates the Dice coefficient between the true and predicted labels.

    Parameters:
    y_true (tf.Tensor): The true labels.
    y_pred (tf.Tensor): The predicted labels.
    smooth (int, optional): A smoothing factor to prevent division by zero. Defaults to 1.

    Returns:
    tf.Tensor: The Dice coefficient.
    """
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)

    return (2.0 * intersection + K.epsilon()) / (
        K.sum(y_true_f) + K.sum(y_pred_f) + K.epsilon()
    )

## Main body

### Global Constants

In [None]:
# NOTE:
#
# Download the dataset and the model checkpoint from Google Drive and place them in the following directories:
#
# IMAGE_DIR_PATH = "../data/raw/iiitd-unet/images"
# MASKS_DIR_PATH = "../data/raw/iiitd-unet/masks"
# MODEL_CHECKPOINT_PATH = "../models/best-iiitd-unet.h5"
#
# This is a temporary solution and will be replaced by a better one in the future.
IMAGE_DIR_PATH = "../data/raw/iiitd-unet/images"
MASKS_DIR_PATH = "../data/raw/iiitd-unet/masks"
MODEL_CHECKPOINT_PATH = "../models/best-iiitd-unet.h5"

### Reading images and create train-validation split

In [None]:
imgs_paths = os.listdir(IMAGE_DIR_PATH)
masks_paths = os.listdir(MASKS_DIR_PATH)

imgs_paths.sort()
masks_paths.sort()

train_imgs, val_imgs, train_masks, val_masks = train_test_split(
    imgs_paths, masks_paths, test_size=0.15, random_state=42
)

# Check if the dataset is loaded correctly
assert len(train_imgs) == len(train_masks) and len(val_imgs) == len(val_masks)

In [None]:
print(f"Number of training images: {len(train_imgs)}")
print(f"Number of validation images: {len(val_imgs)}")

### Dataset generation

In [None]:
# Create train and validation generator
train_generator = DataGenerator(
    train_imgs,
    IMAGE_DIR_PATH,
    train_masks,
    MASKS_DIR_PATH,
    augmentation=get_training_augmentation(),
    preprocessing=None,
    batch_size=8,
    dim=(256, 256, 3),
    shuffle=True,
)

train_steps = train_generator.__len__()

val_generator = DataGenerator(
    val_imgs,
    IMAGE_DIR_PATH,
    val_masks,
    MASKS_DIR_PATH,
    augmentation=get_validation_augmentation(),
    preprocessing=None,
    batch_size=8,
    dim=(256, 256, 3),
    shuffle=True,
)

val_steps = val_generator.__len__()

### Show random samples

In [None]:
X, y = train_generator.__getitem__(1)

# TODO: Rename in case random.sample is used
sample = random.randint(0, 7)

visualize(image=X[sample], label=y[sample])

## UNet model

In [None]:
def conv_block(tensor, nfilters, size=3, padding="same", initializer="he_normal"):
    """
    Defines a convolutional block with two Conv2D layers, each followed by BatchNormalization and ReLU activation.

    Parameters:
    tensor (tf.Tensor): The input tensor.
    nfilters (int): The number of filters for the Conv2D layers.
    size (int, optional): The kernel size for the Conv2D layers. Defaults to 3.
    padding (str, optional): The padding method for the Conv2D layers. Defaults to "same".
    initializer (str, optional): The initializer for the Conv2D layers. Defaults to "he_normal".

    Returns:
    tf.Tensor: The output tensor after applying the convolutional block.
    """
    x = Conv2D(
        filters=nfilters,
        kernel_size=(size, size),
        padding=padding,
        kernel_initializer=initializer,
    )(tensor)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = Conv2D(
        filters=nfilters,
        kernel_size=(size, size),
        padding=padding,
        kernel_initializer=initializer,
    )(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    return x


def deconv_block(tensor, residual, nfilters, size=3, padding="same", strides=(2, 2)):
    """
    Defines a deconvolutional block with a Conv2DTranspose layer followed by a concatenation with the residual, and a convolutional block.

    Parameters:
    tensor (tf.Tensor): The input tensor.
    residual (tf.Tensor): The residual tensor to concatenate with the output of the Conv2DTranspose layer.
    nfilters (int): The number of filters for the Conv2DTranspose layer.
    size (int, optional): The kernel size for the Conv2DTranspose layer. Defaults to 3.
    padding (str, optional): The padding method for the Conv2DTranspose layer. Defaults to "same".
    strides (tuple, optional): The strides for the Conv2DTranspose layer. Defaults to (2, 2).

    Returns:
    tf.Tensor: The output tensor after applying the deconvolutional block.
    """
    y = Conv2DTranspose(
        nfilters, kernel_size=(size, size), strides=strides, padding=padding
    )(tensor)
    y = concatenate([y, residual], axis=3)
    y = conv_block(y, nfilters)
    return y


def Unet(h, w, filters):
    """
    Defines the architecture of the U-Net model.

    Parameters:
    h (int): The height of the input images.
    w (int): The width of the input images.
    filters (int): The number of filters for the first convolutional block. This number is doubled after each max pooling layer in the encoder part and halved after each deconvolutional block in the decoder part.

    Returns:
    keras.Model: The U-Net model.
    """

    # down
    input_layer = Input(shape=(h, w, 3), name="image_input")
    conv1 = conv_block(input_layer, nfilters=filters)
    conv1_out = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = conv_block(conv1_out, nfilters=filters * 2)
    conv2_out = MaxPooling2D(pool_size=(2, 2))(conv2)
    conv3 = conv_block(conv2_out, nfilters=filters * 4)
    conv3_out = MaxPooling2D(pool_size=(2, 2))(conv3)
    conv4 = conv_block(conv3_out, nfilters=filters * 8)
    conv4_out = MaxPooling2D(pool_size=(2, 2))(conv4)
    conv4_out = Dropout(0.5)(conv4_out)
    conv5 = conv_block(conv4_out, nfilters=filters * 16)
    conv5 = Dropout(0.5)(conv5)

    # up
    deconv6 = deconv_block(conv5, residual=conv4, nfilters=filters * 8)
    deconv6 = Dropout(0.5)(deconv6)
    deconv7 = deconv_block(deconv6, residual=conv3, nfilters=filters * 4)
    deconv7 = Dropout(0.5)(deconv7)
    deconv8 = deconv_block(deconv7, residual=conv2, nfilters=filters * 2)
    deconv9 = deconv_block(deconv8, residual=conv1, nfilters=filters)
    output_layer = Conv2D(filters=1, kernel_size=(1, 1), activation="sigmoid")(deconv9)

    # using sigmoid activation for binary classification
    model = Model(inputs=input_layer, outputs=output_layer, name="Unet")

    return model

In [None]:
# Create an instance of the model
model = Unet(256, 256, 64)
model.summary()

### Compile the model and create callbacks

In [None]:
# Set the optimizer and its parameters
if platform.system() == "Darwin":
    # Use the legacy Adam optimizer on M1/M2 Macs
    optim = keras.optimizers.legacy.Adam(learning_rate=0.0001)
else:
    # Use the new Adam optimizer on other platforms
    optim = Adam(learning_rate=0.0001)

# Compile the model, create checkpoints and define an early stopping
model.compile(
    optimizer=optim,
    loss=jaccard_distance_loss,
    metrics=[
        dice_coef,
        sm.metrics.IOUScore(threshold=0.5),
        sm.metrics.FScore(threshold=0.5),
        "accuracy",
    ],
)
mc = ModelCheckpoint(
    mode="max",
    filepath="top-weights.h5",
    monitor="val_dice_coef",
    save_best_only="True",
    save_weights_only="True",
    verbose=1,
)
es = EarlyStopping(mode="max", monitor="val_dice_coef", patience=3, verbose=1)

callbacks = [
    # K.callbacks.LearningRateScheduler(scheduler),
    keras.callbacks.ModelCheckpoint(
        MODEL_CHECKPOINT_PATH,
        save_weights_only=True,
        save_best_only=True,
        mode="min",
    ),
    keras.callbacks.ReduceLROnPlateau(),
]

### Train the model

In [None]:
results = model.fit(
    train_generator,
    steps_per_epoch=train_steps,
    epochs=100,
    callbacks=callbacks,
    validation_data=val_generator,
    validation_steps=val_steps,
)

In [None]:
# Plot training & validation Jaccard loss and Dice coeff.
plt.figure(figsize=(30, 5))
plt.subplot(121)
plt.plot(results.history["loss"])
plt.plot(results.history["val_loss"])
plt.title("Model Jaccard loss")
plt.ylabel("Loss")
plt.xlabel("Epoch")
plt.legend(["Train", "Validation"], loc="upper left")

plt.subplot(122)
plt.plot(results.history["dice_coef"])
plt.plot(results.history["val_dice_coef"])
plt.title("Model dice coef")
plt.ylabel("Loss")
plt.xlabel("Epoch")
plt.legend(["Train", "Validation"], loc="upper left")
plt.show()

# Plot training & validation IoU and F1 score
plt.figure(figsize=(30, 5))
plt.subplot(121)
plt.plot(results.history["iou_score"])
plt.plot(results.history["val_iou_score"])
plt.title("Model IoU score")
plt.ylabel("Loss")
plt.xlabel("Epoch")
plt.legend(["Train", "Validation"], loc="upper left")

plt.subplot(122)
plt.plot(results.history["f1-score"])
plt.plot(results.history["val_f1-score"])
plt.title("Model F1 score")
plt.ylabel("Loss")
plt.xlabel("Epoch")
plt.legend(["Train", "Validation"], loc="upper left")
plt.show()

## Show results

### Load best model weights

In [None]:
# load best weights and evaluate on validation set
model.load_weights(MODEL_CHECKPOINT_PATH)
scores = model.evaluate(val_generator)

print("Loss: {:.5}".format(scores[0]))

### Test predict on sample

In [None]:
PREDICTED_MASK_DIR_PATH = "../data/raw/iiitd-sample/1_i_1_n_1.jpg"

mask_and_segmented_image(PREDICTED_MASK_DIR_PATH)