In [None]:
import sys
import os
from pathlib import Path
import numpy as np
import pickle
import json
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve, auc
from tqdm import tqdm
import pandas as pd
from datetime import datetime
NOTEBOOK_DIR = Path.cwd()
PROJECT_ROOT = NOTEBOOK_DIR.parents[2]

OUTPUT_ROOT = PROJECT_ROOT / "output" / "v0"

In [None]:
def find_project_root(start: Path, target_folder="RIG"):
    for parent in [start] + list(start.parents):
        if parent.name == target_folder:
            return parent
    raise RuntimeError(f"Could not find project root '{target_folder}'")

NOTEBOOK_DIR = Path.cwd()
PROJECT_ROOT = find_project_root(NOTEBOOK_DIR)
OUTPUT_ROOT = PROJECT_ROOT / "output" / "v0"

print(NOTEBOOK_DIR)
print(PROJECT_ROOT)
print(OUTPUT_ROOT)


In [None]:
import numpy as np
import pickle
import json
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve, auc
from tqdm import tqdm
import pandas as pd
from datetime import datetime

output_root = Path(OUTPUT_ROOT)
embeddings_root = output_root / 'embeddings'
results_dir = output_root / 'evaluation_results'
results_dir.mkdir(parents=True, exist_ok=True)

models = [
  'adaface_ir_50',
  'adaface_ir_101',
  'arcface_ir_50',
  'arcface_ir_101'
]

In [None]:
def load_embeddings(model_name: str) -> Dict:
  model_dir = embeddings_root / model_name
  
  if not model_dir.exists():
    raise FileNotFoundError(f"Model directory not found: {model_dir}")
  
  embeddings = {
    'gallery_oneshot_base': None,
    'gallery_oneshot_augmented': None,
    'gallery_fewshot_base': None,
    'gallery_fewshot_augmented': None,
    'probe_positive_unsegmented': None,
    'probe_positive_segmented': None,
    'probe_negative': None
  }

  oneshot_base_path = model_dir / 'gallery_one-shot_base.pkl'
  if oneshot_base_path.exists():
    with open(oneshot_base_path, 'rb') as f:
      embeddings['gallery_oneshot_base'] = pickle.load(f)
  
  oneshot_aug_path = model_dir / 'gallery_one-shot_augmented.pkl'
  if oneshot_aug_path.exists():
    with open(oneshot_aug_path, 'rb') as f:
      embeddings['gallery_oneshot_augmented'] = pickle.load(f)

  fewshot_base_path = model_dir / 'gallery_few-shot_base.pkl'
  if fewshot_base_path.exists():
    with open(fewshot_base_path, 'rb') as f:
      embeddings['gallery_fewshot_base'] = pickle.load(f)
  
  fewshot_aug_path = model_dir / 'gallery_few-shot_augmented.pkl'
  if fewshot_aug_path.exists():
    with open(fewshot_aug_path, 'rb') as f:
      embeddings['gallery_fewshot_augmented'] = pickle.load(f)

  probe_pos_unseg_path = model_dir / 'probe_positive_unsegmented.pkl'
  if probe_pos_unseg_path.exists():
    with open(probe_pos_unseg_path, 'rb') as f:
      embeddings['probe_positive_unsegmented'] = pickle.load(f)

  probe_pos_seg_path = model_dir / 'probe_positive_segmented.pkl'
  if probe_pos_seg_path.exists():
    with open(probe_pos_seg_path, 'rb') as f:
      embeddings['probe_positive_segmented'] = pickle.load(f)

  probe_neg_path = model_dir / 'probe_negative.pkl'
  if probe_neg_path.exists():
    with open(probe_neg_path, 'rb') as f:
      embeddings['probe_negative'] = pickle.load(f)
  
  return embeddings

In [None]:
def cosine_similarity(emb1: np.ndarray, emb2: np.ndarray) -> float:
    norm1 = np.linalg.norm(emb1)
    norm2 = np.linalg.norm(emb2)
    if abs(norm1 - 1.0) < 0.01 and abs(norm2 - 1.0) < 0.01:
        return np.dot(emb1, emb2)
    return np.dot(emb1, emb2) / (norm1 * norm2)

In [None]:
def identify_probe(probe_embedding: np.ndarray, 
                   gallery_embeddings: Dict[str, Dict],
                   threshold: float) -> Tuple[str, float]:
  max_similarity = -1
  predicted_name = None
  
  for name, data in gallery_embeddings.items():
    gallery_embs = data['embeddings']

    for gallery_emb in gallery_embs:
      similarity = cosine_similarity(probe_embedding, gallery_emb)
      if similarity > max_similarity:
        max_similarity = similarity
        predicted_name = name

  if max_similarity < threshold:
    return None, max_similarity
  
  return predicted_name, max_similarity

In [None]:
def identify_probe_mean(probe_embedding: np.ndarray, 
                        gallery_embeddings: Dict[str, Dict],
                        threshold: float) -> Tuple[str, float]:
    max_similarity = -1
    predicted_name = None
    
    for name, data in gallery_embeddings.items():
        gallery_embs = data['embeddings']

        similarities = [cosine_similarity(probe_embedding, gallery_emb) 
                       for gallery_emb in gallery_embs]
        mean_similarity = np.mean(similarities)
        
        if mean_similarity > max_similarity:
            max_similarity = mean_similarity
            predicted_name = name
    
    if max_similarity < threshold:
        return None, max_similarity
    
    return predicted_name, max_similarity

In [None]:
def identify_probe_topk(probe_embedding: np.ndarray, 
                        gallery_embeddings: Dict[str, Dict],
                        threshold: float,
                        k: int = 3) -> Tuple[str, float]:
    max_similarity = -1
    predicted_name = None
    
    for name, data in gallery_embeddings.items():
        gallery_embs = data['embeddings']
        
        similarities = [cosine_similarity(probe_embedding, gallery_emb) 
                       for gallery_emb in gallery_embs]
        topk_similarity = np.mean(sorted(similarities, reverse=True)[:k])
        
        if topk_similarity > max_similarity:
            max_similarity = topk_similarity
            predicted_name = name
    
    if max_similarity < threshold:
        return None, max_similarity
    
    return predicted_name, max_similarity

In [None]:
def plot_impostor_metrics(results_df: pd.DataFrame, title: str, save_path: Path):
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    ax = axes[0, 0]
    ax.plot(results_df['threshold'], results_df['rejection_rate'], 'g-', linewidth=2, label='Rejection Rate (TNR)')
    ax.set_xlabel('Threshold')
    ax.set_ylabel('Rate')
    ax.set_title('Impostor Rejection Rate')
    ax.grid(True, alpha=0.3)
    ax.legend()

    ax = axes[0, 1]
    ax.plot(results_df['threshold'], results_df['false_acceptance_rate'], 'r-', linewidth=2, label='False Acceptance Rate')
    ax.set_xlabel('Threshold')
    ax.set_ylabel('Rate')
    ax.set_title('False Acceptance Rate (Impostors)')
    ax.grid(True, alpha=0.3)
    ax.legend()
    
    ax = axes[1, 0]
    ax.plot(results_df['threshold'], results_df['rejection_rate'], 'g-', linewidth=2, label='Rejection Rate')
    ax.plot(results_df['threshold'], results_df['false_acceptance_rate'], 'r-', linewidth=2, label='FAR')
    ax.set_xlabel('Threshold')
    ax.set_ylabel('Rate')
    ax.set_title('Rejection vs False Acceptance Trade-off')
    ax.grid(True, alpha=0.3)
    ax.legend()

    ax = axes[1, 1]
    ax.plot(results_df['threshold'], results_df['avg_impostor_similarity'], 'b-', linewidth=2, label='Avg All')
    ax.plot(results_df['threshold'], results_df['avg_accepted_similarity'], 'r--', linewidth=2, label='Avg Accepted (FP)')
    ax.plot(results_df['threshold'], results_df['avg_rejected_similarity'], 'g--', linewidth=2, label='Avg Rejected (TN)')
    ax.set_xlabel('Threshold')
    ax.set_ylabel('Similarity')
    ax.set_title('Impostor Similarity Distributions')
    ax.grid(True, alpha=0.3)
    ax.legend()
    
    plt.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Plot saved to: {save_path}")

