# Setup

## Imports

In [None]:
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans

from pylibdmtx.pylibdmtx import decode

from ultralytics import YOLO

import os

## Template Matching Definitions

In [None]:
def single_scale_retinex(image, sigma=30):
    """
    Does single scale retinex on the input image.

    Args:
        image: Input image (numpy array)
        sigma: Gaussian kernel size (default is 30)
    
    Returns:
        Tuple of reflectance and illumination images
    """
    image = image.astype(np.float32) + 1.0
    illumination = cv2.GaussianBlur(image, (0, 0), sigma)
    illumination += 1.0
    reflectance = np.log(image) - np.log(illumination)
    reflectance_display = cv2.normalize(reflectance, None, 0, 255, cv2.NORM_MINMAX)
    reflectance_display = reflectance_display.astype(np.uint8)
    illumination_display = cv2.normalize(illumination, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
    return reflectance_display, illumination_display

In [None]:
def non_max_suppression_fast(boxes, scores, overlap_thresh=0.3):
    """
    Perform non-maximum suppression on the bounding boxes.

    Args:
        boxes: List of bounding boxes (x, y, width, height)
        scores: List of scores for each bounding box
        overlap_thresh: Overlap threshold for suppression (default is 0.3)
    
    Returns:
        List of bounding boxes after non-maximum suppression
    """
    if len(boxes) == 0:
        return []
    boxes = np.array(boxes)
    scores = np.array(scores)
    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 0] + boxes[:, 2]
    y2 = boxes[:, 1] + boxes[:, 3]
    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    idxs = scores.argsort()[::-1]
    keep = []
    while len(idxs) > 0:
        i = idxs[0]
        keep.append(i)
        xx1 = np.maximum(x1[i], x1[idxs[1:]])
        yy1 = np.maximum(y1[i], y1[idxs[1:]])
        xx2 = np.minimum(x2[i], x2[idxs[1:]])
        yy2 = np.minimum(y2[i], y2[idxs[1:]])
        w = np.maximum(0, xx2 - xx1 + 1)
        h = np.maximum(0, yy2 - yy1 + 1)
        inter = w * h
        overlap = inter / (areas[i] + areas[idxs[1:]] - inter)
        idxs = idxs[1:][overlap < overlap_thresh]
    return boxes[keep]

