<a href="https://colab.research.google.com/github/thai94/d2l/blob/main/3.linear-neural-networks/3_6_implementation_of_softmax_regression_from_%1Dscratch_tensorflow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [159]:
import tensorflow as tf

In [160]:
def load_data_fashion_mnist(batch_size, resize=None):
    """Download the Fashion-MNIST dataset and then load it into memory."""
    mnist_train, mnist_test = tf.keras.datasets.fashion_mnist.load_data()
    # Divide all numbers by 255 so that all pixel values are between
    # 0 and 1, add a batch dimension at the last. And cast label to int32
    process = lambda X, y: (tf.expand_dims(X, axis=3) / 255,
                            tf.cast(y, dtype='int32'))
    resize_fn = lambda X, y: (
        tf.image.resize_with_pad(X, resize, resize) if resize else X, y)
    return (
        tf.data.Dataset.from_tensor_slices(process(*mnist_train)).batch(
            batch_size).shuffle(len(mnist_train[0])).map(resize_fn),
        tf.data.Dataset.from_tensor_slices(process(*mnist_test)).batch(
            batch_size).map(resize_fn))

In [161]:
batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)

In [162]:
num_inputs = 784
num_outputs = 10

W = tf.Variable(tf.random.normal(shape=(num_inputs, num_outputs), mean=0, stddev=0.01))
b = tf.Variable(tf.zeros(num_outputs))

In [163]:
def softmax(X):
  X_exp = tf.exp(X)
  partition = tf.reduce_sum(X_exp, 1, keepdims=True)
  return X_exp / partition

In [164]:
def net(X):
  return softmax(tf.matmul(tf.reshape(X, (-1, W.shape[0])), W) + b)

In [165]:
def cross_entropy(y_hat, y):
    return -tf.math.log(tf.boolean_mask(
        y_hat, tf.one_hot(y, depth=y_hat.shape[-1])))

In [166]:
y_hat = tf.constant([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y = tf.constant([0, 2])
cross_entropy(y_hat, y)

<tf.Tensor: shape=(2,), dtype=float32, numpy=array([2.3025851, 0.6931472], dtype=float32)>

In [167]:
def accuracy(y_hat, y):
    """Compute the number of correct predictions."""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = tf.argmax(y_hat, axis=1)
    cmp = tf.cast(y_hat, y.dtype) == y
    return float(tf.reduce_sum(tf.cast(cmp, y.dtype)))

In [168]:
def evaluate_accuracy(net, data_iter):
    """Compute the accuracy for a model on a dataset."""
    metric = Accumulator(2)  # No. of correct predictions, no. of predictions
    for X, y in data_iter:
        metric.add(accuracy(net(X), y), len(y))
    return metric[0] / metric[1]

In [169]:
class Accumulator:
    """For accumulating sums over `n` variables."""
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [170]:
def train_epoch_ch3(net, train_iter, loss, updater, params, lr):
  metric = Accumulator(3)
  for X, y in train_iter:
    with tf.GradientTape() as tape:
      y_hat = net(X)
      l = loss(y_hat, y)
    
    updater(X.shape[0], tape.gradient(l, params), params, lr)
    l_sum = tf.reduce_sum(l)
    metric.add(l_sum, accuracy(y_hat, y), tf.size(y))
  return metric[0] / metric[2], metric[1] / metric[2]

In [171]:
def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater, params, lr):
  for epoch in range(num_epochs):
    train_metrics = train_epoch_ch3(net, train_iter, loss, updater, params, lr)
    test_acc = evaluate_accuracy(net, test_iter)
    print('epoch: %s' % epoch)
    print(train_metrics)
    print(test_acc)
  train_loss, train_acc = train_metrics

In [172]:
def sgd(params, grads, lr, batch_size):
    for param, grad in zip(params, grads):
        param.assign_sub(lr*grad/batch_size)

In [173]:
def updater(batch_size, grads, params, lr): 
  sgd(params, grads, lr, batch_size)

In [174]:
num_epochs = 20
lr = 0.01
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater, [W, b], lr)

epoch: 0
(1.3711719393412272, 0.6414333333333333)
0.6809
epoch: 1
(0.9178400165557862, 0.7162333333333334)
0.7239
epoch: 2
(0.8033531344095866, 0.75105)
0.7474
epoch: 3
(0.742962868754069, 0.7679)
0.7588
epoch: 4
(0.7032462577819825, 0.77935)
0.7714
epoch: 5
(0.6744265129089355, 0.7885833333333333)
0.7775
epoch: 6
(0.6521441453297933, 0.7951333333333334)
0.7825
epoch: 7
(0.6340786262512207, 0.7996333333333333)
0.7885
epoch: 8
(0.6191635148366292, 0.8029)
0.7914
epoch: 9
(0.6064373116811117, 0.80595)
0.7954
epoch: 10
(0.5955943726857503, 0.8093)
0.7982
epoch: 11
(0.5860953343073527, 0.8118)
0.8011
epoch: 12
(0.577628319422404, 0.8142333333333334)
0.8036
epoch: 13
(0.5700759317398071, 0.8165333333333333)
0.8052
epoch: 14
(0.5632773867289226, 0.8184666666666667)
0.8072
epoch: 15
(0.557164646021525, 0.8199)
0.8083
epoch: 16
(0.5514022529602051, 0.8218333333333333)
0.809
epoch: 17
(0.5463516469319661, 0.8231833333333334)
0.8107
epoch: 18
(0.5416445468266805, 0.8247166666666667)
0.8115
epoch

