In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
# ====== Testing Code ======
import os
import numpy as np
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
from tqdm import tqdm
import random

# ====== Test configuration ======
class TestConfig:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    LFW_PATH = "/kaggle/input/lfw-dataset/lfw-deepfunneled/lfw-deepfunneled"
    BATCH_SIZE = 32
    FEATURE_DIM = 512

test_config = TestConfig()

# ====== ArcFace model definition consistent with training ======
class ArcFaceModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        self.feature_extractor = nn.Sequential(*list(self.backbone.children())[:-1])
        self.feature_dim = 512
        self.fc = nn.Linear(self.feature_dim, num_classes, bias=False)
        nn.init.normal_(self.fc.weight, std=0.01)
        self.dropout = nn.Dropout(0.2)

    def get_features(self, x):
        """Safely obtain raw features (without dropout)"""
        feat = self.feature_extractor(x)
        feat = feat.view(feat.size(0), -1)
        return feat

    def forward(self, x, margin=None):
        """Forward pass identical to training code"""
        feat = self.feature_extractor(x)
        feat = feat.view(feat.size(0), -1)
        feat = self.dropout(feat)  # used in training, disabled in eval()
        feat_norm = F.normalize(feat, p=2, dim=1)

        weight_norm = F.normalize(self.fc.weight, p=2, dim=1)
        cosine = torch.matmul(feat_norm, weight_norm.t())

        # per-sample margin support
        if margin is None:
            margin_tensor = 0.0
        else:
            if isinstance(margin, (float, int)):
                margin_tensor = float(margin)
            elif isinstance(margin, torch.Tensor):
                if margin.dim() == 1 and margin.size(0) == cosine.size(0):
                    margin_tensor = margin.view(-1, 1).to(cosine.device)
                else:
                    margin_tensor = float(margin.mean().item())
            else:
                margin_tensor = float(margin)

        cosine = cosine - margin_tensor
        output = cosine * 30.0  # use the scale factor from training
        return feat, output

