In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
from models import SimpleCNN
import orjson
import optuna


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# -------------------------
# 2. PACR Loss (batch-wise)
# -------------------------
def pacr_loss(features, labels):
    """
    features: (B, D)
    labels:   (B,)
    """
    loss = 0.0
    num_classes = labels.max().item() + 1
    eps = 1e-8

    for c in range(num_classes):
        idx = labels == c
        if idx.sum() < 2:
            continue
        z = features[idx]  # (Nc, D)
        mean = z.mean(dim=0, keepdim=True)
        loss += ((z - mean) ** 2).sum(dim=1).mean()

    return loss / (num_classes + eps)


In [3]:
def train_epoch(model, loader, optimizer, device, lambda_pacr=0.0):
    model.train()
    total_loss, total_correct, total = 0, 0, 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()

        logits, feats = model(x, return_feat=True)
        loss_ce = F.cross_entropy(logits, y)

        loss = loss_ce
        if lambda_pacr > 0:
            loss_pacr = pacr_loss(feats, y)
            loss = loss + lambda_pacr * loss_pacr

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        total_correct += (logits.argmax(1) == y).sum().item()
        total += x.size(0)

    return total_loss / total, total_correct / total


In [4]:
@torch.no_grad()
def eval_epoch(model, loader, device):
    model.eval()
    total_correct, total = 0, 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        total_correct += (logits.argmax(1) == y).sum().item()
        total += x.size(0)

    return total_correct / total


In [5]:
def objective(trial):
    # 1. 定义搜索空间
    lr = trial.suggest_float("lr", 1e-4, 1e-1, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-2, log=True)
    optimizer_name = trial.suggest_categorical("optimizer", ["SGD", "AdamW"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    transform_train = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    )
    transform_test = transforms.Compose([transforms.ToTensor()])
    train_set = datasets.CIFAR10(
        root="./data", train=True, download=True, transform=transform_train
    )
    test_set = datasets.CIFAR10(
        root="./data", train=False, download=True, transform=transform_test
    )
    train_loader = DataLoader(train_set, batch_size=1280, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=256)
    # 2. 初始化模型与优化器
    model = SimpleCNN().to(device)
    if optimizer_name == "SGD":
        optimizer = optim.SGD(
            model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay
        )
    else:
        optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    # 3. 训练循环 (简化版，建议跑 20-30 epoch)
    for epoch in range(30):
        train_epoch(model, train_loader, optimizer, device)
        val_acc = eval_epoch(model, test_loader, device)

        # 允许提前停止（剪枝），节省时间
        trial.report(val_acc, epoch)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    return val_acc


# 4. 启动寻优
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=50)  # 跑 50 组不同的配置

print("Best Hyperparameters:", study.best_params)


[I 2026-01-14 22:24:15,487] A new study created in memory with name: no-name-f6fb24b8-c0bd-425e-9280-b6ef366d0d7c
  entry = pickle.load(f, encoding="latin1")
[I 2026-01-14 22:26:50,278] Trial 0 finished with value: 0.3216 and parameters: {'lr': 0.0441877585642909, 'weight_decay': 0.0005406787344672163, 'optimizer': 'AdamW'}. Best is trial 0 with value: 0.3216.
[I 2026-01-14 22:29:24,042] Trial 1 finished with value: 0.1193 and parameters: {'lr': 0.0022169363827064495, 'weight_decay': 0.002385623949933042, 'optimizer': 'SGD'}. Best is trial 0 with value: 0.3216.
[I 2026-01-14 22:31:58,355] Trial 2 finished with value: 0.5326 and parameters: {'lr': 0.0014076944288068435, 'weight_decay': 0.007677664379180153, 'optimizer': 'AdamW'}. Best is trial 2 with value: 0.5326.
[I 2026-01-14 22:34:32,222] Trial 3 finished with value: 0.1012 and parameters: {'lr': 0.001276881085072189, 'weight_decay': 0.0003368665119136075, 'optimizer': 'SGD'}. Best is trial 2 with value: 0.5326.
[I 2026-01-14 22:37:

Best Hyperparameters: {'lr': 0.010406771084612782, 'weight_decay': 2.572424508783154e-05, 'optimizer': 'AdamW'}
