In [None]:
import sys
import os
from pathlib import Path
import numpy as np
import pickle
import json
from typing import Dict, List, Tuple, Optional
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve, auc, average_precision_score
from scipy import stats
from tqdm import tqdm
import pandas as pd
from datetime import datetime

In [None]:
NOTEBOOK_DIR = Path.cwd()

def find_project_root(start: Path, target_folder="RIG"):
  for parent in [start] + list(start.parents):
    if parent.name == target_folder:
      return parent
  raise RuntimeError(f"Could not find project root '{target_folder}'")

PROJECT_ROOT = find_project_root(NOTEBOOK_DIR)
OUTPUT_ROOT = PROJECT_ROOT / "output" / "v2"

print(f"Project Root: {PROJECT_ROOT}")
print(f"Output Root: {OUTPUT_ROOT}")

output_root = Path(OUTPUT_ROOT)
embeddings_root = output_root / 'embeddings'
results_dir = output_root / 'evaluation_results'
plots_dir = results_dir / 'plots'
plots_dir.mkdir(parents=True, exist_ok=True)

models = [
  'adaface_ir_50',
  'adaface_ir_101',
  'arcface_ir_50',
  'arcface_ir_101'
]

In [None]:
def load_embeddings(model_name: str) -> Dict:
  model_dir = embeddings_root / model_name
  
  if not model_dir.exists():
    raise FileNotFoundError(f"Model directory not found: {model_dir}")
  
  embeddings = {}

  embedding_files = {
    'gallery_oneshot_base': 'gallery_one-shot_base.pkl',
    'gallery_oneshot_augmented': 'gallery_one-shot_augmented.pkl',
    'gallery_fewshot_base': 'gallery_few-shot_base.pkl',
    'gallery_fewshot_augmented': 'gallery_few-shot_augmented.pkl',
    'probe_positive_unsegmented': 'probe_positive_unsegmented.pkl',
    'probe_positive_segmented': 'probe_positive_segmented.pkl',
    'probe_negative': 'probe_negative.pkl'
  }
  
  for key, filename in embedding_files.items():
    file_path = model_dir / filename
    if file_path.exists():
      with open(file_path, 'rb') as f:
        embeddings[key] = pickle.load(f)
    else:
        embeddings[key] = None
  
  return embeddings


In [None]:
def cosine_similarity(emb1: np.ndarray, emb2: np.ndarray) -> float:
    norm1 = np.linalg.norm(emb1)
    norm2 = np.linalg.norm(emb2)
    if abs(norm1 - 1.0) < 0.01 and abs(norm2 - 1.0) < 0.01:
        return np.dot(emb1, emb2)
    return np.dot(emb1, emb2) / (norm1 * norm2)

def compute_all_similarities(probe_emb: np.ndarray, 
                            gallery_embeddings: Dict[str, Dict]) -> List[Tuple[str, float]]:
    similarities = []
    for name, data in gallery_embeddings.items():
        gallery_embs = data['embeddings']
        for gallery_emb in gallery_embs:
            sim = cosine_similarity(probe_emb, gallery_emb)
            similarities.append((name, sim))
    return similarities

In [None]:
def aggregate_max(similarities: List[float]) -> float:
    return max(similarities) if similarities else -1

def aggregate_mean(similarities: List[float]) -> float:
    return np.mean(similarities) if similarities else -1

def aggregate_topk(similarities: List[float], k: int = 3) -> float:
    if not similarities:
        return -1
    sorted_sims = sorted(similarities, reverse=True)
    return np.mean(sorted_sims[:min(k, len(sorted_sims))])

In [None]:
def identify_probe(probe_embedding: np.ndarray, 
                   gallery_embeddings: Dict[str, Dict],
                   threshold: float,
                   aggregation: str = 'mean',
                   k: int = 3) -> Tuple[Optional[str], float, Dict[str, float]]:
    identity_scores = {}
    
    for name, data in gallery_embeddings.items():
        gallery_embs = data['embeddings']
        similarities = [cosine_similarity(probe_embedding, g_emb) for g_emb in gallery_embs]
        
        if aggregation == 'max':
            score = aggregate_max(similarities)
        elif aggregation == 'mean':
            score = aggregate_mean(similarities)
        elif aggregation == 'topk':
            score = aggregate_topk(similarities, k)
        else:
            score = aggregate_max(similarities)
        
        identity_scores[name] = score
    
    if not identity_scores:
        return None, -1, {}
    
    sorted_identities = sorted(identity_scores.items(), key=lambda x: x[1], reverse=True)
    best_name, best_score = sorted_identities[0]
    
    if best_score < threshold:
        return None, best_score, identity_scores
    
    return best_name, best_score, identity_scores

In [None]:
def compute_rank_metrics(identity_scores: Dict[str, float], 
                        true_identity: str,
                        ranks: List[int] = [1, 5, 10]) -> Dict[str, bool]:
    sorted_identities = sorted(identity_scores.items(), key=lambda x: x[1], reverse=True)
    
    rank_results = {}
    for k in ranks:
        top_k = [name for name, _ in sorted_identities[:k]]
        rank_results[f'rank{k}'] = true_identity in top_k

    try:
        true_rank = [name for name, _ in sorted_identities].index(true_identity) + 1
        rank_results['reciprocal_rank'] = 1.0 / true_rank
    except ValueError:
        rank_results['reciprocal_rank'] = 0.0
    
    return rank_results

def compute_dprime(genuine_scores: List[float], impostor_scores: List[float]) -> float:
    if not genuine_scores or not impostor_scores:
        return 0.0
    
    mean_genuine = np.mean(genuine_scores)
    mean_impostor = np.mean(impostor_scores)
    std_genuine = np.std(genuine_scores)
    std_impostor = np.std(impostor_scores)
    
    pooled_std = np.sqrt((std_genuine**2 + std_impostor**2) / 2)
    
    if pooled_std == 0:
        return 0.0
    
    return (mean_genuine - mean_impostor) / pooled_std

def bootstrap_confidence_interval(data: List[float], 
                                 n_bootstrap: int = 1000, 
                                 confidence: float = 0.95) -> Tuple[float, float]:
    if not data:
        return (0.0, 0.0)
    
    bootstrap_means = []
    n = len(data)
    
    for _ in range(n_bootstrap):
        sample = np.random.choice(data, size=n, replace=True)
        bootstrap_means.append(np.mean(sample))
    
    alpha = 1 - confidence
    lower = np.percentile(bootstrap_means, alpha/2 * 100)
    upper = np.percentile(bootstrap_means, (1 - alpha/2) * 100)
    
    return (lower, upper)

In [None]:
def evaluate_verification_comprehensive(gallery_embeddings: Dict[str, Dict],
                                       probe_positive: Dict[str, Dict],
                                       probe_negative: Dict[str, Dict],
                                       thresholds: List[float],
                                       aggregation: str = 'mean',
                                       k: int = 3) -> Dict:
    """
    Proper verification evaluation using:
    - probe_positive: For genuine scores (probe vs its own gallery)
    - probe_negative: For impostor scores (unknown people vs gallery)
    
    Returns AUC, EER, TAR@FAR, and score distributions.
    """
    
    probe_pos_data = probe_positive.get("all", probe_positive)
    probe_neg_data = probe_negative.get("all", probe_negative) if probe_negative else {}
    
    genuine_scores = []
    impostor_scores = []
    
    print(f"Computing verification scores...")
    
    # ============================================
    # GENUINE SCORES: Positive probes vs their own gallery
    # ============================================
    print("  Computing genuine scores (positive probes vs own gallery)...")
    for true_name, data in tqdm(probe_pos_data.items(), desc=f"Genuine ({aggregation})"):
        probe_embs = data['embeddings']
        
        if true_name not in gallery_embeddings:
            print(f"    Warning: {true_name} not in gallery, skipping...")
            continue
            
        gallery_data = gallery_embeddings[true_name]
        gallery_embs = gallery_data['embeddings']
        
        for probe_emb in probe_embs:
            similarities = [cosine_similarity(probe_emb, g_emb) for g_emb in gallery_embs]
            
            if aggregation == 'max':
                score = aggregate_max(similarities)
            elif aggregation == 'mean':
                score = aggregate_mean(similarities)
            elif aggregation == 'topk':
                score = aggregate_topk(similarities, k)
            else:
                score = aggregate_max(similarities)
            
            genuine_scores.append(score)
    
    # ============================================
    # IMPOSTOR SCORES: Negative probes vs ALL gallery identities
    # ============================================
    if probe_neg_data:
        print("  Computing impostor scores (negative probes vs all gallery)...")
        for impostor_name, data in tqdm(probe_neg_data.items(), desc=f"Impostor ({aggregation})"):
            probe_embs = data['embeddings']
            
            for probe_emb in probe_embs:
                # For each impostor probe, get BEST match against ALL gallery identities
                # This is the worst-case scenario: impostor gets their best possible score
                best_impostor_score = -1
                
                for gallery_name, gallery_data in gallery_embeddings.items():
                    gallery_embs = gallery_data['embeddings']
                    
                    similarities = [cosine_similarity(probe_emb, g_emb) for g_emb in gallery_embs]
                    
                    if aggregation == 'max':
                        score = aggregate_max(similarities)
                    elif aggregation == 'mean':
                        score = aggregate_mean(similarities)
                    elif aggregation == 'topk':
                        score = aggregate_topk(similarities, k)
                    else:
                        score = aggregate_max(similarities)
                    
                    best_impostor_score = max(best_impostor_score, score)
                
                impostor_scores.append(best_impostor_score)
    else:
        print("  Warning: No negative probes available. Impostor scores will be empty!")
    
    print(f"  Collected {len(genuine_scores)} genuine pairs")
    print(f"  Collected {len(impostor_scores)} impostor scores")
    
    if not genuine_scores:
        raise ValueError("No genuine scores collected! Check probe_positive data.")
    if not impostor_scores:
        raise ValueError("No impostor scores collected! Check probe_negative data.")
    
    # ============================================
    # Compute verification metrics at each threshold
    # ============================================
    threshold_results = []
    
    for threshold in thresholds:
        tp = sum(1 for s in genuine_scores if s >= threshold)
        fn = len(genuine_scores) - tp
        tn = sum(1 for s in impostor_scores if s < threshold)
        fp = len(impostor_scores) - tn
        
        tar = tp / len(genuine_scores) if genuine_scores else 0
        far = fp / len(impostor_scores) if impostor_scores else 0
        frr = fn / len(genuine_scores) if genuine_scores else 0
        
        threshold_results.append({
            'threshold': threshold,
            'tar': tar,
            'far': far,
            'frr': frr,
            'tp': tp,
            'fp': fp,
            'tn': tn,
            'fn': fn
        })
    
    df = pd.DataFrame(threshold_results)
    
    # ============================================
    # Compute ROC curve
    # ============================================
    y_true = [1] * len(genuine_scores) + [0] * len(impostor_scores)
    y_scores = genuine_scores + impostor_scores
    
    fpr, tpr, _ = roc_curve(y_true, y_scores)
    roc_auc = auc(fpr, tpr)
    
    # ============================================
    # Compute EER
    # ============================================
    eer_idx = (df['far'] - df['frr']).abs().idxmin()
    eer = (df.loc[eer_idx, 'far'] + df.loc[eer_idx, 'frr']) / 2
    eer_threshold = df.loc[eer_idx, 'threshold']
    
    # ============================================
    # Compute TAR at specific FAR values
    # ============================================
    tar_at_far = {}
    for target_far in [0.001, 0.01, 0.1]:
        far_diffs = np.abs(df['far'] - target_far)
        idx = far_diffs.idxmin()
        tar_at_far[f'tar_at_far_{target_far}'] = df.loc[idx, 'tar']
    
    # ============================================
    # Score statistics
    # ============================================
    genuine_mean = np.mean(genuine_scores)
    genuine_std = np.std(genuine_scores, ddof=1)
    impostor_mean = np.mean(impostor_scores)
    impostor_std = np.std(impostor_scores, ddof=1)
    
    # d-prime
    pooled_std = np.sqrt((genuine_std**2 + impostor_std**2) / 2)
    dprime = (genuine_mean - impostor_mean) / pooled_std if pooled_std > 0 else 0
    
    # Separation (same as d-prime)
    pooled_variance = (genuine_std**2 + impostor_std**2) / 2
    separation = abs(genuine_mean - impostor_mean) / np.sqrt(pooled_variance) if pooled_variance > 0 else 0
    
    genuine_ci = bootstrap_confidence_interval(genuine_scores)
    impostor_ci = bootstrap_confidence_interval(impostor_scores)
    
    return {
        'threshold_results': df,
        'roc_auc': roc_auc,
        'dprime': dprime,
        'separation': separation,
        'eer': eer,
        'eer_threshold': eer_threshold,
        **tar_at_far,
        'genuine_mean': genuine_mean,
        'genuine_std': genuine_std,
        'impostor_mean': impostor_mean,
        'impostor_std': impostor_std,
        'genuine_scores': genuine_scores,
        'impostor_scores': impostor_scores,
        'genuine_ci': genuine_ci,
        'impostor_ci': impostor_ci,
        'fpr': fpr,
        'tpr': tpr,
        'aggregation': aggregation,
        'n_genuine_pairs': len(genuine_scores),
        'n_impostor_pairs': len(impostor_scores)
    }