In [None]:
def plot_metrics(results_df: pd.DataFrame, model_name: str, save_path: Path = None):
  fig, axes = plt.subplots(2, 2, figsize=(15, 12))
  
  axes[0, 0].plot(results_df['threshold'], results_df['rank1_accuracy'], 'b-', linewidth=2)
  axes[0, 0].set_xlabel('Threshold')
  axes[0, 0].set_ylabel('Rank-1 Accuracy')
  axes[0, 0].set_title(f'{model_name}: Rank-1 Accuracy vs Threshold')
  axes[0, 0].grid(True, alpha=0.3)

  axes[0, 1].plot(results_df['threshold'], results_df['false_acceptance_rate'], 
                  'r-', linewidth=2, label='FAR')
  axes[0, 1].plot(results_df['threshold'], results_df['false_rejection_rate'], 
                  'g-', linewidth=2, label='FRR')
  axes[0, 1].set_xlabel('Threshold')
  axes[0, 1].set_ylabel('Rate')
  axes[0, 1].set_title(f'{model_name}: FAR and FRR vs Threshold')
  axes[0, 1].legend()
  axes[0, 1].grid(True, alpha=0.3)
  
  axes[1, 0].plot(results_df['threshold'], results_df['precision'], 
                  'b-', linewidth=2, label='Precision')
  axes[1, 0].plot(results_df['threshold'], results_df['recall'], 
                  'orange', linewidth=2, label='Recall')
  axes[1, 0].plot(results_df['threshold'], results_df['f1_score'], 
                  'purple', linewidth=2, label='F1-Score')
  axes[1, 0].set_xlabel('Threshold')
  axes[1, 0].set_ylabel('Score')
  axes[1, 0].set_title(f'{model_name}: Precision, Recall, F1 vs Threshold')
  axes[1, 0].legend()
  axes[1, 0].grid(True, alpha=0.3)

  axes[1, 1].plot(results_df['threshold'], results_df['avg_correct_similarity'], 
                  'g-', linewidth=2, label='Correct Matches')
  axes[1, 1].plot(results_df['threshold'], results_df['avg_incorrect_similarity'], 
                  'r-', linewidth=2, label='Incorrect Matches')
  axes[1, 1].axhline(y=results_df['avg_similarity'].iloc[0], 
                      color='black', linestyle='--', alpha=0.5, label='Overall Avg')
  axes[1, 1].set_xlabel('Threshold')
  axes[1, 1].set_ylabel('Cosine Similarity')
  axes[1, 1].set_title(f'{model_name}: Average Similarities')
  axes[1, 1].legend()
  axes[1, 1].grid(True, alpha=0.3)
  
  plt.tight_layout()
  
  if save_path:
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
  plt.show()

In [None]:
def evaluate_identification(gallery_embeddings: Dict[str, Dict],
                          probe_embeddings: Dict[str, Dict],
                          thresholds: List[float],
                          aggregation: str = 'max',
                          topk: int = 3) -> pd.DataFrame:
    if aggregation == 'mean':
        identify_func = identify_probe_mean
    elif aggregation == 'topk':
        identify_func = lambda probe_emb, gallery, thresh: identify_probe_topk(
            probe_emb, gallery, thresh, k=topk
        )
    else:
        identify_func = identify_probe
    
    results = []
    probe_data = probe_embeddings.get("all", probe_embeddings)
    
    for threshold in tqdm(thresholds, desc=f"Evaluating thresholds ({aggregation})"):
        true_positives = 0
        false_positives = 0 
        false_negatives = 0 
        true_negatives = 0 
        
        total_probes = 0
        similarities = []
        correct_similarities = []
        incorrect_similarities = []
        
        for true_name, data in probe_data.items():
            probe_embs = data['embeddings']
            
            for probe_emb in probe_embs:
                total_probes += 1
                predicted_name, similarity = identify_func(
                    probe_emb, gallery_embeddings, threshold
                )
                
                similarities.append(similarity)
                
                if predicted_name is None:
                    false_negatives += 1
                elif predicted_name == true_name:
                    true_positives += 1
                    correct_similarities.append(similarity)
                else:
                    false_positives += 1
                    incorrect_similarities.append(similarity)
        
        rank1_accuracy = true_positives / total_probes if total_probes > 0 else 0
        
        identification_rate = (true_positives) / total_probes if total_probes > 0 else 0

        false_acceptance_rate = false_positives / total_probes if total_probes > 0 else 0

        false_rejection_rate = false_negatives / total_probes if total_probes > 0 else 0

        identified = true_positives + false_positives
        precision = true_positives / identified if identified > 0 else 0
        recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0

        f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        
        results.append({
            'threshold': threshold,
            'rank1_accuracy': rank1_accuracy,
            'identification_rate': identification_rate,
            'false_acceptance_rate': false_acceptance_rate,
            'false_rejection_rate': false_rejection_rate,
            'precision': precision,
            'recall': recall,
            'f1_score': f1_score,
            'true_positives': true_positives,
            'false_positives': false_positives,
            'false_negatives': false_negatives,
            'total_probes': total_probes,
            'avg_similarity': np.mean(similarities),
            'avg_correct_similarity': np.mean(correct_similarities) if correct_similarities else 0,
            'avg_incorrect_similarity': np.mean(incorrect_similarities) if incorrect_similarities else 0,
            'aggregation': aggregation
        })
    
    return pd.DataFrame(results)

In [None]:
def evaluate_impostors(gallery_embeddings: Dict[str, Dict],
                       impostor_embeddings: Dict[str, Dict],
                       thresholds: List[float],
                       aggregation: str = 'max',
                       topk: int = 3) -> pd.DataFrame:
    if aggregation == 'mean':
        identify_func = identify_probe_mean
    elif aggregation == 'topk':
        identify_func = lambda probe_emb, gallery, thresh: identify_probe_topk(
            probe_emb, gallery, thresh, k=topk
        )
    else:
        identify_func = identify_probe
    
    results = []
    
    for threshold in tqdm(thresholds, desc=f"Evaluating impostors ({aggregation})"):
        true_negatives = 0
        false_positives = 0
        
        total_impostors = 0
        impostor_similarities = []
        accepted_impostor_similarities = []
        rejected_impostor_similarities = []
        
        # Iterate over dataset types (e.g., 'real', 'lfw')
        for dataset_name, data in impostor_embeddings.items():
            impostor_embs = data['embeddings']
            impostor_files = data['filenames']
            
            # Each embedding is from a different impostor
            for idx, impostor_emb in enumerate(impostor_embs):
                total_impostors += 1
                predicted_name, similarity = identify_func(
                    impostor_emb, gallery_embeddings, threshold
                )
                
                impostor_similarities.append(similarity)

                if predicted_name is None:
                    # Correctly rejected (true negative)
                    true_negatives += 1
                    rejected_impostor_similarities.append(similarity)
                else:
                    # Incorrectly accepted (false positive)
                    false_positives += 1
                    accepted_impostor_similarities.append(similarity)
        
        rejection_rate = true_negatives / total_impostors if total_impostors > 0 else 0
        false_acceptance_rate = false_positives / total_impostors if total_impostors > 0 else 0
        
        results.append({
            'threshold': threshold,
            'rejection_rate': rejection_rate,
            'false_acceptance_rate': false_acceptance_rate,
            'true_negatives': true_negatives,
            'false_positives': false_positives,
            'total_impostors': total_impostors,
            'avg_impostor_similarity': np.mean(impostor_similarities) if impostor_similarities else 0,
            'avg_accepted_similarity': np.mean(accepted_impostor_similarities) if accepted_impostor_similarities else 0,
            'avg_rejected_similarity': np.mean(rejected_impostor_similarities) if rejected_impostor_similarities else 0,
            'aggregation': aggregation
        })
    
    return pd.DataFrame(results)

