In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import GradScaler, autocast
import numpy as np
from tqdm import tqdm
import os

# =====================
# 配置参数
# =====================
class Config:
    data_dir = "/Study/imagenet_mini/imagenet-mini/"  # 数据集路径
    num_classes = 500                    # 类别数
    input_size = 128                     # 输入尺寸
    batch_size = 32                      # 物理批次大小
    accum_steps = 4                      # 梯度累积步数（等效批次=128）
    num_epochs = 100                     # 训练轮次
    lr = 2e-4                            # 学习率
    weight_decay = 0.05                  # 权重衰减
    device = "cuda" if torch.cuda.is_available() else "cpu"
    early_stop_patience = 50              # 早停耐心值

# =====================
# 数据增强与加载
# =====================
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(Config.input_size, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),
    transforms.RandomApply([transforms.GaussianBlur(3)], p=0.3),
    transforms.RandomApply([transforms.RandomRotation(15)], p=0.3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize(Config.input_size + 32),
    transforms.CenterCrop(Config.input_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

class ImageNetMini(Dataset):
    def __init__(self, root, train=True, transform=None):
        self.root = os.path.join(root, "train" if train else "val")
        self.classes = sorted(os.listdir(self.root))
        self.transform = transform
        self.samples = []
        for class_idx, class_name in enumerate(self.classes):
            class_dir = os.path.join(self.root, class_name)
            for img_name in os.listdir(class_dir):
                self.samples.append((os.path.join(class_dir, img_name), class_idx))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label

# =====================
# 模型定义（适配128x128输入）
# =====================
def create_resnet18():
    model = models.resnet18(pretrained=False)
    
    # 修改首层卷积（适配小输入）
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, 
                           stride=1, padding=1, bias=False)
    
    # 移除第一个最大池化层（防止过早下采样）
    model.maxpool = nn.Identity()
    
    # 调整后续层（可选）
    model.layer4[0].conv1.stride = (1,1)  # 减少下采样
    
    # 修改分类头
    model.fc = nn.Sequential(
        nn.Dropout(0.3),
        nn.Linear(model.fc.in_features, Config.num_classes)
    )
    return model.to(Config.device)

# =====================
# 训练工具函数
# =====================
def cutmix_data(x, y, alpha=1.0):
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size(0)
    index = torch.randperm(batch_size)
    
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]
    
    return x, y, y[index], lam

def rand_bbox(size, lam):
    W, H = size[2], size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

# =====================
# 训练流程
# =====================
def main():
    # 数据加载
    train_dataset = ImageNetMini(Config.data_dir, train=True, transform=train_transform)
    val_dataset = ImageNetMini(Config.data_dir, train=False, transform=test_transform)
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=Config.batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        persistent_workers=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=Config.batch_size*2,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    # 初始化模型
    model = create_resnet18()
    optimizer = optim.AdamW(model.parameters(), lr=Config.lr, weight_decay=Config.weight_decay)
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer, 
        max_lr=Config.lr*2,
        steps_per_epoch=len(train_loader)//Config.accum_steps,
        epochs=Config.num_epochs
    )
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)  # 标签平滑
    scaler = GradScaler()
    
    # 训练循环
    best_val_acc = 0.0
    patience_counter = 0
    
    for epoch in range(Config.num_epochs):
        model.train()
        train_loss = 0.0
        optimizer.zero_grad()
        
        # 带进度条的训练迭代
        pbar = tqdm(enumerate(train_loader), total=len(train_loader), 
                   desc=f"Epoch {epoch+1}/{Config.num_epochs}")
        
        for step, (inputs, targets) in pbar:
            inputs = inputs.to(Config.device, non_blocking=True)
            targets = targets.to(Config.device, non_blocking=True)
            
            # CutMix增强（概率50%）
            if np.random.rand() < 0.5:
                inputs, targets_a, targets_b, lam = cutmix_data(inputs, targets)
            
            # 混合精度训练
            with autocast():
                outputs = model(inputs)
                if np.random.rand() < 0.5:
                    loss = lam * criterion(outputs, targets_a) + (1 - lam) * criterion(outputs, targets_b)
                else:
                    loss = criterion(outputs, targets)
                
                loss = loss / Config.accum_steps  # 梯度累积
            
            scaler.scale(loss).backward()
            
            # 梯度累积更新
            if (step + 1) % Config.accum_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()
            
            train_loss += loss.item() * Config.accum_steps
            pbar.set_postfix({"Loss": f"{train_loss/(step+1):.4f}"})
        
        # 验证阶段
        model.eval()
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs = inputs.to(Config.device)
                targets = targets.to(Config.device)
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()
        
        val_acc = 100 * val_correct / val_total
        print(f"Validation Accuracy: {val_acc:.2f}%")
        
        # 早停与模型保存
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), "best_model.pth")
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= Config.early_stop_patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

if __name__ == "__main__":
    main()

