In [None]:
import os
import warnings
import logging

# ============================================================================
# SUPPRESS ALL TENSORFLOW WARNINGS AND CUDA MESSAGES
# ============================================================================

# 1. Suppress TensorFlow logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # 0=ALL, 1=INFO, 2=WARNING, 3=ERROR

# 2. Suppress CUDA/cuDNN warnings
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

# 3. Suppress XLA compilation messages
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices=false'

# 4. Suppress Python warnings
warnings.filterwarnings('ignore')

# 5. Set logging level for various libraries
logging.getLogger('tensorflow').setLevel(logging.ERROR)
logging.getLogger('absl').setLevel(logging.ERROR)

# 6. Disable deprecation warnings
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

print("✅ All warnings suppressed - training output will be clean!\n")


In [None]:
# Cell 1: Import all required libraries
import os
import random
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.metrics import (
    roc_curve, auc, precision_recall_curve, average_precision_score,
    confusion_matrix, classification_report, f1_score,
    accuracy_score, precision_score, recall_score,
    roc_auc_score
)
from sklearn.preprocessing import label_binarize
from scipy import stats
import itertools
from sklearn.calibration import calibration_curve
from tqdm import tqdm
import warnings
import math
from scipy.spatial.distance import pdist, squareform
from itertools import combinations
from sklearn.metrics.pairwise import cosine_similarity as sklearn_cosine_similarity
import pickle
import json

# TensorFlow imports for model
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Input, Dense, BatchNormalization, Lambda
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
# Suppress warnings
warnings.filterwarnings('ignore')

# Set plotting style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12

print("✅ All libraries imported successfully")

In [None]:
# Cell 2: Global configuration
# Global config
IMG_SIZE = 112
BATCH_SIZE = 64  # Increase for better gradient estimates
AUTOTUNE = tf.data.AUTOTUNE
SEED = 42

# Add augmentation flag
USE_AUGMENTATION = True  # Add data augmentation to prevent overfitting

random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

DATASET_DIR = "/kaggle/input/datasetsforrestnet/ThirdLap"

print("✅ Configuration loaded")
print(f"Image Size: {IMG_SIZE}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Dataset Directory: {DATASET_DIR}")

