In [4]:
import torch
import numpy as np
import sys
sys.path.append('..')
import d2lzh_pytorch as d2l
import torchvision

In [2]:
def load_data_fashion_mnist(batch_size, resize=None, root='~/Datasets/FashionMNIST'):
    """Download the fashion mnist dataset and then load into memory."""
    trans = []
    if resize:
        trans.append(torchvision.transforms.Resize(size=resize))
    trans.append(torchvision.transforms.ToTensor())
    
    transform = torchvision.transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)
    mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)
    if sys.platform.startswith('win'):
        num_workers = 0  # 0表示不用额外的进程来加速读取数据
    else:
        # num_workers = 4
        num_workers = 0
    train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=0)
    test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=0)

    return train_iter, test_iter


In [5]:
batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)


In [7]:
num_inputs, num_outputs, num_hiddens = 784, 10, 256

W1 = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_hiddens)), dtype=torch.float)
b1 = torch.zeros(num_hiddens, dtype=torch.float)
W2 = torch.tensor(np.random.normal(0, 0.01, (num_hiddens, num_outputs)), dtype=torch.float)
b2 = torch.zeros(num_outputs, dtype=torch.float)

params = [W1, b1, W2, b2]
for param in params:
    param.requires_grad_(requires_grad=True)

In [8]:
def relu(x):
    return torch.max(input=x, other=torch.tensor(0.0))

In [9]:
def net(X):
    X = X.view((-1, num_inputs))
    H = relu(torch.matmul(X, W1) + b1)
    return torch.matmul(H, W2) + b2

In [11]:
loss = torch.nn.CrossEntropyLoss()

In [12]:
num_epochs, lr = 5, 100.0

d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr)

epoch 1, loss 0.0030, train acc 0.715, test acc 0.787
epoch 2, loss 0.0019, train acc 0.825, test acc 0.803
epoch 3, loss 0.0017, train acc 0.845, test acc 0.844
epoch 4, loss 0.0015, train acc 0.855, test acc 0.850
epoch 5, loss 0.0015, train acc 0.863, test acc 0.842
