In [1]:
# cifar10_asctl.py
# A-SCTL on CIFAR-10 with architecture and training flow matching MNIST implementation

import os
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler, Subset, random_split
from torchvision import datasets, transforms
from tqdm import tqdm
from collections import defaultdict


# ==================== Configuration ====================
class Config:
    data_dir = "./data"
    out_dir = "./output_cifar10"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Network parameters
    embedding_dim = 128
    
    # PK Sampling
    batch_p = 10      # P classes per batch
    batch_k = 6       # K samples per class
    
    # A-SCTL parameters (same as MNIST / paper)
    margin_intra = 0.01   # m1
    margin_inter = 1.0    # m2
    beta_min = 0.1
    beta_max = 0.9
    eps = 1e-8
    
    # Training parameters
    lr = 1e-3
    momentum = 0.9
    weight_decay = 1e-5
    epochs = 20
    seed = 42
    use_lr_scheduler = True  # cosine annealing, same as MNIST
    
    # Prediction network training
    pred_lr = 1e-3
    pred_momentum = 0.9
    pred_epochs = 10
    pred_batch_size = 256
    
    # Dataset split (80/10/10 as per paper)
    train_split = 0.8
    test_split = 0.1
    val_split = 0.1
    
    # Monitoring and debugging
    num_workers = 2
    pin_memory = True
    print_every = 50
    save_embeddings = True
    embeddings_file = "cifar10_train_embeddings.npy"
    labels_file = "cifar10_train_labels.npy"
    debug_triplets = True
    log_beta_history = True
    beta_history_file = "cifar10_beta_evolution.npy"
    validate_every = 5  # Validate every N epochs
    monitor_collapse = True  # Check for embedding collapse