In [None]:
def parse_image_with_augmentation(file_path, label, training=True):
    """Parse and preprocess image with optional augmentation"""
    image = tf.io.read_file(file_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    
    if training and USE_AUGMENTATION:
        # Data augmentation using only TensorFlow built-ins
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_brightness(image, 0.2)
        image = tf.image.random_contrast(image, 0.8, 1.2)
        image = tf.image.random_saturation(image, 0.8, 1.2)
        image = tf.image.random_hue(image, 0.1)
        
        # Random crop and resize (simulates slight rotation/zoom)
        crop_size = tf.random.uniform([], 0.9, 1.0)
        crop_h = tf.cast(IMG_SIZE * crop_size, tf.int32)
        crop_w = tf.cast(IMG_SIZE * crop_size, tf.int32)
        image = tf.image.random_crop(image, [crop_h, crop_w, 3])
        image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.clip_by_value(image, 0.0, 1.0)
    
    return image, label


def build_train_val_datasets_with_augmentation(root_dir, val_split=0.2):
    """
    Build training and validation datasets with proper split.
    Each identity appears in BOTH train and val sets - THIS IS THE KEY FIX!
    """
    class_names = sorted([d for d in os.listdir(root_dir)
                         if os.path.isdir(os.path.join(root_dir, d))])
    
    num_classes = len(class_names)
    print(f"Total identities: {num_classes}")
    
    train_paths, train_labels = [], []
    val_paths, val_labels = [], []
    
    # For each identity, split its images into train/val
    for idx, cls in enumerate(class_names):
        cls_dir = os.path.join(root_dir, cls)
        images = [os.path.join(cls_dir, img) 
                 for img in os.listdir(cls_dir)
                 if img.lower().endswith((".jpg", ".jpeg", ".png"))]
        
        # Shuffle images for this class
        random.shuffle(images)
        
        # Split: 80% train, 20% val (at least 1 image for val)
        n_val = max(1, int(len(images) * val_split))
        val_images = images[:n_val]
        train_images = images[n_val:]
        
        # Same label (idx) for both train and val!
        val_paths.extend(val_images)
        val_labels.extend([idx] * len(val_images))
        
        train_paths.extend(train_images)
        train_labels.extend([idx] * len(train_images))
    
    print(f"\nDataset Statistics:")
    print(f"  Total classes: {num_classes}")
    print(f"  Training images: {len(train_paths)}")
    print(f"  Validation images: {len(val_paths)}")
    print(f"  Train/Val split: {100*(1-val_split):.0f}% / {100*val_split:.0f}%\n")
    
    # Create datasets
    train_ds = tf.data.Dataset.from_tensor_slices((train_paths, train_labels))
    train_ds = train_ds.shuffle(10000, seed=SEED)
    train_ds = train_ds.map(
        lambda x, y: parse_image_with_augmentation(x, y, training=True), 
        num_parallel_calls=AUTOTUNE
    )
    train_ds = train_ds.batch(BATCH_SIZE).prefetch(AUTOTUNE)
    
    val_ds = tf.data.Dataset.from_tensor_slices((val_paths, val_labels))
    val_ds = val_ds.map(
        lambda x, y: parse_image_with_augmentation(x, y, training=False),
        num_parallel_calls=AUTOTUNE
    )
    val_ds = val_ds.batch(BATCH_SIZE).prefetch(AUTOTUNE)
    
    print("✅ Datasets created successfully\n")
    return train_ds, val_ds, num_classes, class_names

print("✅ Data loading functions defined")


In [None]:
print("Loading datasets...")
train_ds, val_ds, num_classes, all_class_names= build_train_val_datasets_with_augmentation(
    root_dir=DATASET_DIR,
    val_split=0.2  # 20% of images from each identity for validation
)

print(f"\nFinal numbers for model:")
print(f"Number of classes: {num_classes}")

# Verify dataset
sample_batch = next(iter(train_ds))
print(f"Training batch shape: {sample_batch[0].shape}")
print(f"Labels shape: {sample_batch[1].shape}")

In [None]:
# Cell 5: Define model architecture for face verification
def l2_norm(x):
    return tf.nn.l2_normalize(x, axis=1)

def build_resnet_embedding():
    """Build ResNet50 backbone for embeddings"""
    inputs = Input(shape=(112, 112, 3), name="input_image")
    
    base = ResNet50(
        include_top=False,
        weights="imagenet",
        input_tensor=inputs,
        pooling="avg"
    )
    
    x = BatchNormalization()(base.output)
    x = Dense(512, use_bias=False)(x)
    x = BatchNormalization()(x)
    embeddings = Lambda(l2_norm, name="embeddings")(x)
    
    return Model(inputs, embeddings, name="ResNet50_Embedding")

class ArcFace(tf.keras.layers.Layer):
    """ArcFace layer implementation - only for training"""
    def __init__(self, num_classes, margin=0.5, scale=64, **kwargs):
        super(ArcFace, self).__init__(**kwargs)
        self.num_classes = num_classes
        self.margin = margin
        self.scale = scale
        self.cos_m = tf.math.cos(margin)
        self.sin_m = tf.math.sin(margin)
        self.threshold = tf.math.cos(math.pi - margin)
        self.cos_mt = tf.math.cos(math.pi - margin)
        
    def build(self, input_shape):
        self.W = self.add_weight(
            name="W",
            shape=(input_shape[-1], self.num_classes),
            initializer="glorot_uniform",
            trainable=True
        )
        
    def call(self, inputs, labels=None):
        # Normalize weights and inputs
        W_norm = tf.nn.l2_normalize(self.W, axis=0)
        x_norm = tf.nn.l2_normalize(inputs, axis=1)
        
        # Compute cosine similarity
        cosine = tf.matmul(x_norm, W_norm)
        
        if labels is not None:
            # Create one-hot encoded labels
            one_hot_labels = tf.one_hot(labels, depth=self.num_classes)
            
            # Get cosine of correct classes
            cos_y = tf.reduce_sum(one_hot_labels * cosine, axis=1, keepdims=True)
            
            # Compute sin
            sin_y = tf.math.sqrt(1.0 - tf.math.square(cos_y))
            
            # Compute cos(theta + margin)
            cos_theta_m = cos_y * self.cos_m - sin_y * self.sin_m
            
            # Apply margin
            cos_theta_m = tf.where(cos_y > self.threshold, cos_theta_m, cos_y - self.cos_mt)
            
            # Scale and apply
            output = cosine + one_hot_labels * (cos_theta_m - cos_y)
            output = output * self.scale
            return output
        else:
            return cosine * self.scale

print("✅ Model architecture defined (face verification version)")

In [None]:
# Cell 6: Build and compile model for face verification (FIXED)

print("Building face verification model...")

# Build backbone
backbone = build_resnet_embedding()
print("Backbone summary:")
backbone.summary()

# Build ArcFace layer for training
arcface = ArcFace(num_classes, margin=0.5, scale=64)

# Create training model
inputs = Input(shape=(IMG_SIZE, IMG_SIZE, 3), name="input_image")
labels = Input(shape=(), name="label", dtype=tf.int32)

embeddings = backbone(inputs)
logits = arcface(embeddings, labels, training=True)

# Training model
train_model = Model([inputs, labels], logits, name="ArcFace_Trainer")

# ============================================================================
# FIX: Use EITHER LearningRateSchedule OR ReduceLROnPlateau, NOT BOTH
# ============================================================================

# OPTION 1: Simple fixed learning rate + ReduceLROnPlateau (RECOMMENDED)
# This is simpler and works well with the ReduceLROnPlateau callback

train_model.compile(
    optimizer=Adam(learning_rate=1e-4),  # Fixed learning rate
    loss=SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

print("✅ Training model compiled with fixed LR (will be adjusted by ReduceLROnPlateau)")


# OPTION 2: Use LearningRateSchedule WITHOUT ReduceLROnPlateau
# Uncomment this and comment out Option 1 if you prefer a schedule
# Also need to remove ReduceLROnPlateau from callbacks in Cell 8

"""
lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=1e-4,
    decay_steps=EPOCHS * (len(train_paths) // BATCH_SIZE),
    alpha=1e-6
)

train_model.compile(
    optimizer=Adam(learning_rate=lr_schedule),
    loss=SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

print("✅ Training model compiled with CosineDecay schedule")
print("NOTE: Remove ReduceLROnPlateau from callbacks!")
"""


# Build model with input shape
train_model.build([(None, IMG_SIZE, IMG_SIZE, 3), (None,)])

print(f"Total parameters: {train_model.count_params():,}")

# Create inference model (for extracting embeddings)
inference_model = Model(inputs=backbone.input, outputs=backbone.output, 
                       name="Face_Embedding_Model")
print("\n✅ Inference model for embedding extraction created")

In [None]:
# Cell 7: Prepare datasets for model training
print("Preparing datasets for training...")

# The training model expects [images, labels] as input
def prepare_train_dataset(dataset):
    """Convert dataset to format expected by training model"""
    def map_fn(images, labels):
        return (images, labels), labels
    return dataset.map(map_fn, num_parallel_calls=AUTOTUNE)

def prepare_val_dataset(dataset):
    """Convert validation dataset for training model"""
    def map_fn(images, labels):
        return (images, labels), labels
    return dataset.map(map_fn, num_parallel_calls=AUTOTUNE)

# Prepare datasets
train_ds_for_model = prepare_train_dataset(train_ds)
val_ds_for_model = prepare_val_dataset(val_ds)

# Use cardinality instead of iterating through the entire dataset
train_card = tf.data.experimental.cardinality(train_ds_for_model).numpy()
val_card   = tf.data.experimental.cardinality(val_ds_for_model).numpy()
print(f"Training batches  : {train_card}")
print(f"Validation batches: {val_card}")
print("✅ Datasets prepared for ArcFace training")

In [None]:
# Cell 8: Define training callbacks with embedding saver
print("Setting up callbacks...")

def compute_similarity_scores_for_auc(embeddings, labels):
    """Compute similarity scores for ROC-AUC calculation (memory-safe)"""
    n = len(embeddings)
    if n > 1000:
        indices = np.random.choice(n, 1000, replace=False)
        embeddings = embeddings[indices]
        labels = labels[indices]
        n = 1000

    similarity_matrix = sklearn_cosine_similarity(embeddings)

    max_pairs = min(10000, n * (n - 1) // 2)
    pair_labels = []
    scores = []
    pairs_sampled = 0

    while pairs_sampled < max_pairs:
        i, j = np.random.randint(0, n, 2)
        if i == j:
            continue
        scores.append(similarity_matrix[i, j])
        pair_labels.append(1 if labels[i] == labels[j] else 0)
        pairs_sampled += 1

    del similarity_matrix  # free immediately

    return np.array(scores), np.array(pair_labels)


class EmbeddingSaverCallback(tf.keras.callbacks.Callback):
    """Saves only the ROC-AUC per epoch — does NOT accumulate embeddings in RAM."""
    def __init__(self, backbone, val_ds, save_dir='embeddings_by_epoch'):
        super().__init__()
        self.backbone = backbone
        self.val_ds = val_ds
        self.save_dir = save_dir
        os.makedirs(save_dir, exist_ok=True)
        self.epoch_roc_aucs = []  # only a short list of floats

        # Extract labels once (cheap — just ints)
        print("Extracting validation labels...")
        all_labels = []
        for images, labels in self.val_ds:
            all_labels.append(labels.numpy())
        self.val_labels = np.concatenate(all_labels)

    def on_epoch_end(self, epoch, logs=None):
        print(f"\nEpoch {epoch+1}: Extracting embeddings for ROC-AUC...")

        # Stream embeddings batch-by-batch
        embeddings_list = []
        for images, _ in self.val_ds:
            emb = self.backbone(images, training=False)
            emb = tf.math.l2_normalize(emb, axis=1)
            embeddings_list.append(emb.numpy())

        embeddings = np.vstack(embeddings_list)
        del embeddings_list  # release list of arrays

        # Compute ROC-AUC (internally subsamples to 1000 samples)
        scores, pair_labels = compute_similarity_scores_for_auc(embeddings, self.val_labels)
        del embeddings  # free the big array NOW

        if len(np.unique(pair_labels)) > 1:
            roc_auc = roc_auc_score(pair_labels, scores)
            self.epoch_roc_aucs.append(roc_auc)
            if logs is not None:
                logs['val_roc_auc'] = roc_auc
            print(f"Epoch {epoch+1}: ROC-AUC = {roc_auc:.4f}")
        else:
            self.epoch_roc_aucs.append(0.5)

        del scores, pair_labels  # free

        # Persist only the lightweight history list
        np.save(os.path.join(self.save_dir, 'roc_auc_history.npy'),
                np.array(self.epoch_roc_aucs))

        # Persist a small JSON summary (no large arrays)
        summary = {
            'epoch': epoch + 1,
            'val_loss': float(logs.get('val_loss', 0)),
            'val_accuracy': float(logs.get('val_accuracy', 0)),
            'val_roc_auc': float(logs.get('val_roc_auc', 0.5)),
            'loss': float(logs.get('loss', 0)),
            'accuracy': float(logs.get('accuracy', 0))
        }
        with open(os.path.join(self.save_dir, f'epoch_{epoch+1:03d}_summary.json'), 'w') as f:
            json.dump(summary, f, indent=4)


# Create the embedding saver callback
embedding_saver = EmbeddingSaverCallback(backbone, val_ds, save_dir='embeddings_by_epoch')

# Define all callbacks
callbacks = [
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=3,
        min_lr=1e-6,
        verbose=1
    ),
    EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True,
        verbose=1
    ),
    ModelCheckpoint(
        'best_arcface_model.h5',
        monitor='val_loss',
        save_best_only=True,
        verbose=1
    ),
    embedding_saver
]

print("✅ Callbacks defined")

In [None]:
# Cell 9: Train the model with embedding saving
print("Starting training...")
print("Note: Embeddings will be saved at each epoch for ROC-AUC stability analysis")
print("This may increase training time slightly\n")

EPOCHS = 20

# FIXED: Use train_model instead of model, and prepared datasets
history = train_model.fit(
    train_ds_for_model,
    validation_data=val_ds_for_model,
    epochs=EPOCHS,
    callbacks=callbacks,
    verbose=1
)

print("✅ Training completed")
print("✅ Embeddings saved for each epoch in 'embeddings_by_epoch' folder")

In [None]:
# Cell 10: Save training history
print("Saving training history...")

# Convert history to serializable format
history_dict = {}
for key, values in history.history.items():
    if isinstance(values, list):
        history_dict[key] = [float(v) for v in values]

# Save to JSON
with open('training_history.json', 'w') as f:
    json.dump(history_dict, f, indent=4)

# Save to pickle
with open('training_history.pkl', 'wb') as f:
    pickle.dump(history.history, f)

print("✅ Training history saved")

In [None]:
SAVE_DIR = "/kaggle/working/arcface_model"
os.makedirs(SAVE_DIR, exist_ok=True)

# Save complete model
train_model.save(os.path.join(SAVE_DIR, "arcface_resnet50_final.keras"))

# Save only the backbone for inference
embedding_model = tf.keras.Model(
    inputs=backbone.input,
    outputs=backbone.output,
    name="embedding_model"
)
embedding_model.save(os.path.join(SAVE_DIR, "embedding_model.keras"))

# Save class names
with open(os.path.join(SAVE_DIR, "class_names.pkl"), "wb") as f:
    pickle.dump(all_class_names, f)

print("✅ Model saved successfully")
print("Saved files:")
print(f"1. {SAVE_DIR}/arcface_resnet50_final.keras - Complete model")
print(f"2. {SAVE_DIR}/embedding_model.keras - Embedding model for inference")
print(f"3. {SAVE_DIR}/class_names.pkl - Class names")
print(f"4. {SAVE_DIR}/training_history.json - Training metrics")

In [None]:
# Cell 12: Define evaluation functions for face verification
def extract_embeddings(model, dataset):
    """Extract embeddings from dataset"""
    embeddings_list = []
    labels_list = []
    
    print("Extracting embeddings...")
    for images, labels in tqdm(dataset, desc="Processing batches"):
        # Get only images (skip labels for inference)
        embeddings = model.predict(images, verbose=0)
        embeddings_list.append(embeddings)
        labels_list.append(labels.numpy())
    
    return np.vstack(embeddings_list), np.concatenate(labels_list)

def evaluate_face_verification(gallery_embeddings, gallery_labels, 
                              probe_embeddings, probe_labels, threshold=0.5):
    """
    Evaluate face verification performance using cosine similarity
    
    Args:
        gallery_embeddings: Embeddings from training set (known identities)
        gallery_labels: Labels for gallery set
        probe_embeddings: Embeddings from validation set (query identities)
        probe_labels: Labels for probe set
        threshold: Similarity threshold for verification
    """
    from sklearn.metrics.pairwise import cosine_similarity
    
    # Calculate similarity matrix between gallery and probe
    print("Calculating similarity matrix...")
    similarity_matrix = cosine_similarity(probe_embeddings, gallery_embeddings)
    
    # For each probe image, find most similar gallery image
    predictions = []
    true_labels = []
    
    print("Matching probe to gallery...")
    for i in tqdm(range(len(probe_embeddings)), desc="Evaluating matches"):
        max_similarity_idx = np.argmax(similarity_matrix[i])
        max_similarity = similarity_matrix[i, max_similarity_idx]
        
        # Predict same identity if similarity > threshold
        predicted_same = max_similarity > threshold
        
        # For class-disjoint validation: All should be FALSE (different identities)
        # But we check if it's the same label index (though identities are different)
        actual_same = (probe_labels[i] == gallery_labels[max_similarity_idx])
        
        predictions.append(predicted_same)
        true_labels.append(actual_same)
    
    # Calculate metrics
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
    
    accuracy = accuracy_score(true_labels, predictions)
    precision = precision_score(true_labels, predictions, zero_division=0)
    recall = recall_score(true_labels, predictions, zero_division=0)
    f1 = f1_score(true_labels, predictions, zero_division=0)
    cm = confusion_matrix(true_labels, predictions)
    
    print(f"\nFace Verification Results (threshold={threshold}):")
    print(f"  Accuracy: {accuracy:.4f}")
    print(f"  Precision: {precision:.4f}")
    print(f"  Recall: {recall:.4f}")
    print(f"  F1-Score: {f1:.4f}")
    print(f"\nConfusion Matrix:")
    print(f"  True Negatives: {cm[0,0]}")
    print(f"  False Positives: {cm[0,1]}")
    print(f"  False Negatives: {cm[1,0]}")
    print(f"  True Positives: {cm[1,1]}")
    
    # Calculate verification rate at different thresholds
    print("\nCalculating similarity distributions...")
    similarities = []
    same_pairs = []
    
    # Sample pairs for efficiency (optional)
    sample_size = min(10000, len(probe_embeddings) * len(gallery_embeddings))
    indices = np.random.choice(len(probe_embeddings) * len(gallery_embeddings), 
                              sample_size, replace=False)
    
    for idx in tqdm(indices, desc="Sampling pairs"):
        i = idx // len(gallery_embeddings)
        j = idx % len(gallery_embeddings)
        
        if i >= len(probe_embeddings):
            continue
            
        sim = cosine_similarity([probe_embeddings[i]], [gallery_embeddings[j]])[0][0]
        similarities.append(sim)
        same_pairs.append(probe_labels[i] == gallery_labels[j])
    
    return accuracy, precision, recall, f1, similarities, same_pairs, cm

print("✅ Evaluation functions defined")

In [None]:
# Cell 13: Evaluate the trained model (memory-safe)
import gc
print("Evaluating face verification performance...")

def extract_embeddings_safe(model, dataset):
    """Extract embeddings batch-by-batch without model.predict overhead"""
    embeddings_list = []
    labels_list = []
    for images, labels in dataset:
        emb = model(images, training=False).numpy()
        embeddings_list.append(emb)
        labels_list.append(labels.numpy())
    return np.vstack(embeddings_list), np.concatenate(labels_list)


# 1. Extract gallery embeddings (training set)
print("\n1. Extracting gallery embeddings (training set)...")
gallery_embeddings, gallery_labels = extract_embeddings_safe(inference_model, train_ds)

# 2. Extract probe / validation embeddings
print("\n2. Extracting probe embeddings (validation set)...")
probe_embeddings, probe_labels = extract_embeddings_safe(inference_model, val_ds)

print(f"\nDataset sizes:")
print(f"  Gallery: {len(gallery_embeddings)} embeddings, {len(np.unique(gallery_labels))} identities")
print(f"  Probe:   {len(probe_embeddings)} embeddings, {len(np.unique(probe_labels))} identities")

# --------------- chunked verification eval ---------------
from sklearn.metrics.pairwise import cosine_similarity as sk_cosine

print("\n3. Computing gallery<->probe similarity (chunked to save RAM)...")

CHUNK = 256
n_probe   = len(probe_embeddings)
n_gallery = len(gallery_embeddings)

max_sim_idx = np.empty(n_probe, dtype=np.int64)
max_sim_val = np.empty(n_probe, dtype=np.float32)

for start in range(0, n_probe, CHUNK):
    end   = min(start + CHUNK, n_probe)
    block = sk_cosine(probe_embeddings[start:end], gallery_embeddings)  # (chunk, gallery)
    max_sim_idx[start:end] = block.argmax(axis=1)
    max_sim_val[start:end] = block.max(axis=1)
    del block   # free chunk immediately

THRESHOLD = 0.5
predictions = max_sim_val > THRESHOLD
true_labels = probe_labels == gallery_labels[max_sim_idx]

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix as sk_cm

accuracy  = accuracy_score(true_labels, predictions)
precision = precision_score(true_labels, predictions, zero_division=0)
recall    = recall_score(true_labels, predictions, zero_division=0)
f1        = f1_score(true_labels, predictions, zero_division=0)
cm        = sk_cm(true_labels, predictions)

print(f"\nFace Verification Results (threshold={THRESHOLD}):")
print(f"  Accuracy : {accuracy:.4f}")
print(f"  Precision: {precision:.4f}")
print(f"  Recall   : {recall:.4f}")
print(f"  F1-Score : {f1:.4f}")
print(f"\nConfusion Matrix:")
print(f"  TN={cm[0,0]}  FP={cm[0,1]}  FN={cm[1,0]}  TP={cm[1,1]}")

# Sample random pairs for similarity-distribution plots
print("\nSampling pairs for similarity distributions...")
sample_size = min(10000, n_probe * n_gallery)
similarities = []
same_pairs   = []

for _ in range(sample_size):
    i = np.random.randint(n_probe)
    j = np.random.randint(n_gallery)
    # Embeddings are L2-normalised so dot product == cosine similarity
    sim = float(probe_embeddings[i] @ gallery_embeddings[j])
    similarities.append(sim)
    same_pairs.append(int(probe_labels[i] == gallery_labels[j]))

similarities = np.array(similarities, dtype=np.float32)
same_pairs   = np.array(same_pairs,   dtype=np.int8)

# ---- FREE the large gallery array — no longer needed ----
del gallery_embeddings, gallery_labels, max_sim_idx, max_sim_val
gc.collect()

# Keep only what the later cells actually need
val_embeddings = probe_embeddings
val_labels     = probe_labels

print("\n✅ Face verification evaluation completed!")

In [None]:
# Cell 14: Visualization of results
print("Visualizing results...")

# Plot training history
if 'history' in locals():
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot loss
    axes[0].plot(history.history['loss'], label='Training Loss')
    axes[0].plot(history.history['val_loss'], label='Validation Loss')
    axes[0].set_title('Training and Validation Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')5.8GiB
Max 29GiB
GPU
GPU
0.00%
GPU Memory
12.3GiB
Max 16GiB
￼power_settings_new
￼loop
￼more_vert
￼
￼terminal￼keyboard
￼
arrow_left
Session started.
RAM
    axes[0].legend()
    axes[0].grid(True)
    
    # Plot accuracy
    axes[1].plot(history.history['accuracy'], label='Training Accuracy')
    axes[1].plot(history.history['val_accuracy'], label='Validation Accuracy')
    axes[1].set_title('Training and Validation Accuracy')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].legend()
    axes[1].grid(True)
    
    plt.tight_layout()
    plt.show()

