<a href="https://colab.research.google.com/github/thai94/d2l/blob/main/4.multilayer_perceptrons/4_3_concise_implementation_of_multilayer_perceptrons_tensorflow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf

In [2]:
net = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(256, activation='relu'),
    tf.keras.layers.Dense(10)])

In [3]:
def load_data_fashion_mnist(batch_size, resize=None):
    """Download the Fashion-MNIST dataset and then load it into memory."""
    mnist_train, mnist_test = tf.keras.datasets.fashion_mnist.load_data()
    # Divide all numbers by 255 so that all pixel values are between
    # 0 and 1, add a batch dimension at the last. And cast label to int32
    process = lambda X, y: (tf.expand_dims(X, axis=3) / 255,
                            tf.cast(y, dtype='int32'))
    resize_fn = lambda X, y: (
        tf.image.resize_with_pad(X, resize, resize) if resize else X, y)
    return (
        tf.data.Dataset.from_tensor_slices(process(*mnist_train)).batch(
            batch_size).shuffle(len(mnist_train[0])).map(resize_fn),
        tf.data.Dataset.from_tensor_slices(process(*mnist_test)).batch(
            batch_size).map(resize_fn))

def accuracy(y_hat, y):
    """Compute the number of correct predictions."""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = tf.argmax(y_hat, axis=1)
    cmp = tf.cast(y_hat, y.dtype) == y
    return float(tf.reduce_sum(tf.cast(cmp, y.dtype)))

def evaluate_accuracy(net, data_iter):
    """Compute the accuracy for a model on a dataset."""
    metric = Accumulator(2)  # No. of correct predictions, no. of predictions
    for X, y in data_iter:
        metric.add(accuracy(net(X), y), len(y))
    return metric[0] / metric[1]

class Accumulator:
    """For accumulating sums over `n` variables."""
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def train_epoch_ch3(net, train_iter, loss, updater, params, lr):
  metric = Accumulator(3)
  for X, y in train_iter:
    with tf.GradientTape() as tape:
      y_hat = net(X)
      if isinstance(loss, tf.keras.losses.Loss):
        l = loss(y, y_hat)
      else:
        l = loss(y_hat, y)

    if isinstance(updater, tf.keras.optimizers.Optimizer):
      params = net.trainable_variables
      grads = tape.gradient(l, params)
      updater.apply_gradients(zip(grads, params))
    else:
      updater(X.shape[0], tape.gradient(l, params), params, lr)
    l_sum = tf.reduce_sum(l)
    metric.add(l_sum, accuracy(y_hat, y), tf.size(y))
  return metric[0] / metric[2], metric[1] / metric[2]


def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater, params, lr):
  for epoch in range(num_epochs):
    train_metrics = train_epoch_ch3(net, train_iter, loss, updater, params, lr)
    test_acc = evaluate_accuracy(net, test_iter)
    print('epoch: %s' % epoch)
    print(train_metrics)
    print(test_acc)
  train_loss, train_acc = train_metrics

In [5]:
batch_size, lr, num_epochs = 256, 0.1, 10
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
trainer = tf.keras.optimizers.SGD(learning_rate=lr)

train_iter, test_iter = load_data_fashion_mnist(batch_size)
train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer, None, lr)

epoch: 0
(0.002883693205813567, 0.7562833333333333)
0.8011
epoch: 1
(0.0020279876967271167, 0.8216)
0.8171
epoch: 2
(0.0018227596312761307, 0.83775)
0.8311
epoch: 3
(0.0016962133139371872, 0.8480333333333333)
0.8382
epoch: 4
(0.0016101967359582582, 0.85555)
0.8416
epoch: 5
(0.0015577166840434074, 0.85985)
0.8474
epoch: 6
(0.0014895240172743797, 0.86695)
0.8308
epoch: 7
(0.0014492553025484084, 0.8690833333333333)
0.8562
epoch: 8
(0.0014141719241937002, 0.8733166666666666)
0.8574
epoch: 9
(0.0013772437175114949, 0.8750166666666667)
0.8627