# ====== LFW test dataset ======
class LFWDataset(Dataset):
    def __init__(self, lfw_path, transform=None, num_pairs=600):
        self.transform = transform
        self.lfw_path = lfw_path
        self.pairs = self._build_pairs_from_lfw(num_pairs)
        
    def _build_pairs_from_lfw(self, num_pairs):
        """Build pairs directly from the LFW dataset"""
        print(f"Building test pairs from LFW dataset: {self.lfw_path}")
        
        if not os.path.exists(self.lfw_path):
            print(f"ERROR: LFW path does not exist: {self.lfw_path}")
            return []
        
        person_folders = [f for f in os.listdir(self.lfw_path) 
                         if os.path.isdir(os.path.join(self.lfw_path, f))]
        
        print(f"Found {len(person_folders)} person folders")
        
        if len(person_folders) == 0:
            print("ERROR: No person folders found")
            return []
        
        person_images = {}
        valid_persons = []
        
        for person in person_folders:
            person_path = os.path.join(self.lfw_path, person)
            images = [f for f in os.listdir(person_path) 
                     if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            if images:
                person_images[person] = images
                valid_persons.append(person)
        
        print(f"Valid persons: {len(valid_persons)}")
        
        if len(valid_persons) < 2:
            print("ERROR: Not enough valid persons to build pairs")
            return []
        
        pairs = []
        same_person_count = num_pairs // 2
        diff_person_count = num_pairs - same_person_count
        
        same_count = 0
        persons_with_multiple = [p for p in valid_persons if len(person_images[p]) >= 2]
        
        print(f"Persons with multiple images: {len(persons_with_multiple)}")
        
        for person in persons_with_multiple:
            if same_count >= same_person_count:
                break
            images = person_images[person]
            img1, img2 = random.sample(images, 2)
            img1_path = os.path.join(self.lfw_path, person, img1)
            img2_path = os.path.join(self.lfw_path, person, img2)
            pairs.append((img1_path, img2_path, 1))
            same_count += 1
        
        diff_count = 0
        while diff_count < diff_person_count and len(valid_persons) >= 2:
            person1, person2 = random.sample(valid_persons, 2)
            images1 = person_images[person1]
            images2 = person_images[person2]
            
            if images1 and images2:
                img1 = random.choice(images1)
                img2 = random.choice(images2)
                img1_path = os.path.join(self.lfw_path, person1, img1)
                img2_path = os.path.join(self.lfw_path, person2, img2)
                pairs.append((img1_path, img2_path, 0))
                diff_count += 1
        
        print(f"Built {len(pairs)} test pairs (same: {same_count}, different: {diff_count})")
        
        if pairs:
            print("\nExample pairs:")
            for i in range(min(3, len(pairs))):
                path1, path2, label = pairs[i]
                name1 = os.path.basename(os.path.dirname(path1))
                name2 = os.path.basename(os.path.dirname(path2))
                img1 = os.path.basename(path1)
                img2 = os.path.basename(path2)
                relation = "same person" if label == 1 else "different person"
                print(f"  {i+1}. {name1}/{img1} vs {name2}/{img2} - {relation}")
        
        return pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img1_path, img2_path, label = self.pairs[idx]
        
        try:
            img1 = Image.open(img1_path).convert('RGB')
            img2 = Image.open(img2_path).convert('RGB')
        except Exception as e:
            print(f"Failed to load image: {e}")
            img1 = Image.new('RGB', (112, 112), color='gray')
            img2 = Image.new('RGB', (112, 112), color='gray')
        
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
            
        return img1, img2, torch.tensor(label, dtype=torch.float32)

# ====== Improved model loading function ======
def load_trained_model(model_path):
    """Load a trained model - automatically handle class count mismatch"""
    try:
        # Load state dict first to detect class count
        state_dict = torch.load(model_path, map_location=test_config.device)
        
        # Handle different storage formats
        if isinstance(state_dict, dict):
            # If dict, check if it includes 'state_dict' and other info
            if 'state_dict' in state_dict:
                # New format: contains state_dict and metadata
                actual_state_dict = state_dict['state_dict']
                if 'num_classes' in state_dict:
                    num_classes = state_dict['num_classes']
                    print(f"Loaded number of classes from metadata: {num_classes}")
                else:
                    # Infer num_classes from fc.weight
                    num_classes = actual_state_dict['fc.weight'].shape[0]
                    print(f"Inferred number of classes from fc.weight: {num_classes}")
            else:
                # Old format: only state dict
                actual_state_dict = state_dict
                num_classes = actual_state_dict['fc.weight'].shape[0]
                print(f"Inferred number of classes from fc.weight: {num_classes}")
        else:
            # If it's directly a state dict
            actual_state_dict = state_dict
            num_classes = actual_state_dict['fc.weight'].shape[0]
            print(f"Inferred number of classes from fc.weight: {num_classes}")
        
        # Create model
        model = ArcFaceModel(num_classes=num_classes)
        
        # Simple key processing: remove 'module.' prefix
        new_state_dict = {}
        for k, v in actual_state_dict.items():
            if k.startswith('module.'):
                new_k = k[7:]
            else:
                new_k = k
            new_state_dict[new_k] = v
        
        # Try strict loading
        try:
            model.load_state_dict(new_state_dict, strict=True)
            print("Strict load succeeded")
        except RuntimeError as e:
            if "size mismatch" in str(e):
                print("Size mismatch detected, attempting flexible loading...")
                # Flexible loading: only load matching layers
                model_state_dict = model.state_dict()
                filtered_state_dict = {}
                
                for key, value in new_state_dict.items():
                    if key in model_state_dict:
                        if model_state_dict[key].shape == value.shape:
                            filtered_state_dict[key] = value
                        else:
                            print(f"  Skipping mismatched layer: {key} (expected: {model_state_dict[key].shape}, actual: {value.shape})")
                    else:
                        print(f"  Skipping non-existing layer: {key}")
                
                # Load filtered state dict
                model.load_state_dict(filtered_state_dict, strict=False)
                print(f"Flexible load complete, loaded {len(filtered_state_dict)}/{len(new_state_dict)} layers")
            else:
                raise e
        
        model.to(test_config.device)
        model.eval()  # important: disable dropout and set BN to eval
        
        # Verify model loaded correctly
        with torch.no_grad():
            test_input = torch.randn(1, 3, 112, 112).to(test_config.device)
            features, outputs = model(test_input)
            print(f"Model verification: feature shape {features.shape}, output shape {outputs.shape}")
            
        return model
        
    except Exception as e:
        print(f"Failed to load model {model_path}: {e}")
        import traceback
        traceback.print_exc()
        return None

# ====== Feature extraction ======
def extract_features(model, images):
    """Extract image features - use get_features for consistency"""
    with torch.no_grad():
        # Use get_features instead of forward to ensure features used for margin calculation are consistent
        features = model.get_features(images)
        features = F.normalize(features, p=2, dim=1)
    return features

# ====== Compute accuracy and ROC curve ======
def evaluate_model(model, test_loader, model_name):
    """Evaluate model performance on the test set"""
    if len(test_loader.dataset) == 0:
        print(f"ERROR: Test dataset is empty, cannot evaluate {model_name}")
        return 0, 0, 0, [], []
    
    similarities = []
    labels = []
    
    print(f"Evaluating model {model_name}...")
    
    with torch.no_grad():
        for img1, img2, label in tqdm(test_loader, desc=f"Testing {model_name}"):
            img1, img2 = img1.to(test_config.device), img2.to(test_config.device)
            
            # Use consistent feature extraction method
            feat1 = extract_features(model, img1)
            feat2 = extract_features(model, img2)
            
            similarity = F.cosine_similarity(feat1, feat2)
            
            similarities.extend(similarity.cpu().numpy())
            labels.extend(label.numpy())
    
    similarities = np.array(similarities)
    labels = np.array(labels)
    
    # Compute ROC and AUC
    fpr, tpr, thresholds = roc_curve(labels, similarities)
    roc_auc = auc(fpr, tpr)
    
    gmeans = np.sqrt(tpr * (1-fpr))
    ix = np.argmax(gmeans)
    best_threshold = thresholds[ix]
    
    predictions = (similarities >= best_threshold).astype(int)
    accuracy = np.mean(predictions == labels)
    
    print(f"{model_name} results:")
    print(f"   Accuracy: {accuracy:.4f}")
    print(f"   AUC: {roc_auc:.4f}")
    print(f"   Best threshold: {best_threshold:.4f}")
    print(f"   Similarity range: [{similarities.min():.3f}, {similarities.max():.3f}]")
    
    return accuracy, roc_auc, best_threshold, fpr, tpr

# ====== Plot ROC curves ======
def plot_roc_curves(results, save_path='/kaggle/working/roc_curves.pdf'):
    """Plot ROC curves for all models"""
    plt.figure(figsize=(10, 8))
    
    colors = ['red', 'blue', 'green', 'orange']
    for i, (model_name, (accuracy, roc_auc, _, fpr, tpr)) in enumerate(results.items()):
        plt.plot(fpr, tpr, color=colors[i], lw=2, 
                label=f'{model_name} (AUC = {roc_auc:.3f}, Acc = {accuracy:.3f})')
    

    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('LFW - ROC')
    plt.legend(loc="lower right")
    plt.grid(True)
    
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

# ====== Analyze results ======
def analyze_results(results):
    """Analyze test results"""
    print("\nResults Analysis:")
    print("=" * 30)
    
    if len(results) < 2:
        print("At least two model results are required for comparison")
        return
    
    sorted_results = sorted(results.items(), key=lambda x: x[1][0], reverse=True)
    
    print("Model ranking (by accuracy):")
    for i, (model_name, (accuracy, roc_auc, threshold, _, _)) in enumerate(sorted_results):
        print(f"  {i+1}. {model_name}: {accuracy:.4f}")
    
    best_acc = sorted_results[0][1][0]
    worst_acc = sorted_results[-1][1][0]
    improvement = best_acc - worst_acc
    
    print(f"\nBest model accuracy is higher than worst model by: {improvement:.4f} ({improvement*100:.2f}%)")
    
    if "fixed_margin" in results:
        fixed_acc = results["fixed_margin"][0]
        for model_name, (accuracy, _, _, _, _) in results.items():
            if model_name != "fixed_margin":
                diff = accuracy - fixed_acc
                if diff > 0:
                    print(f"{model_name} is better than fixed_margin: +{diff:.4f} (+{diff*100:.2f}%)")
                elif diff < 0:
                    print(f"{model_name} is worse than fixed_margin: {diff:.4f} ({diff*100:.2f}%)")
                else:
                    print(f"{model_name} is equal to fixed_margin")

# ====== Main test function ======
def test_models_on_lfw():
    """Test all models on the LFW dataset"""
    
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.CenterCrop((112, 112)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    print("Creating LFW test dataset...")
    test_dataset = LFWDataset(
        lfw_path=test_config.LFW_PATH,
        transform=transform,
        num_pairs=600
    )
    
    if len(test_dataset) == 0:
        print("ERROR: Failed to create LFW test dataset")
        return
    
    test_loader = DataLoader(test_dataset, batch_size=test_config.BATCH_SIZE, shuffle=False)
    
    model_paths = {
        "fixed_margin": "/kaggle/working/models/best_model_fixed_margin.pth",
        "quality_adaptive": "/kaggle/working/models/best_model_quality_adaptive.pth", 
        "confidence_adaptive": "/kaggle/working/models/best_model_confidence_adaptive.pth",
        "easy_hard_norm": "/kaggle/working/models/best_model_easy_hard_norm.pth"
    }

    results = {}

    for model_name, model_path in model_paths.items():
        if not os.path.exists(model_path):
           print(f"WARNING: Model not found: {model_path}")
           continue

        print(f"\n{'='*50}")
        print(f"Loading and testing model: {model_name}...")
        print(f"{'='*50}")

        model = load_trained_model(model_path)
        if model is None:
            continue

        # LFW evaluation uses cosine similarity without training-time margin by default
        accuracy, roc_auc, threshold, fpr, tpr = evaluate_model(model, test_loader, model_name)
        results[model_name] = (accuracy, roc_auc, threshold, fpr, tpr)

        if torch.cuda.is_available():
           torch.cuda.empty_cache()

    
    if results:
        plot_roc_curves(results)
        
        print("\nLFW final results:")
        print("=" * 50)
        for model_name, (accuracy, roc_auc, threshold, _, _) in results.items():
            print(f"{model_name:20} | Accuracy: {accuracy:.4f} | AUC: {roc_auc:.4f} | Threshold: {threshold:.4f}")
        
        best_model = max(results.items(), key=lambda x: x[1][0])
        print(f"\nBest model: {best_model[0]} (Accuracy: {best_model[1][0]:.4f})")
        
        analyze_results(results)
        
        with open('/kaggle/working/lfw_test_results.txt', 'w') as f:
            f.write("LFW Test Results\n")
            f.write("=" * 50 + "\n")
            for model_name, (accuracy, roc_auc, threshold, _, _) in results.items():
                f.write(f"{model_name:20} | Accuracy: {accuracy:.4f} | AUC: {roc_auc:.4f} | Threshold: {threshold:.4f}\n")
            f.write(f"\nBest model: {best_model[0]} (Accuracy: {best_model[1][0]:.4f})\n")
        
        print("Results saved to: /kaggle/working/lfw_test_results.txt")
    else:
        print("ERROR: No models were successfully tested")

# ====== Run tests ======
if __name__ == "__main__":
    print("Starting LFW evaluation...")
    print(f"Using LFW path: {test_config.LFW_PATH}")
    test_models_on_lfw()