=== STEP 1. SETUP AND IMPORTS ===

In [None]:
# Always pull the latest repo to ensure consistency
!rm -rf Deep-Learning-Based-Signature-Forgery-Detection-for-Personal-Identity-Authentication-Update
!git clone https://github.com/trongjhuongwr/Deep-Learning-Based-Signature-Forgery-Detection-for-Personal-Identity-Authentication-Update.git
%cd Deep-Learning-Based-Signature-Forgery-Detection-for-Personal-Identity-Authentication-Update

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
import random
import os
import json
import sys
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import time

sys.path.append(os.path.abspath(os.getcwd()))

from dataloader.meta_dataloader import SignatureEpisodeDataset
from models.feature_extractor import ResNetFeatureExtractor
from models.meta_learner import MetricGenerator
from utils.model_evaluation import evaluate_meta_model, plot_roc_curve, plot_confusion_matrix
from utils.helpers import MemoryTracker

print("Setup and Imports successful!")

=== STEP 2. CONFIGURATION ===

In [None]:
K_SHOT = 10
N_QUERY_GENUINE = 15
N_QUERY_FORGERY = 15

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

BEST_CEDAR_MODEL_DIR = '/kaggle/input/best-cedar-model-weights/best_model_fold_2'  # adjust as needed
BHSIG_RAW_BASE_DIR = '/kaggle/input/cedarbhsig-260/'  # adjust if using a different dataset
SPLIT_OUTPUT_DIR = '/kaggle/working/'
BHSIG_SPLIT_FILE = os.path.join(SPLIT_OUTPUT_DIR, 'bhsig_restructured_split.json')

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

NUM_WORKERS = 2 if 'kaggle' in os.environ.get('KAGGLE_KERNEL_RUN_TYPE', '') else 0

print(f"Best CEDAR model path: {BEST_CEDAR_MODEL_DIR}")
print(f"BHSig raw path: {BHSIG_RAW_BASE_DIR}")
print(f"Split JSON will be saved to: {BHSIG_SPLIT_FILE}")

=== STEP 3. PREPARE BHSIG-260 SPLIT FILE ===

In [None]:
print("📦 Restructuring BHSig-260 dataset...")
script_path = 'scripts/restructure_bhsig.py'

# Run the restructuring script
command = (
    f"python {script_path} "
    f"--base_dir {BHSIG_RAW_BASE_DIR} "
    f"--output_dir {SPLIT_OUTPUT_DIR} "
    f"--seed {SEED} "
    f"--num_bengali 50 "
    f"--num_hindi 30"
)

print(f"Running command: {command}")
!{command}

# Check if the file was created
if not os.path.exists(BHSIG_SPLIT_FILE):
    raise FileNotFoundError(f"❌ Split file not found at {BHSIG_SPLIT_FILE}. Please check script output.")
else:
    print("BHSig-260 split file generated successfully!")

=== STEP 4. LOAD THE BEST MODEL TRAINED ON CEDAR ===

In [None]:
print(f"--- Loading best model from {BEST_CEDAR_MODEL_DIR} ---")

feature_extractor = ResNetFeatureExtractor(backbone_name='resnet34', output_dim=512)
metric_generator = MetricGenerator(embedding_dim=512)

fe_path = os.path.join(BEST_CEDAR_MODEL_DIR, 'best_feature_extractor.pth')
mg_path = os.path.join(BEST_CEDAR_MODEL_DIR, 'best_metric_generator.pth')

if not os.path.exists(fe_path):
    raise FileNotFoundError(f"Feature extractor weights not found: {fe_path}")
if not os.path.exists(mg_path):
    raise FileNotFoundError(f"Metric generator weights not found: {mg_path}")

feature_extractor.load_state_dict(torch.load(fe_path, map_location=DEVICE))
metric_generator.load_state_dict(torch.load(mg_path, map_location=DEVICE))
feature_extractor.to(DEVICE)
metric_generator.to(DEVICE)

print("Successfully loaded CEDAR-trained model weights!")

=== STEP 5. CREATE BHSIG-260 EVALUATION DATASET ===

In [None]:
print("--- Creating evaluation dataset for BHSig-260 ---")

eval_transform = transforms.Compose([
    transforms.Resize((220, 150)),
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

try:
    bhsig_test_dataset = SignatureEpisodeDataset(
        split_file_path=BHSIG_SPLIT_FILE,
        base_data_dir=None,
        split_name='meta-test',
        k_shot=K_SHOT,
        n_query_genuine=N_QUERY_GENUINE,
        n_query_forgery=N_QUERY_FORGERY,
        augment=False,
        use_full_path=True
    )
    print(f"Created BHSig-260 evaluation dataset with {len(bhsig_test_dataset)} episodes.")
except Exception as e:
    print(f"Error creating dataset: {e}")
    raise

=== STEP 6. PERFORM CROSS-DATASET EVALUATION ===

In [None]:
print(f"\n--- Starting Cross-Dataset Evaluation on BHSig-260 ({len(bhsig_test_dataset)} users) ---")

eval_start_time = time.time()

if DEVICE.type == 'cuda':
    eval_memory_tracker = MemoryTracker(DEVICE)
    eval_initial_gpu_mem = eval_memory_tracker.get_used_memory_mb()
else:
    eval_memory_tracker = None
    eval_initial_gpu_mem = 0

results_dict, true_labels, predictions, distances = evaluate_meta_model(
    feature_extractor,
    metric_generator,
    bhsig_test_dataset,
    DEVICE
)

eval_duration = time.time() - eval_start_time
print(f"Evaluation finished in {eval_duration:.2f} seconds.")


=== STEP 7. REPORT AND VISUALIZE RESULTS ===

In [None]:
print("\n\n--- FINAL CROSS-DATASET RESULTS (CEDAR → BHSig-260) ---")

if results_dict:
    for metric, value in results_dict.items():
        if metric == 'accuracy':
            print(f"  - {metric.capitalize()}: {value*100:.2f}%")
        else:
            print(f"  - {metric.capitalize()}: {value:.4f}")
else:
    print("Evaluation produced no valid results.")

# Visualizations
if len(true_labels) > 0 and len(predictions) > 0 and len(distances) > 0:
    print("\nGenerating ROC and Confusion Matrix plots...")
    plot_roc_curve(true_labels, distances, title='ROC Curve - Cross-Dataset (CEDAR → BHSig-260)')
    plot_confusion_matrix(true_labels, predictions, title='Confusion Matrix - Cross-Dataset (CEDAR → BHSig-260)')
else:
    print("\nSkipping visualizations — missing evaluation data.")

# Memory usage report
if eval_memory_tracker:
    eval_final_gpu_mem = eval_memory_tracker.get_used_memory_mb()
    print(f"\nInitial GPU Mem: {eval_initial_gpu_mem:.2f} MB")
    print(f"Final GPU Mem: {eval_final_gpu_mem:.2f} MB")
    print(f"≈ Used During Eval: {eval_final_gpu_mem - eval_initial_gpu_mem:.2f} MB")
    del eval_memory_tracker
