# Project 3 - Fingerprint verification

[GitHub](https://github.com/siyu-hu/TMS016_Spatial_Statistics_and_Image_Analysis)

# Preprocess.py

In [None]:
"""
Unified fingerprint-image preprocessing module
Key features
1. Advanced enhancement pipeline
   Gamma correction → Zero-mean / unit-variance gray normalization →
   Orientation field estimation → Overlapping-block Gabor filtering → CLAHE

2. Backward-compatible public entry point*
   `normalize(image_path) → float32 array [300 * 300], range 0-1`

3. Batch processing 
   Saves two copies simultaneously  
     • Enhanced **tif** → ./data/original/<*_new>  
     • Training **npy**  → ./data/processed/<*_new>

What changed compared with the old version?
-------------------------------------------
✓ **CHANGED** - The former “simple” normalization was replaced; the external
                calling method stays the same.  
✓ **NEW**      - `normalize_gray`, core preprocessing functions,
                and a re-worked `batch_preprocess`.  
✓ **KEPT**     - `check_image_sizes` 
"""

import os
import cv2
import numpy as np
from tqdm import tqdm
import argparse


def check_image_sizes(folder_path, expected_size=(300, 300)):
    """Iterate through tif files and report those whose size is unexpected."""
    for fname in os.listdir(folder_path):
        if fname.lower().endswith('.tif'):
            path = os.path.join(folder_path, fname)
            img  = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
            if img is None:
                print(f" ERROR:  Failed to read {fname}")
                continue
            if img.shape != expected_size:
                print(f"  {fname} has size {img.shape} ≠ {expected_size}")

# ------------------------------------------------------------------------- #
#                      advanced pipeline by Qi Wang                     #
# ------------------------------------------------------------------------- #
def normalize_gray(img):
    """Zero-mean / unit-variance, then stretch to 0–255 (uint8)."""
    mean, std = img.mean(), img.std()
    z = (img - mean) / (std + 1e-5)
    z = (z - z.min()) / (z.max() - z.min()) * 255
    return z.astype(np.uint8)

def gamma_correction(img, gamma=0.3):
    inv = 1.0 / gamma
    table = ((np.arange(256) / 255.0) ** inv * 255).astype("uint8")
    return cv2.LUT(img, table)

def compute_orientation_field(img, block_size=16):
    rows, cols = img.shape
    sobelx = cv2.Sobel(img, cv2.CV_64F, 1, 0, ksize=3)
    sobely = cv2.Sobel(img, cv2.CV_64F, 0, 1, ksize=3)
    orient = np.zeros((rows // block_size, cols // block_size))

    for i in range(0, rows - block_size, block_size):
        for j in range(0, cols - block_size, block_size):
            gx = sobelx[i:i+block_size, j:j+block_size]
            gy = sobely[i:i+block_size, j:j+block_size]
            Vx = 2 * np.sum(gx * gy)
            Vy = np.sum(gx**2 - gy**2)
            orient[i//block_size, j//block_size] = 0.5 * np.arctan2(Vx, Vy)
    return orient

def gabor_filter_overlap(img, orientation_field,
                         block_size=16, freq=0.1, sigma=4.0):
    rows, cols = img.shape
    enhanced = np.zeros((rows, cols), np.float32)
    weights  = np.zeros((rows, cols), np.float32)
    stride   = block_size // 2

    for i in range(0, rows - block_size, stride):
        for j in range(0, cols - block_size, stride):
            theta  = orientation_field[i//block_size, j//block_size]
            kernel = cv2.getGaborKernel((block_size, block_size),
                                        sigma, theta, 1.0/freq, 0.5, 0,
                                        ktype=cv2.CV_32F)
            block    = img[i:i+block_size, j:j+block_size]
            filtered = cv2.filter2D(block, cv2.CV_32F, kernel)

            enhanced[i:i+block_size, j:j+block_size] += filtered
            weights[i:i+block_size, j:j+block_size]  += 1.0

    enhanced /= np.where(weights == 0, 1.0, weights)
    return cv2.normalize(enhanced, None, 0, 255,
                         cv2.NORM_MINMAX).astype(np.uint8)

def apply_clahe(img, clipLimit=3.0, tileGridSize=(8, 8)):
    """Local contrast enhancement (CLAHE)."""
    clahe = cv2.createCLAHE(clipLimit=clipLimit, tileGridSize=tileGridSize)
    return clahe.apply(img)

def preprocess_fingerprint(img_path):
    """
    Full preprocessing of a single image (returns uint8, original size).
    Gamma → normalize_gray → orientation field → Gabor → CLAHE
    """
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise ValueError(f"ERROR: Cannot open {img_path}")

    stage1 = gamma_correction(img)
    stage2 = normalize_gray(stage1)
    field  = compute_orientation_field(stage2)
    gabor  = gabor_filter_overlap(stage2, field)
    return apply_clahe(gabor)


def normalize(image_path, size=(300, 300)):
    """
    Backward-compatible wrapper.

    Parameters
    ----------
    image_path : str
        Path to a grayscale fingerprint image.
    size : tuple, default (300, 300)
        Output spatial size.

    Returns
    -------
    ndarray, float32, shape = size
        Pixel range 0 - 1, ready for the network.
    """
    img = preprocess_fingerprint(image_path)          # uint8
    img = cv2.resize(img, size)
    return img.astype('float32') / 255.0

# ------------------------------------------------------------------------- #
#       Batch processing (save tif + npy to the specified locations)        #
# ------------------------------------------------------------------------- #
def batch_preprocess(input_dir,
                     output_dir_npy="./data/processed/DB1_B_new_1",
                     output_dir_tif="./data/original/DB1_B_new_1",
                     size=(300, 300)):
    """
    Convert all tif/bmp/jpg/png in `input_dir`.
    • Enhanced tif (uint8, 0 - 255) → `output_dir_tif`
    • Training npy (float32, 0 - 1) → `output_dir_npy`
    """
    os.makedirs(output_dir_npy, exist_ok=True)
    os.makedirs(output_dir_tif, exist_ok=True)

    count = 0
    for fname in tqdm(os.listdir(input_dir), desc="preprocessing"):
        if fname.lower().endswith(('.tif', '.bmp', '.jpg', '.png')):
            in_path = os.path.join(input_dir, fname)
            try:
                img = preprocess_fingerprint(in_path)
                img = cv2.resize(img, size)

                stem = os.path.splitext(fname)[0]
                # Save tif
                cv2.imwrite(os.path.join(output_dir_tif, f"{stem}.tif"), img)
                # Save npy (float32, 0–1)
                np.save(os.path.join(output_dir_npy, f"{stem}.npy"),
                        img.astype('float32') / 255.0)
                count += 1
            except Exception as e:
                print(f" {fname}: {e}")

    print(f"\n {count} images saved to")
    print(f"   tif : {output_dir_tif}")
    print(f"   npy : {output_dir_npy}")

# ------------------------------------------------------------------------- #
#                                  CLI                                    #
# ------------------------------------------------------------------------- #
if __name__ == "__main__":

    parser = argparse.ArgumentParser(
        description="Fingerprint batch preprocessing (advanced pipeline)")
    parser.add_argument("--input", "-i", required=True,
                        help="Directory containing raw fingerprint images")
    parser.add_argument("--out-npy", default="./data/processed/DB1_B_new_1",
                        help="Destination folder for processed .npy files")
    parser.add_argument("--out-tif", default="./data/original/DB1_B_new_1",
                        help="Destination folder for enhanced .tif files")
    parser.add_argument("--size", type=int, default=300,
                        help="Output width/height after resize (default: 300)")

    args = parser.parse_args()
    
    batch_preprocess(input_dir=args.input,
                     output_dir_npy=args.out_npy,
                     output_dir_tif=args.out_tif,
                     size=(args.size, args.size))


# augment_images.py

In [None]:
import numpy as np
import cv2
import random

def random_affine_transform(image, max_rotation=10, max_scale=0.1):
    """
    Apply random affine transformation (small rotation + scaling) to a single image.

    Args:
        image: numpy array, single-channel (H, W)
        max_rotation: maximum rotation angle (± degrees)
        max_scale: maximum scale variation (± percentage)

    Returns:
        Transformed image with the same size.
    """

    h, w = image.shape

    # 1. Random rotation angle
    angle = random.uniform(-max_rotation, max_rotation)

    # 2. Random scaling factor
    scale = 1.0 + random.uniform(-max_scale, max_scale)

    # 3. Build the affine transformation matrix
    center = (w // 2, h // 2)
    M = cv2.getRotationMatrix2D(center, angle, scale)

    # 4. Apply the affine transformation
    transformed = cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101)

    return transformed


# create_train_pairs.py

In [None]:
import os
import numpy as np
import random
from itertools import combinations
from augment_images import random_affine_transform
from sklearn.utils import shuffle


def load_images_by_finger(data_path):
    """key = finger_id; value = image path list"""
    finger_dict = {}
    for file in os.listdir(data_path):
        if file.endswith('.npy'):
            finger_id = file.split('_')[0]
            finger_dict.setdefault(finger_id, []).append(os.path.join(data_path, file))
    for fid in finger_dict:
        finger_dict[fid] = sorted(finger_dict[fid])
    return finger_dict

def load_images_by_finger_tif(data_path):
    """load tif (for DB3_B test images)"""
    finger_dict = {}
    for file in os.listdir(data_path):
        if file.lower().endswith('.tif'):
            finger_id = file.split('_')[0]
            finger_dict.setdefault(finger_id, []).append(os.path.join(data_path, file))
    for fid in finger_dict:
        finger_dict[fid] = sorted(finger_dict[fid])
    return finger_dict


def create_pairs(finger_dict, selected_fingers, augment_positive=False, num_augments=2,balance_negatives=False):
    """
    Create positive and negative pairs.
    If augment_positive=True, perform data augmentation on positive pairs.
    If balance_negatives=True, sample negative pairs to match positive pairs count.
    
    Args:
        finger_dict: dict of {finger_id: list of npy file paths}
        selected_fingers: list of selected finger ids
        augment_positive: whether to augment positive pairs
        num_augments: how many augmentations per positive pair
        balance_negatives: whether to balance negative samples
    """
    pairs = []
    labels = []

    for fid in selected_fingers:
        images = finger_dict[fid]

        # postive pairs (same finger)
        for img1, img2 in combinations(images, 2):
            img1 = os.path.relpath(img1, start=".")  
            img2 = os.path.relpath(img2, start=".")
            pairs.append([img1, img2])
            labels.append(1) # label = 1 ->> positive pair (same finger)

            if augment_positive:
                # add data augmentation for positive pairs
                img1_arr = np.load(img1)
                img2_arr = np.load(img2)

                for _ in range(num_augments):
                    aug1 = random_affine_transform(img1_arr)
                    aug2 = random_affine_transform(img2_arr)

                    # save temporary augmented images in memory
                    pairs.append([aug1, aug2])
                    labels.append(1)
        
    num_positive = sum(1 for l in labels if l == 1)

    # negative pairs (different fingers, randomly selected)
    negative_pairs = []
    all_fingers = list(selected_fingers)
    for i in range(len(all_fingers)):
        for j in range(i+1, len(all_fingers)):
            imgs1 = finger_dict[all_fingers[i]]
            imgs2 = finger_dict[all_fingers[j]]
            for img1 in imgs1:
                for img2 in imgs2:
                    img1 = os.path.relpath(img1, start=".")  
                    img2 = os.path.relpath(img2, start=".")
                    negative_pairs.append(([img1, img2], 0))

    if balance_negatives:
        # Randomly sample same number of negative pairs as positive pairs
        negative_pairs = random.sample(negative_pairs, min(num_positive, len(negative_pairs)))

    # Add negative pairs
    for pair, label in negative_pairs:
        pairs.append(pair)
        labels.append(label)

    pairs, labels = shuffle(pairs, labels, random_state=42) 
    return pairs, labels



def save_pairs(pairs, labels, output_file):
    pairs = np.array(pairs, dtype=object)  # [img1_path, img2_path]
    labels = np.array(labels)
    np.savez(output_file, pairs=pairs, labels=labels)
    print(f"Saved {len(pairs)} pairs to {output_file}.npz")


# siamese_model.py

In [None]:
# siamese.py
import torch
import torch.nn as nn
import torch.nn.functional as F


class SiameseNetwork(nn.Module):
    """
    Lightweight Siamese CNN for 300 * 300 grayscale fingerprints.

    Input  : two tensors of shape [B, 1, 300, 300]
    Output : two L2-normalised embedding tensors of shape [B, embedding_dim]

    Total parameters ≈ 0.11 M (vs. ≈ 90 M in the original design).
    """

    def __init__(self, embedding_dim: int = 128):
        super().__init__()

        # -------- Convolutional feature extractor --------
        self.features = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, padding=2),   # [B, 16, 300, 300]
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),                  # [B, 16, 150, 150]

            nn.Conv2d(16, 32, kernel_size=3, padding=1),  # [B, 32, 150, 150]
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),                  # [B, 32, 75, 75]

            nn.Conv2d(32, 64, kernel_size=3, padding=1),  # [B, 64, 75, 75]
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3),                  # [B, 64, 25, 25]

            nn.Conv2d(64, 128, kernel_size=3, padding=1), # [B, 128, 25, 25]
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.AdaptiveAvgPool2d((1, 1))                  # [B, 128, 1, 1]
        )

        # -------- Projection to low-dimensional embedding --------
        self.projection = nn.Sequential(
            nn.Flatten(),                                 # [B, 128]
            nn.Linear(128, embedding_dim)                 # [B, embedding_dim]
        )

    # Forward pass for one branch
    def forward_once(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.projection(x)
        x = F.normalize(x, p=2, dim=1)                   # L2 normalisation
        return x

    # Siamese forward: return embeddings for both inputs
    def forward(self, x1: torch.Tensor, x2: torch.Tensor):
        return self.forward_once(x1), self.forward_once(x2)


# ------------- quick self-test -------------
if __name__ == "__main__":
    net = SiameseNetwork()
    dummy_a = torch.randn(4, 1, 300, 300)   # batch = 4
    dummy_b = torch.randn(4, 1, 300, 300)
    emb_a, emb_b = net(dummy_a, dummy_b)
    print(emb_a.shape, emb_b.shape)         # torch.Size([4, 128]) torch.Size([4, 128])


# train.py

In [None]:
import torch
from torch.utils.data import DataLoader
from siamese_model import SiameseNetwork
from utils import SiameseDataset, ContrastiveLoss, plot_loss, save_checkpoint
import os
from tqdm import tqdm
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau 
# --- add CLI -------------------------------------------
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--train_pairs", default=None)
parser.add_argument("--val_pairs",   default=None)
parser.add_argument("--finetune", action="store_true",
                    help="Continue training from --best_ckpt with typically fewer epochs / lower lr")

parser.add_argument("--lr",          type=float)
parser.add_argument("--num_epochs",  type=int)
parser.add_argument("--use_aug",     action="store_true")  
parser.add_argument("--balance_neg", action="store_true")  
parser.add_argument("--best_ckpt",   default=None)   
args = parser.parse_args()
# --------------------------------------------------------

def train_one_epoch(model, dataloader, loss_fn, optimizer, device):
    model.train()
    total_loss = 0.0

    for img1, img2, label in tqdm(dataloader, desc="Training", leave=False): 
        img1, img2, label = img1.to(device), img2.to(device), label.to(device)

        optimizer.zero_grad()
        output1, output2 = model(img1, img2)
        loss = loss_fn(output1, output2, label)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)


