In [1]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference
import os
import logging
logging.basicConfig(level=logging.ERROR)

#### 1. 定义神经网络结构相关的参数。

In [2]:
BATCH_SIZE = 100 
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
TRAINING_STEPS = 10000
MOVING_AVERAGE_DECAY = 0.99 
MODEL_SAVE_PATH = "MNIST_model/"
MODEL_NAME = "mnist_model"

#### 2. 定义训练过程，支持程序关闭后从checkpoint恢复训练。

In [3]:
def train(mnist):
    # 定义checkpoint保存点
    ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
    # 定义输入输出placeholder。
    x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
    y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')

    regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
    # 直接使用mnist_inference.py中定义的前向传播过程
    y = mnist_inference.inference(x, regularizer)
    global_step = tf.Variable(0, trainable=False)
    
    # 定义损失函数、学习率、滑动平均操作以及训练过程。
    variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables())
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
    cross_entropy_mean = tf.reduce_mean(cross_entropy)
    loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
    learning_rate = tf.train.exponential_decay(
        LEARNING_RATE_BASE,
        global_step,
        mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY,
        staircase=True)
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
    with tf.control_dependencies([train_step, variables_averages_op]):
        train_op = tf.no_op(name='train')
        
    # 初始化TensorFlow持久化类。
    saver = tf.train.Saver()  
    with tf.Session() as sess:
        saved_step = 0
        if ckpt and ckpt.model_checkpoint_path:
            print("checkpoint存在，直接恢复变量")
            saver.restore(sess, ckpt.model_checkpoint_path)
            # 恢复global_step
            saved_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
            sess.run(global_step.assign(saved_step))
        else:
            print("checkpoint不存在，进行变量初始化")
            tf.global_variables_initializer().run()

        for i in range(saved_step, TRAINING_STEPS):
            xs, ys = mnist.train.next_batch(BATCH_SIZE)
            _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
            if i % 1000 == 0:
                # 输出模型在当前batch上的损失函数大小
                print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
                # 保存模型
                # global_step参数：每个保存模型的文件名末尾加上训练的轮数
                saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
                last_step = step - 1000
                """
                if last_step > 0:
                    try:
                        os.remove(MODEL_SAVE_PATH+MODEL_NAME+"-"+str(last_step)+".index")
                        os.remove(MODEL_SAVE_PATH+MODEL_NAME+"-"+str(last_step)+".data-00000-of-00001")
                        os.remove(MODEL_SAVE_PATH+MODEL_NAME+"-"+str(last_step)+".meta")
                    except:
                        print("删除数据异常")
                    else:
                        print("成功删除：", MODEL_SAVE_PATH+MODEL_NAME+"-"+str(last_step)+".*")
                """

#### 3. 主程序入口。

In [4]:
def main(argv=None):
    mnist = input_data.read_data_sets("../../datasets/MNIST_data", one_hot=True)
    train(mnist)

if __name__ == '__main__':
    main()

Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.


Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.


Instructions for updating:
Please write your own downloading logic.


Instructions for updating:
Please write your own downloading logic.


Instructions for updating:
Please use tf.data to implement this functionality.


Instructions for updating:
Please use tf.data to implement this functionality.


Extracting ../../datasets/MNIST_data/train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.


Instructions for updating:
Please use tf.data to implement this functionality.


Extracting ../../datasets/MNIST_data/train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.


Instructions for updating:
Please use tf.one_hot on tensors.


Extracting ../../datasets/MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../../datasets/MNIST_data/t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.


Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.


checkpoint不存在，进行变量初始化
After 1 training step(s), loss on training batch is 2.85186.
After 1001 training step(s), loss on training batch is 0.204792.
After 2001 training step(s), loss on training batch is 0.170533.
After 3001 training step(s), loss on training batch is 0.135178.
After 4001 training step(s), loss on training batch is 0.119899.
After 5001 training step(s), loss on training batch is 0.101341.
After 6001 training step(s), loss on training batch is 0.0990699.
After 7001 training step(s), loss on training batch is 0.089756.
After 8001 training step(s), loss on training batch is 0.0776623.
After 9001 training step(s), loss on training batch is 0.070653.
