In [None]:
import torch
import numpy as np
from torchvision import datasets, transforms
from torch import nn
from torch.utils import data
from matplotlib import pyplot as plt

In [None]:
class Reshape(nn.Module):
    def forward(self, x):
        return x.view(-1, 1, 224, 224)

In [None]:
def init_weight(m):
    if type(m) in [nn.Linear, nn.Conv2d]:
        nn.init.xavier_normal_(m.weight)

In [None]:
net = nn.Sequential(
    Reshape(),
    nn.Conv2d(1, 96, 11, 2),
    nn.MaxPool2d(3, 2),
    nn.Conv2d(96, 256, 5, padding=2),
    nn.MaxPool2d(3, 2),
    nn.Conv2d(256, 384, 3, padding=1),
    nn.Conv2d(384, 384, 3, padding=1),
    nn.Conv2d(384, 256, 3, padding=1),
    nn.MaxPool2d(3, 2),
    nn.Flatten(),
    nn.Linear(256 * 12 * 12, 4096),
    nn.ReLU(),
    nn.Dropout(),
    nn.Linear(4096, 4096),
    nn.ReLU(),
    nn.Dropout(),
    nn.Linear(4096, 10)
)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
net.apply(init_weight)
net.to(device=device)

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
train = datasets.FashionMNIST("./data", train=True, transform=transform)
test = datasets.FashionMNIST("./data", train=False, transform=transform)
data = data.DataLoader(train, 4, shuffle=True, num_workers=4)
optimizer = torch.optim.SGD(net.parameters(), lr=0.08)
loss = nn.CrossEntropyLoss()
epochs = 35

In [None]:
def trainer(data, optimizer, epochs, loss):
    net.train()
    J = np.empty(shape=0)
    for _ in range(epochs):
        costs = np.empty(shape=0)
        for x, y in data:
            optimizer.zero_grad()
            l = loss(net(x.cuda()), y.cuda())
            temp = l.clone()
            temp = temp.to(device="cpu")
            costs = np.append(costs, temp.detach().numpy())
            del temp
            l.backward()
            optimizer.step()
        J = np.append(J, np.mean(costs))
        del costs
    return J

In [None]:
switch = True
if switch:
    trainer(data, optimizer, epochs, loss)
    torch.save(net.state_dict(), "./data/LeNet.pt")

In [None]:
net.load_state_dict(torch.load("./data/LeNet.pt"))
net.to(device="cpu")
net.eval()
score_test = list()
score_train = list()
for a, b in test:
    score_test.append(net(a).argmax() == b)
for c, d in train:
    if len(score_train) < 10000:
        score_train.append(net(c).argmax() == d)
    else:
        break
accuracy_test = sum(score_test) / len(score_test)
accuracy_train = sum(score_train) / len(score_train)

In [None]:
print(f"acc_test:{accuracy_test}")
print(f"acc_train:{accuracy_train}")