In [1]:
# 会在代码同一级的目录下载训练数据，名称oxford-iiit-pet，1.5G
# 如果训练所有层参数，训练时间长。 epoch=10，训练和测试时间约四十分钟。
# 可以冻结除了最后一层外的所有参数，只训练最后一层，可以大大缩减训练时间。
# 最后一行代码保存模型swinv2_oxford_pet.pth

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import timm

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 加载数据集
train_dataset = datasets.OxfordIIITPet(root='.', split='trainval', download=True, transform=transform)
test_dataset = datasets.OxfordIIITPet(root='.', split='test', download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)

# 加载模型
model = timm.create_model('swinv2_base_window8_256', pretrained=True, num_classes=37)
# model = timm.create_model('swinv2_base_window8_256', pretrained=False, num_classes=37)

'''
# 冻结除最后一层外的所有层
for param in model.parameters():
    param.requires_grad = False

# 解冻最后一层（分类头）
for param in model.head.parameters():
    param.requires_grad = True
'''

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')

# 测试模型
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Test Accuracy: {100 * correct / total:.2f}%')

# 保存模型
# torch.save(model.state_dict(), 'swinv2_oxford_pet.pth')

Epoch [1/10], Loss: 0.7256
Epoch [2/10], Loss: 0.1220
Epoch [3/10], Loss: 0.0804
Epoch [4/10], Loss: 0.0304
Epoch [5/10], Loss: 0.0397
Epoch [6/10], Loss: 0.0981
Epoch [7/10], Loss: 0.0428
Epoch [8/10], Loss: 0.0230
Epoch [9/10], Loss: 0.0336
Epoch [10/10], Loss: 0.0513
Test Accuracy: 87.41%