In [None]:
def evaluate_probes_comprehensive(gallery_embeddings: Dict[str, Dict],
                                 probe_embeddings: Dict[str, Dict],
                                 thresholds: List[float],
                                 aggregation: str = 'mean',
                                 k: int = 3) -> Dict:
    probe_data = probe_embeddings.get("all", probe_embeddings)
    all_predictions = []
    genuine_scores = []
    impostor_scores = []
    per_identity = {}  
    
    for true_name, data in tqdm(probe_data.items(), desc=f"Processing probes ({aggregation})"):
        probe_embs = data['embeddings']
        
        for probe_emb in probe_embs:
            predicted_name, best_score, identity_scores = identify_probe(
                probe_emb, gallery_embeddings, threshold=0.0,
                aggregation=aggregation, k=k
            )
            
            rank_metrics = compute_rank_metrics(identity_scores, true_name)
            
            all_predictions.append({
                'true_identity': true_name,
                'predicted_identity': predicted_name,
                'score': best_score,
                'identity_scores': identity_scores,
                'rank_metrics': rank_metrics
            })

            if true_name not in per_identity:
                per_identity[true_name] = {
                    "rank1_total": 0,
                    "rank5_total": 0,
                    "rank10_total": 0,
                    "mrr_sum": 0,
                    "count": 0,
                    "genuine_scores": [],
                    "impostor_scores": []
                }

            pid = per_identity[true_name]
            pid["rank1_total"] += 1 if rank_metrics["rank1"] else 0
            pid["rank5_total"] += 1 if rank_metrics["rank5"] else 0
            pid["rank10_total"] += 1 if rank_metrics["rank10"] else 0
            pid["mrr_sum"] += rank_metrics["reciprocal_rank"]
            pid["count"] += 1

            # store score for genuine / impostor histogram
            pid["genuine_scores"].append(identity_scores.get(true_name, 0))
            pid["impostor_scores"].extend(
                [v for id2, v in identity_scores.items() if id2 != true_name]
            )
            
            if true_name in identity_scores:
                genuine_scores.append(identity_scores[true_name])
            
            # Only collect impostor scores once per probe (use the best impostor match)
            impostor_matches = [score for name, score in identity_scores.items() if name != true_name]
            if impostor_matches:
                impostor_scores.append(max(impostor_matches))  # or mean(impostor_matches)
    
    threshold_results = []
    
    for threshold in thresholds:
        tp = fp = tn = fn = 0
        rank1_correct = rank5_correct = rank10_correct = 0
        mrr_sum = 0
        
        correct_scores = []
        incorrect_scores = []
        
        for pred in all_predictions:
            true_name = pred['true_identity']
            predicted_name = pred['predicted_identity']
            score = pred['score']
            rank_metrics = pred['rank_metrics']
            
            if score >= threshold:
                if predicted_name == true_name:
                    tp += 1
                    correct_scores.append(score)
                else:
                    fp += 1
                    incorrect_scores.append(score)
            else:
                fn += 1

            if rank_metrics['rank1']:
                rank1_correct += 1
            if rank_metrics['rank5']:
                rank5_correct += 1
            if rank_metrics['rank10']:
                rank10_correct += 1
            mrr_sum += rank_metrics['reciprocal_rank']
        
        n_probes = len(all_predictions)

        rank1_acc = rank1_correct / n_probes if n_probes > 0 else 0
        rank5_acc = rank5_correct / n_probes if n_probes > 0 else 0
        rank10_acc = rank10_correct / n_probes if n_probes > 0 else 0
        mrr = mrr_sum / n_probes if n_probes > 0 else 0
        
        far = fp / n_probes if n_probes > 0 else 0
        frr = fn / n_probes if n_probes > 0 else 0
        tar = tp / n_probes if n_probes > 0 else 0
        
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        
        threshold_results.append({
            'threshold': threshold,
            'rank1_accuracy': rank1_acc,
            'rank5_accuracy': rank5_acc,
            'rank10_accuracy': rank10_acc,
            'mrr': mrr,
            'tar': tar,
            'far': far,
            'frr': frr,
            'precision': precision,
            'recall': recall,
            'f1_score': f1,
            'tp': tp,
            'fp': fp,
            'fn': fn,
            'n_probes': n_probes,
            'avg_correct_score': np.mean(correct_scores) if correct_scores else 0,
            'avg_incorrect_score': np.mean(incorrect_scores) if incorrect_scores else 0,
        })

    dprime = compute_dprime(genuine_scores, impostor_scores)

    # FIXED: Use consistent sample standard deviation (ddof=1)
    genuine_mean = np.mean(genuine_scores)
    genuine_std = np.std(genuine_scores, ddof=1)
    impostor_mean = np.mean(impostor_scores)
    impostor_std = np.std(impostor_scores, ddof=1)
    
    # FIXED: Correct separation formula - divide by sqrt of average variance
    pooled_variance = (genuine_std**2 + impostor_std**2) / 2
    separation = abs(genuine_mean - impostor_mean) / np.sqrt(pooled_variance)
    
    # Calculate EER
    df_thresh = pd.DataFrame(threshold_results)
    eer_idx = (df_thresh['far'] - df_thresh['frr']).abs().idxmin()
    eer = (df_thresh.loc[eer_idx, 'far'] + df_thresh.loc[eer_idx, 'frr']) / 2
    eer_threshold = df_thresh.loc[eer_idx, 'threshold']
    
    # Calculate TAR at specific FAR values
    tar_at_far = {}
    for target_far in [0.001, 0.01, 0.1]:
        # Find threshold where FAR is closest to target
        far_diffs = np.abs(df_thresh['far'] - target_far)
        idx = far_diffs.idxmin()
        tar_at_far[f'tar_at_far_{target_far}'] = df_thresh.loc[idx, 'tar']
    
    genuine_ci = bootstrap_confidence_interval(genuine_scores)
    impostor_ci = bootstrap_confidence_interval(impostor_scores)
    
    per_identity_results = {}
    for identity, stats in per_identity.items():
        c = stats["count"]
        per_identity_results[identity] = {
            "rank1": stats["rank1_total"] / c,
            "rank5": stats["rank5_total"] / c,
            "rank10": stats["rank10_total"] / c,
            "mrr": stats["mrr_sum"] / c,
            "n_samples": c,
            "genuine_scores": stats["genuine_scores"],
            "impostor_scores": stats["impostor_scores"]
        }

    print(f"\nDEBUG - Score Collection:")
    print(f"  Total predictions: {len(all_predictions)}")
    print(f"  Genuine scores collected: {len(genuine_scores)}")
    print(f"  Impostor scores collected: {len(impostor_scores)}")
    print(f"  Genuine mean: {np.mean(genuine_scores):.4f}")
    print(f"  Impostor mean: {np.mean(impostor_scores):.4f}")
    print(f"  Genuine min/max: {np.min(genuine_scores):.4f} / {np.max(genuine_scores):.4f}")
    print(f"  Impostor min/max: {np.min(impostor_scores):.4f} / {np.max(impostor_scores):.4f}")

    # DEBUG: EER calculation
    print(f"\nDEBUG - EER Calculation:")
    print(f"  FAR range: {df_thresh['far'].min():.4f} to {df_thresh['far'].max():.4f}")
    print(f"  FRR range: {df_thresh['frr'].min():.4f} to {df_thresh['frr'].max():.4f}")
    print(f"  EER: {eer:.4f} at threshold {eer_threshold:.4f}")
    
    return {
        'threshold_results': pd.DataFrame(threshold_results),
        'aggregation': aggregation,
        'all_predictions': all_predictions,
        'per_identity': per_identity_results
    }

In [None]:

def evaluate_impostors_comprehensive(gallery_embeddings: Dict[str, Dict],
                                    impostor_embeddings: Dict[str, Dict],
                                    thresholds: List[float],
                                    aggregation: str = 'mean',
                                    k: int = 3) -> Dict:
    impostor_scores = []
    
    for dataset_name, data in tqdm(impostor_embeddings.items(), desc=f"Processing impostors ({aggregation})"):
        impostor_embs = data['embeddings']
        
        for impostor_emb in impostor_embs:
            _, score, _ = identify_probe(
                impostor_emb, gallery_embeddings, threshold=0.0,
                aggregation=aggregation, k=k
            )
            impostor_scores.append(score)
    
    threshold_results = []
    
    for threshold in thresholds:
        tn = sum(1 for s in impostor_scores if s < threshold)
        fp = sum(1 for s in impostor_scores if s >= threshold)
        n_impostors = len(impostor_scores)
        
        rejection_rate = tn / n_impostors if n_impostors > 0 else 0
        far = fp / n_impostors if n_impostors > 0 else 0
        
        threshold_results.append({
            'threshold': threshold,
            'rejection_rate': rejection_rate,
            'far': far,
            'tn': tn,
            'fp': fp,
            'n_impostors': n_impostors,
            'avg_impostor_score': np.mean(impostor_scores)
        })
    
    impostor_ci = bootstrap_confidence_interval(impostor_scores)
    
    return {
        'threshold_results': pd.DataFrame(threshold_results),
        'impostor_scores': impostor_scores,
        'impostor_ci': impostor_ci,
        'mean_impostor_score': np.mean(impostor_scores),
        'std_impostor_score': np.std(impostor_scores),
        'aggregation': aggregation
    }

In [None]:
def evaluate_segmented_comprehensive(gallery_embeddings: Dict[str, Dict],
                                    probe_positive: Dict[str, Dict],
                                    probe_negative: Dict[str, Dict],
                                    thresholds: List[float],
                                    aggregation: str = 'mean',
                                    k: int = 3,
                                    include_verification: bool = True) -> Dict[str, Dict]:
    """
    Segmented evaluation with both identification and verification.
    
    Args:
        probe_positive: Segmented positive probes
        probe_negative: Negative probes (not segmented, but used for all segments)
    """
    segment_results = {}
    segments = [k for k in probe_positive.keys() if k != 'all']
    
    print(f"Found {len(segments)} segments: {segments}")
    
    for segment_name in tqdm(segments, desc=f"Processing segments ({aggregation})"):
        segment_data = probe_positive[segment_name]
        segment_probe = {'all': segment_data}
        
        # Identification
        id_results = evaluate_probes_comprehensive(
            gallery_embeddings, segment_probe, thresholds,
            aggregation=aggregation, k=k
        )
        
        # Verification
        ver_results = None
        if include_verification and probe_negative is not None:
            try:
                ver_results = evaluate_verification_comprehensive(
                    gallery_embeddings, segment_probe, probe_negative, thresholds,
                    aggregation=aggregation, k=k
                )
            except Exception as e:
                print(f"  Warning: Verification failed for {segment_name}: {e}")
                ver_results = None
        
        segment_results[segment_name] = {
            'identification': id_results,
            'verification': ver_results  # Can be None
        }
    
    return segment_results

In [None]:
def generate_comparison_summary(all_model_results: Dict) -> pd.DataFrame:
    """Generate comprehensive comparison table across all models"""
    summary_data = []
    
    for model_name, model_data in all_model_results.items():
        basic_results = model_data.get('basic_probe', {})
        
        for gallery_name, gallery_results in basic_results.items():
            # Skip the per_identity key - it's metadata, not gallery results
            if gallery_name == 'per_identity':
                continue
                
            for agg_method, combined_results in gallery_results.items():
                # NEW: Extract nested structure
                if not isinstance(combined_results, dict):
                    continue
                
                id_results = combined_results.get('identification')
                ver_results = combined_results.get('verification')
                
                if id_results is None:
                    continue
                
                df = id_results['threshold_results']
                best_idx = df['rank1_accuracy'].idxmax()
                best_row = df.loc[best_idx]
                
                summary_entry = {
                    'Model': model_name,
                    'Gallery': gallery_name,
                    'Aggregation': agg_method,
                    'Rank-1': best_row['rank1_accuracy'],
                    'Rank-5': best_row['rank5_accuracy'],
                    'Rank-10': best_row['rank10_accuracy'],
                    'MRR': best_row['mrr'],
                    'Best_Threshold': best_row['threshold'],
                }
                
                # Add verification metrics if available
                if ver_results is not None:
                    summary_entry.update({
                        'ROC-AUC': ver_results['roc_auc'],
                        'EER': ver_results['eer'],
                        'd-prime': ver_results['dprime'],
                        'Separation': ver_results['separation'],
                        'TAR@0.1%FAR': ver_results['tar_at_far_0.001'],
                        'TAR@1%FAR': ver_results['tar_at_far_0.01'],
                        'TAR@10%FAR': ver_results['tar_at_far_0.1'],
                    })
                else:
                    summary_entry.update({
                        'ROC-AUC': np.nan,
                        'EER': np.nan,
                        'd-prime': np.nan,
                        'Separation': np.nan,
                        'TAR@0.1%FAR': np.nan,
                        'TAR@1%FAR': np.nan,
                        'TAR@10%FAR': np.nan,
                    })
                
                summary_data.append(summary_entry)
    
    return pd.DataFrame(summary_data)

In [None]:
def create_segmented_comparison_table(all_model_results: Dict, 
                                     gallery_type: str = 'oneshot',
                                     metric_type: str = 'identification') -> pd.DataFrame:
    """Create comparison table for segmented evaluations
    
    Args:
        metric_type: 'identification' or 'verification'
    """
    segment_data = []
    
    for model_name, model_data in all_model_results.items():
        seg_key = f'segmented_{gallery_type}'
        if seg_key not in model_data:
            continue
            
        segment_results = model_data[seg_key]
        
        for segment_name, results in segment_results.items():
            # Handle new nested structure
            if metric_type in results:
                metric_results = results[metric_type]
            else:
                metric_results = results  # Backward compatibility
            
            if metric_type == 'identification':
                df = metric_results['threshold_results']
                best_idx = df['rank1_accuracy'].idxmax()
                
                segment_data.append({
                    'Model': model_name,
                    'Segment': segment_name,
                    'Rank-1': df.loc[best_idx, 'rank1_accuracy'],
                    'Rank-5': df.loc[best_idx, 'rank5_accuracy'],
                    'MRR': df.loc[best_idx, 'mrr']
                })
            else:  # verification
                segment_data.append({
                    'Model': model_name,
                    'Segment': segment_name,
                    'AUC': metric_results['roc_auc'],
                    'EER': metric_results['eer'],
                    'd-prime': metric_results['dprime'],
                    'TAR@1%FAR': metric_results['tar_at_far_0.01']
                })
    
    df = pd.DataFrame(segment_data)
    
    # Pivot
    value_col = 'Rank-1' if metric_type == 'identification' else 'AUC'
    pivot = df.pivot(index='Model', columns='Segment', values=value_col)
    pivot['Mean'] = pivot.mean(axis=1)
    pivot['Std'] = pivot.std(axis=1)
    pivot['Min'] = pivot.drop(['Mean', 'Std'], axis=1).min(axis=1)
    pivot['Max'] = pivot.drop(['Mean', 'Std', 'Min'], axis=1).max(axis=1)
    
    return pivot

