=========================================================

Meta-Training with 5-Fold Cross-Validation on CEDAR

=========================================================

Objective: 
  Train and rigorously evaluate the adaptive metric learning framework 
  (Feature Extractor + Metric Generator) on the CEDAR dataset.

Methodology:
1. K-Fold Splits: Use `prepare_kfold_splits.py` to create 5 folds.

2. Model: Load pre-trained ResNetFeatureExtractor + initialize MetricGenerator.

3. Dataloader: SignatureEpisodeDataset (meta_dataloader.py)
      - augment=True for training (meta-train)
      - augment=False for validation (meta-test)

4. Training: Online Hard Triplet Mining with adaptive Mahalanobis distance.

5. Evaluation: evaluate_meta_model() → Acc, Precision, Recall, F1, ROC-AUC

6. Result: Mean ± Std of all metrics across 5 folds.

=== 1. SETUP AND IMPORTS ===

In [None]:
# Ensure the latest version of the code is used
!rm -rf Deep-Learning-Based-Signature-Forgery-Detection-for-Personal-Identity-Authentication-Update
!git clone https://github.com/trongjhuongwr/Deep-Learning-Based-Signature-Forgery-Detection-for-Personal-Identity-Authentication-Update.git
%cd Deep-Learning-Based-Signature-Forgery-Detection-for-Personal-Identity-Authentication-Update

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.optim.lr_scheduler import CosineAnnealingLR
import numpy as np
import random
import os
import json
import sys
from tqdm.notebook import tqdm
from copy import deepcopy
import matplotlib.pyplot as plt
import pandas as pd
import time

sys.path.append(os.path.abspath(os.getcwd()))

from dataloader.meta_dataloader import SignatureEpisodeDataset
from models.feature_extractor import ResNetFeatureExtractor
from models.meta_learner import MetricGenerator
from losses.triplet_loss import pairwise_mahalanobis_distance
from utils.model_evaluation import evaluate_meta_model, plot_roc_curve, plot_confusion_matrix
from utils.helpers import MemoryTracker

print("Setup and Imports successful!")

=== 2. GLOBAL CONFIGURATION ===

In [None]:
# Meta-Learning Parameters
K_SHOT = 10          # Number of support samples per user/episode
N_QUERY_GENUINE = 14 # Number of genuine query samples per user/episode
N_QUERY_FORGERY = 14 # Number of forgery query samples per user/episode

# Training Hyperparameters
NUM_EPOCHS = 60
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4
MARGIN = 0.5
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
NUM_SPLITS = 5
PATIENCE = 20

# Paths
BASE_DATA_DIR = '/kaggle/input/cedar-dataset-for-signature-verification/signatures'
SPLIT_FILES_DIR = '/kaggle/working/Deep-Learning-Based-Signature-Forgery-Detection-for-Personal-Identity-Authentication-Update/scripts/kfold_splits'
PRETRAINED_WEIGHTS_PATH = '/kaggle/input/my-pretrained-weights/pretrained_feature_extractor.pth'
BEST_MODEL_SAVE_DIR_TEMPLATE = '/kaggle/working/best_model_fold_{}'

# Seed
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

print(f"Using device: {DEVICE}")
print(f"Configuration loaded. Running {NUM_SPLITS}-Fold Cross-Validation.")

NUM_WORKERS = 2 if 'kaggle' in os.environ.get('KAGGLE_KERNEL_RUN_TYPE', '') else 0

=== 3. GENERATE K-FOLD SPLIT FILES ===

In [None]:
print("Generating K-Fold split files...")
os.makedirs(SPLIT_FILES_DIR, exist_ok=True)

script_path = 'scripts/prepare_kfold_splits.py'
command = f"python {script_path} --base_data_dir {BASE_DATA_DIR} --output_dir {SPLIT_FILES_DIR} --seed {SEED} --num_splits {NUM_SPLITS}"

print(f"Running command: {command}")
!{command}

created_files = os.listdir(SPLIT_FILES_DIR)
print(f"Generated files in {SPLIT_FILES_DIR}: {created_files}")
if len(created_files) != NUM_SPLITS:
    print(f"Warning: Expected {NUM_SPLITS} split files, but found {len(created_files)}.")
else:
    print("K-Fold split files generated successfully.")

In [None]:
# Run the preparation script to create 5 JSON files defining the user splits.
# !python prepare_kfold_splits.py

=== 4. K-FOLD CROSS-VALIDATION LOOP ===

In [None]:
# Store results from each fold
fold_results_list = []
fold_best_epochs = []
fold_training_times = []

# Initialize MemoryTracker if GPU is available
if DEVICE.type == 'cuda':
    memory_tracker = MemoryTracker(DEVICE)
    initial_gpu_mem = memory_tracker.get_used_memory_mb()
else:
    memory_tracker = None
    initial_gpu_mem = 0

print(f"\n--- Starting {NUM_SPLITS}-Fold Cross-Validation ---")

