In [4]:
# 任务:
# 1. 读入 FashionMNIST, CIFAR10, Tiny Imagenet(来自CS231n课程项目) 数据集
# 2. 构建 CNN 网络, 写 train/eval loop, 带调试信息
# 3. 训练, 保存, 加载

# 数据集导入
from torchvision import datasets, transforms

op_dataset = 0  # 0: FashionMNIST, 1: CIFAR10

if op_dataset == 0:
    train_data = datasets.FashionMNIST(
        root='./data',
        train=True,
        download=True,
        transform=transforms.ToTensor(),
    )
    test_data = datasets.FashionMNIST(
        root='./data',
        train=False,
        download=True,
        transform=transforms.ToTensor(),
    )
elif op_dataset == 1:
    train_data = datasets.CIFAR10(
        root='./data',
        train=True,
        download=True,
        transform=transforms.ToTensor(),
    )
    test_data = datasets.CIFAR10(
        root='./data',
        train=False,
        download=True,
        transform=transforms.ToTensor(),
    )

from torch.utils.data import DataLoader

batch_size = 64
train_dataloader = DataLoader(train_data, batch_size= batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

# 搭建 CNN 网络
from torch import nn
import torch

sample = train_data[0][0]

class LetNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=sample.shape[0], out_channels=6, kernel_size=5)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(self.get_flatten_dim(), 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def get_flatten_dim(self):
        with torch.no_grad():
            x = self.pool(self.conv1(sample.unsqueeze(0)))
            x = self.pool(self.conv2(x))
            return x.numel()
        
    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        logits = self.fc3(x)
        return logits
    
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = LetNet().to(device)

# train/eval

from torch import optim

loss_fn = nn.CrossEntropyLoss()
lr = 1e-3
optimizer = optim.Adam(model.parameters(), lr=lr)

def train_loop(model, dataloader, loss_fn, optimizer):
    model.train()
    samples_size = len(dataloader.dataset)
    batches_size = len(dataloader)
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        logits_pred = model(X)
        loss = loss_fn(logits_pred, y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if batch % 100 == 0 or batch == batches_size - 1:
            print(f'loss:{loss:>7.2f} | {batch * batch_size + len(y)}/{samples_size}')

def test_loop(model, dataloader, loss_fn):
    model.eval()
    samples_size = len(dataloader.dataset)
    batches_size = len(dataloader)
    loss_sum = 0
    acc_sum = 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        with torch.no_grad():
            logits_pred = model(X)
        loss_sum += loss_fn(logits_pred, y)
        acc_sum += (logits_pred.argmax(dim=1) == y).sum()
    print(f'loss:{(loss_sum/batches_size):>7.2f} | acc:{(acc_sum/samples_size*100):>7.2f}%')

epochs = 10
for t in range(1, epochs + 1):
    print(f'Epoch {t} -------------------------')
    train_loop(model, train_dataloader, loss_fn, optimizer)
    test_loop(model, test_dataloader, loss_fn)
    print()

# 保存/加载模型
import os

if not os.path.exists('./model'):
    os.makedirs('./model')
torch.save(model.state_dict(), './model/LetNet.pth')

model_new = LetNet().to(device)
model_new.load_state_dict(torch.load('./model/LetNet.pth'))
model_new.eval()
with torch.no_grad():
    y_pred = model_new(sample.unsqueeze(0).to(device)).argmax()
print(y_pred, train_data[0][1])

Epoch 1 -------------------------
loss:   2.30 | 64/60000
loss:   1.00 | 6464/60000
loss:   0.85 | 12864/60000
loss:   0.81 | 19264/60000
loss:   0.90 | 25664/60000
loss:   0.60 | 32064/60000
loss:   0.54 | 38464/60000
loss:   0.65 | 44864/60000
loss:   0.57 | 51264/60000
loss:   0.53 | 57664/60000
loss:   0.61 | 60000/60000
loss:   0.57 | acc:  78.13%

Epoch 2 -------------------------
loss:   0.81 | 64/60000
loss:   0.43 | 6464/60000
loss:   0.60 | 12864/60000
loss:   0.49 | 19264/60000
loss:   0.65 | 25664/60000
loss:   0.50 | 32064/60000
loss:   0.39 | 38464/60000
loss:   0.59 | 44864/60000
loss:   0.53 | 51264/60000
loss:   0.50 | 57664/60000
loss:   0.39 | 60000/60000
loss:   0.47 | acc:  83.03%

Epoch 3 -------------------------
loss:   0.45 | 64/60000
loss:   0.66 | 6464/60000
loss:   0.35 | 12864/60000
loss:   0.36 | 19264/60000
loss:   0.37 | 25664/60000
loss:   0.30 | 32064/60000
loss:   0.28 | 38464/60000
loss:   0.51 | 44864/60000
loss:   0.39 | 51264/60000
loss:   0.39 | 