In [None]:
import cv2
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from mobiofp.iqa import estimate_noise, laplacian_sharpness, rms_contrast, gradient_magnitude

%matplotlib inline

In [None]:
RAW_IMAGE_DIR = "RAW_IMAGE_DIR"
MASK_IMAGE_DIR = "MASK_IMAGE_DIR"

# Generate by running `fpctl quality report INPUT_DIR OUTPUT_DIR`
IQA_FILE = "IQA_FILE"
BRISQUE_MODEL = "../models/brisque_model_live.yml"
BRISQUE_RANGE = "../models/brisque_range_live.yml"

In [None]:
def enh1(image, diameter=7, sigma_color=100, sigma_space=100, clip_limit=2.0, tile_grid_size=8):
    result = cv2.bilateralFilter(image, d=diameter, sigmaColor=sigma_color, sigmaSpace=sigma_space)
    result = cv2.createCLAHE(
        clipLimit=clip_limit, tileGridSize=(tile_grid_size, tile_grid_size)
    ).apply(result)

    return result


def enh2(image, clip_limit=2.0, tile_grid_size=8):
    result = cv2.medianBlur(image, 5)
    result = cv2.createCLAHE(
        clipLimit=clip_limit, tileGridSize=(tile_grid_size, tile_grid_size)
    ).apply(result)
    blurred = cv2.GaussianBlur(result, (0, 0), 2)
    result = cv2.addWeighted(result, 1.5, blurred, -0.5, 0)

    return result


def plot_dataframe(df, plot_func, title, suptitle, **kwargs):
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    for col, ax in zip(df.columns[1:], axes.flat):
        plot_func(df[col], ax=ax, **kwargs)
        ax.set_title(f"{title} of {col}")

    # Adjust layout (optional)
    fig.suptitle(f"{title} of all columns in the dataframe")
    plt.tight_layout()
    plt.show()


def plot_correlation_heatmap(df):
    plt.figure(figsize=(10, 7))
    corr = df.drop(columns=["Image name"]).corr()
    mask = np.triu(np.ones_like(corr))
    heat = sns.heatmap(corr, annot=True, mask=mask, vmin=-1, vmax=1, cmap="BrBG")
    heat.set_title("Correlation Heatmap", fontdict={"fontsize": 12}, pad=12)
    plt.show()


def remove_outliers_iqr(df):
    for col in df.columns[1:]:
        Q1 = df[col].quantile(0.25)
        Q3 = df[col].quantile(0.75)
        IQR = Q3 - Q1

        # Filter out the outliers
        return df[(df[col] >= Q1 - 1.5 * IQR) & (df[col] <= Q3 + 1.5 * IQR)]


def filter_top_images(df, column, n=10, largest=True):
    if largest:
        top_n = df.nlargest(n, column)
    else:
        top_n = df.nsmallest(n, column)

    condition = df["Image name"].isin(top_n["Image name"])

    return df[condition]


def read_images(src_dir, image_names):
    images = []
    images_name = []
    for image_name in image_names:
        image = cv2.imread(f"{src_dir}/{image_name}", cv2.IMREAD_GRAYSCALE)
        images.append(image)
        images_name.append(image_name)

    return images, images_name


def plot_images(images, titles=None, rows=1, cols=None, sup_title=None, show_axis=False):
    num_images = len(images)

    if titles is None:
        titles = ["Image {}".format(i + 1) for i in range(num_images)]

    if cols is None:
        cols = num_images // rows + (1 if num_images % rows else 0)

    _, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4))

    for i, ax in enumerate(axes.flat):
        cmap = "gray" if len(images[i].shape) == 2 else None
        if i < num_images:
            ax.imshow(images[i], cmap=cmap)
            ax.set_title(titles[i])
            ax.axis("on" if show_axis else "off")
        else:
            # Do not show the remaining subplots
            ax.axis("off")

    if sup_title:
        plt.suptitle(sup_title)

    plt.tight_layout()
    plt.show()