# ==================== PK Batch Sampler ====================
class PKSampler(Sampler):
    """Samples P classes, then K samples per class."""
    def __init__(self, labels, P, K):
        self.labels = np.array(labels)
        self.P = P
        self.K = K
        
        # Group indices by label
        self.label_to_indices = defaultdict(list)
        for idx, label in enumerate(labels):
            self.label_to_indices[label].append(idx)
        
        # Verify each class has at least K samples
        for label, indices in self.label_to_indices.items():
            if len(indices) < self.K:
                raise ValueError(f"Label {label} has only {len(indices)} samples, need {self.K}")
        
        self.labels_set = list(self.label_to_indices.keys())
        self.num_samples = len(labels)
        
        # Calculate batches per epoch
        self.batches_per_epoch = max(1, self.num_samples // (self.P * self.K))
        
        print(f"PKSampler initialized: {len(self.labels_set)} classes, "
              f"{self.batches_per_epoch} batches/epoch, "
              f"batch_size={self.P*self.K}")
    
    def __iter__(self):
        for _ in range(self.batches_per_epoch):
            # Randomly select P classes
            if self.P > len(self.labels_set):
                raise ValueError(f"P={self.P} exceeds number of classes {len(self.labels_set)}")
            
            selected_classes = np.random.choice(self.labels_set, self.P, replace=False)
            batch_indices = []
            
            # Sample K instances from each selected class
            for cls in selected_classes:
                indices = np.random.choice(self.label_to_indices[cls], self.K, replace=False)
                batch_indices.extend(indices.tolist())
            
            yield batch_indices
    
    def __len__(self):
        return self.batches_per_epoch


# ==================== CNN Encoder (Paper Architecture, 3-channel CIFAR) ====================
class EncoderCNN(nn.Module):
    """
    CNN architecture from paper Section IV-B:
    Input: 3 channels (CIFAR-10)
    Conv layers: (3, 5x5) -> (64, 5x5) -> (128, 3x3) -> (256, 3x3) -> (128, 2x2)
    Output: 128-dimensional L2-normalized embeddings
    """
    def __init__(self, embedding_dim=128, input_channels=3):
        super().__init__()
        
        # Convolutional layers as specified in paper
        self.features = nn.Sequential(
            # Layer 1: (3, 5x5) -> 64 channels
            nn.Conv2d(input_channels, 64, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 32x32 -> 16x16
            
            # Layer 2: (64, 5x5) -> 128 channels
            nn.Conv2d(64, 128, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 16x16 -> 8x8
            
            # Layer 3: (128, 3x3) -> 256 channels
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            
            # Layer 4: (256, 3x3) -> 128 channels
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            
            # Layer 5: (128, 2x2) -> 128 channels
            nn.Conv2d(128, 128, kernel_size=2, padding=0),
            nn.ReLU(inplace=True),
            
            # Global average pooling
            nn.AdaptiveAvgPool2d(1)
        )
        
        # Fully connected to embedding dimension
        self.fc = nn.Linear(128, embedding_dim)
        self.bn = nn.BatchNorm1d(embedding_dim)
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.bn(x)
        # L2 normalization - ensures embeddings have unit norm
        x = F.normalize(x, p=2, dim=1)
        return x


# ==================== Prediction Network ====================
class PredictionNetwork(nn.Module):
    """
    Simple 1-layer MLP for classification from embeddings.
    Same as MNIST implementation: 128-dim -> hidden -> 10-dim.
    """
    def __init__(self, embedding_dim=128, num_classes=10, hidden_dim=256):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, num_classes)
        )
    
    def forward(self, x):
        return self.fc(x)


# ==================== Enhanced Triplet Mining ====================
def mine_hard_triplets_batch(embeddings, labels, debug=False, return_stats=False):
    """
    Same as MNIST implementation.
    Enhanced online hard triplet mining with quality metrics.
    """
    embeddings = embeddings.detach()
    labels_np = labels.cpu().numpy()
    N = embeddings.size(0)
    
    # Validate batch composition
    unique_classes = np.unique(labels_np)
    if len(unique_classes) < 2:
        if debug:
            print(f"⚠️  WARNING: Only {len(unique_classes)} class(es) in batch!")
        return [] if not return_stats else ([], {})
    
    # Compute pairwise distances
    dists = torch.cdist(embeddings, embeddings, p=2)
    
    # Initialize statistics
    stats = {
        'avg_d_ap': 0.0,
        'avg_d_an': 0.0,
        'margin_violations': 0.0,
        'semi_hard_ratio': 0.0,
        'min_dist': 0.0,
        'max_dist': 0.0,
        'mean_dist': 0.0
    }
    
    # Debug: Check distance statistics
    if debug or return_stats:
        max_dist = dists.max().item()
        min_dist = dists[dists > 0].min().item() if (dists > 0).any() else 0
        mean_dist = dists[dists > 0].mean().item() if (dists > 0).any() else 0
        stats['min_dist'] = min_dist
        stats['max_dist'] = max_dist
        stats['mean_dist'] = mean_dist
        
        if debug:
            print(f"  Distance range: [{min_dist:.4f}, {max_dist:.4f}], mean: {mean_dist:.4f}")
            norms = torch.norm(embeddings, p=2, dim=1)
            print(f"  Embedding norms: mean={norms.mean():.4f}, std={norms.std():.4f}")
    
    triplets = []
    no_pos_count = 0
    no_neg_count = 0
    semi_hard_count = 0
    d_ap_sum = 0.0
    d_an_sum = 0.0
    violation_count = 0
    
    for i in range(N):
        anchor_label = labels_np[i]
        
        pos_mask = (labels_np == anchor_label)
        pos_mask[i] = False
        pos_indices = np.where(pos_mask)[0]
        
        if len(pos_indices) == 0:
            no_pos_count += 1
            continue
        
        neg_indices = np.where(labels_np != anchor_label)[0]
        if len(neg_indices) == 0:
            no_neg_count += 1
            continue
        
        pos_dists = dists[i, pos_indices]
        hardest_pos_idx = pos_indices[torch.argmax(pos_dists).item()]
        d_ap = dists[i, hardest_pos_idx].item()
        
        neg_dists = dists[i, neg_indices]
        
        semi_hard_mask = neg_dists > d_ap
        if semi_hard_mask.any():
            semi_hard_indices = neg_indices[semi_hard_mask.cpu().numpy()]
            semi_hard_dists = neg_dists[semi_hard_mask]
            hardest_neg_idx = semi_hard_indices[torch.argmin(semi_hard_dists).item()]
            semi_hard_count += 1
        else:
            hardest_neg_idx = neg_indices[torch.argmin(neg_dists).item()]
        
        d_an = dists[i, hardest_neg_idx].item()
        
        d_ap_sum += d_ap
        d_an_sum += d_an
        if d_ap >= d_an:
            violation_count += 1
        
        triplets.append((i, hardest_pos_idx, int(hardest_neg_idx)))
    
    if debug and (no_pos_count > 0 or no_neg_count > 0):
        print(f"  Skipped: {no_pos_count} (no pos), {no_neg_count} (no neg)")
    
    if len(triplets) == 0:
        if debug:
            print(f"⚠️  WARNING: No valid triplets found! Batch size={N}, Classes={len(unique_classes)}")
        return [] if not return_stats else ([], stats)
    
    if return_stats:
        stats['avg_d_ap'] = d_ap_sum / len(triplets)
        stats['avg_d_an'] = d_an_sum / len(triplets)
        stats['margin_violations'] = violation_count / len(triplets)
        stats['semi_hard_ratio'] = semi_hard_count / len(triplets)
    
    return triplets if not return_stats else (triplets, stats)


# ==================== Embedding Collapse Monitor ====================
def check_embedding_collapse(embeddings, threshold=0.1):
    with torch.no_grad():
        pairwise_dists = torch.cdist(embeddings, embeddings, p=2)
        mask = ~torch.eye(pairwise_dists.size(0), dtype=torch.bool, device=pairwise_dists.device)
        valid_dists = pairwise_dists[mask]
        
        mean_dist = valid_dists.mean().item()
        std_dist = valid_dists.std().item()
        is_collapsing = mean_dist < threshold
        
    return is_collapsing, mean_dist, std_dist


# ==================== Validation Function ====================
def validate_embedding(model, val_dataset, cfg):
    model.eval()
    
    val_labels = [int(val_dataset.dataset[idx][1]) for idx in val_dataset.indices]
    pk_sampler = PKSampler(val_labels, cfg.batch_p, cfg.batch_k)
    
    val_loader = DataLoader(
        val_dataset,
        batch_sampler=pk_sampler,
        num_workers=cfg.num_workers,
        pin_memory=cfg.pin_memory
    )
    
    total_loss = 0.0
    total_lintra = 0.0
    total_linter = 0.0
    num_batches = 0
    
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(cfg.device)
            labels = labels.to(cfg.device)
            
            embeddings = model(images)
            triplets = mine_hard_triplets_batch(embeddings, labels, debug=False)
            
            if len(triplets) == 0:
                continue
            
            d_ap_list = []
            d_an_list = []
            
            for (anchor_idx, pos_idx, neg_idx) in triplets:
                d_ap = torch.norm(embeddings[anchor_idx] - embeddings[pos_idx], p=2)
                d_an = torch.norm(embeddings[anchor_idx] - embeddings[neg_idx], p=2)
                d_ap_list.append(d_ap)
                d_an_list.append(d_an)
            
            d_ap = torch.stack(d_ap_list)
            d_an = torch.stack(d_an_list)
            
            L_intra = F.relu(d_ap - cfg.margin_intra).mean()
            L_inter = F.relu(cfg.margin_inter - d_an).mean()
            
            beta = L_intra / (L_intra + L_inter + cfg.eps)
            beta = torch.clamp(beta, cfg.beta_min, cfg.beta_max).item()
            
            loss = (1.0 - beta) * L_intra + beta * L_inter
            
            total_loss += loss.item()
            total_lintra += L_intra.item()
            total_linter += L_inter.item()
            num_batches += 1
    
    if num_batches == 0:
        return 0.0, 0.0, 0.0
    
    avg_loss = total_loss / num_batches
    avg_lintra = total_lintra / num_batches
    avg_linter = total_linter / num_batches
    
    return avg_loss, avg_lintra, avg_linter


# ==================== Enhanced Training with Monitoring ====================
def train_epoch_embedding(model, loader, optimizer, scheduler, cfg, epoch, val_dataset=None):
    model.train()
    
    epoch_loss = 0.0
    epoch_lintra = 0.0
    epoch_linter = 0.0
    epoch_betas = []
    epoch_stats = defaultdict(list)
    num_batches = 0
    total_triplets = 0
    
    pbar = tqdm(loader, desc=f"Epoch {epoch}/{cfg.epochs}", leave=False)
    
    for batch_idx, (images, labels) in enumerate(pbar):
        images = images.to(cfg.device)
        labels = labels.to(cfg.device)
        
        optimizer.zero_grad()
        embeddings = model(images)
        
        if cfg.monitor_collapse:
            is_collapsing, mean_dist, std_dist = check_embedding_collapse(embeddings)
            if is_collapsing and batch_idx % 100 == 0:
                print(f"\n⚠️  WARNING: Potential embedding collapse! Mean dist: {mean_dist:.4f}")
        
        debug_mode = (cfg.debug_triplets and epoch == 1 and batch_idx == 0)
        
        return_stats = (batch_idx == 0)
        result = mine_hard_triplets_batch(embeddings, labels, debug=debug_mode, return_stats=return_stats)
        
        if return_stats:
            triplets, stats = result
            for key, val in stats.items():
                epoch_stats[key].append(val)
        else:
            triplets = result
        
        if len(triplets) == 0:
            continue
        
        d_ap_list = []
        d_an_list = []
        
        for (anchor_idx, pos_idx, neg_idx) in triplets:
            d_ap = torch.norm(embeddings[anchor_idx] - embeddings[pos_idx], p=2)
            d_an = torch.norm(embeddings[anchor_idx] - embeddings[neg_idx], p=2)
            d_ap_list.append(d_ap)
            d_an_list.append(d_an)
        
        d_ap = torch.stack(d_ap_list)
        d_an = torch.stack(d_an_list)
        
        L_intra_terms = F.relu(d_ap - cfg.margin_intra)
        L_inter_terms = F.relu(cfg.margin_inter - d_an)
        
        L_intra = L_intra_terms.mean()
        L_inter = L_inter_terms.mean()
        
        with torch.no_grad():
            beta = L_intra / (L_intra + L_inter + cfg.eps)
            beta = torch.clamp(beta, cfg.beta_min, cfg.beta_max).item()
        
        loss = (1.0 - beta) * L_intra + beta * L_inter
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_lintra += L_intra.item()
        epoch_linter += L_inter.item()
        epoch_betas.append(beta)
        num_batches += 1
        total_triplets += len(triplets)
        
        if (batch_idx + 1) % cfg.print_every == 0 or batch_idx == 0:
            avg_loss = epoch_loss / num_batches
            avg_lintra = epoch_lintra / num_batches
            avg_linter = epoch_linter / num_batches
            avg_beta = np.mean(epoch_betas)
            
            pbar.set_postfix({
                'loss': f'{avg_loss:.4f}',
                'L_in': f'{avg_lintra:.4f}',
                'L_out': f'{avg_linter:.4f}',
                'β': f'{avg_beta:.3f}',
                'trip': total_triplets
            })
    
    if scheduler is not None:
        scheduler.step()
    
    val_info = ""
    if val_dataset is not None and epoch % cfg.validate_every == 0:
        val_loss, val_lintra, val_linter = validate_embedding(model, val_dataset, cfg)
        val_info = f" | Val Loss: {val_loss:.4f}"
    
    avg_loss = epoch_loss / max(num_batches, 1)
    avg_lintra = epoch_lintra / max(num_batches, 1)
    avg_linter = epoch_linter / max(num_batches, 1)
    avg_beta = np.mean(epoch_betas) if len(epoch_betas) > 0 else 0.5
    
    if epoch_stats:
        stats_str = " | ".join([f"{k}: {np.mean(v):.4f}" for k, v in epoch_stats.items() if v])
        if stats_str:
            print(f"  Triplet stats: {stats_str}")
    
    return avg_loss, avg_lintra, avg_linter, avg_beta, val_info


def train_prediction_network(encoder, pred_net, train_dataset, test_dataset, cfg):
    print("\n" + "="*60)
    print("Training Prediction Network (CIFAR-10)")
    print("="*60)
    
    encoder.eval()
    pred_net.train()
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=cfg.pred_batch_size,
        shuffle=True,
        num_workers=cfg.num_workers,
        pin_memory=cfg.pin_memory
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=cfg.pred_batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=cfg.pin_memory
    )
    
    optimizer = torch.optim.SGD(
        pred_net.parameters(),
        lr=cfg.pred_lr,
        momentum=cfg.pred_momentum,
        weight_decay=cfg.weight_decay
    )
    
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.pred_epochs)
    
    criterion = nn.CrossEntropyLoss()
    
    best_acc = 0.0
    
    for epoch in range(1, cfg.pred_epochs + 1):
        pred_net.train()
        train_loss = 0
        correct = 0
        total = 0
        
        for images, labels in tqdm(train_loader, desc=f"Pred Epoch {epoch}/{cfg.pred_epochs}", leave=False):
            images = images.to(cfg.device)
            labels = labels.to(cfg.device)
            
            with torch.no_grad():
                embeddings = encoder(images)
            
            optimizer.zero_grad()
            outputs = pred_net(embeddings)
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
        
        train_acc = 100. * correct / total
        
        scheduler.step()
        
        test_acc = evaluate_classification(encoder, pred_net, test_loader, cfg)
        
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Pred Epoch {epoch}/{cfg.pred_epochs} | "
              f"LR: {current_lr:.6f} | "
              f"Train Loss: {train_loss/len(train_loader):.4f} | "
              f"Train Acc: {train_acc:.2f}% | "
              f"Test Acc: {test_acc:.2f}%")
        
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save({
                'encoder': encoder.state_dict(),
                'predictor': pred_net.state_dict(),
                'accuracy': best_acc
            }, os.path.join(cfg.out_dir, 'best_cifar10_model.pt'))
            print(f"  ✓ New best model saved! Accuracy: {best_acc:.2f}%")
    
    print(f"\nBest Test Accuracy (CIFAR-10): {best_acc:.2f}%")
    return best_acc


