<a href="https://colab.research.google.com/github/thai94/d2l/blob/main/3.linear-neural-networks/3_7_concise_implementation_of_softmax_regression_tensorflow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [29]:
import tensorflow as tf

In [30]:
def load_data_fashion_mnist(batch_size, resize=None):
    """Download the Fashion-MNIST dataset and then load it into memory."""
    mnist_train, mnist_test = tf.keras.datasets.fashion_mnist.load_data()
    # Divide all numbers by 255 so that all pixel values are between
    # 0 and 1, add a batch dimension at the last. And cast label to int32
    process = lambda X, y: (tf.expand_dims(X, axis=3) / 255,
                            tf.cast(y, dtype='int32'))
    resize_fn = lambda X, y: (
        tf.image.resize_with_pad(X, resize, resize) if resize else X, y)
    return (
        tf.data.Dataset.from_tensor_slices(process(*mnist_train)).batch(
            batch_size).shuffle(len(mnist_train[0])).map(resize_fn),
        tf.data.Dataset.from_tensor_slices(process(*mnist_test)).batch(
            batch_size).map(resize_fn))

In [31]:
batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)

In [32]:
net = tf.keras.models.Sequential()
net.add(tf.keras.layers.Flatten(input_shape=(28, 28)))
weight_initializer = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.01)
net.add(tf.keras.layers.Dense(10, kernel_initializer=weight_initializer))

In [33]:
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
trainer = tf.keras.optimizers.SGD(learning_rate=0.1)

In [34]:
def accuracy(y_hat, y):
    """Compute the number of correct predictions."""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = tf.argmax(y_hat, axis=1)
    cmp = tf.cast(y_hat, y.dtype) == y
    return float(tf.reduce_sum(tf.cast(cmp, y.dtype)))

In [35]:
def evaluate_accuracy(net, data_iter):
    """Compute the accuracy for a model on a dataset."""
    metric = Accumulator(2)  # No. of correct predictions, no. of predictions
    for X, y in data_iter:
        metric.add(accuracy(net(X), y), len(y))
    return metric[0] / metric[1]

In [36]:
class Accumulator:
    """For accumulating sums over `n` variables."""
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [37]:
def train_epoch_ch3(net, train_iter, loss, updater, params, lr):
  metric = Accumulator(3)
  for X, y in train_iter:
    with tf.GradientTape() as tape:
      y_hat = net(X)
      if isinstance(loss, tf.keras.losses.Loss):
        l = loss(y, y_hat)
      else:
        l = loss(y_hat, y)

    if isinstance(updater, tf.keras.optimizers.Optimizer):
      params = net.trainable_variables
      grads = tape.gradient(l, params)
      updater.apply_gradients(zip(grads, params))
    else:
      updater(X.shape[0], tape.gradient(l, params), params, lr)
    l_sum = tf.reduce_sum(l)
    metric.add(l_sum, accuracy(y_hat, y), tf.size(y))
  return metric[0] / metric[2], metric[1] / metric[2]

In [38]:
def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater, params, lr):
  for epoch in range(num_epochs):
    train_metrics = train_epoch_ch3(net, train_iter, loss, updater, params, lr)
    test_acc = evaluate_accuracy(net, test_iter)
    print('epoch: %s' % epoch)
    print(train_metrics)
    print(test_acc)
  train_loss, train_acc = train_metrics

In [39]:
num_epochs = 20
lr = 0.01
train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer, None, lr)

epoch: 0
(0.003080140237013499, 0.7473333333333333)
0.7924
epoch: 1
(0.0022357578267653785, 0.8120833333333334)
0.8136
epoch: 2
(0.0020538648664951325, 0.8266833333333333)
0.8205
epoch: 3
(0.0019628437598546346, 0.8315833333333333)
0.8205
epoch: 4
(0.0018993083328008652, 0.8366)
0.8253
epoch: 5
(0.0018540361866354943, 0.84025)
0.83
epoch: 6
(0.0018206486160556475, 0.8433333333333334)
0.8247
epoch: 7
(0.001792177265882492, 0.8460666666666666)
0.8305
epoch: 8
(0.0017696840226650238, 0.8470166666666666)
0.8326
epoch: 9
(0.001750953158736229, 0.8483833333333334)
0.8346
epoch: 10
(0.0017347917646169662, 0.8496333333333334)
0.8344
epoch: 11
(0.0017190943916638693, 0.8502)
0.836
epoch: 12
(0.0017068222537636756, 0.8511833333333333)
0.8349
epoch: 13
(0.0016964888880650203, 0.85225)
0.8373
epoch: 14
(0.0016826957042018573, 0.8539)
0.8386
epoch: 15
(0.001676203283170859, 0.8533)
0.8386
epoch: 16
(0.0016683359781901042, 0.8542)
0.8396
epoch: 17
(0.0016559156447649003, 0.8559)
0.8395
epoch: 18
(0.