In [1]:
from utils import *

In [2]:
# load data

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("tmp/")

Extracting tmp/train-images-idx3-ubyte.gz
Extracting tmp/train-labels-idx1-ubyte.gz
Extracting tmp/t10k-images-idx3-ubyte.gz
Extracting tmp/t10k-labels-idx1-ubyte.gz


Let's say we have a pretrained model, but we want to train it on a different, but similar task. The weights on the pretrained model should be about where they "need to be" for the different, similar task, so starting with them would be helpful. However, this applies mostly to the lower layers, which capture low-level features, while the higher layers may actually be *too* specialized for the new task. 

In this case, we want to use the lower layers of the pretrained model but use new upper layers with randomly initialized weights. One way is to load the entire pretrained model, but at the point where you want brand new layers, make a branching *new set of layers* that take the output of the last pretrained layer you want to use as input. 

The upper layers of the old graph are still there - they're just not being used, which is why I call this the "branching" or "Siamese twin" method.

In [3]:
reset_graph()

saver = tf.train.import_meta_graph("savedmodels/11_07_gradientclipping.ckpt.meta")

In [4]:
X = tf.get_default_graph().get_tensor_by_name("X:0")
y = tf.get_default_graph().get_tensor_by_name("y:0")

# We want new layers after hidden3
# We get the final activations from hidden3
hidden3 = tf.get_default_graph().get_tensor_by_name("dnn/hidden3/Relu:0")

# Then, we continue building as normal - using "new" to avoid name collisions

n_hidden4 = 50
n_hidden5 = 50
n_outputs = 10

new_hidden4 = tf.layers.dense(hidden3, n_hidden4, activation=tf.nn.relu, name="new_hidden4")
new_hidden5 = tf.layers.dense(new_hidden4, n_hidden5, activation=tf.nn.relu, name="new_hidden5")
new_logits = tf.layers.dense(new_hidden5, n_outputs, name="new_outputs")

with tf.name_scope("new_loss"):
    xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=new_logits)
    loss = tf.reduce_mean(xentropy, name="loss")
    
learning_rate = 0.01

with tf.name_scope("new_train"):
    optimizer = tf.train.GradientDescentOptimizer(learning_rate)
    training_op = optimizer.minimize(loss)
    
with tf.name_scope("new_eval"):
    correct = tf.nn.in_top_k(new_logits, y, 1)
    accuracy = tf.reduce_mean(tf.cast(correct, tf.float32), name="accuracy")
    train_summary = tf.summary.scalar('train_accuracy', accuracy)
    valid_summary = tf.summary.scalar('valid_accuracy', accuracy)
    
init = tf.global_variables_initializer()
new_saver = tf.train.Saver()
file_writer = tf.summary.FileWriter("to_tensorboard/11_10_branching", tf.get_default_graph())

In [5]:
n_epochs = 20
batch_size = 200

with tf.Session() as sess:
    init.run()
    # restore weights from old run
    # all weights will be restored, but "new" prefixed variables
    # will be initialized in the default manner
    saver.restore(sess, "savedmodels/11_07_gradientclipping.ckpt")
    
    for epoch in range(n_epochs):
        for iteration in range(mnist.train.num_examples // batch_size):
            X_batch, y_batch = mnist.train.next_batch(batch_size)
            sess.run(training_op, feed_dict={X:X_batch, y:y_batch})
        acc_train, tra_str = sess.run([accuracy, train_summary], feed_dict={X:mnist.train.images, y:mnist.train.labels})
        acc_val, val_str = sess.run([accuracy, valid_summary], feed_dict={X:mnist.validation.images, y:mnist.validation.labels})
        file_writer.add_summary(tra_str, epoch)
        file_writer.add_summary(val_str, epoch)
        if epoch % 5 == 0 or epoch == n_epochs - 1:
            print(epoch, "train acc:", acc_train, "val acc:", acc_val)
    acc_test = accuracy.eval(feed_dict={X:mnist.test.images, y:mnist.test.labels})
    print("Test acc:", acc_test)
    
file_writer.close()

INFO:tensorflow:Restoring parameters from savedmodels/11_07_gradientclipping.ckpt
0 train acc: 0.920455 val acc: 0.9252
5 train acc: 0.956782 val acc: 0.9582
10 train acc: 0.966455 val acc: 0.9642
15 train acc: 0.973945 val acc: 0.9686
19 train acc: 0.977236 val acc: 0.9694
Test acc: 0.9666
