In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models

# 检查是否有可用的 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

# 定义数据转换
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 加载 CIFAR-10 数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=200, shuffle=True)

# 加载预训练的 MobileNetV3 模型
model = models.mobilenet_v3_small(pretrained=True)

# 修改分类器层
model.classifier[3] = nn.Linear(model.classifier[3].in_features, 10)  # CIFAR-10 有10个类别

# 将模型移动到 GPU
model = model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型并输出每轮训练的精度
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)  # 将数据移动到 GPU
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        
        # 计算精度
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {accuracy:.2f}%')

# 保存模型
torch.save(model.state_dict(), 'mobilenetv3_classifier.pth')


Using device: cuda
Files already downloaded and verified




Epoch [1/10], Loss: 0.3415, Accuracy: 88.46%
Epoch [2/10], Loss: 0.1430, Accuracy: 95.04%
Epoch [3/10], Loss: 0.0978, Accuracy: 96.60%
Epoch [4/10], Loss: 0.0748, Accuracy: 97.34%
Epoch [5/10], Loss: 0.0580, Accuracy: 97.96%
Epoch [6/10], Loss: 0.0444, Accuracy: 98.49%
Epoch [7/10], Loss: 0.0455, Accuracy: 98.42%
Epoch [8/10], Loss: 0.0371, Accuracy: 98.70%
Epoch [9/10], Loss: 0.0302, Accuracy: 98.96%
Epoch [10/10], Loss: 0.0369, Accuracy: 98.75%
