# MoBioFP - Fingertip Enhancement

In [None]:
%env SM_FRAMEWORK=tf.keras

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

from tensorflow.keras.layers import (
    Input,
    Conv2D,
    MaxPooling2D,
    concatenate,
    Conv2DTranspose,
    BatchNormalization,
    Activation,
    Dropout,
)
from tensorflow.keras.models import Model

In [None]:
MODEL_CHECKPOINT_PATH = "../models/best-iiitd-unet-arm64.h5"
PREDICTED_MASK_DIR_PATH = "../data/raw/iiitd-sample/1_i_1_n_1.jpg"

In [None]:
# Function to create convolutional block
def conv_block(tensor, nfilters, size=3, padding="same", initializer="he_normal"):
    x = Conv2D(
        filters=nfilters,
        kernel_size=(size, size),
        padding=padding,
        kernel_initializer=initializer,
    )(tensor)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = Conv2D(
        filters=nfilters,
        kernel_size=(size, size),
        padding=padding,
        kernel_initializer=initializer,
    )(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    return x


# Function to create encoder block
def encoder_block(inputs, n_filters):
    conv = conv_block(inputs, n_filters)
    pool = MaxPooling2D(pool_size=(2, 2))(conv)

    return pool, conv


# Function to create decoder block
def decoder_block(inputs, conv_output, n_filters):
    deconv = Conv2DTranspose(
        n_filters, kernel_size=(3, 3), strides=(2, 2), padding="same"
    )(inputs)
    concat = concatenate([deconv, conv_output], axis=3)
    conv = conv_block(concat, n_filters)

    return conv


# Function to create U-Net model
def create_unet(input_shape, n_filters=64):
    inputs = Input(shape=input_shape, name="image_input")
    p1, c1 = encoder_block(inputs, n_filters)
    p2, c2 = encoder_block(p1, n_filters * 2)
    p3, c3 = encoder_block(p2, n_filters * 4)
    p4, c4 = encoder_block(p3, n_filters * 8)
    p4 = Dropout(0.5)(p4)
    c5 = conv_block(p4, n_filters * 16)
    c5 = Dropout(0.5)(c5)
    d6 = decoder_block(c5, c4, n_filters * 8)
    d6 = Dropout(0.5)(d6)
    d7 = decoder_block(d6, c3, n_filters * 4)
    d7 = Dropout(0.5)(d7)
    d8 = decoder_block(d7, c2, n_filters * 2)
    d9 = decoder_block(d8, c1, n_filters)
    outputs = Conv2D(filters=1, kernel_size=(1, 1), activation="sigmoid")(d9)
    model = Model(inputs=inputs, outputs=outputs, name="Unet")

    return model

In [None]:
# Create an instance of the model
model = create_unet(input_shape=(256, 256, 3), n_filters=64)
model.summary()

In [None]:
def find_largest_connected_component(mask: np.ndarray) -> np.array:
    """
    Finds the largest connected component in a binary mask.
    Args:
        mask: Binary mask containing connected components.
    Returns:
        Binary mask with only the largest connected component.
    """

    # Use OpenCV's connectedComponentsWithStats to find connected components
    _, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)

    # Find the index of the largest connected component (excluding background)
    largest_component_index = np.argmax(stats[1:, cv2.CC_STAT_AREA]) + 1

    # Create a mask with only the largest connected component
    largest_component_mask = np.uint8(labels == largest_component_index)

    return largest_component_mask

In [None]:
# Load image
image = cv2.imread(PREDICTED_MASK_DIR_PATH)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(10, 10))
plt.imshow(image)
plt.axis("off")
plt.show()

In [None]:
# Preprocess image
image_resized = cv2.resize(image, (256, 256)) / 255.0
image_input = np.expand_dims(image_resized, axis=0)

# Load model weights
model.load_weights(MODEL_CHECKPOINT_PATH)

# Predict the mask
pred_mask = (model.predict(image_input) > 0.5).astype(np.uint8).reshape(256, 256)

# Resize the mask to the original image size
fingertip_mask = cv2.resize(pred_mask, (image.shape[1], image.shape[0]))

# Apply morphological operations (closing) to remove small holes in the mask
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (11, 11))
closing_mask = cv2.morphologyEx(fingertip_mask, cv2.MORPH_CLOSE, kernel, iterations=2)

# Apply bilateral filter
d = 15  # Diameter of pixel neighborhood for filtering
sigma_color = 75  # Filter sigma in the color space
sigma_space = 75  # Filter sigma in the coordinate space
bilateral_mask = cv2.bilateralFilter(closing_mask, d, sigma_color, sigma_space)

# Find the largest connected component
lcc_mask = find_largest_connected_component(bilateral_mask)

# Apply dilation
mask = cv2.dilate(lcc_mask, kernel, iterations=2)

plt.figure(figsize=(20, 15))
plt.subplot(3, 3, 1)
plt.imshow(image)
plt.title("Original Image")
plt.axis("off")
plt.subplot(3, 3, 2)
plt.imshow(fingertip_mask, cmap="gray")
plt.title("Predicted Fingertip Mask")
plt.axis("off")
plt.subplot(3, 3, 3)
plt.imshow(closing_mask, cmap="gray")
plt.title("Morhological Closing")
plt.axis("off")
plt.subplot(3, 3, 4)
plt.imshow(bilateral_mask, cmap="gray")
plt.title("Bi-Lateral Filter")
plt.axis("off")
plt.subplot(3, 3, 5)
plt.imshow(lcc_mask, cmap="gray")
plt.title("Largest Connected Component")
plt.axis("off")
plt.subplot(3, 3, 6)
plt.imshow(mask, cmap="gray")
plt.title("Morhological Dilation")
plt.axis("off")
plt.show()

