In [25]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.datasets import mnist

# 加载 MNIST 数据集
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# 预处理数据
train_images = train_images.reshape(-1, 28, 28, 1).astype('float32') / 255.0
test_images = test_images.reshape(-1, 28, 28, 1).astype('float32') / 255.0
train_labels = tf.one_hot(train_labels, 10)
test_labels = tf.one_hot(test_labels, 10)

batch_size = 100
learning_rate = 1e-4
keep_prob_rate = 0.7
max_epoch = 2000


def compute_accuracy(v_xs, v_ys):
    y_pre = forward(v_xs, training=False)
    correct_prediction = tf.equal(tf.argmax(y_pre, 1), tf.argmax(v_ys, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    return accuracy.numpy()


def weight_variable(shape):
    initial = tf.random.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)


def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)


def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')


def max_pool_2x2(x):
    return tf.nn.max_pool2d(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')


# define placeholder for inputs to network
xs = tf.Variable(tf.zeros([batch_size, 28, 28, 1]), dtype=tf.float32)
ys = tf.Variable(tf.zeros([batch_size, 10]), dtype=tf.float32)
keep_prob = tf.Variable(keep_prob_rate, dtype=tf.float32)

# 卷积层 1
W_conv1 = weight_variable([7, 7, 1, 32])  # patch 7x7, in size 1, out size 32
b_conv1 = bias_variable([32])
h_conv1 = tf.nn.relu(conv2d(xs, W_conv1) + b_conv1)  # 卷积 选择激活函数
h_pool1 = max_pool_2x2(h_conv1)  # 池化

# 卷积层 2
W_conv2 = weight_variable([5, 5, 32, 64])  # patch 5x5, in size 32, out size 64
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)  # 卷积 选择激活函数
h_pool2 = max_pool_2x2(h_conv2)  # 池化

# 全连接层 1
W_fc1 = weight_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024])

h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
h_fc1_drop = tf.nn.dropout(h_fc1, rate=1-keep_prob)

# 全连接层 2
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])


def forward(x, training):
    h_conv1 = tf.nn.relu(conv2d(x, W_conv1) + b_conv1)
    h_pool1 = max_pool_2x2(h_conv1)
    h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
    h_pool2 = max_pool_2x2(h_conv2)
    h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
    h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
    h_fc1_drop = tf.nn.dropout(h_fc1, rate=1-keep_prob if training else 0.0)
    logits = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
    return tf.nn.softmax(logits)


# 交叉熵函数
cross_entropy = tf.reduce_mean(
    tf.nn.softmax_cross_entropy_with_logits(labels=ys, logits=forward(xs, training=True)))
optimizer = tf.optimizers.Adam(learning_rate)

# 训练步骤


@tf.function
def train_step(batch_xs, batch_ys):
    with tf.GradientTape() as tape:
        logits = forward(batch_xs, training=True)
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
            labels=batch_ys, logits=logits))
    grads = tape.gradient(
        loss, [W_conv1, b_conv1, W_conv2, b_conv2, W_fc1, b_fc1, W_fc2, b_fc2])
    optimizer.apply_gradients(
        zip(grads, [W_conv1, b_conv1, W_conv2, b_conv2, W_fc1, b_fc1, W_fc2, b_fc2]))
    return loss


# 训练模型
for epoch in range(max_epoch):
    indices = np.random.choice(train_images.shape[0], batch_size)
    batch_xs = tf.gather(train_images, indices)
    batch_ys = tf.gather(train_labels, indices)
    loss = train_step(batch_xs, batch_ys)
    if epoch % 100 == 0:
        accuracy = compute_accuracy(test_images[:1000], test_labels[:1000])
        print(f'Epoch {epoch}, Loss: {loss.numpy()}, Accuracy: {accuracy}')

Epoch 0, Loss: 2.3644542694091797, Accuracy: 0.1599999964237213
Epoch 100, Loss: 2.106959342956543, Accuracy: 0.4050000011920929
Epoch 200, Loss: 1.8580803871154785, Accuracy: 0.6119999885559082
Epoch 300, Loss: 1.7715262174606323, Accuracy: 0.7260000109672546
Epoch 400, Loss: 1.7532286643981934, Accuracy: 0.7379999756813049
Epoch 500, Loss: 1.733815312385559, Accuracy: 0.7490000128746033
Epoch 600, Loss: 1.7386484146118164, Accuracy: 0.7570000290870667
Epoch 700, Loss: 1.6645724773406982, Accuracy: 0.7590000033378601
Epoch 800, Loss: 1.5931472778320312, Accuracy: 0.8360000252723694
Epoch 900, Loss: 1.5848629474639893, Accuracy: 0.8519999980926514
Epoch 1000, Loss: 1.5748093128204346, Accuracy: 0.8569999933242798
Epoch 1100, Loss: 1.586379051208496, Accuracy: 0.8539999723434448
Epoch 1200, Loss: 1.5114654302597046, Accuracy: 0.953000009059906
Epoch 1300, Loss: 1.508358359336853, Accuracy: 0.9599999785423279
Epoch 1400, Loss: 1.5209400653839111, Accuracy: 0.9639999866485596
Epoch 1500, 