In [None]:
def evaluate_segmented_probes(gallery_embeddings: Dict[str, Dict],
                              probe_embeddings: Dict[str, Dict],
                              thresholds: List[float],
                              aggregation: str = 'max',
                              topk: int = 3) -> Dict[str, pd.DataFrame]:
    if aggregation == 'mean':
        identify_func = identify_probe_mean
    elif aggregation == 'topk':
        identify_func = lambda probe_emb, gallery, thresh: identify_probe_topk(
            probe_emb, gallery, thresh, k=topk
        )
    else:
        identify_func = identify_probe
    
    segment_results = {}
    segments = [k for k in probe_embeddings.keys() if k != 'all']
    
    if not segments:
        print("Warning: No segments found in probe embeddings")
        return segment_results
    
    print(f"Found {len(segments)} segments: {segments}")
    
    for segment_name in segments:
        segment_data = probe_embeddings[segment_name]
        results = []
        
        for threshold in tqdm(thresholds, desc=f"Evaluating {segment_name} ({aggregation})", leave=False):
            true_positives = 0
            false_positives = 0
            false_negatives = 0
            
            total_probes = 0
            similarities = []
            correct_similarities = []
            incorrect_similarities = []
            
            for true_name, data in segment_data.items():
                probe_embs = data['embeddings']
                
                for probe_emb in probe_embs:
                    total_probes += 1
                    predicted_name, similarity = identify_func(
                        probe_emb, gallery_embeddings, threshold
                    )
                    
                    similarities.append(similarity)
                    
                    if predicted_name is None:
                        false_negatives += 1
                    elif predicted_name == true_name:
                        true_positives += 1
                        correct_similarities.append(similarity)
                    else:
                        false_positives += 1
                        incorrect_similarities.append(similarity)
            
            rank1_accuracy = true_positives / total_probes if total_probes > 0 else 0
            identification_rate = true_positives / total_probes if total_probes > 0 else 0
            false_acceptance_rate = false_positives / total_probes if total_probes > 0 else 0
            false_rejection_rate = false_negatives / total_probes if total_probes > 0 else 0
            
            identified = true_positives + false_positives
            precision = true_positives / identified if identified > 0 else 0
            recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
            f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
            
            results.append({
                'threshold': threshold,
                'segment': segment_name,
                'rank1_accuracy': rank1_accuracy,
                'identification_rate': identification_rate,
                'false_acceptance_rate': false_acceptance_rate,
                'false_rejection_rate': false_rejection_rate,
                'precision': precision,
                'recall': recall,
                'f1_score': f1_score,
                'true_positives': true_positives,
                'false_positives': false_positives,
                'false_negatives': false_negatives,
                'total_probes': total_probes,
                'avg_similarity': np.mean(similarities),
                'avg_correct_similarity': np.mean(correct_similarities) if correct_similarities else 0,
                'avg_incorrect_similarity': np.mean(incorrect_similarities) if incorrect_similarities else 0,
                'aggregation': aggregation
            })
        
        segment_results[segment_name] = pd.DataFrame(results)
    
    return segment_results

In [None]:
def run_impostor_evaluation(model_name: str, embeddings: Dict,
                           results_dir: Path, gallery_type: str = 'base',
                           impostor_type: str = 'real',
                           thresholds: np.ndarray = None,
                           aggregations: List[str] = ['max']):
    if thresholds is None:
        thresholds = np.arange(0.2, 0.91, 0.05)
    
    print(f"\n{'='*60}")
    print(f"Evaluating {model_name} - Impostor ({impostor_type.upper()}, {gallery_type.upper()} gallery)")
    print(f"{'='*60}")
    
    # Get the appropriate gallery
    if gallery_type == 'oneshot_base':
        gallery = embeddings.get('gallery_oneshot_base')
    elif gallery_type == 'oneshot_aug':
        gallery = embeddings.get('gallery_oneshot_augmented')
    elif gallery_type == 'fewshot_base':
        gallery = embeddings.get('gallery_fewshot_base')
    elif gallery_type == 'fewshot_aug':
        gallery = embeddings.get('gallery_fewshot_augmented')
    else:
        gallery = embeddings.get(f'gallery_{gallery_type}')

    impostor_embeddings_full = embeddings.get('probe_negative')
    
    if gallery is None or impostor_embeddings_full is None:
        print(f"Missing embeddings for impostor evaluation")
        return None

    # Filter impostor embeddings based on type
    if impostor_type == 'real':
        # Use only 'real' dataset
        if 'real' not in impostor_embeddings_full:
            print(f"No 'real' impostors found in probe_negative")
            return None
        impostor_embeddings = {'real': impostor_embeddings_full['real']}
        print(f"Using {len(impostor_embeddings_full['real']['embeddings'])} REAL impostors")
        
    elif impostor_type == 'lfw':
        # Use only 'lfw' dataset
        if 'lfw' not in impostor_embeddings_full:
            print(f"No 'lfw' impostors found in probe_negative")
            return None
        impostor_embeddings = {'lfw': impostor_embeddings_full['lfw']}
        print(f"Using {len(impostor_embeddings_full['lfw']['embeddings'])} LFW impostors")
        
    elif impostor_type == 'combined':
        # Use both datasets
        impostor_embeddings = {}
        total_count = 0
        if 'real' in impostor_embeddings_full:
            impostor_embeddings['real'] = impostor_embeddings_full['real']
            total_count += len(impostor_embeddings_full['real']['embeddings'])
        if 'lfw' in impostor_embeddings_full:
            impostor_embeddings['lfw'] = impostor_embeddings_full['lfw']
            total_count += len(impostor_embeddings_full['lfw']['embeddings'])
        print(f"Using {total_count} combined impostors (real + lfw)")
        
    else:
        print(f"Unknown impostor_type: {impostor_type}. Use 'real', 'lfw', or 'combined'")
        return None
    
    model_results_dir = results_dir / model_name
    model_results_dir.mkdir(parents=True, exist_ok=True)
    
    all_results = {}
    
    for agg in aggregations:
        print(f"\n--- Testing aggregation: {agg.upper()} ---")
        results_df = evaluate_impostors(gallery, impostor_embeddings, thresholds, aggregation=agg)
        all_results[agg] = results_df

        csv_path = model_results_dir / f'impostor_{impostor_type}_{gallery_type}_{agg}_metrics.csv'
        results_df.to_csv(csv_path, index=False)
        print(f"Results saved to: {csv_path}")

        plot_path = model_results_dir / f'impostor_{impostor_type}_{gallery_type}_{agg}_plot.png'
        plot_impostor_metrics(results_df, 
                             f"{model_name} - Impostor ({impostor_type}, {gallery_type}) - {agg.upper()}", 
                             plot_path)

        best_rejection_idx = results_df['rejection_rate'].idxmax()
        best_rejection = results_df.iloc[best_rejection_idx]
        print(f"Best Rejection Rate ({agg}): {best_rejection['rejection_rate']:.4f} at threshold {best_rejection['threshold']:.2f}")
        print(f"  FAR at this threshold: {best_rejection['false_acceptance_rate']:.4f}")
        print(f"  Total impostors tested: {int(best_rejection['total_impostors'])}")
    
    return all_results

