In [None]:
import torch
import torch.nn as nn
import torchvision
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
import datetime
import math

In [None]:
transform_train = torchvision.transforms.Compose([
    torchvision.transforms.RandomCrop(32, padding=4),
    torchvision.transforms.RandomHorizontalFlip(p=0.5),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
])

transform_test = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
])


cifar_train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=None)
cifar_test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

In [None]:
from torch.utils.data import DataLoader, random_split
torch.manual_seed(42)  # 在 random_split 之前

train_ratio = 0.8
train_size = int(train_ratio * len(cifar_train))
val_size = len(cifar_train) - train_size

# 因为验证集是从原始训练集里分割的，而原始训练集应用了数据增强，所以分割出来的验证集也就应用了数据增强，分割训练集应用数据增强没毛病，但是分割验证集不应应用
# 数据增强
cifar_train_sub, cifar_val_sub = random_split(cifar_train, [train_size, val_size])
class TransformSubset(torch.utils.data.Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, idx):
        x, y = self.subset[idx]
        if self.transform:
            x = self.transform(x)
        return x, y
    
cifar_train = TransformSubset(cifar_train_sub, transform=transform_train)
cifar_val = TransformSubset(cifar_val_sub, transform=transform_test)

batch_size = 128


train_loader = DataLoader(cifar_train, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(cifar_val, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(cifar_test, batch_size=batch_size, shuffle=False, num_workers=4)

In [None]:
# resnet18 = torchvision.models.resnet18(pretrained=True)
# print(resnet18)

In [None]:
def prepare_resnet18(pretrained=True, num_classes=10):
    """
    准备用于CIFAR-10的ResNet-18
    
    Args:
        pretrained: 是否使用预训练权重
        num_classes: 分类数，CIFAR-10就是10
    """
    # 加载模型
    model = torchvision.models.resnet18(pretrained=pretrained)
    
    # 修改第一层：适应32x32输入
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    # 对于修改过的层适当初始化
    nn.init.kaiming_normal_(model.conv1.weight, mode='fan_out', nonlinearity='relu')

    
    # 移除第一个最大池化层
    model.maxpool = nn.Identity()
    
    # 修改分类头
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, num_classes)
    # 对于修改过的层适当初始化
    nn.init.normal_(model.fc.weight, 0, 0.01)
    nn.init.constant_(model.fc.bias, 0)
    
    print(f"模型准备完成:")
    print(f"  - 输入: 3x32x32")
    print(f"  - 输出: {num_classes}类")
    print(f"  - 使用预训练: {pretrained}")
    
    return model

# 使用
resnet18 = prepare_resnet18(pretrained=False, num_classes=10)


In [None]:
epochs = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(resnet18.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
# 余弦退火
warmup_epochs = 30
eta_max = 0.1
eta_min = 0.001
def warmup_cosine_annealing(epoch):
    if epoch < warmup_epochs:
        return (epoch + 1) / warmup_epochs
    else:
        # 余弦退火：返回相对于初始学习率的比例（从1到eta_min/eta_max）
        progress = (epoch - warmup_epochs) / (epochs - warmup_epochs)
        cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
        # 计算相对于初始学习率的比例
        return (eta_min + (eta_max - eta_min) * cosine_decay) / eta_max
optimizer = torch.optim.SGD(resnet18.parameters(), lr=eta_max, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_cosine_annealing)
# tensorboard
current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
log_dir = f"runs/cifair10_resnet18_{current_time}"
writer = SummaryWriter(log_dir)

best_acc = 0.0

def train(model, train_loader, val_loader, criterion, optimizer, scheduler, device, epochs):
    global best_acc
    model = model.to(device)
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        running_acc = 0.0
        total_samples = 0
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * batch_size
            _, predicted = outputs.max(1)
            running_acc += predicted.eq(targets).sum().item()
            total_samples += targets.size(0)
            if batch_idx % 10 == 0:
                step = epoch * len(train_loader) + batch_idx
                writer.add_scalar('train_loss', loss.item(), step)
        scheduler.step()
        epoch_loss = running_loss / total_samples
        epoch_acc = running_acc / total_samples * 100.0

        writer.add_scalar('train_loss', epoch_loss, epoch)
        writer.add_scalar('train_acc', epoch_acc, epoch)
        writer.add_scalar('learning_rate', scheduler.get_last_lr()[0], epoch)
        print(f"Epoch [{epoch + 1}/{epochs}]")
        print(f'  Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.2f}%')
        print(f'  LR: {scheduler.get_last_lr()[0]:.6f}')

        
        model.eval()
        with torch.no_grad():
            running_loss = 0.0
            running_acc = 0.0
            val_total_samples = 0
            # for batch_idx, (inputs, targets) in enumerate(val_loader):
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                running_loss += loss.item() * batch_size
                _, predicted = outputs.max(1)
                running_acc += predicted.eq(targets).sum().item()
                val_total_samples += targets.size(0)
            epoch_loss = running_loss / val_total_samples
            epoch_acc = running_acc / val_total_samples * 100.0
            if epoch_acc > best_acc:
                best_acc = epoch_acc
                torch.save(model.state_dict(), 'best_model.pth')
                print(f'New best model saved with val_acc: {epoch_acc:.2f}%')

            writer.add_scalar('val_loss', epoch_loss, epoch)
            writer.add_scalar('val_acc', epoch_acc, epoch)
            print(f'  Val Loss: {epoch_loss:.4f}, Val Acc: {epoch_acc:.2f}%')
            
def test(model, test_loader, criterion, device):
    model.load_state_dict(torch.load('best_model.pth'))
    model = model.to(device)
    model.eval()
    running_loss = 0.0
    running_acc = 0.0
    total_samples = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            running_loss += loss.item() * batch_size
            _, predicted = outputs.max(1)
            running_acc += predicted.eq(targets).sum().item()
            total_samples += targets.size(0)
        epoch_loss = running_loss / total_samples
        epoch_acc = running_acc / total_samples * 100.0
        print(f"Test Loss: {epoch_loss:.4f}, Test Acc: {epoch_acc:.4f}%")
        writer.add_scalar('test_loss', epoch_loss, 0)
        writer.add_scalar('test_acc', epoch_acc, 0)

train(resnet18, train_loader, val_loader, criterion, optimizer, scheduler, device, epochs)
test(resnet18, test_loader, criterion, device)
writer.close()

