In [1]:
import numpy as np
import tensorflow as tf

In [2]:
# Parameters
learning_rate = 0.001
training_epochs = 6
batch_size = 600

# Import MNIST data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

train_dataset = (
    tf.data.Dataset.from_tensor_slices((tf.reshape(x_train, [-1, 784]), y_train))
    .batch(batch_size)
    .shuffle(1000)
)

train_dataset = (
    train_dataset.map(lambda x, y:
                      (tf.divide(tf.cast(x, tf.float32), 255.0),
                       tf.reshape(tf.one_hot(y, 10), (-1, 10))))
)

In [3]:
# Set model weights
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

In [4]:
# Construct model
model = lambda x: tf.nn.softmax(tf.matmul(x, W) + b) # Softmax
# Minimize error using cross entropy
compute_loss = lambda true, pred: tf.reduce_mean(tf.reduce_sum(tf.losses.binary_crossentropy(true, pred), axis=-1))
# caculate accuracy
compute_accuracy = lambda true, pred: tf.reduce_mean(tf.keras.metrics.categorical_accuracy(true, pred))
# Gradient Descent
optimizer = tf.optimizers.Adam(learning_rate)

for epoch in range(training_epochs):
    for i, (x_, y_) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            pred = model(x_)
            loss = compute_loss(y_, pred)
        acc = compute_accuracy(y_, pred)
        grads = tape.gradient(loss, [W, b])
        optimizer.apply_gradients(zip(grads, [W, b]))
        print("=> loss %.2f acc %.2f" %(loss.numpy(), acc.numpy()))

=> loss 195.05 acc 0.10
=> loss 191.90 acc 0.49
=> loss 188.57 acc 0.57
=> loss 185.54 acc 0.59
=> loss 182.33 acc 0.62
=> loss 178.60 acc 0.72
=> loss 176.53 acc 0.67
=> loss 174.14 acc 0.68
=> loss 170.18 acc 0.71
=> loss 168.72 acc 0.69
=> loss 167.69 acc 0.66
=> loss 163.87 acc 0.68
=> loss 161.03 acc 0.70
=> loss 158.00 acc 0.72
=> loss 155.47 acc 0.74
=> loss 152.42 acc 0.75
=> loss 152.62 acc 0.76
=> loss 145.45 acc 0.78
=> loss 144.87 acc 0.76
=> loss 147.08 acc 0.73
=> loss 141.74 acc 0.74
=> loss 136.57 acc 0.78
=> loss 138.47 acc 0.76
=> loss 133.01 acc 0.74
=> loss 131.67 acc 0.78
=> loss 138.72 acc 0.75
=> loss 126.61 acc 0.83
=> loss 127.99 acc 0.80
=> loss 131.20 acc 0.74
=> loss 133.83 acc 0.73
=> loss 114.84 acc 0.83
=> loss 113.11 acc 0.82
=> loss 113.97 acc 0.81
=> loss 117.88 acc 0.77
=> loss 116.33 acc 0.76
=> loss 113.45 acc 0.79
=> loss 116.09 acc 0.78
=> loss 108.93 acc 0.81
=> loss 101.96 acc 0.84
=> loss 102.48 acc 0.84
=> loss 111.73 acc 0.79
=> loss 105.36 a

=> loss 28.13 acc 0.95
=> loss 38.98 acc 0.89
=> loss 31.79 acc 0.92
=> loss 44.82 acc 0.88
=> loss 30.24 acc 0.92
=> loss 37.18 acc 0.90
=> loss 36.42 acc 0.91
=> loss 33.38 acc 0.92
=> loss 37.14 acc 0.90
=> loss 38.45 acc 0.89
=> loss 34.60 acc 0.92
=> loss 36.08 acc 0.90
=> loss 42.70 acc 0.88
=> loss 45.39 acc 0.88
=> loss 28.66 acc 0.93
=> loss 34.36 acc 0.90
=> loss 39.99 acc 0.89
=> loss 29.69 acc 0.93
=> loss 33.27 acc 0.92
=> loss 40.77 acc 0.88
=> loss 33.60 acc 0.90
=> loss 38.83 acc 0.90
=> loss 37.09 acc 0.89
=> loss 36.53 acc 0.90
=> loss 32.11 acc 0.93
=> loss 47.90 acc 0.87
=> loss 32.43 acc 0.92
=> loss 31.44 acc 0.92
=> loss 37.04 acc 0.90
=> loss 36.31 acc 0.90
=> loss 35.98 acc 0.89
=> loss 43.75 acc 0.88
=> loss 33.04 acc 0.92
=> loss 42.75 acc 0.88
=> loss 37.65 acc 0.89
=> loss 39.26 acc 0.89
=> loss 39.57 acc 0.90
=> loss 42.53 acc 0.88
=> loss 48.80 acc 0.87
=> loss 30.81 acc 0.93
=> loss 34.82 acc 0.91
=> loss 37.10 acc 0.88
=> loss 36.43 acc 0.90
=> loss 35.