In [None]:
def run_segmented_evaluation(model_name: str, embeddings: Dict,
                             results_dir: Path, gallery_type: str = 'base',
                             thresholds: np.ndarray = None,
                             aggregations: List[str] = ['max']):
    if thresholds is None:
        thresholds = np.arange(0.2, 0.91, 0.05)
    
    print(f"\n{'='*60}")
    print(f"Evaluating {model_name} - Segmented Probes ({gallery_type.upper()} gallery)")
    print(f"{'='*60}")

    if gallery_type == 'oneshot_base':
        gallery = embeddings.get('gallery_oneshot_base')
    elif gallery_type == 'oneshot_augmented':
        gallery = embeddings.get('gallery_oneshot_augmented')
    elif gallery_type == 'fewshot_base':
        gallery = embeddings.get('gallery_fewshot_base')
    elif gallery_type == 'fewshot_augmented':
        gallery = embeddings.get('gallery_fewshot_augmented')
    else:
        gallery = embeddings.get(f'gallery_{gallery_type}')
    
    probe = embeddings.get('probe_positive_segmented')
    
    if gallery is None or probe is None:
        print(f"Missing embeddings for segmented evaluation")
        return None
    
    model_results_dir = results_dir / model_name
    model_results_dir.mkdir(parents=True, exist_ok=True)
    
    all_results = {}
    
    for agg in aggregations:
        print(f"\n--- Testing aggregation: {agg.upper()} ---")
        segment_results = evaluate_segmented_probes(gallery, probe, thresholds, aggregation=agg)
        all_results[agg] = segment_results

        for segment_name, results_df in segment_results.items():
            csv_path = model_results_dir / f'segmented_{segment_name}_{gallery_type}_{agg}_metrics.csv'
            results_df.to_csv(csv_path, index=False)

            plot_path = model_results_dir / f'segmented_{segment_name}_{gallery_type}_{agg}_plot.png'
            plot_metrics(results_df, 
                        f"{model_name} - {segment_name} ({gallery_type}) - {agg.upper()}", 
                        plot_path)
            
            # Print summary
            best_rank1_idx = results_df['rank1_accuracy'].idxmax()
            best_rank1 = results_df.iloc[best_rank1_idx]
            print(f"\n{segment_name}:")
            print(f"  Best Rank-1 ({agg}): {best_rank1['rank1_accuracy']:.4f} at threshold {best_rank1['threshold']:.2f}")
        
        # Comparison across segments
        if len(segment_results) > 1:
            print(f"\n{'='*70}")
            print(f"SEGMENT COMPARISON ({agg.upper()})")
            print(f"{'='*70}")
            for segment_name, results_df in segment_results.items():
                best_rank1 = results_df['rank1_accuracy'].max()
                best_f1 = results_df['f1_score'].max()
                print(f"{segment_name:20s} → Rank-1: {best_rank1:.4f} | F1: {best_f1:.4f}")
    
    return all_results

In [None]:
def run_oneshot_evaluation(model_name: str, embeddings: Dict, 
                          results_dir: Path, gallery_type: str = 'base',
                          thresholds: np.ndarray = None,
                          aggregations: List[str] = ['max']):
    if thresholds is None:
        thresholds = np.arange(0.2, 0.91, 0.05)
    
    print(f"\n{'='*60}")
    print(f"Evaluating {model_name} - One-Shot Gallery ({gallery_type.upper()})")
    print(f"{'='*60}")
    
    gallery_key = f'gallery_oneshot_{gallery_type}'
    gallery = embeddings.get(gallery_key)
    probe = embeddings['probe_positive_unsegmented']
    
    if gallery is None or probe is None:
        print(f"Missing embeddings for {model_name} - {gallery_type}")
        return None
    
    print(f"Gallery identities: {len(gallery)}")
    print(f"Probe identities: {len(probe.get('all', probe))}")
    
    model_results_dir = results_dir / model_name
    model_results_dir.mkdir(parents=True, exist_ok=True)
    
    all_results = {}

    for agg in aggregations:
        print(f"\n--- Testing aggregation: {agg.upper()} ---")
        results_df = evaluate_identification(gallery, probe, thresholds, aggregation=agg)
        all_results[agg] = results_df

        csv_path = model_results_dir / f'oneshot_{gallery_type}_{agg}_metrics.csv'
        results_df.to_csv(csv_path, index=False)
        print(f"Results saved to: {csv_path}")

        plot_path = model_results_dir / f'oneshot_{gallery_type}_{agg}_metrics_plot.png'
        plot_metrics(results_df, f"{model_name} - One-Shot ({gallery_type.capitalize()}) - {agg.upper()}", plot_path)

        best_rank1_idx = results_df['rank1_accuracy'].idxmax()
        best_rank1 = results_df.iloc[best_rank1_idx]
        print(f"Best Rank-1 Accuracy ({agg}): {best_rank1['rank1_accuracy']:.4f} at threshold {best_rank1['threshold']:.2f}")

    if len(aggregations) > 1:
        print(f"\n{'='*70}")
        print(f"AGGREGATION COMPARISON - {gallery_type.upper()}")
        print(f"{'='*70}")
        for agg, results_df in all_results.items():
            best_rank1 = results_df['rank1_accuracy'].max()
            best_f1 = results_df['f1_score'].max()
            print(f"{agg.upper():8s} → Best Rank-1: {best_rank1:.4f} | Best F1: {best_f1:.4f}")
    
    return all_results


def run_fewshot_evaluation(model_name: str, embeddings: Dict, 
                          results_dir: Path, gallery_type: str = 'base',
                          thresholds: np.ndarray = None,
                          aggregations: List[str] = ['max', 'mean', 'topk']):
    if thresholds is None:
        thresholds = np.arange(0.2, 0.91, 0.05)
    
    print(f"\n{'='*60}")
    print(f"Evaluating {model_name} - Few-Shot Gallery ({gallery_type.upper()})")
    print(f"{'='*60}")
    
    gallery_key = f'gallery_fewshot_{gallery_type}'
    gallery = embeddings.get(gallery_key)
    probe = embeddings['probe_positive_unsegmented']
    
    if gallery is None or probe is None:
        print(f"Missing embeddings for {model_name} - {gallery_type}")
        return None

    min_required = 2 if gallery_type == 'base' else 16
    filtered_gallery = {}
    invalid_identities = []
    
    for identity, data in gallery.items():
        num_embeddings = len(data['embeddings'])
        if num_embeddings < min_required:
            invalid_identities.append((identity, num_embeddings))
        else:
            filtered_gallery[identity] = data
    
    if invalid_identities:
        print(f"\nWARNING: Filtered out {len(invalid_identities)} identities with insufficient embeddings:")
        print(f"   Required: >= {min_required} embeddings per identity")
        for identity, count in invalid_identities[:5]:
            print(f"   - {identity}: {count} embeddings")
        if len(invalid_identities) > 5:
            print(f"   ... and {len(invalid_identities) - 5} more")
        print(f"\n✓ Continuing with {len(filtered_gallery)} valid identities")
    
    print(f"Gallery identities: {len(gallery)}")
    print(f"Probe identities: {len(probe.get('all', probe))}")
    
    embedding_counts = [len(data['embeddings']) for data in gallery.values()]
    print(f"Embeddings per identity - Min: {min(embedding_counts)}, "
          f"Max: {max(embedding_counts)}, Mean: {np.mean(embedding_counts):.1f}")
    
    model_results_dir = results_dir / model_name
    model_results_dir.mkdir(parents=True, exist_ok=True)
    
    all_results = {}

    for agg in aggregations:
        print(f"\n--- Testing aggregation: {agg.upper()} ---")
        results_df = evaluate_identification(gallery, probe, thresholds, aggregation=agg)
        all_results[agg] = results_df

        csv_path = model_results_dir / f'fewshot_{gallery_type}_{agg}_metrics.csv'
        results_df.to_csv(csv_path, index=False)
        print(f"Results saved to: {csv_path}")

        plot_path = model_results_dir / f'fewshot_{gallery_type}_{agg}_metrics_plot.png'
        plot_metrics(results_df, f"{model_name} - Few-Shot ({gallery_type.capitalize()}) - {agg.upper()}", plot_path)

        best_rank1_idx = results_df['rank1_accuracy'].idxmax()
        best_rank1 = results_df.iloc[best_rank1_idx]
        print(f"Best Rank-1 Accuracy ({agg}): {best_rank1['rank1_accuracy']:.4f} at threshold {best_rank1['threshold']:.2f}")

    if len(aggregations) > 1:
        print(f"\n{'='*70}")
        print(f"AGGREGATION COMPARISON - {gallery_type.upper()}")
        print(f"{'='*70}")
        for agg, results_df in all_results.items():
            best_rank1 = results_df['rank1_accuracy'].max()
            best_f1 = results_df['f1_score'].max()
            print(f"{agg.upper():8s} → Best Rank-1: {best_rank1:.4f} | Best F1: {best_f1:.4f}")
    
    return all_results

In [None]:
# print(f"\n{'#'*70}")
# print(f"# OVERALL BEST MODEL COMPARISON")
# print(f"{'#'*70}")

# all_results = {
#     'oneshot_base': {},
#     'oneshot_aug': {},
#     'fewshot_base': {},
#     'fewshot_aug': {}
# }

