In [1]:
import torch
from pathlib import Path
from matplotlib import pyplot as plt
from torch import nn
from torch.utils import data
from torchvision import transforms
from torchvision import datasets
from torchsummary import summary

In [2]:
transform = transforms.Compose(
    [
        transforms.Resize([224, 224]),
        transforms.ToTensor(),
    ]
)
mnist_train = datasets.FashionMNIST("./dataset", train=True, transform=transform)
mnist_test = datasets.FashionMNIST("./dataset", train=False, transform=transform)

In [3]:
net = nn.Sequential(
    nn.Conv2d(1, 96, 11, 2),
    nn.MaxPool2d(3, 2),
    nn.Conv2d(96, 256, 5, padding=2),
    nn.MaxPool2d(3, 2),
    nn.Conv2d(256, 384, 3, padding=1),
    nn.Conv2d(384, 384, 3, padding=1),
    nn.Conv2d(384, 256, 3, padding=1),
    nn.MaxPool2d(3, 2),
    nn.Flatten(),
    nn.Linear(256 * 12 * 12, 4096),
    nn.ReLU(),
    nn.Dropout(),
    nn.Linear(4096, 4096),
    nn.ReLU(),
    nn.Dropout(),
    nn.Linear(4096, 10)
)

In [4]:
summary(net, (1, 224, 224), device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 96, 107, 107]          11,712
         MaxPool2d-2           [-1, 96, 53, 53]               0
            Conv2d-3          [-1, 256, 53, 53]         614,656
         MaxPool2d-4          [-1, 256, 26, 26]               0
            Conv2d-5          [-1, 384, 26, 26]         885,120
            Conv2d-6          [-1, 384, 26, 26]       1,327,488
            Conv2d-7          [-1, 256, 26, 26]         884,992
         MaxPool2d-8          [-1, 256, 12, 12]               0
           Flatten-9                [-1, 36864]               0
           Linear-10                 [-1, 4096]     150,999,040
             ReLU-11                 [-1, 4096]               0
          Dropout-12                 [-1, 4096]               0
           Linear-13                 [-1, 4096]      16,781,312
             ReLU-14                 [-

In [5]:
def init_weight(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        nn.init.xavier_normal_(m.weight.data)

In [6]:
net.apply(init_weight)
temp = torch.randn(1, 1, 224, 224)
for layer in net:
    temp = layer(temp)
    print(layer.__class__.__name__, "out size:", temp.shape)

Conv2d out size: torch.Size([1, 96, 107, 107])
MaxPool2d out size: torch.Size([1, 96, 53, 53])
Conv2d out size: torch.Size([1, 256, 53, 53])
MaxPool2d out size: torch.Size([1, 256, 26, 26])
Conv2d out size: torch.Size([1, 384, 26, 26])
Conv2d out size: torch.Size([1, 384, 26, 26])
Conv2d out size: torch.Size([1, 256, 26, 26])
MaxPool2d out size: torch.Size([1, 256, 12, 12])
Flatten out size: torch.Size([1, 36864])
Linear out size: torch.Size([1, 4096])
ReLU out size: torch.Size([1, 4096])
Dropout out size: torch.Size([1, 4096])
Linear out size: torch.Size([1, 4096])
ReLU out size: torch.Size([1, 4096])
Dropout out size: torch.Size([1, 4096])
Linear out size: torch.Size([1, 10])


In [7]:
def trainer(net, trainset, loss, optimizer, epochs):
    try:
        print("读取参数")
        net.load_state_dict(torch.load("./dataset/alexnet.pt"))
    except FileNotFoundError:
        print("参数文件不存在,开始训练网络")
        net.train()
        for epoch in range(epochs):
            print(f"第{epoch}次迭代,网络训练中...")
            for x, y in trainset:
                x = x.cuda()
                y = y.cuda()
                optimizer.zero_grad()
                l = loss(net(x), y)
                l.backward()
                optimizer.step()
        torch.save(net.state_dict(), "./dataset/alexnet.pt")
        print("网络训练完成,存储参数")
    else:
        print("参数读取成功")

In [8]:
epochs = 1
device = "cuda" if torch.cuda.is_available() else "cpu"
net.to(device)
loss = nn.CrossEntropyLoss()
trainset = data.DataLoader(mnist_train, batch_size=16, shuffle=True)
optimizer = torch.optim.SGD(net.parameters(), 0.01)

In [9]:
trainer(net, trainset, loss, optimizer, epochs)

读取参数
参数文件不存在,开始训练网络
第0次迭代,网络训练中...


KeyboardInterrupt: 

In [None]:
net.eval()
score_test = list()
score_train = list()
for a, b in mnist_test:
    a = a.reshape(-1, 1, 224, 224)
    score_test.append(net(a.cuda()).argmax() == b)
for c, d in mnist_train:
    if len(score_train) < 10000:
        c = c.reshape(-1, 1, 224, 224)
        score_train.append(net(c.cuda()).argmax() == d)
    else:
        break
accuracy_test = sum(score_test) / len(score_test)
accuracy_train = sum(score_train) / len(score_train)

In [None]:
print(accuracy_test)