In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Using device: cpu


In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),          # 转为 [0,1] 的 Tensor
    transforms.Lambda(lambda x: x.view(-1))  # 将 28x28 展平为 784
])

In [4]:
# 下载并加载训练集和测试集
train_dataset = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transform
)
test_dataset = torchvision.datasets.MNIST(
    root='./data', train=False, download=True, transform=transform
)

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:10<00:00, 961793.61it/s] 


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 90921.96it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:08<00:00, 201879.66it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4775765.55it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [5]:
model = nn.Linear(784, 10).to(device)

In [6]:
criterion = nn.CrossEntropyLoss()  # 用于多分类
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

In [7]:
num_epochs = 20
for epoch in range(num_epochs):
    model.train()  # 设置为训练模式
    total_loss = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # 反向传播 + 优化
        optimizer.zero_grad()  # 清空梯度
        loss.backward()        # 计算梯度
        optimizer.step()       # 更新参数
        
        total_loss += loss.item()
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}')

Epoch [1/20], Loss: 0.9809
Epoch [2/20], Loss: 0.5508
Epoch [3/20], Loss: 0.4715
Epoch [4/20], Loss: 0.4326
Epoch [5/20], Loss: 0.4086
Epoch [6/20], Loss: 0.3918
Epoch [7/20], Loss: 0.3793
Epoch [8/20], Loss: 0.3693
Epoch [9/20], Loss: 0.3613
Epoch [10/20], Loss: 0.3546
Epoch [11/20], Loss: 0.3487
Epoch [12/20], Loss: 0.3437
Epoch [13/20], Loss: 0.3392
Epoch [14/20], Loss: 0.3355
Epoch [15/20], Loss: 0.3321
Epoch [16/20], Loss: 0.3288
Epoch [17/20], Loss: 0.3259
Epoch [18/20], Loss: 0.3232
Epoch [19/20], Loss: 0.3209
Epoch [20/20], Loss: 0.3187


In [8]:
model.eval()  # 设置为评估模式
correct = 0
total = 0
with torch.no_grad():  # 不计算梯度，加速推理
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Test Accuracy: {accuracy:.2f}%')
assert accuracy > 90, "❌ 未达到目标！检查学习率、epoch 或数据预处理"
print("✅ 成功！准确率 > 90%")

# 保存模型
torch.save(model.state_dict(), 'mnist_model.pth')
print("✅ 模型已保存为 mnist_model.pth")

Test Accuracy: 91.61%
✅ 成功！准确率 > 90%
✅ 模型已保存为 mnist_model.pth
