<a href="https://colab.research.google.com/github/thai94/d2l/blob/main/4.multilayer_perceptrons/4_2_implementation_of_multilayer_perceptrons_from_scratch_tensorflow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [55]:
import tensorflow as tf

In [56]:
def load_data_fashion_mnist(batch_size, resize=None):
    """Download the Fashion-MNIST dataset and then load it into memory."""
    mnist_train, mnist_test = tf.keras.datasets.fashion_mnist.load_data()
    # Divide all numbers by 255 so that all pixel values are between
    # 0 and 1, add a batch dimension at the last. And cast label to int32
    process = lambda X, y: (tf.expand_dims(X, axis=3) / 255,
                            tf.cast(y, dtype='int32'))
    resize_fn = lambda X, y: (
        tf.image.resize_with_pad(X, resize, resize) if resize else X, y)
    return (
        tf.data.Dataset.from_tensor_slices(process(*mnist_train)).batch(
            batch_size).shuffle(len(mnist_train[0])).map(resize_fn),
        tf.data.Dataset.from_tensor_slices(process(*mnist_test)).batch(
            batch_size).map(resize_fn))

In [57]:
batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)

In [58]:
num_inputs, num_outputs, num_hiddens = 784, 10, 256
W1 = tf.Variable(tf.random.normal(
    shape=(num_inputs, num_hiddens), mean=0, stddev=0.01))
b1 = tf.Variable(tf.zeros(num_hiddens))
W2 = tf.Variable(tf.random.normal(
    shape=(num_hiddens, num_outputs), mean=0, stddev=0.01))
b2 = tf.Variable(tf.random.normal([num_outputs], stddev=.01))

params = [W1, b1, W2, b2]

In [59]:
def relu(X):
    return tf.math.maximum(X, 0)

In [60]:
def net(X):
  X = tf.reshape(X, (-1, num_inputs))
  H = relu(tf.matmul(X, W1) + b1)
  return tf.matmul(H, W2) + b2

In [61]:
def loss(y_hat, y):
  return tf.losses.sparse_categorical_crossentropy(y, y_hat, from_logits=True)

In [62]:
def accuracy(y_hat, y):
    """Compute the number of correct predictions."""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = tf.argmax(y_hat, axis=1)
    cmp = tf.cast(y_hat, y.dtype) == y
    return float(tf.reduce_sum(tf.cast(cmp, y.dtype)))

In [63]:
def evaluate_accuracy(net, data_iter):
    """Compute the accuracy for a model on a dataset."""
    metric = Accumulator(2)  # No. of correct predictions, no. of predictions
    for X, y in data_iter:
        metric.add(accuracy(net(X), y), len(y))
    return metric[0] / metric[1]

In [64]:
class Accumulator:
    """For accumulating sums over `n` variables."""
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [65]:
def train_epoch_ch3(net, train_iter, loss, updater, params, lr):
  metric = Accumulator(3)
  for X, y in train_iter:
    with tf.GradientTape() as tape:
      y_hat = net(X)
      l = loss(y_hat, y)
    
    updater(X.shape[0], tape.gradient(l, params), params, lr)
    l_sum = tf.reduce_sum(l)
    metric.add(l_sum, accuracy(y_hat, y), tf.size(y))
  return metric[0] / metric[2], metric[1] / metric[2]

In [66]:
def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater, params, lr):
  for epoch in range(num_epochs):
    train_metrics = train_epoch_ch3(net, train_iter, loss, updater, params, lr)
    test_acc = evaluate_accuracy(net, test_iter)
    print('epoch: %s' % epoch)
    print(train_metrics)
    print(test_acc)
  train_loss, train_acc = train_metrics

In [67]:
def sgd(params, grads, lr, batch_size):
    for param, grad in zip(params, grads):
        param.assign_sub(lr*grad/batch_size)

In [68]:
def updater(batch_size, grads, params, lr): 
  sgd(params, grads, lr, batch_size)

In [69]:
num_epochs, lr = 10, 0.1
train_ch3(net, train_iter, test_iter, loss, num_epochs, updater, [W1, W2, b1, b2], lr)

epoch: 0
(1.0390050524393717, 0.642)
0.7292
epoch: 1
(0.5978441931406657, 0.7899666666666667)
0.7968
epoch: 2
(0.5178074743906657, 0.81875)
0.8213
epoch: 3
(0.47877785797119143, 0.83305)
0.8128
epoch: 4
(0.45603862469991047, 0.8387333333333333)
0.8338
epoch: 5
(0.43227490488688153, 0.84765)
0.8396
epoch: 6
(0.41748050651550295, 0.8532166666666666)
0.8382
epoch: 7
(0.4036259128570557, 0.8585666666666667)
0.8495
epoch: 8
(0.39335877742767333, 0.8598333333333333)
0.8528
epoch: 9
(0.3802210620880127, 0.8654166666666666)
0.8539