# for model_name in models:
#     embeddings = load_embeddings(model_name)

#     results_oneshot_base = run_oneshot_evaluation(
#         model_name, 
#         embeddings, 
#         results_dir, 
#         gallery_type='base',
#         aggregations=['max']
#     )
#     results_oneshot_aug = run_oneshot_evaluation(
#         model_name, 
#         embeddings, 
#         results_dir, 
#         gallery_type='augmented',
#         aggregations=['max', 'mean', 'topk']
#     )
#     all_results['oneshot_base'][model_name] = results_oneshot_base
#     all_results['oneshot_aug'][model_name] = results_oneshot_aug

#     if results_oneshot_base is not None and results_oneshot_aug is not None:
#         print(f"\n{'='*70}")
#         print(f"ONE-SHOT COMPARISON: BASE vs AUGMENTED - {model_name}")
#         print(f"{'='*70}")

#         best_base = results_oneshot_base['max']['rank1_accuracy'].max()

#         print(f"Base (max):       Rank-1 = {best_base:.4f}")
#         print(f"\nAugmented:")
#         for agg in ['max', 'mean', 'topk']:
#             best_aug = results_oneshot_aug[agg]['rank1_accuracy'].max()
#             improvement = best_aug - best_base
#             pct_improvement = (best_aug/best_base - 1)*100 if best_base > 0 else 0
#             print(f"  {agg:6s}:      Rank-1 = {best_aug:.4f} | "
#                   f"Δ = {improvement:+.4f} ({pct_improvement:+.2f}%)")

#     results_fewshot_base = run_fewshot_evaluation(
#         model_name, 
#         embeddings, 
#         results_dir, 
#         gallery_type='base',
#         aggregations=['max', 'mean', 'topk']
#     )
#     results_fewshot_aug = run_fewshot_evaluation(
#         model_name, 
#         embeddings, 
#         results_dir, 
#         gallery_type='augmented',
#         aggregations=['max', 'mean', 'topk']
#     )

#     all_results['fewshot_base'][model_name] = results_fewshot_base
#     all_results['fewshot_aug'][model_name] = results_fewshot_aug

#     if results_fewshot_base is not None and results_fewshot_aug is not None:
#         print(f"\n{'='*70}")
#         print(f"FEW-SHOT COMPARISON: BASE vs AUGMENTED - {model_name}")
#         print(f"{'='*70}")

#         for agg in ['max', 'mean', 'topk']:
#             best_base = results_fewshot_base[agg]['rank1_accuracy'].max()
#             best_aug = results_fewshot_aug[agg]['rank1_accuracy'].max()
#             improvement = best_aug - best_base
#             pct_improvement = (best_aug/best_base - 1)*100 if best_base > 0 else 0
            
#             print(f"\n{agg.upper()} Aggregation:")
#             print(f"  Base:       Rank-1 = {best_base:.4f}")
#             print(f"  Augmented:  Rank-1 = {best_aug:.4f}")
#             print(f"  Improvement: {improvement:+.4f} ({pct_improvement:+.2f}%)")

#         print(f"\n{'─'*70}")
#         print("OVERALL BEST (any aggregation):")
#         best_base_overall = max(results_fewshot_base[agg]['rank1_accuracy'].max() 
#                                 for agg in ['max', 'mean', 'topk'])
#         best_aug_overall = max(results_fewshot_aug[agg]['rank1_accuracy'].max() 
#                                for agg in ['max', 'mean', 'topk'])

#         best_base_agg = max(['max', 'mean', 'topk'], 
#                            key=lambda a: results_fewshot_base[a]['rank1_accuracy'].max())
#         best_aug_agg = max(['max', 'mean', 'topk'], 
#                           key=lambda a: results_fewshot_aug[a]['rank1_accuracy'].max())
        
#         print(f"  Base best:       {best_base_overall:.4f} ({best_base_agg})")
#         print(f"  Augmented best:  {best_aug_overall:.4f} ({best_aug_agg})")
#         improvement = best_aug_overall - best_base_overall
#         pct_improvement = (best_aug_overall/best_base_overall - 1)*100 if best_base_overall > 0 else 0
#         print(f"  Improvement:     {improvement:+.4f} ({pct_improvement:+.2f}%)")


In [None]:
print(f"\n{'#'*70}")
print(f"# COMPREHENSIVE MODEL EVALUATION")
print(f"{'#'*70}")

all_results = {
    # Identification results
    'oneshot_base': {},
    'oneshot_aug': {},
    'fewshot_base': {},
    'fewshot_aug': {},
    
    # Impostor results - separated by type
    'impostor_oneshot_base_real': {},
    'impostor_oneshot_base_lfw': {},
    'impostor_oneshot_base_combined': {},
    'impostor_oneshot_aug_real': {},
    'impostor_oneshot_aug_lfw': {},
    'impostor_oneshot_aug_combined': {},
    'impostor_fewshot_base_real': {},
    'impostor_fewshot_base_lfw': {},
    'impostor_fewshot_base_combined': {},
    'impostor_fewshot_aug_real': {},
    'impostor_fewshot_aug_lfw': {},
    'impostor_fewshot_aug_combined': {},
    
    # Segmented results
    'segmented_oneshot_base': {},
    'segmented_oneshot_aug': {},
    'segmented_fewshot_base': {},
    'segmented_fewshot_aug': {}
}

