In [13]:
#
#  前向传播示例，通过底层实现简单三层网络结构了解深度学习的基本原理
#
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets

# keras自带加载示例数据集的方法，此方法会返回训练数据集和测试数据集
# 用x和y分别接收样本和标签, x_test和y_test接收测试数据集
(x, y), (x_test, y_test) = datasets.mnist.load_data()
# 默认x和y是numpy类型，转换成tensor，顺便做归一化
x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.
y = tf.convert_to_tensor(y, dtype=tf.int32)

x_test = tf.convert_to_tensor(x_test, dtype=tf.float32) / 255.
y_test = tf.convert_to_tensor(y_test, dtype=tf.int32)

print(x.shape, y.shape, x_test.shape, y_test.shape)

# 生成训练db，将x和y合并，并定义每个batch的数量
train_db = tf.data.Dataset.from_tensor_slices((x,y)).batch(128)
test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test)).batch(128)

# 验证一下db的样本是不是我们想要的
train_iter = iter(train_db)
sample = next(train_iter)
print(sample[0].shape, sample[1].shape)

# 初始化参数，我们这次是一个三层网络结构，x的shape是[b, 784]，每层输出的变化如下：
# 第一层：x[b, 784]@w[784, 256] + bias[b] = [b, 256] + [b] = [b, 256]
# 第二层：x[b, 256]@w[256, 128] + bias[b] = [b, 128] + [b] = [b, 128]
# 第三层：x[b, 128]@w[128, 10] + bias[b] = [b, 10] + [b] = [b, 10]
# 参数一定要是tf.Variable类型，才能被记录并计算梯度
w1 = tf.Variable(tf.random.truncated_normal([784, 256], stddev=0.1))  # 一般w通过truncated_normal初始化，标准差设置成0.1，防止梯度爆炸
b1 = tf.Variable(tf.ones([256]))
w2 = tf.Variable(tf.random.truncated_normal([256, 128], stddev=0.1))
b2 = tf.Variable(tf.ones([128]))
w3 = tf.Variable(tf.random.truncated_normal([128, 10], stddev=0.1))
b3 = tf.Variable(tf.ones([10]))
# 学习率
lr = 1e-3

for epoch in range(100):
    
    for step, (x, y) in enumerate(train_db):
        # reshape x
        x = tf.reshape(x, [-1,28*28])
        
        # 需要将网络的前向计算过程包裹在tape中，这样才能记录Variable的变化，从而计算梯度
        with tf.GradientTape() as tape:
            
            # 实际的前向计算过程是：
            # 第一层的输出：h1 = relu(x @ w1 + b1)
            # 第二层的输出：h2 = relu(h1 @ w2 + b2)
            # 第三层的输出：out = h2 @ w3 + b3，最后一层都不需要relu函数
            h1 = x @ w1 + b1
            h1 = tf.nn.relu(h1)

            h2 = h1 @ w2 + b2
            h2 = tf.nn.relu(h2)

            out = h2 @ w3 + b3
            
            # y默认的shape是[b]，而out是[b, 10]
            # 所以y需要做one hot，变成[b, 10]，然后计算loss
            y = tf.one_hot(y, depth=10)
            
            # 计算loss，先对y-out的值求平方
            # 然后求平均，这里的reduce_mean相当于对每个元素求平均值，相当于sum/batch/10，这个对整体loss的梯度并不会产生影响
            mse = tf.square(y - out)
            loss = tf.reduce_mean(mse)
        
        # 通过此方法，求w和b对loss的梯度
        grads = tape.gradient(loss, [w1, b1, w2, b2, w3, b3])
        
        # 原地更新各个参数
        w1.assign_sub(lr*grads[0])
        b1.assign_sub(lr*grads[1])
        w2.assign_sub(lr*grads[2])
        b2.assign_sub(lr*grads[3])
        w3.assign_sub(lr*grads[4])
        b3.assign_sub(lr*grads[5])


        if step%100 == 0:
            print('epoch:', epoch, '\t', 'step:', step, '\t', 'loss:', float(loss))
    
    total_correct, total_num = 0, 0
    for step, (x_test, y_test) in enumerate(test_db):
        # reshape x_test
        x_test = tf.reshape(x_test, [-1,28*28])
        
        # [128, 784] => [128, 512]
        h1 = tf.nn.relu(x_test @ w1 + b1)
        # [128, 512] => [128, 256]
        h2 = tf.nn.relu(h1 @ w2 + b2)
        # [128, 265] => [128, 10]
        out = h2 @ w3 + b3
        
        # softmax将out限定在0到1之间，并且总和为1，效果就是强者恒强，让结果更分明
        # [128, 10]
        pred = tf.nn.softmax(out, axis=1)
        # print(pred.shape)
        # 通过argmax获得真正的推理结果
        # [128,]
        # 默认pred是int64，后边会保类型不一致错误，强转int32
        pred = tf.cast(tf.argmax(pred, axis=1), dtype=tf.int32)
        # print(pred.shape)
        # 通过equal得到pred和y_test的对应预测结果正确的tensor，结果为[true,false...]布尔类型的tensor
        # 强转为int32
        correct = tf.cast(tf.equal(pred, y_test), dtype=tf.int32)
        correct = tf.reduce_sum(correct)
        # 累加每个step的预测正确的数量
        total_correct += int(correct)
        # 累加已测试样本数
        total_num += x_test.shape[0]
    
    # 计算每个epoch的正确率
    acc_rate = total_correct / total_num
    print('epoch {} --- accuracy rate {}'.format(epoch, acc_rate))