# Plot similarity distributions
if 'similarities' in locals() and 'same_pairs' in locals():
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Convert to numpy arrays
    similarities = np.array(similarities)
    same_pairs = np.array(same_pairs)
    
    # Plot histogram of similarities
    axes[0].hist(similarities[same_pairs == 0], bins=50, alpha=0.7, 
                label='Different Identities', density=True)
    axes[0].hist(similarities[same_pairs == 1], bins=50, alpha=0.7, 
                label='Same Identity', density=True)
    axes[0].set_title('Similarity Distribution')
    axes[0].set_xlabel('Cosine Similarity')
    axes[0].set_ylabel('Density')
    axes[0].legend()
    axes[0].grid(True)
    
    # Plot ROC curve (if we have enough data)
    if len(np.unique(same_pairs)) > 1:
        from sklearn.metrics import roc_curve, auc
        
        fpr, tpr, thresholds = roc_curve(same_pairs, similarities)
        roc_auc = auc(fpr, tpr)
        
        axes[1].plot(fpr, tpr, color='darkorange', lw=2, 
                    label=f'ROC curve (AUC = {roc_auc:.3f})')
        axes[1].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
        axes[1].set_xlim([0.0, 1.0])
        axes[1].set_ylim([0.0, 1.05])
        axes[1].set_xlabel('False Positive Rate')
        axes[1].set_ylabel('True Positive Rate')
        axes[1].set_title('Receiver Operating Characteristic')
        axes[1].legend(loc="lower right")
        axes[1].grid(True)
    
    plt.tight_layout()
    plt.show()

