In [5]:
import tensorflow as tf
import numpy as np


# get the data
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()


def preprocess_images(images):
    return images.reshape(-1, 784).astype(np.float32) / 255


def preprocess_labels(labels):
    return labels.reshape(-1).astype(np.int32)


train_images = preprocess_images(train_images)
test_images = preprocess_images(test_images)
train_labels = preprocess_labels(train_labels)
test_labels = preprocess_labels(test_labels)

train_data = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(60000).batch(128).repeat()
#test_data = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(128)


# define the model first, from input to output

# 2 layers again
n_units = 100
n_layers = 2
w_range = 0.1

# just set up a "chain" of hidden layers
# model is represented by a list where each element is a layer,
# and each layer is in turn a list of the layer variables (w, b)

# first layer goes from n_input to n_hidden
w_input = tf.Variable(tf.random.uniform([784, n_units], -w_range, w_range),
                      name="w0")
b_input = tf.Variable(tf.zeros(n_units), name="b0")
layers = [[w_input, b_input]]

# all other hidden layers go from n_hidden to n_hidden
for layer in range(n_layers - 1):
    w = tf.Variable(tf.random.uniform([n_units, n_units], -w_range, w_range),
                    name="w" + str(layer+1))
    b = tf.Variable(tf.zeros(n_units), name="b" + str(layer+1))
    layers.append([w, b])

# finally add the output layer
w_out = tf.Variable(tf.random.uniform([n_units, 10], -w_range, w_range),
                    name="wout")
b_out = tf.Variable(tf.zeros(10), name="bout")
layers.append([w_out, b_out])

# flatten the layers to get a list of variables
all_variables = [variable for layer in layers for variable in layer]


def model_forward(inputs):
    x = inputs
    for w, b in layers[:-1]:
        x = tf.nn.relu(tf.matmul(x, w) + b)
    # finally, the softmax classification output layer :)))
    logits = tf.nn.softmax(tf.matmul(x, layers[-1][0]) + layers[-1][1])

    return logits


lr = 0.2
train_steps = 3000
for step, (img_batch, lbl_batch) in enumerate(train_data):
    if step > train_steps:
        break

    with tf.GradientTape() as tape:
        # here we just run all the layers in sequence via a for-loop
        logits = model_forward(img_batch)
        xent = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=lbl_batch))

    grads = tape.gradient(xent, all_variables)
    for grad, var in zip(grads, all_variables):
        var.assign_sub(lr*grad)

    if not step % 100:
        preds = tf.argmax(logits, axis=1, output_type=tf.int32)
        acc = tf.reduce_mean(tf.cast(tf.equal(preds, lbl_batch), tf.float32))
        print("Loss: {} Accuracy: {}".format(xent, acc))


test_preds = model_forward(test_images)
test_preds = tf.argmax(test_preds, axis=1, output_type=tf.int32)
acc = tf.reduce_mean(tf.cast(tf.equal(test_preds, test_labels), tf.float32))
print("Final test accuracy: {}".format(acc))


Loss: 2.302792549133301 Accuracy: 0.0859375
Loss: 2.257176399230957 Accuracy: 0.1328125
Loss: 2.049697160720825 Accuracy: 0.421875
Loss: 1.8722679615020752 Accuracy: 0.6015625
Loss: 1.7706685066223145 Accuracy: 0.7109375
Loss: 1.7247099876403809 Accuracy: 0.75
Loss: 1.7379562854766846 Accuracy: 0.7578125
Loss: 1.738759994506836 Accuracy: 0.734375
Loss: 1.7297308444976807 Accuracy: 0.734375
Loss: 1.6786214113235474 Accuracy: 0.78125
Loss: 1.7088273763656616 Accuracy: 0.7421875
Loss: 1.687302827835083 Accuracy: 0.8046875
Loss: 1.6289303302764893 Accuracy: 0.8515625
Loss: 1.5865097045898438 Accuracy: 0.8984375
Loss: 1.5426028966903687 Accuracy: 0.9375
Loss: 1.5514631271362305 Accuracy: 0.9140625
Loss: 1.6319334506988525 Accuracy: 0.84375
Loss: 1.5473146438598633 Accuracy: 0.9140625
Loss: 1.554461121559143 Accuracy: 0.9140625
Loss: 1.5622340440750122 Accuracy: 0.9140625
Loss: 1.5521352291107178 Accuracy: 0.9296875
Loss: 1.55103600025177 Accuracy: 0.921875
Loss: 1.528256893157959 Accuracy: 