for model_name in models:
    print(f"\n{'='*80}")
    print(f"PROCESSING MODEL: {model_name}")
    print(f"{'='*80}")
    
    embeddings = load_embeddings(model_name)

    # ========================================================================
    # 1. IDENTIFICATION EVALUATIONS (Genuine Probes)
    # ========================================================================
    
    print(f"\n{'─'*80}")
    print("1. IDENTIFICATION EVALUATIONS")
    print(f"{'─'*80}")
    
    # One-shot evaluations
    results_oneshot_base = run_oneshot_evaluation(
        model_name, 
        embeddings, 
        results_dir, 
        gallery_type='base',
        aggregations=['max']
    )
    results_oneshot_aug = run_oneshot_evaluation(
        model_name, 
        embeddings, 
        results_dir, 
        gallery_type='augmented',
        aggregations=['max', 'mean', 'topk']
    )
    all_results['oneshot_base'][model_name] = results_oneshot_base
    all_results['oneshot_aug'][model_name] = results_oneshot_aug

    if results_oneshot_base is not None and results_oneshot_aug is not None:
        print(f"\n{'='*70}")
        print(f"ONE-SHOT IDENTIFICATION: BASE vs AUGMENTED - {model_name}")
        print(f"{'='*70}")

        best_base = results_oneshot_base['max']['rank1_accuracy'].max()

        print(f"Base (max):       Rank-1 = {best_base:.4f}")
        print(f"\nAugmented:")
        for agg in ['max', 'mean', 'topk']:
            best_aug = results_oneshot_aug[agg]['rank1_accuracy'].max()
            improvement = best_aug - best_base
            pct_improvement = (best_aug/best_base - 1)*100 if best_base > 0 else 0
            print(f"  {agg:6s}:      Rank-1 = {best_aug:.4f} | "
                  f"Δ = {improvement:+.4f} ({pct_improvement:+.2f}%)")

    # Few-shot evaluations
    results_fewshot_base = run_fewshot_evaluation(
        model_name, 
        embeddings, 
        results_dir, 
        gallery_type='base',
        aggregations=['max', 'mean', 'topk']
    )
    results_fewshot_aug = run_fewshot_evaluation(
        model_name, 
        embeddings, 
        results_dir, 
        gallery_type='augmented',
        aggregations=['max', 'mean', 'topk']
    )

    all_results['fewshot_base'][model_name] = results_fewshot_base
    all_results['fewshot_aug'][model_name] = results_fewshot_aug

    if results_fewshot_base is not None and results_fewshot_aug is not None:
        print(f"\n{'='*70}")
        print(f"FEW-SHOT IDENTIFICATION: BASE vs AUGMENTED - {model_name}")
        print(f"{'='*70}")

        for agg in ['max', 'mean', 'topk']:
            best_base = results_fewshot_base[agg]['rank1_accuracy'].max()
            best_aug = results_fewshot_aug[agg]['rank1_accuracy'].max()
            improvement = best_aug - best_base
            pct_improvement = (best_aug/best_base - 1)*100 if best_base > 0 else 0
            
            print(f"\n{agg.upper()} Aggregation:")
            print(f"  Base:       Rank-1 = {best_base:.4f}")
            print(f"  Augmented:  Rank-1 = {best_aug:.4f}")
            print(f"  Improvement: {improvement:+.4f} ({pct_improvement:+.2f}%)")

        print(f"\n{'─'*70}")
        print("OVERALL BEST (any aggregation):")
        best_base_overall = max(results_fewshot_base[agg]['rank1_accuracy'].max() 
                                for agg in ['max', 'mean', 'topk'])
        best_aug_overall = max(results_fewshot_aug[agg]['rank1_accuracy'].max() 
                               for agg in ['max', 'mean', 'topk'])

        best_base_agg = max(['max', 'mean', 'topk'], 
                           key=lambda a: results_fewshot_base[a]['rank1_accuracy'].max())
        best_aug_agg = max(['max', 'mean', 'topk'], 
                          key=lambda a: results_fewshot_aug[a]['rank1_accuracy'].max())
        
        print(f"  Base best:       {best_base_overall:.4f} ({best_base_agg})")
        print(f"  Augmented best:  {best_aug_overall:.4f} ({best_aug_agg})")
        improvement = best_aug_overall - best_base_overall
        pct_improvement = (best_aug_overall/best_base_overall - 1)*100 if best_base_overall > 0 else 0
        print(f"  Improvement:     {improvement:+.4f} ({pct_improvement:+.2f}%)")

    # ========================================================================
    # 2. IMPOSTOR EVALUATIONS (SEPARATED BY TYPE)
    # ========================================================================
    
    print(f"\n{'─'*80}")
    print("2. IMPOSTOR REJECTION EVALUATIONS")
    print(f"{'─'*80}")
    
    # One-shot impostor evaluations
    for gallery_type, gallery_key in [('oneshot_base', 'base'), ('oneshot_aug', 'augmented')]:
        agg_methods = ['max'] if gallery_key == 'base' else ['max', 'mean', 'topk']
        
        # Evaluate each impostor type separately
        print(f"\nEvaluating {gallery_type} against impostors...")
        
        # Real impostors
        results_impostor_real = run_impostor_evaluation(
            model_name, embeddings, results_dir,
            gallery_type=gallery_type,
            impostor_type='real',
            aggregations=agg_methods
        )
        all_results[f'impostor_{gallery_type}_real'][model_name] = results_impostor_real
        
        # LFW impostors
        results_impostor_lfw = run_impostor_evaluation(
            model_name, embeddings, results_dir,
            gallery_type=gallery_type,
            impostor_type='lfw',
            aggregations=agg_methods
        )
        all_results[f'impostor_{gallery_type}_lfw'][model_name] = results_impostor_lfw
        
        # Combined impostors
        results_impostor_combined = run_impostor_evaluation(
            model_name, embeddings, results_dir,
            gallery_type=gallery_type,
            impostor_type='combined',
            aggregations=agg_methods
        )
        all_results[f'impostor_{gallery_type}_combined'][model_name] = results_impostor_combined
    
    # Few-shot impostor evaluations
    for gallery_type in ['fewshot_base', 'fewshot_aug']:
        print(f"\nEvaluating {gallery_type} against impostors...")
        
        # Real impostors
        results_impostor_real = run_impostor_evaluation(
            model_name, embeddings, results_dir,
            gallery_type=gallery_type,
            impostor_type='real',
            aggregations=['max', 'mean', 'topk']
        )
        all_results[f'impostor_{gallery_type}_real'][model_name] = results_impostor_real
        
        # LFW impostors
        results_impostor_lfw = run_impostor_evaluation(
            model_name, embeddings, results_dir,
            gallery_type=gallery_type,
            impostor_type='lfw',
            aggregations=['max', 'mean', 'topk']
        )
        all_results[f'impostor_{gallery_type}_lfw'][model_name] = results_impostor_lfw
        
        # Combined impostors
        results_impostor_combined = run_impostor_evaluation(
            model_name, embeddings, results_dir,
            gallery_type=gallery_type,
            impostor_type='combined',
            aggregations=['max', 'mean', 'topk']
        )
        all_results[f'impostor_{gallery_type}_combined'][model_name] = results_impostor_combined
    
    # Display impostor results summary
    print(f"\n{'='*70}")
    print(f"IMPOSTOR REJECTION SUMMARY - {model_name}")
    print(f"{'='*70}")
    
    impostor_summary = []
    for gallery_config in ['oneshot_base', 'oneshot_aug', 'fewshot_base', 'fewshot_aug']:
        for imp_type in ['real', 'lfw', 'combined']:
            key = f'impostor_{gallery_config}_{imp_type}'
            if all_results[key].get(model_name):
                results = all_results[key][model_name]
                if results:
                    best_agg = max(results.keys(), 
                                  key=lambda a: results[a]['rejection_rate'].max())
                    best_rejection = results[best_agg]['rejection_rate'].max()
                    best_far = results[best_agg].loc[
                        results[best_agg]['rejection_rate'].idxmax(), 
                        'false_acceptance_rate'
                    ]
                    total_impostors = results[best_agg]['total_impostors'].iloc[0]
                    
                    impostor_summary.append({
                        'gallery': gallery_config,
                        'impostor': imp_type,
                        'agg': best_agg,
                        'rejection': best_rejection,
                        'far': best_far,
                        'n_impostors': total_impostors
                    })
    
    if impostor_summary:
        print(f"\n{'Gallery':<20} {'Impostor':<10} {'Agg':<8} {'N':<6} {'Rejection':<12} {'FAR':<10}")
        print(f"{'-'*80}")
        for item in impostor_summary:
            print(f"{item['gallery']:<20} {item['impostor']:<10} {item['agg']:<8} "
                  f"{item['n_impostors']:<6} {item['rejection']:<12.4f} {item['far']:<10.4f}")

    # ========================================================================
    # 3. SEGMENTED PROBE EVALUATIONS
    # ========================================================================
    
    print(f"\n{'─'*80}")
    print("3. SEGMENTED PROBE EVALUATIONS")
    print(f"{'─'*80}")
    
    # One-shot segmented evaluations
    for gallery_type, gallery_key in [('oneshot_base', 'base'), ('oneshot_aug', 'augmented')]:
        agg_methods = ['max'] if gallery_key == 'base' else ['max', 'mean', 'topk']
        
        results_key = f'segmented_{gallery_type}'
        
        results_segmented = run_segmented_evaluation(
            model_name, embeddings, results_dir,
            gallery_type=gallery_type,
            aggregations=agg_methods
        )
        
        all_results[results_key][model_name] = results_segmented
    
    # Few-shot segmented evaluations
    for gallery_type in ['fewshot_base', 'fewshot_aug']:
        results_key = f'segmented_{gallery_type}'
        
        results_segmented = run_segmented_evaluation(
            model_name, embeddings, results_dir,
            gallery_type=gallery_type,
            aggregations=['max', 'mean', 'topk']
        )
        
        all_results[results_key][model_name] = results_segmented
    
    # Display segmented results summary
    print(f"\n{'='*70}")
    print(f"SEGMENTED PERFORMANCE SUMMARY - {model_name}")
    print(f"{'='*70}")
    
    for gallery_config in ['oneshot_base', 'oneshot_aug', 'fewshot_base', 'fewshot_aug']:
        key = f'segmented_{gallery_config}'
        if all_results[key].get(model_name):
            results = all_results[key][model_name]
            if results:
                print(f"\n{gallery_config.upper()}:")
                
                # Collect segment performance across aggregations
                segment_performance = {}
                for agg, seg_dict in results.items():
                    if isinstance(seg_dict, dict):
                        for seg_name, seg_df in seg_dict.items():
                            if seg_name not in segment_performance:
                                segment_performance[seg_name] = {}
                            best_rank1 = seg_df['rank1_accuracy'].max()
                            segment_performance[seg_name][agg] = best_rank1
                
                # Display table
                if segment_performance:
                    # Get all aggregation methods used
                    all_aggs = set()
                    for seg_perf in segment_performance.values():
                        all_aggs.update(seg_perf.keys())
                    all_aggs = sorted(all_aggs)
                    
                    # Header
                    header = f"  {'Segment':<25}"
                    for agg in all_aggs:
                        header += f" {agg.upper():<10}"
                    header += " BEST"
                    print(header)
                    print(f"  {'-'*70}")
                    
                    # Rows
                    for seg_name in sorted(segment_performance.keys()):
                        seg_perf = segment_performance[seg_name]
                        row = f"  {seg_name:<25}"
                        
                        values = []
                        for agg in all_aggs:
                            val = seg_perf.get(agg, 0)
                            row += f" {val:<10.4f}"
                            values.append(val)
                        
                        best_val = max(values) if values else 0
                        row += f" {best_val:.4f}"
                        print(row)
                    
                    # Overall statistics
                    all_values = [v for seg_perf in segment_performance.values() 
                                 for v in seg_perf.values()]
                    if all_values:
                        print(f"\n  Statistics:")
                        print(f"    Mean:   {np.mean(all_values):.4f}")
                        print(f"    Std:    {np.std(all_values):.4f}")
                        print(f"    Min:    {min(all_values):.4f}")
                        print(f"    Max:    {max(all_values):.4f}")
                        print(f"    Range:  {max(all_values) - min(all_values):.4f}")

