In [16]:
import torch  # 导入torch的相关库
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

In [17]:
# Download training data from open datasets.下载训练数据集
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.下载测试数据集
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

In [18]:
batch_size = 64  # 批量大小为64，即每次训练64个样本，Dataloader中的每个元素会返回64个特征和标签

# Create data loaders.#创建数据加载器
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64


In [19]:
# Get cpu, gpu or mps device for training. 在PyTorch中定义一个神经网络，需要创建一个继承自nn.Module的类，定义网络的结构和传播的过程
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

# Define model


class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


model = NeuralNetwork().to(device)
print(model)

Using cuda device
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)


In [20]:
# Optimizing the Model Parameters 为了训练模型，我们需要损失函数和优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

In [21]:
# In a single training loop, the model makes predictions on the training dataset (fed to it in batches), and backpropagates the prediction error to adjust the model’s parameters. 在训练循环当中，模型对数据集进行预测，并通过反向传播预测误差来调整模型的参数
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error 计算预测误差，通过损失函数计算
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation 反向传播 通过优化器调整模型参数
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:    # 每100个batch打印一次损失
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [22]:
# We also check the model’s performance against the test dataset to ensure it is learning. 根据数据集检查模型的性能，保证模型在学习

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [23]:
epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 2.316073  [   64/60000]
loss: 2.292903  [ 6464/60000]
loss: 2.271219  [12864/60000]
loss: 2.259429  [19264/60000]
loss: 2.242147  [25664/60000]
loss: 2.219713  [32064/60000]
loss: 2.232277  [38464/60000]
loss: 2.196522  [44864/60000]
loss: 2.187490  [51264/60000]
loss: 2.166581  [57664/60000]
Test Error: 
 Accuracy: 43.0%, Avg loss: 2.152578 

Epoch 2
-------------------------------
loss: 2.162163  [   64/60000]
loss: 2.147994  [ 6464/60000]
loss: 2.085974  [12864/60000]
loss: 2.107597  [19264/60000]
loss: 2.059603  [25664/60000]
loss: 1.999512  [32064/60000]
loss: 2.037723  [38464/60000]
loss: 1.949230  [44864/60000]
loss: 1.954740  [51264/60000]
loss: 1.904990  [57664/60000]
Test Error: 
 Accuracy: 55.8%, Avg loss: 1.884129 

Epoch 3
-------------------------------
loss: 1.909419  [   64/60000]
loss: 1.883172  [ 6464/60000]
loss: 1.757221  [12864/60000]
loss: 1.811207  [19264/60000]
loss: 1.705239  [25664/60000]
loss: 1.650380  [32064/600

In [24]:
#Save model  保存模型
torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")

Saved PyTorch Model State to model.pth


In [25]:
# Loading Models 加载模型
model = NeuralNetwork().to(device)
model.load_state_dict(torch.load("model.pth"))
model.load_state_dict(torch.load("model.pth", weights_only=True))
# wights_only 指示是否应限制 unpickler 只加载张量、基元类型、字典和任何通过 torch.serialization.add_safe_globals() 添加的类型。 Indicates whether unpickler should be restricted to loading only tensors, primitive types, dictionaries and any types added via torch.serialization.add_safe_globals().


  model.load_state_dict(torch.load("model.pth"))


<All keys matched successfully>

In [26]:
#使用刚刚创建的模型进行预测
classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    x = x.to(device)
    pred = model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')
    

Predicted: "Ankle boot", Actual: "Ankle boot"