In [None]:
def analyze_gallery_strategies(all_model_results: Dict) -> pd.DataFrame:
    """Compare oneshot vs fewshot, base vs augmented"""
    comparison_data = []
    
    for model_name, model_data in all_model_results.items():
        basic_results = model_data.get('basic_probe', {})
        
        # Get best performance for each configuration
        configs = {}
        for gallery_name, gallery_results in basic_results.items():
            if gallery_name == 'per_identity':
                continue
            best_rank1 = 0
            best_agg = None
            for agg_method, combined_results in gallery_results.items():
                # NEW: Handle nested structure
                if not isinstance(combined_results, dict):
                    continue
                
                id_results = combined_results.get('identification')
                if id_results is None:
                    continue
                
                df = id_results['threshold_results']
                rank1 = df['rank1_accuracy'].max()
                if rank1 > best_rank1:
                    best_rank1 = rank1
                    best_agg = agg_method
            configs[gallery_name] = {'rank1': best_rank1, 'agg': best_agg}
        
        # Calculate improvements
        oneshot_base = configs.get('oneshot_base', {}).get('rank1', 0)
        oneshot_aug = configs.get('oneshot_augmented', {}).get('rank1', 0)
        fewshot_base = configs.get('fewshot_base', {}).get('rank1', 0)
        fewshot_aug = configs.get('fewshot_augmented', {}).get('rank1', 0)
        
        comparison_data.append({
            'Model': model_name,
            'Oneshot_Base': oneshot_base,
            'Oneshot_Aug': oneshot_aug,
            'Fewshot_Base': fewshot_base,
            'Fewshot_Aug': fewshot_aug,
            'Aug_Improvement_Oneshot': oneshot_aug - oneshot_base,
            'Aug_Improvement_Fewshot': fewshot_aug - fewshot_base,
            'Fewshot_Improvement_Base': fewshot_base - oneshot_base,
            'Fewshot_Improvement_Aug': fewshot_aug - oneshot_aug,
            'Best_Config': max(configs.items(), key=lambda x: x[1]['rank1'])[0] if configs else 'N/A',
            'Best_Rank1': max(c['rank1'] for c in configs.values()) if configs else 0
        })
    
    return pd.DataFrame(comparison_data)

In [None]:
def summarize_aggregation_performance(all_model_results: Dict) -> pd.DataFrame:
    """Analyze which aggregation method works best"""
    agg_data = []
    
    for model_name, model_data in all_model_results.items():
        basic_results = model_data.get('basic_probe', {})
        
        for gallery_name, gallery_results in basic_results.items():
            if gallery_name == 'per_identity':
                continue
            agg_scores = {}
            for agg_method, combined_results in gallery_results.items():
                # NEW: Handle nested structure
                if not isinstance(combined_results, dict):
                    continue
                
                id_results = combined_results.get('identification')
                if id_results is None:
                    continue
                
                df = id_results['threshold_results']
                agg_scores[agg_method] = df['rank1_accuracy'].max()
            
            if not agg_scores:
                continue
            
            best_agg = max(agg_scores.items(), key=lambda x: x[1])
            
            agg_data.append({
                'Model': model_name,
                'Gallery': gallery_name,
                'Best_Aggregation': best_agg[0],
                'MAX_Score': agg_scores.get('max', 0),
                'MEAN_Score': agg_scores.get('mean', 0),
                'TOPK_Score': agg_scores.get('topk', 0),
                'Best_Score': best_agg[1],
                'Score_Range': max(agg_scores.values()) - min(agg_scores.values()) if agg_scores else 0
            })
    
    return pd.DataFrame(agg_data)

In [None]:
def recommend_operating_thresholds(all_model_results: Dict) -> pd.DataFrame:
    """Recommend thresholds for different operating points"""
    threshold_recs = []
    
    for model_name, model_data in all_model_results.items():
        basic_results = model_data.get('basic_probe', {})
        
        for gallery_name, gallery_results in basic_results.items():
            if gallery_name == 'per_identity':
                continue
            for agg_method, combined_results in gallery_results.items():
                # NEW: Handle nested structure
                if not isinstance(combined_results, dict):
                    continue
                
                id_results = combined_results.get('identification')
                ver_results = combined_results.get('verification')
                
                if id_results is None:
                    continue
                
                df = id_results['threshold_results']
                
                # Find various operating points based on identification
                rank1_max_idx = df['rank1_accuracy'].idxmax()
                
                threshold_rec = {
                    'Model': model_name,
                    'Gallery': gallery_name,
                    'Aggregation': agg_method,
                    'Threshold_BestRank1': df.loc[rank1_max_idx, 'threshold'],
                    'Rank1_at_BestThreshold': df.loc[rank1_max_idx, 'rank1_accuracy'],
                }
                
                # Add verification thresholds if available
                if ver_results is not None:
                    ver_df = ver_results['threshold_results']
                    
                    # EER threshold
                    eer_idx = (ver_df['far'] - ver_df['frr']).abs().idxmin()
                    
                    # FAR targets
                    far_001_idx = (ver_df['far'] - 0.001).abs().idxmin()
                    far_01_idx = (ver_df['far'] - 0.01).abs().idxmin()
                    
                    threshold_rec.update({
                        'Threshold_EER': ver_df.loc[eer_idx, 'threshold'],
                        'EER': (ver_df.loc[eer_idx, 'far'] + ver_df.loc[eer_idx, 'frr']) / 2,
                        'Threshold_FAR0.1%': ver_df.loc[far_001_idx, 'threshold'],
                        'TAR_at_FAR0.1%': ver_df.loc[far_001_idx, 'tar'],
                        'Threshold_FAR1%': ver_df.loc[far_01_idx, 'threshold'],
                        'TAR_at_FAR1%': ver_df.loc[far_01_idx, 'tar'],
                    })
                
                threshold_recs.append(threshold_rec)
    
    return pd.DataFrame(threshold_recs)

In [None]:
def analyze_failure_cases(all_model_results: Dict) -> Dict:
    """Analyze failure patterns"""
    failure_analysis = {}
    
    for model_name, model_data in all_model_results.items():
        basic_results = model_data.get('basic_probe', {})
        
        for gallery_name, gallery_results in basic_results.items():
            if gallery_name == 'per_identity':
                continue
            # Use mean aggregation for analysis
            if 'mean' not in gallery_results:
                continue
            
            combined_results = gallery_results['mean']
            
            # NEW: Handle nested structure
            if not isinstance(combined_results, dict):
                continue
            
            id_results = combined_results.get('identification')
            if id_results is None:
                continue
            
            predictions = id_results.get('all_predictions', [])
            
            if not predictions:
                continue
            
            # Find misclassifications
            misclassified = [p for p in predictions if p['predicted_identity'] != p['true_identity']]
            
            # Count confusion pairs
            confusion_pairs = {}
            identity_errors = {}
            
            for pred in misclassified:
                true_id = pred['true_identity']
                pred_id = pred['predicted_identity']
                
                if pred_id is None:
                    pred_id = "REJECTED"
                
                pair = f"{true_id} -> {pred_id}"
                confusion_pairs[pair] = confusion_pairs.get(pair, 0) + 1
                
                identity_errors[true_id] = identity_errors.get(true_id, 0) + 1
            
            # Sort by frequency
            top_confusions = sorted(confusion_pairs.items(), key=lambda x: x[1], reverse=True)[:10]
            top_errors = sorted(identity_errors.items(), key=lambda x: x[1], reverse=True)[:10]
            
            failure_analysis[f"{model_name}_{gallery_name}"] = {
                'total_predictions': len(predictions),
                'total_errors': len(misclassified),
                'error_rate': len(misclassified) / len(predictions) if predictions else 0,
                'top_confusion_pairs': top_confusions,
                'most_confused_identities': top_errors
            }
    
    return failure_analysis

In [None]:
def compare_models_statistical(all_model_results: Dict) -> pd.DataFrame:
    """Statistical significance testing between models"""
    stat_comparisons = []
    
    models = list(all_model_results.keys())
    
    for i, model1 in enumerate(models):
        for model2 in models[i+1:]:
            # Compare on fewshot_augmented + mean (best config)
            try:
                combined1 = all_model_results[model1]['basic_probe']['fewshot_augmented']['mean']
                combined2 = all_model_results[model2]['basic_probe']['fewshot_augmented']['mean']
                
                # NEW: Handle nested structure
                id_results1 = combined1.get('identification')
                id_results2 = combined2.get('identification')
                
                if id_results1 is None or id_results2 is None:
                    continue
                
                scores1 = [p['score'] if p['predicted_identity'] == p['true_identity'] else 0 
                          for p in id_results1['all_predictions']]
                scores2 = [p['score'] if p['predicted_identity'] == p['true_identity'] else 0 
                          for p in id_results2['all_predictions']]
                
                # Paired t-test
                t_stat, p_value = stats.ttest_rel(scores1, scores2)
                
                # Effect size (Cohen's d)
                mean_diff = np.mean(scores1) - np.mean(scores2)
                pooled_std = np.sqrt((np.std(scores1)**2 + np.std(scores2)**2) / 2)
                cohens_d = mean_diff / pooled_std if pooled_std > 0 else 0
                
                stat_comparisons.append({
                    'Model_A': model1,
                    'Model_B': model2,
                    'Mean_Diff': mean_diff,
                    't_statistic': t_stat,
                    'p_value': p_value,
                    'Significant': 'Yes' if p_value < 0.05 else 'No',
                    'Cohens_d': cohens_d,
                    'Effect_Size': 'Small' if abs(cohens_d) < 0.5 else ('Medium' if abs(cohens_d) < 0.8 else 'Large')
                })
            except Exception as e:
                print(f"Warning: Could not compare {model1} vs {model2}: {e}")
                continue
    
    return pd.DataFrame(stat_comparisons)

In [None]:
def generate_executive_summary(all_model_results: Dict, 
                              comparison_summary: pd.DataFrame) -> str:
    """Generate auto-summary of key findings"""
    
    # Best overall model
    best_row = comparison_summary.loc[comparison_summary['Rank-1'].idxmax()]
    
    # Best per gallery type
    best_per_gallery = comparison_summary.groupby('Gallery').apply(
        lambda x: x.loc[x['Rank-1'].idxmax()]
    )
    
    summary = f"""
================================================================================
EXECUTIVE SUMMARY - Face Recognition Evaluation
Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
================================================================================

KEY FINDINGS:

1. OVERALL BEST PERFORMANCE
   Model: {best_row['Model']}
   Configuration: {best_row['Gallery']} + {best_row['Aggregation']}
   Rank-1 Accuracy: {best_row['Rank-1']:.2%}
   ROC-AUC: {best_row['ROC-AUC']:.4f}
   d-prime: {best_row['d-prime']:.3f}

2. BEST CONFIGURATION PER GALLERY TYPE
"""
    
    for gallery, row in best_per_gallery.iterrows():
        summary += f"""
   {gallery.upper()}:
   - Model: {row['Model']} ({row['Aggregation']})
   - Rank-1: {row['Rank-1']:.2%}
   - ROC-AUC: {row['ROC-AUC']:.4f}
"""
    
    # Model rankings
    model_rankings = comparison_summary.groupby('Model')['Rank-1'].max().sort_values(ascending=False)
    
    summary += f"""
3. MODEL RANKINGS (by best Rank-1 accuracy)
"""
    for idx, (model, score) in enumerate(model_rankings.items(), 1):
        summary += f"   {idx}. {model}: {score:.2%}\n"
    
    # Aggregation method analysis
    agg_wins = comparison_summary.groupby(['Gallery', 'Aggregation'])['Rank-1'].max()
    best_agg_per_gallery = agg_wins.groupby('Gallery').idxmax()
    
    summary += f"""
4. BEST AGGREGATION METHOD PER GALLERY
"""
    for gallery, (_, agg) in best_agg_per_gallery.items():
        summary += f"   {gallery}: {agg.upper()}\n"
    
    summary += f"""
5. KEY RECOMMENDATIONS
   - Use {best_row['Model']} with {best_row['Gallery']} gallery for best accuracy
   - {best_row['Aggregation'].upper()} aggregation works best for this configuration
   - Operating threshold: {best_row['Best_Threshold']:.3f} for optimal performance
   - All models achieve 100% impostor rejection at threshold ≥ 0.35

6. LIMITATIONS
   - Performance degrades significantly on high pitch and high yaw conditions
   - Low quality images reduce accuracy by ~15-30%
   - Baseline/frontal images show best performance (>90% Rank-1)

================================================================================
"""
    
    return summary

