<a href="https://colab.research.google.com/github/wangyf5996/tensorflow_learning/blob/master/mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
###
# 样例数据是针对mnist，有一些封装好的函数，如果使用自己的数据，
# 数据的读取可以参考 https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10
###
import tensorflow as tf
from tensorflow.contrib import slim
from tensorflow import keras

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('image_height', 28, 'the height of image')
tf.app.flags.DEFINE_integer('image_width', 28, 'the width of image')
tf.app.flags.DEFINE_integer('batch_size', 128, 'Number of images to process in a batch')
TRAIN_EXAMPLES_NUM = 55000
VALIDATION_EXAMPLES_NUM = 5000
TEST_EXAMPLES_NUM = 10000

#STEP1：输入数据的解析和预处理
def parse_data(example_proto):
    features = {'img_raw': tf.FixedLenFeature([], tf.string, ''),
                'label': tf.FixedLenFeature([], tf.int64, 0)}
    parsed_features = tf.parse_single_example(example_proto, features)
    image = tf.decode_raw(parsed_features['img_raw'], tf.uint8)
    label = tf.cast(parsed_features['label'], tf.int64)
    image = tf.reshape(image, [FLAGS.image_height, FLAGS.image_width, 1])
    image = tf.cast(image, tf.float32)
    return image, label
  
def read_mnist_tfrecords(filename_queue):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)

    features = tf.parse_single_example(serialized_example, features={
        'img_raw': tf.FixedLenFeature([], tf.string, ''),
        'label': tf.FixedLenFeature([], tf.int64, 0)
    })
    image = tf.decode_raw(features['img_raw'], tf.uint8)
    label = tf.cast(features['label'], tf.int64)
    image = tf.reshape(image, [FLAGS.image_height, FLAGS.image_width, 1])
    return image, label
  
def inputs(filenames, examples_num, batch_size, shuffle):
    for f in filenames:
        if not tf.gfile.Exists(f):
            raise ValueError('Failed to find file: ' + f)
    with tf.name_scope('inputs'):
        filename_queue = tf.train.string_input_producer(filenames)
        image, label = read_mnist_tfrecords(filename_queue)
        image = tf.cast(image, tf.float32)
        min_fraction_of_examples_in_queue = 0.4
        min_queue_examples = int(min_fraction_of_examples_in_queue * examples_num)
        num_process_threads = 16
        if shuffle:
            images, labels = tf.train.shuffle_batch([image, label], batch_size=batch_size,
                                                    num_threads=num_process_threads,
                                                    capacity=min_queue_examples + batch_size * 3,
                                                    min_after_dequeue=min_queue_examples)
        else:
            images, labels = tf.train.batch([image, label], batch_size=batch_size,
                                            num_threads=num_process_threads,
                                            capacity=min_queue_examples + batch_size * 3)
        return images, labels

      
#STEP2：定义模型
def inference(images, training):
    with tf.variable_scope('conv1'):
        conv1 = tf.layers.conv2d(inputs=images,
                                 filters=32,
                                 kernel_size=[5, 5],
                                 padding='same',
                                 activation=tf.nn.relu)
 
    pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)      # 14*14*32
 
    with tf.variable_scope('conv2'):
        conv2 = tf.layers.conv2d(inputs=pool1,
                                 filters=64,
                                 kernel_size=[5, 5],
                                 padding='same',
                                 activation=tf.nn.relu)
 
    pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)      # 7*7*64
 
    with tf.variable_scope('fc1'):
        pool2_flat = tf.reshape(pool2, [-1, 7*7*64])
        fc1 = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu)
        dropout1 = tf.layers.dropout(inputs=fc1, rate=0.4, training=training)
 
    with tf.variable_scope('logits'):
        logits = tf.layers.dense(inputs=dropout1, units=10)     # 使用该值计算交叉熵损失
        predict = tf.nn.softmax(logits)
 
    return logits, predict


#STEP3:定义计算损失并定义训练操作
def loss(logits, labels):
    labels = tf.cast(labels, tf.int64)
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits, name='cross_entropy')
    cross_entropy_loss = tf.reduce_mean(cross_entropy)
    return cross_entropy_loss
 
 
def train(total_loss, global_step):
    num_batches_per_epoch = TRAIN_EXAMPLES_NUM / FLAGS.batch_size
    decay_steps = int(num_batches_per_epoch * 10)
 
    # Decay the learning rate exponentially based on the number of steps.
    lr = tf.train.exponential_decay(learning_rate=0.001,
                                    global_step=global_step,
                                    decay_steps=decay_steps,
                                    decay_rate=0.1,
                                    staircase=True)
 
    # opt = tf.train.GradientDescentOptimizer(lr)
    # opt = tf.train.MomentumOptimizer(learning_rate=0.001, momentum=0.99)
    opt = tf.train.AdamOptimizer(learning_rate=lr)
    grad = opt.compute_gradients(total_loss)
    apply_grad_op = opt.apply_gradients(grad, global_step)
 
    return apply_grad_op




    
#STEP5:模型验证
def eval_once(saver, top_k_op):
    with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
        else:
            print('no checkpoint file')
            return
 
        coord = tf.train.Coordinator()
        try:
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
 
            iter_per_epoch = int(math.ceil(mnist.VALIDATION_EXAMPLES_NUM / FLAGS.batch_size))
 
            total_sample = iter_per_epoch * FLAGS.batch_size
            correct_predict = 0
            step = 0
 
            while step < iter_per_epoch and not coord.should_stop():
                predict = sess.run(top_k_op)
                correct_predict += np.sum(predict)
                step += 1
 
            precision = correct_predict / total_sample
            print('step: {}, model: {}, precision: {}'.format(global_step, ckpt.model_checkpoint_path, precision))
 
        except Exception as e:
            print('exception: ', e)
            coord.request_stop(e)
        finally:
            coord.request_stop()
        coord.join(threads)
 
 
def evaluation():
    images, labels = mnist.inputs(['./validation_img.tfrecords'], mnist.VALIDATION_EXAMPLES_NUM,
                                  batch_size=FLAGS.batch_size, shuffle=False)
    logits, pred = mnist.inference(images, training=False)
    top_k_op = tf.nn.in_top_k(logits, labels, 1)
 
    saver = tf.train.Saver()
 
    while True:
        eval_once(saver, top_k_op)
        if FLAGS.run_once:
            break
        time.sleep(FLAGS.eval_interval_secs)

