In [16]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import moxing.tensorflow as mox
import numpy

import tensorflow as tf
import tensorflow.contrib.slim as slim

import mnist_data
import cnn_model
tf.reset_default_graph()

In [17]:
train_url = './cache/model/'          #训练输出位置。
if mox.file.exists(train_url):
    mox.file.remove(train_url,recursive=True)
mox.file.make_dirs(train_url)
log_url = './cache/log/'            # 日志存放位置
if mox.file.exists(log_url):
    mox.file.remove(log_url,recursive=True)
mox.file.make_dirs(log_url)

In [18]:
MODEL_DIRECTORY = train_url+'model.ckpt'
LOGS_DIRECTORY = log_url
# Params for Train
training_epochs = 10# 10 for augmented training data, 20 for training data
TRAIN_BATCH_SIZE = 50
display_step = 100
validation_step = 500

# Params for test
TEST_BATCH_SIZE = 5000

In [19]:
def train():

    # Some parameters
    batch_size = TRAIN_BATCH_SIZE
    num_labels = mnist_data.NUM_LABELS

    # Prepare mnist data
    train_total_data, train_size, validation_data, validation_labels, test_data, test_labels = mnist_data.prepare_MNIST_data(True)

    # Boolean for MODE of train or test
    is_training = tf.placeholder(tf.bool, name='MODE')

    # tf Graph input
    x = tf.placeholder(tf.float32, [None, 784])
    y_ = tf.placeholder(tf.float32, [None, 10]) #answer

    # Predict
    y = cnn_model.CNN(x)

    # Get loss of model
    with tf.name_scope("LOSS"):
        loss = slim.losses.softmax_cross_entropy(y,y_)

    # Create a summary to monitor loss tensor
    tf.summary.scalar('loss', loss)

    # Define optimizer
    with tf.name_scope("ADAM"):
        # Optimizer: set up a variable that's incremented once per batch and
        # controls the learning rate decay.
        batch = tf.Variable(0)

        learning_rate = tf.train.exponential_decay(
            1e-4,  # Base learning rate.
            batch * batch_size,  # Current index into the dataset.
            train_size,  # Decay step.
            0.95,  # Decay rate.
            staircase=True)
        # Use simple momentum for the optimization.
        train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss,global_step=batch)

    # Create a summary to monitor learning_rate tensor
    tf.summary.scalar('learning_rate', learning_rate)

    # Get accuracy of model
    with tf.name_scope("ACC"):
        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    # Create a summary to monitor accuracy tensor
    tf.summary.scalar('acc', accuracy)

    # Merge all summaries into a single op
    merged_summary_op = tf.summary.merge_all()

    # Add ops to save and restore all the variables
    saver = tf.train.Saver()
    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer(), feed_dict={is_training: True})

    # Training cycle
    total_batch = int(train_size / batch_size)

    # op to write logs to Tensorboard
    summary_writer = tf.summary.FileWriter(LOGS_DIRECTORY, graph=tf.get_default_graph())

    # Save the maximum accuracy value for validation data
    max_acc = 0.

    # Loop for epoch
    for epoch in range(training_epochs):

        # Random shuffling
        numpy.random.shuffle(train_total_data)
        train_data_ = train_total_data[:, :-num_labels]
        train_labels_ = train_total_data[:, -num_labels:]

        # Loop over all batches
        for i in range(total_batch):

            # Compute the offset of the current minibatch in the data.
            offset = (i * batch_size) % (train_size)
            batch_xs = train_data_[offset:(offset + batch_size), :]
            batch_ys = train_labels_[offset:(offset + batch_size), :]

            # Run optimization op (backprop), loss op (to get loss value)
            # and summary nodes
            _, train_accuracy, summary = sess.run([train_step, accuracy, merged_summary_op] , feed_dict={x: batch_xs, y_: batch_ys, is_training: True})

            # Write logs at every iteration
            summary_writer.add_summary(summary, epoch * total_batch + i)

            # Display logs
            if i % display_step == 0:
                print("Epoch:", '%04d,' % (epoch + 1),
                "batch_index %4d/%4d, training accuracy %.5f" % (i, total_batch, train_accuracy))

            # Get accuracy for validation data
            if i % validation_step == 0:
                # Calculate accuracy
                validation_accuracy = sess.run(accuracy,
                feed_dict={x: validation_data, y_: validation_labels, is_training: False})

                print("Epoch:", '%04d,' % (epoch + 1),
                "batch_index %4d/%4d, validation accuracy %.5f" % (i, total_batch, validation_accuracy))

            # Save the current model if the maximum accuracy is updated
            if validation_accuracy > max_acc:
                max_acc = validation_accuracy
                save_path = saver.save(sess, MODEL_DIRECTORY)
                print("Model updated and saved in file: %s" % save_path)

    print("Optimization Finished!")

    # Restore variables from disk
    saver.restore(sess, MODEL_DIRECTORY)

    # Calculate accuracy for all mnist test images
    test_size = test_labels.shape[0]
    batch_size = TEST_BATCH_SIZE
    total_batch = int(test_size / batch_size)

    acc_buffer = []

    # Loop over all batches
    for i in range(total_batch):
        # Compute the offset of the current minibatch in the data.
        offset = (i * batch_size) % (test_size)
        batch_xs = test_data[offset:(offset + batch_size), :]
        batch_ys = test_labels[offset:(offset + batch_size), :]

        y_final = sess.run(y, feed_dict={x: batch_xs, y_: batch_ys, is_training: False})
        correct_prediction = numpy.equal(numpy.argmax(y_final, 1), numpy.argmax(batch_ys, 1))
        acc_buffer.append(numpy.sum(correct_prediction) / batch_size)

    print("test accuracy for the stored model: %g" % numpy.mean(acc_buffer))


