In [1]:
import tensorflow as tf
from tensorflow import keras
from keras import datasets

In [2]:
# x: [60k, 28, 28]
# y: [60k]
(x, y), (x_test, y_test) = datasets.mnist.load_data()

x = tf.convert_to_tensor(x, dtype=tf.float32)
y = tf.convert_to_tensor(y, dtype=tf.int32)
x_test = tf.convert_to_tensor(x_test, dtype=tf.float32)
y_test = tf.convert_to_tensor(y_test, dtype=tf.int32)

print(x.shape, y.shape, x.dtype, y.dtype)
print(tf.reduce_min(x), tf.reduce_max(x))
print(tf.reduce_min(y), tf.reduce_max(y))



(60000, 28, 28) (60000,) <dtype: 'float32'> <dtype: 'int32'>
tf.Tensor(0.0, shape=(), dtype=float32) tf.Tensor(255.0, shape=(), dtype=float32)
tf.Tensor(0, shape=(), dtype=int32) tf.Tensor(9, shape=(), dtype=int32)


In [3]:
# 创建数据集，方便取batch
train_db = tf.data.Dataset.from_tensor_slices((x, y)).batch(128)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(128)
train_iter = iter(train_db) # 获取迭代器
sample = next(train_iter)
print('batch:', sample[0].shape, sample[1].shape) # 从输出可以看到  每次读取了128张图片

batch: (128, 28, 28) (128,)


In [8]:
# y = x@w + b
# Input => Out: [b, 784] => [b, 256] => [b, 128] => [b, 10] 
# [dim_in, dim_out], [dim_out]
# 初始化权重很重要，尤其注意范围，否则很可能导致梯度爆炸
w1 = tf.Variable(tf.random.truncated_normal([784, 256], stddev=0.01)) # 为了支持tf.GradientTape，需要转为tf.Variable
b1 = tf.Variable(tf.zeros([256]))
w2 = tf.Variable(tf.random.truncated_normal([256, 128], stddev=0.01))
b2 = tf.Variable(tf.zeros([128]))
w3 = tf.Variable(tf.random.truncated_normal([128, 10], stddev=0.01))
b3 = tf.Variable(tf.zeros([10]))

lr = 1e-1 # 学习率过大会导致gradient exploding

for epoch in range(10): # iterate db for 10 数据不够，多迭代几次
    for step, (x, y) in enumerate(train_db):
        # x: [128, 28, 28]
        # y: [128]
        x = tf.reshape(x, [-1, 28 * 28]) # x: [b, 28 * 28]
        # [b, 784]@[784, 256] + [256] => [b, 256] + [256]

        with tf.GradientTape() as tape: # 记录梯度 只能记录类型为tf.Variable的变量
            # tape.watch([w1, b1, w2, b2, w3, b3]) # 这里如果不手动watch,就需要把变量声明为tf.Variable

            h1 = x@w1 + b1
            h1 = tf.nn.relu(h1)
            h2 = h1@w2 + b2
            h2 = tf.nn.relu(h2)
            out = h2@w3 + b3

            # compute loss
            # out: [b, 10]
            # y: [b] => [b, 10]
            y_onehot = tf.one_hot(y, depth=10) # one-hot encoding

            # mse = mean(sum(y-out)^2)
            # [b, 10]
            loss = tf.square(y_onehot - out)
            loss = tf.reduce_mean(loss)
        # compute gradients
        grads = tape.gradient(loss, [w1, b1, w2, b2, w3, b3])
        grads, _ = tf.clip_by_global_norm(grads, 15) # 解决gradient exploding
        # w1 = w1 - lr * w1_grad
        w1.assign_sub(lr * grads[0])
        b1.assign_sub(lr * grads[1])
        w2.assign_sub(lr * grads[2])
        b2.assign_sub(lr * grads[3])
        w3.assign_sub(lr * grads[4])
        b3.assign_sub(lr * grads[5])
        # w1 = tf.Variable(w1 - lr * grads[0]) 
        # b1 = tf.Variable(b1 - lr * grads[1])
        # w2 = tf.Variable(w2 - lr * grads[2])
        # b2 = tf.Variable(b2 - lr * grads[3])
        # w3 = tf.Variable(w3 - lr * grads[4])
        # b3 = tf.Variable(b3 - lr * grads[5])

        if step % 100 == 0:
            print(epoch, step, 'loss:', float(loss))

    # test/evluation
    # [w1, b1, w2, b2, w3, b3]
    total_correct, total_num = 0, 0
    for step, (x, y) in enumerate(test_db):
        # [b, 28, 28] => [b, 28 * 28]
        x = tf.reshape(x, [-1, 28*28])
        # [b, 784] => [b, 256] => [b, 128] => [b, 10]
        h1 = tf.nn.relu(x@w1 + b1)
        h2 = tf.nn.relu(h1@w2 + b2)
        out = h2@w3 + b3
        # out: [b, 10] ~ R
        # prob: [b, 10] ~ [0, 1]
        prob = tf.nn.softmax(out, axis=1)
        # [b, 10] => [b]
        pred = tf.cast(tf.argmax(prob, axis=1), dtype=tf.int32)
        # y: [b]
        correct = tf.cast(tf.equal(pred, y), dtype=tf.int32)
        correct = tf.reduce_sum(correct)
        total_correct += int(correct)
        total_num += x.shape[0]
    acc = total_correct / total_num
    print('test acc:', acc)

    
        

0 0 loss: 0.12915447354316711
0 100 loss: 0.02183777466416359
0 200 loss: 0.014754688367247581
0 300 loss: 0.014491741545498371
0 400 loss: 0.01534547470510006
test acc: 0.938
1 0 loss: 0.011050170287489891
1 100 loss: 0.01046844944357872
1 200 loss: 0.012057008221745491
1 300 loss: 0.009639924392104149
1 400 loss: 0.011321398429572582
test acc: 0.9566
2 0 loss: 0.008592674508690834
2 100 loss: 0.008449897170066833
2 200 loss: 0.010562802664935589
2 300 loss: 0.007304591126739979
2 400 loss: 0.009837128221988678
test acc: 0.9652
3 0 loss: 0.007479909807443619
3 100 loss: 0.007153123617172241
3 200 loss: 0.009035671129822731
3 300 loss: 0.006315006408840418
3 400 loss: 0.008303827606141567
test acc: 0.9679
4 0 loss: 0.006521566770970821
4 100 loss: 0.006263895891606808
4 200 loss: 0.007643335964530706
4 300 loss: 0.0056782858446240425
4 400 loss: 0.007644444704055786
test acc: 0.9712
5 0 loss: 0.005835401825606823
5 100 loss: 0.005733797326683998
5 200 loss: 0.006747194565832615
5 300 l