In [1]:
import tensorflow as tf

In [2]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./tf_mnist", one_hot=True)

Extracting ./tf_mnist\train-images-idx3-ubyte.gz
Extracting ./tf_mnist\train-labels-idx1-ubyte.gz
Extracting ./tf_mnist\t10k-images-idx3-ubyte.gz
Extracting ./tf_mnist\t10k-labels-idx1-ubyte.gz


In [3]:
learning_rate = 1e-3
batch_size = 100
display_step = 1
model_path = "./model.ckpt"

In [4]:
n_h1 = 256
n_h2 = 256
n_input = 28*28
n_classes = 10

In [5]:
def MLP(x, weights, biases):
    h1 = tf.nn.relu(tf.matmul(x,  weights['h1']) + biases['h1'])
    h2 = tf.nn.relu(tf.matmul(h1, weights['h2']) + biases['h2'])
    out = tf.matmul(h2, weights['out']) + biases['out']
    return out

In [6]:
graph = tf.Graph()
with graph.as_default():
    
    Xtr = tf.placeholder(tf.float32, [None, n_input])
    Ytr = tf.placeholder(tf.float32, [None, n_classes])
    
    weights = {
        'h1': tf.Variable(tf.random_normal([n_input, n_h1])),
        'h2': tf.Variable(tf.random_normal([n_h1, n_h2])),
        'out': tf.Variable(tf.random_normal([n_h2, n_classes]))
    }
    
    biases = {
        'h1': tf.Variable(tf.zeros([n_h1])),
        'h2': tf.Variable(tf.zeros([n_h2])),
        'out': tf.Variable(tf.zeros([n_classes]))
    }
    
    pred = MLP(Xtr, weights, biases)
    
    # Cost and Optimizer
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=Ytr))
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
    
    # Evaluation
    correct_pred = tf.equal(tf.argmax(Ytr,1), tf.argmax(pred,1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
    
    # 'Saver' op to save and restore all variables
    saver = tf.train.Saver()    
    
    # session init
    init = tf.global_variables_initializer()

In [7]:
with tf.Session(graph=graph) as sess:
    sess.run(init)
    
    for epoch in range(3):
        total_batch = int(mnist.train.num_examples/batch_size)
        avg_cost = 0.0
        for i in range(total_batch):
            batch_x, batch_y = mnist.train.next_batch(batch_size)
            feed_dict = { Xtr: batch_x, Ytr: batch_y  }
            _, c_ = sess.run([optimizer, cost], feed_dict=feed_dict)
            avg_cost += c_/total_batch
        if epoch % display_step == 0:
            train_accuracy = accuracy.eval(feed_dict=feed_dict)
            print ("Epoch: %d, avg_cost: %.4f, accuracy: %.4f"%
                  (epoch, avg_cost, train_accuracy))
    print ("First Optimization Finished!")
    
    # Test model
    test_accuracy = accuracy.eval(feed_dict={Xtr: mnist.test.images, Ytr: mnist.test.labels})
    print ("Test Accuracy: %.4f"% (test_accuracy))
    
    # Save model weights to disk
    save_path = saver.save(sess, model_path)
    print ("Model saved in file: %s"% (save_path))

Epoch: 0, avg_cost: 199.5351, accuracy: 0.8900
Epoch: 1, avg_cost: 45.5584, accuracy: 0.8300
Epoch: 2, avg_cost: 28.2303, accuracy: 0.9000
First Optimization Finished!
Test Accuracy: 0.9123
Model saved in file: ./model.ckpt


In [10]:
print ("Staring 2nd session...")
with tf.Session(graph=graph) as sess:
    sess.run(init)
    
    # Restore model weights from previously 
    load_path = saver.restore(sess, model_path)
    print ("Model restored from file: ", save_path)
    
    for epoch in range(7):
        total_batch = int(mnist.train.num_examples/batch_size)
        avg_cost = 0.0
        for i in range(total_batch):
            batch_x, batch_y = mnist.train.next_batch(batch_size)
            feed_dict = { Xtr: batch_x, Ytr: batch_y  }
            _, c_ = sess.run([optimizer, cost], feed_dict=feed_dict)
            avg_cost += c_/total_batch
        if epoch % display_step == 0:
            train_accuracy = accuracy.eval(feed_dict=feed_dict)
            print ("Epoch: %d, avg_cost: %.4f, accuracy: %.4f"%
                  (epoch, avg_cost, train_accuracy))
    print ("Second Optimization Finished!")
    
    # Test model
    test_accuracy = accuracy.eval(feed_dict={Xtr: mnist.test.images, Ytr: mnist.test.labels})
    print ("Test Accuracy: %.4f"% (test_accuracy))
        

Staring 2nd session...
Model restored from file:  ./model.ckpt
Epoch: 0, avg_cost: 19.7482, accuracy: 0.9000
Epoch: 1, avg_cost: 14.5744, accuracy: 0.9100
Epoch: 2, avg_cost: 10.9540, accuracy: 0.8900
Epoch: 3, avg_cost: 8.2293, accuracy: 0.9900
Epoch: 4, avg_cost: 6.2116, accuracy: 0.9600
Epoch: 5, avg_cost: 4.6374, accuracy: 0.9800
Epoch: 6, avg_cost: 3.5798, accuracy: 0.9900
Second Optimization Finished!
Test Accuracy: 0.9393
