<a href="https://colab.research.google.com/github/thai94/d2l/blob/main/3.linear-neural-networks/3_7_concise_implementation_of_softmax_regression_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [27]:
import torch
from torch import nn
import torchvision
from torch.utils import data
from torchvision import transforms

In [28]:
def get_dataloader_workers():
    """Use 4 processes to read the data."""
    return 8

In [29]:
def load_data_fashion_mnist(batch_size, resize=None):
    """Download the Fashion-MNIST dataset and then load it into memory."""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))

In [30]:
batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)

  cpuset_checked))


In [31]:
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))
def init_weights(m):
  if type(m) == nn.Linear:
    nn.init.normal_(m.weight, std=0.01)
net.apply(init_weights)

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=10, bias=True)
)

In [32]:
loss = nn.CrossEntropyLoss()

In [33]:
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

In [34]:
class Accumulator:
  def __init__(self, n):
    self.data = [0.0] * n
  
  def add(self, *args):
    self.data = [a + float(b) for a,b in zip(self.data, args)]

  def reset(self):
    self.data = [0.0] * len(self.data)
    
  def __getitem__(self, idx):
    return self.data[idx]

In [35]:
def accuracy(y_hat, y):
  if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
    y_hat = y_hat.argmax(axis=1)
  cmp = y_hat.type(y.dtype) == y
  return float(cmp.type(y.dtype).sum())

In [36]:
def evaluate_accuracy(net, data_iter):
  metric = Accumulator(2)
  with torch.no_grad():
    for X, y in data_iter:
      metric.add(accuracy(net(X), y), y.numel())
  return metric[0] / metric[1]

In [37]:
def train_epoch_ch3(net, train_iter, loss, updater):
  metric = Accumulator(3)
  for X, y in train_iter:
    y_hat = net(X)
    l = loss(y_hat, y)
    if isinstance(updater, torch.optim.Optimizer):
      updater.zero_grad()
      l.sum().backward()
      updater.step()
    else:  
      l.sum().backward()
      updater(X.shape[0])
  metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
  return metric[0] / metric[2], metric[1] / metric[2]

In [38]:
def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):

  for epoch in range(num_epochs):
    train_metrics = train_epoch_ch3(net, train_iter, loss, updater)
    test_acc = evaluate_accuracy(net, test_iter)
    print('epoch: %s' % epoch)
    print(train_metrics)
    print(test_acc)
  
  train_loss, train_acc = train_metrics
  assert train_loss < 0.5, train_loss
  assert train_acc <= 1 and train_acc > 0.7, train_acc
  assert test_acc <= 1 and test_acc > 0.7, test_acc

In [39]:
num_epochs = 10
train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

  cpuset_checked))


epoch: 0
(0.00667466161151727, 0.8020833333333334)
0.7904
epoch: 1
(0.005972965930898984, 0.8020833333333334)
0.8084
epoch: 2
(0.004637866901854674, 0.8541666666666666)
0.8119
epoch: 3
(0.004795749671757221, 0.8541666666666666)
0.7995
epoch: 4
(0.0055068352570136385, 0.7395833333333334)
0.8155
epoch: 5
(0.005575697248180707, 0.7708333333333334)
0.8224
epoch: 6
(0.004443559485177199, 0.875)
0.8267
epoch: 7
(0.0038528948401411376, 0.8645833333333334)
0.8302
epoch: 8
(0.005135933558146159, 0.8541666666666666)
0.8294
epoch: 9
(0.006221183265248935, 0.75)
0.8142
