In [None]:
from modules.hardnet.eval_metrics import ErrorRateAt95Recall
from modules.hardnet.losses import distance_matrix_vector
from modules.hardnet.models import HardNet

import os
import re

import torch
import torchvision

image_descriptions = {
    1: "very oblique, close",
    3: "very oblique, far",
    5: "medium oblique, close",
    7: "medium oblique, far",
    9: "fronto-parallel, close",
    11: "fronto-parallel, far",
    13: "fronto-parallel, very far",
}

for date in [
       "2025_11_05",
       "2025_11_06",
    ]:
    for scale in [96, 128]:
        checkpoint = 199
        print(f"\n# Evaluating {date} with lambda={scale}")
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = HardNet(transform="PTN", coords="log", patch_size=32, scale=scale)
        model.load_state_dict(torch.load(f"./data/models/{date}_blobinator_{scale}/model_checkpoint_{checkpoint}.pth", weights_only=False)["state_dict"])
        model.to(device)
        model.eval()
        positive_path_regex = re.compile("(\\d+)_(\\d+).png")
        dataset_path = "./data/datasets/new/real/validation"
        patch_files = os.listdir(os.path.join(dataset_path, f"patches/{scale}/positives"))

        overall_anchor_features = []
        overall_positive_features = []
        all_anchor_patches = []
        all_positive_patches = []
        overall_garbage_features = []
        fpr95_sum = 0
        fpr95_num = 0
        for i in range(0, 123):
            
            regex = re.compile(f"{i:04}_\\d+\\.png")
            patches = list(filter(lambda f: regex.match(f) is not None, patch_files))
            if len(patches) == 0:
                continue

            anchor_patches = torch.empty((len(patches), 1, 32, 32)).to(device)
            positive_patches = torch.empty((len(patches), 1, 32, 32)).to(device)

            for j, patch_file in enumerate(patches):
                match = positive_path_regex.search(os.path.basename(patch_file))
                board_idx, blob_idx = match.group(1), match.group(2)
                positive_patches[j] = torchvision.io.decode_image(os.path.join(dataset_path, f"patches/{scale}/positives/{board_idx}_{blob_idx}.png"), torchvision.io.ImageReadMode.GRAY).to(torch.float32) / 255
                anchor_patches[j] = torchvision.io.decode_image(os.path.join(dataset_path, f"patches/{scale}/anchors/{board_idx}_{blob_idx}.png"), torchvision.io.ImageReadMode.GRAY).to(torch.float32) / 255
            garbage_patch_files = os.listdir(os.path.join(dataset_path, f"patches/{scale}/garbage"))
            garbage_patch_files = list(filter(lambda f: regex.match(f) is not None, garbage_patch_files))
            garbage_patches = torch.empty((len(garbage_patch_files), 1, 32, 32)).to(device)
            for j, patch_file in enumerate(garbage_patch_files):
                garbage_patches[j] = torchvision.io.decode_image(os.path.join(dataset_path, f"patches/{scale}/garbage/{patch_file}"), torchvision.io.ImageReadMode.GRAY)

            anchor_features, _ = model(anchor_patches)
            positive_features, _ = model(positive_patches)
            garbage_features, _ = model(garbage_patches)

            distances = distance_matrix_vector(anchor_features, torch.concat((positive_features, garbage_features))).detach().cpu().numpy().flatten()
            labels = torch.eye(anchor_features.size(0), positive_features.size(0) + garbage_features.size(0)).cpu().numpy().flatten()
            fpr95 = ErrorRateAt95Recall(labels, 1.0 / (distances + 1e-8))
            print(f"{fpr95=} for image {i} with {anchor_features.size(0)} features")
            fpr95_num += distances.size
            fpr95_sum += distances.size * fpr95

        avg_fpr95 = fpr95_sum / fpr95_num
        print(f"{avg_fpr95=}")


# Evaluating 2025_11_05 with lambda=96
fpr95=0.37279790293081366 for image 23 with 317 features
fpr95=0.2711958022645678 for image 28 with 213 features
fpr95=0.6423611111111112 for image 76 with 32 features
fpr95=0.19751552795031055 for image 79 with 35 features
fpr95=0.7236874038463568 for image 85 with 487 features
fpr95=0.04630341587756349 for image 99 with 436 features
fpr95=0.11487549754697769 for image 101 with 416 features
fpr95=0.0704004264897524 for image 103 with 184 features
fpr95=0.08747338735746213 for image 107 with 670 features
fpr95=0.6636419778812529 for image 108 with 514 features
avg_fpr95=0.3129410234611788

# Evaluating 2025_11_05 with lambda=128
fpr95=0.2331394740383034 for image 23 with 317 features
fpr95=0.2283899475283071 for image 28 with 213 features
fpr95=0.7599206349206349 for image 76 with 32 features
fpr95=0.15610766045548655 for image 79 with 35 features
fpr95=0.6429827097547541 for image 85 with 487 features
fpr95=0.02893963492347718 for image 99 with 