for fold_idx in range(NUM_SPLITS):
    fold_start_time = time.time()
    print(f"\n================== Starting Fold {fold_idx + 1}/{NUM_SPLITS} ==================")

    # --- 4.1. Load Data for the Current Fold ---
    split_file_path = os.path.join(SPLIT_FILES_DIR, f'cedar_meta_split_fold_{fold_idx}.json')
    if not os.path.exists(split_file_path):
        print(f"Error: Split file not found for fold {fold_idx}: {split_file_path}. Skipping fold.")
        continue

    try:
        print(f"Fold {fold_idx + 1}: Creating training dataset (Augmentation: True)...")
        train_dataset = SignatureEpisodeDataset(
            split_file_path=split_file_path,
            base_data_dir=BASE_DATA_DIR,
            split_name='meta-train',
            k_shot=K_SHOT,
            n_query_genuine=N_QUERY_GENUINE,
            n_query_forgery=N_QUERY_FORGERY,
            augment=True,
            use_full_path=False
        )

        print(f"Fold {fold_idx + 1}: Creating validation dataset (Augmentation: False)...")
        val_dataset = SignatureEpisodeDataset(
            split_file_path=split_file_path,
            base_data_dir=BASE_DATA_DIR,
            split_name='meta-test',
            k_shot=K_SHOT,
            n_query_genuine=N_QUERY_GENUINE,
            n_query_forgery=N_QUERY_FORGERY,
            augment=False,
            use_full_path=False
        )

        train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=NUM_WORKERS, pin_memory=(DEVICE.type == 'cuda'))
    except Exception as e:
        print(f"Error creating datasets/loaders for fold {fold_idx + 1}: {e}. Skipping fold.")
        continue

    # Initialize Model and Optimizer for Each Fold
    feature_extractor = ResNetFeatureExtractor(backbone_name='resnet34', output_dim=512, pretrained=False)

    if os.path.exists(PRETRAINED_WEIGHTS_PATH):
        try:
            pretrained_state_dict = torch.load(PRETRAINED_WEIGHTS_PATH, map_location=DEVICE)
            feature_extractor.load_state_dict(pretrained_state_dict, strict=True)
            print(f"Fold {fold_idx + 1}: Successfully loaded pre-trained weights.")
        except Exception as e:
            print(f"Fold {fold_idx + 1}: Error loading pre-trained weights: {e}. Using ImageNet fallback.")
            feature_extractor = ResNetFeatureExtractor(backbone_name='resnet34', output_dim=512, pretrained=True)
    else:
        print(f"Fold {fold_idx + 1}: WARNING - Pre-trained weights not found.")
        feature_extractor = ResNetFeatureExtractor(backbone_name='resnet34', output_dim=512, pretrained=True)

    metric_generator = MetricGenerator(embedding_dim=512)
    feature_extractor.to(DEVICE)
    metric_generator.to(DEVICE)

    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs for Fold {fold_idx + 1}.")
        feature_extractor = nn.DataParallel(feature_extractor)
        metric_generator = nn.DataParallel(metric_generator)

    optimizer = optim.AdamW(
        list(feature_extractor.parameters()) + list(metric_generator.parameters()),
        lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
    )

    scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-6)

    best_val_acc = -1.0
    best_val_results = None
    best_epoch = -1
    epochs_no_improve = 0
    fold_train_losses = []
    fold_val_accuracies = []

    for epoch in range(NUM_EPOCHS):
        epoch_start_time = time.time()
        feature_extractor.train()
        metric_generator.train()
        total_epoch_loss = 0.0
        num_valid_episodes = 0

        progress_bar = tqdm(train_loader, desc=f"Fold {fold_idx+1} Epoch {epoch+1}/{NUM_EPOCHS} Training", leave=False)

        for batch in progress_bar:
            if batch is None or 'support_images' not in batch: 
                continue

            support_images = batch['support_images'].squeeze(0).to(DEVICE)
            query_images = batch['query_images'].squeeze(0).to(DEVICE)
            query_labels = batch['query_labels'].squeeze(0).to(DEVICE)

            optimizer.zero_grad()

            if support_images.shape[0] == 0: 
                continue

            all_images = torch.cat([support_images, query_images], dim=0)
            try:
                all_embeddings = feature_extractor(all_images)
            except RuntimeError:
                continue

            support_embeddings = all_embeddings[:support_images.shape[0]]
            query_embeddings = all_embeddings[support_images.shape[0]:]

            if query_embeddings.shape[0] == 0: 
                continue

            try:
                W = metric_generator(support_embeddings)
            except RuntimeError:
                continue

            genuine_em = query_embeddings[query_labels == 1]
            forgery_em = query_embeddings[query_labels == 0]

            if genuine_em.shape[0] < 2 or forgery_em.shape[0] < 1:
                continue

            try:
                dist_ap_mat = pairwise_mahalanobis_distance(genuine_em, genuine_em, W)
                dist_an_mat = pairwise_mahalanobis_distance(genuine_em, forgery_em, W)
            except RuntimeError:
                continue

            dist_ap_mat.fill_diagonal_(float('-inf'))
            hardest_positive_dist, _ = torch.max(dist_ap_mat, dim=1)
            hardest_negative_dist, _ = torch.min(dist_an_mat, dim=1)

            losses = F.relu(hardest_positive_dist - hardest_negative_dist + MARGIN)

            num_active_triplets = torch.sum(losses > 1e-6).item()
            if num_active_triplets > 0:
                episode_loss = torch.sum(losses) / num_active_triplets
                episode_loss.backward()
                optimizer.step()
                total_epoch_loss += episode_loss.item()
                num_valid_episodes += 1
                progress_bar.set_postfix(loss=f"{episode_loss.item():.4f}", active=num_active_triplets)
            else:
                progress_bar.set_postfix(loss="0.0000", active=0)

        # End of Epoch
        avg_epoch_loss = total_epoch_loss / num_valid_episodes if num_valid_episodes > 0 else 0.0
        fold_train_losses.append(avg_epoch_loss)

        epoch_duration = time.time() - epoch_start_time

        fe_eval = feature_extractor.module if isinstance(feature_extractor, nn.DataParallel) else feature_extractor
        mg_eval = metric_generator.module if isinstance(metric_generator, nn.DataParallel) else metric_generator

        val_results_dict, _, _, _ = evaluate_meta_model(fe_eval, mg_eval, val_dataset, DEVICE)
        val_acc = val_results_dict['accuracy']
        fold_val_accuracies.append(val_acc)

        print(f"Fold {fold_idx+1} Epoch {epoch+1}/{NUM_EPOCHS} - "
              f"Train Loss: {avg_epoch_loss:.4f} - "
              f"Val Acc: {val_acc*100:.2f}% - "
              f"Val F1: {val_results_dict['f1_score']:.4f} - "
              f"Val AUC: {val_results_dict['roc_auc']:.4f} - "
              f"LR: {optimizer.param_groups[0]['lr']:.6f} - "
              f"Time: {epoch_duration:.2f}s")

        scheduler.step()

        # Early stopping
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_val_results = val_results_dict
            best_epoch = epoch + 1
            epochs_no_improve = 0

            fold_save_dir = BEST_MODEL_SAVE_DIR_TEMPLATE.format(fold_idx + 1)
            os.makedirs(fold_save_dir, exist_ok=True)
            try:
                torch.save(fe_eval.state_dict(), os.path.join(fold_save_dir, 'best_feature_extractor.pth'))
                torch.save(mg_eval.state_dict(), os.path.join(fold_save_dir, 'best_metric_generator.pth'))
                print(f"✨ Saved best model for Fold {fold_idx + 1} at Epoch {best_epoch} to {fold_save_dir}")
            except Exception as e:
                print(f"Error saving model: {e}")
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= PATIENCE:
            print(f"--- Early stopping at Epoch {epoch + 1} for Fold {fold_idx + 1} ---")
            break

    fold_duration = time.time() - fold_start_time
    if best_val_results:
        fold_results_list.append(best_val_results)
        fold_best_epochs.append(best_epoch)
        fold_training_times.append(fold_duration)
        print(f"\n====== Fold {fold_idx + 1} Finished ======")
        print(f"Best Epoch: {best_epoch}")
        for metric, value in best_val_results.items():
            print(f"  - {metric.capitalize()}: {value*100:.2f}%" if metric == 'accuracy' else f"  - {metric.capitalize()}: {value:.4f}")
        print(f"Total Time: {fold_duration:.2f}s")
    else:
        print(f"\n====== Fold {fold_idx + 1} Finished with no improvement ======")

    del feature_extractor, metric_generator, optimizer, scheduler, train_dataset, val_dataset, train_loader
    torch.cuda.empty_cache()