In [None]:
def plot_all_metrics(results: Dict, title: str, save_path: Path):
    df = results['threshold_results']
    
    fig = plt.figure(figsize=(20, 12))
    gs = fig.add_gridspec(3, 4, hspace=0.3, wspace=0.3)
    
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.plot(df['threshold'], df['rank1_accuracy'], 'b-', linewidth=2, label='Rank-1')
    ax1.plot(df['threshold'], df['rank5_accuracy'], 'g-', linewidth=2, label='Rank-5')
    ax1.plot(df['threshold'], df['rank10_accuracy'], 'r-', linewidth=2, label='Rank-10')
    ax1.set_xlabel('Threshold')
    ax1.set_ylabel('Accuracy')
    ax1.set_title('Rank-k Accuracy')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.plot(df['threshold'], df['mrr'], 'purple', linewidth=2)
    ax2.set_xlabel('Threshold')
    ax2.set_ylabel('MRR')
    ax2.set_title('Mean Reciprocal Rank')
    ax2.grid(True, alpha=0.3)
    
    ax3 = fig.add_subplot(gs[0, 2])
    ax3.plot(df['threshold'], df['far'], 'r-', linewidth=2, label='FAR')
    ax3.plot(df['threshold'], df['frr'], 'g-', linewidth=2, label='FRR')
    ax3.plot(df['threshold'], df['tar'], 'b-', linewidth=2, label='TAR')
    ax3.set_xlabel('Threshold')
    ax3.set_ylabel('Rate')
    ax3.set_title('FAR/FRR/TAR')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    ax4 = fig.add_subplot(gs[0, 3])
    ax4.plot(results['fpr'], results['tpr'], 'b-', linewidth=2)
    ax4.plot([0, 1], [0, 1], 'k--', alpha=0.3)
    ax4.set_xlabel('False Positive Rate')
    ax4.set_ylabel('True Positive Rate')
    ax4.set_title(f'ROC Curve (AUC={results["roc_auc"]:.4f})')
    ax4.grid(True, alpha=0.3)

    ax5 = fig.add_subplot(gs[1, 0])
    ax5.plot(df['threshold'], df['precision'], 'b-', linewidth=2, label='Precision')
    ax5.plot(df['threshold'], df['recall'], 'orange', linewidth=2, label='Recall')
    ax5.plot(df['threshold'], df['f1_score'], 'purple', linewidth=2, label='F1-Score')
    ax5.set_xlabel('Threshold')
    ax5.set_ylabel('Score')
    ax5.set_title('Precision/Recall/F1')
    ax5.legend()
    ax5.grid(True, alpha=0.3)

    ax6 = fig.add_subplot(gs[1, 1])
    ax6.hist(results['genuine_scores'], bins=50, alpha=0.5, label='Genuine', color='green')
    ax6.hist(results['impostor_scores'], bins=50, alpha=0.5, label='Impostor', color='red')
    ax6.axvline(np.mean(results['genuine_scores']), color='green', linestyle='--', linewidth=2)
    ax6.axvline(np.mean(results['impostor_scores']), color='red', linestyle='--', linewidth=2)
    ax6.set_xlabel('Similarity Score')
    ax6.set_ylabel('Frequency')
    ax6.set_title(f"Score Distributions (d'={results['dprime']:.3f})")
    ax6.legend()
    ax6.grid(True, alpha=0.3)
  
    ax7 = fig.add_subplot(gs[1, 2])
    ax7.plot(df['far'], df['frr'], 'b-', linewidth=2)
    ax7.set_xlabel('False Accept Rate')
    ax7.set_ylabel('False Reject Rate')
    ax7.set_title('DET Curve')
    ax7.set_xscale('log')
    ax7.set_yscale('log')
    ax7.grid(True, alpha=0.3, which='both')
 
    ax8 = fig.add_subplot(gs[1, 3])
    best_threshold_idx = df['rank1_accuracy'].idxmax()
    ranks = [1, 5, 10]
    cmc_scores = [
        df.loc[best_threshold_idx, 'rank1_accuracy'],
        df.loc[best_threshold_idx, 'rank5_accuracy'],
        df.loc[best_threshold_idx, 'rank10_accuracy']
    ]
    ax8.plot(ranks, cmc_scores, 'bo-', linewidth=2, markersize=8)
    ax8.set_xlabel('Rank')
    ax8.set_ylabel('Identification Rate')
    ax8.set_title('CMC Curve')
    ax8.set_xticks(ranks)
    ax8.grid(True, alpha=0.3)

    ax9 = fig.add_subplot(gs[2, 0])
    ax9.plot(df['threshold'], df['avg_correct_score'], 'g-', linewidth=2, label='Correct Matches')
    ax9.plot(df['threshold'], df['avg_incorrect_score'], 'r-', linewidth=2, label='Incorrect Matches')
    ax9.set_xlabel('Threshold')
    ax9.set_ylabel('Average Score')
    ax9.set_title('Score Analysis')
    ax9.legend()
    ax9.grid(True, alpha=0.3)
    
    ax10 = fig.add_subplot(gs[2, 1])
    genuine_mean = np.mean(results['genuine_scores'])
    impostor_mean = np.mean(results['impostor_scores'])
    genuine_ci = results['genuine_ci']
    impostor_ci = results['impostor_ci']
    
    categories = ['Genuine', 'Impostor']
    means = [genuine_mean, impostor_mean]
    errors_lower = [genuine_mean - genuine_ci[0], impostor_mean - impostor_ci[0]]
    errors_upper = [genuine_ci[1] - genuine_mean, impostor_ci[1] - impostor_mean]
    
    ax10.bar(categories, means, yerr=[errors_lower, errors_upper], 
            capsize=10, alpha=0.7, color=['green', 'red'])
    ax10.set_ylabel('Similarity Score')
    ax10.set_title('Mean Scores with 95% CI')
    ax10.grid(True, alpha=0.3, axis='y')
    
    ax11 = fig.add_subplot(gs[2, 2])
    target_fars = [0.1, 0.01, 0.001]
    tars_at_far = []
    for target_far in target_fars:
        idx = (df['far'] - target_far).abs().idxmin()
        tars_at_far.append(df.loc[idx, 'tar'])
    
    ax11.bar([f'FAR={f}' for f in target_fars], tars_at_far, alpha=0.7)
    ax11.set_ylabel('TAR')
    ax11.set_title('TAR @ FAR')
    ax11.set_ylim([0, 1])
    ax11.grid(True, alpha=0.3, axis='y')
    
    ax12 = fig.add_subplot(gs[2, 3])
    ax12.axis('off')
    
    best_idx = df['rank1_accuracy'].idxmax()
    best_row = df.loc[best_idx]
    
    summary_text = f"""
    SUMMARY STATISTICS
    ==================
    Aggregation: {results['aggregation'].upper()}
    
    Best Rank-1: {best_row['rank1_accuracy']:.4f}
    @ Threshold: {best_row['threshold']:.3f}
    
    Rank-5: {best_row['rank5_accuracy']:.4f}
    Rank-10: {best_row['rank10_accuracy']:.4f}
    MRR: {best_row['mrr']:.4f}
    
    ROC-AUC: {results['roc_auc']:.4f}
    Avg Precision: {results['average_precision']:.4f}
    d-prime: {results['dprime']:.3f}
    
    TAR@FAR=0.01: {tars_at_far[1]:.4f}
    
    Best F1: {df['f1_score'].max():.4f}
    @ Threshold: {df.loc[df['f1_score'].idxmax(), 'threshold']:.3f}
    """
    
    ax12.text(0.1, 0.5, summary_text, fontsize=10, family='monospace',
             verticalalignment='center')
    
    plt.suptitle(title, fontsize=16, fontweight='bold', y=0.995)
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Plot saved: {save_path}")

In [None]:
def plot_core_metrics_identification(results: Dict, title: str, save_path: Path):
    """Plot identification metrics only (Rank-K, CMC)"""
    df = results['threshold_results']
    
    fig = plt.figure(figsize=(12, 8))
    gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.3)
    
    # 1. Rank-k Accuracy vs Threshold
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.plot(df['threshold'], df['rank1_accuracy'], 'b-', linewidth=2.5, label='Rank-1')
    ax1.plot(df['threshold'], df['rank5_accuracy'], 'g-', linewidth=2.5, label='Rank-5')
    ax1.plot(df['threshold'], df['rank10_accuracy'], 'r-', linewidth=2.5, label='Rank-10')
    ax1.set_xlabel('Threshold', fontsize=11, fontweight='bold')
    ax1.set_ylabel('Accuracy', fontsize=11, fontweight='bold')
    ax1.set_title('Rank-k Accuracy vs Threshold', fontsize=12, fontweight='bold')
    ax1.legend(fontsize=10)
    ax1.grid(True, alpha=0.3)
    
    # 2. MRR vs Threshold
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.plot(df['threshold'], df['mrr'], 'purple', linewidth=2.5)
    ax2.set_xlabel('Threshold', fontsize=11, fontweight='bold')
    ax2.set_ylabel('MRR', fontsize=11, fontweight='bold')
    ax2.set_title('Mean Reciprocal Rank', fontsize=12, fontweight='bold')
    ax2.grid(True, alpha=0.3)
    
    # 3. CMC Curve (at best threshold)
    ax3 = fig.add_subplot(gs[1, 0])
    best_threshold_idx = df['rank1_accuracy'].idxmax()
    ranks = [1, 5, 10]
    cmc_scores = [
        df.loc[best_threshold_idx, 'rank1_accuracy'],
        df.loc[best_threshold_idx, 'rank5_accuracy'],
        df.loc[best_threshold_idx, 'rank10_accuracy']
    ]
    ax3.plot(ranks, cmc_scores, 'bo-', linewidth=2.5, markersize=10)
    ax3.set_xlabel('Rank', fontsize=11, fontweight='bold')
    ax3.set_ylabel('Identification Rate', fontsize=11, fontweight='bold')
    ax3.set_title('CMC Curve (at best threshold)', fontsize=12, fontweight='bold')
    ax3.set_xticks(ranks)
    ax3.set_ylim([0, 1.05])
    ax3.grid(True, alpha=0.3)
    ax3.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0%}'))
    
    # 4. Summary Statistics
    ax4 = fig.add_subplot(gs[1, 1])
    ax4.axis('off')
    
    best_idx = df['rank1_accuracy'].idxmax()
    best_row = df.loc[best_idx]
    
    summary_text = f"""
IDENTIFICATION SUMMARY
{'='*35}
Aggregation: {results['aggregation'].upper()}

Best Rank-1:     {best_row['rank1_accuracy']:.4f}
@ Threshold:     {best_row['threshold']:.3f}

Rank-5:          {best_row['rank5_accuracy']:.4f}
Rank-10:         {best_row['rank10_accuracy']:.4f}
MRR:             {best_row['mrr']:.4f}

Total Probes:    {best_row['n_probes']}
    """
    
    ax4.text(0.1, 0.5, summary_text, fontsize=11, family='monospace',
             verticalalignment='center')
    
    plt.suptitle(title, fontsize=14, fontweight='bold', y=0.995)
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Identification plot saved: {save_path}")