print("✅ Visualizations complete!")

In [None]:
# Cell 15: Define basic evaluation functions (memory-safe)
def compute_similarity_scores(embeddings, labels, max_n=2000):
    """Compute similarity scores — subsample if too large to avoid N×N explosion"""
    n = len(embeddings)
    if n > max_n:
        idx = np.random.choice(n, max_n, replace=False)
        embeddings = embeddings[idx]
        labels     = labels[idx]
        n = max_n

    similarity_matrix = sklearn_cosine_similarity(embeddings)

    # Build pair labels using vectorised broadcasting
    pair_labels = (labels[:, None] == labels[None, :]).astype(np.int8)

    # Upper triangle only (no diagonal, no double-counting)
    mask        = np.triu_indices(n, k=1)
    scores      = similarity_matrix[mask].astype(np.float32)
    labels_flat = pair_labels[mask]

    del similarity_matrix, pair_labels  # free immediately

    return scores, labels_flat


def get_predicted_labels(embeddings, labels):
    """Get predicted labels using nearest neighbour (chunked)"""
    n     = len(embeddings)
    preds = np.empty(n, dtype=labels.dtype)
    CHUNK = 256

    for start in range(0, n, CHUNK):
        end = min(start + CHUNK, n)
        sim = sklearn_cosine_similarity(embeddings[start:end], embeddings)
        # Zero out self-similarity for items in this chunk
        for local_i, global_i in enumerate(range(start, end)):
            sim[local_i, global_i] = -1.0
        preds[start:end] = labels[sim.argmax(axis=1)]
        del sim  # free each chunk
    return preds

print("✅ Basic evaluation functions defined")

In [None]:
# Cell 16: Plot training history
print("Plotting training history...")

# Extract history
history_data = history.history
epochs = range(1, len(history_data['accuracy']) + 1)

# Create figure with subplots
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Plot 1: Training & Validation Accuracy
axes[0, 0].plot(epochs, history_data['accuracy'], 'b-', linewidth=2, label='Training Accuracy')
axes[0, 0].plot(epochs, history_data['val_accuracy'], 'r-', linewidth=2, label='Validation Accuracy')
axes[0, 0].set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
axes[0, 0].set_xlabel('Epochs', fontsize=12)
axes[0, 0].set_ylabel('Accuracy', fontsize=12)
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Plot 2: Training & Validation Loss
axes[0, 1].plot(epochs, history_data['loss'], 'b-', linewidth=2, label='Training Loss')
axes[0, 1].plot(epochs, history_data['val_loss'], 'r-', linewidth=2, label='Validation Loss')
axes[0, 1].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
axes[0, 1].set_xlabel('Epochs', fontsize=12)
axes[0, 1].set_ylabel('Loss', fontsize=12)
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Learning Rate Schedule (if available)
if 'lr' in history_data:
    axes[1, 0].plot(epochs, history_data['lr'], 'g-', linewidth=2)
    axes[1, 0].set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
    axes[1, 0].set_xlabel('Epochs', fontsize=12)
    axes[1, 0].set_ylabel('Learning Rate', fontsize=12)
    axes[1, 0].set_yscale('log')
    axes[1, 0].grid(True, alpha=0.3)

# Plot 4: Accuracy vs Loss
ax2 = axes[1, 1] if 'lr' in history_data else axes[1, 0]
ax2.scatter(history_data['loss'], history_data['accuracy'], alpha=0.6, 
            c='blue', label='Training', s=50)
ax2.scatter(history_data['val_loss'], history_data['val_accuracy'], alpha=0.6, 
            c='red', label='Validation', s=50)
