In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
from torch.cuda.amp import GradScaler, autocast
import os
import time

In [2]:
# 设置随机种子，确保结果可重复
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    cudnn.deterministic = True  # 设置为True以确保每次运行结果相同
    cudnn.benchmark = False  # 关闭自动优化以确保可重复性

In [3]:
# 定义设备，优先使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# 数据处理
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [5]:
# 加载CIFAR-10数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=512, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=512, shuffle=False, num_workers=2)

In [6]:
# 定义模型
model = torchvision.models.resnet50(weights=None, progress=True, num_classes=10)
model = model.to(device)

In [7]:
# 如果使用多个gpu，则进行数据并行
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)

In [8]:
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

In [9]:
# 混合精度训练和梯度缩放器
scaler = torch.amp.GradScaler()
# 学习率调度器
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

In [10]:
# 训练函数
def train(epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    start_time = time.time()
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        # 自动混合精度前向传播
        with torch.amp.autocast(device_type=device.type):
            outputs = model(inputs)
            loss = criterion(outputs, targets)
        # 梯度缩放
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        if batch_idx % 100 == 0:
            print(f'Epoch {epoch} | Batch {batch_idx} | Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f}%')
    end_time = time.time()
    print(f'Epoch {epoch} completed in {end_time - start_time:.2f} seconds.')

In [11]:
# 测试函数
def test(epoch):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            with torch.amp.autocast(device_type=device.type):
                outputs = model(inputs)
                loss = criterion(outputs, targets)
            
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    acc = 100.* correct / total
    print(f'Test Epoch {epoch} | Test Loss: {test_loss/(batch_idx+1):.3f} | Acc: {acc:.3f}%')
    return acc

In [12]:
# 主训练循环
best_acc = 0
num_epochs = 10
set_seed(42)
for epoch in range(num_epochs):
    train(epoch)
    acc = test(epoch)
    scheduler.step()
    # 保存最佳模型
    if acc > best_acc:
        print(f'New best accuracy: {acc:.3f}%, saving model...')
        best_acc = acc
        state = {
            'model': model.state_dict(),
            'acc': best_acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
    print(f'Best Acc: {best_acc:.3f}%\n')

Epoch 0 | Batch 0 | Loss: 2.488 | Acc: 11.719%
Epoch 0 completed in 19.29 seconds.
Test Epoch 0 | Test Loss: 205.518 | Acc: 9.710%
New best accuracy: 9.710%, saving model...
Best Acc: 9.710%

Epoch 1 | Batch 0 | Loss: 201.528 | Acc: 8.203%
Epoch 1 completed in 18.01 seconds.
Test Epoch 1 | Test Loss: 221.339 | Acc: 9.480%
Best Acc: 9.710%

Epoch 2 | Batch 0 | Loss: 193.375 | Acc: 9.375%
Epoch 2 completed in 18.51 seconds.
Test Epoch 2 | Test Loss: nan | Acc: 9.810%
New best accuracy: 9.810%, saving model...
Best Acc: 9.810%

Epoch 3 | Batch 0 | Loss: 195.976 | Acc: 11.328%
Epoch 3 completed in 18.77 seconds.
Test Epoch 3 | Test Loss: nan | Acc: 9.740%
Best Acc: 9.810%

Epoch 4 | Batch 0 | Loss: 186.836 | Acc: 9.570%
Epoch 4 completed in 18.29 seconds.
Test Epoch 4 | Test Loss: nan | Acc: 9.470%
Best Acc: 9.810%

Epoch 5 | Batch 0 | Loss: 198.444 | Acc: 8.008%
Epoch 5 completed in 18.09 seconds.
Test Epoch 5 | Test Loss: 199.256 | Acc: 9.590%
Best Acc: 9.810%

Epoch 6 | Batch 0 | Loss: 

In [13]:
# 加载最佳模型进行最终测试
checkpoint = torch.load('./checkpoint/ckpt.pth')
model.load_state_dict(checkpoint['model'])
final_acc = test('Final')
print(f'Final Test Accuracy: {final_acc:.3f}%')

Test Epoch Final | Test Loss: 267.067 | Acc: 9.760%
Final Test Accuracy: 9.760%