def evaluate_classification(encoder, pred_net, test_loader, cfg):
    encoder.eval()
    pred_net.eval()
    
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(cfg.device)
            labels = labels.to(cfg.device)
            
            embeddings = encoder(images)
            outputs = pred_net(embeddings)
            
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    accuracy = 100. * correct / total
    return accuracy


def extract_embeddings(model, dataset, cfg):
    model.eval()
    loader = DataLoader(
        dataset,
        batch_size=256,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=cfg.pin_memory
    )
    
    all_embeddings = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Extracting embeddings (CIFAR-10)"):
            images = images.to(cfg.device)
            embeddings = model(images).cpu().numpy()
            all_embeddings.append(embeddings)
            all_labels.append(labels.numpy())
    
    embeddings = np.vstack(all_embeddings)
    labels = np.concatenate(all_labels)
    
    os.makedirs(cfg.out_dir, exist_ok=True)
    np.save(os.path.join(cfg.out_dir, cfg.embeddings_file), embeddings)
    np.save(os.path.join(cfg.out_dir, cfg.labels_file), labels)
    
    print(f"\n✓ Saved CIFAR-10 embeddings: {embeddings.shape}")
    return embeddings, labels


# ==================== Dataset Split (80/10/10) ====================
def create_paper_splits(full_dataset, train_ratio=0.8, test_ratio=0.1, val_ratio=0.1, seed=42):
    assert abs(train_ratio + test_ratio + val_ratio - 1.0) < 1e-6, "Ratios must sum to 1"
    
    total_size = len(full_dataset)
    train_size = int(train_ratio * total_size)
    test_size = int(test_ratio * total_size)
    val_size = total_size - train_size - test_size
    
    generator = torch.Generator().manual_seed(seed)
    train_dataset, test_dataset, val_dataset = random_split(
        full_dataset, [train_size, test_size, val_size], generator=generator
    )
    
    return train_dataset, test_dataset, val_dataset