In [20]:
train()

Extracting ./cache/local_data/train-images-idx3-ubyte.gz
Extracting ./cache/local_data/train-labels-idx1-ubyte.gz
Extracting ./cache/local_data/t10k-images-idx3-ubyte.gz
Extracting ./cache/local_data/t10k-labels-idx1-ubyte.gz
expanding data : 100 / 55000
expanding data : 200 / 55000
expanding data : 300 / 55000
expanding data : 400 / 55000
expanding data : 500 / 55000
expanding data : 600 / 55000
expanding data : 700 / 55000
expanding data : 800 / 55000
expanding data : 900 / 55000
expanding data : 1000 / 55000
expanding data : 1100 / 55000
expanding data : 1200 / 55000
expanding data : 1300 / 55000
expanding data : 1400 / 55000
expanding data : 1500 / 55000
expanding data : 1600 / 55000
expanding data : 1700 / 55000
expanding data : 1800 / 55000
expanding data : 1900 / 55000
expanding data : 2000 / 55000
expanding data : 2100 / 55000
expanding data : 2200 / 55000
expanding data : 2300 / 55000
expanding data : 2400 / 55000
expanding data : 2500 / 55000
expanding data : 2600 / 55000
exp

expanding data : 26200 / 55000
expanding data : 26300 / 55000
expanding data : 26400 / 55000
expanding data : 26500 / 55000
expanding data : 26600 / 55000
expanding data : 26700 / 55000
expanding data : 26800 / 55000
expanding data : 26900 / 55000
expanding data : 27000 / 55000
expanding data : 27100 / 55000
expanding data : 27200 / 55000
expanding data : 27300 / 55000
expanding data : 27400 / 55000
expanding data : 27500 / 55000
expanding data : 27600 / 55000
expanding data : 27700 / 55000
expanding data : 27800 / 55000
expanding data : 27900 / 55000
expanding data : 28000 / 55000
expanding data : 28100 / 55000
expanding data : 28200 / 55000
expanding data : 28300 / 55000
expanding data : 28400 / 55000
expanding data : 28500 / 55000
expanding data : 28600 / 55000
expanding data : 28700 / 55000
expanding data : 28800 / 55000
expanding data : 28900 / 55000
expanding data : 29000 / 55000
expanding data : 29100 / 55000
expanding data : 29200 / 55000
expanding data : 29300 / 55000
expandin