In [None]:
def extract_dominant_dot_template(image, min_area=20, max_area=300, patch_size=(24, 24), offset=5, size_tol=0.5):
    """
    Extracts the dominant dot template from the image.
    The function applies a series of image processing techniques to identify and extract the dot template.

    Args:
        image: Input image (numpy array)
        min_area: Minimum area of the dot to be considered (default is 20)
        max_area: Maximum area of the dot to be considered (default is 300)
        patch_size: Size of the patch to be extracted (default is (24, 24))
        offset: Offset for bounding box around the detected dot (default is 5)
        size_tol: Tolerance for size consistency (default is 0.5)

    Returns:
        Tuple of the extracted patch and contours of the detected dots.
    
    Raises:
        ValueError: If no valid dot candidates are found or if no size-consistent patches are found.
    """
    image_clean = cv2.bilateralFilter(image, d=15, sigmaColor=50, sigmaSpace=5)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    image_clean = clahe.apply(image_clean)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (16, 16))
    tophat = cv2.morphologyEx(image_clean, cv2.MORPH_BLACKHAT, kernel)

    _, binary_top = cv2.threshold(tophat, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    contours, _ = cv2.findContours(binary_top, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    candidates = []
    sizes = []
    img_w, img_h = image.shape

    for cnt in contours:
        area = cv2.contourArea(cnt)
        if min_area < area < max_area:
            x, y, w, h = cv2.boundingRect(cnt)
            crop_x_start = x - offset
            crop_x_end = x + w + offset
            crop_y_start = y - offset
            crop_y_end = y + h + offset

            if crop_x_start < 0 or crop_x_end >= img_w or crop_y_start < 0 or crop_y_end >= img_h:
                continue

            patch = image[crop_y_start:crop_y_end, crop_x_start:crop_x_end]
            candidates.append((patch, h, w))
            sizes.append((h, w))

    if not candidates:
        raise ValueError("No valid dot candidates found.")

    # Compute median size
    heights = [s[0] for s in sizes]
    widths = [s[1] for s in sizes]
    median_area = np.median(heights) * np.median(widths)

    # Keep only patches with similar size
    patches_filtered = []
    resized_for_matching = []
    for (patch, h, w) in candidates:
        # print(abs(h * w - median_area))
        if abs(h * w - median_area) / median_area < size_tol:
            patches_filtered.append(patch)
            resized_for_matching.append(cv2.resize(patch, patch_size))

    if not patches_filtered:
        raise ValueError("No size-consistent patches found.")

    # Find patch closest to the median template
    stack = np.stack(resized_for_matching, axis=0).astype(np.float32)
    median_template = np.median(stack, axis=0)
    diffs = [np.linalg.norm(p.astype(np.float32) - median_template) for p in resized_for_matching]
    best_idx = np.argmin(diffs)

    return patches_filtered[best_idx], contours

In [None]:
def contours_from_patch(patch):
    """
    Extracts contours from supplied patch image.
    """
    _, binary_patch = cv2.threshold(patch, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    contours, _ = cv2.findContours(binary_patch, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    return contours

In [None]:
def display_image(image, size=(300, 300)):
    """
    Displays the numpy image using PIL and notebook display functionality.

    Args:
        image: Input image (numpy array)
        size: Size to which the image should be resized (default is (300, 300))
    
    Returns:
        None
    """
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, size)
    pil_image = Image.fromarray(image)
    display(pil_image)

In [None]:
def display_yucheng_methods(nms_boxes, reflectance, dot_contours, img, illumination, dot_template):
    """
    Displays the results of Yuchengs methods for dot detection and template matching.

    Args:
        nms_boxes: List of bounding boxes after non-maximum suppression
        reflectance: Reflectance map (numpy array)
        dot_contours: Contours of the detected dots
        img: Original image (numpy array)
        illumination: Estimated illumination (numpy array)
        dot_template: Dot template (numpy array)
    
    Returns:
        None
    """
    # === Draw matching result ===
    output = cv2.cvtColor(reflectance, cv2.COLOR_GRAY2BGR)
    for (x, y, w, h) in nms_boxes:
        cv2.rectangle(output, (x, y), (x + w, y + h), (0, 255, 0), 2)

    # === Draw contours over reflectance ===
    contour_vis = cv2.cvtColor(reflectance, cv2.COLOR_GRAY2BGR)
    cv2.drawContours(contour_vis, dot_contours, -1, (0, 0, 255), 1)

    # === Show results ===
    fig, axs = plt.subplots(2, 3, figsize=(10, 6))
    axs[0, 0].imshow(img, cmap='gray')
    axs[0, 0].set_title("Original Image")
    axs[0, 0].axis("off")

    axs[0, 1].imshow(illumination, cmap='gray')
    axs[0, 1].set_title("Estimated Illumination")
    axs[0, 1].axis("off")

    axs[0, 2].imshow(reflectance, cmap='gray')
    axs[0, 2].set_title("Reflectance Map (SSR)")
    axs[0, 2].axis("off")

    axs[1, 0].imshow(cv2.cvtColor(contour_vis, cv2.COLOR_BGR2RGB))
    axs[1, 0].set_title("Dot Contours")
    axs[1, 0].axis("off")

    axs[1, 1].imshow(dot_template, cmap='gray')
    axs[1, 1].set_title("Dot template (median of patches)")
    axs[1, 1].axis("off")

    axs[1, 2].imshow(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))
    axs[1, 2].set_title("Template matching")
    axs[1, 2].axis("off")

    plt.tight_layout()
    plt.show()

# YOLO Model Loading

In [None]:
# # train4 is original oriented YOLO no rotation lock
# model = YOLO(f'../yolo/runs/obb/train4/weights/best.pt') # load best model from training
# model.eval() # Set the model to evaluation mode

# Keeping Relevant YOLO Crops Only

In [27]:
# raw_imgs = os.listdir('../data/dot_detection/MAN/raw/')
# raw_imgs = [os.path.join('../data/dot_detection/MAN/raw/', img) for img in raw_imgs]

# misses = 0
# for img_path in raw_imgs:
#     result = model(img_path, verbose=False)[0]

#     if len(result) == 0:
#         misses += 1
#         print(f'Missed image: {img_path}')
#         continue

#     # get xywhr from result
#     xywhr = result.obb.xywhr.cpu().numpy()

#     if len(xywhr) == 0:
#         misses += 1
#         print(f'Missed image: {img_path}')
#         continue

#     # load image
#     img_path = result.path
#     img = cv2.imread(img_path)

#     # display image
#     # display_image(img, size=(300, 300))

#     # get coords for cropping
#     x, y, w, h, r = xywhr[0]
#     x = int(x)
#     y = int(y)
#     w = int(w)
#     h = int(h)
#     r = int(r)
#     x1 = int(x - w / 2)
#     y1 = int(y - h / 2)
#     x2 = int(x + w / 2)
#     y2 = int(y + h / 2)

#     # adding some padding to the cropped image
#     pac_perc = 0.5
#     pad = int(pac_perc * max(w, h))
#     x1 -= pad
#     y1 -= pad
#     x2 += pad
#     y2 += pad

#     x1 = max(0, x1)
#     y1 = max(0, y1)
#     x2 = min(img.shape[1], x2)
#     y2 = min(img.shape[0], y2)
#     img = img[y1:y2, x1:x2]
#     if img.size == 0:  # Check if the cropped image is empty
#         misses += 1
#         print(f'Missed image: {img_path}')
#         continue

#     # display cropped image
#     display_image(img, size=(300, 300))

# print(f'Misses: {misses}')

# Mapping Dictionary for Originals and Manual Template Crops

In [None]:
# loop through raw dir and map to same filenames in template dir
mapping = dict()

raw_dir = '../data/dot_detection/MAN/raw/'
template_dir = '../data/dot_detection/MAN/templates/'

# sanity check equal number of files in both directories
if len(os.listdir(raw_dir)) != len(os.listdir(template_dir)):
    print("MISSING TEMPLATES!!!")
else:
    print("got enough templates :)")

mapping = dict()
for raw_file in os.listdir(raw_dir):
    raw_file_path = os.path.join(raw_dir, raw_file)
    template_file = raw_file.split('.')[0] + '.jpg'
    template_file_path = os.path.join(template_dir, template_file)
    if os.path.exists(template_file_path):
        mapping[raw_file] = template_file_path
    else:
        print(f"Missing template for {raw_file}")
print(mapping)

# Testing Template Matching

In [None]:
# === Load image (grayscale) ===
img_to_test = "../data/dot_detection/MAN/raw/1D1165212740006.jpeg"
template_to_test = "../data/dot_detection/MAN/templates/1D1165212740006.jpg"
img = cv2.imread(img_to_test, cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (320, 320))
dot_template = cv2.imread(template_to_test, cv2.IMREAD_GRAYSCALE)
dot_template = cv2.resize(dot_template, (19, 19))

In [None]:
# === Apply Retinex ===
reflectance, illumination = single_scale_retinex(img, sigma=64)

In [None]:
dot_contours = contours_from_patch(dot_template)

In [None]:
# === Template matching ===
result = cv2.matchTemplate(reflectance, dot_template, cv2.TM_CCOEFF_NORMED)
threshold = 0.7
locations = zip(*np.where(result >= threshold)[::-1])
scores = result[result >= threshold].flatten()

In [None]:
# === Bounding boxes (x, y, w, h) for each match ===
h, w = dot_template.shape
boxes = [(int(x), int(y), w, h) for (x, y) in locations]

In [None]:
# === Apply NMS ===
nms_boxes = non_max_suppression_fast(boxes, scores, overlap_thresh=0.3)

In [None]:
display_yucheng_methods(nms_boxes, reflectance, dot_contours, img, illumination, dot_template)

# Template Matching for Label Data