In [5]:
from modules.hardnet.losses import distance_matrix_vector
from modules.hardnet.models import HardNet

import os
import re

import torch
import torchvision

device = torch.device("cpu")
model = HardNet(transform="PTN", coords="log", patch_size=32, scale=128)
model.load_state_dict(torch.load("./data/models/2025_11_01_blobinator_128/model_checkpoint_147.pth", weights_only=False)["state_dict"])
model.to(device)
positive_path_regex = re.compile("(\\d+)_(\\d+).png")
dataset_path = "./data/very_hard/validation"
patch_files = os.listdir(os.path.join(dataset_path, "patches/128/positives"))

overall_anchor_features = []
overall_positive_features = []
all_anchor_patches = []
all_positive_patches = []
overall_garbage_features = []

for i in range(10):
    regex = re.compile(f"{i:04}_\\d+\\.png")
    patches = list(filter(lambda f: regex.match(f) is not None, patch_files))

    anchor_patches = torch.empty((len(patches), 1, 32, 32)).to(device)
    positive_patches = torch.empty((len(patches), 1, 32, 32)).to(device)

    for j, patch_file in enumerate(patches):
        match = positive_path_regex.search(os.path.basename(patch_file))
        board_idx, blob_idx = match.group(1), match.group(2)
        positive_patches[j] = torchvision.io.decode_image(os.path.join(dataset_path, f"patches/128/positives/{board_idx}_{blob_idx}.png"), torchvision.io.ImageReadMode.GRAY).to(torch.float32) / 255
        anchor_patches[j] = torchvision.io.decode_image(os.path.join(dataset_path, f"patches/128/anchors/{board_idx}_{blob_idx}.png"), torchvision.io.ImageReadMode.GRAY).to(torch.float32) / 255
    garbage_patch_files = os.listdir(os.path.join(dataset_path, "patches/128/garbage"))
    garbage_patch_files = list(filter(lambda f: regex.match(f) is not None, garbage_patch_files))
    garbage_patches = torch.empty((len(garbage_patch_files), 1, 32, 32))
    for j, patch_file in enumerate(garbage_patch_files):
        garbage_patches[j] = torchvision.io.decode_image(os.path.join(dataset_path, f"patches/128/garbage/{patch_file}"), torchvision.io.ImageReadMode.GRAY)

    all_anchor_patches.append(anchor_patches)

    anchor_features, _ = model(anchor_patches)
    positive_features, _ = model(positive_patches)
    garbage_features, _ = model(garbage_patches)

    overall_anchor_features.append(anchor_features)
    overall_positive_features.append(positive_features)
    overall_garbage_features.append(garbage_features)

    distances = distance_matrix_vector(anchor_features, torch.concat((positive_features, garbage_features)))
    # indices = distances.argsort(dim=0)
    # matches = torch.stack([torch.arange(0, anchor_features.size(0), dtype=int).to(device), indices[0]])
    # true_positives = torch.where(matches[0] == matches[1], 1, 0).sum()
    # false_positives = torch.where(matches[0] != matches[1], 1, 0).sum()
    # print(f"Positive to anchor matching for image {i} FPR={false_positives / (true_positives + false_positives)}")

    # # distances = distance_matrix_vector(anchor_features, positive_features)
    indices = distances.argsort(dim=1)
    matches = torch.stack([torch.arange(0, positive_features.size(0), dtype=int).to(device), indices[:,0]])
    true_positives = torch.where(matches[0] == matches[1], 1, 0).sum()
    false_positives = torch.where(matches[0] != matches[1], 1, 0).sum()
    print(f"Anchor to positive matching for image {i} FPR={false_positives / (true_positives + false_positives)}")

distances = distance_matrix_vector(torch.concat(overall_anchor_features), torch.concat(overall_positive_features + overall_garbage_features))

indices = distances.argsort(dim=1)
matches = torch.stack([torch.arange(0, distances.size(0), dtype=int).to(device), indices[:,0]])
true_positives = torch.where(matches[0] == matches[1], 1, 0).sum()
false_positives = torch.where(matches[0] != matches[1], 1, 0).sum()
print(f"Anchor to positive matching over all images FPR={false_positives / (true_positives + false_positives)}")

Anchor to positive matching for image 0 FPR=0.475862056016922
Anchor to positive matching for image 1 FPR=0.9770833253860474
Anchor to positive matching for image 2 FPR=0.7494692206382751
Anchor to positive matching for image 3 FPR=0.023463686928153038
Anchor to positive matching for image 4 FPR=0.8358209133148193
Anchor to positive matching for image 5 FPR=0.42391303181648254
Anchor to positive matching for image 6 FPR=0.9842519760131836
Anchor to positive matching for image 7 FPR=0.8828282952308655
Anchor to positive matching for image 8 FPR=0.8045454621315002
Anchor to positive matching for image 9 FPR=0.9756097793579102
Anchor to positive matching over all images FPR=0.7369179129600525