expanding data : 52800 / 55000
expanding data : 52900 / 55000
expanding data : 53000 / 55000
expanding data : 53100 / 55000
expanding data : 53200 / 55000
expanding data : 53300 / 55000
expanding data : 53400 / 55000
expanding data : 53500 / 55000
expanding data : 53600 / 55000
expanding data : 53700 / 55000
expanding data : 53800 / 55000
expanding data : 53900 / 55000
expanding data : 54000 / 55000
expanding data : 54100 / 55000
expanding data : 54200 / 55000
expanding data : 54300 / 55000
expanding data : 54400 / 55000
expanding data : 54500 / 55000
expanding data : 54600 / 55000
expanding data : 54700 / 55000
expanding data : 54800 / 55000
expanding data : 54900 / 55000
expanding data : 55000 / 55000




Epoch: 0001, batch_index    0/5500, training accuracy 0.18000
Epoch: 0001, batch_index    0/5500, validation accuracy 0.14240
Model updated and saved in file: ./cache/model/model.ckpt
Epoch: 0001, batch_index  100/5500, training accuracy 0.90000
Epoch: 0001, batch_index  200/5500, training accuracy 0.90000
Epoch: 0001, batch_index  300/5500, training accuracy 0.88000
Epoch: 0001, batch_index  400/5500, training accuracy 0.92000
Epoch: 0001, batch_index  500/5500, training accuracy 0.92000
Epoch: 0001, batch_index  500/5500, validation accuracy 0.96900
Model updated and saved in file: ./cache/model/model.ckpt
Epoch: 0001, batch_index  600/5500, training accuracy 0.92000
Epoch: 0001, batch_index  700/5500, training accuracy 0.94000
Epoch: 0001, batch_index  800/5500, training accuracy 0.96000
Epoch: 0001, batch_index  900/5500, training accuracy 0.94000
Epoch: 0001, batch_index 1000/5500, training accuracy 1.00000
Epoch: 0001, batch_index 1000/5500, validation accuracy 0.98020
Model upda

Epoch: 0002, batch_index 4600/5500, training accuracy 1.00000
Epoch: 0002, batch_index 4700/5500, training accuracy 0.98000
Epoch: 0002, batch_index 4800/5500, training accuracy 1.00000
Epoch: 0002, batch_index 4900/5500, training accuracy 1.00000
Epoch: 0002, batch_index 5000/5500, training accuracy 1.00000
Epoch: 0002, batch_index 5000/5500, validation accuracy 0.99260
Epoch: 0002, batch_index 5100/5500, training accuracy 0.96000
Epoch: 0002, batch_index 5200/5500, training accuracy 1.00000
Epoch: 0002, batch_index 5300/5500, training accuracy 0.96000
Epoch: 0002, batch_index 5400/5500, training accuracy 1.00000
Epoch: 0003, batch_index    0/5500, training accuracy 1.00000
Epoch: 0003, batch_index    0/5500, validation accuracy 0.98900
Epoch: 0003, batch_index  100/5500, training accuracy 1.00000
Epoch: 0003, batch_index  200/5500, training accuracy 1.00000
Epoch: 0003, batch_index  300/5500, training accuracy 0.98000
Epoch: 0003, batch_index  400/5500, training accuracy 0.98000
Epoc

Epoch: 0004, batch_index 4400/5500, training accuracy 1.00000
Epoch: 0004, batch_index 4500/5500, training accuracy 1.00000
Epoch: 0004, batch_index 4500/5500, validation accuracy 0.99340
Epoch: 0004, batch_index 4600/5500, training accuracy 1.00000
Epoch: 0004, batch_index 4700/5500, training accuracy 1.00000
Epoch: 0004, batch_index 4800/5500, training accuracy 1.00000
Epoch: 0004, batch_index 4900/5500, training accuracy 1.00000
Epoch: 0004, batch_index 5000/5500, training accuracy 1.00000
Epoch: 0004, batch_index 5000/5500, validation accuracy 0.99300
Epoch: 0004, batch_index 5100/5500, training accuracy 1.00000
Epoch: 0004, batch_index 5200/5500, training accuracy 1.00000
Epoch: 0004, batch_index 5300/5500, training accuracy 1.00000
Epoch: 0004, batch_index 5400/5500, training accuracy 0.98000
Epoch: 0005, batch_index    0/5500, training accuracy 1.00000
Epoch: 0005, batch_index    0/5500, validation accuracy 0.99360
Epoch: 0005, batch_index  100/5500, training accuracy 1.00000
Ep