# ==================== Main ====================
def main():
    cfg = Config()
    
    os.makedirs(cfg.out_dir, exist_ok=True)
    
    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(cfg.seed)
    
    print("="*60)
    print("A-SCTL Training on CIFAR-10 (Same Implementation as MNIST)")
    print("="*60)
    print(f"Device: {cfg.device}")
    print(f"Embedding dim: {cfg.embedding_dim}")
    print(f"Margins: m1={cfg.margin_intra}, m2={cfg.margin_inter}")
    print(f"Beta range: [{cfg.beta_min}, {cfg.beta_max}]")
    print(f"Batch sampling: P={cfg.batch_p}, K={cfg.batch_k}")
    print(f"Optimizer: SGD (lr={cfg.lr}, momentum={cfg.momentum})")
    print(f"LR Scheduler: {'Enabled (Cosine)' if cfg.use_lr_scheduler else 'Disabled'}")
    print(f"Dataset split: {cfg.train_split}/{cfg.test_split}/{cfg.val_split}")
    print(f"Monitoring: Collapse={cfg.monitor_collapse}, Validation every {cfg.validate_every} epochs")
    print("="*60 + "\n")
    
    # CIFAR-10 transforms (standard)
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.4914, 0.4822, 0.4465),
            (0.2023, 0.1994, 0.2010)
        )
    ])
    
    transform_full = transform_train  # for splitting; test transform can be same or simpler
    
    full_dataset = datasets.CIFAR10(
        cfg.data_dir, train=True, download=True, transform=transform_full
    )
    
    train_dataset, test_dataset, val_dataset = create_paper_splits(
        full_dataset, cfg.train_split, cfg.test_split, cfg.val_split, cfg.seed
    )
    
    print(f"Dataset sizes (CIFAR-10):")
    print(f"  Training:   {len(train_dataset)} samples")
    print(f"  Test:       {len(test_dataset)} samples")
    print(f"  Validation: {len(val_dataset)} samples")
    print()
    
    train_labels = [int(train_dataset.dataset[idx][1]) for idx in train_dataset.indices]
    
    pk_sampler = PKSampler(train_labels, cfg.batch_p, cfg.batch_k)
    print()
    
    train_loader = DataLoader(
        train_dataset,
        batch_sampler=pk_sampler,
        num_workers=cfg.num_workers,
        pin_memory=cfg.pin_memory
    )
    
    encoder = EncoderCNN(cfg.embedding_dim, input_channels=3).to(cfg.device)
    
    optimizer = torch.optim.SGD(
        encoder.parameters(),
        lr=cfg.lr,
        momentum=cfg.momentum,
        weight_decay=cfg.weight_decay
    )
    
    scheduler = None
    if cfg.use_lr_scheduler:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.epochs)
    
    print("Phase 1: Training Embedding Network with A-SCTL (CIFAR-10)")
    print("-"*60)
    print("Same features as MNIST: LR scheduling, validation, collapse monitoring\n")
    
    all_beta_history = []
    
    for epoch in range(1, cfg.epochs + 1):
        loss, L_intra, L_inter, beta, val_info = train_epoch_embedding(
            encoder, train_loader, optimizer, scheduler, cfg, epoch, 
            val_dataset=(val_dataset if epoch % cfg.validate_every == 0 else None)
        )
        
        all_beta_history.append(beta)
        
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Epoch {epoch:2d}/{cfg.epochs} | "
              f"LR: {current_lr:.6f} | "
              f"Loss: {loss:.4f} | "
              f"L_intra: {L_intra:.4f} | "
              f"L_inter: {L_inter:.4f} | "
              f"β_avg: {beta:.4f}{val_info}")
        
        if epoch % 5 == 0 or epoch == cfg.epochs:
            ckpt_path = os.path.join(cfg.out_dir, f'encoder_cifar10_epoch{epoch}.pt')
            torch.save(encoder.state_dict(), ckpt_path)
            
            if cfg.log_beta_history:
                beta_path = os.path.join(cfg.out_dir, f'cifar10_beta_epoch{epoch}.npy')
                np.save(beta_path, np.array(all_beta_history))
    
    if cfg.log_beta_history:
        final_beta_path = os.path.join(cfg.out_dir, cfg.beta_history_file)
        np.save(final_beta_path, np.array(all_beta_history))
        print(f"\n✓ Saved CIFAR-10 beta evolution history: {final_beta_path}")
    
    print("\n" + "="*60)
    print("Phase 2: Training Prediction Network (CIFAR-10)")
    print("-"*60)
    
    pred_net = PredictionNetwork(cfg.embedding_dim, num_classes=10).to(cfg.device)
    best_acc = train_prediction_network(encoder, pred_net, train_dataset, test_dataset, cfg)
    
    if cfg.save_embeddings:
        print("\n" + "="*60)
        print("Extracting CIFAR-10 Embeddings")
        print("-"*60)
        extract_embeddings(encoder, train_dataset, cfg)
    
    print("\n" + "="*60)
    print("CIFAR-10 Training Complete!")
    print(f"Final Test Accuracy: {best_acc:.2f}%")
    print("="*60)