(60000, 28, 28) (60000,) (10000, 28, 28) (10000,)
(128, 28, 28) (128,)
epoch: 0 	 step: 0 	 loss: 3.6722922325134277
epoch: 0 	 step: 100 	 loss: 0.38453489542007446
epoch: 0 	 step: 200 	 loss: 0.3504512906074524
epoch: 0 	 step: 300 	 loss: 0.2855910658836365
epoch: 0 	 step: 400 	 loss: 0.24766042828559875
epoch 0 --- accuracy rate 0.1978
epoch: 1 	 step: 0 	 loss: 0.23160989582538605
epoch: 1 	 step: 100 	 loss: 0.21969887614250183
epoch: 1 	 step: 200 	 loss: 0.2174583226442337
epoch: 1 	 step: 300 	 loss: 0.19014450907707214
epoch: 1 	 step: 400 	 loss: 0.17668640613555908
epoch 1 --- accuracy rate 0.3227
epoch: 2 	 step: 0 	 loss: 0.17258940637111664
epoch: 2 	 step: 100 	 loss: 0.1704375147819519
epoch: 2 	 step: 200 	 loss: 0.1706511229276657
epoch: 2 	 step: 300 	 loss: 0.15302656590938568
epoch: 2 	 step: 400 	 loss: 0.14563919603824615
epoch 2 --- accuracy rate 0.4089
epoch: 3 	 step: 0 	 loss: 0.14432339370250702
epoch: 3 	 step: 100 	 loss: 0.14465390145778656
epoch: 3 	 

epoch: 29 	 step: 100 	 loss: 0.059586383402347565
epoch: 29 	 step: 200 	 loss: 0.0580902099609375
epoch: 29 	 step: 300 	 loss: 0.05687941238284111
epoch: 29 	 step: 400 	 loss: 0.06273476779460907
epoch 29 --- accuracy rate 0.7627
epoch: 30 	 step: 0 	 loss: 0.05476672574877739
epoch: 30 	 step: 100 	 loss: 0.058945298194885254
epoch: 30 	 step: 200 	 loss: 0.05736606568098068
epoch: 30 	 step: 300 	 loss: 0.05620614439249039
epoch: 30 	 step: 400 	 loss: 0.06213857978582382
epoch 30 --- accuracy rate 0.7659
epoch: 31 	 step: 0 	 loss: 0.05408364534378052
epoch: 31 	 step: 100 	 loss: 0.058338046073913574
epoch: 31 	 step: 200 	 loss: 0.056673694401979446
epoch: 31 	 step: 300 	 loss: 0.055562347173690796
epoch: 31 	 step: 400 	 loss: 0.06157444790005684
epoch 31 --- accuracy rate 0.7692
epoch: 32 	 step: 0 	 loss: 0.05342373251914978
epoch: 32 	 step: 100 	 loss: 0.05775492638349533
epoch: 32 	 step: 200 	 loss: 0.05601781606674194
epoch: 32 	 step: 300 	 loss: 0.05495355650782585


epoch: 58 	 step: 100 	 loss: 0.04828064143657684
epoch: 58 	 step: 200 	 loss: 0.045518048107624054
epoch: 58 	 step: 300 	 loss: 0.0451788529753685
epoch: 58 	 step: 400 	 loss: 0.052213240414857864
epoch 58 --- accuracy rate 0.8226
epoch: 59 	 step: 0 	 loss: 0.0428839847445488
epoch: 59 	 step: 100 	 loss: 0.04804065078496933
epoch: 59 	 step: 200 	 loss: 0.04527352377772331
epoch: 59 	 step: 300 	 loss: 0.044934406876564026
epoch: 59 	 step: 400 	 loss: 0.052001167088747025
epoch 59 --- accuracy rate 0.8241
epoch: 60 	 step: 0 	 loss: 0.04264971241354942
epoch: 60 	 step: 100 	 loss: 0.047809455543756485
epoch: 60 	 step: 200 	 loss: 0.045036133378744125
epoch: 60 	 step: 300 	 loss: 0.0446971096098423
epoch: 60 	 step: 400 	 loss: 0.05179222673177719
epoch 60 --- accuracy rate 0.8249
epoch: 61 	 step: 0 	 loss: 0.04242371767759323
epoch: 61 	 step: 100 	 loss: 0.047579605132341385
epoch: 61 	 step: 200 	 loss: 0.04480490833520889
epoch: 61 	 step: 300 	 loss: 0.0444643534719944
e

epoch: 87 	 step: 100 	 loss: 0.04295936971902847
epoch: 87 	 step: 200 	 loss: 0.040199510753154755
epoch: 87 	 step: 300 	 loss: 0.03980197012424469
epoch: 87 	 step: 400 	 loss: 0.04746449738740921
epoch 87 --- accuracy rate 0.8475
epoch: 88 	 step: 0 	 loss: 0.037760794162750244
epoch: 88 	 step: 100 	 loss: 0.04281621053814888
epoch: 88 	 step: 200 	 loss: 0.04006411135196686
epoch: 88 	 step: 300 	 loss: 0.039665862917900085
epoch: 88 	 step: 400 	 loss: 0.047333039343357086
epoch 88 --- accuracy rate 0.848
epoch: 89 	 step: 0 	 loss: 0.03763044998049736
epoch: 89 	 step: 100 	 loss: 0.04267502576112747
epoch: 89 	 step: 200 	 loss: 0.039931345731019974
epoch: 89 	 step: 300 	 loss: 0.0395306721329689
epoch: 89 	 step: 400 	 loss: 0.04720446839928627
epoch 89 --- accuracy rate 0.8483
epoch: 90 	 step: 0 	 loss: 0.037500638514757156
epoch: 90 	 step: 100 	 loss: 0.04253661632537842
epoch: 90 	 step: 200 	 loss: 0.0397992879152298
epoch: 90 	 step: 300 	 loss: 0.03939611464738846
e