Epoch: 0006, batch_index 4200/5500, training accuracy 1.00000
Epoch: 0006, batch_index 4300/5500, training accuracy 0.96000
Epoch: 0006, batch_index 4400/5500, training accuracy 0.98000
Epoch: 0006, batch_index 4500/5500, training accuracy 0.98000
Epoch: 0006, batch_index 4500/5500, validation accuracy 0.99460
Epoch: 0006, batch_index 4600/5500, training accuracy 1.00000
Epoch: 0006, batch_index 4700/5500, training accuracy 1.00000
Epoch: 0006, batch_index 4800/5500, training accuracy 1.00000
Epoch: 0006, batch_index 4900/5500, training accuracy 1.00000
Epoch: 0006, batch_index 5000/5500, training accuracy 1.00000
Epoch: 0006, batch_index 5000/5500, validation accuracy 0.99540
Epoch: 0006, batch_index 5100/5500, training accuracy 1.00000
Epoch: 0006, batch_index 5200/5500, training accuracy 1.00000
Epoch: 0006, batch_index 5300/5500, training accuracy 1.00000
Epoch: 0006, batch_index 5400/5500, training accuracy 1.00000
Epoch: 0007, batch_index    0/5500, training accuracy 1.00000
Epoc

Epoch: 0008, batch_index 4100/5500, training accuracy 0.98000
Epoch: 0008, batch_index 4200/5500, training accuracy 1.00000
Epoch: 0008, batch_index 4300/5500, training accuracy 1.00000
Epoch: 0008, batch_index 4400/5500, training accuracy 1.00000
Epoch: 0008, batch_index 4500/5500, training accuracy 1.00000
Epoch: 0008, batch_index 4500/5500, validation accuracy 0.99420
Epoch: 0008, batch_index 4600/5500, training accuracy 0.98000
Epoch: 0008, batch_index 4700/5500, training accuracy 1.00000
Epoch: 0008, batch_index 4800/5500, training accuracy 1.00000
Epoch: 0008, batch_index 4900/5500, training accuracy 0.98000
Epoch: 0008, batch_index 5000/5500, training accuracy 1.00000
Epoch: 0008, batch_index 5000/5500, validation accuracy 0.99380
Epoch: 0008, batch_index 5100/5500, training accuracy 1.00000
Epoch: 0008, batch_index 5200/5500, training accuracy 1.00000
Epoch: 0008, batch_index 5300/5500, training accuracy 1.00000
Epoch: 0008, batch_index 5400/5500, training accuracy 1.00000
Epoc

Epoch: 0010, batch_index 4100/5500, training accuracy 1.00000
Epoch: 0010, batch_index 4200/5500, training accuracy 1.00000
Epoch: 0010, batch_index 4300/5500, training accuracy 1.00000
Epoch: 0010, batch_index 4400/5500, training accuracy 1.00000
Epoch: 0010, batch_index 4500/5500, training accuracy 0.98000
Epoch: 0010, batch_index 4500/5500, validation accuracy 0.99480
Epoch: 0010, batch_index 4600/5500, training accuracy 1.00000
Epoch: 0010, batch_index 4700/5500, training accuracy 1.00000
Epoch: 0010, batch_index 4800/5500, training accuracy 1.00000
Epoch: 0010, batch_index 4900/5500, training accuracy 1.00000
Epoch: 0010, batch_index 5000/5500, training accuracy 0.98000
Epoch: 0010, batch_index 5000/5500, validation accuracy 0.99460
Epoch: 0010, batch_index 5100/5500, training accuracy 1.00000
Epoch: 0010, batch_index 5200/5500, training accuracy 1.00000
Epoch: 0010, batch_index 5300/5500, training accuracy 1.00000
Epoch: 0010, batch_index 5400/5500, training accuracy 1.00000


INFO:tensorflow:Restoring parameters from ./cache/model/model.ckpt


Optimization Finished!
test accuracy for the stored model: 0.9952


In [21]:
# 将本地的data,log,model拷贝到obs,(如果obs上已经存在该文件夹，则删除后重新拷贝)
obs_save_url = 's3://zhh-obs001/mnist-1208-v2/cache/'
if mox.file.exists(obs_save_url):
    mox.file.remove(obs_save_url,recursive=True)
mox.file.make_dirs(train_url)
mox.file.copy_parallel('./cache/', obs_save_url)