In [8]:
import torch
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import FakeData
from torch.utils.data import DataLoader
from torch import nn

%load_ext nbimporter
from ShffleNet import ShffleNet

# 设定图像大小和归一化参数
image_size = (224, 224)
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 生成模拟训练和测试数据集
train_dataset = FakeData(size=1000, image_size=(3, *image_size), num_classes=1000, transform=transform)
test_dataset = FakeData(size=100, image_size=(3, *image_size), num_classes=1000, transform=transform)

# 数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

# 初始化模型和优化器
model = ShffleNet(num_classes=1000, groups=3)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()

# 训练循环
num_epochs = 10
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print(f'Epoch {epoch + 1}, Loss: {running_loss / (i + 1)}')

# 测试循环
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the test images: {100 * correct / total}%')

The nbimporter module is not an IPython extension.
Epoch 1, Loss: 6.9321678429841995
Epoch 2, Loss: 6.876694455742836
Epoch 3, Loss: 6.823976039886475
Epoch 4, Loss: 6.781430810689926
Epoch 5, Loss: 6.755142405629158
Epoch 6, Loss: 6.713146954774857
Epoch 7, Loss: 6.687558740377426
Epoch 8, Loss: 6.655869677662849
Epoch 9, Loss: 6.637718424201012
Epoch 10, Loss: 6.620083764195442
Accuracy of the network on the test images: 0.0%