ax2.set_title('Accuracy vs Loss', fontsize=14, fontweight='bold')
ax2.set_xlabel('Loss', fontsize=12)
ax2.set_ylabel('Accuracy', fontsize=12)
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_history.png', dpi=300, bbox_inches='tight')
plt.show()
print("✅ Training history plots saved")

In [None]:
# Cell 17: Compute basic metrics
print("Computing basic metrics...")

# Check if val_embeddings exists, if not extract them
if 'val_embeddings' not in locals() or 'val_labels' not in locals():
    print("Extracting validation embeddings...")
    val_embeddings, val_labels = extract_embeddings(inference_model, val_ds)

# Compute similarity scores for verification
scores, pair_labels = compute_similarity_scores(val_embeddings, val_labels)

# Compute ROC-AUC
fpr, tpr, thresholds = roc_curve(pair_labels, scores)
roc_auc = auc(fpr, tpr)

# Compute Top-1 accuracy
preds = get_predicted_labels(val_embeddings, val_labels)
top1_accuracy = accuracy_score(val_labels, preds)

print(f"✅ ROC-AUC: {roc_auc:.4f}")
print(f"✅ Top-1 Accuracy: {top1_accuracy:.4f}")

# Save basic metrics
basic_metrics = {
    'roc_auc': float(roc_auc),
    'top1_accuracy': float(top1_accuracy),
    'fpr': fpr.tolist(),
    'tpr': tpr.tolist(),
    'thresholds': thresholds.tolist()
}

with open('basic_metrics.json', 'w') as f:
    json.dump(basic_metrics, f, indent=4)

In [None]:
# Cell 18: Plot ROC Curve
print("Plotting ROC Curve...")

plt.figure(figsize=(10, 8))
plt.plot(fpr, tpr, color='darkorange', lw=2, 
         label=f'ROC curve (AUC = {roc_auc:.4f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', 
         label='Random Classifier')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate (FPR)', fontsize=12)
plt.ylabel('True Positive Rate (TPR)', fontsize=12)
plt.title('Receiver Operating Characteristic (ROC) Curve', fontsize=14, fontweight='bold')
plt.legend(loc="lower right")
plt.grid(True, alpha=0.3)

# Add some important points
for far_target in [0.001, 0.01, 0.1]:
    idx = np.argmin(np.abs(fpr - far_target))
    if idx < len(tpr):
        plt.scatter(fpr[idx], tpr[idx], s=100, zorder=5, 
                   label=f'FAR={far_target}: TAR={tpr[idx]:.3f}')
        plt.annotate(f'TAR={tpr[idx]:.3f}', 
                    (fpr[idx], tpr[idx]),
                    xytext=(10, 10), textcoords='offset points')

plt.legend(loc="lower right", fontsize=10)
plt.tight_layout()
plt.savefig('roc_curve.png', dpi=300, bbox_inches='tight')
plt.show()
print("✅ ROC Curve saved")

In [None]:
# Cell 19: Plot Precision-Recall Curve
print("Plotting Precision-Recall Curve...")

precision, recall, pr_thresholds = precision_recall_curve(pair_labels, scores)
avg_precision = average_precision_score(pair_labels, scores)

plt.figure(figsize=(10, 8))
plt.plot(recall, precision, color='darkgreen', lw=2, 
         label=f'AP = {avg_precision:.4f}')
plt.fill_between(recall, precision, alpha=0.2, color='green')
plt.xlabel('Recall', fontsize=12)
plt.ylabel('Precision', fontsize=12)
plt.title('Precision-Recall Curve', fontsize=14, fontweight='bold')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.legend(loc="lower left")
plt.grid(True, alpha=0.3)

# Add F1-score contours
f1_scores = np.linspace(0.1, 0.9, 9)
for f1 in f1_scores:
    if f1 == 0:
        continue
    x = np.linspace(0.01, 1)
    y = f1 * x / (2 * x - f1)
    plt.plot(x[y >= 0], y[y >= 0], color='gray', alpha=0.2, linestyle='--')
    plt.text(0.9, y[-1] - 0.05, f'F1={f1:.1f}', fontsize=8, alpha=0.5)

plt.tight_layout()
plt.savefig('precision_recall_curve.png', dpi=300, bbox_inches='tight')
plt.show()
print("✅ Precision-Recall Curve saved")

In [None]:
# Cell 20: Plot TAR vs FAR Curve
print("Plotting TAR vs FAR Curve...")

# Sort thresholds and corresponding TPR, FPR
sorted_idx = np.argsort(thresholds)[::-1]
sorted_fpr = fpr[sorted_idx]
sorted_tpr = tpr[sorted_idx]

