# Setup

In [None]:
# standard imports
import os
import math
from collections import Counter

# third-party imports
from ultralytics import YOLO

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.io import read_image
from torchvision import transforms

import cv2
from PIL import Image

import numpy as np

import matplotlib.pyplot as plt

from scipy.optimize import minimize, linear_sum_assignment
from scipy.spatial.distance import cdist
from scipy.spatial import ConvexHull
import scipy.ndimage as ndi

from sklearn.cluster import DBSCAN

from skimage.feature import peak_local_max
from skimage.segmentation import watershed

from pylibdmtx.pylibdmtx import decode

In [None]:
# set device for PyTorch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Simple Helper Functions

In [None]:
# simple helper functions
def get_mode(lst):
    """Get the mode of a list."""
    return Counter(lst).most_common(1)[0][0]

def load_image_to_device(image_path):
    """Load an image to the device."""
    image = read_image(image_path).to(device) / 255.0  # Normalize to [0, 1]

    # if image has 4 channels, convert to 3 channels
    if image.shape[0] == 4:
        image = image[:3, :, :]

    return image.unsqueeze(0) # Add batch dimension

# YOLO Cropping

In [None]:
def load_yolo(model_path):
    """
    Load the YOLOv11 model from the specified path.

    Args:
        model_path (str): Path to the YOLOv11 model file.

    Returns:
        YOLO: The loaded YOLOv11 model.
    """
    model = YOLO(model_path)
    model.fuse()  # fuse model for faster inference
    model.eval()  # set model to evaluation mode
    return model

def yolo_detect(image, model, debug=False):
    """
    Detect object with highest confidence in an image using YOLOv11.

    Args:
        image (torch.Tensor): The input image as a tensor.
        model (YOLO): The YOLOv11 model.
        debug (bool): If True, display the image with bounding box.

    Returns:
        list: xywh bounding box for object detected in the image.
    """
    results = model.predict(image, verbose=False)

    if len(results) == 0:
        print("YOLO did not detect any objects.")
        return None

    result = results[0]
    if len(result.boxes.xywh) == 0:
        print("YOLO did not detect any objects.")
        return None

    if debug:
        result.save(filename='deleteme.jpg')

    return result.boxes.xywh[0]

def yolo_crop(image, xywh, pad):
    """
    Crop the image using YOLOv11 bounding box.

    Args:
        image (torch.Tensor): The input image as a tensor.
        xywh (torch.Tensor): xywh bounding box for object detected in the image.

    Returns:
        torch.Tensor: The cropped image.
    """
    x, y, w, h = xywh
    x, y, w, h = float(x), float(y), float(w), float(h)

    # padding by pad% of longest side
    pad = max(image.shape[2:]) * pad
    w_crop = int(w + 2 * pad)
    h_crop = int(h + 2 * pad)

    # image dimensions
    _, _, H, W = image.shape

    # crop to the bounding box
    x1 = int(max(0, x - w_crop / 2))
    y1 = int(max(0, y - h_crop / 2))
    x2 = int(min(W, x + w_crop / 2))
    y2 = int(min(H, y + h_crop / 2))
    cropped_image = image[:, :, y1:y2, x1:x2]
    cropped_image = F.interpolate(cropped_image, size=(h_crop, w_crop), mode='bilinear', align_corners=False)
    cropped_image = torch.clamp(cropped_image, 0, 1)  # Ensure values are in [0, 1]

    return cropped_image

def yolo_detect_and_crop(image, model, pad, debug=False):
    """
    Detects an crops down to an object in the image using YOLOv11.

    Args:
        image (torch.Tensor): The input image as a tensor.
        model (YOLO): The YOLOv11 model.
        debug (bool): If True, display the cropped image.

    Returns:
        torch.Tensor: The cropped image.
    """
    image_yolo = transforms.Resize((640, 640))(image.clone()) # resize image to 640x640

    # sometimes resize creates values above 1.0, so clamp to [0, 1]
    image_yolo = torch.clamp(image_yolo, 0, 1)

    with torch.no_grad():
        xywh = yolo_detect(image_yolo, model, debug)

    if xywh is None:
        print('WARNING, no object detected, returning original image')
        return image # return original image if no object detected

    # scale detection back to original image size
    _, _, H_orig, W_orig = image.shape

    scale_x = W_orig / 640
    scale_y = H_orig / 640

    x, y, w, h = [float(v) for v in xywh]
    x *= scale_x
    y *= scale_y
    w *= scale_x
    h *= scale_y
    xywh_scaled = torch.tensor([x, y, w, h], dtype=torch.float32, device=image.device)

    image_cropped = yolo_crop(image, xywh_scaled, pad)

    if debug:
        figure, axis = plt.subplots(1, 3, figsize=(15, 5))

        axis[0].imshow(image.squeeze(0).permute(1, 2, 0).cpu().numpy())
        axis[0].axis('off')
        axis[0].set_title('Original Image')

        yolo_detection = Image.open('deleteme.jpg')
        axis[1].imshow(yolo_detection)
        axis[1].axis('off')
        axis[1].set_title('YOLO Detection')
        os.remove('deleteme.jpg')

        axis[2].imshow(image_cropped.squeeze(0).permute(1, 2, 0).cpu().numpy())
        axis[2].axis('off')
        axis[2].set_title('YOLO Crop (Post Padding)')
        plt.show()

    return image_cropped

# === load image ===
img_to_test = '../data/MAN/raw/train/1D1165212740006.jpeg'
image = load_image_to_device(img_to_test)

# === load & run YOLOv11 ===
yolo_path = '../yolo/runs/obb/train7/weights/best.pt' # train4 is original oriented YOLO no rotation lock
pad = 0.01 # % to pad the yolo crop by
yolo_model = load_yolo(yolo_path)
image_yolo = yolo_detect_and_crop(image, yolo_model, pad, debug=True)

# Template Acquiring

## Model Definition

In [None]:
class DoubleConv(nn.Module):
    """(Conv → BN → ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class Down(nn.Module):
    """Downscaling with maxpool → double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.down = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.down(x)


