In [1]:
# To support both python 2 and python 3
from __future__ import division, print_function, unicode_literals

# Common imports
import numpy as np
import os
import tensorflow as tf

# to make this notebook's output stable across runs
def reset_graph(seed=42):
    tf.reset_default_graph()
    tf.set_random_seed(seed)
    np.random.seed(seed)

reset_graph()

In [2]:
# load data: digits 5 to 9, but still label with 0 to 4, 
# because TensorFlow expects label's integers from 0 to n_classes-1.
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/")

X_train2_full = mnist.train.images[mnist.train.labels >= 5]
y_train2_full = mnist.train.labels[mnist.train.labels >= 5] - 5
X_valid2_full = mnist.validation.images[mnist.validation.labels >= 5]
y_valid2_full = mnist.validation.labels[mnist.validation.labels >= 5] - 5
X_test2 = mnist.test.images[mnist.test.labels >= 5]
y_test2 = mnist.test.labels[mnist.test.labels >= 5] - 5

Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz


In [3]:
# we want to keep only 100 instances per class in the training set 
# and let's keep only 30 instances per class in the validation set
# tesing set is already loaded above
def sample_n_instances_per_class(X, y, n=100):
    Xs, ys = [], []
    for label in np.unique(y):
        idx = (y == label)
        Xc = X[idx][:n]
        yc = y[idx][:n]
        Xs.append(Xc)
        ys.append(yc)
    return np.concatenate(Xs), np.concatenate(ys)

X_train2, y_train2 = sample_n_instances_per_class(X_train2_full, y_train2_full, n=100)
X_valid2, y_valid2 = sample_n_instances_per_class(X_valid2_full, y_valid2_full, n=30)

In [26]:
reset_graph()

# import model from HW2
restore_saver = tf.train.import_meta_graph("../HW2_DNN/Team20_HW2.ckpt.meta")

# Step1: Get tensor from HW2 model
x = tf.get_default_graph().get_tensor_by_name("X:0")
y = tf.get_default_graph().get_tensor_by_name("Y:0")
loss = tf.get_default_graph().get_tensor_by_name("loss:0")
accuracy = tf.get_default_graph().get_tensor_by_name("accuracy:0")
training_mode = tf.get_default_graph().get_tensor_by_name("is_training:0")

# Step2: Get the softmax Layer
output_layer = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='logits')


# Step3: Only let softmax layer trainable
print(output_layer)
optimizer = tf.train.AdamOptimizer(learning_rate, name='opt')
# optimizer = tf.get_collection("optimizer")[0]
training_op = optimizer.minimize(loss, var_list=output_layer)

# y_ = tf.nn.softmax(output_layer)
# print(y_)
# prediction = tf.argmax(y_, 1, output_type=tf.int32) # get the index of y_ of max prob and value will be 0,1,2,3,4
# correct_prediction = tf.equal(prediction, tf.cast(y, tf.int32))
# accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy')

[<tf.Variable 'logits/kernel:0' shape=(128, 5) dtype=float32_ref>, <tf.Variable 'logits/bias:0' shape=(5,) dtype=float32_ref>]


In [None]:
# For debug
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    restore_saver.restore(sess, "../HW2_DNN/Team20_HW2.ckpt")
    y_ = sess.run(output_layer)
    print(y_)
#     all_vars = tf.trainable_variables()
#     for v in all_vars:
#         print("%s with value %s" % (v.name, sess.run(v)))
    
    restore_saver.restore(sess, "./Team20_HW3.ckpt")
    y_ = sess.run(output_layer)
    print(y_)

In [45]:
init_g = tf.global_variables_initializer()

# some params
batch_size = 20
learning_rate = 0.01

training_epochs = 1000
# If validation accuracy does not improve after certain steps of training, apply early stopping
early_stopping_epoch = 20

# Early stopping variables
best_loss = 10000000
best_epoch = 0
early_stopped = False

# Prepare our training dataset with batch
dataset_batch = tf.contrib.data.Dataset.from_tensor_slices((X_train2, y_train2)).batch(batch_size).repeat(training_epochs)
dataset_batch = dataset_batch.make_initializable_iterator()

# An epoch means one iteration over all of the training data
train_steps = round(len(X_train2) / batch_size)

# Saver
saver = tf.train.Saver()


with tf.Session() as sess:
    # initialize
    sess.run(init_g)
    sess.run(dataset_batch.initializer)
    
    # restore our model
    restore_saver.restore(sess, "../HW2_DNN/Team20_HW2.ckpt")
    for var in output_layer:
        var.initializer.run()
    
    # Training 1000 epochs
    for epoch in range(0, training_epochs):
        # Training steps
        for i in range(train_steps):
            X_in, y_in = sess.run(dataset_batch.get_next())
            sess.run(training_op, feed_dict={x: X_in, y: y_in, training_mode: False})

        # Validate accuracy every epoch
        curr_loss, curr_accuracy = sess.run([loss, accuracy], feed_dict={x: X_valid2, y: y_valid2, training_mode: False})
        
        # Save checkpoint of current model if it performs better
        if best_loss > curr_loss:
            best_loss = curr_loss
            save_path = saver.save(sess, "./Team20_HW3.ckpt")
            best_epoch = epoch        
        # Early stop if model does not improve for certain epoch
        elif epoch - best_epoch >= early_stopping_epoch:
            early_stopped = True
            break
        
        print("Epoch {}: Validation loss: {} Best loss: {} Accuracy: {} ".format(epoch, curr_loss, best_loss ,curr_accuracy))

    # Save checkpoint in case the training is not early-stopped
    if not early_stopped:
        print("save best model")
        save_path = saver.save(sess, "./Team20_HW3.ckpt")

    # Get the best model
    saver.restore(sess, "./Team20_HW3.ckpt")
    
    # Total accuracy
    final_accuracy = sess.run(accuracy, feed_dict={x: X_test2, y: y_test2, training_mode: False})
    print("Test accuracy: ", final_accuracy)


INFO:tensorflow:Restoring parameters from ../HW2_DNN/Team20_HW2.ckpt
Epoch 0: Validation loss: 1.7691210508346558 Best loss: 1.7691210508346558 Accuracy: 0.6466666460037231 
Epoch 1: Validation loss: 0.8087993860244751 Best loss: 0.8087993860244751 Accuracy: 0.7933333516120911 
Epoch 2: Validation loss: 0.6772139072418213 Best loss: 0.6772139072418213 Accuracy: 0.8399999737739563 
Epoch 3: Validation loss: 0.5612290501594543 Best loss: 0.5612290501594543 Accuracy: 0.8666666746139526 
Epoch 4: Validation loss: 0.5395689606666565 Best loss: 0.5395689606666565 Accuracy: 0.8533333539962769 
Epoch 5: Validation loss: 0.4974088668823242 Best loss: 0.4974088668823242 Accuracy: 0.8666666746139526 
Epoch 6: Validation loss: 0.4809908866882324 Best loss: 0.4809908866882324 Accuracy: 0.8866666555404663 
Epoch 7: Validation loss: 0.4638897478580475 Best loss: 0.4638897478580475 Accuracy: 0.8866667151451111 
Epoch 8: Validation loss: 0.4550561308860779 Best loss: 0.4550561308860779 Accuracy: 0.8799