if __name__ == "__main__":
    main()

A-SCTL Training on CIFAR-10 (Same Implementation as MNIST)
Device: cuda
Embedding dim: 128
Margins: m1=0.01, m2=1.0
Beta range: [0.1, 0.9]
Batch sampling: P=10, K=6
Optimizer: SGD (lr=0.001, momentum=0.9)
LR Scheduler: Enabled (Cosine)
Dataset split: 0.8/0.1/0.1
Monitoring: Collapse=True, Validation every 5 epochs



100%|██████████| 170M/170M [00:01<00:00, 105MB/s]  


Dataset sizes (CIFAR-10):
  Training:   40000 samples
  Test:       5000 samples
  Validation: 5000 samples

PKSampler initialized: 10 classes, 666 batches/epoch, batch_size=60

Phase 1: Training Embedding Network with A-SCTL (CIFAR-10)
------------------------------------------------------------
Same features as MNIST: LR scheduling, validation, collapse monitoring



Epoch 1/20:   0%|          | 0/666 [00:00<?, ?it/s]

  Distance range: [0.0002, 1.9249], mean: 1.3799
  Embedding norms: mean=1.0000, std=0.0000


                                                                                                                          

  Triplet stats: avg_d_ap: 1.6531 | avg_d_an: 1.4921 | margin_violations: 0.1333 | semi_hard_ratio: 0.8667 | min_dist: 0.0002 | max_dist: 1.9249 | mean_dist: 1.3799