# ============================================================================
# FINAL COMPREHENSIVE COMPARISON
# ============================================================================

print(f"\n\n{'#'*80}")
print(f"# FINAL COMPREHENSIVE COMPARISON ACROSS ALL MODELS")
print(f"{'#'*80}\n")

# Compare identification performance
print(f"{'='*80}")
print("1. IDENTIFICATION PERFORMANCE RANKING")
print(f"{'='*80}\n")

identification_ranking = []
for model_name in models:
    for scenario in ['oneshot_base', 'oneshot_aug', 'fewshot_base', 'fewshot_aug']:
        if all_results[scenario].get(model_name):
            results = all_results[scenario][model_name]
            if results:
                for agg, df in results.items():
                    best_rank1 = df['rank1_accuracy'].max()
                    best_f1 = df['f1_score'].max()
                    identification_ranking.append({
                        'model': model_name,
                        'scenario': scenario,
                        'aggregation': agg,
                        'rank1': best_rank1,
                        'f1': best_f1
                    })

identification_ranking = sorted(identification_ranking, 
                               key=lambda x: x['rank1'], 
                               reverse=True)

print(f"{'Rank':<6} {'Model':<20} {'Scenario':<18} {'Agg':<8} {'Rank-1':<10} {'F1':<10}")
print(f"{'-'*80}")
for i, item in enumerate(identification_ranking[:20], 1):
    print(f"{i:<6} {item['model']:<20} {item['scenario']:<18} "
          f"{item['aggregation']:<8} {item['rank1']:<10.4f} {item['f1']:<10.4f}")

# Compare impostor rejection - separated by type
print(f"\n{'='*80}")
print("2. IMPOSTOR REJECTION PERFORMANCE RANKING")
print(f"{'='*80}\n")

for impostor_type in ['real', 'lfw', 'combined']:
    print(f"\n{impostor_type.upper()} IMPOSTORS:")
    print(f"{'-'*80}")
    
    impostor_ranking = []
    for model_name in models:
        for gallery in ['oneshot_base', 'oneshot_aug', 'fewshot_base', 'fewshot_aug']:
            key = f'impostor_{gallery}_{impostor_type}'
            if all_results[key].get(model_name):
                results = all_results[key][model_name]
                if results:
                    for agg, df in results.items():
                        best_rejection = df['rejection_rate'].max()
                        best_far = df.loc[df['rejection_rate'].idxmax(), 
                                         'false_acceptance_rate']
                        total_impostors = df['total_impostors'].iloc[0]
                        impostor_ranking.append({
                            'model': model_name,
                            'gallery': gallery,
                            'agg': agg,
                            'rejection': best_rejection,
                            'far': best_far,
                            'n': total_impostors
                        })

    impostor_ranking = sorted(impostor_ranking, 
                             key=lambda x: x['rejection'], 
                             reverse=True)

    print(f"{'Rank':<6} {'Model':<15} {'Gallery':<15} {'Agg':<8} "
          f"{'N':<6} {'Rejection':<12} {'FAR':<10}")
    print(f"{'-'*90}")
    for i, item in enumerate(impostor_ranking[:15], 1):
        print(f"{i:<6} {item['model']:<15} {item['gallery']:<15} "
              f"{item['agg']:<8} {item['n']:<6} {item['rejection']:<12.4f} {item['far']:<10.4f}")

# Segmented performance analysis
print(f"\n{'='*80}")
print("3. SEGMENTED PERFORMANCE ANALYSIS")
print(f"{'='*80}\n")

# Aggregate all segment performances
all_segment_scores = {}
for model_name in models:
    for gallery_config in ['oneshot_base', 'oneshot_aug', 'fewshot_base', 'fewshot_aug']:
        key = f'segmented_{gallery_config}'
        if all_results[key].get(model_name):
            results = all_results[key][model_name]
            if results:
                for agg, seg_dict in results.items():
                    if isinstance(seg_dict, dict):
                        for seg_name, seg_df in seg_dict.items():
                            if seg_name not in all_segment_scores:
                                all_segment_scores[seg_name] = []
                            best_rank1 = seg_df['rank1_accuracy'].max()
                            all_segment_scores[seg_name].append(best_rank1)

# Calculate statistics per segment
segment_stats = []
for seg_name, scores in all_segment_scores.items():
    segment_stats.append({
        'segment': seg_name,
        'mean': np.mean(scores),
        'std': np.std(scores),
        'min': min(scores),
        'max': max(scores),
        'count': len(scores)
    })

segment_stats = sorted(segment_stats, key=lambda x: x['mean'], reverse=True)

print(f"{'Segment':<25} {'Mean':<10} {'Std':<10} {'Min':<10} {'Max':<10} {'N':<6}")
print(f"{'-'*80}")
for item in segment_stats:
    print(f"{item['segment']:<25} {item['mean']:<10.4f} {item['std']:<10.4f} "
          f"{item['min']:<10.4f} {item['max']:<10.4f} {item['count']:<6}")

# Overall summary
print(f"\n{'='*80}")
print("4. OVERALL SUMMARY")
print(f"{'='*80}\n")

print("Best Identification Performance:")
best_id = identification_ranking[0]
print(f"  Model: {best_id['model']}")
print(f"  Config: {best_id['scenario']} ({best_id['aggregation']})")
print(f"  Rank-1 Accuracy: {best_id['rank1']:.4f}")

print("\nBest Impostor Rejection (by type):")
for impostor_type in ['real', 'lfw', 'combined']:
    impostor_ranking = []
    for model_name in models:
        for gallery in ['oneshot_base', 'oneshot_aug', 'fewshot_base', 'fewshot_aug']:
            key = f'impostor_{gallery}_{impostor_type}'
            if all_results[key].get(model_name):
                results = all_results[key][model_name]
                if results:
                    for agg, df in results.items():
                        best_rejection = df['rejection_rate'].max()
                        impostor_ranking.append({
                            'model': model_name,
                            'gallery': gallery,
                            'agg': agg,
                            'rejection': best_rejection,
                            'type': impostor_type
                        })
    
    if impostor_ranking:
        best_imp = max(impostor_ranking, key=lambda x: x['rejection'])
        print(f"  {impostor_type.upper():10s}: {best_imp['model']:<15s} "
              f"{best_imp['gallery']:<15s} ({best_imp['agg']:<5s}) - {best_imp['rejection']:.4f}")

print("\nBest/Worst Segments:")
print(f"  Best: {segment_stats[0]['segment']} (mean: {segment_stats[0]['mean']:.4f})")
print(f"  Worst: {segment_stats[-1]['segment']} (mean: {segment_stats[-1]['mean']:.4f})")
print(f"  Performance Gap: {segment_stats[0]['mean'] - segment_stats[-1]['mean']:.4f}")