class Up(nn.Module):
    """Upscaling → concat → double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        # pad if needed
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, base_c=64):
        super().__init__()
        self.in_conv = DoubleConv(in_channels, base_c)
        self.down1 = Down(base_c, base_c * 2)
        self.down2 = Down(base_c * 2, base_c * 4)
        self.down3 = Down(base_c * 4, base_c * 8)
        self.down4 = Down(base_c * 8, base_c * 8)

        self.up1 = Up(base_c * 16, base_c * 4)
        self.up2 = Up(base_c * 8, base_c * 2)
        self.up3 = Up(base_c * 4, base_c)
        self.up4 = Up(base_c * 2, base_c)

        self.out_conv = nn.Conv2d(base_c, out_channels, kernel_size=1)

    def forward(self, x):
        x1 = self.in_conv(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

        return torch.sigmoid(self.out_conv(x))

## Model Functions

In [None]:
def compute_retinex_reflectance_torch(img_tensor, sigma=30):
    """Reflectance extraction using Retinex algorithm."""
    eps = 1e-6
    img = img_tensor.clamp(min=eps) # avoids log(0) without shifting scale
    log_img = torch.log(img)

    def get_gaussian_kernel2d(kernel_size, sigma):
        ax = torch.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.).to(img_tensor.device)
        xx, yy = torch.meshgrid(ax, ax, indexing='ij')
        kernel = torch.exp(-(xx**2 + yy**2) / (2. * sigma**2))
        kernel = kernel / torch.sum(kernel)
        return kernel
    
    # approximate kernel size from sigma
    kernel_size = int(2 * math.ceil(3 * sigma) + 1)
    kernel = get_gaussian_kernel2d(kernel_size, sigma)
    kernel = kernel.to(img_tensor.device).unsqueeze(0).unsqueeze(0)  # Add batch and channel dimensions

    # apply gaussian blur to each channel
    channels = img_tensor.shape[1]
    kernel = kernel.expand(channels, 1, -1, -1) # [C, 1, kH, kW]
    blurred = F.conv2d(img_tensor, kernel, padding=kernel_size//2, groups=channels)
    blurred = blurred.clamp(min=eps) # avoids log(0) without shifting scale
    log_blur = torch.log(blurred)

    reflectance = log_img - log_blur

    def normalize(tensor):
        return (tensor - tensor.amin(dim=(1,2,3), keepdim=True)) / (tensor.amax(dim=(1,2,3), keepdim=True) + eps)

    # normalize to [0, 1]
    reflectance = normalize(reflectance)
    illumination = normalize(log_blur)

    return reflectance, illumination

In [None]:
def load_unet(model_path):
    """
    Load the UNet model from the specified path.

    Args:
        model_path (str): Path to the UNet model file.

    Returns:
        UNet: The loaded UNet model.
    """
    # Load the model
    model = UNet(in_channels=1, out_channels=1).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()  # set model to evaluation mode
    return model

# === Watershed-based template extraction ===
def extract_templates_watershed(heatmap_np, reflectance_np, k_templates, debug=False):
    templates = []
    template_bounds = []

    # 1. Threshold to create binary mask
    thresh_val = np.max(heatmap_np) * 0.1
    binary = (heatmap_np >= thresh_val).astype(np.uint8)

    # 2. Distance transform
    distance = cv2.distanceTransform(binary, cv2.DIST_L2, 5)

    # 3. Local maxima as markers
    coords = peak_local_max(distance, min_distance=10, labels=binary) # can add num_peaks limit
    mask = np.zeros(distance.shape, dtype=bool)
    mask[tuple(coords.T)] = True
    markers, _ = ndi.label(mask)

    # 4. Watershed segmentation
    labels = watershed(-distance, markers, mask=binary)

    # 5. Extract contours and compute blob areas
    blob_areas = []
    blob_data = []

    for label in np.unique(labels):
        if label == 0:
            continue # ignore background
        region_mask = np.zeros_like(heatmap_np, dtype=np.uint8)
        region_mask[labels == label] = 255

        contours, _ = cv2.findContours(region_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        if not contours:
            print(f"No contours found for label {label}.")
            continue

        cnt = max(contours, key=cv2.contourArea)
        area = cv2.contourArea(cnt)
        perimeter = cv2.arcLength(cnt, True)

        # Avoid division by zero
        if perimeter == 0:
            continue

        circularity = (4 * np.pi * area) / (perimeter ** 2)
        blob_areas.append(area)
        blob_data.append((cnt, area, circularity))

    if not blob_areas:
        print("No blobs found.")
        return [], []

    # 6. Compute average area, set thresholds
    avg_area = np.median(blob_areas)
    max_area = 1.5 * avg_area     # Reject blobs much larger than average
    min_circularity = 0.5         # Reject elongated blobs

    for cnt, area, circularity in sorted(blob_data, key=lambda x: x[1], reverse=True): # sorted by area (largest first)
        if area > max_area:
            # print(f"Skipping large blob with area {area:.2f} > {max_area:.2f}")
            continue
        if circularity < min_circularity:
            # print(f"Skipping elongated blob with circularity {circularity:.2f} < {min_circularity:.2f}")
            continue

        x, y, w, h = cv2.boundingRect(cnt)

        pad = int(0.10 * max(w, h)) # 10% padding
        side = max(w, h) + 2 * pad  # padded square
        cx, cy = x + w // 2, y + h // 2
        x1 = max(cx - side // 2, 0)
        y1 = max(cy - side // 2, 0)
        x2 = min(x1 + side, heatmap_np.shape[1])
        y2 = min(y1 + side, heatmap_np.shape[0])

        template = reflectance_np[y1:y2, x1:x2]
        template = template[:, :, 0]  # (H, W, 1) -> (H, W)

        if template.shape[0] != template.shape[1]:
            template = cv2.resize(template, (side, side))

        template = ((template - np.min(template)) / (np.max(template) - np.min(template)) * 255).astype(np.uint8)
        templates.append(template)
        template_bounds.append([x1, y1, side, side])

        if len(templates) >= k_templates:
            break

    if debug and templates:
        if len(templates) > 1:
            figure, axis = plt.subplots(3, len(templates), figsize=(15, 5))
            for i, template in enumerate(templates):
                x1, y1, side1, side2 = template_bounds[i]

                axis[0, i].imshow(heatmap_np, cmap='gray')
                axis[0, i].add_patch(plt.Rectangle((x1, y1), side1, side2, edgecolor='red', facecolor='none', lw=2))
                axis[0, i].set_title(f"T {i+1}")
                axis[0, i].axis('off')

                axis[1, i].imshow(reflectance_np, cmap='gray')
                axis[1, i].add_patch(plt.Rectangle((x1, y1), side1, side2, edgecolor='red', facecolor='none', lw=2))
                axis[1, i].axis('off')

                axis[2, i].imshow(template, cmap='gray')
                axis[2, i].axis('off')
        else:
            figure, axis = plt.subplots(3, 1, figsize=(10, 5))
            x1, y1, side1, side2 = template_bounds[0]
            axis[0].imshow(heatmap_np, cmap='gray')
            axis[0].add_patch(plt.Rectangle((x1, y1), side1, side2, edgecolor='red', facecolor='none', lw=2))
            axis[0].set_title("T 1")
            axis[0].axis('off')

            axis[1].imshow(reflectance_np, cmap='gray')
            axis[1].add_patch(plt.Rectangle((x1, y1), side1, side2, edgecolor='red', facecolor='none', lw=2))
            axis[1].axis('off')

            axis[2].imshow(templates[0], cmap='gray')
            axis[2].axis('off')
        plt.tight_layout()
        plt.show()

    return templates, template_bounds

def unet_get_template(image_yolo, model, k_templates, debug=False):
    """
    Get the template from the image using UNet.

    Args:
        image_yolo (torch.Tensor): The input image as a tensor.
        h (float): Height of the YOLO bounding box.
        w (float): Width of the YOLO bounding box.
        model (UNet): The UNet model.
        debug (bool): Whether to show debug information.

    Returns:
        numpy.ndarray: The template extracted from the image.
    """
    # === prepare image for UNet ===
    image = transforms.Resize((384, 384), Image.BILINEAR)(image_yolo)
    image = transforms.Grayscale(num_output_channels=1)(image) # convert to grayscale
    reflectance, illumination = compute_retinex_reflectance_torch(image, sigma=50)

    # === get UNet heatmap ===
    with torch.no_grad():
        heatmap = model(reflectance)
    heatmap_np = heatmap.squeeze().cpu().numpy()
    templates, template_bounds = extract_templates_watershed(heatmap_np, reflectance.squeeze(0).permute(1, 2, 0).cpu().numpy(), k_templates, debug)

    # === post-process for use in template matching ===
    reflectance = reflectance.squeeze(0).permute(1, 2, 0).cpu().numpy() # (H, W, 1)
    reflectance = reflectance[:, :, 0] # remove extra channel
    illumination = illumination.squeeze(0).permute(1, 2, 0).cpu().numpy() # (H, W, 1)
    illumination = illumination[:, :, 0] # remove extra channel

    # normalize both to [0, 255]
    reflectance = ((reflectance - np.min(reflectance)) / (np.max(reflectance) - np.min(reflectance)) * 255).astype(np.uint8)

    return reflectance, illumination, templates, template_bounds, heatmap_np

# === load image ===
img_to_test = '../data/MAN/raw/train/1D1165212740006.jpeg'
image = load_image_to_device(img_to_test)

# === get yolo crop ===
yolo_path = '../yolo/runs/obb/train7/weights/best.pt'
pad = 0.01 # % to pad the yolo crop by
yolo_model = load_yolo(yolo_path)
image_yolo = yolo_detect_and_crop(image, yolo_model, pad, debug=True)

# === get unet template ===
unet_path = '../models/dot_detection/checkpoints/unet_best.pth'
unet_model = load_unet(unet_path)
reflectance, illumination, templates, template_bounds, heatmap_np = unet_get_template(image_yolo, unet_model, k_templates=3, debug=True)

# Template Matching

In [None]:
def non_max_suppression_fast(boxes, scores, overlap_thresh=0.3):
    """
    Perform non-maximum suppression on the bounding boxes.

    Args:
        boxes: List of bounding boxes (x, y, width, height)
        scores: List of scores for each bounding box
        overlap_thresh: Overlap threshold for suppression (default is 0.3)
    
    Returns:
        List of bounding boxes after non-maximum suppression
    """
    if len(boxes) == 0:
        return []
    boxes = np.array(boxes)
    scores = np.array(scores)
    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 0] + boxes[:, 2]
    y2 = boxes[:, 1] + boxes[:, 3]
    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    idxs = scores.argsort()[::-1]
    keep = []
    while len(idxs) > 0:
        i = idxs[0]
        keep.append(i)
        xx1 = np.maximum(x1[i], x1[idxs[1:]])
        yy1 = np.maximum(y1[i], y1[idxs[1:]])
        xx2 = np.minimum(x2[i], x2[idxs[1:]])
        yy2 = np.minimum(y2[i], y2[idxs[1:]])
        w = np.maximum(0, xx2 - xx1 + 1)
        h = np.maximum(0, yy2 - yy1 + 1)
        inter = w * h
        overlap = inter / (areas[i] + areas[idxs[1:]] - inter)
        idxs = idxs[1:][overlap < overlap_thresh]
    return boxes[keep]

In [None]:
def template_matching(reflectance, template, method, match_thresh=0.7, nms_thresh=0.3):
    """
    Perform template matching to find the best match for the template in the reflectance image.

    Args:
        reflectance: Input reflectance image (numpy array)
        templates: Template image (numpy array)
        method: Method for template matching (default is cv2.TM_CCOEFF_NORMED)
        match_thresh: Threshold for template matching (default is 0.7)
        nms_thresh: Threshold for non-maximum suppression (default is 0.3)

    Returns:
        List of bounding boxes for the detected matches.
    """
    # === Template matching ===
    result = cv2.matchTemplate(reflectance, template, method)
    locations = zip(*np.where(result >= match_thresh)[::-1])
    scores = result[result >= match_thresh].flatten()

    # === Bounding boxes (x, y, w, h) for each match ===
    h, w = template.shape
    boxes = [(int(x), int(y), w, h) for (x, y) in locations]

    # === Apply NMS ===
    nms_boxes = non_max_suppression_fast(boxes, scores, overlap_thresh=nms_thresh)

    return nms_boxes

In [None]:
def contours_from_patch(patch):
    """
    Extracts contours from supplied patch image.
    """
    _, binary_patch = cv2.threshold(patch, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    contours, _ = cv2.findContours(binary_patch, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    return contours

In [None]:
def hu_descriptor(patch):
    """
    Computes Hu moments for a given patch.

    Args:
        patch: Input patch (numpy array)
    
    Returns:
        Hu moments of the patch (numpy array)
    """
    contours = contours_from_patch(patch)
    if not contours:
        return np.zeros(7)
    largest_contour = max(contours, key=cv2.contourArea)
    moments = cv2.moments(largest_contour)
    hu_moments = cv2.HuMoments(moments).flatten()
    return -np.sign(hu_moments) * np.log10(np.abs(hu_moments) + 1e-10)

def select_diverse_templates(candidates, k):
    """
    Selects k diverse templates from the candidates using greedy farthest-point sampling.

    Args:
        candidates: List of candidate patches (numpy arrays)
        k: Number of templates to select

    Returns:
        List of selected templates (numpy arrays)
    """
    selected = [candidates[0]]
    selected_ids = {id(candidates[0])}

    while len(selected) < k and len(selected) < len(candidates):
        remaining = [c for c in candidates if id(c) not in selected_ids]
        if not remaining:
            break
        best = max(
            remaining,
            key=lambda c: min(np.linalg.norm(c[0] - s[0]) for s in selected)
        )
        selected.append(best)
        selected_ids.add(id(best))

    # return patches and their bounding boxes separately
    return [s[1] for s in selected], [s[2] for s in selected]

def remove_outlier_boxes(boxes, eps=50, min_samples=3):
    """
    Removes outlier boxes if they are too far from the main cluster of boxes.

    Args:
        boxes: List of bounding boxes (x, y, width, height)
    
    Returns:
        List of bounding boxes after removing outliers
    """
    centers = np.column_stack((boxes[:, 0] + boxes[:, 2] / 2,
                               boxes[:, 1] + boxes[:, 3] / 2))
    
    clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(centers)
    labels = clustering.labels_

    # keep only boxes from the largest cluster (label >= 0)
    valid_labels = labels[labels >= 0]
    if len(valid_labels) == 0:
        return boxes.tolist() # fallback: keep all
    main_cluster = np.argmax(np.bincount(valid_labels))
    keep_indices = np.where(labels == main_cluster)[0]

    return boxes[keep_indices].tolist()

def iou(box1, box2):
    """
    Computes the Intersection over Union (IoU) of two bounding boxes.

    Args:
        box1: First bounding box (x, y, width, height)
        box2: Second bounding box (x, y, width, height)

    Returns:
        IoU value
    """
    x1, y1, w1, h1 = box1
    x2, y2, w2, h2 = box2

    xa = max(x1, x2)
    ya = max(y1, y2)
    xb = min(x1 + w1, x2 + w2)
    yb = min(y1 + h1, y2 + h2)

    inter_area = max(0, xb - xa) * max(0, yb - ya)
    box1_area = w1 * h1
    box2_area = w2 * h2

    union_area = box1_area + box2_area - inter_area
    return inter_area / union_area if union_area != 0 else 0

def cascade_template_matching(reflectance, template_bounds, method, match_thresh=0.9, nms_thresh=0.3, debug=False):
    """
    Performs template matching repeatedly using new matches as templates until enough matches found.

    Args:
        reflectance: Input reflectance image (numpy array)
        heatmap_np: UNet heatmap (numpy array)
        templates: List of template images (numpy arrays)
        template_bounds: List of bounding boxes for the templates (x, y, width, height)
        method: Method for template matching (default is cv2.TM_CCOEFF_NORMED)
        match_thresh_t: Threshold for initial template matching (default is 0.7)
        match_thresh_ct: Threshold for cascade template matching (default is 0.8)
        nms_thresh: Threshold for non-maximum suppression (default is 0.3)
        k_templates: Number of templates to select (default is 4)
        N_cascades: Number of cascades to perform (default is 1)
        debug: If True, display debug information (default is False)

    Returns:
        List of bounding boxes for the detected matches.
    """
    templates = template_bounds # tracking box coords of matches found in template matching
    unused_templates = templates.copy() # track unused templates
    used_templates = [] # track used templates

    debug_snapshots = [] # for debugging, store snapshots of the cascade process

    # perform template matching until all matches have been used as templates or number of matches above max possible
    count = 0
    while unused_templates:
        if len(templates) >= 256:  # max number of templates to avoid excessive computation
            print("Maximum number of templates reached, stopping cascade.")
            break

        count += 1

        if debug:
            print(f"Cascade {count}, templates found: {len(templates)}, unused: {len(unused_templates)}, used: {len(used_templates)}")

        for template in unused_templates:
            if type(template) is not list:
                template = template.tolist()
            
            used_templates.append(template) # add to used templates

            # extract template from reflectance image
            x, y, side1, side2 = template
            template = reflectance[y:y + side2, x:x + side1]
            
            # do template matching
            matches = template_matching(reflectance, template, method, match_thresh, nms_thresh).tolist()

            # add new matches to template list if not overlapping with already existing templates
            matches = [m for m in matches if not any(iou(m, template) > nms_thresh for template in templates)]
            templates.extend(matches)

        # remove overlapping templates
        reduced = list(non_max_suppression_fast(np.array(templates), [1]*len(templates), overlap_thresh=nms_thresh))
        templates = [t.tolist() for t in reduced] # convert list of numpy arrays to list of lists

        # create new unused_templates list
        unused_templates = []
        for template in templates:
            # only add templates that do not overlap with used templates
            if any(iou(template, used_template) > nms_thresh for used_template in used_templates):
                continue
            unused_templates.append(template)
        
        if debug:
            debug_snapshots.append({
                'templates': templates.copy(),
                'used_templates': used_templates.copy(),
                'count': count
            })

    if debug and debug_snapshots:
        n = len(debug_snapshots)
        fig, axes = plt.subplots(1, n, figsize=(5 * n, 5))
        if n == 1:
            axes = [axes]
        
        for ax, snapshot in zip(axes, debug_snapshots):
            ax.imshow(reflectance, cmap='gray')
            for i, template in enumerate(snapshot['templates']):
                x, y, side1, side2 = template
                color = 'green' if template in snapshot['used_templates'] else 'red'
                ax.add_patch(plt.Rectangle((x, y), side1, side2, edgecolor=color, facecolor='none', lw=2))
            ax.set_title(f"Cascade {snapshot['count']}")
            ax.axis('off')
        plt.tight_layout()
        plt.show()
    
    final_templates = np.array(templates)
    final_templates = remove_outlier_boxes(final_templates)

    return final_templates

In [None]:
def display_image(image, size=(300, 300)):
    """
    Displays the numpy image using PIL and notebook display functionality.

    Args:
        image: Input image (numpy array)
        size: Size to which the image should be resized (default is (300, 300))
    
    Returns:
        None
    """
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, size)
    pil_image = Image.fromarray(image)
    display(pil_image)

In [None]:
def display_yucheng_methods(nms_boxes, reflectance, img, illumination, template_A, template_B):
    """
    Displays the results of Yuchengs methods for dot detection and template matching.

    Args:
        nms_boxes: List of bounding boxes after non-maximum suppression
        reflectance: Reflectance map (numpy array)
        dot_contours: Contours of the detected dots
        img: Original image (numpy array)
        illumination: Estimated illumination (numpy array)
        dot_template: Dot template (numpy array)
    
    Returns:
        None
    """
    # === Draw matching result ===
    output = cv2.cvtColor(reflectance, cv2.COLOR_GRAY2BGR)
    for (x, y, w, h) in nms_boxes:
        cv2.rectangle(output, (x, y), (x + w, y + h), (0, 255, 0), 2)

    # === Show results ===
    fig, axs = plt.subplots(2, 3, figsize=(10, 6))
    axs[0, 0].imshow(img, cmap='gray')
    axs[0, 0].set_title("Original Image")
    axs[0, 0].axis("off")

    axs[0, 1].imshow(illumination, cmap='gray')
    axs[0, 1].set_title("Estimated Illumination")
    axs[0, 1].axis("off")

    axs[0, 2].imshow(reflectance, cmap='gray')
    axs[0, 2].set_title("Reflectance Map (SSR)")
    axs[0, 2].axis("off")

    axs[1, 0].imshow(cv2.cvtColor(template_A, cv2.COLOR_BGR2RGB))
    axs[1, 0].set_title("Template A")
    axs[1, 0].axis("off")

    axs[1, 1].imshow(template_B, cmap='gray')
    axs[1, 1].set_title("Template B")
    axs[1, 1].axis("off")

    axs[1, 2].imshow(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))
    axs[1, 2].set_title("Template matching")
    axs[1, 2].axis("off")

    plt.tight_layout()
    plt.show()

## Yucheng Use

In [None]:
# === load image ===
img_to_test = '../data/MAN/raw/train/1D1165212740006.jpeg'
image = load_image_to_device(img_to_test)

# === get yolo crop ===
yolo_path = '../yolo/runs/obb/train7/weights/best.pt'
pad = 0.01 # % to pad the yolo crop by
yolo_model = load_yolo(yolo_path)
image_yolo = yolo_detect_and_crop(image, yolo_model, pad)

# === get image and unet template ===
unet_path = '../models/dot_detection/checkpoints/unet_best.pth'
unet_model = load_unet(unet_path)
reflectance, illumination, templates, template_bounds, heatmap_np = unet_get_template(image_yolo, unet_model, k_templates=3, debug=True)

# === cascade template matching ===
nms_boxes = cascade_template_matching(reflectance, template_bounds, cv2.TM_CCOEFF_NORMED, match_thresh=0.925, nms_thresh=0.2, debug=True)

# === visualize results ===
display_yucheng_methods(nms_boxes, reflectance, image_yolo.squeeze(0).permute(1, 2, 0).cpu().numpy(), illumination, templates[0], templates[1])

# Grid Fitting

In [None]:
def generate_grid(params, grid_size=20):
    """
    Generates a grid of DMC points based on the given parameters.

    Args:
        x0: X coordinate of the center of the grid
        y0: Y coordinate of the center of the grid
        sx: Scale factor in the X direction
        sy: Scale factor in the Y direction
        theta: Rotation angle in radians
        grid_size: Size of the grid (default is 16)

    Returns:
        array of grid points (x, y) in the original coordinate system
    """
    x0, y0, sx, sy, theta = params

    # force sx and sy to be minimum of 1.0 (to avoid zero size)
    sx = max(sx, 1.0)
    sy = max(sy, 1.0)

    # building full grid of DMC points
    coords = []
    for i in range(grid_size):
        for j in range(grid_size):
            coords.append([i, j])
    coords = np.array(coords).astype(float)

    # center the grid around (0, 0)
    coords -= (grid_size - 1) / 2

    # Convert to original coordinate system
    coords = np.dot(coords, np.array([[sx, 0], [0, sy]])) # scale
    coords = np.dot(coords, np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])) # rotate
    coords += np.array([[x0, y0]]) # translate

    return coords

def inverse_grid_transform(grid_pts, params, grid_size=20):
    """
    Inverse transformation of the grid points to the original coordinate system.

    Args:
        grid_pts: Grid points to be transformed (numpy array)
        params: Parameters used for the transformation (x0, y0, sx, sy, theta)
        grid_size: Size of the grid (default is 16)

    Returns:
        array of grid points (x, y) in the dmc coordinate system
    """
    x0, y0, sx, sy, theta = params

    # force sx and sy to be minimum of 1.0 (same as in generate_grid)
    sx = max(sx, 1.0)
    sy = max(sy, 1.0)

    # undo translation
    grid_pts = grid_pts.astype(float)
    grid_pts -= np.array([[x0, y0]])

    # undo rotation
    rot_mat_inv = np.array([[np.cos(theta), np.sin(theta)],
                            [-np.sin(theta), np.cos(theta)]])
    grid_pts = np.dot(grid_pts, rot_mat_inv)

    # undo scaling
    grid_pts = np.dot(grid_pts, np.linalg.inv(np.diag([sx, sy])))

    # convert to standard grid coordinates
    grid_pts += (grid_size - 1) / 2

    # round to nearest integer
    grid_pts = np.rint(grid_pts).astype(int)

    # clip to valid range (to avoid errors)
    if np.any(grid_pts < 0) or np.any(grid_pts >= grid_size):
        print("Warning: one or more grid points are out of bounds!")
        grid_pts = np.clip(grid_pts, 0, grid_size - 1)

    return grid_pts


def show_grid(img, grid_pts, title="Grid Points"):
    """
    Shows the grid points on the image.

    Args:
        img: Input image (numpy array)
        grid_pts: Grid points to be displayed (numpy array)
        title: Title of the plot (default is "Grid Points")
    """
    img = img.copy() # to avoid modifying the original image
    if len(img.shape) == 2 or img.shape[2] == 1:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)  # Convert grayscale to BGR
    for p in grid_pts:
        # cv2.circle(img, (int(p[0]), int(p[1])), 3, (0, 255, 0), -1)
        cv2.circle(img, (int(p[0]), int(p[1])), 3, (0, 0, 255), -1) # red color for grid points
    plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    plt.title(title)
    plt.axis("off")
    plt.show()

def show_grids(img, init_estimate, opt_center, opt_spacing, opt_theta, opt_all, mapped):
    """
    Shows multiple grid debug images on the same plot.

    Args:
        img: Input image (numpy array)
        init_estimate: Initial estimate of the grid points (numpy array)
        opt_center: Optimized center of the grid (numpy array)
        opt_spacing: Optimized spacing of the grid (numpy array)
        opt_theta: Optimized rotation angle of the grid (float)
        opt_all: Optimized grid points (numpy array)
        mapped: Mapped grid points (numpy array)
    """
    if len(img.shape) == 2 or img.shape[2] == 1:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)  # Convert grayscale to BGR

    fig, axs = plt.subplots(2, 3, figsize=(10, 6))

    img_copy = img.copy()
    for p in init_estimate:
        cv2.circle(img_copy, (int(p[0]), int(p[1])), 3, (0, 0, 255), -1) # red color for grid points
    axs[0, 0].imshow(cv2.cvtColor(img_copy, cv2.COLOR_BGR2RGB))
    axs[0, 0].set_title("Initial Estimate")
    axs[0, 0].axis("off")

    img_copy = img.copy()
    for p in opt_center:
        cv2.circle(img_copy, (int(p[0]), int(p[1])), 3, (0, 0, 255), -1)
    axs[0, 1].imshow(cv2.cvtColor(img_copy, cv2.COLOR_BGR2RGB))
    axs[0, 1].set_title("Optimized Center")
    axs[0, 1].axis("off")

    img_copy = img.copy()
    for p in opt_spacing:
        cv2.circle(img_copy, (int(p[0]), int(p[1])), 3, (0, 0, 255), -1)
    axs[0, 2].imshow(cv2.cvtColor(img_copy, cv2.COLOR_BGR2RGB))
    axs[0, 2].set_title("Optimized Spacing")
    axs[0, 2].axis("off")

    img_copy = img.copy()
    for p in opt_theta:
        cv2.circle(img_copy, (int(p[0]), int(p[1])), 3, (0, 0, 255), -1)
    axs[1, 0].imshow(cv2.cvtColor(img_copy, cv2.COLOR_BGR2RGB))
    axs[1, 0].set_title("Optimized Theta")
    axs[1, 0].axis("off")

    img_copy = img.copy()
    for p in opt_all:
        cv2.circle(img_copy, (int(p[0]), int(p[1])), 3, (0, 0, 255), -1)
    axs[1, 1].imshow(cv2.cvtColor(img_copy, cv2.COLOR_BGR2RGB))
    axs[1, 1].set_title("Optimized All")
    axs[1, 1].axis("off")

    img_copy = img.copy()
    for p in mapped:
        cv2.circle(img_copy, (int(p[0]), int(p[1])), 3, (0, 0, 255), -1)
    axs[1, 2].imshow(cv2.cvtColor(img_copy, cv2.COLOR_BGR2RGB))
    axs[1, 2].set_title("Mapped Points")
    axs[1, 2].axis("off")

    plt.tight_layout()
    plt.show()


# === Example usage of grid generation ===
img = reflectance
grid_size = 20

# params to cover entire image
img_size = img.shape[0]
params = [
    img_size / 2, # x0
    img_size / 2, # y0
    img_size / (grid_size-1), # sx
    img_size / (grid_size-1), # sy
    0.0 # theta
]
grid_pts = generate_grid(params, grid_size)
show_grid(img, grid_pts)

In [None]:
def estimate_grid_params(nms_boxes, img, debug=False):
    """
    Estimates grid starting parameters using the observed points from the detected boxes.
    The function uses the observed points to compute the initial parameters for the affine transformation.
    """
    observed_pts = np.array([[x + w / 2, y + h / 2] for (x, y, w, h) in nms_boxes])

    # estimating reasonable x, y based on the average of all points
    x = np.mean(observed_pts[:, 0])
    y = np.mean(observed_pts[:, 1])

    # getting mode distance between closest points for later usage
    distances = []
    for i in range(len(observed_pts)):
        closest_pt = float('inf')
        for j in range(i + 1, len(observed_pts)):
            if i != j:
                # find the closest point to i
                dist = np.linalg.norm(observed_pts[i] - observed_pts[j])
                if dist < closest_pt:
                    closest_pt = dist
        if closest_pt != float('inf'):
            distances.append(np.round(closest_pt, 2))
    # mode distance is used because we want the most common distance, not the average
    mod_dist = get_mode(distances)
    if debug:
        print(f"Modal distance between (close) points: {mod_dist}")

    # estimating sx & sy & theta based on "L shapes" formed by the points
    L_shapes = set()
    all_L_shapes = set() # to keep track of all L shapes found
    for i in range(len(observed_pts)):
        # find closest 2 points to i
        dist_a = float('inf')
        a_idx = -1
        dist_b = float('inf')
        b_idx = -1
        for j in range(len(observed_pts)):
            if i == j:
                continue
            dist = np.linalg.norm(observed_pts[i] - observed_pts[j])

            # if point closer than a, update a and b accordingly
            if dist < dist_a:
                # b takes the place of a
                dist_b = dist_a
                b_idx = a_idx
                
                # a takes the place of j
                dist_a = dist
                a_idx = j
            # if point is only closer than b, update b
            elif dist < dist_b:
                dist_b = dist
                b_idx = j
        
        # closest points found, now calculate all 3 "L" shapes (each point can be the corner of the "L")
        L1 = (observed_pts[a_idx], observed_pts[i], observed_pts[b_idx]) # i in the corner
        L2 = (observed_pts[i], observed_pts[a_idx], observed_pts[b_idx]) # a in the corner
        L3 = (observed_pts[i], observed_pts[b_idx], observed_pts[a_idx]) # b in the corner

        for L in [L1, L2, L3]:
            # check if L forms a 90 degree angle

            vec_a = L[0] - L[1] # vector from corner (L[1]) to first edge point (L[0])
            vec_b = L[2] - L[1] # vector from corner (L[1]) to second edge point (L[2])
            cos_angle = np.dot(vec_a, vec_b) / (np.linalg.norm(vec_a) * np.linalg.norm(vec_b) + 1e-8)
            angle_deg = np.degrees(np.arccos(np.clip(cos_angle, -1.0, 1.0)))
            if 80 <= angle_deg <= 100:
                # L shape has valid angle
                # check if distances from edge points to corner point are relatively equal
                dist_a = np.linalg.norm(vec_a)
                dist_b = np.linalg.norm(vec_b)
                abs_diff = abs(dist_a - dist_b)
                if abs_diff <= img.shape[0] * 0.01:  # allow distances to differ by 1% of the image size

                    # vertical distance is the distance of the most vertical vector
                    if abs(vec_a[0]) < abs(vec_b[0]): # vec_a is more vertical than vec_b
                        vertical_dist = dist_a
                        horizontal_dist = dist_b
                    else: # vec_b is more vertical than vec_a
                        vertical_dist = dist_b
                        horizontal_dist = dist_a

                    # store the orientation of the L shape (smallest angle away from vertical)
                    L_orientation = np.arctan2(vec_a[1], vec_a[0]) - np.arctan2(0, 1) # angle between vector and vertical
                    L_orientation = np.round(L_orientation, 2) # round to 2 decimal places for consistent mode calculation
                    if L_orientation == 0: # weird case where numpy differentiates between -0.0 and 0.0
                        L_orientation = 0
                    L_orientation = -L_orientation # flipped angle to match the default vector (0, 1)

                    # L shape is temporarily valid, add it to the set with dist and orientation values at end
                    L_shapes.add((tuple(L[0].tolist()), tuple(L[1].tolist()), tuple(L[2].tolist()), horizontal_dist, vertical_dist, L_orientation))
                    all_L_shapes.add((tuple(L[0].tolist()), tuple(L[1].tolist()), tuple(L[2].tolist()), horizontal_dist, vertical_dist, L_orientation))

    # return default parameters if no L shapes found
    if not L_shapes:
        print("No valid L shapes found, using default parameters.")
        return [x, y, mod_dist, mod_dist, 0.0], observed_pts

    # round distances to ints and keep only L shapes with distances close to the mode
    modal_x = get_mode([int(L[3]) for L in L_shapes])  # horizontal distance
    modal_y = get_mode([int(L[4]) for L in L_shapes])  # vertical distance
    # keep only L shapes with distances equal to modal_x and modal_y
    L_shapes = [L for L in L_shapes if int(L[3]) <= modal_x or int(L[4]) <= modal_y]

    # round theta to nearest degree and keep only L shapes with orientation close to the mode
    modal_theta = get_mode([int(np.round(L[5] * 180 / np.pi)) for L in L_shapes])  # convert radians to degrees
    L_shapes = [L for L in L_shapes if int(np.round(L[5] * 180 / np.pi)) == modal_theta]
    
    # estimate sx, sy, theta based on the average of the L shapes
    sx = np.mean([L[3] for L in L_shapes])  # horizontal distance
    sy = np.mean([L[4] for L in L_shapes])  # vertical distance
    theta = np.mean([L[5] for L in L_shapes])  # orientation in radians

    if debug:
        # visualize the L shapes found
        if len(img.shape) == 2 or img.shape[2] == 1:
            img_copy = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
        else:
            img_copy = img.copy()
        for L in all_L_shapes:
            point_a, corner, point_b = L[:3]
            point_a = (int(point_a[0]), int(point_a[1]))
            corner = (int(corner[0]), int(corner[1]))
            point_b = (int(point_b[0]), int(point_b[1]))

            # draw color based on if L shape dists and orientation are equal to the estimated sx, sy, theta
            if L in L_shapes:
                color = (0, 255, 0)  # green for valid L shapes
            else:
                color = (0, 0, 255) # red for invalid L shapes

            cv2.line(img_copy, point_a, corner, color, 2)
            cv2.line(img_copy, point_b, corner, color, 2)
            cv2.line(img_copy, point_a, point_b, color, 2)
        plt.imshow(cv2.cvtColor(img_copy, cv2.COLOR_BGR2RGB))
        # plt.title("Estimated L Shapes")
        plt.axis("off")
        plt.show()

        print(f"Estimated parameters: x={x}, y={y}, sx={sx}, sy={sy}, theta={theta}")

    return [x, y, sx, sy, theta], observed_pts

# === Example usage of grid parameter estimation ===
img = reflectance
nms_boxes = [(x, y, w, h) for (x, y, w, h) in nms_boxes]
init_params, observed_pts = estimate_grid_params(nms_boxes, reflectance, debug=True)
grid_pts = generate_grid(init_params)
show_grid(img, grid_pts)

In [None]:
def plot_convex_hull(observed_points, grid_points=None):
    hull = ConvexHull(observed_points)

    plt.figure(figsize=(8, 8))
    plt.plot(observed_points[:, 0], observed_points[:, 1], 'o', label='Observed Points')

    # Draw convex hull edges
    for simplex in hull.simplices:
        plt.plot(observed_points[simplex, 0], observed_points[simplex, 1], 'k-')

    # Optionally plot grid points
    if grid_points is not None:
        plt.plot(grid_points[:, 0], grid_points[:, 1], 'rx', label='Grid Points')

    plt.title("Observed Points and Convex Hull")
    plt.legend()
    plt.axis('equal')
    plt.grid(True)
    plt.show()

In [None]:
def hungarian_cost(params, observed_pts):
    """
    Cost function for optimization using the Hungarian algorithm.
    Args:
        params: Parameters for the grid (x0, y0, sx, sy, theta)
        observed_pts: Observed points (numpy array)
    
    Returns:
        float: Total cost of the assignment
    """
    grid_pts = generate_grid(params)
    cost_matrix = cdist(observed_pts, grid_pts, metric='euclidean')
    row_ind, col_ind = linear_sum_assignment(cost_matrix)
    total_cost = cost_matrix[row_ind, col_ind].sum()
    if np.isnan(total_cost) or np.isinf(total_cost):
        return float('inf')  # Return a large cost if the total cost is invalid
    return total_cost

# === Example usage of cost function ===
img = reflectance
nms_boxes = [(x, y, w, h) for (x, y, w, h) in nms_boxes]

# cost of grid covering entire image
grid_size = 20
params = [
    img.shape[0] / 2, # x0
    img.shape[0] / 2, # y0
    img.shape[0] / (grid_size-1), # sx
    img.shape[0] / (grid_size-1), # sy
    0.0 # theta
]
grid_pts = generate_grid(params)
show_grid(img, grid_pts)
observed_pts = np.array([[x + w / 2, y + h / 2] for (x, y, w, h) in nms_boxes])
cost_value = hungarian_cost(params, observed_pts)
print(f"Cost value (naive start): {cost_value}")

init_params, observed_pts = estimate_grid_params(nms_boxes, reflectance, debug=False)
grid_pts = generate_grid(init_params)
show_grid(img, grid_pts)
cost_value = hungarian_cost(init_params, observed_pts)
print(f"Cost value (estimated start): {cost_value}")

In [None]:
def cost_xy(xy, params, observed_pts):
    """
    Wrapper function for cost function. It takes the xy coordinates and the rest of the parameters separately.
    Used for optimizing x0 and y0.

    Args:
        xy: X and Y coordinates of the center of the grid
        params: Parameters for grid (sx, sy, theta)
    
    Returns:
        Mean squared distance between observed points and grid points.
    """
    x0, y0 = xy
    sx, sy, theta = params
    return hungarian_cost([x0, y0, sx, sy, theta], observed_pts)

def cost_sx_sy(sx_sy, params, observed_pts):
    """
    Wrapper function for cost function. It takes the sx and sy coordinates and the rest of the parameters separately.
    Used for optimizing sx and sy.

    Args:
        sx_sy: Scale factors in the X and Y directions
        params: Parameters for grid (x0, y0, theta)
    
    Returns:
        Mean squared distance between observed points and grid points.
    """
    sx, sy = sx_sy
    x0, y0, theta = params
    return hungarian_cost([x0, y0, sx, sy, theta], observed_pts)

def cost_theta(theta, params, observed_pts):
    """
    Wrapper function for cost function. It takes the theta coordinate and the rest of the parameters separately.
    Used for optimizing theta.

    Args:
        theta: Rotation angle in radians
        params: Parameters for grid (x0, y0, sx, sy)
    
    Returns:
        Mean squared distance between observed points and grid points.
    """
    theta = theta[0]  # Extract the single value from the array
    x0, y0, sx, sy = params
    return hungarian_cost([x0, y0, sx, sy, theta], observed_pts)

In [None]:
def estimate_grid(init_params, observed_pts, debug=False):
    """
    Estimates the grid parameters using optimization. The function optimizes the parameters
    (x0, y0, sx, sy, theta) to minimize the cost function based on the observed points.

    The optimization is done in the following order:
    1. Optimize x0 and y0
    2. Optimize sx and sy
    3. Optimize theta
    4. Optimize all parameters together

    Args:
        init_params: Initial parameters for the grid (x0, y0, sx, sy, theta)
        init_pts: Initial estimate for grid points (numpy array)
        img: Input image (numpy array)
        N: Number of closest points to consider for cost function
        alpha: Weighting factor for the cost function
        debug: If True, display debug information (default is False)
    
    Returns:
        Optimized parameters (x0, y0, sx, sy, theta) and the observed points.
    """
    x0, y0, sx, sy, theta = init_params
    if debug:
        init_pts = generate_grid([x0, y0, sx, sy, theta])

    # 1. optimize for x0, y0 only
    result = minimize(cost_xy, [x0, y0], args=([sx, sy, theta], observed_pts), method='Powell')
    x0, y0 = result.x
    if debug:
        print([x0, y0, sx, sy, theta])
        opt_center = generate_grid([x0, y0, sx, sy, theta])
    
    # 2. optimize for sx, sy only
    result = minimize(cost_sx_sy, [sx, sy], args=([x0, y0, theta], observed_pts), method='Powell')
    sx, sy = result.x
    if debug:
        print([x0, y0, sx, sy, theta])
        opt_spacing = generate_grid([x0, y0, sx, sy, theta])

    # 3. optimize for theta only
    result = minimize(cost_theta, [theta], args=([x0, y0, sx, sy], observed_pts), method='Powell')
    theta = result.x[0]
    # trying other 3 90 degree angles to see if they are better
    lowest_cost = hungarian_cost([x0, y0, sx, sy, theta], observed_pts)
    for i in range(1, 4):
        theta_tmp = (theta + i * np.pi / 2) % (2 * np.pi)
        # theta_tmp_cost = cost([x0, y0, sx, sy, theta_tmp], observed_pts, N, alpha)
        theta_tmp_cost = hungarian_cost([x0, y0, sx, sy, theta_tmp], observed_pts)
        if theta_tmp_cost < lowest_cost:
            if debug:
                print(f"Found better theta: {theta_tmp} with cost: {theta_tmp_cost}. (previous: {theta} with cost: {lowest_cost})")
            lowest_cost = theta_tmp_cost
            theta = theta_tmp
    if debug:
        print([x0, y0, sx, sy, theta])
        opt_theta = generate_grid([x0, y0, sx, sy, theta])

    # 4. optimize for all parameters together
    result = minimize(hungarian_cost, [x0, y0, sx, sy, theta], args=(observed_pts,), method='Powell')
    x0, y0, sx, sy, theta = result.x
    if debug:
        print([x0, y0, sx, sy, theta])
        opt_all = generate_grid([x0, y0, sx, sy, theta])

    if debug:
        return [x0, y0, sx, sy, theta], observed_pts, init_pts, opt_center, opt_spacing, opt_theta, opt_all
    else:
        return [x0, y0, sx, sy, theta], observed_pts

# === Example usage of grid estimation ===
init_params, init_pts = estimate_grid_params(nms_boxes, reflectance, debug=False)
print(f"Initial parameters: {init_params}")
grid_pts = generate_grid(init_params)
show_grid(img, grid_pts)
opt_params, observed_pts, init_pts, opt_center, opt_spacing, opt_theta, opt_all = estimate_grid(init_params, init_pts, debug=True)
grid_pts = generate_grid(opt_params)
show_grid(img, grid_pts, title="Optimized Grid Points")

In [None]:
# showing different grid points with different parameters
empty_img = np.zeros_like(img)

# alter values to be gray
empty_img.fill(200)  # Fill with gray color

print(empty_img.shape)

example_params = [
    img.shape[0] / 2,  # x0
    img.shape[0] / 2,  # y0
    (img.shape[0] / (grid_size - 1)) / 1.5,  # sx
    (img.shape[0] / (grid_size - 1)) / 1.5,  # sy
    0.0  # theta
]
example_grid_pts = generate_grid(example_params, grid_size)
print(example_params)
show_grid(empty_img, example_grid_pts, title="")

example_params = [
    img.shape[0] / 1.5,  # x0
    img.shape[0] / 2,  # y0
    (img.shape[0] / (grid_size - 1)) / 2.5,  # sx
    (img.shape[0] / (grid_size - 1)) / 2.5,  # sy
    1.0  # theta
]
example_grid_pts = generate_grid(example_params, grid_size)
print(example_params)
show_grid(empty_img, example_grid_pts, title="")

In [None]:
def map_observed_to_grid(observed_pts, grid_pts):
    """
    Maps observed points to grid points using the Hungarian algorithm to ensure 1-1 mapping.

    Args:
        observed_pts: Observed points (numpy array)
        grid_pts: Grid points (numpy array)
    
    Returns:
        Mapped grid points (numpy array)
    """
    cost_matrix = cdist(observed_pts, grid_pts, metric='euclidean')
    row_ind, col_ind = linear_sum_assignment(cost_matrix)
    mapped_grid_pts = grid_pts[col_ind]
    return mapped_grid_pts

mapped_pts = map_observed_to_grid(observed_pts, opt_all)
show_grids(img, init_pts, opt_center, opt_spacing, opt_theta, opt_all, mapped_pts)

# Decoding

In [None]:
def display_DMC(matrix, title="DMC Matrix"):
    """
    Display the DMC matrix like normal DMC.

    Args:
        matrix: Numpy array representing the DMC matrix
        title: Title of the plot (default is "DMC Matrix")
    """
    # Invert the matrix for display
    matrix = np.invert(matrix)
    plt.imshow(matrix, cmap='gray', interpolation='nearest')
    plt.title(title)
    plt.axis("off")
    plt.show()

In [None]:
def to_dmc_matrix(dmc_pts, grid_size=20, debug=False):
    """
    Converts DMC points to a matrix representation.

    Args:
        dmc_pts: DMC points (numpy array)
        grid_size: Size of the grid (default is 20)

    Returns:
        DMC matrix (numpy array)
    """
    if len(dmc_pts) < 50:
        print("Not enough points to create DMC matrix.")
        return None

    dmc_matrix = np.zeros((grid_size, grid_size), dtype=int)
    for pt in dmc_pts:
        x, y = pt
        dmc_matrix[y, x] = 1

    if debug:
        dmc_straight = np.invert(dmc_matrix.copy())

    # delete cols and rows that are all zeros
    dmc_matrix = dmc_matrix[np.any(dmc_matrix, axis=1)]
    dmc_matrix = dmc_matrix[:, np.any(dmc_matrix, axis=0)]

    if debug:
        dmc_zero_delete = np.invert(dmc_matrix.copy())

    # if edge col or row has a single 1, delete it
    if dmc_matrix[0, :].sum() == 1:
        dmc_matrix = dmc_matrix[1:, :]
    if dmc_matrix[:, 0].sum() == 1:
        dmc_matrix = dmc_matrix[:, 1:]
    if dmc_matrix[-1, :].sum() == 1:
        dmc_matrix = dmc_matrix[:-1, :]
    if dmc_matrix[:, -1].sum() == 1:
        dmc_matrix = dmc_matrix[:, :-1]

    if debug:
        dmc_one_delete = np.invert(dmc_matrix.copy())

    # two edges with most 1s are filled with 1s
    top_count = dmc_matrix[0, :].sum()
    left_count = dmc_matrix[:, 0].sum()
    bottom_count = dmc_matrix[-1, :].sum()
    right_count = dmc_matrix[:, -1].sum()
    bl = bottom_count + left_count
    br = bottom_count + right_count
    tl = top_count + left_count
    tr = top_count + right_count
    if br > bl and br > tl and br > tr:
        dmc_matrix[-1, :] = 1
        dmc_matrix[:, -1] = 1
        orientation = "bottom right"
    elif tl > bl and tl > br and tl > tr:
        dmc_matrix[0, :] = 1
        dmc_matrix[:, 0] = 1
        orientation = "top left"
    elif tr > bl and tr > br and tr > tl:
        dmc_matrix[0, :] = 1
        dmc_matrix[:, -1] = 1
        orientation = "top right"
    else: # either bottom left is best or all are equal (in which case we default to bottom left)
        dmc_matrix[-1, :] = 1
        dmc_matrix[:, 0] = 1
        orientation = "bottom left"

    if debug:
        dmc_finder_fill = np.invert(dmc_matrix.copy())

    # rotate orientation of the grid according to found orientation
    if orientation == "bottom right": # rotate 90 deg clockwise
        dmc_matrix = np.rot90(dmc_matrix, k=3)
    elif orientation == "top left": # rotate 90 deg anticlockwise
        dmc_matrix = np.rot90(dmc_matrix, k=1)
    elif orientation == "top right": # rotate 180 deg
        dmc_matrix = np.rot90(dmc_matrix, k=2)

    if debug:
        dmc_rotated = np.invert(dmc_matrix.copy())

    # iterate over timing pattern points and set accordingly
    w_size = dmc_matrix.shape[1]
    h_size = dmc_matrix.shape[0]
    print(f"Grid size: {w_size}x{h_size}")
    # unexpected 1s on the top are shifted down a pixel
    for i in range(1, w_size-1):
        if i % 2 == 0:
            dmc_matrix[0, i] = 1
        else:
            if dmc_matrix[0, i] == 1: # unexpected 1
                dmc_matrix[1, i] = 1
            dmc_matrix[0, i] = 0
    # unexpected 1s on the right are shifted left a pixel
    for i in range(1, h_size-1):
        if i % 2 == 1:
            dmc_matrix[i, -1] = 1
        else:
            if dmc_matrix[i, -1] == 1: # unexpected 1
                dmc_matrix[i, -2] = 1
            dmc_matrix[i, -1] = 0

    if debug:
        dmc_timing_fill = np.invert(dmc_matrix.copy())

    if debug:
        # display all debug images
        plt.figure(figsize=(12, 8))
        plt.subplot(2, 3, 1)
        plt.imshow(dmc_straight, cmap='gray')
        plt.title("Direct Mapping")
        plt.axis("off")

        plt.subplot(2, 3, 2)
        plt.imshow(dmc_zero_delete, cmap='gray')
        plt.title("Zero Delete")
        plt.axis("off")

        plt.subplot(2, 3, 3)
        plt.imshow(dmc_one_delete, cmap='gray')
        plt.title("One Delete")
        plt.axis("off")

        plt.subplot(2, 3, 4)
        plt.imshow(dmc_finder_fill, cmap='gray')
        plt.title("Finder Fill")
        plt.axis("off")

        plt.subplot(2, 3, 5)
        plt.imshow(dmc_rotated, cmap='gray')
        plt.title("Rotated")
        plt.axis("off")

        plt.subplot(2, 3, 6)
        plt.imshow(dmc_timing_fill, cmap='gray')
        plt.title("Timing Fill")
        plt.axis("off")

    return dmc_matrix

# === Example usage of mapping to DMC points ===
img = reflectance
nms_boxes = [(x, y, w, h) for (x, y, w, h) in nms_boxes]
init_params, init_pts = estimate_grid_params(nms_boxes, reflectance)

opt_params, observed_pts = estimate_grid(init_params, init_pts, debug=False)
grid_pts = generate_grid(opt_params)
mapped_grid_pts = map_observed_to_grid(observed_pts, grid_pts)
dmc_pts = inverse_grid_transform(mapped_grid_pts, opt_params)
dmc_matrix = to_dmc_matrix(dmc_pts, debug=True)
print("DMC matrix:")
print(dmc_matrix)

In [None]:
def decode_DMC(matrix):
    """
    Decodes the DMC matrix using pylibdmtx.

    Args:
        matrix: Numpy array representing the DMC matrix
    
    Returns:
        Decoded string if successful, None otherwise.
    """
    # Converting binary matrix to uint8 image
    image = np.zeros((matrix.shape[0], matrix.shape[1]), dtype=np.uint8)
    image[matrix == 1] = 255
    image = Image.fromarray(image, 'L')

    # Inverting the image for decoding
    image = Image.eval(image, lambda x: 255 - x)

    # Padding the image by 2 pixels to add margin larger than a DMC module (https://www.keyence.eu/ss/products/auto_id/codereader/basic_2d/datamatrix.jsp)
    image = np.pad(np.array(image), ((2, 2), (2, 2)), mode='constant', constant_values=255)
    image = Image.fromarray(image, 'L')

    # Resizing to larger image for better decoding
    image = image.resize((image.size[0] * 10, image.size[1] * 10), Image.NEAREST)

    # display the image
    # plt.imshow(image, cmap='gray')
    # plt.axis("off")
    # plt.show()

    # Decode using pylibdmtx
    decoded = decode(image)
    if decoded:
        raw_bytes = decoded[0].data
        try:
            # Try to decode as UTF-8
            decoded_str = raw_bytes.decode('utf-8')
            return decoded_str
        except UnicodeDecodeError:
            # If decoding fails, return None
            return None
    else:
        return None

# Example of Grid Fitting Followed by Decoding

In [None]:
# === Estimate initial grid parameters ===
init_params, init_pts = estimate_grid_params(nms_boxes,  reflectance)
grid_pts = generate_grid(init_params)
show_grid(img, grid_pts)

In [None]:
# === Optimize grid parameters & map ===
opt_params, observed_pts, init_pts, opt_center, opt_spacing, opt_theta, opt_all = estimate_grid(init_params, init_pts, debug=True)
grid_pts = generate_grid(opt_params)
mapped_grid_pts = map_observed_to_grid(observed_pts, grid_pts)
show_grids(img, init_pts, opt_center, opt_spacing, opt_theta, opt_all, mapped_grid_pts)

In [None]:
# === Convert to DMC matrix ===
dmc_pts = inverse_grid_transform(mapped_grid_pts, opt_params)
dmc_matrix = to_dmc_matrix(dmc_pts)
display_DMC(dmc_matrix)

In [None]:
# === Decoding DMC ===
decoded_data = decode_DMC(dmc_matrix)
print(decoded_data)

# Full Decoding Pipeline

In [None]:
def decode_pipeline(image_path, yolo_model, unet_model, yolo_pad, tm_method=cv2.TM_CCOEFF_NORMED, match_thresh=0.9, nms_thresh=0.3, k_tm_templates=3, debug=False, rotation=None):
    """
    Performs the entire decoding pipeline on the input image and template.

    Args:
        image_path: Path to the input image
        yolo_model: YOLO model for object detection
        unet_model: UNet model for template extraction
        yolo_pad: Padding for YOLO detection
        tm_method: Method for template matching (default is cv2.TM_CCOEFF_NORMED)
        tmm_thresh_t: Threshold for initial template matching (default is 0.7)
        tmm_thresh_ct: Threshold for cascade template matching (default is 0.8)
        nms_thresh: Threshold for non-maximum suppression (default is 0.3)
        k_tm_templates: Number of templates to use for initial template matching (default is 3)
        k_ctm_templates: Number of closest templates to use for cascade template matching (default is 3)
        N_cascades: Number of cascades for template matching (default is 0)
        N: Number of closest points to consider for cost function (default is 1)
        alpha: Weighting factor for the cost function (default is 4.0)
        debug: Flag for debugging (default is False)
        rotation: Rotation angle in degrees (default is None, no rotation)
    
    Returns:
        Decoded data from the DMC matrix or None if decoding fails.
    """
    # === Load image ===
    img = load_image_to_device(image_path)

    # === Rotation ===
    if rotation is not None:
        _, _, h, w = img.shape
        center = (w / 2, h / 2)

        # rotation matrix
        M = cv2.getRotationMatrix2D(center, rotation, 1)

        # getting size of new box to avoid cutting off corners
        cos = np.abs(M[0, 0])
        sin = np.abs(M[0, 1])
        new_w = int(h * sin + w * cos)
        new_h = int(h * cos + w * sin)

        # adjusting the rotation matrix to take into account translation
        M[0, 2] += new_w / 2 - center[0]
        M[1, 2] += new_h / 2 - center[1]

        # temporarily convert tensor to cv2 image
        img = img.squeeze(0).permute(1, 2, 0).cpu().numpy() # (H, W, 3)

        # rotating with new bounds
        img = cv2.warpAffine(img, M, (new_w, new_h))

        # converting back to tensor
        img = img.astype(np.float32)
        img = np.transpose(img, (2, 0, 1))
        img = torch.from_numpy(img).unsqueeze(0).to(device)

    # === YOLO crop ===
    img_yolo = yolo_detect_and_crop(img, yolo_model, yolo_pad, debug=debug)

    # === UNet template ===
    reflectance, illumination, unet_templates, template_bounds, heatmap_np = unet_get_template(img_yolo, unet_model, k_tm_templates, debug=debug)
    if reflectance is None or illumination is None or unet_templates is None or len(template_bounds) == 0:
        print("UNet failed to get templates, returning None...")
        return None
    
    # === Attempt decode with lower and lower thresholds until success or failure ===
    while True:
        # === Template matching ===
        nms_boxes = cascade_template_matching(reflectance, template_bounds, method=tm_method, match_thresh=match_thresh, nms_thresh=nms_thresh, debug=debug)
        if len(nms_boxes) <= 1:
            print("Template matching found no (or only 1) matches, lowering match threshold...")
            match_thresh -= 0.025
            continue
        if len(nms_boxes) >= 256:
            print(f"Too many matches found ({len(nms_boxes)}), returning None...")
            return None

        if debug:
            if len(unet_templates) > 1:
                display_yucheng_methods(nms_boxes, reflectance, img_yolo.squeeze(0).permute(1, 2, 0).cpu().numpy(), illumination, unet_templates[0], unet_templates[1])
            else:
                display_yucheng_methods(nms_boxes, reflectance, img_yolo.squeeze(0).permute(1, 2, 0).cpu().numpy(), illumination, unet_templates[0], unet_templates[0])

        # === Estimate initial grid parameters ===
        init_params, init_pts = estimate_grid_params(nms_boxes, reflectance, debug)

        # === Optimizing Grid Parameters ===
        if debug:
            opt_params, observed_pts, init_pts, opt_center, opt_spacing, opt_theta, opt_all = estimate_grid(init_params, init_pts, debug)
        else:
            opt_params, observed_pts = estimate_grid(init_params, init_pts, debug)
        grid_pts = generate_grid(opt_params)

        # === Mapping observed points to grid points ===
        mapped_grid_pts = map_observed_to_grid(observed_pts, grid_pts)
        if debug:
            show_grids(reflectance, init_pts, opt_center, opt_spacing, opt_theta, opt_all, mapped_grid_pts)

        # === Convert to DMC matrix ===
        dmc_pts = inverse_grid_transform(mapped_grid_pts, opt_params)
        dmc_matrix = to_dmc_matrix(dmc_pts, debug=debug)
        if dmc_matrix is None:
            print("DMC matrix conversion failed, trying with lower match threshold...")
            match_thresh -= 0.025
            continue

        # === Decoding DMC ===
        decoded_data = decode_DMC(dmc_matrix)

        if type(decoded_data) is not str:
            print("Decoding failed, trying with lower match threshold...")
            match_thresh -= 0.025

        else:
            if debug:
                print(f"Decoded data: {decoded_data}")

            return decoded_data

img_to_test = '../data/MAN/raw/train/1D1165212740006.jpeg'

# === Load models ===
yolo_pad = 0.01 # % to pad the yolo crop by
yolo_model = load_yolo('../yolo/runs/obb/train7/weights/best.pt')
unet_model = load_unet('../models/dot_detection/checkpoints/unet_best.pth')

# === Template matching params ===
tm_method = cv2.TM_CCOEFF_NORMED # method for template matching
match_thresh = 0.95 # threshold for template matching (higher = more strict matching)
nms_thresh = 0.2 # threshold for non-maximum suppression (higher = more overlap)
k_tm_templates = 3 # number of UNet templates to get (higher = more templates, but slower)

decoded_data = decode_pipeline(img_to_test, yolo_model, unet_model, yolo_pad, tm_method, match_thresh, nms_thresh, k_tm_templates, debug=True, rotation=None)

# Evaluating Pipeline

In [None]:
def evaluate_pipeline(image_paths, yolo_model, unet_model, yolo_pad, tm_method=cv2.TM_CCOEFF_NORMED, match_thresh=0.95, nms_thresh=0.2, k_tm_templates=3, img_idx=None, debug=False):
    """
    Evaluates the decoding pipeline on a list of images.

    Args:
        image_paths: List of paths to input images
        yolo_model: YOLO model for object detection
        unet_model: UNet model for template extraction
        yolo_pad: Padding for YOLO detection
        tm_method: Method for template matching (default is cv2.TM_CCOEFF_NORMED)
        match_thresh: Threshold for template matching (default is 0.95)
        nms_thresh: Threshold for non-maximum suppression (default is 0.2)
        k_tm_templates: Number of templates to use for initial template matching (default is 3)
        debug: Flag for debugging (default is False)
    """
    my_method = 0
    pylibdmtx_method = 0
    yolo_pylibdmtx_method = 0

    # if img_idx:
    #     image_paths = image_paths[img_idx-1:img_idx]

    for image_path in image_paths:
        # my method
        decoded_data = decode_pipeline(image_path, yolo_model, unet_model, yolo_pad, tm_method, match_thresh, nms_thresh, k_tm_templates, debug)
        if type(decoded_data) == str:
            my_method += 1
        
        # pylibdmtx method
        img = cv2.imread(image_path)
        decoded_data = decode(img)
        if decoded_data:
            pylibdmtx_method += 1
        
        # yolo + pylibdmtx method
        img = load_image_to_device(image_path)
        img = yolo_detect_and_crop(img, yolo_model, yolo_pad, debug=debug)
        img = img.squeeze(0).permute(1, 2, 0).cpu().numpy()
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        decoded_data = decode(img)
        if decoded_data:
            yolo_pylibdmtx_method += 1

    print(f"Results for {len(image_paths)} images:")
    print(f"My method: {my_method} ({my_method / len(image_paths) * 100:.2f}%)")
    print(f"Pylibdmtx method: {pylibdmtx_method} ({pylibdmtx_method / len(image_paths) * 100:.2f}%)")
    print(f"YOLO + Pylibdmtx method: {yolo_pylibdmtx_method} ({yolo_pylibdmtx_method / len(image_paths) * 100:.2f}%)")

## Train

In [None]:
image_dir = '../data/dot_detection/MAN/train-data/'
image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir)]
img_idx = 66

evaluate_pipeline(image_paths, yolo_model, unet_model, yolo_pad, tm_method, match_thresh, nms_thresh, k_tm_templates, img_idx, debug=False)

# Validation

In [None]:
image_dir = '../data/dot_detection/MAN/val-data/'
image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir)]
img_idx = 21

evaluate_pipeline(image_paths, yolo_model, unet_model, yolo_pad, tm_method, match_thresh, nms_thresh, k_tm_templates, img_idx, debug=False)

# Test

In [None]:
image_dir = '../data/dot_detection/MAN/test-data/'
image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir)]
img_idx = 36

evaluate_pipeline(image_paths, yolo_model, unet_model, yolo_pad, tm_method, match_thresh, nms_thresh, k_tm_templates, img_idx, debug=False)