In [1]:
import torch
from pathlib import Path
from matplotlib import pyplot as plt
from torch import nn
from torch.utils import data
from torchvision import transforms
from torchvision import datasets

In [None]:
transform = transforms.Compose(
    [
        transforms.Resize([224, 224]),
        transforms.ToTensor(),
    ]
)
mnist_train = datasets.FashionMNIST("./data", train=True, transform=transform)
mnist_test = datasets.FashionMNIST("./data", train=False, transform=transform)

In [None]:
net = nn.Sequential(
    nn.Conv2d(1, 96, 11, 2),
    nn.MaxPool2d(3, 2),
    nn.Conv2d(96, 256, 5, padding=2),
    nn.MaxPool2d(3, 2),
    nn.Conv2d(256, 384, 3, padding=1),
    nn.Conv2d(384, 384, 3, padding=1),
    nn.Conv2d(384, 256, 3, padding=1),
    nn.MaxPool2d(3, 2),
    nn.Flatten(),
    nn.Linear(256 * 12 * 12, 4096),
    nn.ReLU(),
    nn.Dropout(),
    nn.Linear(4096, 4096),
    nn.ReLU(),
    nn.Dropout(),
    nn.Linear(4096, 10)
)

In [None]:
def init_weight(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        nn.init.xavier_normal_(m.weight.data)

In [None]:
net.apply(init_weight)
temp = torch.randn(1, 1, 224, 224)
for layer in net:
    temp = layer(temp)
    print(layer.__class__.__name__, "out size:", temp.shape)

In [None]:
def trainer(net, data, loss, optimizer, epochs):
    try:
        print("读取参数")
        net.load_state_dict(torch.load("./data/alexnet.pt"))
    except FileNotFoundError:
        print("参数文件不存在,开始训练网络")
        net.train()
        for epoch in range(epochs):
            print(f"第{epoch}次迭代,网络训练中...")
            for x, y in data:
                x = x.cuda()
                y = y.cuda()
                optimizer.zero_grad()
                l = loss(net(x), y)
                l.backward()
                optimizer.step()
        torch.save(net.state_dict(), "./data/alexnet.pt")
        print("网络训练完成,存储参数")
    else:
        print("参数读取成功")

In [None]:
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), 0.01)
epochs = 1
device = "cuda" if torch.cuda.is_available() else "cpu"
data = data.DataLoader(mnist_train, batch_size=32, shuffle=True, num_workers=4)
net.to(device)

In [None]:
trainer(net, data, loss, optimizer, epochs)

In [None]:
net.eval()
score_test = list()
score_train = list()
for a, b in mnist_test:
    a = a.reshape(-1, 1, 224, 224)
    score_test.append(net(a.cuda()).argmax() == b)
for c, d in mnist_train:
    if len(score_train) < 10000:
        c = c.reshape(-1, 1, 224, 224)
        score_train.append(net(c.cuda()).argmax() == d)
    else:
        break
accuracy_test = sum(score_test) / len(score_test)
accuracy_train = sum(score_train) / len(score_train)

In [None]:
print(accuracy_test)