def validate(model, dataloader, loss_fn, device):
    model.eval()
    total_loss = 0.0

    with torch.no_grad():
        for img1, img2, label in tqdm(dataloader, desc="Validating", leave=False):
            img1, img2, label = img1.to(device), img2.to(device), label.to(device)
            output1, output2 = model(img1, img2)
            loss = loss_fn(output1, output2, label)
            total_loss += loss.item()

    return total_loss / len(dataloader)


def main():
    # -------------- Config --------------
    use_augmentation = args.use_aug # IMPORTANT: Set to True if training on augmented data
    balance_negatives = args.balance_neg  # IMPORTANT: Should match how you created training pairs
    finetune = args.finetune    # IMPORTANT: Set to True to continue training from the best checkpoint


    if use_augmentation:
        default_train_path = "./data/new_train_pairs_augmented.npz"
    else:
        default_train_path = "./data/new_train_pairs.npz"

    train_data_path = args.train_pairs or default_train_path
    val_data_path   = args.val_pairs or "./data/new_val_pairs.npz"


    batch_size = 8
    margin = 2.0

    if args.finetune:                
        default_lr     = 1e-4
        default_epoch  = 10
    else:                            
        default_lr     = 5e-4
        default_epoch  = 20

    learning_rate = args.lr if args.lr is not None else default_lr
    num_epochs    = args.num_epochs if args.num_epochs is not None else default_epoch

    # output ckpt 
    ckpt_filename = f"new_model_ft{finetune}_aug{use_augmentation}_bl{balance_negatives}_bs{batch_size}_ep{num_epochs}_lr{learning_rate}_mg{margin}.pt"
    ckpt_path = f"./checkpoints/{ckpt_filename}"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    best_ckpt_path = args.best_ckpt or ckpt_path
    # -------------- Dataset + Dataloader -------------
    train_dataset = SiameseDataset(train_data_path, root_dir=".")
    val_dataset = SiameseDataset(val_data_path, root_dir=".")
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    # ------------ Model, Loss, Optimizer --------------
    model = SiameseNetwork().to(device)

    if finetune and os.path.exists(best_ckpt_path): 
        print(f" Continue training from checkpoint: {best_ckpt_path}")
        model.load_state_dict(torch.load(best_ckpt_path, map_location=device))
        print("[INFO] Fine-tuning mode: freezing convolutional backbone...")
        for param in model.features.parameters():
            param.requires_grad = False # freeze conv layers, only train the fc layer

    loss_fn = ContrastiveLoss(margin=margin)
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate) 

    scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.1, patience=2, verbose=True)
    early_stop_patience = 7      # if validation loss does not improve for this many epochs, stop training
    bad_epochs = 0

    # ------ Training Loop -----------------
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        print(f"\n Epoch {epoch+1}/{num_epochs}")
        train_loss = train_one_epoch(model, train_loader, loss_fn, optimizer, device)
        val_loss = validate(model, val_loader, loss_fn, device)

        train_losses.append(train_loss)
        val_losses.append(val_loss)

        print(f"[Epoch {epoch+1}/{num_epochs}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

        scheduler.step(val_loss)
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_checkpoint(model, ckpt_path)
            print(f"Saved improved model to {ckpt_path}")
            bad_epochs = 0           # reset
        else:
            bad_epochs += 1
        # early stopping
        if bad_epochs >= early_stop_patience:
            print("Early stopping triggered.")
            break

    plot_loss(train_losses, val_losses, save_path="./outputs/loss_curve.png")

if __name__ == "__main__":
    main()


# validate.py

In [None]:
import torch
import numpy as np
from siamese_model import SiameseNetwork
from utils import SiameseDataset, print_classification_report
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils import plot_distance_distribution, plot_roc_curve, plot_metrics_vs_threshold
import os
import argparse



def evaluate_accuracy(model, dataloader, threshold=0.5, device="cpu"):
    print(f"\n Threshold used: {threshold}")
    model.eval()
    correct = 0
    total = 0
    tp = tn = fp = fn = 0

    with torch.no_grad():
        for img1, img2, label in tqdm(dataloader, desc="Validating"):
            img1, img2, label = img1.to(device), img2.to(device), label.to(device)

            out1, out2 = model(img1, img2)
            distance = F.pairwise_distance(out1, out2)

            # distance < threshold means similar (1), distance >= threshold means dissimilar (0)
            prediction = (distance < threshold).float()

            correct += (prediction == label).sum().item()
            total += label.size(0)

            tp += ((prediction == 1) & (label == 1)).sum().item()
            tn += ((prediction == 0) & (label == 0)).sum().item()
            fp += ((prediction == 1) & (label == 0)).sum().item()
            fn += ((prediction == 0) & (label == 1)).sum().item()

    accuracy = correct / total * 100
    print_classification_report(tp, tn, fp, fn)
    return accuracy

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--val_data", required=True, help="Path to validation pairs npz file")
    parser.add_argument("--ckpt", required=True, help="Path to checkpoint file", default="./checkpoints/model_augTrue_blTrue_bs8_ep20_lr0.001_mg2.0.pt")
    parser.add_argument("--threshold", type=float, default=1.010204, help="Threshold for distance")
    args = parser.parse_args()

    os.makedirs("./outputs", exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    batch_size = 8

    # Load the best threshold from file if it exists
    # threshold_file = "./project3_fingerprint_fvc2000/outputs/best_threshold.txt"
    # if os.path.exists(threshold_file):
    #     with open(threshold_file, "r") as f:
    #         threshold = float(f.read().strip())
    #     print(f"Loaded best threshold: {threshold}")
    # else:
    #     threshold = 0.05  # fallback 
    #     print(f"No threshold file found, using default threshold = {threshold}")

    # IMPORTANT: Change the model path to your trained model
    ckpt_path = args.ckpt

    model = SiameseNetwork().to(device)
    model.load_state_dict(torch.load(ckpt_path, map_location=device))
    print(f"Loaded model from {ckpt_path}")

    val_dataset = SiameseDataset(args.val_data, root_dir=".")
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    evaluate_accuracy(model, val_loader, threshold=args.threshold, device=device)

    plot_distance_distribution(model, val_loader, device, save_path="./outputs/distance_hist.png")
    plot_roc_curve(model, val_loader, device, save_path="./outputs/roc_curve.png")
    plot_metrics_vs_threshold(model, val_loader, device, save_path="./outputs/metrics_vs_threshold.png")

if __name__ == "__main__":
    main()



# inference_batch.py

In [None]:
import os
import torch
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
from project3_fingerprint_fvc2000.preprocess_old import normalize
from create_train_pairs import load_images_by_finger_tif, create_pairs
from siamese_model import SiameseNetwork
from utils import print_classification_report
from random import sample
import argparse


def inference_batch(model, pairs, labels, threshold=0.041, device="cpu", desc="Inferencing"):
    model.eval()
    correct = 0
    total = 0
    tp = tn = fp = fn = 0

    for (img1_path, img2_path), label in tqdm(zip(pairs, labels), total=len(pairs), desc=desc):
        img1 = normalize(img1_path)
        img2 = normalize(img2_path)

        img1 = torch.from_numpy(img1).unsqueeze(0).unsqueeze(0).to(device)  # shape: [1, 1, H, W]
        img2 = torch.from_numpy(img2).unsqueeze(0).unsqueeze(0).to(device)
        label = torch.tensor(label).to(device)

        with torch.no_grad():
            out1, out2 = model(img1, img2)
            dist = F.pairwise_distance(out1, out2).item()
            prediction = 1.0 if dist < threshold else 0.0

            correct += (prediction == label.item())
            total += 1

            # Update confusion matrix
            if prediction == 1.0 and label.item() == 1:
                tp += 1
            elif prediction == 0.0 and label.item() == 0:
                tn += 1
            elif prediction == 1.0 and label.item() == 0:
                fp += 1
            elif prediction == 0.0 and label.item() == 1:
                fn += 1

    acc = correct / total * 100
    print_classification_report(tp, tn, fp, fn)


def auto_calibrate_threshold(model, calib_pairs, calib_labels, device="cpu",
                             search_min=0.0, search_max=2.0, steps=120):
    """
    Given a batch of calibrated pairs + labels, scan the threshold to find the highest F1 point.
    Returns: best_threshold, best_f1
    """
    model.eval()
    dists = []

    with torch.no_grad():
        for (p1, p2), _ in zip(calib_pairs, calib_labels):
            a_arr = normalize(p1)   
            b_arr = normalize(p2)
            a = torch.from_numpy(a_arr).unsqueeze(0).unsqueeze(0).to(device)
            b = torch.from_numpy(b_arr).unsqueeze(0).unsqueeze(0).to(device)
            f1, f2 = model(a, b)
            dists.append(F.pairwise_distance(f1, f2).item())

    dists = np.array(dists)
    labs  = np.array(calib_labels)

    thresholds = np.linspace(search_min, search_max, steps)
    best_f1, best_t = 0.0, thresholds[0]

    for t in thresholds:
        pred = (dists < t).astype(int)
        tp = np.sum((pred == 1) & (labs == 1))
        fp = np.sum((pred == 1) & (labs == 0))
        fn = np.sum((pred == 0) & (labs == 1))
        prec = tp / (tp + fp + 1e-8)
        rec  = tp / (tp + fn + 1e-8)
        f1   = 2 * prec * rec / (prec + rec + 1e-8)

        if f1 > best_f1:
            best_f1, best_t = f1, t

    return best_t, best_f1


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--inference_data", type=str, required=True, help="Path to inference images folder")
    parser.add_argument("--ckpt", type=str, required=True, help="Path to model checkpoint")
    parser.add_argument("--thresholds", nargs='+', type=float,
                    help="List of thresholds to try, e.g., 0.83 0.85 0.87 0.90")
    parser.add_argument("--auto_threshold", action="store_true",
                    help="Use auto-calibrated threshold (20% data for calibration)")
    args = parser.parse_args()

    # ------------ paths ------------
    inference_data_path = args.inference_data
    model_path = args.ckpt
    

    # ------------ device / model ------------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SiameseNetwork().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    print(f"Loaded model from {model_path}")

    # ------------ build pairs for this DB ------------
    finger_dict = load_images_by_finger_tif(inference_data_path)
    pairs, labels = create_pairs(finger_dict, sorted(finger_dict.keys()),
                                 augment_positive=False, num_augments=0,
                                 balance_negatives=False)
    
    
    if args.auto_threshold:
        
        print("\n[INFO] Auto threshold mode enabled.")
        pos_idx = [i for i, l in enumerate(labels) if l == 1]
        neg_idx = [i for i, l in enumerate(labels) if l == 0]
        np.random.seed(42)
        np.random.shuffle(pos_idx);  np.random.shuffle(neg_idx)
        # # ------------ version 1: 20 % calibrate(pos:neg ~= 1:1), 80 % infer ------------
        # cap = int(0.2 * len(labels))                      
        # n_pos_calib = min(len(pos_idx), cap // 2)         
        # n_neg_calib = min(len(neg_idx), cap - n_pos_calib)  
        # calib_idx   = pos_idx[:n_pos_calib] + neg_idx[:n_neg_calib]  
        # np.random.shuffle(calib_idx)                     

        # calib_pairs  = [pairs[i]  for i in calib_idx]
        # calib_labels = [labels[i] for i in calib_idx]

        # print(f"[Calib] Positive={sum(calib_labels)} | Negative={len(calib_labels)-sum(calib_labels)}")
        # threshold, f1_calib = auto_calibrate_threshold(model, calib_pairs, calib_labels, device)
        # print(f"Auto-calibrated threshold = {threshold:.4f}  (F1 on calib = {f1_calib:.4f})")

        # # ------------ version 2: 20 % calibrate (pos:neg ~= 1:10) 100% infer ------------
        # === Keep the real distribution as original inference data ===
        total_pairs = len(labels)
        pos_ratio = len(pos_idx) / total_pairs
        neg_ratio = len(neg_idx) / total_pairs

        cap = int(0.12 * total_pairs)
        n_pos_calib = int(cap * pos_ratio)
        n_neg_calib = cap - n_pos_calib

        n_pos_calib = min(n_pos_calib, len(pos_idx))
        n_neg_calib = min(n_neg_calib, len(neg_idx))

        calib_idx = pos_idx[:n_pos_calib] + neg_idx[:n_neg_calib]
        np.random.shuffle(calib_idx)

        calib_pairs  = [pairs[i]  for i in calib_idx]
        calib_labels = [labels[i] for i in calib_idx]

        print(f"[Calib] Positive={sum(calib_labels)} | Negative={len(calib_labels)-sum(calib_labels)}")
        threshold, f1_calib = auto_calibrate_threshold(model, calib_pairs, calib_labels, device)
        print(f"Auto-calibrated threshold = {threshold:.4f}  (F1 on calib = {f1_calib:.4f})")

        
        # ------------ final inference ------------
        n_pos = sum(labels)                  # label == 1
        n_neg = len(labels) - n_pos          # label == 0
        print(f"[All pairs] Positive={n_pos}  |  Negative={n_neg}  "
            f"({n_pos/len(labels):.2%} positive)")
        inference_batch(model, pairs, labels,
                    threshold=threshold, device=device, desc="Infer-100%")
        print(f"[Inference] Total pairs: {len(pairs)}")
    
    elif args.thresholds:
        print("\n[INFO] Manual threshold mode enabled.")
        for t in args.thresholds:
            print(f"\n=== threshold {t} ===")
            inference_batch(model, pairs, labels, t, device, desc=f"@{t}")

    else:
        print("ERROR: You must either use --auto_threshold or provide --thresholds values.")

    
if __name__ == "__main__":
    main()


# utils.py

In [None]:
import torch
import numpy as np
from torch.utils.data import Dataset
import torch.nn as nn
import matplotlib.pyplot as plt
import os
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
import torch.nn.functional as F
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score

class SiameseDataset(Dataset):
    def __init__(self, pairs_file, root_dir=""):
        data = np.load(pairs_file, allow_pickle=True)
        self.pairs = data["pairs"]
        self.labels = data["labels"]
        self.root_dir = root_dir

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img1_path, img2_path = self.pairs[idx]
        label = self.labels[idx]

        if isinstance(img1_path, str):
            img1_full_path = os.path.join(self.root_dir, img1_path)
            img1 = np.load(img1_full_path)
        else:
            img1 = img1_path  

        if isinstance(img2_path, str):
            img2_full_path = os.path.join(self.root_dir, img2_path)
            img2 = np.load(img2_full_path)
        else:
            img2 = img2_path

        img1 = torch.tensor(img1).unsqueeze(0)  # [1, H, W]
        img2 = torch.tensor(img2).unsqueeze(0)
        label = torch.tensor(label).float()

        return img1, img2, label


class ContrastiveLoss(nn.Module):
    def __init__(self, margin=2.0):          
        super().__init__()
        self.margin = margin

    def forward(self, out1, out2, label):
        # Euclidean distance between embeddings
        d = torch.nn.functional.pairwise_distance(out1, out2)
        # label = 1 → same finger  → target distance = 0
        # label = 0 → different    → target distance ≥ margin
        loss = label * d.pow(2) + (1 - label) * torch.clamp(self.margin - d, min=0.0).pow(2)
        return loss.mean()                  # ← 缩进对齐




def plot_loss(train_losses, val_losses, save_path=None):
    plt.figure()
    plt.plot(train_losses, label="Train Loss")
    plt.plot(val_losses, label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training vs Validation Loss")
    plt.legend()
    if save_path:
        plt.savefig(save_path)
        print(f"Loss plot saved to {save_path}")
    else:
        plt.show()


def save_checkpoint(model, path="checkpoints/best_model.pt"):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")




def plot_distance_distribution(model, dataloader, device="cpu", save_path=None):
    import torch.nn.functional as F
    model.eval()

    pos_distances = []
    neg_distances = []

    with torch.no_grad():
        for img1, img2, label in dataloader:
            img1, img2, label = img1.to(device), img2.to(device), label.to(device)
            out1, out2 = model(img1, img2)
            dist = F.pairwise_distance(out1, out2)

            for d, l in zip(dist, label):
                if l == 1:
                    pos_distances.append(d.item())
                else:
                    neg_distances.append(d.item())


    plt.figure(figsize=(8,5))
    plt.hist(pos_distances, bins=50, alpha=0.6, label="Positive (same finger)")
    plt.hist(neg_distances, bins=50, alpha=0.6, label="Negative (different finger)")
    plt.xlabel("Distance")
    plt.ylabel("Count")
    plt.title("Distance Distribution")
    plt.legend()
    if save_path:
        plt.savefig(save_path)
        print(f" Distance histogram saved to {save_path}")
    else:
        plt.show()


def plot_roc_curve(model, dataloader, device="cpu", save_path=None):
    import torch.nn.functional as F
    model.eval()

    all_distances = []
    all_labels = []

    with torch.no_grad():
        for img1, img2, label in dataloader:
            img1, img2, label = img1.to(device), img2.to(device), label.to(device)
            out1, out2 = model(img1, img2)
            dist = F.pairwise_distance(out1, out2)

            all_distances.extend(dist.cpu().numpy())
            all_labels.extend(label.cpu().numpy())


    # score = -distance
    fpr, tpr, thresholds = roc_curve(all_labels, -1 * np.array(all_distances)) 
    roc_auc = auc(fpr, tpr)

    plt.figure(figsize=(6,6))
    plt.plot(fpr, tpr, label=f"ROC curve (AUC = {roc_auc:.2f})")
    plt.plot([0, 1], [0, 1], linestyle='--', color='gray')
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve")
    plt.legend(loc="lower right")

    if save_path:
        plt.savefig(save_path)
        print(f" ROC curve saved to {save_path}")
    else:
        plt.show()


def print_classification_report(tp, tn, fp, fn):
    total = tp + tn + fp + fn
    acc = (tp + tn) / total * 100
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

    print("\n Classification Report:")
    print(f"  Accuracy  : {acc:.2f}%")
    print(f"  Precision : {precision:.4f}")
    print(f"  Recall    : {recall:.4f}")
    print(f"   F1 Score  : {f1:.4f}")
    print(f"  TP={tp} | TN={tn} | FP={fp} | FN={fn}")

def plot_metrics_vs_threshold(model, dataloader, device="cpu", save_path=None):
    model.eval()
    thresholds = np.linspace(0.7, 1.1, 50)
    all_distances = []
    all_labels = []

    print("Extracting embeddings...")
    with torch.no_grad():
        for img1, img2, label in tqdm(dataloader, desc="Forward pass"):
            img1, img2 = img1.to(device), img2.to(device)
            out1, out2 = model(img1, img2)
            dist = F.pairwise_distance(out1, out2)
            all_distances.extend(dist.cpu().numpy())
            all_labels.extend(label.numpy())

    all_distances = np.array(all_distances)
    all_labels = np.array(all_labels)

    accuracies = []
    precisions = []
    recalls = []
    f1s = []

    print("Calculating metrics for thresholds...")
    for t in tqdm(thresholds, desc="Threshold"):
        preds = (all_distances < t).astype(int)
        accuracies.append((preds == all_labels).mean())
        precisions.append(precision_score(all_labels, preds, zero_division=0))
        recalls.append(recall_score(all_labels, preds, zero_division=0))
        f1s.append(f1_score(all_labels, preds, zero_division=0))

    best_idx = np.argmax(f1s)
    best_threshold = thresholds[best_idx]
    best_f1 = f1s[best_idx]
    print(f"\n Best threshold = {best_threshold:.6f} → F1 Score = {best_f1:.6f}")

    # Plot all metrics
    plt.figure(figsize=(10, 6))
    plt.plot(thresholds, accuracies, label="Accuracy")
    plt.plot(thresholds, precisions, label="Precision")
    plt.plot(thresholds, recalls, label="Recall")
    plt.plot(thresholds, f1s, label="F1 Score")
    plt.xlabel("Threshold")
    plt.ylabel("Metric Value")
    plt.title("Metrics vs Threshold")
    plt.legend()
    plt.grid(True)

    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path)
        print(f"Metrics plot saved to {save_path}")
    else:
        plt.show()

    with open("./outputs/best_threshold.txt", "w") as f:
        f.write(f"{best_threshold:.6f}")
    print("Best threshold saved to outputs/best_threshold.txt")

    return best_threshold