Epoch  1/20 | LR: 0.000994 | Loss: 0.1507 | L_intra: 1.4805 | L_inter: 0.0029 | β_avg: 0.9000


                                                                                                                          

  Triplet stats: avg_d_ap: 1.5090 | avg_d_an: 1.5270 | margin_violations: 0.0000 | semi_hard_ratio: 1.0000 | min_dist: 0.0002 | max_dist: 1.8988 | mean_dist: 1.3914
Epoch  2/20 | LR: 0.000976 | Loss: 0.1451 | L_intra: 1.4320 | L_inter: 0.0021 | β_avg: 0.9000


                                                                                                                          

  Triplet stats: avg_d_ap: 1.3814 | avg_d_an: 1.4047 | margin_violations: 0.0000 | semi_hard_ratio: 1.0000 | min_dist: 0.0002 | max_dist: 1.9165 | mean_dist: 1.3361
Epoch  3/20 | LR: 0.000946 | Loss: 0.1427 | L_intra: 1.4024 | L_inter: 0.0028 | β_avg: 0.9000


                                                                                                                          

  Triplet stats: avg_d_ap: 1.4439 | avg_d_an: 1.4399 | margin_violations: 0.0333 | semi_hard_ratio: 0.9667 | min_dist: 0.0002 | max_dist: 1.8848 | mean_dist: 1.3585
Epoch  4/20 | LR: 0.000905 | Loss: 0.1414 | L_intra: 1.3867 | L_inter: 0.0030 | β_avg: 0.9000


                                                                                                                          

PKSampler initialized: 10 classes, 83 batches/epoch, batch_size=60
  Triplet stats: avg_d_ap: 1.3638 | avg_d_an: 1.3699 | margin_violations: 0.0167 | semi_hard_ratio: 0.9833 | min_dist: 0.0002 | max_dist: 1.8614 | mean_dist: 1.3645
Epoch  5/20 | LR: 0.000854 | Loss: 0.1399 | L_intra: 1.3730 | L_inter: 0.0028 | β_avg: 0.9000 | Val Loss: 0.1384


                                                                                                                          

  Triplet stats: avg_d_ap: 1.4923 | avg_d_an: 1.4900 | margin_violations: 0.0167 | semi_hard_ratio: 0.9833 | min_dist: 0.0002 | max_dist: 1.8865 | mean_dist: 1.3519