print(f"\n{'#'*80}")
print("# EVALUATION COMPLETE")
print(f"{'#'*80}\n")

# Save all results
results_pickle_path = results_dir / 'all_results_comprehensive.pkl'
with open(results_pickle_path, 'wb') as f:
    pickle.dump(all_results, f)
print(f"All results saved to: {results_pickle_path}")

In [None]:
def create_comparison_plot(comparison_df, results_dir):
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))

    top_10 = comparison_df.head(10)
    ax = axes[0, 0]
    bars = ax.barh(range(len(top_10)), top_10['rank1_accuracy'])
    ax.set_yticks(range(len(top_10)))
    ax.set_yticklabels([f"{row['model']}\n{row['scenario']}\n({row['aggregation']})" 
                        for _, row in top_10.iterrows()], fontsize=8)
    ax.set_xlabel('Rank-1 Accuracy')
    ax.set_title('Top 10 Models by Rank-1 Accuracy')
    ax.invert_yaxis()
    ax.grid(axis='x', alpha=0.3)

    for i, (_, row) in enumerate(top_10.iterrows()):
        if 'adaface' in row['model']:
            bars[i].set_color('skyblue')
        else:
            bars[i].set_color('lightcoral')

    top_10_f1 = comparison_df.sort_values('f1_score', ascending=False).head(10)
    ax = axes[0, 1]
    bars = ax.barh(range(len(top_10_f1)), top_10_f1['f1_score'])
    ax.set_yticks(range(len(top_10_f1)))
    ax.set_yticklabels([f"{row['model']}\n{row['scenario']}\n({row['aggregation']})" 
                        for _, row in top_10_f1.iterrows()], fontsize=8)
    ax.set_xlabel('F1-Score')
    ax.set_title('Top 10 Models by F1-Score')
    ax.invert_yaxis()
    ax.grid(axis='x', alpha=0.3)
    
    for i, (_, row) in enumerate(top_10_f1.iterrows()):
        if 'adaface' in row['model']:
            bars[i].set_color('skyblue')
        else:
            bars[i].set_color('lightcoral')

    top_10_eer = comparison_df.sort_values('eer', ascending=True).head(10)
    ax = axes[1, 0]
    bars = ax.barh(range(len(top_10_eer)), top_10_eer['eer'])
    ax.set_yticks(range(len(top_10_eer)))
    ax.set_yticklabels([f"{row['model']}\n{row['scenario']}\n({row['aggregation']})" 
                        for _, row in top_10_eer.iterrows()], fontsize=8)
    ax.set_xlabel('Equal Error Rate (EER)')
    ax.set_title('Top 10 Models by EER (Lower is Better)')
    ax.invert_yaxis()
    ax.grid(axis='x', alpha=0.3)
    
    for i, (_, row) in enumerate(top_10_eer.iterrows()):
        if 'adaface' in row['model']:
            bars[i].set_color('skyblue')
        else:
            bars[i].set_color('lightcoral')

    ax = axes[1, 1]
    scenario_stats = comparison_df.groupby('scenario')['rank1_accuracy'].agg(['mean', 'max', 'std'])
    x = range(len(scenario_stats))
    ax.bar([i-0.2 for i in x], scenario_stats['mean'], width=0.4, label='Mean', alpha=0.7)
    ax.bar([i+0.2 for i in x], scenario_stats['max'], width=0.4, label='Max', alpha=0.7)
    ax.set_xticks(x)
    ax.set_xticklabels(scenario_stats.index, rotation=45, ha='right')
    ax.set_ylabel('Rank-1 Accuracy')
    ax.set_title('Performance by Scenario')
    ax.legend()
    ax.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plot_path = results_dir / 'model_comparison_plot.png'
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    print(f"Comparison plot saved to: {plot_path}")
    plt.show()

In [None]:
def find_best_model(all_results):
    print(f"\n{'#'*70}")
    print(f"# FINAL MODEL RANKING")
    print(f"{'#'*70}")

    configs = []
    
    for scenario in ['oneshot_base', 'oneshot_aug', 'fewshot_base', 'fewshot_aug']:
        for model_name, results in all_results[scenario].items():
            if results is None:
                continue
            
            if scenario == 'oneshot_base':
                aggregations = ['max']
            else:
                aggregations = ['max', 'mean', 'topk'] if isinstance(results, dict) else ['max']
            
            for agg in aggregations:
                if isinstance(results, dict) and agg in results:
                    df = results[agg]
                elif not isinstance(results, dict):
                    df = results
                    agg = 'max'
                else:
                    continue

                best_idx = df['rank1_accuracy'].idxmax()
                best_row = df.iloc[best_idx]

                df_temp = df.copy()
                df_temp['err_diff'] = abs(df_temp['false_acceptance_rate'] - 
                                         df_temp['false_rejection_rate'])
                eer_idx = df_temp['err_diff'].idxmin()
                eer_row = df.iloc[eer_idx]
                
                configs.append({
                    'model': model_name,
                    'scenario': scenario,
                    'aggregation': agg,
                    'best_threshold': best_row['threshold'],
                    'rank1_accuracy': best_row['rank1_accuracy'],
                    'precision': best_row['precision'],
                    'recall': best_row['recall'],
                    'f1_score': best_row['f1_score'],
                    'far': best_row['false_acceptance_rate'],
                    'frr': best_row['false_rejection_rate'],
                    'eer_threshold': eer_row['threshold'],
                    'eer': (eer_row['false_acceptance_rate'] + 
                           eer_row['false_rejection_rate']) / 2,
                    'eer_accuracy': eer_row['rank1_accuracy']
                })

    comparison_df = pd.DataFrame(configs)

    comparison_df = comparison_df.sort_values('rank1_accuracy', ascending=False)
    
    print("\n" + "="*70)
    print("RANKING BY RANK-1 ACCURACY")
    print("="*70)
    print(comparison_df[['model', 'scenario', 'aggregation', 'rank1_accuracy', 
                         'best_threshold', 'f1_score']].head(10).to_string(index=False))
    
    print("\n" + "="*70)
    print("RANKING BY F1-SCORE")
    print("="*70)
    comparison_df_f1 = comparison_df.sort_values('f1_score', ascending=False)
    print(comparison_df_f1[['model', 'scenario', 'aggregation', 'f1_score', 
                            'best_threshold', 'rank1_accuracy']].head(10).to_string(index=False))
    
    print("\n" + "="*70)
    print("RANKING BY EER (Lower is better)")
    print("="*70)
    comparison_df_eer = comparison_df.sort_values('eer', ascending=True)
    print(comparison_df_eer[['model', 'scenario', 'aggregation', 'eer', 
                             'eer_threshold', 'eer_accuracy']].head(10).to_string(index=False))

    print(f"\n{'#'*70}")
    print("# OVERALL BEST MODEL")
    print(f"{'#'*70}")
    
    best_config = comparison_df.iloc[0]
    print(f"\nModel: {best_config['model']}")
    print(f"Scenario: {best_config['scenario']}")
    print(f"Aggregation: {best_config['aggregation']}")
    print(f"\nBest Performance (Threshold = {best_config['best_threshold']:.2f}):")
    print(f"  Rank-1 Accuracy: {best_config['rank1_accuracy']:.4f}")
    print(f"  Precision:       {best_config['precision']:.4f}")
    print(f"  Recall:          {best_config['recall']:.4f}")
    print(f"  F1-Score:        {best_config['f1_score']:.4f}")
    print(f"  FAR:             {best_config['far']:.4f}")
    print(f"  FRR:             {best_config['frr']:.4f}")
    print(f"\nAt EER (Threshold = {best_config['eer_threshold']:.2f}):")
    print(f"  EER:             {best_config['eer']:.4f}")
    print(f"  Rank-1 Accuracy: {best_config['eer_accuracy']:.4f}")

    comparison_path = results_dir / 'model_comparison.csv'
    comparison_df.to_csv(comparison_path, index=False)
    print(f"\nFull comparison saved to: {comparison_path}")

    create_comparison_plot(comparison_df, results_dir)
    
    return comparison_df

In [None]:
final_comparison = find_best_model(all_results)