In [None]:
#对于tanh,sigmod等函数，在接近边界值时，但是又与实际值误差较大时，由于此时梯度趋于0,下降速度太慢，导致参数调整太慢，训练速度太慢
#我们采用交叉熵进行处理，使得当误差越大时，梯度也就越大，参数的调整就越快，训练速度越快
#交叉熵的概念可以参考逻辑回归中的代价函数https://www.cnblogs.com/ssyfj/p/12799137.html

In [10]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np

In [12]:
#数据载入---有其他方法
mnist = input_data.read_data_sets("MNIST_data",one_hot=True) #注意输出标签是one_hot类型

#按批次进行计算
#设置每个批次大小
batch_size = 100
#获取批次数量，注意：我们训练时使用的是mnist.train训练集数据
n_batch = mnist.train.num_examples // batch_size #注意：//也是除法，返回整数

#开始定义占位符号
x = tf.placeholder(tf.float32,[None,784]) #由于每张图片都是28×28像素，所以每个样本的特征都是784
y = tf.placeholder(tf.float32,[None,10]) #是分10类

#开始创建一个神经网络隐藏层L1含有10个激活单元
Weight_L1 = tf.Variable(tf.zeros([784,10]))
biases_L1 = tf.Variable(tf.zeros([1,10])) #对于偏执单元的权重，我们没有必要进行随机初始化
Wx_plus_b_L1 = tf.matmul(x,Weight_L1)+biases_L1
y_pred = tf.nn.softmax(Wx_plus_b_L1)

#定义代价函数
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=y_pred)) #这里我们使用交叉熵进行处理，使得加快训练速度
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)

init = tf.global_variables_initializer()

#获取结果存放在bool列表中
correct_prediction = tf.equal(tf.arg_max(y,1),tf.arg_max(y_pred,1)) #注意：arg_max第二个参数是指定坐标信息，因为结果按行存储，所以1
#根据上面的bool列表，获取准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) #tf.cast格式转换后，求解平均值


#开启会话
with tf.Session() as sess:
    sess.run(init)
    for iter_cnt in range(21): #整体数据迭代
        for batch_cnt in range(n_batch): #小批次数据分批次迭代
            batch_xs,batch_ys = mnist.train.next_batch(batch_size) #依照batch_size获取下一批次数据
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys}) #传入数据，进行训练
            
        #获取每一次整体数据迭代后的准确率
        acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}) #注意：预测时，我们也要传入数据，并且传入的是测试集数据
        print("Iter: %d, Testing accuracy:%f"%(iter_cnt,acc))


Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
Iter: 0, Testing accuracy:0.868500
Iter: 1, Testing accuracy:0.897300
Iter: 2, Testing accuracy:0.903700
Iter: 3, Testing accuracy:0.906900
Iter: 4, Testing accuracy:0.907300
Iter: 5, Testing accuracy:0.909700
Iter: 6, Testing accuracy:0.912300
Iter: 7, Testing accuracy:0.913900
Iter: 8, Testing accuracy:0.915500
Iter: 9, Testing accuracy:0.915900
Iter: 10, Testing accuracy:0.917700
Iter: 11, Testing accuracy:0.917400
Iter: 12, Testing accuracy:0.918400
Iter: 13, Testing accuracy:0.919200
Iter: 14, Testing accuracy:0.920500
Iter: 15, Testing accuracy:0.919300
Iter: 16, Testing accuracy:0.919500
Iter: 17, Testing accuracy:0.921000
Iter: 18, Testing accuracy:0.920900
Iter: 19, Testing accuracy:0.920800
Iter: 20, Testing accuracy:0.922100


In [None]:
#未使用交叉熵
Iter: 0, Testing accuracy:0.828000
Iter: 1, Testing accuracy:0.871900
Iter: 2, Testing accuracy:0.882100
Iter: 3, Testing accuracy:0.888300
Iter: 4, Testing accuracy:0.893300
Iter: 5, Testing accuracy:0.897300
Iter: 6, Testing accuracy:0.899300
Iter: 7, Testing accuracy:0.901100
Iter: 8, Testing accuracy:0.903900
Iter: 9, Testing accuracy:0.905400
Iter: 10, Testing accuracy:0.906400
Iter: 11, Testing accuracy:0.907500
Iter: 12, Testing accuracy:0.908100
Iter: 13, Testing accuracy:0.909400
Iter: 14, Testing accuracy:0.909400
Iter: 15, Testing accuracy:0.911300
Iter: 16, Testing accuracy:0.911400
Iter: 17, Testing accuracy:0.911700
Iter: 18, Testing accuracy:0.913300
Iter: 19, Testing accuracy:0.913500
Iter: 20, Testing accuracy:0.914000
        
#使用交叉熵，我们发现训练速度快于未使用交叉熵时的训练速度
Iter: 0, Testing accuracy:0.868500
Iter: 1, Testing accuracy:0.897300
Iter: 2, Testing accuracy:0.903700
Iter: 3, Testing accuracy:0.906900
Iter: 4, Testing accuracy:0.907300
Iter: 5, Testing accuracy:0.909700
Iter: 6, Testing accuracy:0.912300
Iter: 7, Testing accuracy:0.913900
Iter: 8, Testing accuracy:0.915500
Iter: 9, Testing accuracy:0.915900
Iter: 10, Testing accuracy:0.917700
Iter: 11, Testing accuracy:0.917400
Iter: 12, Testing accuracy:0.918400
Iter: 13, Testing accuracy:0.919200
Iter: 14, Testing accuracy:0.920500
Iter: 15, Testing accuracy:0.919300
Iter: 16, Testing accuracy:0.919500
Iter: 17, Testing accuracy:0.921000
Iter: 18, Testing accuracy:0.920900
Iter: 19, Testing accuracy:0.920800
Iter: 20, Testing accuracy:0.922100