In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

# 检查是否有可用的 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 调整图像大小以适应 ResNet 的输入
    transforms.Grayscale(num_output_channels=1),  # 将图像转换为灰度图像（1个通道）
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 加载数据集
train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)

# 创建 ResNet 模型
model = models.resnet18(pretrained=False)
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)  # 将输出类别改为 10

# 将模型加载到 GPU 或 CPU 上
model = model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
epochs = 5
for epoch in range(epochs):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)  # 将数据加载到 GPU 或 CPU 上
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if i % 100 == 99:
            print(f"[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}")
            running_loss = 0.0

print('Finished Training')

# 保存模型
torch.save(model.state_dict(), 'resnet_mnist_model.pth')



Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 20028311.93it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 28828104.19it/s]

Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 6621753.12it/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<?, ?it/s]


Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw

[Epoch 1, Batch 100] loss: 0.359
[Epoch 1, Batch 200] loss: 0.122
[Epoch 1, Batch 300] loss: 0.086
[Epoch 1, Batch 400] loss: 0.076
[Epoch 1, Batch 500] loss: 0.065
[Epoch 1, Batch 600] loss: 0.067
[Epoch 1, Batch 700] loss: 0.063
[Epoch 1, Batch 800] loss: 0.056
[Epoch 1, Batch 900] loss: 0.043
[Epoch 2, Batch 100] loss: 0.039
[Epoch 2, Batch 200] loss: 0.048
[Epoch 2, Batch 300] loss: 0.042
[Epoch 2, Batch 400] loss: 0.040
[Epoch 2, Batch 500] loss: 0.040
[Epoch 2, Batch 600] loss: 0.037
[Epoch 2, Batch 700] loss: 0.047
[Epoch 2, Batch 800] loss: 0.045
[Epoch 2, Batch 900] loss: 0.044
[Epoch 3, Batch 100] loss: 0.025
[Epoch 3, Batch 200] loss: 0.029
[Epoch 3, Batch 300] loss: 0.027
[Epoch 3, Batch 400] loss: 0.030
[Epoch 3, Batch 500] loss: 0.038
[Epoch 3, Batch 600] loss: 0.030
[Epoch 3, Batch 700] loss: 0.032
[Epoch 3, Batch 800] loss: 0.036
[Epoch 3, Batch 900] loss: 0.030
[Epoch 4, Batch 100] loss: 0.030
[

In [2]:
# 加载保存的模型
model = models.resnet18(pretrained=False)
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)  # 将输出类别改为 10
model.load_state_dict(torch.load('resnet_mnist_model.pth'))
model = model.to(device)

# 测试模型
test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=False)

correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))


Accuracy of the network on the 10000 test images: 99 %