In [175]:
for X, y in test_iter:
  break

In [176]:
tf.argmax(net(X), axis=1)

<tf.Tensor: shape=(256,), dtype=int64, numpy=
array([9, 2, 1, 1, 6, 1, 4, 6, 7, 7, 4, 5, 5, 3, 4, 1, 2, 6, 8, 0, 0, 7,
       7, 5, 1, 2, 6, 3, 9, 4, 8, 8, 3, 3, 8, 0, 7, 5, 7, 9, 0, 1, 3, 9,
       6, 7, 2, 1, 4, 6, 6, 2, 7, 6, 4, 2, 8, 2, 8, 0, 7, 7, 8, 5, 1, 1,
       3, 4, 7, 8, 7, 0, 6, 6, 2, 3, 1, 2, 8, 4, 1, 8, 5, 9, 5, 0, 3, 2,
       0, 6, 5, 3, 6, 7, 1, 8, 0, 1, 6, 2, 3, 6, 7, 2, 7, 8, 7, 9, 9, 4,
       2, 5, 7, 0, 5, 2, 8, 4, 7, 8, 0, 0, 9, 9, 3, 0, 8, 4, 1, 5, 4, 1,
       9, 1, 8, 4, 6, 1, 2, 5, 1, 0, 0, 0, 1, 6, 1, 3, 2, 2, 6, 6, 1, 3,
       5, 0, 4, 7, 9, 3, 7, 2, 3, 9, 0, 9, 4, 7, 4, 2, 6, 5, 2, 1, 2, 1,
       3, 0, 9, 1, 0, 9, 3, 8, 7, 9, 9, 4, 4, 7, 1, 2, 1, 6, 3, 2, 8, 3,
       6, 1, 1, 0, 2, 9, 2, 4, 0, 7, 9, 8, 4, 1, 8, 4, 1, 3, 1, 2, 7, 4,
       8, 5, 6, 0, 7, 7, 6, 6, 7, 0, 7, 8, 9, 2, 9, 0, 5, 1, 4, 2, 5, 4,
       9, 6, 2, 8, 6, 4, 2, 4, 9, 7, 4, 5, 5, 4])>

In [177]:
y

<tf.Tensor: shape=(256,), dtype=int32, numpy=
array([9, 2, 1, 1, 6, 1, 4, 6, 5, 7, 4, 5, 7, 3, 4, 1, 2, 4, 8, 0, 2, 5,
       7, 9, 1, 4, 6, 0, 9, 3, 8, 8, 3, 3, 8, 0, 7, 5, 7, 9, 6, 1, 3, 7,
       6, 7, 2, 1, 2, 2, 4, 4, 5, 8, 2, 2, 8, 4, 8, 0, 7, 7, 8, 5, 1, 1,
       2, 3, 9, 8, 7, 0, 2, 6, 2, 3, 1, 2, 8, 4, 1, 8, 5, 9, 5, 0, 3, 2,
       0, 6, 5, 3, 6, 7, 1, 8, 0, 1, 4, 2, 3, 6, 7, 2, 7, 8, 5, 9, 9, 4,
       2, 5, 7, 0, 5, 2, 8, 6, 7, 8, 0, 0, 9, 9, 3, 0, 8, 4, 1, 5, 4, 1,
       9, 1, 8, 6, 2, 1, 2, 5, 1, 0, 0, 0, 1, 6, 1, 6, 2, 2, 4, 4, 1, 4,
       5, 0, 4, 7, 9, 3, 7, 2, 3, 9, 0, 9, 4, 7, 4, 2, 0, 5, 2, 1, 2, 1,
       3, 0, 9, 1, 0, 9, 3, 6, 7, 9, 9, 4, 4, 7, 1, 2, 1, 6, 3, 2, 8, 3,
       6, 1, 1, 0, 2, 9, 2, 4, 0, 7, 9, 8, 4, 1, 8, 4, 1, 3, 1, 6, 7, 2,
       8, 5, 2, 0, 7, 7, 6, 2, 7, 0, 7, 8, 9, 2, 9, 0, 5, 1, 4, 4, 5, 6,
       9, 2, 6, 8, 6, 4, 2, 2, 9, 7, 6, 5, 5, 2], dtype=int32)>