In [None]:
def extract_roi(mask: np.ndarray) -> tuple[int, int, int, int]:
    """
    Extract ROI from a binary mask
    Args:
        mask: Binary mask.
    Returns:
        Tuple with four coordinates representing the bounding box rectangle.
    """
    cnts, _ = cv2.findContours(
        mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
    )
    cnt = max(cnts, key=cv2.contourArea)
    x, y, w, h = cv2.boundingRect(cnt)

    return (x, y, w, h)

In [None]:
# Extract ROI (finger)
(x, y, w, h) = extract_roi(mask)

# Create a rectangle patch for the ROI
roi_rect = patches.Rectangle((x, y), w, h, linewidth=1, edgecolor="r", facecolor="none")

fig, axes = plt.subplots(1, 2, figsize=(20, 7))
axes[0].imshow(image)
axes[0].add_patch(roi_rect)
axes[0].set_title("Original Image: Fingertip ROI")

# Crop the ROI from the original image
roi = image[y : y + h, x : x + w]
roi_mask = mask[y : y + h, x : x + w]

axes[1].imshow(roi)
axes[1].set_title("Segmented Fingertip")

plt.tight_layout()
plt.show()

## Fingertip - Gamma Correction

In [None]:
from skimage import img_as_float
from skimage import exposure


def plot_img_and_hist(image, axes, bins=256):
    """Plot an image along with its histogram and cumulative histogram."""
    image = img_as_float(image)
    ax_img, ax_hist = axes
    ax_cdf = ax_hist.twinx()

    # Display image
    ax_img.imshow(image, cmap=plt.cm.gray)
    ax_img.set_axis_off()

    # Display histogram
    ax_hist.hist(image.ravel(), bins=bins, histtype="step", color="black")
    ax_hist.ticklabel_format(axis="y", style="scientific", scilimits=(0, 0))
    ax_hist.set_xlabel("Pixel intensity")
    ax_hist.set_xlim(0, 1)
    ax_hist.set_yticks([])

    # Display cumulative distribution
    img_cdf, bins = exposure.cumulative_distribution(image, bins)
    ax_cdf.plot(bins, img_cdf, "r")
    ax_cdf.set_yticks([])

    return ax_img, ax_hist, ax_cdf

### Fingertip Enhancement

In [None]:
# Convert to grayscale
fingertip_gray = cv2.cvtColor(roi, cv2.COLOR_RGB2GRAY)

# Gamma
gamma_corrected = exposure.adjust_gamma(fingertip_gray, 2)

# Logarithmic
logarithmic_corrected = exposure.adjust_log(fingertip_gray, 1)

# Display results
fig = plt.figure(figsize=(15, 7))
axes = np.zeros((2, 3), dtype=object)
axes[0, 0] = plt.subplot(2, 3, 1)
axes[0, 1] = plt.subplot(2, 3, 2, sharex=axes[0, 0], sharey=axes[0, 0])
axes[0, 2] = plt.subplot(2, 3, 3, sharex=axes[0, 0], sharey=axes[0, 0])
axes[1, 0] = plt.subplot(2, 3, 4)
axes[1, 1] = plt.subplot(2, 3, 5)
axes[1, 2] = plt.subplot(2, 3, 6)

ax_img, ax_hist, ax_cdf = plot_img_and_hist(fingertip_gray, axes[:, 0])
ax_img.set_title("Grayscale Fingertip")

y_min, y_max = ax_hist.get_ylim()
ax_hist.set_ylabel("Number of pixels")
ax_hist.set_yticks(np.linspace(0, y_max, 5))

ax_img, ax_hist, ax_cdf = plot_img_and_hist(gamma_corrected, axes[:, 1])
ax_img.set_title("Gamma correction")

ax_img, ax_hist, ax_cdf = plot_img_and_hist(logarithmic_corrected, axes[:, 2])
ax_img.set_title("Logarithmic correction")

ax_cdf.set_ylabel("Fraction of total intensity")
ax_cdf.set_yticks(np.linspace(0, 1, 5))

# prevent overlap of y-axis labels
fig.tight_layout()
plt.show()

### Fingertip - Enhancement

In [None]:
# 1. Remove speckle noise by applying a median blur
median = cv2.medianBlur(logarithmic_corrected, 5)

# 2. Apply adaptive histogram equalization (CLAHE) to mitigate the effect of illumination variations
equalized = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)).apply(median)

# 3. Sharpen the image by substracting the Gaussian blurred image (sigma=2) from the previous image
gaussian = cv2.GaussianBlur(equalized, (7, 7), 2)
sharpened = cv2.addWeighted(equalized, 1.5, gaussian, -0.5, 0)

enhancements = [
    logarithmic_corrected,
    median,
    equalized,
    gaussian,
    sharpened,
    # non_local_means
]
enhancements_titles = [
    "Original Image",
    "Median Blur (kernel=5x5)",
    "Adaptive Histogram Equalization",
    "Gaussian Blur (kernel=7x7, sigma=2)",
    "Sharpened",
]

plt.figure(figsize=(15, 20))

for i, (enhancement, title) in enumerate(zip(enhancements, enhancements_titles)):
    plt.subplot(5, 2, 2 * i + 1)
    plt.imshow(enhancement, cmap="gray")
    plt.title(title)

    plt.subplot(5, 2, 2 * i + 2)
    hist = cv2.calcHist([enhancement], [0], None, [256], [0, 256])
    hist /= hist.sum()
    plt.plot(hist)
    plt.xlim([0, 256])
    plt.xlabel("Pixel Value")
    plt.ylabel("Frequency")

plt.tight_layout()
plt.show()