In [1]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import LeNet5_infernece
import os
import numpy as np

#### 1. 定义神经网络相关的参数

In [2]:
BATCH_SIZE = 100
LEARNING_RATE_BASE = 0.01
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
TRAINING_STEPS = 200
MOVING_AVERAGE_DECAY = 0.99

#### 2. 定义训练过程

In [3]:
def train(mnist):
    # 定义输出为4维矩阵的placeholder
    x = tf.placeholder(tf.float32, [
            None,
            LeNet5_infernece.IMAGE_SIZE,
            LeNet5_infernece.IMAGE_SIZE,
            LeNet5_infernece.NUM_CHANNELS],
        name='x-input')
    y_ = tf.placeholder(tf.float32, [None, LeNet5_infernece.OUTPUT_NODE], name='y-input')
    
    regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
    y = LeNet5_infernece.inference(x,False,regularizer)
    global_step = tf.Variable(0, trainable=False)

    # 定义损失函数、学习率、滑动平均操作以及训练过程。
    variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables())
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
    cross_entropy_mean = tf.reduce_mean(cross_entropy)
    loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
    learning_rate = tf.train.exponential_decay(
        LEARNING_RATE_BASE,
        global_step,
        mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY,
        staircase=True)

    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
    with tf.control_dependencies([train_step, variables_averages_op]):
        train_op = tf.no_op(name='train')
        
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        
    # 初始化TensorFlow持久化类。
    saver = tf.train.Saver()
    with tf.Session() as sess:
        
        tf.global_variables_initializer().run()
        for i in range(TRAINING_STEPS):
            xs, ys = mnist.train.next_batch(BATCH_SIZE)

            reshaped_xs = np.reshape(xs, (
                BATCH_SIZE,
                LeNet5_infernece.IMAGE_SIZE,
                LeNet5_infernece.IMAGE_SIZE,
                LeNet5_infernece.NUM_CHANNELS))
            _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: reshaped_xs, y_: ys})

            if i % 10 == 0:
                print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
        
        #dev验证集进行验证测试
        reshaped_xs = np.reshape(mnist.validation.images, (
                -1,
                LeNet5_infernece.IMAGE_SIZE,
                LeNet5_infernece.IMAGE_SIZE,
                LeNet5_infernece.NUM_CHANNELS))
        validate_feed = {x: reshaped_xs, y_: mnist.validation.labels}
        print sess.run(accuracy,feed_dict = validate_feed)
    


#### 3. 主程序入口

In [4]:
def main(argv=None):
    #改成你自己的mnist数据文件夹目录
    mnist = input_data.read_data_sets("../../../datasets/MNIST_data", one_hot=True)
    train(mnist)

if __name__ == '__main__':
    main()

Extracting ../../../datasets/MNIST_data/train-images-idx3-ubyte.gz
Extracting ../../../datasets/MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../../../datasets/MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../../../datasets/MNIST_data/t10k-labels-idx1-ubyte.gz
After 1 training step(s), loss on training batch is 4.51752.
After 11 training step(s), loss on training batch is 2.13127.
After 21 training step(s), loss on training batch is 1.90628.
After 31 training step(s), loss on training batch is 1.47134.
After 41 training step(s), loss on training batch is 1.2439.
After 51 training step(s), loss on training batch is 1.1529.
After 61 training step(s), loss on training batch is 1.08565.
After 71 training step(s), loss on training batch is 1.17094.
After 81 training step(s), loss on training batch is 1.13228.
After 91 training step(s), loss on training batch is 1.08917.
After 101 training step(s), loss on training batch is 0.996232.
After 111 training step(s), loss on training batch is