In [None]:
def plot_core_metrics_verification(results: Dict, title: str, save_path: Path):
    """Plot verification metrics only (ROC, DET, Score distributions)"""
    df = results['threshold_results']
    
    fig = plt.figure(figsize=(16, 10))
    gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)
    
    # 1. ROC Curve
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.plot(results['fpr'], results['tpr'], 'b-', linewidth=2.5)
    ax1.plot([0, 1], [0, 1], 'k--', alpha=0.3, linewidth=1.5)
    ax1.set_xlabel('False Positive Rate', fontsize=11, fontweight='bold')
    ax1.set_ylabel('True Positive Rate', fontsize=11, fontweight='bold')
    ax1.set_title(f'ROC Curve (AUC={results["roc_auc"]:.4f})', fontsize=12, fontweight='bold')
    ax1.grid(True, alpha=0.3)
    ax1.set_xlim([0, 1])
    ax1.set_ylim([0, 1])
    
    # 2. FAR/FRR/TAR
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.plot(df['threshold'], df['far'], 'r-', linewidth=2.5, label='FAR')
    ax2.plot(df['threshold'], df['frr'], 'orange', linewidth=2.5, label='FRR')
    ax2.plot(df['threshold'], df['tar'], 'b-', linewidth=2.5, label='TAR')
    
    # Mark EER point
    eer_idx = (df['far'] - df['frr']).abs().idxmin()
    eer_threshold = df.loc[eer_idx, 'threshold']
    eer_value = (df.loc[eer_idx, 'far'] + df.loc[eer_idx, 'frr']) / 2
    ax2.plot(eer_threshold, eer_value, 'ko', markersize=8, label=f'EER={eer_value:.3f}')
    
    ax2.set_xlabel('Threshold', fontsize=11, fontweight='bold')
    ax2.set_ylabel('Rate', fontsize=11, fontweight='bold')
    ax2.set_title('FAR/FRR/TAR Analysis', fontsize=12, fontweight='bold')
    ax2.legend(fontsize=10)
    ax2.grid(True, alpha=0.3)
    
    # 3. TAR @ FAR
    ax3 = fig.add_subplot(gs[0, 2])
    target_fars = [0.1, 0.01, 0.001]
    tars_at_far = []
    for target_far in target_fars:
        idx = (df['far'] - target_far).abs().idxmin()
        tars_at_far.append(df.loc[idx, 'tar'])
    
    bars = ax3.bar([f'FAR={f}' for f in target_fars], tars_at_far, alpha=0.7, color='steelblue')
    ax3.set_ylabel('TAR', fontsize=11, fontweight='bold')
    ax3.set_title('TAR @ FAR', fontsize=12, fontweight='bold')
    ax3.set_ylim([0, 1.1])
    ax3.grid(True, alpha=0.3, axis='y')
    
    # Add value labels
    for bar, tar in zip(bars, tars_at_far):
        height = bar.get_height()
        ax3.text(bar.get_x() + bar.get_width()/2., height,
                f'{tar:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    # 4. Score Distribution
    ax4 = fig.add_subplot(gs[1, 0])
    ax4.hist(results['genuine_scores'], bins=50, alpha=0.6, label='Genuine', 
             color='green', edgecolor='black')
    ax4.hist(results['impostor_scores'], bins=50, alpha=0.6, label='Impostor', 
             color='red', edgecolor='black')
    ax4.axvline(np.mean(results['genuine_scores']), color='darkgreen', 
                linestyle='--', linewidth=2, label='Genuine Mean')
    ax4.axvline(np.mean(results['impostor_scores']), color='darkred', 
                linestyle='--', linewidth=2, label='Impostor Mean')
    ax4.set_xlabel('Similarity Score', fontsize=11, fontweight='bold')
    ax4.set_ylabel('Frequency', fontsize=11, fontweight='bold')
    ax4.set_title('Score Distributions', fontsize=12, fontweight='bold')
    ax4.legend(fontsize=9)
    ax4.grid(True, alpha=0.3, axis='y')
    
    # 5. Score Confidence Intervals
    ax5 = fig.add_subplot(gs[1, 1])
    genuine_mean = np.mean(results['genuine_scores'])
    impostor_mean = np.mean(results['impostor_scores'])
    genuine_ci = results['genuine_ci']
    impostor_ci = results['impostor_ci']
    
    categories = ['Genuine', 'Impostor']
    means = [genuine_mean, impostor_mean]
    errors_lower = [genuine_mean - genuine_ci[0], impostor_mean - impostor_ci[0]]
    errors_upper = [genuine_ci[1] - genuine_mean, impostor_ci[1] - impostor_mean]
    
    bars = ax5.bar(categories, means, yerr=[errors_lower, errors_upper], 
                   capsize=10, alpha=0.7, color=['green', 'red'], 
                   edgecolor='black', linewidth=1.5)
    ax5.set_ylabel('Similarity Score', fontsize=11, fontweight='bold')
    ax5.set_title('Mean Scores with 95% CI', fontsize=12, fontweight='bold')
    ax5.grid(True, alpha=0.3, axis='y')
    
    # Add value labels
    for bar, mean in zip(bars, means):
        height = bar.get_height()
        ax5.text(bar.get_x() + bar.get_width()/2., height,
                f'{mean:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    # 6. Summary Statistics
    ax6 = fig.add_subplot(gs[1, 2])
    ax6.axis('off')
    
    summary_text = f"""
VERIFICATION SUMMARY
{'='*35}
Aggregation: {results['aggregation'].upper()}

ROC-AUC:         {results['roc_auc']:.4f}
EER:             {results['eer']:.4f}
EER Threshold:   {results['eer_threshold']:.3f}

TAR @ FAR=10%:   {tars_at_far[0]:.4f}
TAR @ FAR=1%:    {tars_at_far[1]:.4f}
TAR @ FAR=0.1%:  {tars_at_far[2]:.4f}

d-prime:         {results['dprime']:.4f}
Separation:      {results['separation']:.4f}

Genuine μ:       {genuine_mean:.4f}
Impostor μ:      {impostor_mean:.4f}
Δμ:              {abs(genuine_mean - impostor_mean):.4f}

Pairs:
  Genuine:       {results['n_genuine_pairs']}
  Impostor:      {results['n_impostor_pairs']}
    """
    
    ax6.text(0.1, 0.5, summary_text, fontsize=10, family='monospace',
             verticalalignment='center')
    
    plt.suptitle(title, fontsize=14, fontweight='bold', y=0.995)
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Verification plot saved: {save_path}")

In [None]:
def plot_paper_figures_identification(results: Dict, model_name: str, save_dir: Path):
    """Generate publication-quality identification plots"""
    df = results['threshold_results']
    save_dir.mkdir(parents=True, exist_ok=True)
    
    # 1. CMC Curve
    fig, ax = plt.subplots(figsize=(8, 6))
    best_threshold_idx = df['rank1_accuracy'].idxmax()
    available_ranks = [1, 5, 10]
    cmc_scores = [
        df.loc[best_threshold_idx, 'rank1_accuracy'],
        df.loc[best_threshold_idx, 'rank5_accuracy'],
        df.loc[best_threshold_idx, 'rank10_accuracy']
    ]
    
    ax.plot(available_ranks, cmc_scores, 'bo-', linewidth=3, markersize=10, 
            markeredgecolor='white', markeredgewidth=2)
    ax.set_xlabel('Rank', fontsize=14, fontweight='bold')
    ax.set_ylabel('Identification Rate', fontsize=14, fontweight='bold')
    ax.set_title(f'CMC Curve - {model_name}', fontsize=16, fontweight='bold')
    ax.set_xticks(available_ranks)
    ax.set_ylim([0, 1.05])
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0%}'))
    plt.tight_layout()
    plt.savefig(save_dir / f'{model_name}_cmc_curve.png', dpi=300, bbox_inches='tight')
    plt.close()
    print(f"CMC curve saved: {save_dir / f'{model_name}_cmc_curve.png'}")


def plot_paper_figures_verification(results: Dict, model_name: str, save_dir: Path):
    """Generate publication-quality verification plots"""
    df = results['threshold_results']
    save_dir.mkdir(parents=True, exist_ok=True)
    
    # 1. ROC Curve
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.plot(results['fpr'], results['tpr'], 'b-', linewidth=3, 
            label=f'AUC = {results["roc_auc"]:.4f}')
    ax.plot([0, 1], [0, 1], 'k--', alpha=0.4, linewidth=2, label='Random')
    ax.set_xlabel('False Positive Rate', fontsize=14, fontweight='bold')
    ax.set_ylabel('True Positive Rate', fontsize=14, fontweight='bold')
    ax.set_title(f'ROC Curve - {model_name}', fontsize=16, fontweight='bold')
    ax.legend(fontsize=12, loc='lower right')
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.set_xlim([0, 1])
    ax.set_ylim([0, 1])
    plt.tight_layout()
    plt.savefig(save_dir / f'{model_name}_roc_curve.png', dpi=300, bbox_inches='tight')
    plt.close()
    print(f"ROC curve saved: {save_dir / f'{model_name}_roc_curve.png'}")
    
    # 2. TAR @ FAR
    fig, ax = plt.subplots(figsize=(10, 6))
    target_fars = [0.1, 0.01, 0.001, 0.0001]
    far_labels = ['10%', '1%', '0.1%', '0.01%']
    tars_at_far = []
    
    for target_far in target_fars:
        idx = (df['far'] - target_far).abs().idxmin()
        tars_at_far.append(df.loc[idx, 'tar'])
    
    bars = ax.bar(far_labels, tars_at_far, alpha=0.7, color='steelblue', 
                   edgecolor='black', linewidth=2)
    
    for bar, tar in zip(bars, tars_at_far):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{tar:.3f}', ha='center', va='bottom', 
                fontsize=12, fontweight='bold')
    
    ax.set_xlabel('False Accept Rate', fontsize=14, fontweight='bold')
    ax.set_ylabel('True Accept Rate', fontsize=14, fontweight='bold')
    ax.set_title(f'TAR @ FAR - {model_name}', fontsize=16, fontweight='bold')
    ax.set_ylim([0, 1.1])
    ax.grid(True, alpha=0.3, axis='y', linestyle='--')
    ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0%}'))
    plt.tight_layout()
    plt.savefig(save_dir / f'{model_name}_tar_at_far.png', dpi=300, bbox_inches='tight')
    plt.close()
    print(f"TAR@FAR plot saved: {save_dir / f'{model_name}_tar_at_far.png'}")
    
    # 3. Score Distribution
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.hist(results['genuine_scores'], bins=60, alpha=0.6, label='Genuine (Intra-class)', 
            color='green', edgecolor='black', linewidth=1.2)
    ax.hist(results['impostor_scores'], bins=60, alpha=0.6, label='Impostor (Inter-class)', 
            color='red', edgecolor='black', linewidth=1.2)
    
    genuine_mean = np.mean(results['genuine_scores'])
    impostor_mean = np.mean(results['impostor_scores'])
    
    ax.axvline(genuine_mean, color='darkgreen', linestyle='--', linewidth=2.5, 
               label=f'Genuine μ={genuine_mean:.3f}')
    ax.axvline(impostor_mean, color='darkred', linestyle='--', linewidth=2.5, 
               label=f'Impostor μ={impostor_mean:.3f}')
    
    ax.set_xlabel('Cosine Similarity Score', fontsize=14, fontweight='bold')
    ax.set_ylabel('Frequency', fontsize=14, fontweight='bold')
    ax.set_title(f'Score Distribution - {model_name}', fontsize=16, fontweight='bold')
    ax.legend(fontsize=11, loc='upper right')
    ax.grid(True, alpha=0.3, axis='y', linestyle='--')
    plt.tight_layout()
    plt.savefig(save_dir / f'{model_name}_score_distribution.png', dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Score distribution saved: {save_dir / f'{model_name}_score_distribution.png'}")
    
    print(f"\nAll verification figures saved to: {save_dir}")

In [None]:
def plot_model_comparison_charts(all_model_results: Dict, 
                                comparison_summary: pd.DataFrame,
                                save_dir: Path):
    """Create comprehensive comparison visualizations"""
    save_dir.mkdir(parents=True, exist_ok=True)
    
    # 1. Bar chart: Rank-1 across models & galleries
    fig, ax = plt.subplots(figsize=(14, 6))
    
    pivot = comparison_summary.pivot_table(
        values='Rank-1', 
        index='Model', 
        columns='Gallery',
        aggfunc='max'
    )
    
    pivot.plot(kind='bar', ax=ax, width=0.8)
    ax.set_ylabel('Rank-1 Accuracy', fontsize=12)
    ax.set_xlabel('Model', fontsize=12)
    ax.set_title('Rank-1 Accuracy Comparison Across Models and Galleries', fontsize=14, fontweight='bold')
    ax.legend(title='Gallery Type', bbox_to_anchor=(1.05, 1), loc='upper left')
    ax.grid(True, alpha=0.3, axis='y')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(save_dir / 'comparison_rank1_bar.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # 2. ROC curves overlaid (if verification available)
    fig, ax = plt.subplots(figsize=(10, 8))
    
    colors = plt.cm.tab10(np.linspace(0, 1, len(all_model_results)))
    has_verification = False
    
    for (model_name, model_data), color in zip(all_model_results.items(), colors):
        try:
            # Use fewshot_augmented + mean as reference
            combined = model_data['basic_probe']['fewshot_augmented']['mean']
            ver_results = combined.get('verification')
            
            if ver_results is None:
                continue
            
            has_verification = True
            ax.plot(ver_results['fpr'], ver_results['tpr'], 
                   label=f"{model_name} (AUC={ver_results['roc_auc']:.3f})",
                   linewidth=2, color=color)
        except Exception as e:
            print(f"Warning: Could not plot ROC for {model_name}: {e}")
            continue
    
    if has_verification:
        ax.plot([0, 1], [0, 1], 'k--', alpha=0.3, label='Random')
        ax.set_xlabel('False Positive Rate', fontsize=12)
        ax.set_ylabel('True Positive Rate', fontsize=12)
        ax.set_title('ROC Curve Comparison (Fewshot Augmented + Mean)', fontsize=14, fontweight='bold')
        ax.legend(loc='lower right')
        ax.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig(save_dir / 'comparison_roc_curves.png', dpi=300, bbox_inches='tight')
        print(f"ROC comparison saved: {save_dir / 'comparison_roc_curves.png'}")
    else:
        print("Warning: No verification data available for ROC comparison")
    plt.close()
    
    # 3. Heatmap: Models vs Aggregation methods
    fig, ax = plt.subplots(figsize=(12, 6))
    
    pivot_agg = comparison_summary[comparison_summary['Gallery'] == 'fewshot_augmented'].pivot(
        index='Model',
        columns='Aggregation',
        values='Rank-1'
    )
    
    if not pivot_agg.empty:
        sns.heatmap(pivot_agg, annot=True, fmt='.3f', cmap='RdYlGn', 
                    vmin=0.0, vmax=1.0, ax=ax, cbar_kws={'label': 'Rank-1 Accuracy'})
        ax.set_title('Rank-1 Accuracy: Models vs Aggregation Methods (Fewshot Augmented)', 
                    fontsize=14, fontweight='bold')
        plt.tight_layout()
        plt.savefig(save_dir / 'comparison_aggregation_heatmap.png', dpi=300, bbox_inches='tight')
        print(f"Aggregation heatmap saved: {save_dir / 'comparison_aggregation_heatmap.png'}")
    plt.close()
    
    # 4. Score distributions (if verification available)
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    axes = axes.flatten()

    for idx, model_name in enumerate(list(all_model_results.keys())[:4]):
        try:
            combined = all_model_results[model_name]['basic_probe']['fewshot_augmented']['mean']
            ver_results = combined.get('verification')
            
            if ver_results is None:
                axes[idx].text(0.5, 0.5, f'{model_name}\n(No verification data)', 
                              ha='center', va='center', transform=axes[idx].transAxes)
                axes[idx].set_title(f'{model_name}', fontweight='bold', fontsize=12)
                continue
            
            genuine = ver_results['genuine_scores']
            impostor = ver_results['impostor_scores']
            
            # Create overlaid histograms
            axes[idx].hist(genuine, bins=40, alpha=0.6, label='Genuine', 
                        color='green', density=True, edgecolor='black')
            axes[idx].hist(impostor, bins=40, alpha=0.6, label='Impostor', 
                        color='red', density=True, edgecolor='black')
            
            # Add mean lines
            axes[idx].axvline(np.mean(genuine), color='darkgreen', 
                            linestyle='--', linewidth=2, label=f'Genuine μ={np.mean(genuine):.3f}')
            axes[idx].axvline(np.mean(impostor), color='darkred', 
                            linestyle='--', linewidth=2, label=f'Impostor μ={np.mean(impostor):.3f}')
            
            # Add d-prime annotation
            dprime = ver_results.get('dprime', 0)
            axes[idx].text(0.05, 0.95, f"d'={dprime:.3f}\nSep={np.mean(genuine)-np.mean(impostor):.3f}",
                        transform=axes[idx].transAxes, fontsize=10, verticalalignment='top',
                        bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
            
            axes[idx].set_title(f'{model_name}', fontweight='bold', fontsize=12)
            axes[idx].set_xlabel('Similarity Score')
            axes[idx].set_ylabel('Density')
            axes[idx].legend(loc='upper right', fontsize=9)
            axes[idx].grid(True, alpha=0.3, axis='y')
        except Exception as e:
            print(f"Warning: Could not plot score distribution for {model_name}: {e}")
            continue

    plt.suptitle('Score Distribution Comparison (Normalized)', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(save_dir / 'comparison_score_distributions.png', dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Score distributions saved: {save_dir / 'comparison_score_distributions.png'}")
        
    print

In [None]:
def plot_segmented_heatmap(segmented_table: pd.DataFrame, 
                          save_path: Path,
                          title: str = "Segmented Performance"):
    """Create heatmap for segmented evaluation results with logical grouping"""
    
    # Drop summary columns for heatmap
    plot_data = segmented_table.drop(['Mean', 'Std', 'Min', 'Max'], axis=1, errors='ignore')
    
    # Define segment order with logical grouping
    segment_categories = {
        'Quality': ['high_quality', 'low_quality'],
        'Face Size': ['face_large', 'face_medium', 'face_small'],
        'Pose': ['pose_easy', 'pose_medium', 'pose_hard'],
        'Blur': ['blur_sharp', 'blur_blurry']
    }
    
    # Build ordered segment list from available columns
    ordered_segments = []
    for category, segments in segment_categories.items():
        for seg in segments:
            if seg in plot_data.columns:
                ordered_segments.append(seg)
    
    # Reorder columns
    plot_data = plot_data[ordered_segments]
    
    # Create figure with more height for better readability
    fig, ax = plt.subplots(figsize=(16, 6))
    
    # Create heatmap
    sns.heatmap(plot_data, annot=True, fmt='.3f', cmap='RdYlGn',
                vmin=0.0, vmax=1.0, ax=ax, 
                cbar_kws={'label': 'Rank-1 Accuracy'},
                linewidths=0.5, linecolor='gray')
    
    # Add category separators
    category_positions = [0]  # Start position
    current_pos = 0
    for category, segments in segment_categories.items():
        available = [s for s in segments if s in ordered_segments]
        current_pos += len(available)
        if current_pos < len(ordered_segments):
            category_positions.append(current_pos)
            # Draw vertical line separator
            ax.axvline(x=current_pos, color='black', linewidth=2.5, zorder=10)
    
    # Add category labels at the top
    current_pos = 0
    label_y = -0.15  # Position above the heatmap
    for category, segments in segment_categories.items():
        available = [s for s in segments if s in ordered_segments]
        if available:
            category_width = len(available)
            center_pos = current_pos + category_width / 2
            ax.text(center_pos, label_y, category, 
                   ha='center', va='top', fontsize=11, fontweight='bold',
                   transform=ax.get_xaxis_transform())
            current_pos += category_width
    
    # Formatting
    ax.set_title(title, fontsize=14, fontweight='bold', pad=30)
    ax.set_xlabel('', fontsize=12)  # Remove xlabel, we have category labels
    ax.set_ylabel('Model', fontsize=12)
    
    # Rotate x-axis labels
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Segmented heatmap saved: {save_path}")
    
    # Print category summary
    print("\n" + "="*70)
    print("HEATMAP CATEGORY SUMMARY")
    print("="*70)
    for category, segments in segment_categories.items():
        available = [s for s in segments if s in ordered_segments]
        if available:
            print(f"\n{category}:")
            for seg in available:
                mean_acc = plot_data[seg].mean()
                print(f"  {seg:20s}: {mean_acc:.1%} (mean)")
    print("="*70 + "\n")

In [None]:
def plot_sensitivity_analysis(segmented_table: pd.DataFrame,
                              save_path: Path,
                              title: str = "Model Sensitivity Across Segments"):
    """Create individual line plots for each model showing sensitivity to different face quality conditions"""
    
    # Drop summary columns
    plot_data = segmented_table.drop(['Mean', 'Std', 'Min', 'Max'], axis=1, errors='ignore')
    
    # Dynamically sort segments by mean accuracy across all models (best → worst)
    segment_means = plot_data.mean(axis=0).sort_values(ascending=False)
    available_segments = segment_means.index.tolist()
    
    # Reorder columns based on sorted segments
    plot_data = plot_data[available_segments]
    
    # Define colors for each model
    colors = ['#2E86AB', '#A23B72', '#F18F01', '#06A77D']
    markers = ['o', 's', '^', 'D']
    
    # Create a subplot for each model
    num_models = len(plot_data)
    fig, axes = plt.subplots(num_models, 1, figsize=(14, 5 * num_models))
    
    # Handle case of single model
    if num_models == 1:
        axes = [axes]
    
    # Plot each model in its own subplot
    for idx, (model_name, row) in enumerate(plot_data.iterrows()):
        ax = axes[idx]
        color = colors[idx % len(colors)]
        marker = markers[idx % len(markers)]
        
        ax.plot(available_segments, row.values, 
                marker=marker, 
                linewidth=3, 
                markersize=10,
                color=color,
                markeredgecolor='white',
                markeredgewidth=2)
        
        # Formatting
        ax.set_ylabel('Rank-1 Accuracy', fontsize=12, fontweight='bold')
        ax.set_title(f'{model_name}', fontsize=14, fontweight='bold', pad=15)
        
        # Grid
        ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.8)
        ax.set_axisbelow(True)
        
        # Y-axis limits
        ax.set_ylim(0, 1.05)
        ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.1%}'))
        
        # X-axis
        ax.set_xticks(range(len(available_segments)))
        if idx == num_models - 1:  # Only show x-label on bottom plot
            ax.set_xticklabels(available_segments, rotation=45, ha='right')
            ax.set_xlabel('Quality Segment (Best → Worst)', fontsize=12, fontweight='bold')
        else:
            ax.set_xticklabels(available_segments, rotation=45, ha='right')
        
        # Add a subtle background gradient to show quality degradation
        gradient = ax.imshow([[0, 1]], cmap='RdYlGn_r', aspect='auto',
                            extent=[0, len(available_segments)-1, 0, 1.05],
                            alpha=0.1, zorder=0)
        
        # Add performance stats as text box
        best_acc = row.max()
        worst_acc = row.min()
        degradation = best_acc - worst_acc
        stats_text = f'Best: {best_acc:.1%} | Worst: {worst_acc:.1%} | Δ: {degradation:.1%}'
        ax.text(0.02, 0.98, stats_text, transform=ax.transAxes,
                fontsize=10, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    # Add overall title
    fig.suptitle(title, fontsize=16, fontweight='bold', y=0.995)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Sensitivity plot saved: {save_path}")
    
    # Print numerical summary
    print("\n" + "="*70)
    print("SENSITIVITY ANALYSIS SUMMARY")
    print("="*70)
    print(f"\nSegment order (best → worst by mean accuracy):")
    for seg in available_segments:
        mean_acc = segment_means[seg]
        print(f"  {seg:20s}: {mean_acc:.1%}")
    
    print("\nPerformance degradation per model:")
    for model_name, row in plot_data.iterrows():
        best_acc = row.max()
        worst_acc = row.min()
        degradation = best_acc - worst_acc
        best_seg = row.idxmax()
        worst_seg = row.idxmin()
        print(f"  {model_name:30s}: {best_acc:.1%} ({best_seg}) → {worst_acc:.1%} ({worst_seg}) | Δ = {degradation:.1%}")
    print("="*70 + "\n")

In [None]:
def plot_gallery_strategy_comparison(strategy_df: pd.DataFrame, save_path: Path):
    """Visualize gallery strategy analysis"""
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # 1. Raw scores comparison
    ax = axes[0, 0]
    strategy_df.set_index('Model')[['Oneshot_Base', 'Oneshot_Aug', 
                                     'Fewshot_Base', 'Fewshot_Aug']].plot(
        kind='bar', ax=ax, width=0.8)
    ax.set_ylabel('Rank-1 Accuracy')
    ax.set_title('Gallery Strategy Comparison', fontweight='bold')
    ax.legend(title='Configuration')
    ax.grid(True, alpha=0.3, axis='y')
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha='right')
    
    # 2. Augmentation improvement
    ax = axes[0, 1]
    strategy_df.set_index('Model')[['Aug_Improvement_Oneshot', 
                                     'Aug_Improvement_Fewshot']].plot(
        kind='bar', ax=ax, width=0.8)
    ax.set_ylabel('Rank-1 Improvement')
    ax.set_title('Augmentation Benefit', fontweight='bold')
    ax.axhline(y=0, color='k', linestyle='--', alpha=0.3)
    ax.legend(title='Gallery Type')
    ax.grid(True, alpha=0.3, axis='y')
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha='right')
    
    # 3. Fewshot improvement
    ax = axes[1, 0]
    strategy_df.set_index('Model')[['Fewshot_Improvement_Base', 
                                     'Fewshot_Improvement_Aug']].plot(
        kind='bar', ax=ax, width=0.8)
    ax.set_ylabel('Rank-1 Improvement')
    ax.set_title('Fewshot vs Oneshot Benefit', fontweight='bold')
    ax.axhline(y=0, color='k', linestyle='--', alpha=0.3)
    ax.legend(title='Augmentation')
    ax.grid(True, alpha=0.3, axis='y')
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha='right')
    
    # 4. Best configuration per model
    ax = axes[1, 1]
    best_configs = strategy_df.groupby('Best_Config').size()
    best_configs.plot(kind='bar', ax=ax, color='steelblue')
    ax.set_ylabel('Number of Models')
    ax.set_title('Most Common Best Configuration', fontweight='bold')
    ax.grid(True, alpha=0.3, axis='y')
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha='right')
    
    plt.suptitle('Gallery Strategy Analysis', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Gallery strategy plot saved: {save_path}")

In [None]:
def export_comprehensive_report(all_model_results: Dict,
                               all_summaries: Dict,
                               save_dir: Path):
    """Export complete results to multiple formats"""
    save_dir.mkdir(parents=True, exist_ok=True)
    
    # 1. Excel workbook with multiple sheets
    excel_path = save_dir / 'comprehensive_report.xlsx'
    with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
        all_summaries['comparison_summary'].to_excel(writer, sheet_name='Overall_Comparison', index=False)
        all_summaries['gallery_strategy'].to_excel(writer, sheet_name='Gallery_Strategy', index=False)
        all_summaries['aggregation_analysis'].to_excel(writer, sheet_name='Aggregation_Analysis', index=False)
        all_summaries['threshold_recommendations'].to_excel(writer, sheet_name='Threshold_Recommendations', index=False)
        
        if 'segmented_oneshot' in all_summaries:
            all_summaries['segmented_oneshot'].to_excel(writer, sheet_name='Segmented_Oneshot')
        if 'segmented_fewshot' in all_summaries:
            all_summaries['segmented_fewshot'].to_excel(writer, sheet_name='Segmented_Fewshot')
        if 'statistical_comparison' in all_summaries:
            all_summaries['statistical_comparison'].to_excel(writer, sheet_name='Statistical_Tests', index=False)
    
    print(f"Excel report saved: {excel_path}")
    
    # 2. JSON export
    json_data = {
        'metadata': {
            'generated': datetime.now().isoformat(),
            'models_evaluated': list(all_model_results.keys())
        },
        'summaries': {
            key: df.to_dict(orient='records') if isinstance(df, pd.DataFrame) else df
            for key, df in all_summaries.items()
            if key != 'executive_summary'
        },
        'executive_summary': all_summaries.get('executive_summary', '')
    }
    
    json_path = save_dir / 'comprehensive_report.json'
    with open(json_path, 'w') as f:
        json.dump(json_data, f, indent=2)
    
    print(f"JSON report saved: {json_path}")
    
    # 3. Text summary
    txt_path = save_dir / 'executive_summary.txt'
    with open(txt_path, 'w') as f:
        f.write(all_summaries.get('executive_summary', ''))
    
    print(f"Text summary saved: {txt_path}")
    
    # 4. LaTeX tables
    latex_path = save_dir / 'latex_tables.tex'
    with open(latex_path, 'w') as f:
        f.write("% Comparison Summary\n")
        f.write(all_summaries['comparison_summary'].to_latex(index=False, float_format="%.4f"))
        f.write("\n\n% Gallery Strategy\n")
        f.write(all_summaries['gallery_strategy'].to_latex(index=False, float_format="%.4f"))
    
    print(f"LaTeX tables saved: {latex_path}")

In [None]:
def plot_impostor_metrics(results: Dict, title: str, save_path: Path):
    df = results['threshold_results']
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    ax = axes[0, 0]
    ax.plot(df['threshold'], df['rejection_rate'], 'g-', linewidth=2)
    ax.set_xlabel('Threshold')
    ax.set_ylabel('Rejection Rate')
    ax.set_title('Impostor Rejection Rate')
    ax.grid(True, alpha=0.3)

    ax = axes[0, 1]
    ax.plot(df['threshold'], df['far'], 'r-', linewidth=2)
    ax.set_xlabel('Threshold')
    ax.set_ylabel('False Accept Rate')
    ax.set_title('False Accept Rate')
    ax.grid(True, alpha=0.3)

    ax = axes[1, 0]
    ax.hist(results['impostor_scores'], bins=50, alpha=0.7, color='red')
    ax.axvline(np.mean(results['impostor_scores']), color='darkred', 
              linestyle='--', linewidth=2, label='Mean')
    ax.axvline(results['impostor_ci'][0], color='orange', 
              linestyle=':', linewidth=2, label='95% CI')
    ax.axvline(results['impostor_ci'][1], color='orange', 
              linestyle=':', linewidth=2)
    ax.set_xlabel('Similarity Score')
    ax.set_ylabel('Frequency')
    ax.set_title('Impostor Score Distribution')
    ax.legend()
    ax.grid(True, alpha=0.3)

    ax = axes[1, 1]
    ax.axis('off')
    
    best_idx = df['rejection_rate'].idxmax()
    best_row = df.loc[best_idx]
    
    summary_text = f"""
    IMPOSTOR REJECTION SUMMARY
    ==========================
    Aggregation: {results['aggregation'].upper()}
    
    Best Rejection: {best_row['rejection_rate']:.4f}
    @ Threshold: {best_row['threshold']:.3f}
    
    FAR at best: {best_row['far']:.4f}
    
    Total Impostors: {best_row['n_impostors']}
    
    Mean Score: {np.mean(results['impostor_scores']):.4f}
    Std Score: {np.std(results['impostor_scores']):.4f}
    
    95% CI: [{results['impostor_ci'][0]:.4f}, 
             {results['impostor_ci'][1]:.4f}]
    """
    
    ax.text(0.1, 0.5, summary_text, fontsize=11, family='monospace',
           verticalalignment='center')
    
    plt.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Plot saved: {save_path}")

In [None]:
def run_basic_probe_evaluation(model_name: str, embeddings: Dict, 
                               results_dir: Path, plots_dir: Path):
    """Run basic probe evaluation with BOTH identification AND verification"""
    print(f"\n{'='*70}")
    print(f"BASIC PROBE EVALUATION: {model_name}")
    print(f"{'='*70}")
    
    probe_positive = embeddings['probe_positive_unsegmented']
    probe_negative = embeddings['probe_negative']

    if probe_positive is None:
        print("Missing positive probe embeddings!")
        return None
    
    if probe_negative is None:
        print("Warning: Missing negative probe embeddings! Verification metrics will be skipped.")

    gallery_types = {
        'oneshot_base': 'gallery_oneshot_base',
        'oneshot_augmented': 'gallery_oneshot_augmented', 
        'fewshot_base': 'gallery_fewshot_base',
        'fewshot_augmented': 'gallery_fewshot_augmented'
    }
        
    thresholds = np.arange(0.2, 0.91, 0.05)
    aggregations = ['mean']
    
    all_results = {}
    per_identity_consolidated = None

    for gallery_name, gallery_key in gallery_types.items():
        gallery = embeddings.get(gallery_key)
        
        if gallery is None:
            print(f"Missing {gallery_name} gallery, skipping...")
            continue
        
        print(f"\n{'-'*70}")
        print(f"GALLERY: {gallery_name.upper()}")
        print(f"{'-'*70}")
        
        gallery_results = {}
    
        for agg in aggregations:
            # ============================================
            # IDENTIFICATION EVALUATION
            # ============================================
            print(f"\n[IDENTIFICATION] Evaluating with {agg.upper()} aggregation...")
            id_results = evaluate_probes_comprehensive(
                gallery, probe_positive, thresholds, aggregation=agg, k=3
            )
            
            if per_identity_consolidated is None and 'per_identity' in id_results:
                per_identity_consolidated = id_results['per_identity']
            
            # ============================================
            # VERIFICATION EVALUATION
            # ============================================
            ver_results = None
            if probe_negative is not None:
                print(f"\n[VERIFICATION] Evaluating with {agg.upper()} aggregation...")
                try:
                    ver_results = evaluate_verification_comprehensive(
                        gallery, probe_positive, probe_negative, thresholds, aggregation=agg, k=3
                    )
                except Exception as e:
                    print(f"  Error in verification evaluation: {e}")
                    ver_results = None
            
            # Combine both results
            combined_results = {
                'identification': id_results,
                'verification': ver_results  # Can be None if probe_negative missing
            }
            
            # ============================================
            # SAVE CSV RESULTS
            # ============================================
            csv_path_id = results_dir / model_name / f'basic_probe_{gallery_name}_{agg}_identification.csv'
            csv_path_id.parent.mkdir(parents=True, exist_ok=True)
            id_results['threshold_results'].to_csv(csv_path_id, index=False)
            
            if ver_results is not None:
                csv_path_ver = results_dir / model_name / f'basic_probe_{gallery_name}_{agg}_verification.csv'
                ver_results['threshold_results'].to_csv(csv_path_ver, index=False)

            # ============================================
            # GENERATE PLOTS - IDENTIFICATION
            # ============================================
            plot_path_id = plots_dir / model_name / f'basic_probe_{gallery_name}_{agg}_identification_core.png'
            plot_path_id.parent.mkdir(parents=True, exist_ok=True)
            plot_core_metrics_identification(
                id_results, 
                f"{model_name} - Identification - {gallery_name.upper()} ({agg.upper()})", 
                plot_path_id
            )
            
            # ============================================
            # GENERATE PLOTS - VERIFICATION
            # ============================================
            if ver_results is not None:
                plot_path_ver = plots_dir / model_name / f'basic_probe_{gallery_name}_{agg}_verification_core.png'
                plot_core_metrics_verification(
                    ver_results, 
                    f"{model_name} - Verification - {gallery_name.upper()} ({agg.upper()})", 
                    plot_path_ver
                )
            
            # ============================================
            # GENERATE PAPER FIGURES - IDENTIFICATION
            # ============================================
            paper_figs_dir_id = plots_dir / model_name / 'paper_figures' / f'basic_{gallery_name}_{agg}_identification'
            plot_paper_figures_identification(
                id_results,
                model_name=f"{model_name}_{gallery_name}_{agg}",
                save_dir=paper_figs_dir_id
            )
            
            # ============================================
            # GENERATE PAPER FIGURES - VERIFICATION
            # ============================================
            if ver_results is not None:
                paper_figs_dir_ver = plots_dir / model_name / 'paper_figures' / f'basic_{gallery_name}_{agg}_verification'
                plot_paper_figures_verification(
                    ver_results,
                    model_name=f"{model_name}_{gallery_name}_{agg}",
                    save_dir=paper_figs_dir_ver
                )
                        
            gallery_results[agg] = combined_results

            # ============================================
            # PRINT SUMMARY - IDENTIFICATION
            # ============================================
            print(f"\n  IDENTIFICATION Metrics:")
            print(f"  {'─'*50}")
            df_id = id_results['threshold_results']
            best_idx = df_id['rank1_accuracy'].idxmax()
            print(f"  Best Rank-1: {df_id.loc[best_idx, 'rank1_accuracy']:.4f} "
                f"@ threshold {df_id.loc[best_idx, 'threshold']:.2f}")
            print(f"  Rank-5: {df_id.loc[best_idx, 'rank5_accuracy']:.4f}")
            print(f"  Rank-10: {df_id.loc[best_idx, 'rank10_accuracy']:.4f}")
            print(f"  MRR: {df_id.loc[best_idx, 'mrr']:.4f}")
            
            # ============================================
            # PRINT SUMMARY - VERIFICATION
            # ============================================
            if ver_results is not None:
                print(f"\n  VERIFICATION Metrics:")
                print(f"  {'─'*50}")
                print(f"  ROC-AUC: {ver_results['roc_auc']:.4f}")
                print(f"  EER: {ver_results['eer']:.4f} @ threshold {ver_results['eer_threshold']:.2f}")
                print(f"  TAR@FAR=0.1%: {ver_results['tar_at_far_0.001']:.4f}")
                print(f"  TAR@FAR=1%: {ver_results['tar_at_far_0.01']:.4f}")
                print(f"  TAR@FAR=10%: {ver_results['tar_at_far_0.1']:.4f}")
                print(f"  d-prime: {ver_results['dprime']:.4f}")
                print(f"  Separation: {ver_results['separation']:.4f}")
                
                print(f"\n  Score Distribution Statistics:")
                print(f"  {'─'*50}")
                print(f"  μ_genuine:  {ver_results['genuine_mean']:.4f} (σ = {ver_results['genuine_std']:.4f})")
                print(f"  μ_impostor: {ver_results['impostor_mean']:.4f} (σ = {ver_results['impostor_std']:.4f})")
                print(f"  Δμ = {abs(ver_results['genuine_mean'] - ver_results['impostor_mean']):.4f}")
                print(f"  n_genuine_pairs: {ver_results['n_genuine_pairs']}")
                print(f"  n_impostor_pairs: {ver_results['n_impostor_pairs']}")
            else:
                print(f"\n  VERIFICATION Metrics: SKIPPED (no negative probes)")
            
        all_results[gallery_name] = gallery_results
    
    all_results['per_identity'] = per_identity_consolidated
    
    return all_results

In [None]:
def run_impostor_evaluation(model_name: str, embeddings: Dict,
                           results_dir: Path, plots_dir: Path):
    """Run impostor evaluation (ORIGINAL + ENHANCED OUTPUT)"""
    print(f"\n{'='*70}")
    print(f"IMPOSTOR EVALUATION: {model_name}")
    print(f"{'='*70}")
    
    gallery = embeddings['gallery_oneshot_augmented']
    impostor = embeddings['probe_negative']
    
    if gallery is None or impostor is None:
        print("Missing embeddings!")
        return None
    
    thresholds = np.arange(0.2, 0.91, 0.05)
    
    print("\nEvaluating with MEAN aggregation...")
    results = evaluate_impostors_comprehensive(
        gallery, impostor, thresholds, aggregation='mean', k=3
    )
    
    csv_path = results_dir / model_name / 'impostor_metrics.csv'
    csv_path.parent.mkdir(parents=True, exist_ok=True)
    results['threshold_results'].to_csv(csv_path, index=False)
    
    plot_path = plots_dir / model_name / 'impostor_plot.png'
    plot_path.parent.mkdir(parents=True, exist_ok=True)
    plot_impostor_metrics(results, f"{model_name} - Impostor Rejection", plot_path)

    df = results['threshold_results']
    best_idx = df['rejection_rate'].idxmax()
    print(f"  Best Rejection Rate: {df.loc[best_idx, 'rejection_rate']:.4f} "
          f"@ threshold {df.loc[best_idx, 'threshold']:.2f}")
    print(f"  FAR at best: {df.loc[best_idx, 'far']:.4f}")
    print(f"  Total impostors: {df.loc[best_idx, 'n_impostors']}")
    print(f"  Mean impostor score: {results['mean_impostor_score']:.4f}")
    print(f"  Std impostor score: {results['std_impostor_score']:.4f}")
    print(f"  95% CI: [{results['impostor_ci'][0]:.4f}, {results['impostor_ci'][1]:.4f}]")
    
    return results


In [None]:
def run_segmented_evaluation(model_name: str, embeddings: Dict,
                             results_dir: Path, plots_dir: Path,
                             gallery_type: str):
    """Run segmented evaluation with BOTH identification AND verification"""
    print(f"\n{'='*70}")
    print(f"SEGMENTED EVALUATION: {model_name} ({gallery_type})")
    print(f"{'='*70}")
    
    gallery_key = f'gallery_{gallery_type}_augmented'
    gallery = embeddings[gallery_key]
    probe_positive = embeddings['probe_positive_segmented']
    probe_negative = embeddings['probe_negative']
    
    if gallery is None or probe_positive is None:
        print("Missing embeddings!")
        return None
    
    if probe_negative is None:
        print("Warning: No negative probes. Verification will be skipped for segmented evaluation.")
    
    thresholds = np.arange(0.2, 0.91, 0.05)
    
    print("\nEvaluating with MEAN aggregation...")
    segment_results = evaluate_segmented_comprehensive(
        gallery, probe_positive, probe_negative, thresholds, 
        aggregation='mean', k=3, include_verification=True
    )
    
    for segment_name, combined_results in segment_results.items():
        # Extract identification and verification results
        id_results = combined_results['identification']
        ver_results = combined_results['verification']  # Can be None
        
        # Save CSVs
        csv_path_id = results_dir / model_name / f'segmented_{gallery_type}_{segment_name}_identification.csv'
        csv_path_id.parent.mkdir(parents=True, exist_ok=True)
        id_results['threshold_results'].to_csv(csv_path_id, index=False)
        
        if ver_results is not None:
            csv_path_ver = results_dir / model_name / f'segmented_{gallery_type}_{segment_name}_verification.csv'
            ver_results['threshold_results'].to_csv(csv_path_ver, index=False)

        # Plots - Identification
        plot_path_id = plots_dir / model_name / f'segmented_{gallery_type}_{segment_name}_identification_core.png'
        plot_path_id.parent.mkdir(parents=True, exist_ok=True)
        plot_core_metrics_identification(
            id_results, 
            f"{model_name} - {segment_name} - ID ({gallery_type})", 
            plot_path_id
        )
        
        # Plots - Verification
        if ver_results is not None:
            plot_path_ver = plots_dir / model_name / f'segmented_{gallery_type}_{segment_name}_verification_core.png'
            plot_core_metrics_verification(
                ver_results, 
                f"{model_name} - {segment_name} - Ver ({gallery_type})", 
                plot_path_ver
            )
        
        # Paper figures
        paper_figs_dir_id = plots_dir / model_name / 'paper_figures' / f'segmented_{gallery_type}_{segment_name}_identification'
        plot_paper_figures_identification(
            id_results,
            model_name=f"{model_name}_{segment_name}",
            save_dir=paper_figs_dir_id
        )
        
        if ver_results is not None:
            paper_figs_dir_ver = plots_dir / model_name / 'paper_figures' / f'segmented_{gallery_type}_{segment_name}_verification'
            plot_paper_figures_verification(
                ver_results,
                model_name=f"{model_name}_{segment_name}",
                save_dir=paper_figs_dir_ver
            )
        
        # Print summaries
        print(f"\n  {segment_name}:")
        print(f"  {'─'*50}")
        print(f"  IDENTIFICATION:")
        df_id = id_results['threshold_results']
        best_idx = df_id['rank1_accuracy'].idxmax()
        print(f"    Rank-1: {df_id.loc[best_idx, 'rank1_accuracy']:.4f} @ threshold {df_id.loc[best_idx, 'threshold']:.2f}")
        print(f"    Rank-5: {df_id.loc[best_idx, 'rank5_accuracy']:.4f}")
        print(f"    Rank-10: {df_id.loc[best_idx, 'rank10_accuracy']:.4f}")
        print(f"    MRR: {df_id.loc[best_idx, 'mrr']:.4f}")
        
        if ver_results is not None:
            print(f"\n  VERIFICATION:")
            print(f"    ROC-AUC: {ver_results['roc_auc']:.4f}")
            print(f"    EER: {ver_results['eer']:.4f} @ threshold {ver_results['eer_threshold']:.2f}")
            print(f"    TAR@FAR=0.1%: {ver_results['tar_at_far_0.001']:.4f}")
            print(f"    d-prime: {ver_results['dprime']:.4f}")
            print(f"    Separation: {ver_results['separation']:.4f}")
        else:
            print(f"\n  VERIFICATION: SKIPPED")
    
    return segment_results

In [None]:
def plot_rank1_per_identity_all_models(all_model_results: Dict[str, Dict],
                                       save_path: Path,
                                       sort_by_average: bool = True):
    """
    Plot Rank-1 per identity for all models in a single chart.
    Expects each model to contain:
        results["basic_probe"]["per_identity"]
    """

    # --- Gather all identities across all models ---
    identities = set()
    for model_name, results in all_model_results.items():
        per_id = results["basic_probe"].get("per_identity", {})
        identities.update(per_id.keys())
    identities = list(identities)

    # --- Build table: identity -> {model: rank1} ---
    rank1_table = {idn: {} for idn in identities}
    model_names = list(all_model_results.keys())

    for model_name, results in all_model_results.items():
        per_id = results["basic_probe"].get("per_identity", {})
        for idn in identities:
            rank1_table[idn][model_name] = per_id.get(idn, {}).get("rank1", 0)

    # --- Sort identities (optional but recommended) ---
    if sort_by_average:
        identities.sort(
            key=lambda idn: np.mean(list(rank1_table[idn].values())),
            reverse=True
        )

    # --- Plot ---
    x = np.arange(len(identities))
    fig, ax = plt.subplots(figsize=(22, 8))

    for model_name in model_names:
        y = [rank1_table[idn][model_name] for idn in identities]
        ax.plot(x, y, marker='o', linewidth=2, label=model_name)

    ax.set_xticks(x)
    ax.set_xticklabels(identities, rotation=90)
    ax.set_ylabel("Rank-1 Accuracy")
    ax.set_title("Rank-1 Accuracy per Identity Across Models")
    ax.grid(True, linestyle="--", alpha=0.4)
    ax.legend()

    fig.tight_layout()
    save_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(save_path, dpi=300)
    plt.close(fig)


In [None]:
def run_complete_evaluation_pipeline(all_embeddings: Dict, output_base_dir: Path):
    """
    Complete evaluation pipeline with all analysis and comparisons
    
    Args:
        all_embeddings: Dict with structure {model_name: embeddings_dict}
        output_base_dir: Base directory for all outputs
    """
    
    print("\n" + "="*80)
    print("COMPLETE FACE RECOGNITION EVALUATION PIPELINE")
    print("="*80)
    
    results_dir = output_base_dir / 'evaluation_results'
    plots_dir = output_base_dir / 'plots'
    comparison_dir = output_base_dir / 'comparisons'

    results_dir.mkdir(parents=True, exist_ok=True)
    plots_dir.mkdir(parents=True, exist_ok=True)
    comparison_dir.mkdir(parents=True, exist_ok=True)
    (comparison_dir / "charts").mkdir(parents=True, exist_ok=True)
    (comparison_dir / "reports").mkdir(parents=True, exist_ok=True)
        
    all_model_results = {}
    
    # Run individual model evaluations
    for model_name, embeddings in all_embeddings.items():
        print(f"\n{'#'*80}")
        print(f"# PROCESSING MODEL: {model_name}")
        print(f"{'#'*80}")
        
        model_results = {}
        
        # 1. Basic probe evaluation
        model_results['basic_probe'] = run_basic_probe_evaluation(
            model_name, embeddings, results_dir, plots_dir
        )
        
        # 2. Impostor evaluation
        model_results['impostor'] = run_impostor_evaluation(
            model_name, embeddings, results_dir, plots_dir
        )
        
        # 3. Segmented evaluation - oneshot
        model_results['segmented_oneshot'] = run_segmented_evaluation(
            model_name, embeddings, results_dir, plots_dir, 'oneshot'
        )
        
        # 4. Segmented evaluation - fewshot
        model_results['segmented_fewshot'] = run_segmented_evaluation(
            model_name, embeddings, results_dir, plots_dir, 'fewshot'
        )
        
        all_model_results[model_name] = model_results
    
    # ========================================================================
    # DISPLAY KEY METRICS SUMMARY
    # ========================================================================
    
    print(f"\n{'#'*80}")
    print("# KEY METRICS SUMMARY")
    print(f"{'#'*80}\n")
    
    for model_name, model_results in all_model_results.items():
        print(f"\n{model_name.upper()}")
        print(f"{'─'*70}")
        
        # Extract metrics from basic_probe oneshot_base mean (most common baseline)
        # Extract metrics from basic_probe oneshot_base mean
        if (model_results.get('basic_probe') and 
            'oneshot_base' in model_results['basic_probe'] and
            'mean' in model_results['basic_probe']['oneshot_base']):
            
            combined = model_results['basic_probe']['oneshot_base']['mean']
            ver_results = combined.get('verification')  # Can be None
            
            if ver_results is not None:
                print(f"  AUC:               {ver_results['roc_auc']:.4f}")
                print(f"  EER:               {ver_results['eer']*100:.2f}%")
                print(f"  EER Threshold:     {ver_results['eer_threshold']:.4f}")
                print(f"  TAR @ 0.1% FAR:    {ver_results['tar_at_far_0.001']*100:.2f}%")
                print(f"  TAR @ 1% FAR:      {ver_results['tar_at_far_0.01']*100:.2f}%")
                print(f"  TAR @ 10% FAR:     {ver_results['tar_at_far_0.1']*100:.2f}%")
                print(f"  Separation:        {ver_results['separation']:.4f}")
                print(f"  d-prime:           {ver_results['dprime']:.4f}")
                print(f"  μ_genuine:         {ver_results['genuine_mean']:.4f} (σ={ver_results['genuine_std']:.4f})")
                print(f"  μ_impostor:        {ver_results['impostor_mean']:.4f} (σ={ver_results['impostor_std']:.4f})")
            else:
                print(f"  Verification metrics: NOT AVAILABLE (missing negative probes)")
    
    # ========================================================================
    # COMPARATIVE ANALYSIS
    # ========================================================================
    
    print(f"\n{'#'*80}")
    print("# COMPARATIVE ANALYSIS")
    print(f"{'#'*80}")
    
    all_summaries = {}
    
    # 1. Generate comparison summary
    print("\n1. Generating comparison summary...")
    all_summaries['comparison_summary'] = generate_comparison_summary(all_model_results)
    all_summaries['comparison_summary'].to_csv(comparison_dir / 'comparison_summary.csv', index=False)
    print(f"   Saved: {comparison_dir / 'comparison_summary.csv'}")
    
    # 2. Gallery strategy analysis
    print("\n2. Analyzing gallery strategies...")
    all_summaries['gallery_strategy'] = analyze_gallery_strategies(all_model_results)
    all_summaries['gallery_strategy'].to_csv(comparison_dir / 'gallery_strategy_analysis.csv', index=False)
    print(f"   Saved: {comparison_dir / 'gallery_strategy_analysis.csv'}")
    
    # 3. Aggregation method analysis
    print("\n3. Analyzing aggregation methods...")
    all_summaries['aggregation_analysis'] = summarize_aggregation_performance(all_model_results)
    all_summaries['aggregation_analysis'].to_csv(comparison_dir / 'aggregation_analysis.csv', index=False)
    print(f"   Saved: {comparison_dir / 'aggregation_analysis.csv'}")
    
    # 4. Threshold recommendations
    print("\n4. Generating threshold recommendations...")
    all_summaries['threshold_recommendations'] = recommend_operating_thresholds(all_model_results)
    all_summaries['threshold_recommendations'].to_csv(comparison_dir / 'threshold_recommendations.csv', index=False)
    print(f"   Saved: {comparison_dir / 'threshold_recommendations.csv'}")
    
    # 5. Segmented comparison tables - IDENTIFICATION
    print("\n5a. Creating segmented IDENTIFICATION comparison tables...")
    all_summaries['segmented_oneshot_identification'] = create_segmented_comparison_table(
        all_model_results, 'oneshot', 'identification')
    all_summaries['segmented_oneshot_identification'].to_csv(
        comparison_dir / 'segmented_oneshot_identification.csv')

    all_summaries['segmented_fewshot_identification'] = create_segmented_comparison_table(
        all_model_results, 'fewshot', 'identification')
    all_summaries['segmented_fewshot_identification'].to_csv(
        comparison_dir / 'segmented_fewshot_identification.csv')

    # 5b. Segmented comparison tables - VERIFICATION
    print("\n5b. Creating segmented VERIFICATION comparison tables...")
    all_summaries['segmented_oneshot_verification'] = create_segmented_comparison_table(
        all_model_results, 'oneshot', 'verification')
    all_summaries['segmented_oneshot_verification'].to_csv(
        comparison_dir / 'segmented_oneshot_verification.csv')

    all_summaries['segmented_fewshot_verification'] = create_segmented_comparison_table(
        all_model_results, 'fewshot', 'verification')
    all_summaries['segmented_fewshot_verification'].to_csv(
        comparison_dir / 'segmented_fewshot_verification.csv')
    
    # 6. Failure analysis
    print("\n6. Analyzing failure cases...")
    all_summaries['failure_analysis'] = analyze_failure_cases(all_model_results)
    with open(comparison_dir / 'failure_analysis.json', 'w') as f:
        json.dump(all_summaries['failure_analysis'], f, indent=2)
    print(f"   Saved: {comparison_dir / 'failure_analysis.json'}")
    
    # 7. Statistical comparison
    print("\n7. Performing statistical comparisons...")
    all_summaries['statistical_comparison'] = compare_models_statistical(all_model_results)
    all_summaries['statistical_comparison'].to_csv(comparison_dir / 'statistical_comparison.csv', index=False)
    print(f"   Saved: {comparison_dir / 'statistical_comparison.csv'}")
    
    # 8. Executive summary
    print("\n8. Generating executive summary...")
    all_summaries['executive_summary'] = generate_executive_summary(
        all_model_results, 
        all_summaries['comparison_summary']
    )
    print(all_summaries['executive_summary'])
    
    # ========================================================================
    # VISUALIZATIONS
    # ========================================================================
    
    print(f"\n{'#'*80}")
    print("# GENERATING COMPARISON VISUALIZATIONS")
    print(f"{'#'*80}")
    
    # 1. Model comparison charts
    print("\n1. Creating model comparison charts...")
    plot_model_comparison_charts(
        all_model_results,
        all_summaries['comparison_summary'],
        comparison_dir / 'charts'
    )
    print("\n1b. Creating Rank-1 per-identity comparison chart...")
    plot_rank1_per_identity_all_models(
        all_model_results,
        comparison_dir / "charts" / "rank1_per_identity_all_models.png"
    )
    print(f"   Saved: {comparison_dir / 'charts' / 'rank1_per_identity_all_models.png'}")
    
    # 2. Segmented heatmaps - IDENTIFICATION
    print("\n2a. Creating segmented IDENTIFICATION heatmaps...")
    if 'segmented_oneshot_identification' in all_summaries:
        plot_segmented_heatmap(
            all_summaries['segmented_oneshot_identification'],  # ✓ CORRECT
            comparison_dir / 'charts' / 'segmented_oneshot_identification_heatmap.png',
            'Segmented Identification Performance - Oneshot'
        )
        
    if 'segmented_fewshot_identification' in all_summaries:
        plot_segmented_heatmap(
            all_summaries['segmented_fewshot_identification'],  # ✓ CORRECT
            comparison_dir / 'charts' / 'segmented_fewshot_identification_heatmap.png',
            'Segmented Identification Performance - Fewshot'
        )

    # 2b. Segmented heatmaps - VERIFICATION
    print("\n2b. Creating segmented VERIFICATION heatmaps...")
    if 'segmented_oneshot_verification' in all_summaries:
        plot_segmented_heatmap(
            all_summaries['segmented_oneshot_verification'],  # ✓ CORRECT
            comparison_dir / 'charts' / 'segmented_oneshot_verification_heatmap.png',
            'Segmented Verification Performance - Oneshot'
        )

    if 'segmented_fewshot_verification' in all_summaries:
        plot_segmented_heatmap(
            all_summaries['segmented_fewshot_verification'],  # ✓ CORRECT
            comparison_dir / 'charts' / 'segmented_fewshot_verification_heatmap.png',
            'Segmented Verification Performance - Fewshot'
        )

    # 3. Sensitivity analysis - IDENTIFICATION
    print("\n3a. Creating sensitivity analysis plots - IDENTIFICATION...")
    if 'segmented_oneshot_identification' in all_summaries:
        plot_sensitivity_analysis(
            all_summaries['segmented_oneshot_identification'],  # ✓ CORRECT
            comparison_dir / 'charts' / 'sensitivity_analysis_oneshot_identification.png',
            'Sensitivity Analysis - Oneshot Identification'
        )

    if 'segmented_fewshot_identification' in all_summaries:
        plot_sensitivity_analysis(
            all_summaries['segmented_fewshot_identification'],  # ✓ CORRECT
            comparison_dir / 'charts' / 'sensitivity_analysis_fewshot_identification.png',
            'Sensitivity Analysis - Fewshot Identification'
        )

    # 3b. Sensitivity analysis - VERIFICATION
    print("\n3b. Creating sensitivity analysis plots - VERIFICATION...")
    if 'segmented_oneshot_verification' in all_summaries:
        plot_sensitivity_analysis(
            all_summaries['segmented_oneshot_verification'],  # ✓ CORRECT
            comparison_dir / 'charts' / 'sensitivity_analysis_oneshot_verification.png',
            'Sensitivity Analysis - Oneshot Verification'
        )

    if 'segmented_fewshot_verification' in all_summaries:
        plot_sensitivity_analysis(
            all_summaries['segmented_fewshot_verification'],  # ✓ CORRECT
            comparison_dir / 'charts' / 'sensitivity_analysis_fewshot_verification.png',
            'Sensitivity Analysis - Fewshot Verification'
        )
    # 3. Gallery strategy visualization
    print("\n3. Creating gallery strategy visualizations...")
    plot_gallery_strategy_comparison(
        all_summaries['gallery_strategy'],
        comparison_dir / 'charts' / 'gallery_strategy_comparison.png'
    )
    
    # ========================================================================
    # EXPORT REPORTS
    # ========================================================================
    
    print(f"\n{'#'*80}")
    print("# EXPORTING COMPREHENSIVE REPORTS")
    print(f"{'#'*80}")
    
    export_comprehensive_report(
        all_model_results,
        all_summaries,
        comparison_dir / 'reports'
    )
    
    print(f"\n{'='*80}")
    print("EVALUATION PIPELINE COMPLETE")
    print(f"{'='*80}")
    print(f"\nAll results saved to: {output_base_dir}")
    print(f"  - Individual results: {results_dir}")
    print(f"  - Individual plots: {plots_dir}")
    print(f"  - Comparisons: {comparison_dir}")
    print(f"  - Reports: {comparison_dir / 'reports'}")
    
    return all_model_results, all_summaries

In [None]:
def load_pkl(path: Path):
    with open(path, "rb") as f:
        return pickle.load(f)


def load_all_embeddings(base_dir: Path):
    """
    Auto-loads all embeddings following this structure:

    base_dir /
        adaface_ir_101 /
            gallery_few-shot_augmented.pkl
            gallery_few-shot_base.pkl
            ...
        arcface_ir_50 /
            ...

    Returns: dict formatted exactly for run_complete_evaluation_pipeline()
    """
    mapping = {
        "gallery_one-shot_base.pkl": "gallery_oneshot_base",
        "gallery_one-shot_augmented.pkl": "gallery_oneshot_augmented",
        "gallery_few-shot_base.pkl": "gallery_fewshot_base",
        "gallery_few-shot_augmented.pkl": "gallery_fewshot_augmented",
        "probe_negative.pkl": "probe_negative",
        "probe_positive_segmented.pkl": "probe_positive_segmented",
        "probe_positive_unsegmented.pkl": "probe_positive_unsegmented",
    }

    all_embeddings = {}

    for model_dir in base_dir.iterdir():
        if not model_dir.is_dir():
            continue

        model_name = model_dir.name
        all_embeddings[model_name] = {}

        for file in model_dir.iterdir():
            if file.suffix != ".pkl":
                continue

            key = mapping.get(file.name)
            if key is None:
                print(f"Warning: Unrecognized file {file.name}, skipping…")
                continue

            all_embeddings[model_name][key] = load_pkl(file)

    return all_embeddings

In [None]:
all_embeddings = {}

for model in models:
    model_dir = embeddings_root / model
    if not model_dir.exists():
        print(f"Warning: {model_dir} not found, skipping...")
        continue

    all_embeddings[model] = {}

    for file in model_dir.glob("*.pkl"):
        fname = file.name

        # Map filename → dictionary key
        if "one-shot_base" in fname:
            key = "gallery_oneshot_base"
        elif "one-shot_augmented" in fname:
            key = "gallery_oneshot_augmented"
        elif "few-shot_base" in fname:
            key = "gallery_fewshot_base"
        elif "few-shot_augmented" in fname:
            key = "gallery_fewshot_augmented"
        else:
            # probe_negative, probe_positive_segmented, etc.
            key = fname.replace(".pkl", "")

        with open(file, "rb") as f:
            all_embeddings[model][key] = pickle.load(f)

In [None]:
all_results, all_summaries = run_complete_evaluation_pipeline(
    all_embeddings,
    output_root
)

print(all_summaries['executive_summary'])
print(all_summaries['comparison_summary'])
print(all_summaries['gallery_strategy'])