def compute_metrics(images, images_name):
    brisquer = cv2.quality.QualityBRISQUE_create(BRISQUE_MODEL, BRISQUE_RANGE)
    df = pd.DataFrame(columns=["Image name", "Laplacian", "Noise", "Contrast", "BRISQUE"])

    for i, (image, image_name) in enumerate(zip(images, images_name)):
        laplacian = laplacian_sharpness(image)
        noise = estimate_noise(image)
        contrast = rms_contrast(image)
        brisque = brisquer.compute(image)[0]

        df.loc[i] = [image_name, laplacian, noise, contrast, brisque]

    return df

In [None]:
df = pd.read_csv(IQA_FILE)
df.head()

In [None]:
df.describe()

In [None]:
plot_dataframe(df, sns.histplot, "Histogram", "Distribution", kde=True)
plot_dataframe(df, sns.boxplot, "Boxplot", "Distribution")
plot_correlation_heatmap(df)

## Remove Outliers

In [None]:
df_no_outliers = remove_outliers_iqr(df)
plot_dataframe(df_no_outliers, sns.histplot, "Histogram", "Distribution", kde=True)
plot_dataframe(df_no_outliers, sns.boxplot, "Boxplot", "Distribution", showfliers=False)
plot_correlation_heatmap(df_no_outliers)
df_no_outliers.describe()

## Pre-Enhancement (BRISQUE Evaluation)

In [None]:
brisque_threshold = 5
df_high_brisque = df_no_outliers[df_no_outliers["BRISQUE"] <= brisque_threshold]
df_high_brisque.describe()
high_brisque_images, high_brisque_titles = read_images(RAW_IMAGE_DIR, df_high_brisque["Image name"])

print(
    f"# of WI images: { df_high_brisque[df_high_brisque['Image name'].str.contains('w')].shape[0]}"
)
print(
    f"# of NO images: { df_high_brisque[df_high_brisque['Image name'].str.contains('n')].shape[0]}"
)
plot_images(high_brisque_images, high_brisque_titles, rows=14, cols=8, show_axis=False)

In [None]:
brisque_magnitude_images = [gradient_magnitude(image) for image in high_brisque_images]
plot_images(brisque_magnitude_images, high_brisque_titles, rows=14, cols=8, show_axis=False)

In [None]:
from pathlib import Path

df_high_brisque_copy = df_high_brisque.copy()

for index, row in df_high_brisque_copy.iterrows():
    image_name = row["Image name"]
    subject_id, illumination, finger_id, background, impression_id = Path(image_name).stem.split(
        "_"
    )
    df_high_brisque_copy.loc[index, "Subject ID"] = subject_id
    df_high_brisque_copy.loc[index, "Illumination"] = illumination
    df_high_brisque_copy.loc[index, "Finger ID"] = finger_id
    df_high_brisque_copy.loc[index, "Background"] = background
    df_high_brisque_copy.loc[index, "Impression ID"] = impression_id

In [None]:
result = df_high_brisque_copy.groupby(
    ["Subject ID", "Illumination", "Finger ID", "Background"]
).size()

plt.figure(figsize=(15, 10))
result.plot(kind="bar")

## Enhancement 1 (Bilateral Filter + CLAHE)

In [None]:
bil_images = [
    cv2.bilateralFilter(image, d=7, sigmaColor=100, sigmaSpace=100) for image in high_brisque_images
]
bil_metrics = compute_metrics(bil_images, high_brisque_titles)
plot_images(
    bil_images, high_brisque_titles, rows=14, cols=8, sup_title="Bilateral", show_axis=False
)

In [None]:
bilateral_magnitude_images = [gradient_magnitude(image) for image in bil_images]
plot_images(bilateral_magnitude_images, high_brisque_titles, rows=14, cols=8, show_axis=False)

In [None]:
clahe_images = [
    cv2.createCLAHE(clipLimit=2.0, tileGridSize=(16, 16)).apply(image)
    for image in high_brisque_images
]
clahe_metrics = compute_metrics(clahe_images, high_brisque_titles)
plot_images(clahe_images, high_brisque_titles, rows=14, cols=8, sup_title="CLAHE", show_axis=False)

In [None]:
clahe_magnitude_images = [gradient_magnitude(image) for image in clahe_images]
plot_images(clahe_magnitude_images, high_brisque_titles, rows=14, cols=8, show_axis=False)

## Enhacement 1 Analysis