Epoch  6/20 | LR: 0.000794 | Loss: 0.1389 | L_intra: 1.3596 | L_inter: 0.0033 | β_avg: 0.9000


                                                                                                                          

  Triplet stats: avg_d_ap: 1.3873 | avg_d_an: 1.4061 | margin_violations: 0.0000 | semi_hard_ratio: 1.0000 | min_dist: 0.0002 | max_dist: 1.8967 | mean_dist: 1.3703
Epoch  7/20 | LR: 0.000727 | Loss: 0.1379 | L_intra: 1.3489 | L_inter: 0.0033 | β_avg: 0.9000


                                                                                                                          

  Triplet stats: avg_d_ap: 1.4020 | avg_d_an: 1.4068 | margin_violations: 0.0167 | semi_hard_ratio: 0.9833 | min_dist: 0.0002 | max_dist: 1.9281 | mean_dist: 1.3470
Epoch  8/20 | LR: 0.000655 | Loss: 0.1368 | L_intra: 1.3408 | L_inter: 0.0030 | β_avg: 0.9000


                                                                                                                          

  Triplet stats: avg_d_ap: 1.3990 | avg_d_an: 1.4029 | margin_violations: 0.0167 | semi_hard_ratio: 0.9833 | min_dist: 0.0002 | max_dist: 1.8635 | mean_dist: 1.3748
Epoch  9/20 | LR: 0.000578 | Loss: 0.1358 | L_intra: 1.3298 | L_inter: 0.0032 | β_avg: 0.9000


                                                                                                                           

PKSampler initialized: 10 classes, 83 batches/epoch, batch_size=60
  Triplet stats: avg_d_ap: 1.4091 | avg_d_an: 1.4228 | margin_violations: 0.0000 | semi_hard_ratio: 1.0000 | min_dist: 0.0002 | max_dist: 1.8326 | mean_dist: 1.3800
Epoch 10/20 | LR: 0.000500 | Loss: 0.1349 | L_intra: 1.3214 | L_inter: 0.0031 | β_avg: 0.9000 | Val Loss: 0.1312


                                                                                                                           

  Triplet stats: avg_d_ap: 1.3117 | avg_d_an: 1.3158 | margin_violations: 0.0167 | semi_hard_ratio: 0.9833 | min_dist: 0.0002 | max_dist: 1.8550 | mean_dist: 1.3044
Epoch 11/20 | LR: 0.000422 | Loss: 0.1339 | L_intra: 1.3102 | L_inter: 0.0033 | β_avg: 0.9000


                                                                                                                           

  Triplet stats: avg_d_ap: 1.3972 | avg_d_an: 1.4067 | margin_violations: 0.0167 | semi_hard_ratio: 0.9833 | min_dist: 0.0002 | max_dist: 1.9015 | mean_dist: 1.3520
Epoch 12/20 | LR: 0.000345 | Loss: 0.1320 | L_intra: 1.2920 | L_inter: 0.0031 | β_avg: 0.9000


                                                                                                                           

  Triplet stats: avg_d_ap: 1.3097 | avg_d_an: 1.3136 | margin_violations: 0.0167 | semi_hard_ratio: 0.9833 | min_dist: 0.0002 | max_dist: 1.8493 | mean_dist: 1.2862
Epoch 13/20 | LR: 0.000273 | Loss: 0.1204 | L_intra: 1.1526 | L_inter: 0.0057 | β_avg: 0.9000


                                                                                                                           

  Triplet stats: avg_d_ap: 1.0934 | avg_d_an: 1.0836 | margin_violations: 0.0333 | semi_hard_ratio: 0.9667 | min_dist: 0.0002 | max_dist: 1.5638 | mean_dist: 1.0719
Epoch 14/20 | LR: 0.000206 | Loss: 0.1164 | L_intra: 1.1040 | L_inter: 0.0066 | β_avg: 0.9000


                                                                                                                           

PKSampler initialized: 10 classes, 83 batches/epoch, batch_size=60
  Triplet stats: avg_d_ap: 1.1193 | avg_d_an: 1.1288 | margin_violations: 0.0000 | semi_hard_ratio: 1.0000 | min_dist: 0.0002 | max_dist: 1.5311 | mean_dist: 1.0932
Epoch 15/20 | LR: 0.000146 | Loss: 0.1149 | L_intra: 1.0948 | L_inter: 0.0060 | β_avg: 0.9000 | Val Loss: 0.1294


                                                                                                                           

  Triplet stats: avg_d_ap: 1.1194 | avg_d_an: 1.1322 | margin_violations: 0.0000 | semi_hard_ratio: 1.0000 | min_dist: 0.0002 | max_dist: 1.5194 | mean_dist: 1.0850
