In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
# ====== Training Code ======
import os
import random
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
from tqdm import tqdm
import cv2
import matplotlib.pyplot as plt

# ================= Configuration =================
class Config:
    """
    Configuration parameters
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    MS1M_PATH = "/kaggle/input/ms1m-arcface-dataset/ms1m-arcface"
    BATCH_SIZE = 32
    NUM_EPOCHS = 15
    LEARNING_RATE = 0.01
    FEATURE_DIM = 512
    SCALE_FACTOR = 30.0
    BASE_MARGIN = 0.5
    ALPHA = 0.3
    MODES = ["fixed_margin", "quality_adaptive", "confidence_adaptive", "easy_hard_norm"]

config = Config()
print(f"Using device: {config.device}")

# ================= Dataset =================
class FaceDataset(Dataset):
    """
    Custom face dataset
    1. Load a specified number of identities, each with a certain number of images
    2. Returns (image_tensor, label)
    """
    def __init__(self, root_dir, num_identities=200, samples_per_identity=15, transform=None):
        self.transform = transform
        self.samples = []
        
        all_folders = [f for f in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, f))]
        selected_identities = random.sample(all_folders, min(num_identities, len(all_folders)))
        
        for identity in selected_identities:
            identity_path = os.path.join(root_dir, identity)
            image_files = [f for f in os.listdir(identity_path) if f.lower().endswith(('.jpg','.png','.jpeg'))]
            
            if len(image_files) < 5:
                continue
                
            selected_images = random.sample(image_files, min(samples_per_identity, len(image_files)))
            for img_file in selected_images:
                self.samples.append((os.path.join(identity_path, img_file), identity))
        
        unique_identities = list(set([identity for _,identity in self.samples]))
        self.identity_to_label = {identity: idx for idx, identity in enumerate(unique_identities)}
        self.num_classes = len(unique_identities)
        
        print(f"Loaded {len(self.samples)} images, {self.num_classes} identities")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, identity = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.identity_to_label[identity]
        if self.transform:
            image = self.transform(image)
        return image, label

# ================= ArcFace Model =================
class ArcFaceModel(nn.Module):
    """
    ResNet18 + ArcFace classification
    1. Supports per-sample margin
    2. Supports direct extraction of backbone features (used for margin calculation)
    """
    def __init__(self, num_classes):
        super().__init__()
        self.backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        self.feature_extractor = nn.Sequential(*list(self.backbone.children())[:-1])
        self.feature_dim = 512
        self.fc = nn.Linear(self.feature_dim, num_classes, bias=False)
        nn.init.normal_(self.fc.weight, std=0.01)
        self.dropout = nn.Dropout(0.2)

    def get_features(self, x):
        """
        Safely extract raw features (without dropout)
        """
        feat = self.feature_extractor(x)
        feat = feat.view(feat.size(0), -1)
        return feat

    def forward(self, x, margin=None):
        """
        Forward pass
        margin: None / scalar / tensor(B,) for per-sample margin
        """
        feat = self.feature_extractor(x)
        feat = feat.view(feat.size(0), -1)
        feat = self.dropout(feat)
        feat_norm = F.normalize(feat, p=2, dim=1)

        weight_norm = F.normalize(self.fc.weight, p=2, dim=1)

        cosine = torch.matmul(feat_norm, weight_norm.t())

        # per-sample margin support
        if margin is None:
            margin_tensor = 0.0
        else:
            if isinstance(margin, (float, int)):
                margin_tensor = float(margin)
            elif isinstance(margin, torch.Tensor):
                if margin.dim() == 1 and margin.size(0) == cosine.size(0):
                    margin_tensor = margin.view(-1, 1).to(cosine.device)
                else:
                    margin_tensor = float(margin.mean().item())
            else:
                margin_tensor = float(margin)

        cosine = cosine - margin_tensor
        output = cosine * config.SCALE_FACTOR
        return feat, output

# ================= Margin Calculation =================
def calculate_margin(mode, features=None, images=None, logits=None, device=None):
    """
    Returns a per-sample margin tensor of shape (B,)
    1. fixed_margin: all samples share the same margin
    2. quality_adaptive: based on image sharpness
    3. confidence_adaptive: based on softmax top-1 probability / feature norm
    4. easy_hard_norm: based on feature norm
    """
    device = device if device is not None else config.device
    B = features.size(0)

    if mode == "fixed_margin":
        return torch.full((B,), config.BASE_MARGIN, device=device, dtype=torch.float32)

    elif mode == "quality_adaptive":
        margins = []
        images_cpu = images.detach().cpu()
        for i in range(images_cpu.size(0)):
            img_np = (images_cpu[i].permute(1,2,0).numpy() * 255).astype(np.uint8)
            gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
            sharpness = cv2.Laplacian(gray, cv2.CV_64F).var()
            normalized_sharpness = np.clip(sharpness / 1000.0, 0.0, 1.0)
            margin = config.BASE_MARGIN * (1.0 + config.ALPHA * normalized_sharpness)
            margins.append(margin)
        margins = np.array(margins, dtype=np.float32)
        return torch.from_numpy(margins).to(device)

    elif mode == "confidence_adaptive":
        # based on feature norm confidence
        feat_norms = torch.norm(features, p=2, dim=1)
        mean_norm = feat_norms.mean()
        std_norm = feat_norms.std()
        conf_norm = (feat_norms - mean_norm) / (std_norm + 1e-8)
        conf_norm = torch.clamp(conf_norm, -1.0, 1.0)

        # image quality scores (sampled subset)
        quality_scores = []
        images_cpu = images.detach().cpu()
        for i in range(min(5, images_cpu.size(0))):
            img_np = (images_cpu[i].permute(1,2,0).numpy() * 255).astype(np.uint8)
            gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
            sharpness = cv2.Laplacian(gray, cv2.CV_64F).var()
            quality_scores.append(min(sharpness / 1000.0, 1.0))
        avg_quality = np.mean(quality_scores) if quality_scores else 0.5

        combined_conf = 0.7 * conf_norm + 0.3 * avg_quality
        margins = config.BASE_MARGIN * (0.8 + 0.2 * combined_conf)
        return margins.to(device)

    elif mode == "easy_hard_norm":
        feat_norms = torch.norm(features, p=2, dim=1)
        mean_norm = feat_norms.mean()
        std_norm = feat_norms.std()
        norm_scores = (feat_norms - mean_norm) / (std_norm + 1e-8)

        margins = torch.empty_like(norm_scores)
        margins[norm_scores >= 0] = config.BASE_MARGIN * (1 + config.ALPHA)  # easy samples
        margins[norm_scores < 0] = config.BASE_MARGIN * (1 - config.ALPHA)   # hard samples
        return margins.to(device)

    return torch.full((B,), config.BASE_MARGIN, device=device, dtype=torch.float32)

# ================= Training Function =================
def train_single_model(mode):
    """
    Train a single model
    Returns: train_loss_list, val_acc_list (recorded per epoch)
    """
    print(f"\nðŸŽ¯ Start training {mode} model...")
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.RandomCrop((112, 112)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    dataset = FaceDataset(config.MS1M_PATH, num_identities=200, samples_per_identity=15, transform=transform)

    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False)

    model = ArcFaceModel(num_classes=dataset.num_classes).to(config.device)
    optimizer = optim.SGD(model.parameters(), lr=config.LEARNING_RATE, momentum=0.9, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.5)

    best_val_acc = 0.0
    patience = 5
    patience_counter = 0

    # record curves
    train_loss_list = []
    val_acc_list = []

    for epoch in range(config.NUM_EPOCHS):
        # ===== Training =====
        model.train()
        epoch_loss = 0.0
        train_correct, train_total = 0, 0

        for imgs, labels in tqdm(train_loader, desc=f"Training {mode}"):
            imgs, labels = imgs.to(config.device), labels.to(config.device)
            # compute margin (no gradient)
            with torch.no_grad():
                features_for_margin = model.get_features(imgs)
                _, logits_for_margin = model(imgs, margin=None)
                margins = calculate_margin(mode, features=features_for_margin,
                                           images=imgs, logits=logits_for_margin, device=config.device)

            # forward + backward
            features, outputs = model(imgs, margin=margins)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            epoch_loss += loss.item()

        avg_train_loss = epoch_loss / len(train_loader) if len(train_loader) > 0 else 0.0
        train_loss_list.append(avg_train_loss)
        train_acc = train_correct / train_total if train_total > 0 else 0.0

        # ===== Validation =====
        model.eval()
        val_correct, val_total = 0, 0
        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(config.device), labels.to(config.device)
                _, outputs = model(imgs, margin=None)
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        val_acc = val_correct / val_total if val_total > 0 else 0.0
        val_acc_list.append(val_acc)

        scheduler.step()
        print(f"{mode} Epoch {epoch+1}: train_loss={avg_train_loss:.4f}, train_acc={train_acc:.4f}, val_acc={val_acc:.4f}")

        # early stopping
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            os.makedirs('/kaggle/working/models', exist_ok=True)
            torch.save(model.state_dict(), f'/kaggle/working/models/best_model_{mode}.pth')
            print(f"Saved best model: val_acc = {val_acc:.4f}")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

    print(f"{mode} model training completed, best val_acc: {best_val_acc:.4f}")
    return model, val_loader, best_val_acc, train_loss_list, val_acc_list

# ================= Main Loop (Train Four Models) + Plot =================
if __name__ == "__main__":
    print("Start training four margin models...")
    results = {}
    curve_loss = {}
    curve_valacc = {}

    for mode in config.MODES:
        try:
            model, val_loader, best_acc, train_loss_list, val_acc_list = train_single_model(mode)
            results[mode] = best_acc
            curve_loss[mode] = train_loss_list
            curve_valacc[mode] = val_acc_list
        except Exception as e:
            print(f"{mode} model training failed: {e}")
            continue

    print("\n All models training completed!")
    print("Final results:")
    for mode, acc in results.items():
        print(f"  {mode}: val_acc = {acc:.4f}")

    # ===== Plot setup (four models in one figure) =====
    os.makedirs('/kaggle/working/plots', exist_ok=True)

    # color dictionary: distinguishable colors
    color_list = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red']
    colors = {mode: color_list[i % len(color_list)] for i, mode in enumerate(config.MODES)}

    # ===== Figure 1: Train Loss =====
    plt.figure(figsize=(10, 6))
    max_epochs = max(len(v) for v in curve_loss.values()) if curve_loss else 0
    for mode in config.MODES:
        if mode in curve_loss:
            epochs = range(1, len(curve_loss[mode]) + 1)
            plt.plot(epochs, curve_loss[mode], label=mode, color=colors.get(mode), linewidth=1)
    plt.title("Train Loss Comparison (All Models)")
    plt.xlabel("Epoch")
    plt.ylabel("Train Loss")
    plt.xticks(range(1, max_epochs + 1))
    plt.legend()
    plt.grid(True)
    train_loss_path = '/kaggle/working/plots/train_loss_comparison.pdf'
    plt.savefig(train_loss_path, bbox_inches='tight')
    plt.show()
    print(f"Saved: {train_loss_path}")

    # ===== Figure 2: Validation Accuracy =====
    plt.figure(figsize=(10, 6))
    max_epochs_val = max(len(v) for v in curve_valacc.values()) if curve_valacc else 0
    for mode in config.MODES:
        if mode in curve_valacc:
            epochs = range(1, len(curve_valacc[mode]) + 1)
            plt.plot(epochs, curve_valacc[mode], label=mode, color=colors.get(mode), linewidth=1)
    plt.title("Validation Accuracy Comparison (All Models)")
    plt.xlabel("Epoch")
    plt.ylabel("Validation Accuracy")
    plt.xticks(range(1, max_epochs_val + 1))
    plt.legend()
    plt.grid(True, linestyle='-', color='lightgray', linewidth=0.8, alpha=0.7)
    val_acc_path = '/kaggle/working/plots/val_acc_comparison.pdf'
    plt.savefig(val_acc_path, bbox_inches='tight')
    plt.show()
    print(f"Saved: {val_acc_path}")

    # List saved model files (if any)
    if os.path.exists('/kaggle/working/models'):
        saved_models = [f for f in os.listdir('/kaggle/working/models') if f.endswith('.pth')]
    else:
        saved_models = []
    print(f"Saved model files: {saved_models}")