In [None]:
plot_dataframe(df_high_brisque, sns.histplot, "Histogram", "Distribution", kde=True)
plot_dataframe(bil_metrics, sns.histplot, "Histogram", "Distribution", kde=True)
plot_dataframe(clahe_metrics, sns.histplot, "Histogram", "Distribution", kde=True)

In [None]:
print(df_high_brisque.describe())
print(bil_metrics.describe())
print(clahe_metrics.describe())

## Mean Adaptive Thresholding 

In [None]:
masks_paths = [f"{image_name}" for image_name in high_brisque_titles]
masks_paths = [path.replace(".jpg", ".png") for path in masks_paths]
masks = read_images(MASK_IMAGE_DIR, masks_paths)[0]

plot_images(masks, high_brisque_titles, rows=14, cols=8, sup_title="Masks", show_axis=False)

In [None]:
nobg_images = [cv2.bitwise_and(image, image, mask=mask) for image, mask in zip(clahe_images, masks)]
plot_images(
    nobg_images, high_brisque_titles, rows=14, cols=8, sup_title="Thresholded", show_axis=False
)

In [None]:
th_images = [
    cv2.adaptiveThreshold(image, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 15, 2)
    for image in nobg_images
]
plot_images(
    th_images, high_brisque_titles, rows=14, cols=8, sup_title="Thresholded", show_axis=False
)

## Enhancement 2 (Normalization + CLAHE + Gaussian Blur)

In [None]:
normalized_images = [
    cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX) for image in high_brisque_images
]
normalized_metrics = compute_metrics(normalized_images, high_brisque_titles)
plot_images(
    normalized_images, high_brisque_titles, rows=14, cols=8, sup_title="Normalized", show_axis=False
)

In [None]:
clahe2_images = [
    cv2.createCLAHE(clipLimit=2.0, tileGridSize=(7, 7)).apply(image) for image in normalized_images
]
clahe2_metrics = compute_metrics(clahe2_images, high_brisque_titles)
plot_images(clahe2_images, high_brisque_titles, rows=14, cols=8, sup_title="CLAHE", show_axis=False)

In [None]:
gaussian_images = [cv2.GaussianBlur(image, (5, 5), 0) for image in clahe2_images]
gaussian_metrics = compute_metrics(gaussian_images, high_brisque_titles)
plot_images(
    gaussian_images, high_brisque_titles, rows=14, cols=8, sup_title="Gaussian", show_axis=False
)

In [None]:
plot_dataframe(normalized_metrics, sns.histplot, "Histogram", "Distribution", kde=True)
plot_dataframe(clahe2_metrics, sns.histplot, "Histogram", "Distribution", kde=True)
plot_dataframe(gaussian_metrics, sns.histplot, "Histogram", "Distribution", kde=True)

In [None]:
print(df_high_brisque.describe())
print("Enhancement 1")
print(bil_metrics.describe())
print(clahe_metrics.describe())
print("Enhancement 2")
print(normalized_metrics.describe())
print(clahe2_metrics.describe())
print(gaussian_metrics.describe())

## Mean Adaptive Thresholding

In [None]:
nobg2_images = [
    cv2.bitwise_and(image, image, mask=mask) for image, mask in zip(gaussian_images, masks)
]
th2_images = [
    cv2.adaptiveThreshold(image, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV, 15, 2)
    for image in nobg2_images
]
plot_images(
    th2_images, high_brisque_titles, rows=14, cols=8, sup_title="Thresholded", show_axis=False
)

In [None]:
# Apply opening to remove noise
kernel = np.ones((3, 3), np.uint8)
opening_images = [cv2.morphologyEx(image, cv2.MORPH_OPEN, kernel) for image in th2_images]
plot_images(
    opening_images, high_brisque_titles, rows=14, cols=8, sup_title="Opening", show_axis=False
)

In [None]:
# Apply dilation to fill holes
dilated_images = [cv2.dilate(image, kernel, iterations=1) for image in opening_images]
plot_images(
    dilated_images, high_brisque_titles, rows=14, cols=8, sup_title="Dilated", show_axis=False
)

In [None]:
# Apply thinninig to reduce the thickness of the ridges
skeleton_images = [cv2.ximgproc.thinning(image) for image in dilated_images]
plot_images(
    skeleton_images, high_brisque_titles, rows=1, cols=4, sup_title="Skeleton", show_axis=False
)