=== Aggregate and Report Final K-Fold Results ===

In [None]:
print("\n--- Overall K-Fold Cross-Validation Results ---")

valid_fold_results = [res for res in fold_results_list if res is not None]

if not valid_fold_results:
    print("No valid results obtained.")
else:
    results_df = pd.DataFrame(valid_fold_results)
    mean_metrics = results_df.mean()
    std_metrics = results_df.std()

    print("\nPerformance Summary (Mean ± Std):")
    summary_df = pd.DataFrame({'Mean': mean_metrics, 'Std Dev': std_metrics})
    summary_df['Mean'] = summary_df.apply(lambda row: f"{row['Mean']*100:.2f}%" if row.name == 'accuracy' else f"{row['Mean']:.4f}", axis=1)
    summary_df['Std Dev'] = summary_df.apply(lambda row: f"{row['Std Dev']*100:.2f}%" if row.name == 'accuracy' else f"{row['Std Dev']:.4f}", axis=1)
    print(summary_df)

    print(f"\nBest Epoch per Fold: {fold_best_epochs}")
    print(f"Average Training Time per Fold: {np.mean(fold_training_times):.2f}s")

if memory_tracker:
    final_gpu_mem = memory_tracker.get_used_memory_mb()
    print(f"\nInitial GPU Memory: {initial_gpu_mem:.2f} MB")
    print(f"Final GPU Memory: {final_gpu_mem:.2f} MB")
    print(f"Approx. Memory Used: {final_gpu_mem - initial_gpu_mem:.2f} MB")
    del memory_tracker