# Setup
Thank you Yucheng for the dataloader, model and image preparing code.

## Imports

In [1]:
import glob
import os
import random

import cv2
import numpy as np
import pandas as pd
import torch
from PIL import Image, ImageDraw
from torchvision.transforms import v2 as T
from torchvision.transforms.functional import to_tensor, to_pil_image

import matplotlib.cm as cm
from torch.utils.data import DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F

import csv

from tqdm import tqdm

import matplotlib.pyplot as plt

import json

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


## Image Preparing Definitions

In [2]:
def image_rotate(image, angle_deg):
    """Rotate the image counterclockwise by `angle_deg` degrees."""
    (h, w) = image.shape[:2]
    center = (w / 2.0, h / 2.0)
    scale = 1.0
    angle_rad = np.deg2rad(angle_deg)
    a = np.sin(angle_rad) * scale
    b = np.cos(angle_rad) * scale
    new_w = int(h * abs(a) + w * abs(b))
    new_h = int(w * abs(a) + h * abs(b))

    M = cv2.getRotationMatrix2D(center, -angle_deg, scale)
    M[0, 2] += (new_w - w) / 2
    M[1, 2] += (new_h - h) / 2

    rotated = cv2.warpAffine(image, M, (new_w, new_h), flags=cv2.INTER_LINEAR, borderValue=(255, 255, 255))
    return rotated, M


def split_and_save_patches(image, dots, output_dir, base_index, patch_size=256, t_size=False):
    """Split image and dot coordinates into NxN patches and save each as img_xxxxx.*"""
    os.makedirs(output_dir, exist_ok=True)
    h, w = image.shape[:2]
    patch_idx = 0

    for y0 in range(0, h, patch_size):
        for x0 in range(0, w, patch_size):
            x1 = min(x0 + patch_size, w)
            y1 = min(y0 + patch_size, h)
            patch = image[y0:y1, x0:x1]

            if patch.shape[0] != patch_size or patch.shape[1] != patch_size:
                continue  # Skip incomplete patch

            local_dots = []
            for x, y in dots:
                if x0 <= x < x1 and y0 <= y < y1:
                    local_dots.append((x - x0, y - y0))

            name = f"img_{base_index + patch_idx:05d}"
            img_path = os.path.join(output_dir, name + ".png")
            csv_path = os.path.join(output_dir, name + ".csv")

            patch_gray = cv2.cvtColor(patch, cv2.COLOR_BGR2GRAY)
            cv2.imwrite(img_path, patch_gray)

            with open(csv_path, 'w', newline='') as f:
                writer = csv.writer(f)
                writer.writerow([t_size])
                writer.writerow(["x", "y"])
                writer.writerows(local_dots)

            patch_idx += 1

    return patch_idx


def parse_and_save_DSBI(image_path, output_dir, base_index, patch_size=256):
    """Parse annotation and split into patches with adjusted coordinates."""
    base_path = os.path.splitext(image_path)[0]
    recto_path = base_path + "+recto.txt"
    verso_path = base_path + "+verso.txt"

    image = cv2.imread(image_path)
    all_coords = []
    rotated_image = image
    last_transform = None

    for txt_path in [recto_path, verso_path]:
        if not os.path.exists(txt_path):
            continue

        with open(txt_path, 'r') as f:
            lines = f.readlines()

        if len(lines) < 3:
            continue

        angle = float(lines[0].strip())
        verticals = list(map(int, lines[1].strip().split()))
        horizontals = list(map(int, lines[2].strip().split()))
        cell_lines = lines[3:]

        rotated_image, transform = image_rotate(image, -angle)
        last_transform = transform

        for cell in cell_lines:
            parts = list(map(int, cell.strip().split()))
            r, c = parts[0] - 1, parts[1] - 1
            dots = parts[2:]
            for i, val in enumerate(dots):
                if val == 1:
                    if i < 3:
                        y = horizontals[r * 3 + i]
                        x = verticals[c * 2]
                    else:
                        y = horizontals[r * 3 + i - 3]
                        x = verticals[c * 2 + 1]
                    coord = np.array([[x, y]], dtype=np.float32)
                    coord = np.array([coord])
                    rotated_coord = cv2.transform(coord, transform)[0][0]
                    all_coords.append(rotated_coord.tolist())

        image = rotated_image

    if last_transform is None:
        return 0

    return split_and_save_patches(image, all_coords, output_dir, base_index, patch_size)


def collect_images(root_dir):
    """Recursively collect image files excluding those containing '+recto' or '+verso'."""
    valid_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')
    file_list = []
    for dirpath, _, filenames in os.walk(root_dir):
        for fname in filenames:
            if fname.lower().endswith(valid_extensions):
                if "+recto" not in fname and "+verso" not in fname:
                    full_path = os.path.join(dirpath, fname)
                    file_list.append(full_path)
    return file_list

## Model Definition

In [3]:
class DoubleConv(nn.Module):
    """(Conv → BN → ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class Down(nn.Module):
    """Downscaling with maxpool → double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.down = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.down(x)