plt.figure(figsize=(10, 8))
plt.semilogx(sorted_fpr, sorted_tpr, 'b-', linewidth=2)
plt.xlim([1e-4, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Acceptance Rate (FAR)', fontsize=12)
plt.ylabel('True Acceptance Rate (TAR)', fontsize=12)
plt.title('TAR vs FAR Curve (Verification Performance)', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3, which='both')

# Mark specific operating points
far_targets = [1e-4, 1e-3, 1e-2, 1e-1]
for far in far_targets:
    idx = np.argmin(np.abs(sorted_fpr - far))
    if idx < len(sorted_tpr):
        plt.scatter(sorted_fpr[idx], sorted_tpr[idx], s=100, color='red', zorder=5)
        plt.annotate(f'FAR={far:.0e}\nTAR={sorted_tpr[idx]:.3f}', 
                    (sorted_fpr[idx], sorted_tpr[idx]),
                    xytext=(10, 10), textcoords='offset points',
                    bbox=dict(boxstyle="round,pad=0.3", facecolor="yellow", alpha=0.5))

plt.tight_layout()
plt.savefig('tar_vs_far_curve.png', dpi=300, bbox_inches='tight')
plt.show()
print("✅ TAR vs FAR Curve saved")

In [None]:
# Cell 21: Compute and plot Per-Class Accuracy Distribution
print("Computing per-class accuracy...")

# Compute confusion matrix
cm = confusion_matrix(val_labels, preds)

# Calculate per-class accuracy
per_class_accuracy = cm.diagonal() / cm.sum(axis=1)
per_class_accuracy = np.nan_to_num(per_class_accuracy)  # Handle division by zero

# Sort classes by accuracy
sorted_idx = np.argsort(per_class_accuracy)[::-1]
sorted_acc = per_class_accuracy[sorted_idx]

# Plot distribution
plt.figure(figsize=(14, 6))

# Plot 1: Histogram of accuracies
plt.subplot(1, 2, 1)
plt.hist(per_class_accuracy, bins=50, alpha=0.7, color='steelblue', edgecolor='black')
plt.xlabel('Per-Class Accuracy', fontsize=12)
plt.ylabel('Number of Classes', fontsize=12)
plt.title('Distribution of Per-Class Accuracies', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)

# Add statistics
mean_acc = np.mean(per_class_accuracy)
median_acc = np.median(per_class_accuracy)
std_acc = np.std(per_class_accuracy)
plt.axvline(mean_acc, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_acc:.3f}')
plt.axvline(median_acc, color='green', linestyle='--', linewidth=2, label=f'Median: {median_acc:.3f}')
plt.legend()

# Plot 2: Sorted per-class accuracy
plt.subplot(1, 2, 2)
plt.plot(range(len(sorted_acc)), sorted_acc, 'b-', linewidth=1, alpha=0.7)
plt.fill_between(range(len(sorted_acc)), sorted_acc, alpha=0.3, color='blue')
plt.xlabel('Class Rank (Sorted by Accuracy)', fontsize=12)
plt.ylabel('Accuracy', fontsize=12)
plt.title('Sorted Per-Class Accuracy', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)

# Mark top 20 and bottom 20
top_20_idx = sorted_idx[:20]
bottom_20_idx = sorted_idx[-20:]
plt.scatter(range(20), sorted_acc[:20], color='green', s=50, label='Top 20 Classes', zorder=5)
plt.scatter(range(len(sorted_acc)-20, len(sorted_acc)), sorted_acc[-20:], 
            color='red', s=50, label='Bottom 20 Classes', zorder=5)
plt.legend()

plt.tight_layout()
plt.savefig('per_class_accuracy_distribution.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"✅ Per-class accuracy computed")
print(f"   Mean: {mean_acc:.4f}")
print(f"   Median: {median_acc:.4f}")
print(f"   Std: {std_acc:.4f}")
print(f"   Min: {np.min(per_class_accuracy):.4f}")
print(f"   Max: {np.max(per_class_accuracy):.4f}")

In [None]:
# Cell 22: Quick Similarity Distribution
pos_scores = scores[pair_labels == 1]
neg_scores = scores[pair_labels == 0]

plt.figure(figsize=(7,5))
plt.hist(pos_scores, bins=100, alpha=0.6, label="Intra-class (Genuine)")
plt.hist(neg_scores, bins=100, alpha=0.6, label="Inter-class (Impostor)")
plt.xlabel("Cosine Similarity")
plt.ylabel("Frequency")
plt.title("Intra vs Inter Class Similarity Distribution")
plt.legend()
plt.grid()
plt.show()

print("Mean Intra-class similarity:", np.mean(pos_scores))
print("Mean Inter-class similarity:", np.mean(neg_scores))

In [None]:
# Cell 23: Intra-Class vs Inter-Class Similarity Distribution (memory-safe)
import gc
print("Computing intra-class and inter-class similarities...")

# Subsample to at most 2000 samples so the N×N matrix stays small
_n   = len(val_embeddings)
_max = 2000
if _n > _max:
    _idx = np.random.choice(_n, _max, replace=False)
    _emb = val_embeddings[_idx]
    _lab = val_labels[_idx]
else:
    _emb = val_embeddings
    _lab = val_labels

# Single N×N similarity matrix (N <= 2000 -> ~32 MB max)
_sim = sklearn_cosine_similarity(_emb)

# Upper-triangle mask (no diagonal, no double-counting)
_i_idx, _j_idx = np.triu_indices(len(_lab), k=1)
_scores  = _sim[_i_idx, _j_idx].astype(np.float32)
_same    = (_lab[_i_idx] == _lab[_j_idx])

del _sim, _i_idx, _j_idx  # free the big matrix immediately
gc.collect()

intra_similarities = _scores[ _same].copy()
inter_similarities = _scores[~_same].copy()
del _scores, _same, _emb, _lab
gc.collect()

# Cap at 10000 for plotting
if len(intra_similarities) > 10000:
    intra_similarities = np.random.choice(intra_similarities, 10000, replace=False)
if len(inter_similarities) > 10000:
    inter_similarities = np.random.choice(inter_similarities, 10000, replace=False)

# ---------- plots ----------
plt.figure(figsize=(14, 10))

plt.subplot(2, 2, 1)
plt.hist(intra_similarities, bins=100, alpha=0.6, color='blue',
         label=f'Intra-Class (n={len(intra_similarities):,})', density=True)
plt.hist(inter_similarities, bins=100, alpha=0.6, color='red',
         label=f'Inter-Class (n={len(inter_similarities):,})', density=True)
plt.xlabel('Cosine Similarity', fontsize=12)
plt.ylabel('Density', fontsize=12)
plt.title('Intra-Class vs Inter-Class Similarity Distribution', fontsize=14, fontweight='bold')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(2, 2, 2)
box_data   = [intra_similarities, inter_similarities]
box_labels = ['Intra-Class', 'Inter-Class']
bp = plt.boxplot(box_data, labels=box_labels, patch_artist=True)
bp['boxes'][0].set_facecolor('lightblue')
bp['boxes'][1].set_facecolor('lightcoral')
plt.ylabel('Cosine Similarity', fontsize=12)
plt.title('Similarity Distribution Comparison', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3, axis='y')

plt.subplot(2, 2, 3)
parts = plt.violinplot(box_data, showmeans=True, showmedians=True)
for i, pc in enumerate(parts['bodies']):
    pc.set_facecolor(['lightblue', 'lightcoral'][i])
    pc.set_alpha(0.7)
plt.xticks([1, 2], box_labels)
plt.ylabel('Cosine Similarity', fontsize=12)
plt.title('Violin Plot of Similarities', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3, axis='y')

plt.subplot(2, 2, 4)
sorted_intra = np.sort(intra_similarities)
sorted_inter = np.sort(inter_similarities)
y_intra = np.arange(1, len(sorted_intra)+1) / len(sorted_intra)
y_inter = np.arange(1, len(sorted_inter)+1) / len(sorted_inter)
plt.plot(sorted_intra, y_intra, 'b-', label='Intra-Class', linewidth=2)
plt.plot(sorted_inter, y_inter, 'r-', label='Inter-Class', linewidth=2)
plt.xlabel('Cosine Similarity', fontsize=12)
plt.ylabel('Cumulative Probability', fontsize=12)
plt.title('Cumulative Distribution Function', fontsize=14, fontweight='bold')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('intra_inter_class_similarity.png', dpi=300, bbox_inches='tight')
plt.show()

intra_mean = np.mean(intra_similarities)
inter_mean = np.mean(inter_similarities)
intra_std  = np.std(intra_similarities)
inter_std  = np.std(inter_similarities)

print(f"✅ Intra-class similarity: Mean = {intra_mean:.4f}, Std = {intra_std:.4f}")
print(f"✅ Inter-class similarity: Mean = {inter_mean:.4f}, Std = {inter_std:.4f}")
print(f"✅ Separation: {intra_mean - inter_mean:.4f}")

In [None]:
# Cell 24: Confusion Matrices for Top 20 and Bottom 20 Classes
print("Plotting confusion matrices...")

# Get accuracy per class
class_accuracies = per_class_accuracy

# Get indices of top 20 and bottom 20 classes
top_20_idx = np.argsort(class_accuracies)[-20:][::-1]  # Highest to lowest
bottom_20_idx = np.argsort(class_accuracies)[:20]  # Lowest to highest

# Create confusion matrices for selected classes
def plot_confusion_matrix_for_classes(class_indices, title, filename):
    """Plot confusion matrix for specific classes"""
    # Filter predictions and labels for selected classes
    mask = np.isin(val_labels, class_indices)
    filtered_labels = val_labels[mask]
    filtered_preds = preds[mask]
    
    # Map to new indices for better visualization
    idx_map = {old: new for new, old in enumerate(class_indices)}
    mapped_labels = np.array([idx_map[l] for l in filtered_labels])
    mapped_preds = np.array([idx_map[p] if p in idx_map else -1 for p in filtered_preds])
    
    # Create confusion matrix
    cm = confusion_matrix(mapped_labels, mapped_preds, labels=range(len(class_indices)))
    
    # Plot
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=[f'C{idx}' for idx in class_indices],
                yticklabels=[f'C{idx}' for idx in class_indices])
    plt.xlabel('Predicted Label', fontsize=12)
    plt.ylabel('True Label', fontsize=12)
    plt.title(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.show()

# Plot top 20 classes
plot_confusion_matrix_for_classes(
    top_20_idx, 
    'Confusion Matrix - Top 20 Performing Classes',
    'confusion_matrix_top20.png'
)

# Plot bottom 20 classes
plot_confusion_matrix_for_classes(
    bottom_20_idx,
    'Confusion Matrix - Bottom 20 Performing Classes',
    'confusion_matrix_bottom20.png'
)

print("✅ Confusion matrices saved")

In [None]:
# Cell 25: Calibration Curve (Reliability Diagram)
print("Plotting calibration curve...")

# For multi-class calibration, we'll use the max softmax probability as confidence
# First, get logits from the model
all_logits = []
all_labels = []

for images, labels in tqdm(val_ds, desc="Getting logits"):
    embeddings = backbone(images, training=False)
    logits = arcface(embeddings)  # Use arcface without labels for inference
    all_logits.append(logits.numpy())
    all_labels.append(labels.numpy())

all_logits = np.vstack(all_logits)
all_labels = np.concatenate(all_labels)

# Convert to probabilities using softmax
probabilities = tf.nn.softmax(all_logits).numpy()

# Get predicted class and confidence
predicted_class = np.argmax(probabilities, axis=1)
confidence = np.max(probabilities, axis=1)
correct = (predicted_class == all_labels).astype(int)

# Create calibration curve
prob_true, prob_pred = calibration_curve(correct, confidence, n_bins=10, strategy='uniform')

plt.figure(figsize=(10, 8))
plt.plot(prob_pred, prob_true, 's-', linewidth=2, markersize=8, label='Model')
plt.plot([0, 1], [0, 1], 'k--', label='Perfectly Calibrated')

# Add bars for number of samples in each bin
bin_counts = np.histogram(confidence, bins=10)[0]
bin_centers = (prob_pred[1:] + prob_pred[:-1]) / 2
plt.bar(bin_centers, prob_pred[:-1], width=0.08, alpha=0.3, 
        color='gray', label='Bin Density')

plt.xlabel('Mean Predicted Probability', fontsize=12)
plt.ylabel('Fraction of Positives', fontsize=12)
plt.title('Reliability Diagram (Calibration Curve)', fontsize=14, fontweight='bold')
plt.legend(loc='upper left')
plt.grid(True, alpha=0.3)

# Calculate ECE (Expected Calibration Error)
bin_edges = np.linspace(0, 1, 11)
bin_indices = np.digitize(confidence, bin_edges) - 1
ece = 0
for i in range(10):
    mask = bin_indices == i
    if np.sum(mask) > 0:
        bin_conf = np.mean(confidence[mask])
        bin_acc = np.mean(correct[mask])
        ece += np.abs(bin_conf - bin_acc) * np.sum(mask)
ece /= len(confidence)

plt.text(0.05, 0.95, f'ECE = {ece:.4f}', transform=plt.gca().transAxes,
         fontsize=12, verticalalignment='top',
         bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.tight_layout()
plt.savefig('calibration_curve.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"✅ Calibration curve saved (ECE = {ece:.4f})")

In [None]:
# Cell 26: Confidence Distribution (Correct vs Incorrect)
print("Plotting confidence distribution...")

# Separate confidence for correct and incorrect predictions
correct_conf = confidence[correct == 1]
incorrect_conf = confidence[correct == 0]

plt.figure(figsize=(14, 6))

# Plot 1: Histogram
plt.subplot(1, 2, 1)
plt.hist(correct_conf, bins=50, alpha=0.6, color='green', 
         label=f'Correct (n={len(correct_conf):,})', density=True)
plt.hist(incorrect_conf, bins=50, alpha=0.6, color='red', 
         label=f'Incorrect (n={len(incorrect_conf):,})', density=True)
plt.xlabel('Confidence', fontsize=12)
plt.ylabel('Density', fontsize=12)
plt.title('Confidence Distribution: Correct vs Incorrect', fontsize=14, fontweight='bold')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot 2: Box plot
plt.subplot(1, 2, 2)
box_data = [correct_conf, incorrect_conf]
box_labels = ['Correct', 'Incorrect']
bp = plt.boxplot(box_data, labels=box_labels, patch_artist=True)
bp['boxes'][0].set_facecolor('lightgreen')
bp['boxes'][1].set_facecolor('lightcoral')
plt.ylabel('Confidence', fontsize=12)
plt.title('Confidence Comparison', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3, axis='y')

# Add statistics
mean_correct = np.mean(correct_conf)
mean_incorrect = np.mean(incorrect_conf)
std_correct = np.std(correct_conf)
std_incorrect = np.std(incorrect_conf)

plt.text(0.05, 0.95, f'Correct: μ={mean_correct:.3f}, σ={std_correct:.3f}\n'
                     f'Incorrect: μ={mean_incorrect:.3f}, σ={std_incorrect:.3f}',
         transform=plt.gca().transAxes, fontsize=10,
         bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

plt.tight_layout()
plt.savefig('confidence_distribution.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"✅ Confidence distribution saved")
print(f"   Correct predictions: Mean = {mean_correct:.4f}, Std = {std_correct:.4f}")
print(f"   Incorrect predictions: Mean = {mean_incorrect:.4f}, Std = {std_incorrect:.4f}")

In [None]:
# Cell 27: Confidence vs Accuracy Correlation
print("Plotting confidence vs accuracy correlation...")

# Bin confidence and compute accuracy per bin
num_bins = 20
bins = np.linspace(0, 1, num_bins + 1)
bin_centers = (bins[:-1] + bins[1:]) / 2

bin_accuracies = []
bin_confidences = []
bin_counts = []

for i in range(num_bins):
    mask = (confidence >= bins[i]) & (confidence < bins[i+1])
    if np.sum(mask) > 0:
        bin_acc = np.mean(correct[mask])
        bin_conf = np.mean(confidence[mask])
        bin_accuracies.append(bin_acc)
        bin_confidences.append(bin_conf)
        bin_counts.append(np.sum(mask))

bin_accuracies = np.array(bin_accuracies)
bin_confidences = np.array(bin_confidences)
bin_counts = np.array(bin_counts)

plt.figure(figsize=(12, 8))

# Plot 1: Scatter plot with size proportional to bin count
scatter = plt.scatter(bin_confidences, bin_accuracies, 
                      s=bin_counts/10, alpha=0.7, 
                      c=bin_counts, cmap='viridis')
plt.colorbar(scatter, label='Number of Samples (scaled)')
plt.plot([0, 1], [0, 1], 'k--', alpha=0.5, label='Perfect Calibration')

# Add regression line
if len(bin_confidences) > 1:
    z = np.polyfit(bin_confidences, bin_accuracies, 1)
    p = np.poly1d(z)
    plt.plot(bin_confidences, p(bin_confidences), "r--", alpha=0.8, 
             label=f'Linear fit: y={z[0]:.3f}x+{z[1]:.3f}')

plt.xlabel('Mean Confidence in Bin', fontsize=12)
plt.ylabel('Accuracy in Bin', fontsize=12)
plt.title('Confidence vs Accuracy Correlation', fontsize=14, fontweight='bold')
plt.legend()
plt.grid(True, alpha=0.3)

# Calculate correlation
correlation = np.corrcoef(bin_confidences, bin_accuracies)[0, 1]
plt.text(0.05, 0.95, f'Correlation: {correlation:.4f}', transform=plt.gca().transAxes,
         fontsize=12, verticalalignment='top',
         bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

plt.tight_layout()
plt.savefig('confidence_vs_accuracy.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"✅ Confidence vs Accuracy plot saved (Correlation: {correlation:.4f})")

In [None]:
# Cell 28: ROC-AUC Stability Across Training (Real Analysis)
print("Analyzing ROC-AUC stability across training...")

# Helper function for fallback analysis
def analyze_from_saved_embeddings(embeddings_dir, val_labels):
    """Analyze ROC-AUC from saved embedding files"""
    epoch_files = sorted([f for f in os.listdir(embeddings_dir) if f.startswith('embeddings_epoch_')])
    epoch_roc_aucs = []
    
    for i, epoch_file in enumerate(tqdm(epoch_files, desc="Analyzing saved embeddings")):
        embeddings = np.load(os.path.join(embeddings_dir, epoch_file))
        scores, pair_labels = compute_similarity_scores_for_auc(embeddings, val_labels)
        roc_auc = roc_auc_score(pair_labels, scores)
        epoch_roc_aucs.append(roc_auc)
    
    # Save the computed ROC-AUC values
    np.save(os.path.join(embeddings_dir, 'roc_auc_history.npy'), epoch_roc_aucs)
    
    return epoch_roc_aucs

def create_demonstration_plot(history):
    """Create demonstration plot when no embeddings are saved"""
    epochs = range(1, len(history.history['val_accuracy']) + 1)
    # Create realistic-looking synthetic data
    roc_auc_values = 0.7 + 0.3 * (1 - np.exp(-np.array(epochs) / 8)) + np.random.normal(0, 0.01, len(epochs))
    
    plt.figure(figsize=(10, 8))
    plt.plot(epochs, roc_auc_values, 'b-o', linewidth=2, markersize=8)
    plt.fill_between(epochs, roc_auc_values - 0.02, roc_auc_values + 0.02, alpha=0.2, color='blue')
    
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('ROC-AUC', fontsize=12)
    plt.title('ROC-AUC Stability Across Training (Demonstration)', fontsize=14, fontweight='bold')
    plt.grid(True, alpha=0.3)
    
    best_epoch = np.argmax(roc_auc_values)
    plt.scatter(epochs[best_epoch], roc_auc_values[best_epoch], 
               color='red', s=200, zorder=5, 
               label=f'Best: {roc_auc_values[best_epoch]:.4f} @ Epoch {epochs[best_epoch]}')
    
    plt.legend()
    plt.tight_layout()
    plt.savefig('roc_auc_stability_demo.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print("⚠️ Created demonstration plot (not actual data)")
    print("   For real analysis, ensure embeddings are saved during training")

def create_stability_plot(epochs, roc_auc_values):
    """Create stability plot from computed values"""
    plt.figure(figsize=(10, 8))
    plt.plot(epochs, roc_auc_values, 'b-o', linewidth=2, markersize=8)
    
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('ROC-AUC', fontsize=12)
    plt.title('ROC-AUC Stability Across Training', fontsize=14, fontweight='bold')
    plt.grid(True, alpha=0.3)
    
    best_epoch = np.argmax(roc_auc_values)
    plt.scatter(epochs[best_epoch], roc_auc_values[best_epoch], 
               color='red', s=200, zorder=5, 
               label=f'Best: {roc_auc_values[best_epoch]:.4f} @ Epoch {epochs[best_epoch]}')
    
    plt.legend()
    plt.tight_layout()
    plt.savefig('roc_auc_stability.png', dpi=300, bbox_inches='tight')
    plt.show()

# Load ROC-AUC history from saved files
embeddings_dir = 'embeddings_by_epoch'

# Check if embeddings were saved
if os.path.exists(embeddings_dir):
    # Load ROC-AUC history if available
    roc_auc_file = os.path.join(embeddings_dir, 'roc_auc_history.npy')
    if os.path.exists(roc_auc_file):
        epoch_roc_aucs = np.load(roc_auc_file)
        epochs = range(1, len(epoch_roc_aucs) + 1)
        
        print(f"Found ROC-AUC values for {len(epoch_roc_aucs)} epochs")
        
        # Load validation accuracy for comparison
        val_accuracies = history.history.get('val_accuracy', [])
        val_losses = history.history.get('val_loss', [])
        
        # Create comprehensive stability analysis
        plt.figure(figsize=(16, 12))
        
        # Plot 1: ROC-AUC vs Epoch
        plt.subplot(2, 2, 1)
        plt.plot(epochs, epoch_roc_aucs, 'b-o', linewidth=2, markersize=8, label='ROC-AUC')
        
        # Add smoothing
        if len(epoch_roc_aucs) > 5:
            from scipy.ndimage import gaussian_filter1d
            smoothed = gaussian_filter1d(epoch_roc_aucs, sigma=1)
            plt.plot(epochs, smoothed, 'r--', linewidth=2, alpha=0.7, label='Smoothed')
        
        plt.xlabel('Epoch', fontsize=12)
        plt.ylabel('ROC-AUC', fontsize=12)
        plt.title('ROC-AUC Stability Across Training', fontsize=14, fontweight='bold')
        plt.grid(True, alpha=0.3)
        
        # Mark best epoch
        best_epoch_idx = np.argmax(epoch_roc_aucs)
        best_roc_auc = epoch_roc_aucs[best_epoch_idx]
        plt.scatter(epochs[best_epoch_idx], best_roc_auc, 
                   color='red', s=200, zorder=5, 
                   label=f'Best: {best_roc_auc:.4f} @ Epoch {epochs[best_epoch_idx]}')
        
        # Add final epoch
        final_roc_auc = epoch_roc_aucs[-1]
        plt.scatter(epochs[-1], final_roc_auc, 
                   color='green', s=200, zorder=5,
                   label=f'Final: {final_roc_auc:.4f}')
        
        plt.legend()
        plt.ylim([max(0.5, min(epoch_roc_aucs) - 0.05), min(1.0, max(epoch_roc_aucs) + 0.05)])
        
        # Plot 2: ROC-AUC vs Validation Accuracy
        plt.subplot(2, 2, 2)
        if len(val_accuracies) == len(epoch_roc_aucs):
            plt.scatter(val_accuracies, epoch_roc_aucs, c=epochs, cmap='viridis', s=100)
            plt.colorbar(label='Epoch')
            
            # Add correlation line
            z = np.polyfit(val_accuracies, epoch_roc_aucs, 1)
            p = np.poly1d(z)
            plt.plot(val_accuracies, p(val_accuracies), "r--", alpha=0.8, 
                    label=f'Correlation: {np.corrcoef(val_accuracies, epoch_roc_aucs)[0,1]:.3f}')
            
            plt.xlabel('Validation Accuracy', fontsize=12)
            plt.ylabel('ROC-AUC', fontsize=12)
            plt.title('ROC-AUC vs Validation Accuracy', fontsize=14, fontweight='bold')
            plt.grid(True, alpha=0.3)
            plt.legend()
        
        # Plot 3: ROC-AUC vs Validation Loss
        plt.subplot(2, 2, 3)
        if len(val_losses) == len(epoch_roc_aucs):
            plt.scatter(val_losses, epoch_roc_aucs, c=epochs, cmap='plasma', s=100)
            plt.colorbar(label='Epoch')
            
            # Add correlation line
            z = np.polyfit(val_losses, epoch_roc_aucs, 1)
            p = np.poly1d(z)
            plt.plot(val_losses, p(val_losses), "r--", alpha=0.8,
                    label=f'Correlation: {np.corrcoef(val_losses, epoch_roc_aucs)[0,1]:.3f}')
            
            plt.xlabel('Validation Loss', fontsize=12)
            plt.ylabel('ROC-AUC', fontsize=12)
            plt.title('ROC-AUC vs Validation Loss', fontsize=14, fontweight='bold')
            plt.grid(True, alpha=0.3)
            plt.legend()
        
        # Plot 4: ROC-AUC Improvement per Epoch
        plt.subplot(2, 2, 4)
        if len(epoch_roc_aucs) > 1:
            improvements = np.diff(epoch_roc_aucs)
            plt.bar(epochs[1:], improvements, color=['green' if x > 0 else 'red' for x in improvements])
            plt.axhline(y=0, color='black', linestyle='-', alpha=0.3)
            plt.xlabel('Epoch', fontsize=12)
            plt.ylabel('ROC-AUC Improvement', fontsize=12)
            plt.title('ROC-AUC Improvement per Epoch', fontsize=14, fontweight='bold')
            plt.grid(True, alpha=0.3, axis='y')
            
            # Add cumulative improvement
            cumulative_improvement = epoch_roc_aucs[-1] - epoch_roc_aucs[0]
            plt.text(0.05, 0.95, f'Total Improvement: {cumulative_improvement:.4f}',
                    transform=plt.gca().transAxes, fontsize=12,
                    bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
        
        plt.tight_layout()
        plt.savefig('roc_auc_stability_analysis.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        # Detailed statistical analysis
        print("\n" + "="*60)
        print("ROC-AUC STABILITY ANALYSIS")
        print("="*60)
        
        # Calculate statistics
        roc_auc_array = np.array(epoch_roc_aucs)
        
        print(f"Number of Epochs Analyzed: {len(roc_auc_array)}")
        print(f"Initial ROC-AUC (Epoch 1): {roc_auc_array[0]:.4f}")