Epoch 16/20 | LR: 0.000095 | Loss: 0.1141 | L_intra: 1.0907 | L_inter: 0.0056 | β_avg: 0.9000


                                                                                                                           

  Triplet stats: avg_d_ap: 1.1112 | avg_d_an: 1.1140 | margin_violations: 0.0167 | semi_hard_ratio: 0.9833 | min_dist: 0.0002 | max_dist: 1.4903 | mean_dist: 1.0633
Epoch 17/20 | LR: 0.000054 | Loss: 0.1138 | L_intra: 1.0869 | L_inter: 0.0057 | β_avg: 0.9000


                                                                                                                           

  Triplet stats: avg_d_ap: 1.1048 | avg_d_an: 1.1057 | margin_violations: 0.0167 | semi_hard_ratio: 0.9833 | min_dist: 0.0002 | max_dist: 1.4441 | mean_dist: 1.0792
Epoch 18/20 | LR: 0.000024 | Loss: 0.1133 | L_intra: 1.0840 | L_inter: 0.0054 | β_avg: 0.9000


                                                                                                                           

  Triplet stats: avg_d_ap: 1.0854 | avg_d_an: 1.0927 | margin_violations: 0.0000 | semi_hard_ratio: 1.0000 | min_dist: 0.0003 | max_dist: 1.4445 | mean_dist: 1.0618
Epoch 19/20 | LR: 0.000006 | Loss: 0.1131 | L_intra: 1.0823 | L_inter: 0.0054 | β_avg: 0.9000


                                                                                                                           

PKSampler initialized: 10 classes, 83 batches/epoch, batch_size=60
  Triplet stats: avg_d_ap: 1.0889 | avg_d_an: 1.0917 | margin_violations: 0.0167 | semi_hard_ratio: 0.9833 | min_dist: 0.0002 | max_dist: 1.4485 | mean_dist: 1.0712
Epoch 20/20 | LR: 0.000000 | Loss: 0.1128 | L_intra: 1.0810 | L_inter: 0.0052 | β_avg: 0.9000 | Val Loss: 0.1128

✓ Saved CIFAR-10 beta evolution history: ./output_cifar10/cifar10_beta_evolution.npy

Phase 2: Training Prediction Network (CIFAR-10)
------------------------------------------------------------

Training Prediction Network (CIFAR-10)


                                                                  

Pred Epoch 1/10 | LR: 0.000976 | Train Loss: 2.2988 | Train Acc: 11.23% | Test Acc: 11.94%
  ✓ New best model saved! Accuracy: 11.94%


                                                                  

Pred Epoch 2/10 | LR: 0.000905 | Train Loss: 2.2901 | Train Acc: 14.72% | Test Acc: 17.76%
  ✓ New best model saved! Accuracy: 17.76%


                                                                  

Pred Epoch 3/10 | LR: 0.000794 | Train Loss: 2.2820 | Train Acc: 19.39% | Test Acc: 25.38%
  ✓ New best model saved! Accuracy: 25.38%


                                                                  

Pred Epoch 4/10 | LR: 0.000655 | Train Loss: 2.2741 | Train Acc: 22.97% | Test Acc: 31.36%
  ✓ New best model saved! Accuracy: 31.36%


                                                                  

Pred Epoch 5/10 | LR: 0.000500 | Train Loss: 2.2671 | Train Acc: 26.40% | Test Acc: 35.48%
  ✓ New best model saved! Accuracy: 35.48%


                                                                  

Pred Epoch 6/10 | LR: 0.000345 | Train Loss: 2.2607 | Train Acc: 28.68% | Test Acc: 38.30%
  ✓ New best model saved! Accuracy: 38.30%


                                                                  

Pred Epoch 7/10 | LR: 0.000206 | Train Loss: 2.2558 | Train Acc: 29.65% | Test Acc: 39.36%
  ✓ New best model saved! Accuracy: 39.36%


                                                                  

Pred Epoch 8/10 | LR: 0.000095 | Train Loss: 2.2528 | Train Acc: 30.20% | Test Acc: 38.70%


                                                                  

Pred Epoch 9/10 | LR: 0.000024 | Train Loss: 2.2507 | Train Acc: 30.92% | Test Acc: 40.56%
  ✓ New best model saved! Accuracy: 40.56%


                                                                   

Pred Epoch 10/10 | LR: 0.000000 | Train Loss: 2.2500 | Train Acc: 30.66% | Test Acc: 39.40%

Best Test Accuracy (CIFAR-10): 40.56%

Extracting CIFAR-10 Embeddings
------------------------------------------------------------


Extracting embeddings (CIFAR-10): 100%|██████████| 157/157 [00:09<00:00, 16.78it/s]


✓ Saved CIFAR-10 embeddings: (40000, 128)

CIFAR-10 Training Complete!
Final Test Accuracy: 40.56%