class Up(nn.Module):
    """Upscaling → concat → double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        # pad if needed
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, base_c=64):
        super().__init__()
        self.in_conv = DoubleConv(in_channels, base_c)
        self.down1 = Down(base_c, base_c * 2)
        self.down2 = Down(base_c * 2, base_c * 4)
        self.down3 = Down(base_c * 4, base_c * 8)
        self.down4 = Down(base_c * 8, base_c * 8)

        self.up1 = Up(base_c * 16, base_c * 4)
        self.up2 = Up(base_c * 8, base_c * 2)
        self.up3 = Up(base_c * 4, base_c)
        self.up4 = Up(base_c * 2, base_c)

        self.out_conv = nn.Conv2d(base_c, out_channels, kernel_size=1)

    def forward(self, x):
        x1 = self.in_conv(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

        return torch.sigmoid(self.out_conv(x))

# Preparing Images

## Final Saving Format
Files are saved under data/dot_detection/prepared-patches.

Images/labels matched by filename.

Filename format is img_0000.png / img_0000.csv.

Image size is 512/512.

.csv format is:
- 25 <-- template size, either a number or False
- x,y <-- always there
- 0.3494873046875,22.6949462890625 <-- x,y coord (float due to reduced pixels)
- 22.349365234375,22.6181640625 <-- x,y coord (float due to reduced pixels)
- ...

## DSBI
### Format
Filepathing:
- data
  - Fundamentals of Massage
  - Massage
  - Math
  - Ordinary Printed Document
  - Shaver Yang Fengting
  - The Second Volume of Ninth Grade Chinese Book 1
  - The Second Volume of Ninth Grade Chinese Book 2

Different types of data are stored in different directories and processed later to be in the same. <br>
Images and labels are stored in the same directories and matched by filename.

example .txt data with explanation:
- 0.80 <-- angle (which gets corrected for during preprocessing)
- 47 67 94 114 142 162 189 209 236 256... <-- verticals
- 47 67 87 125 145 165 203 223 243 281... <-- horizontals
- 4 6 1 0 1 0 0 0 <-- what
- 4 7 1 1 1 0 0 1 <-- huh
- ...

In [4]:
root_directory = "../data/dot_detection/DSBI-master"
output_directory = "../data/dot_detection/prepared-patches"
patch_size = 512

image_list = collect_images(root_directory)
print(f"Found {len(image_list)} images to process.")

os.makedirs(output_directory, exist_ok=True)

patch_counter = 0
for path in sorted(image_list):
    print(f"Processing {path}...")
    n = parse_and_save_DSBI(path, output_directory, base_index=patch_counter, patch_size=patch_size)
    patch_counter += n

Found 115 images to process.
Processing ../data/dot_detection/DSBI-master\data\Fundamentals of Massage\FM+1.jpg...
Processing ../data/dot_detection/DSBI-master\data\Fundamentals of Massage\FM+10.jpg...
Processing ../data/dot_detection/DSBI-master\data\Fundamentals of Massage\FM+11.jpg...
Processing ../data/dot_detection/DSBI-master\data\Fundamentals of Massage\FM+12.jpg...
Processing ../data/dot_detection/DSBI-master\data\Fundamentals of Massage\FM+13.jpg...
Processing ../data/dot_detection/DSBI-master\data\Fundamentals of Massage\FM+14.jpg...
Processing ../data/dot_detection/DSBI-master\data\Fundamentals of Massage\FM+15.jpg...
Processing ../data/dot_detection/DSBI-master\data\Fundamentals of Massage\FM+16.jpg...
Processing ../data/dot_detection/DSBI-master\data\Fundamentals of Massage\FM+17.jpg...
Processing ../data/dot_detection/DSBI-master\data\Fundamentals of Massage\FM+18.jpg...
Processing ../data/dot_detection/DSBI-master\data\Fundamentals of Massage\FM+19.jpg...
Processing ../d

## MAN

In [5]:
def parse_and_save_MAN(image_path, output_dir):
    """Parse annotation and split into patches with adjusted coordinates."""
    image = cv2.imread(image_path)
    all_coords = []

    # finding label txt
    label_path = image_path.replace('images', 'labels')
    label_path = label_path.replace('.jpg', '.txt')

    with open(label_path, 'r') as f:
        lines = f.readlines()
    data = lines[0].split(' ')
    
    # first 5 elements are bounding box related (skip)
    data = data[5:]

    # rest of elements are x, y, class. we look through them and only keep x, y, of classes 2
    for i in range(0, len(data), 3):
        x = float(data[i])
        y = float(data[i + 1])
        cls = int(data[i + 2])
        if cls == 2:  # only keep class 2
            # convert relative (0-1) coords to relative (0-max width/height)
            x = x * image.shape[1]
            y = y * image.shape[0]

            all_coords.append([x, y])
    
    # based on xy coordinates, we find the center
    dmc_center = np.mean(all_coords, axis=0)
    dmc_center = np.array([dmc_center[0], dmc_center[1]])

    # find coords of dot furtherst from center
    max_dist = 0
    for coord in all_coords:
        dist = np.linalg.norm(coord - dmc_center)
        if dist > max_dist:
            max_dist = dist
            # farthest_coord = coord # unused

    # based on max dist, we use that val + some% padding to crop the image
    padding = max_dist * 1.1
    x1 = int(dmc_center[0] - padding)
    y1 = int(dmc_center[1] - padding)
    x2 = int(dmc_center[0] + padding)
    y2 = int(dmc_center[1] + padding)
    x1 = max(0, x1)
    y1 = max(0, y1)
    x2 = min(image.shape[1], x2)
    y2 = min(image.shape[0], y2)
    img_cropped = image[y1:y2, x1:x2]
    all_coords = np.array(all_coords) - np.array([x1, y1])

    # debug display crop with dots overlaid
    # debug_img = cv2.cvtColor(img_cropped, cv2.COLOR_BGR2RGB)
    # debug_img = Image.fromarray(debug_img)
    # draw = ImageDraw.Draw(debug_img)
    # for coord in all_coords:
    #     x, y = int(coord[0]), int(coord[1])
    #     draw.ellipse((x - 5, y - 5, x + 5, y + 5), fill=(255, 0, 0), outline=(255, 0, 0))
    # debug_img = np.array(debug_img)
    # debug_img = Image.fromarray(debug_img)
    # display(debug_img)

    # return split_and_save_patches(img_cropped, all_coords, output_dir, base_index, patch_size)

    # save cropped images and coordinates
    name = os.path.splitext(os.path.basename(image_path))[0]
    img_dir = output_dir + '/images/'
    img_path = os.path.join(img_dir, name + ".png")
    csv_dir = output_dir + '/labels/'
    csv_path = os.path.join(csv_dir, name + ".csv")

    cv2.imwrite(img_path, img_cropped)
    with open(csv_path, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(["x", "y"])
        for coord in all_coords:
            writer.writerow(coord)


root_directory = "../data/dot_detection/MAN/roboflow"

image_list = collect_images(root_directory)
print(f"Found {len(image_list)} images to process.")

output_directory = "../data/dot_detection/MAN/cropped"

for path in sorted(image_list):
    print(f"Processing {path}...")
    parse_and_save_MAN(path, output_directory)

Found 9 images to process.
Processing ../data/dot_detection/MAN/roboflow\images\1D1165212740006_jpeg.rf.d16389429a55a2c9ab5f6d02e8173fde.jpg...
Processing ../data/dot_detection/MAN/roboflow\images\1D1165212740007_jpeg.rf.26033cb881d8bd1c105f6cc48b455b82.jpg...
Processing ../data/dot_detection/MAN/roboflow\images\20231031_053432000_iOS_png.rf.4fd0ae5ea53655b08749a956753656aa.jpg...
Processing ../data/dot_detection/MAN/roboflow\images\20231113_024141652_iOS_jpg.rf.6c4fd508cbe29751ac4cb2042de3aaf5.jpg...
Processing ../data/dot_detection/MAN/roboflow\images\250115190140006_jpeg.rf.d3f59dc9b7bfc3e4fb7f7a25e8fff612.jpg...
Processing ../data/dot_detection/MAN/roboflow\images\250115190140012_jpeg.rf.86d7815fc6dbb08a22f3043e5d47a3ac.jpg...
Processing ../data/dot_detection/MAN/roboflow\images\G95-plunger-2_jpg.rf.cb1bf3703c0894c15072c90ec1c4b5b9.jpg...
Processing ../data/dot_detection/MAN/roboflow\images\IMG_1058_JPG.rf.5f34941fe887524a8cca104339e19771.jpg...
Processing ../data/dot_detection/MAN

In [6]:
root_dir = "../data/dot_detection/MAN/cropped"

image_list = collect_images(root_dir + '/images')
print(f"Found {len(image_list)} images to process.")

output_dir = "../data/dot_detection/MAN/prepared-patches"

patch_counter = 0
for path in sorted(image_list):
    print(f"Processing {path}...")

    # read template image to get patch size
    filename = os.path.basename(path)
    template_path = root_dir + '/templates/' + filename
    template = cv2.imread(template_path)
    if template is None:
        print(f"Error: Template image not found for {filename}. Skipping.")
        continue
    h, w = template.shape[:2]
    if h != w:
        print(f"Error: Template image is not square for {filename}. Skipping.")
        continue
    t_size = h

    img = cv2.imread(path) # read image
    coords = pd.read_csv(path.replace('images', 'labels').replace('.png', '.csv')).to_numpy().astype(np.float32)
    # TODO: try with lower patch_size since we only get 1 patch per image with 512x512
    n = split_and_save_patches(img, coords, output_dir, patch_counter, patch_size, t_size)
    patch_counter += n

Found 9 images to process.
Processing ../data/dot_detection/MAN/cropped/images\1D1165212740006_jpeg.rf.d16389429a55a2c9ab5f6d02e8173fde.png...
Processing ../data/dot_detection/MAN/cropped/images\1D1165212740007_jpeg.rf.26033cb881d8bd1c105f6cc48b455b82.png...
Processing ../data/dot_detection/MAN/cropped/images\20231031_053432000_iOS_png.rf.4fd0ae5ea53655b08749a956753656aa.png...
Processing ../data/dot_detection/MAN/cropped/images\20231113_024141652_iOS_jpg.rf.6c4fd508cbe29751ac4cb2042de3aaf5.png...
Processing ../data/dot_detection/MAN/cropped/images\250115190140006_jpeg.rf.d3f59dc9b7bfc3e4fb7f7a25e8fff612.png...
Processing ../data/dot_detection/MAN/cropped/images\250115190140012_jpeg.rf.86d7815fc6dbb08a22f3043e5d47a3ac.png...
Processing ../data/dot_detection/MAN/cropped/images\G95-plunger-2_jpg.rf.cb1bf3703c0894c15072c90ec1c4b5b9.png...
Processing ../data/dot_detection/MAN/cropped/images\IMG_1058_JPG.rf.5f34941fe887524a8cca104339e19771.png...
Processing ../data/dot_detection/MAN/cropped

## Dataloader Definition

In [36]:
class BrailleDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, patch_size=224, transform_prob=0.1, train_val_ratio=0.8, is_train=True):
        self.root_dir = root_dir
        self.patch_size = patch_size
        self.sigma = 3
        self.transform_prob = transform_prob
        self.is_train = is_train

        all_paths = sorted(glob.glob(os.path.join(root_dir, "*.png")))
        split_idx = int(len(all_paths) * train_val_ratio)
        self.image_paths = all_paths[:split_idx] if is_train else all_paths[split_idx:]

        # light augmentation for grayscale tensor images using torchvision v2
        self.color_aug = T.Compose([
            T.RandomApply([
                T.ColorJitter(brightness=0.5, hue=0.3),
                T.RandomInvert(),
                T.RandomPosterize(bits=2),
                T.RandomSolarize(threshold=192.0 / 255.0),
                T.GaussianBlur(kernel_size=(3, 3))
            ], p=self.transform_prob)
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert("L")
        w, h = img.size

        csv_path = os.path.splitext(img_path)[0] + ".csv"
        coords = pd.read_csv(csv_path, skiprows=1).to_numpy().astype(np.float32)
        img_tensor = to_tensor(img)  # shape: [1, H, W]

        # read first line of csv to get t_size
        with open(csv_path, 'r') as f:
            reader = csv.reader(f)
            first_line = next(reader)
            t_size = first_line[0]
            if t_size == 'False':
                t_size = False
            else:
                t_size = int(t_size)

        if self.is_train:
            if random.random() < 0.5:
                img_tensor = torch.flip(img_tensor, dims=[2])  # Horizontal flip
                coords[:, 0] = w - coords[:, 0]

            if random.random() < 0.5:
                img_tensor = torch.flip(img_tensor, dims=[1])  # Vertical flip
                coords[:, 1] = h - coords[:, 1]

            img_tensor, coords = self.apply_random_zoom(img_tensor, coords, zoom_range=(0.8, 1.2))
            img_tensor, coords = self.apply_random_crop(img_tensor, coords, crop_size=self.patch_size)

            angle = random.uniform(-45, 45)
            translate = (random.uniform(-0.1, 0.1) * w, random.uniform(-0.1, 0.1) * h)
            scale = random.uniform(0.9, 1.1)
            shear = random.uniform(-10, 10)
            img_tensor, coords = self.apply_affine(img_tensor, coords, angle, translate, scale, shear)

            if random.random() < 0.5:
                img_tensor, coords = self.apply_perspective(img_tensor, coords, distortion_scale=0.2)

            img_tensor = self.color_aug(img_tensor)

        heatmap = self.coords_to_heatmap(coords, img_tensor.shape[1:], sigma=self.sigma, t_size=t_size)
        return img_tensor, heatmap

    def apply_random_crop(self, img_tensor, coords, crop_size=256):
        _, h, w = img_tensor.shape
        if h <= crop_size or w <= crop_size:
            return img_tensor, coords

        x0 = random.randint(0, w - crop_size)
        y0 = random.randint(0, h - crop_size)
        x1, y1 = x0 + crop_size, y0 + crop_size

        img_cropped = img_tensor[:, y0:y1, x0:x1]
        coords_cropped = coords - np.array([x0, y0])

        mask = (
                (coords_cropped[:, 0] >= 0) & (coords_cropped[:, 0] < crop_size) &
                (coords_cropped[:, 1] >= 0) & (coords_cropped[:, 1] < crop_size)
        )
        coords_cropped = coords_cropped[mask]

        return img_cropped, coords_cropped

    def apply_affine(self, img_tensor, coords, angle=0, translate=(0, 0), scale=1.0, shear=0.0):
        img_np = (img_tensor.squeeze(0).cpu().numpy() * 255).astype(np.uint8)
        h, w = img_np.shape
        center = (w / 2.0, h / 2.0)

        M = cv2.getRotationMatrix2D(center, angle, scale)
        shear_rad = np.deg2rad(shear)
        shear_mat = np.array([[1, -np.tan(shear_rad)], [0, 1]])
        M[:2, :2] = shear_mat @ M[:2, :2]
        M[:, 2] += translate

        img_warped = cv2.warpAffine(img_np, M, (w, h), borderValue=255)

        coords_hom = np.hstack([coords, np.ones((coords.shape[0], 1))])
        coords_warped = (M @ coords_hom.T).T

        img_tensor = torch.from_numpy(img_warped).unsqueeze(0).float() / 255.0

        coords_int = coords_warped.astype(int)
        valid_mask = []
        for x, y in coords_int:
            if 0 <= x < w and 0 <= y < h and img_tensor[0, y, x] < 0.99:
                valid_mask.append(True)
            else:
                valid_mask.append(False)

        coords_warped = coords_warped[valid_mask]
        return img_tensor, coords_warped

    def apply_perspective(self, img_tensor, coords, distortion_scale=0.2):
        img_np = (img_tensor.squeeze(0).cpu().numpy() * 255).astype(np.uint8)
        h, w = img_np.shape

        def random_shift(pt):
            dx = np.random.uniform(-distortion_scale, distortion_scale) * w
            dy = np.random.uniform(-distortion_scale, distortion_scale) * h
            return [pt[0] + dx, pt[1] + dy]

        src_pts = np.array([[0, 0], [w, 0], [w, h], [0, h]], dtype=np.float32)
        dst_pts = np.array([random_shift(pt) for pt in src_pts], dtype=np.float32)

        M = cv2.getPerspectiveTransform(src_pts, dst_pts)
        warped_np = cv2.warpPerspective(img_np, M, (w, h), borderValue=255)

        coords_hom = np.hstack([coords, np.ones((coords.shape[0], 1))])
        coords_proj = (M @ coords_hom.T).T
        coords_proj /= coords_proj[:, [2]]
        coords_warped = coords_proj[:, :2]

        img_tensor = torch.from_numpy(warped_np).unsqueeze(0).float() / 255.0

        coords_int = coords_warped.astype(int)
        valid_mask = []
        for x, y in coords_int:
            if 0 <= x < w and 0 <= y < h and img_tensor[0, y, x] < 0.99:
                valid_mask.append(True)
            else:
                valid_mask.append(False)

        coords_warped = coords_warped[valid_mask]
        return img_tensor, coords_warped

    def coords_to_heatmap(self, coords, img_size, sigma=2, t_size=False):
        H, W = img_size
        heatmap = torch.zeros((1, H, W), dtype=torch.float32)

        if t_size:
            sigma = sigma * (t_size / 15)
        # else:
        tmp_size = int(3 * sigma)

        for x, y in coords:
            x = int(round(x))
            y = int(round(y))
            if x < 0 or y < 0 or x >= W or y >= H:
                continue

            x0 = max(0, x - tmp_size)
            x1 = min(W, x + tmp_size + 1)
            y0 = max(0, y - tmp_size)
            y1 = min(H, y + tmp_size + 1)

            yy, xx = torch.meshgrid(
                torch.arange(y0, y1, dtype=torch.float32),
                torch.arange(x0, x1, dtype=torch.float32),
                indexing='ij'
            )
            g = torch.exp(-((xx - x) ** 2 + (yy - y) ** 2) / (2 * sigma ** 2))
            g = g / g.max()

            heatmap[0, y0:y1, x0:x1] = torch.maximum(heatmap[0, y0:y1, x0:x1], g)

        return heatmap

    def apply_random_zoom(self, img_tensor, coords, zoom_range=(0.8, 1.2)):
        c, h, w = img_tensor.shape
        scale = random.uniform(*zoom_range)
        new_h = int(h * scale)
        new_w = int(w * scale)

        img_np = (img_tensor.squeeze(0).numpy() * 255).astype(np.uint8)
        img_resized = cv2.resize(img_np, (new_w, new_h), interpolation=cv2.INTER_LINEAR)

        if scale > 1.0:
            x0 = (new_w - w) // 2
            y0 = (new_h - h) // 2
            img_zoomed = img_resized[y0:y0 + h, x0:x0 + w]
            offset = np.array([-x0, -y0])
        else:
            pad_left = (w - new_w) // 2
            pad_top = (h - new_h) // 2
            img_zoomed = np.full((h, w), 255, dtype=np.uint8)
            img_zoomed[pad_top:pad_top + new_h, pad_left:pad_left + new_w] = img_resized
            offset = np.array([pad_left, pad_top])

        coords = coords * scale + offset

        img_tensor = torch.from_numpy(img_zoomed).unsqueeze(0).float() / 255.0

        coords_int = coords.astype(int)
        valid_mask = []
        for x, y in coords_int:
            if 0 <= x < w and 0 <= y < h and img_tensor[0, y, x] < 0.99:
                valid_mask.append(True)
            else:
                valid_mask.append(False)

        coords = coords[valid_mask]
        return img_tensor, coords

# Dataloader Demo

In [8]:
# Output directory for debug images
output_dir = "../data/dot_detection/debug_output"
os.makedirs(output_dir, exist_ok=True)

# Create dataset and dataloader
dataset = BrailleDataset(root_dir="../data/dot_detection/prepared-patches", transform_prob=0.2)
loader = DataLoader(dataset, batch_size=1, shuffle=False)

# Iterate and save debug patches
for idx, (img_tensor, heatmap_tensor) in enumerate(loader):
    img = to_pil_image(img_tensor[0])  # Convert image tensor to PIL

    # Convert heatmap to grayscale image
    heatmap = heatmap_tensor[0, 0]  # Shape: H x W
    heatmap_np = heatmap.clamp(0, 1).numpy()
    heatmap_img = Image.fromarray((cm.inferno(heatmap_np)[:, :, :3] * 255).astype('uint8'))  # Apply colormap

    # Resize heatmap to match image size (safety)
    heatmap_img = heatmap_img.resize(img.size, resample=Image.BILINEAR)

    # Option A: Combine side by side
    combined = Image.new("RGB", (img.width * 2, img.height))
    combined.paste(img, (0, 0))
    combined.paste(heatmap_img, (img.width, 0))

    # Option B: Or overlay heatmap on image (if you prefer)
    # overlay = Image.blend(img.convert("RGB"), heatmap_img, alpha=0.5)

    combined.save(os.path.join(output_dir, f"debug_img_{idx:05d}.png"))

    if idx >= 49:
        break  # Save first 50 only

print(f"Saved {idx + 1} debug images to {output_dir}")

Saved 50 debug images to ../data/dot_detection/debug_output


In [37]:
# Output directory for debug images
output_dir = "../data/dot_detection/MAN/debug_output"
os.makedirs(output_dir, exist_ok=True)

# Create dataset and dataloader
dataset = BrailleDataset(root_dir="../data/dot_detection/MAN/prepared-patches", transform_prob=0.2)
loader = DataLoader(dataset, batch_size=1, shuffle=False)

# Iterate and save debug patches
for idx, (img_tensor, heatmap_tensor) in enumerate(loader):
    img = to_pil_image(img_tensor[0])  # Convert image tensor to PIL

    # Convert heatmap to grayscale image
    heatmap = heatmap_tensor[0, 0]  # Shape: H x W
    heatmap_np = heatmap.clamp(0, 1).numpy()
    heatmap_img = Image.fromarray((cm.inferno(heatmap_np)[:, :, :3] * 255).astype('uint8'))  # Apply colormap

    # Resize heatmap to match image size (safety)
    heatmap_img = heatmap_img.resize(img.size, resample=Image.BILINEAR)

    # Option A: Combine side by side
    combined = Image.new("RGB", (img.width * 2, img.height))
    combined.paste(img, (0, 0))
    combined.paste(heatmap_img, (img.width, 0))

    # Option B: Or overlay heatmap on image (if you prefer)
    # overlay = Image.blend(img.convert("RGB"), heatmap_img, alpha=0.5)

    combined.save(os.path.join(output_dir, f"debug_img_{idx:05d}.png"))

    if idx >= 49:
        break  # Save first 50 only

print(f"Saved {idx + 1} debug images to {output_dir}")

Saved 8 debug images to ../data/dot_detection/MAN/debug_output


# Training Model

In [10]:
# === Config ===
batch_size = 8
num_epochs = 30
lr = 1e-3
train_val_ratio = 0.8
save_dir = "../models/dot_detection/checkpoints"
os.makedirs(save_dir, exist_ok=True)

# === Load dataset ===
train_dataset = BrailleDataset(root_dir="../data/dot_detection/prepared-patches", is_train=True, train_val_ratio=train_val_ratio)
val_dataset = BrailleDataset(root_dir="../data/dot_detection/prepared-patches", is_train=False, train_val_ratio=train_val_ratio)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

# === Model ===
model = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.L1Loss()

# == Training loop ===
best_val_loss = float('inf')
for epoch in range(1, num_epochs + 1):
    model.train()
    train_loss = 0

    for img, heatmap in tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs} [Train]"):
        img = img.to(device)
        heatmap = heatmap.to(device)

        # debug display image + heatmap
        # img_np = img[0].cpu().numpy().squeeze(0)  # Shape: H x W
        # img_np = (img_np * 255).astype(np.uint8)  # Convert to uint8
        # img_pil = Image.fromarray(img_np, mode='L')
        # display(img_pil)
        # heatmap_np = heatmap[0, 0].cpu().numpy()
        # heatmap_img = Image.fromarray((cm.inferno(heatmap_np)[:, :, :3] * 255).astype('uint8'))
        # display(heatmap_img)

        pred = model(img)
        loss = criterion(pred, heatmap)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * img.size(0)

    train_loss /= len(train_loader.dataset)

    # === Validation ===
    model.eval()
    val_loss = 0

    with torch.no_grad():
        for img, heatmap in tqdm(val_loader, desc=f"Epoch {epoch}/{num_epochs} [Val]"):
            img = img.to(device)
            heatmap = heatmap.to(device)

            pred = model(img)
            loss = criterion(pred, heatmap)

            val_loss += loss.item() * img.size(0)

    val_loss /= len(val_loader.dataset)

    print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        ckpt_path = os.path.join(save_dir, f"unet_best.pth")
        torch.save(model.state_dict(), ckpt_path)
        print(f"Saved best model to {ckpt_path}")

Epoch 1/30 [Train]: 100%|██████████| 137/137 [00:41<00:00,  3.32it/s]
Epoch 1/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.11it/s]


Epoch 1: Train Loss = 0.1251, Val Loss = 0.0564
Saved best model to ../models/dot_detection/checkpoints\unet_best.pth


Epoch 2/30 [Train]: 100%|██████████| 137/137 [00:39<00:00,  3.48it/s]
Epoch 2/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.16it/s]


Epoch 2: Train Loss = 0.0378, Val Loss = 0.0299
Saved best model to ../models/dot_detection/checkpoints\unet_best.pth


Epoch 3/30 [Train]: 100%|██████████| 137/137 [00:39<00:00,  3.50it/s]
Epoch 3/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.14it/s]


Epoch 3: Train Loss = 0.0269, Val Loss = 0.0237
Saved best model to ../models/dot_detection/checkpoints\unet_best.pth


Epoch 4/30 [Train]: 100%|██████████| 137/137 [00:39<00:00,  3.50it/s]
Epoch 4/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.10it/s]


Epoch 4: Train Loss = 0.0233, Val Loss = 0.0229
Saved best model to ../models/dot_detection/checkpoints\unet_best.pth


Epoch 5/30 [Train]: 100%|██████████| 137/137 [00:38<00:00,  3.52it/s]
Epoch 5/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.15it/s]


Epoch 5: Train Loss = 0.0218, Val Loss = 0.0224
Saved best model to ../models/dot_detection/checkpoints\unet_best.pth


Epoch 6/30 [Train]: 100%|██████████| 137/137 [00:39<00:00,  3.49it/s]
Epoch 6/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.15it/s]


Epoch 6: Train Loss = 0.0211, Val Loss = 0.0223
Saved best model to ../models/dot_detection/checkpoints\unet_best.pth


Epoch 7/30 [Train]: 100%|██████████| 137/137 [00:39<00:00,  3.50it/s]
Epoch 7/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.16it/s]


Epoch 7: Train Loss = 0.0209, Val Loss = 0.0383


Epoch 8/30 [Train]: 100%|██████████| 137/137 [00:39<00:00,  3.49it/s]
Epoch 8/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.15it/s]


Epoch 8: Train Loss = 0.0209, Val Loss = 0.0186
Saved best model to ../models/dot_detection/checkpoints\unet_best.pth


Epoch 9/30 [Train]: 100%|██████████| 137/137 [00:39<00:00,  3.49it/s]
Epoch 9/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.13it/s]


Epoch 9: Train Loss = 0.0206, Val Loss = 0.0173
Saved best model to ../models/dot_detection/checkpoints\unet_best.pth


Epoch 10/30 [Train]: 100%|██████████| 137/137 [00:39<00:00,  3.50it/s]
Epoch 10/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.15it/s]


Epoch 10: Train Loss = 0.0206, Val Loss = 0.0182


Epoch 11/30 [Train]: 100%|██████████| 137/137 [00:38<00:00,  3.52it/s]
Epoch 11/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.14it/s]


Epoch 11: Train Loss = 0.0200, Val Loss = 0.0171
Saved best model to ../models/dot_detection/checkpoints\unet_best.pth


Epoch 12/30 [Train]: 100%|██████████| 137/137 [00:38<00:00,  3.52it/s]
Epoch 12/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.15it/s]


Epoch 12: Train Loss = 0.0189, Val Loss = 0.0156
Saved best model to ../models/dot_detection/checkpoints\unet_best.pth


Epoch 13/30 [Train]: 100%|██████████| 137/137 [00:39<00:00,  3.50it/s]
Epoch 13/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.17it/s]


Epoch 13: Train Loss = 0.0186, Val Loss = 0.0186


Epoch 14/30 [Train]: 100%|██████████| 137/137 [00:39<00:00,  3.48it/s]
Epoch 14/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.16it/s]


Epoch 14: Train Loss = 0.0190, Val Loss = 0.0146
Saved best model to ../models/dot_detection/checkpoints\unet_best.pth


Epoch 15/30 [Train]: 100%|██████████| 137/137 [00:39<00:00,  3.47it/s]
Epoch 15/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.10it/s]


Epoch 15: Train Loss = 0.0186, Val Loss = 0.0145
Saved best model to ../models/dot_detection/checkpoints\unet_best.pth


Epoch 16/30 [Train]: 100%|██████████| 137/137 [00:39<00:00,  3.48it/s]
Epoch 16/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.14it/s]


Epoch 16: Train Loss = 0.0177, Val Loss = 0.0159


Epoch 17/30 [Train]: 100%|██████████| 137/137 [00:39<00:00,  3.51it/s]
Epoch 17/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.15it/s]


Epoch 17: Train Loss = 0.0184, Val Loss = 0.0134
Saved best model to ../models/dot_detection/checkpoints\unet_best.pth


Epoch 18/30 [Train]: 100%|██████████| 137/137 [00:39<00:00,  3.47it/s]
Epoch 18/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.13it/s]


Epoch 18: Train Loss = 0.0178, Val Loss = 0.0396


Epoch 19/30 [Train]: 100%|██████████| 137/137 [00:39<00:00,  3.51it/s]
Epoch 19/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.16it/s]


Epoch 19: Train Loss = 0.0181, Val Loss = 0.0140


Epoch 20/30 [Train]: 100%|██████████| 137/137 [00:38<00:00,  3.53it/s]
Epoch 20/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.17it/s]


Epoch 20: Train Loss = 0.0177, Val Loss = 0.0178


Epoch 21/30 [Train]: 100%|██████████| 137/137 [00:38<00:00,  3.52it/s]
Epoch 21/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.17it/s]


Epoch 21: Train Loss = 0.0176, Val Loss = 0.0380


Epoch 22/30 [Train]: 100%|██████████| 137/137 [00:39<00:00,  3.51it/s]
Epoch 22/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.15it/s]


Epoch 22: Train Loss = 0.0176, Val Loss = 0.0132
Saved best model to ../models/dot_detection/checkpoints\unet_best.pth


Epoch 23/30 [Train]: 100%|██████████| 137/137 [00:38<00:00,  3.53it/s]
Epoch 23/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.16it/s]


Epoch 23: Train Loss = 0.0178, Val Loss = 0.0148


Epoch 24/30 [Train]: 100%|██████████| 137/137 [00:38<00:00,  3.55it/s]
Epoch 24/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.15it/s]


Epoch 24: Train Loss = 0.0174, Val Loss = 0.0140


Epoch 25/30 [Train]: 100%|██████████| 137/137 [00:38<00:00,  3.54it/s]
Epoch 25/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.16it/s]


Epoch 25: Train Loss = 0.0171, Val Loss = 0.0137


Epoch 26/30 [Train]: 100%|██████████| 137/137 [00:39<00:00,  3.51it/s]
Epoch 26/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.17it/s]


Epoch 26: Train Loss = 0.0174, Val Loss = 0.0146


Epoch 27/30 [Train]: 100%|██████████| 137/137 [00:38<00:00,  3.54it/s]
Epoch 27/30 [Val]: 100%|██████████| 35/35 [00:15<00:00,  2.21it/s]


Epoch 27: Train Loss = 0.0170, Val Loss = 0.0141


Epoch 28/30 [Train]: 100%|██████████| 137/137 [00:38<00:00,  3.59it/s]
Epoch 28/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.18it/s]


Epoch 28: Train Loss = 0.0171, Val Loss = 0.0150


Epoch 29/30 [Train]: 100%|██████████| 137/137 [00:39<00:00,  3.51it/s]
Epoch 29/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.14it/s]


Epoch 29: Train Loss = 0.0172, Val Loss = 0.0147


Epoch 30/30 [Train]: 100%|██████████| 137/137 [00:38<00:00,  3.58it/s]
Epoch 30/30 [Val]: 100%|██████████| 35/35 [00:16<00:00,  2.18it/s]

Epoch 30: Train Loss = 0.0171, Val Loss = 0.0138





# Testing Model

In [11]:
# === Config ===
image_dir = "../data/dot_detection/test-data"
output_dir = "../data/dot_detection/results/bulk"
model_path = "../models/dot_detection/checkpoints/unet_best.pth"
sigma = 3
peak_threshold = 0.5
input_size = 384  # NxN target input

os.makedirs(output_dir, exist_ok=True)

In [12]:
# === Load model ===
model = UNet(in_channels=1, out_channels=1).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

UNet(
  (in_conv): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (down): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (4): Batc

In [13]:
# === Retinex reflectance extractor ===
def compute_retinex_reflectance(img_pil, sigma=30):
    img_np = np.array(img_pil).astype(np.float32) + 1.0  # Avoid log(0)
    log_img = np.log(img_np)

    blurred = cv2.GaussianBlur(img_np, (0, 0), sigmaX=sigma, sigmaY=sigma)
    log_blur = np.log(blurred + 1.0)

    reflectance = log_img - log_blur
    reflectance = (reflectance - reflectance.min()) / (reflectance.max() - reflectance.min()) * 255.0
    reflectance = reflectance.astype(np.uint8)
    return Image.fromarray(reflectance)

In [14]:
# === NMS ===
def extract_peaks_from_heatmap(heatmap, threshold=0.5, dist=3):
    heatmap = heatmap.squeeze(0)
    pooled = F.max_pool2d(heatmap.unsqueeze(0).unsqueeze(0), kernel_size=2 * dist + 1, stride=1, padding=dist)
    peak_mask = (heatmap == pooled.squeeze()) & (heatmap > threshold)
    coords = peak_mask.nonzero(as_tuple=False)  # (y, x)
    coords = coords[:, [1, 0]].cpu().numpy()  # (x, y)
    return coords

In [15]:
# === Load all test images ===
image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.bmp', '*.tif', '*.tiff']
img_paths = []
for ext in image_extensions:
    img_paths.extend(glob.glob(os.path.join(image_dir, ext)))
img_paths = sorted(img_paths)

In [16]:
# === Run inference ===
for idx, img_path in enumerate(img_paths):
    original_img = Image.open(img_path).convert("L")
    original_img_resized = original_img.resize((input_size, input_size), Image.BILINEAR)

    # Compute reflectance map via Retinex
    reflectance_img = compute_retinex_reflectance(original_img_resized, sigma=50)
    img_tensor = to_tensor(reflectance_img).unsqueeze(0).to(device)

    with torch.no_grad():
        pred = model(img_tensor)  # [1, 1, N, N]
        pred = pred.squeeze(0).cpu()

    # Extract dot coords in resized image
    coords = extract_peaks_from_heatmap(pred, threshold=peak_threshold, dist=int(sigma * 1.5))

    # Draw on reflectance image
    reflectance_with_boxes = reflectance_img.convert("RGB")
    draw = ImageDraw.Draw(reflectance_with_boxes)
    for x, y in coords:
        draw.rectangle([x - 3, y - 3, x + 3, y + 3], outline="green", width=2)

    # === Save side-by-side plot with 3 panels ===
    base_name = os.path.basename(img_path)
    heatmap_save_path = os.path.join(output_dir, f"{os.path.splitext(base_name)[0]}_viz.png")

    fig, axs = plt.subplots(1, 3, figsize=(15, 5))

    axs[0].imshow(original_img_resized, cmap="gray")
    axs[0].set_title("Original Image")
    axs[0].axis("off")

    axs[1].imshow(reflectance_with_boxes)
    axs[1].set_title("Reflectance + Prediction")
    axs[1].axis("off")

    axs[2].imshow(pred.squeeze(), cmap="hot", interpolation="nearest")
    axs[2].set_title("Predicted Heatmap")
    axs[2].axis("off")

    plt.tight_layout()
    plt.savefig(heatmap_save_path)
    plt.close(fig)

print(f"Saved {len(img_paths)} prediction results with visualizations to {output_dir}")

Saved 12 prediction results with visualizations to ../data/dot_detection/results/bulk


# Finetuning Model

In [None]:
# === Config ===
batch_size = 8
num_epochs = 30
lr = 1e-3 / 10  # Lower learning rate for fine-tuning
train_val_ratio = 0.8
save_dir = "../models/dot_detection/checkpoints"
os.makedirs(save_dir, exist_ok=True)

# === Load dataset ===
train_dataset = BrailleDataset(root_dir="../data/dot_detection/MAN/prepared-patches", is_train=True, train_val_ratio=train_val_ratio)
val_dataset = BrailleDataset(root_dir="../data/dot_detection/MAN/prepared-patches", is_train=False, train_val_ratio=train_val_ratio)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

# === Model ===
model = UNet().to(device)
model.load_state_dict(torch.load(os.path.join(save_dir, "unet_best.pth"), map_location=device))
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.L1Loss()

# == Training loop ===
best_val_loss = float('inf')
debug_count = 0
for epoch in range(1, num_epochs + 1):
    model.train()
    train_loss = 0

    for img, heatmap in tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs} [Train]"):
        img = img.to(device)
        heatmap = heatmap.to(device)

        # debug display image + heatmap
        if debug_count < 20:
            img_np = img[0].cpu().numpy().squeeze(0)  # Shape: H x W
            img_np = (img_np * 255).astype(np.uint8)  # Convert to uint8
            img_pil = Image.fromarray(img_np, mode='L')
            # display(img_pil)
            heatmap_np = heatmap[0, 0].cpu().numpy()
            heatmap_img = Image.fromarray((cm.inferno(heatmap_np)[:, :, :3] * 255).astype('uint8'))
            # display(heatmap_img)
            debug_count += 1

        pred = model(img)
        loss = criterion(pred, heatmap)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * img.size(0)

    train_loss /= len(train_loader.dataset)

    # === Validation ===
    model.eval()
    val_loss = 0

    with torch.no_grad():
        for img, heatmap in tqdm(val_loader, desc=f"Epoch {epoch}/{num_epochs} [Val]"):
            img = img.to(device)
            heatmap = heatmap.to(device)

            pred = model(img)
            loss = criterion(pred, heatmap)

            val_loss += loss.item() * img.size(0)

    val_loss /= len(val_loader.dataset)

    print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        ckpt_path = os.path.join(save_dir, f"unet_finetuned_best.pth")
        torch.save(model.state_dict(), ckpt_path)
        print(f"Saved best model to {ckpt_path}")

Epoch 1/30 [Train]: 100%|██████████| 1/1 [00:00<00:00,  2.29it/s]
Epoch 1/30 [Val]: 100%|██████████| 1/1 [00:00<00:00, 10.09it/s]


Epoch 1: Train Loss = 0.0779, Val Loss = 0.1374
Saved best model to ../models/dot_detection/checkpoints\unet_finetuned_best.pth


Epoch 2/30 [Train]: 100%|██████████| 1/1 [00:00<00:00,  3.69it/s]
Epoch 2/30 [Val]: 100%|██████████| 1/1 [00:00<00:00, 11.09it/s]


Epoch 2: Train Loss = 0.0532, Val Loss = 0.1388


Epoch 3/30 [Train]: 100%|██████████| 1/1 [00:00<00:00,  3.87it/s]
Epoch 3/30 [Val]: 100%|██████████| 1/1 [00:00<00:00, 11.15it/s]


Epoch 3: Train Loss = 0.0442, Val Loss = 0.1392


Epoch 4/30 [Train]: 100%|██████████| 1/1 [00:00<00:00,  3.84it/s]
Epoch 4/30 [Val]: 100%|██████████| 1/1 [00:00<00:00, 10.63it/s]


Epoch 4: Train Loss = 0.0784, Val Loss = 0.1408


Epoch 5/30 [Train]: 100%|██████████| 1/1 [00:00<00:00,  3.79it/s]
Epoch 5/30 [Val]:   0%|          | 0/1 [00:00<?, ?it/s]

# Testing Finetuned Model

In [18]:
# === Config ===
image_dir = "../data/dot_detection/test-data"
output_dir = "../data/dot_detection/results/finetuned"
model_path = "../models/dot_detection/checkpoints/unet_finetuned_best.pth"
sigma = 3
peak_threshold = 0.5
input_size = 384  # NxN target input

os.makedirs(output_dir, exist_ok=True)

In [19]:
# === Load model ===
model = UNet(in_channels=1, out_channels=1).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

UNet(
  (in_conv): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (down): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (4): Batc

In [20]:
# === Load all test images ===
image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.bmp', '*.tif', '*.tiff']
img_paths = []
for ext in image_extensions:
    img_paths.extend(glob.glob(os.path.join(image_dir, ext)))
img_paths = sorted(img_paths)

In [21]:
# === Run inference ===
for idx, img_path in enumerate(img_paths):
    original_img = Image.open(img_path).convert("L")
    original_img_resized = original_img.resize((input_size, input_size), Image.BILINEAR)

    # Compute reflectance map via Retinex
    reflectance_img = compute_retinex_reflectance(original_img_resized, sigma=50)
    img_tensor = to_tensor(reflectance_img).unsqueeze(0).to(device)

    with torch.no_grad():
        pred = model(img_tensor)  # [1, 1, N, N]
        pred = pred.squeeze(0).cpu()

    # Extract dot coords in resized image
    coords = extract_peaks_from_heatmap(pred, threshold=peak_threshold, dist=int(sigma * 1.5))

    # Draw on reflectance image
    reflectance_with_boxes = reflectance_img.convert("RGB")
    draw = ImageDraw.Draw(reflectance_with_boxes)
    for x, y in coords:
        draw.rectangle([x - 3, y - 3, x + 3, y + 3], outline="green", width=2)

    # === Save side-by-side plot with 3 panels ===
    base_name = os.path.basename(img_path)
    heatmap_save_path = os.path.join(output_dir, f"{os.path.splitext(base_name)[0]}_viz.png")

    fig, axs = plt.subplots(1, 3, figsize=(15, 5))

    axs[0].imshow(original_img_resized, cmap="gray")
    axs[0].set_title("Original Image")
    axs[0].axis("off")

    axs[1].imshow(reflectance_with_boxes)
    axs[1].set_title("Reflectance + Prediction")
    axs[1].axis("off")

    axs[2].imshow(pred.squeeze(), cmap="hot", interpolation="nearest")
    axs[2].set_title("Predicted Heatmap")
    axs[2].axis("off")

    plt.tight_layout()
    plt.savefig(heatmap_save_path)
    plt.close(fig)

print(f"Saved {len(img_paths)} prediction results with visualizations to {output_dir}")

KeyboardInterrupt: 