PHASE 1: Baseline Metric Selection for Pre-training (K-Fold Comparison)

- This notebook runs K-Fold cross-validation to compare fixed distance metrics (Euclidean, Cosine, Manhattan) and prepares the best pre-trained feature extractor for meta-training.

- The environment setup and file paths have been updated to clone the repository and use paths relative to the cloned repo and Kaggle input structure.

=== Environment Setup ===

In [None]:
# Ensure the latest version of the code is used
!rm -rf Deep-Learning-Based-Signature-Forgery-Detection-for-Personal-Identity-Authentication-Update
!git clone https://github.com/trongjhuongwr/Deep-Learning-Based-Signature-Forgery-Detection-for-Personal-Identity-Authentication-Update.git
%cd Deep-Learning-Based-Signature-Forgery-Detection-for-Personal-Identity-Authentication-Update

Cloning into 'Deep-Learning-Based-Signature-Forgery-Detection-for-Personal-Identity-Authentication-Update'...
remote: Enumerating objects: 249, done.[K
remote: Counting objects: 100% (249/249), done.[K
remote: Compressing objects: 100% (175/175), done.[K
remote: Total 249 (delta 126), reused 190 (delta 70), pack-reused 0 (from 0)[K
Receiving objects: 100% (249/249), 1.10 MiB | 9.35 MiB/s, done.
Resolving deltas: 100% (126/126), done.
/kaggle/working/Deep-Learning-Based-Signature-Forgery-Detection-for-Personal-Identity-Authentication-Update


In [None]:
# !pip install -q ipywidgets
# !jupyter nbextension enable --py widgetsnbextension --sys-prefix

Enabling notebook extension jupyter-js-widgets/extension...
Paths used for configuration of notebook: 
    	/usr/etc/jupyter/nbconfig/notebook.json
Paths used for configuration of notebook: 
    	
      - Validating: [32mOK[0m
Paths used for configuration of notebook: 
    	/usr/etc/jupyter/nbconfig/notebook.json


# === Imports and repo path setup ===

In [None]:
import os
import sys
import json
import random
import re
from itertools import combinations
from pathlib import Path

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm, trange
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision.transforms as transforms

# Make repo root importable
REPO_ROOT = os.path.abspath(os.getcwd())
sys.path.append(REPO_ROOT)
print("Repo root:", REPO_ROOT)

# Project imports
from utils.helpers import load_config
from models.Triplet_Siamese_Similarity_Network import tSSN
from models.feature_extractor import ResNetFeatureExtractor
from losses.triplet_loss import TripletLoss
from dataloader.tSSN_trainloader import SignaturePretrainDataset

print("Project modules imported.")

Repo root: /kaggle/working/Deep-Learning-Based-Signature-Forgery-Detection-for-Personal-Identity-Authentication-Update
Project modules imported.


In [None]:
# Deterministic seeds
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print(f"Manual seed set to {SEED}")

Manual seed set to 42


In [None]:
# === Helpers: filename -> user_id, SignaturePairDataset, collate function ===
def _get_user_id_from_filename(filename: str):
    """
    Extract user ID from filename (tries CEDAR and BHSig filename patterns).
    Returns int user id or None.
    """
    if filename is None:
        return None
    # CEDAR pattern: something_like_XX_... (example: sig_001_...)
    m = re.search(r'_(\d+)_', filename)
    if m:
        return int(m.group(1))
    # BHSig pattern: something-like-123-...
    m = re.search(r'-(\d+)-', filename)
    if m:
        return int(m.group(1))
    return None

from torch.utils.data import Dataset

class SignaturePairDataset(Dataset):
    """
    Build evaluation PAIRS: genuine-genuine (label=1) and genuine-forgery (label=0)
    given directories and a list of user IDs to include.
    """
    def __init__(self, org_dir: str, forg_dir: str, user_ids, transform=None):
        self.transform = transform
        self.pairs = []
        self.user_map = {}
        supported_extensions = ('.png', '.jpg', '.jpeg', '.tif', '.tiff', '.bmp')
        user_ids_set = set(user_ids)

        # collect genuine images per user
        if os.path.isdir(org_dir):
            for f in os.listdir(org_dir):
                if f.lower().endswith(supported_extensions):
                    uid = _get_user_id_from_filename(f)
                    if uid in user_ids_set:
                        if uid not in self.user_map:
                            self.user_map[uid] = {'genuine': [], 'forged': []}
                        self.user_map[uid]['genuine'].append(os.path.join(org_dir, f))

        # collect forged images per user
        if os.path.isdir(forg_dir):
            for f in os.listdir(forg_dir):
                if f.lower().endswith(supported_extensions):
                    uid = _get_user_id_from_filename(f)
                    if uid in user_ids_set:
                        if uid not in self.user_map:
                            continue
                        self.user_map[uid]['forged'].append(os.path.join(forg_dir, f))

        # build pairs
        for uid, lists in self.user_map.items():
            genuines = lists['genuine']
            forgeries = lists['forged']
            for (p1, p2) in combinations(genuines, 2):
                self.pairs.append((p1, p2, 1))
            for g in genuines:
                for fg in forgeries:
                    self.pairs.append((g, fg, 0))

        if not self.pairs:
            print(f"Warning: No pairs created for users: {list(user_ids)[:10]} (check folders)")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        path1, path2, label = self.pairs[idx]
        try:
            img1 = Image.open(path1).convert('L')
            img2 = Image.open(path2).convert('L')
            if self.transform is not None:
                img1 = self.transform(img1)
                img2 = self.transform(img2)
            return img1, img2, torch.tensor(label, dtype=torch.float32)
        except Exception as e:
            print(f"Error loading pair ({path1}, {path2}): {e}")
            return None

def collate_fn_skip_none(batch):
    batch = [b for b in batch if b is not None]
    if not batch:
        return torch.empty(0), torch.empty(0), torch.empty(0)
    return torch.utils.data.dataloader.default_collate(batch)


In [None]:
# === EER helper ===
from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score

def calculate_far_frr_eer(true_labels, distances):
    """Return (eer, threshold) computed over arrays of labels and distances."""
    true_labels = np.asarray(true_labels)
    distances = np.asarray(distances)
    mask = np.isfinite(distances)
    if not np.any(mask):
        return 1.0, np.nan
    true_labels = true_labels[mask]
    distances = distances[mask]
    if len(distances) == 0 or len(np.unique(true_labels)) < 2:
        return 1.0, np.nan
    mn, mx = np.min(distances), np.max(distances)
    thresholds = np.linspace(mn - 1e-6, mx + 1e-6, num=500)
    far_list = []
    frr_list = []
    for t in thresholds:
        preds = (distances < t).astype(int)
        tp = np.sum((preds == 1) & (true_labels == 1))
        fp = np.sum((preds == 1) & (true_labels == 0))
        tn = np.sum((preds == 0) & (true_labels == 0))
        fn = np.sum((preds == 0) & (true_labels == 1))
        far = fp / (fp + tn) if (fp + tn) > 0 else 0.0
        frr = fn / (fn + tp) if (fn + tp) > 0 else 0.0
        far_list.append(far)
        frr_list.append(frr)
    far_arr = np.array(far_list)
    frr_arr = np.array(frr_list)
    idx = np.nanargmin(np.abs(far_arr - frr_arr))
    eer = (far_arr[idx] + frr_arr[idx]) / 2.0
    return eer, thresholds[idx]

In [None]:
# === Define the Distance Learning Network (DistanceNet) ===
class DistanceNet(nn.Module):
    """
    A simple MLP network to learn the distance between two embeddings.
    Input: Two embeddings concatenated together.
    Output: A distance value (always positive).
    """
    def __init__(self, input_dim, hidden_dim=512):
        super().__init__()
        # input_dim is the size of ONE embedding (e.g., 512)
        # The input to the MLP will be 2 * input_dim
        self.fc1 = nn.Linear(2 * input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, 1)
        # Use Sigmoid to ensure output > 0 and optionally rescale later
        # Alternatively, you can use ReLU or abs() for the final activation
        self.sigmoid = nn.Sigmoid()

    def forward(self, emb1, emb2):
        # Concatenate two embeddings
        x = torch.cat((emb1, emb2), dim=1)  # Shape: (B, 2 * input_dim)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        # Ensure the output is a non-negative distance
        # Sigmoid gives values in (0, 1). You can scale it if larger distances are desired.
        distance = self.sigmoid(x).squeeze(-1)  # Shape: (B,)
        # Or use ReLU: distance = F.relu(x).squeeze(-1)
        # Or use abs: distance = torch.abs(x).squeeze(-1)
        return distance

class DistanceNetTripletLoss(nn.Module):
    """
    Triplet Loss function that uses DistanceNet to compute distances.
    """
    def __init__(self, margin, input_dim, hidden_dim=512):
        super().__init__()
        self.margin = margin
        self.distance_net = DistanceNet(input_dim, hidden_dim)

    def forward(self, anchor, positive, negative):
        d_ap = self.distance_net(anchor, positive)
        d_an = self.distance_net(anchor, negative)
        losses = F.relu(d_ap - d_an + self.margin)
        return torch.mean(losses)

print("Defined DistanceNet and DistanceNetTripletLoss.")

Defined DistanceNet and DistanceNetTripletLoss.


=== Training epoch and pair evaluation (STATIC vs BASIC LEARNABLE) ===

In [None]:
# --- FUNCTIONS FOR STATIC METRICS (euclidean, cosine, manhattan) ---
def train_epoch(model, dataloader, loss_fn, optimizer, device):
    model.train()
    total_loss = 0.0
    count = 0
    for item in dataloader:
        if isinstance(item, tuple) and len(item) == 3 and getattr(item[0], 'nelement', lambda: 1)() == 0:
            continue
        anchor, positive, negative = item
        anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)
        optimizer.zero_grad()
        a_emb, p_emb, n_emb = model(anchor, positive, negative)
        loss = loss_fn(a_emb, p_emb, n_emb)
        loss.backward()
        optimizer.step()
        total_loss += float(loss.item())
        count += 1
    return total_loss / count if count > 0 else 0.0


def evaluate_on_pairs(model, dataloader, device, distance_mode='euclidean'):
    model.eval()
    all_labels = []
    all_distances = []
    feature_extractor = model.module.feature_extractor if isinstance(model, nn.DataParallel) else model.feature_extractor
    with torch.no_grad():
        for item in dataloader:
            if isinstance(item, tuple) and len(item) == 3 and getattr(item[0], 'nelement', lambda: 1)() == 0:
                continue
            img1, img2, label = item
            img1, img2 = img1.to(device), img2.to(device)
            f1 = feature_extractor(img1)
            f2 = feature_extractor(img2)
            if distance_mode == 'euclidean':
                d = F.pairwise_distance(f1, f2, p=2)
            elif distance_mode == 'cosine':
                d = 1.0 - F.cosine_similarity(f1, f2, dim=1)
            elif distance_mode == 'manhattan':
                d = F.pairwise_distance(f1, f2, p=1)
            all_distances.extend(d.cpu().numpy())
            all_labels.extend(label.cpu().numpy())

    if not all_labels or not all_distances:
        return 1.0, 0.0, 0.0, 0.0, 0.0, 0.0
    eer, thr = calculate_far_frr_eer(all_labels, all_distances)
    if np.isnan(thr):
        preds = np.zeros_like(all_labels)
    else:
        preds = (np.array(all_distances) < thr).astype(int)
    roc_auc = roc_auc_score(all_labels, -np.array(all_distances))
    acc = accuracy_score(all_labels, preds)
    precision = precision_score(all_labels, preds, zero_division=0)
    recall = recall_score(all_labels, preds, zero_division=0)
    f1 = f1_score(all_labels, preds, zero_division=0)
    return eer, roc_auc, acc, precision, recall, f1


# --- FUNCTIONS FOR LEARNABLE DISTANCE_NET ---
def train_epoch_distnet(model, dataloader, distnet_loss_fn, optimizer, device):
    """Training epoch function for DistanceNet."""
    model.train()
    distnet_loss_fn.train()  # Set both model and loss module to .train()
    total_loss = 0.0
    count = 0
    for item in dataloader:
        if isinstance(item, tuple) and len(item) == 3 and getattr(item[0], 'nelement', lambda: 1)() == 0:
            continue
        anchor, positive, negative = item
        anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)
        optimizer.zero_grad()
        # 1. Extract embeddings
        a_emb, p_emb, n_emb = model(anchor, positive, negative)
        # 2. Compute loss (the loss function internally calls distance_net)
        loss = distnet_loss_fn(a_emb, p_emb, n_emb)
        loss.backward()
        optimizer.step()
        total_loss += float(loss.item())
        count += 1
    return total_loss / count if count > 0 else 0.0


def evaluate_on_pairs_distnet(model, dataloader, distnet_loss_fn, device):
    """Evaluation function using DistanceNet."""
    model.eval()
    distnet_loss_fn.eval()  # Set both model and loss module to .eval()
    all_labels = []
    all_distances = []
    # Get feature extractor and distance_net
    feature_extractor = model.module.feature_extractor if isinstance(model, nn.DataParallel) else model.feature_extractor
    distance_net = distnet_loss_fn.distance_net  # Retrieve distance_net from loss function
    with torch.no_grad():
        for item in dataloader:
            if isinstance(item, tuple) and len(item) == 3 and getattr(item[0], 'nelement', lambda: 1)() == 0:
                continue
            img1, img2, label = item
            img1, img2 = img1.to(device), img2.to(device)
            # 1. Extract embeddings
            f1 = feature_extractor(img1)
            f2 = feature_extractor(img2)
            # 2. Compute distance using distance_net
            d = distance_net(f1, f2)
            all_distances.extend(d.cpu().numpy())
            all_labels.extend(label.cpu().numpy())

    if not all_labels or not all_distances:
        return 1.0, 0.0, 0.0, 0.0, 0.0, 0.0
    eer, thr = calculate_far_frr_eer(all_labels, all_distances)
    if np.isnan(thr):
        preds = np.zeros_like(all_labels)
    else:
        preds = (np.array(all_distances) < thr).astype(int)  # Assume smaller distance = genuine
    roc_auc = roc_auc_score(all_labels, -np.array(all_distances))  # Score = -distance
    acc = accuracy_score(all_labels, preds)
    precision = precision_score(all_labels, preds, zero_division=0)
    recall = recall_score(all_labels, preds, zero_division=0)
    f1 = f1_score(all_labels, preds, zero_division=0)
    return eer, roc_auc, acc, precision, recall, f1


print("Defined training and evaluation functions (for both Static and Learnable DistanceNet).")

Defined training and evaluation functions (for both Static and Learnable DistanceNet).


In [9]:
NUM_SPLITS = 5
BASE_DATA_DIR = '/kaggle/input/cedardataset/signatures'
SPLIT_FILES_DIR = '/kaggle/working/Deep-Learning-Based-Signature-Forgery-Detection-for-Personal-Identity-Authentication-Update/scripts/prepare_kfold_splits'

print("Generating K-Fold split files...")
os.makedirs(SPLIT_FILES_DIR, exist_ok=True)

script_path = 'scripts/prepare_kfold_splits.py'
command = f"python {script_path} --base_data_dir {BASE_DATA_DIR} --output_dir {SPLIT_FILES_DIR} --seed {SEED} --num_splits {NUM_SPLITS}"

print(f"Running command: {command}")
!{command}

created_files = os.listdir(SPLIT_FILES_DIR)
print(f"Generated files in {SPLIT_FILES_DIR}: {created_files}")
if len(created_files) != NUM_SPLITS:
    print(f"Warning: Expected {NUM_SPLITS} split files, but found {len(created_files)}.")
else:
    print("K-Fold split files generated successfully.")

Generating K-Fold split files...
Running command: python scripts/prepare_kfold_splits.py --base_data_dir /kaggle/input/cedardataset/signatures --output_dir /kaggle/working/Deep-Learning-Based-Signature-Forgery-Detection-for-Personal-Identity-Authentication-Update/scripts/prepare_kfold_splits --seed 42 --num_splits 5
--- Starting CEDAR K-Fold Split Generation ---
Scanning genuine signatures in: /kaggle/input/cedardataset/signatures/full_org
Scanning forged signatures in: /kaggle/input/cedardataset/signatures/full_forg
Found 55 unique users with genuine signatures.
Splitting users into 5 folds...
  Fold 1: 44 train users, 11 test users. Saved to: /kaggle/working/Deep-Learning-Based-Signature-Forgery-Detection-for-Personal-Identity-Authentication-Update/scripts/prepare_kfold_splits/cedar_meta_split_fold_0.json
  Fold 2: 44 train users, 11 test users. Saved to: /kaggle/working/Deep-Learning-Based-Signature-Forgery-Detection-for-Personal-Identity-Authentication-Update/scripts/prepare_

=== Configuration and paths ===

In [None]:
# The Kaggle dataset folder name (change this to your dataset name if different)
KAGGLE_CEDAR_DATASET_NAME = 'cedardataset'  # change if needed
KAGGLE_BASE = os.path.join('/kaggle/input', KAGGLE_CEDAR_DATASET_NAME)

DATA_DIR = os.path.join(KAGGLE_BASE, 'signatures')
ORG_DIR = os.path.join(DATA_DIR, 'full_org')
FORG_DIR = os.path.join(DATA_DIR, 'full_forg')

# Repo config file
CONFIG_FILE_PATH = os.path.join(REPO_ROOT, 'configs', 'config_tSSN.yaml')

# Default splits output folder used by prepare_kfold_splits script (change if different)
POSSIBLE_SPLIT_DIRS = [
    os.path.join(REPO_ROOT, 'scripts', 'prepare_kfold_splits'),
    os.path.join(REPO_ROOT, 'prepare_kfold_splits'),
    '/kaggle/working/Deep-Learning-Based-Signature-Forgery-Detection-for-Personal-Identity-Authentication-Update/scripts/prepare_kfold_splits',
    '/kaggle/working/kfold_splits',
    '/mnt/data'  # try working dir where you might have saved split files
]

# find an actual existing splits folder
SPLITS_DIR = None
for d in POSSIBLE_SPLIT_DIRS:
    if os.path.isdir(d):
        SPLITS_DIR = d
        break

if SPLITS_DIR is None:
    # fallback to REPO_ROOT/scripts/kfold_splits — will raise later if not present
    SPLITS_DIR = os.path.join(REPO_ROOT, 'scripts', 'kfold_splits')

print("Paths summary:")
print("  CONFIG_FILE_PATH =", CONFIG_FILE_PATH)
print("  ORG_DIR =", ORG_DIR)
print("  FORG_DIR =", FORG_DIR)
print("  SPLITS_DIR (candidate) =", SPLITS_DIR)

# Load config (raises if missing)
config = load_config(CONFIG_FILE_PATH)
print("Loaded config.")


Paths summary:
  CONFIG_FILE_PATH = /kaggle/working/Deep-Learning-Based-Signature-Forgery-Detection-for-Personal-Identity-Authentication-Update/configs/config_tSSN.yaml
  ORG_DIR = /kaggle/input/cedardataset/signatures/full_org
  FORG_DIR = /kaggle/input/cedardataset/signatures/full_forg
  SPLITS_DIR (candidate) = /kaggle/working/Deep-Learning-Based-Signature-Forgery-Detection-for-Personal-Identity-Authentication-Update/scripts/prepare_kfold_splits
Loaded config.


In [None]:
# === Safe transform & device setup (robust handling of config['input_size']) ===
import numbers

# Try to find input_size in a few common config locations
raw_input_size = None
if isinstance(config, dict):
    # try several possible keys
    if 'dataset' in config and isinstance(config['dataset'], dict) and 'input_size' in config['dataset']:
        raw_input_size = config['dataset']['input_size']
    elif 'dataset_params' in config and isinstance(config['dataset_params'], dict) and 'input_size' in config['dataset_params']:
        raw_input_size = config['dataset_params']['input_size']
    elif 'dataset' in config and isinstance(config['dataset'], dict) and 'img_size' in config['dataset']:
        raw_input_size = config['dataset']['img_size']
    elif 'input_size' in config:
        raw_input_size = config['input_size']

print("Raw input_size from config:", raw_input_size, type(raw_input_size))

# Helper to coerce to (H, W)
def coerce_to_hw(x):
    # Accept ints -> (int, int)
    if isinstance(x, numbers.Number):
        return (int(x), int(x))
    # Accept tuple/list of 2 numbers
    if isinstance(x, (list, tuple)) and len(x) == 2:
        try:
            h = int(x[0]); w = int(x[1])
            return (h, w)
        except Exception:
            return None
    # Accept nested single-element list/tuple like [(220,150)]
    if isinstance(x, (list, tuple)) and len(x) == 1:
        return coerce_to_hw(x[0])
    # Accept strings like "220,150" or "220x150"
    if isinstance(x, str):
        s = x.strip()
        for sep in [',', 'x', '×', ' ']:
            if sep in s:
                parts = [p for p in re.split(r'[,x×\s]+', s) if p]
                if len(parts) >= 2:
                    try:
                        return (int(parts[0]), int(parts[1]))
                    except Exception:
                        pass
        # single number in string
        try:
            v = int(s)
            return (v, v)
        except Exception:
            return None
    # Accept dicts with keys like 'height' and 'width' or 'h' 'w'
    if isinstance(x, dict):
        h = x.get('height') or x.get('h') or x.get('H')
        w = x.get('width')  or x.get('w') or x.get('W')
        if h is not None and w is not None:
            try:
                return (int(h), int(w))
            except Exception:
                return None
    return None

input_size = coerce_to_hw(raw_input_size)
if input_size is None:
    print("Warning: could not coerce config input_size into (H, W). Falling back to (220, 150).")
    input_size = (220, 150)

# Final sanity check
if not (isinstance(input_size, (tuple, list)) and len(input_size) == 2 and
        isinstance(input_size[0], int) and isinstance(input_size[1], int)):
    raise ValueError(f"input_size after coercion is invalid: {input_size}")

print("Using input_size (H, W):", input_size)

# Now build the transform (Resize accepts int or (H, W))
transform = transforms.Compose([
    transforms.Resize(input_size),
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)


Raw input_size from config: (220,150) <class 'str'>
Using input_size (H, W): (220, 150)
Using device: cuda


In [None]:
# === Load per-fold split JSON files produced by prepare_kfold_splits.py ===
pattern = re.compile(r"cedar_meta_split_fold_(\d+)\.json$")
found = []
if os.path.isdir(SPLITS_DIR):
    for fname in os.listdir(SPLITS_DIR):
        m = pattern.match(fname)
        if m:
            found.append((int(m.group(1)), os.path.join(SPLITS_DIR, fname)))

# If none found in SPLITS_DIR, also try /mnt/data (sometimes split files placed there)
if not found:
    alt_dir = '/mnt/data'
    if os.path.isdir(alt_dir):
        for fname in os.listdir(alt_dir):
            m = pattern.match(fname)
            if m:
                found.append((int(m.group(1)), os.path.join(alt_dir, fname)))

if not found:
    raise FileNotFoundError(f"No per-fold split JSON files found. Looked in {SPLITS_DIR} and /mnt/data. "
                            "Please run prepare_kfold_splits.py --output_dir <dir> first and point SPLITS_DIR to it.")

found = sorted(found, key=lambda x: x[0])
kfold_splits = []
for idx, filepath in found:
    with open(filepath, 'r', encoding='utf-8') as f:
        data = json.load(f)
    if isinstance(data, dict):
        # Some split files store keys 'train_users' and 'val_users' (flat lists)
        if 'train_users' in data and 'val_users' in data:
            kfold_splits.append({'fold': idx, 'train_users': data['train_users'], 'val_users': data['val_users']})
        else:
            kfold_splits.append({'fold': idx, **data})
    else:
        print(f"Warning: fold file {filepath} loaded but has unexpected format (not dict).")
        kfold_splits.append({'fold': idx, 'raw': data})

print(f"Loaded {len(kfold_splits)} fold files. Folds: {[f['fold'] for f in kfold_splits]}")
print("Sample fold keys:", list(kfold_splits[0].keys()))


Loaded 5 fold files. Folds: [0, 1, 2, 3, 4]
Sample fold keys: ['fold', 'meta-train', 'meta-test']


In [13]:
# === Experiment configs ===
if 'training' in config:
    num_epochs = int(config['training'].get('num_epochs', 50))
    learning_rate = float(config['training'].get('learning_rate', 1e-4))
    batch_size = 64
    early_stop_patience = config['training'].get('early_stop', 3)
else:
    tp = config.get('train_params', {})
    num_epochs = int(tp.get('num_epochs', 50))
    learning_rate = float(tp.get('learning_rate', 1e-4))
    batch_size = 64
    early_stop_patience = tp.get('early_stop', 3)

feature_dim = config['model'].get('feature_dim', 512)

NUM_EPOCHS_PER_FOLD = num_epochs
MODES_TO_TEST = ['learnable', 'euclidean', 'manhattan', 'cosine']
MARGINS_TO_TEST = [0.5]

print(f"Set Batch Size for T4x2: {batch_size}")
print(f"NUM_EPOCHS_PER_FOLD={NUM_EPOCHS_PER_FOLD}, LR={learning_rate}, BATCH_SIZE={batch_size}")
print(f"EARLY_STOP_PATIENCE={early_stop_patience}, FEATURE_DIM={feature_dim}")
print(f"MODES_TO_TEST={MODES_TO_TEST}")
print(f"MARGINS_TO_TEST={MARGINS_TO_TEST}")

Set Batch Size for T4x2: 64
NUM_EPOCHS_PER_FOLD=50, LR=0.0001, BATCH_SIZE=64
EARLY_STOP_PATIENCE=3, FEATURE_DIM=512
MODES_TO_TEST=['learnable', 'euclidean', 'manhattan', 'cosine']
MARGINS_TO_TEST=[0.5]


In [14]:
# === Run K-Fold experiments (WITH EARLY STOPPING) ===
results_data = []

EVAL_FREQ = 10  # Evaluate every 10 epochs
print(f"Validation evaluation will be performed every {EVAL_FREQ} epochs.")

print(f"Using Early Stopping with patience = {early_stop_patience} epochs.")
print(f"Starting experiments: modes={MODES_TO_TEST}, margins={MARGINS_TO_TEST}, folds={len(kfold_splits)}")

n_gpu = torch.cuda.device_count()
print(f"Number of GPUs available: {n_gpu}")
if n_gpu > 1:
    print("Using nn.DataParallel for multi-GPU training.")

for mode in MODES_TO_TEST:
    for margin in MARGINS_TO_TEST:

        fold_metrics = {'eer': [], 'roc_auc': [], 'accuracy': [], 'precision': [], 'recall': [], 'f1': [], 'train_loss': [], 'best_epoch': []}
        config_name = f"mode={mode}_margin={margin}"
        print(f"\n--- Experiment: {config_name} ---")

        for fold_entry in tqdm(kfold_splits, desc=f"Folds for {config_name}", leave=False):
            fold_index = fold_entry.get('fold', None)

            train_user_dict = fold_entry.get('meta-train', {})
            val_user_dict = fold_entry.get('meta-test', {})
            try:
                train_user_ids = {int(uid) for uid in train_user_dict.keys()}
                val_user_ids = {int(uid) for uid in val_user_dict.keys()}
            except ValueError as e:
                print(f"    ERROR: Could not convert user IDs to int in fold {fold_index}. Skipping. Error: {e}")
                continue

            # Initialize Model
            model = tSSN(backbone_name=config['model'].get('backbone') if isinstance(config.get('model'), dict) else config['model'],
                         output_dim=feature_dim,
                         pretrained=True).to(device)
            if n_gpu > 1:
                model = nn.DataParallel(model)

            loss_fn_module = None
            optimizer = None

            if mode == 'learnable':
                loss_fn = DistanceNetTripletLoss(margin=margin, input_dim=feature_dim).to(device)
                loss_fn_module = loss_fn
                params_to_optimize = list(model.parameters()) + list(loss_fn.distance_net.parameters())
                optimizer = optim.Adam(params_to_optimize, lr=learning_rate)
            else:
                loss_fn = TripletLoss(margin=margin, mode=mode).to(device)
                loss_fn_module = loss_fn
                optimizer = optim.Adam(model.parameters(), lr=learning_rate)

            # Dataloaders
            # print(f"    [DEBUG] Fold {fold_index}: Initializing training triplet dataset...")
            train_triplet_dataset = SignaturePretrainDataset(org_dir=ORG_DIR, forg_dir=FORG_DIR, transform=transform)
            original_count = len(train_triplet_dataset.triplets)
            # print(f"    [DEBUG] Fold {fold_index}: Filtering {original_count} triplets...")
            train_triplet_dataset.triplets = [
                t for t in train_triplet_dataset.triplets
                if _get_user_id_from_filename(os.path.basename(t[0])) in train_user_ids
            ]
            # print(f"    [DEBUG] Fold {fold_index}: Initializing validation pair dataset...")
            val_pair_dataset = SignaturePairDataset(org_dir=ORG_DIR, forg_dir=FORG_DIR, user_ids=val_user_ids, transform=transform)

            # print(f"    [DEBUG] Fold {fold_index}: Creating DataLoaders (num_workers=2, pin_memory=True)...")
            train_loader = DataLoader(train_triplet_dataset, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=collate_fn_skip_none, pin_memory=True)
            val_loader = DataLoader(val_pair_dataset, batch_size=max(1, batch_size*2), shuffle=False, num_workers=2, collate_fn=collate_fn_skip_none, pin_memory=True)

            if len(train_loader) == 0 or len(val_loader) == 0:
                print(f"    Skipping fold {fold_index} due to missing data.")
                continue

            best_val_eer = 1.0
            best_metrics = {}
            best_model_state_to_save = None
            epochs_no_improve = 0
            last_eval_epoch = 0

            # print(f"    [DEBUG] Fold {fold_index}: Starting epoch loop...")
            epoch_iterator = trange(NUM_EPOCHS_PER_FOLD, desc=f"  Fold {fold_index} Epochs", leave=False)
            for epoch in epoch_iterator:
                if mode == 'learnable':
                    train_loss = train_epoch_distnet(model, train_loader, loss_fn_module, optimizer, device)
                else:
                    train_loss = train_epoch(model, train_loader, loss_fn_module, optimizer, device)

                # Evaluate only every EVAL_FREQ epochs OR on the final epoch
                if (epoch + 1) % EVAL_FREQ == 0 or (epoch + 1) == NUM_EPOCHS_PER_FOLD:
                    last_eval_epoch = epoch + 1
                    if mode == 'learnable':
                        eer, roc_auc, acc, precision, recall, f1 = evaluate_on_pairs_distnet(model, val_loader, loss_fn_module, device)
                    else:
                        eer, roc_auc, acc, precision, recall, f1 = evaluate_on_pairs(model, val_loader, device, distance_mode=mode)

                    epoch_iterator.set_postfix(TrainLoss=f"{train_loss:.4f}", ValEER=f"{eer:.4f} (ep {epoch+1})")

                    if eer < best_val_eer:
                        best_val_eer = eer
                        epochs_no_improve = 0
                        best_metrics = {
                            'eer': eer, 'roc_auc': roc_auc, 'accuracy': acc,
                            'precision': precision, 'recall': recall, 'f1': f1,
                            'train_loss': train_loss, 'best_epoch': epoch + 1
                        }
                        # Save state_dict of backbone (and distance_net separately if needed)
                        best_model_state_to_save = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
                        if mode == 'learnable':
                           best_distnet_state_to_save = loss_fn_module.distance_net.state_dict()

                    else:
                        epochs_no_improve += EVAL_FREQ

                    if epochs_no_improve >= early_stop_patience:
                        print(f"    Early stopping triggered after epoch {epoch + 1} ({epochs_no_improve} epochs without improvement since epoch {best_metrics.get('best_epoch', '?')}). Best EER: {best_val_eer:.4f}")
                        break
                else:
                    epoch_iterator.set_postfix(TrainLoss=f"{train_loss:.4f}", ValEER=f"...")

            # Save best metrics
            if best_metrics:
                print(f"    Fold {fold_index} finished (last eval at epoch {last_eval_epoch}). Best EER: {best_metrics['eer']:.4f} (at epoch {best_metrics['best_epoch']}) | ROC_AUC: {best_metrics['roc_auc']:.4f} | F1: {best_metrics['f1']:.4f}")
                fold_metrics['eer'].append(best_metrics['eer'])
                fold_metrics['roc_auc'].append(best_metrics['roc_auc'])
                fold_metrics['accuracy'].append(best_metrics['accuracy'])
                fold_metrics['precision'].append(best_metrics['precision'])
                fold_metrics['recall'].append(best_metrics['recall'])
                fold_metrics['f1'].append(best_metrics['f1'])
                fold_metrics['train_loss'].append(best_metrics['train_loss'])
                fold_metrics['best_epoch'].append(best_metrics['best_epoch'])
            else:
                print(f"    Fold {fold_index} did not complete (no metrics after {last_eval_epoch} epochs).")

        # Aggregate config results
        mean_eer = float(np.mean(fold_metrics['eer'])) if fold_metrics['eer'] else float('nan')
        mean_roc_auc = float(np.mean(fold_metrics['roc_auc'])) if fold_metrics['roc_auc'] else float('nan')
        mean_acc = float(np.mean(fold_metrics['accuracy'])) if fold_metrics['accuracy'] else float('nan')
        mean_precision = float(np.mean(fold_metrics['precision'])) if fold_metrics['precision'] else float('nan')
        mean_recall = float(np.mean(fold_metrics['recall'])) if fold_metrics['recall'] else float('nan')
        mean_f1 = float(np.mean(fold_metrics['f1'])) if fold_metrics['f1'] else float('nan')
        mean_loss = float(np.mean(fold_metrics['train_loss'])) if fold_metrics['train_loss'] else float('nan')
        mean_best_epoch = float(np.mean(fold_metrics['best_epoch'])) if fold_metrics['best_epoch'] else float('nan')

        print(f"\n  >> {config_name} -> Mean EER: {mean_eer:.4f} | Mean ROC-AUC: {mean_roc_auc:.4f} | Mean F1: {mean_f1:.4f} | Mean Acc: {mean_acc:.4f} | Mean Best Epoch: {mean_best_epoch:.1f}")

        results_data.append({
            'mode': mode, 'margin': margin,
            'mean_eer': mean_eer, 'mean_roc_auc': mean_roc_auc,
            'mean_accuracy': mean_acc,
            'mean_precision': mean_precision,
            'mean_recall': mean_recall,
            'mean_f1': mean_f1,
            'mean_train_loss': mean_loss,
            'mean_best_epoch': mean_best_epoch
        })

print("\nAll experiments finished.")

Validation evaluation will be performed every 10 epochs.
Using Early Stopping with patience = 3 epochs.
Starting experiments: modes=['learnable', 'euclidean', 'manhattan', 'cosine'], margins=[0.5], folds=5
Number of GPUs available: 2
Using nn.DataParallel for multi-GPU training.

--- Experiment: mode=learnable_margin=0.5 ---


Folds for mode=learnable_margin=0.5:   0%|          | 0/5 [00:00<?, ?it/s]

Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth

  0%|          | 0.00/83.3M [00:00<?, ?B/s][A
  6%|▋         | 5.25M/83.3M [00:00<00:01, 54.4MB/s][A
 13%|█▎        | 10.5M/83.3M [00:00<00:01, 50.1MB/s][A
 33%|███▎      | 27.8M/83.3M [00:00<00:00, 107MB/s] [A
 56%|█████▌    | 46.5M/83.3M [00:00<00:00, 141MB/s][A
100%|██████████| 83.3M/83.3M [00:00<00:00, 146MB/s]


Generated 1320 triplets for pre-training.


  Fold 0 Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

    Early stopping triggered after epoch 30 (10 epochs without improvement since epoch 20). Best EER: 0.0454
    Fold 0 finished (last eval at epoch 30). Best EER: 0.0454 (at epoch 20) | ROC_AUC: 0.9700 | F1: 0.9314
Generated 1320 triplets for pre-training.


  Fold 1 Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

    Early stopping triggered after epoch 30 (10 epochs without improvement since epoch 20). Best EER: 0.0637
    Fold 1 finished (last eval at epoch 30). Best EER: 0.0637 (at epoch 20) | ROC_AUC: 0.9819 | F1: 0.9047
Generated 1320 triplets for pre-training.


  Fold 2 Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

    Early stopping triggered after epoch 20 (10 epochs without improvement since epoch 10). Best EER: 0.0181
    Fold 2 finished (last eval at epoch 20). Best EER: 0.0181 (at epoch 10) | ROC_AUC: 0.9983 | F1: 0.9724
Generated 1320 triplets for pre-training.


  Fold 3 Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

    Early stopping triggered after epoch 30 (10 epochs without improvement since epoch 20). Best EER: 0.0201
    Fold 3 finished (last eval at epoch 30). Best EER: 0.0201 (at epoch 20) | ROC_AUC: 0.9921 | F1: 0.9694
Generated 1320 triplets for pre-training.


  Fold 4 Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

    Early stopping triggered after epoch 20 (10 epochs without improvement since epoch 10). Best EER: 0.0000
    Fold 4 finished (last eval at epoch 20). Best EER: 0.0000 (at epoch 10) | ROC_AUC: 1.0000 | F1: 1.0000

  >> mode=learnable_margin=0.5 -> Mean EER: 0.0294 | Mean ROC-AUC: 0.9885 | Mean F1: 0.9556 | Mean Acc: 0.9705 | Mean Best Epoch: 16.0

--- Experiment: mode=euclidean_margin=0.5 ---


Folds for mode=euclidean_margin=0.5:   0%|          | 0/5 [00:00<?, ?it/s]

Generated 1320 triplets for pre-training.


  Fold 0 Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

    Early stopping triggered after epoch 20 (10 epochs without improvement since epoch 10). Best EER: 0.0331
    Fold 0 finished (last eval at epoch 20). Best EER: 0.0331 (at epoch 10) | ROC_AUC: 0.9960 | F1: 0.9494
Generated 1320 triplets for pre-training.


  Fold 1 Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

    Early stopping triggered after epoch 50 (10 epochs without improvement since epoch 40). Best EER: 0.0354
    Fold 1 finished (last eval at epoch 50). Best EER: 0.0354 (at epoch 40) | ROC_AUC: 0.9953 | F1: 0.9464
Generated 1320 triplets for pre-training.


  Fold 2 Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

    Early stopping triggered after epoch 30 (10 epochs without improvement since epoch 20). Best EER: 0.0472
    Fold 2 finished (last eval at epoch 30). Best EER: 0.0472 (at epoch 20) | ROC_AUC: 0.9907 | F1: 0.9289
Generated 1320 triplets for pre-training.


  Fold 3 Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

    Early stopping triggered after epoch 20 (10 epochs without improvement since epoch 10). Best EER: 0.0367
    Fold 3 finished (last eval at epoch 20). Best EER: 0.0367 (at epoch 10) | ROC_AUC: 0.9945 | F1: 0.9445
Generated 1320 triplets for pre-training.


  Fold 4 Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

    Early stopping triggered after epoch 30 (10 epochs without improvement since epoch 20). Best EER: 0.0381
    Fold 4 finished (last eval at epoch 30). Best EER: 0.0381 (at epoch 20) | ROC_AUC: 0.9945 | F1: 0.9425

  >> mode=euclidean_margin=0.5 -> Mean EER: 0.0381 | Mean ROC-AUC: 0.9942 | Mean F1: 0.9423 | Mean Acc: 0.9618 | Mean Best Epoch: 20.0

--- Experiment: mode=manhattan_margin=0.5 ---


Folds for mode=manhattan_margin=0.5:   0%|          | 0/5 [00:00<?, ?it/s]

Generated 1320 triplets for pre-training.


  Fold 0 Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

    Early stopping triggered after epoch 20 (10 epochs without improvement since epoch 10). Best EER: 0.0767
    Fold 0 finished (last eval at epoch 20). Best EER: 0.0767 (at epoch 10) | ROC_AUC: 0.9745 | F1: 0.8862
Generated 1320 triplets for pre-training.


  Fold 1 Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

    Early stopping triggered after epoch 20 (10 epochs without improvement since epoch 10). Best EER: 0.0461
    Fold 1 finished (last eval at epoch 20). Best EER: 0.0461 (at epoch 10) | ROC_AUC: 0.9915 | F1: 0.9308
Generated 1320 triplets for pre-training.


  Fold 2 Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

    Early stopping triggered after epoch 30 (10 epochs without improvement since epoch 20). Best EER: 0.0479
    Fold 2 finished (last eval at epoch 30). Best EER: 0.0479 (at epoch 20) | ROC_AUC: 0.9884 | F1: 0.9278
Generated 1320 triplets for pre-training.


  Fold 3 Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

    Early stopping triggered after epoch 20 (10 epochs without improvement since epoch 10). Best EER: 0.0510
    Fold 3 finished (last eval at epoch 20). Best EER: 0.0510 (at epoch 10) | ROC_AUC: 0.9886 | F1: 0.9238
Generated 1320 triplets for pre-training.


  Fold 4 Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

    Early stopping triggered after epoch 30 (10 epochs without improvement since epoch 20). Best EER: 0.0135
    Fold 4 finished (last eval at epoch 30). Best EER: 0.0135 (at epoch 20) | ROC_AUC: 0.9992 | F1: 0.9791

  >> mode=manhattan_margin=0.5 -> Mean EER: 0.0471 | Mean ROC-AUC: 0.9884 | Mean F1: 0.9295 | Mean Acc: 0.9530 | Mean Best Epoch: 14.0

--- Experiment: mode=cosine_margin=0.5 ---


Folds for mode=cosine_margin=0.5:   0%|          | 0/5 [00:00<?, ?it/s]

Generated 1320 triplets for pre-training.


  Fold 0 Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

    Early stopping triggered after epoch 20 (10 epochs without improvement since epoch 10). Best EER: 0.0607
    Fold 0 finished (last eval at epoch 20). Best EER: 0.0607 (at epoch 10) | ROC_AUC: 0.9884 | F1: 0.9092
Generated 1320 triplets for pre-training.


  Fold 1 Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

    Early stopping triggered after epoch 20 (10 epochs without improvement since epoch 10). Best EER: 0.0093
    Fold 1 finished (last eval at epoch 20). Best EER: 0.0093 (at epoch 10) | ROC_AUC: 0.9993 | F1: 0.9856
Generated 1320 triplets for pre-training.


  Fold 2 Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

    Early stopping triggered after epoch 30 (10 epochs without improvement since epoch 20). Best EER: 0.0334
    Fold 2 finished (last eval at epoch 30). Best EER: 0.0334 (at epoch 20) | ROC_AUC: 0.9944 | F1: 0.9492
Generated 1320 triplets for pre-training.


  Fold 3 Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

    Early stopping triggered after epoch 20 (10 epochs without improvement since epoch 10). Best EER: 0.0324
    Fold 3 finished (last eval at epoch 20). Best EER: 0.0324 (at epoch 10) | ROC_AUC: 0.9961 | F1: 0.9509
Generated 1320 triplets for pre-training.


  Fold 4 Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

    Early stopping triggered after epoch 30 (10 epochs without improvement since epoch 20). Best EER: 0.0098
    Fold 4 finished (last eval at epoch 30). Best EER: 0.0098 (at epoch 20) | ROC_AUC: 0.9997 | F1: 0.9851

  >> mode=cosine_margin=0.5 -> Mean EER: 0.0291 | Mean ROC-AUC: 0.9956 | Mean F1: 0.9560 | Mean Acc: 0.9708 | Mean Best Epoch: 14.0

All experiments finished.


=== Analyze results and pick best config ===

In [None]:
results_df = pd.DataFrame(results_data)
if results_df.empty:
    raise RuntimeError("No results were produced. Check data directories and split files.")

all_metrics_columns = [
    'mode', 'margin', 
    'mean_eer', 'mean_roc_auc', 'mean_f1', 
    'mean_accuracy', 'mean_precision', 'mean_recall', 
    'mean_train_loss', 'mean_best_epoch'
]
all_metrics_columns = [col for col in all_metrics_columns if col in results_df.columns]

print("Summary table of results (unsorted):")
print(results_df[all_metrics_columns].to_markdown(index=False, floatfmt=".4f"))

# --- Arrange to find the best configuration (Ranking) ---
# Sort by: 1. EER (lowest), 2. ROC-AUC (highest), 3. F1 (highest)
ranking_df = results_df.sort_values(
    by=['mean_eer', 'mean_roc_auc', 'mean_f1'], 
    ascending=[True, False, False]
).reset_index(drop=True)

print("\nLeaderboard (best at the top):")
print(ranking_df[all_metrics_columns].to_markdown(index=False, floatfmt=".4f"))

best_config = ranking_df.iloc[0]
print("\n--- BEST PRE-TRAINING CONFIGURATION CHOSEN (K-Fold Average) ---")
print(best_config.to_dict())

Summary table of results (unsorted):
| mode      |   margin |   mean_eer |   mean_roc_auc |   mean_f1 |   mean_accuracy |   mean_precision |   mean_recall |   mean_train_loss |   mean_best_epoch |
|:----------|---------:|-----------:|---------------:|----------:|----------------:|-----------------:|--------------:|------------------:|------------------:|
| learnable |   0.5000 |     0.0294 |         0.9885 |    0.9556 |          0.9705 |           0.9412 |        0.9708 |            0.0013 |           16.0000 |
| euclidean |   0.5000 |     0.0381 |         0.9942 |    0.9423 |          0.9618 |           0.9234 |        0.9621 |            0.0046 |           20.0000 |
| manhattan |   0.5000 |     0.0471 |         0.9884 |    0.9295 |          0.9530 |           0.9076 |        0.9528 |            0.0255 |           14.0000 |
| cosine    |   0.5000 |     0.0291 |         0.9956 |    0.9560 |          0.9708 |           0.9416 |        0.9710 |            0.0006 |           14.0000 |

Le