In [1]:
import tensorflow as tf
import numpy as np

<h1>Save and Restore tf.variables

<h2>Create tf.Variables</h2>
<p style="font-size:20px">In the previous tutorial, we use <b>tf.Variable</b> to create variables. Here we introduce a second way to create tensorflow variables --- <b>tf.get_variable</b>. The <b>tf.get_variable</b> can be used to get the variable you have defined in the graph before and if it is not defined, it will create a new variable. You can initialize the variable with the <b>initializer</b> argument. You can use <b>assign</b> operation to update the variable just like what we did before. And you still need to have the "init_op" to initialize all variables.

In [9]:
#create some variables
W = tf.get_variable("W",shape=[3],initializer = tf.zeros_initializer)
b = tf.get_variable("b",shape=[5],initializer = tf.zeros_initializer)

#increase operation for W
increase_op = W.assign(W+1)
#decrease operation for b
decrease_op = b.assign(b-1)

#initialize variables
init_op = tf.global_variables_initializer()

<h2>Saver</h2>
<p style="font-size:20px">You can create a <b>Saver</b> with <b>tf.train.Saver()</b> to manage all variables in the model.

In [10]:
#create saver operation to save and restore all the variables
saver = tf.train.Saver()

<p style="font-size:20px">In the <b>tf.Session()</b>, you can save the variables to a specific directory using <b>saver.save(sess,"directory")</b>

In [11]:
with tf.Session() as sess:
    
    sess.run(init_op)
    
    #run the increase and decrease operations
    sess.run(increase_op)
    sess.run(decrease_op)
    
    #save the variables to the directory
    save_path = saver.save(sess,"tmp/model.ckpt")
    print("Model saved in path: %s" % save_path)

Model saved in path: tmp/model.ckpt


<p style="font-size:20px">In the tmp folder, you will find four files:
    <ol style="font-size:20px">
        <li><b>checkpoint</b>: All checkpoint information, like model ckpt file name and path</li>
        <li><b>model.ckpt.meta</b>: Tensorflow stores the graph structure separately from the vairable values. The file .ckpt.meta contains the complete graph.</li>
        <li><b>model.ckpt.data-0000-of-00001</b>: This contains the values of variables (weights, biases, placeholders, gradients, hyperparamters, etc.).</li>
        <li><b>model.ckpt.index</b>: It is a table where each key is the name of a tensor and its value is a serialized BundleEntryProto. (SerializedBundleEntryProto holds metadata of the tensors. Metadata of a tensor may be like: which of the "data" files contains the content of the content of a tensor, the offset into that file, checksum, etc.</li>
    </ol>
</p>

In [15]:
tf.reset_default_graph()

<h2>Restore variables</h2>

<p style="font-size:20px">We can restore our variables from disk by <b>saver.restore(sess, "directory")</b>. Note that here you don't need to initialize the variables with <b>tf.global_variables_initializer</b> anymore because they are existed.

<p style="font-size:20px">In addition, we introduce another way to get the values out: <b>W.eval()</b> that does exactly the same thing as <b>sess.run(W)</b>. In the end, you can see the variables are restored and printed out.

In [16]:
W = tf.get_variable("W",shape=[3])
b = tf.get_variable("b",shape=[5])

saver = tf.train.Saver()
with tf.Session() as sess:
    #restore variables from disk
    saver.restore(sess,"tmp/model.ckpt")
    
    print("Model restored.")
    #check the values of the variables, W.eval() == sess.run(W)
    print("W: %s"%W.eval())
    print("b: %s"%b.eval())

INFO:tensorflow:Restoring parameters from tmp/model.ckpt
Model restored.
W: [1. 1. 1.]
b: [-1. -1. -1. -1. -1.]


<h1>Choose vairbales to save and restore</h1>

<p style="font-size:20px">If you do not pass any arguments to <b>tf.train.Saver()</b>, the saver handles all variables in the graph. Each variable is saved under the name that was passed when the variable was created.</p>
<p style="font-size:20px">It is sometimes useful to explicitly specify names for variables in the checkpoint files. For example, you may have trained a model with a variable named "weights" whose value you want to restore into a variable named "params".</p>

<p style="font-size:20px">It is also sometimes useful to only save or restore a subset of the variables used by a model. For example, you may have trained a neural net with five layers, and you now want to train a new model with six layers that reuses the existing weights of the five trained layers. You can use the saver to restore the weights of just the first five layers.

<p style="font-size:20px">You can easily specify the names and variables to save or load by passing to the <b>tf.train.Saver()</b> constructor either of the following:

<ul style="font-size:20px">
    <li>A list of variables (which will be stored under their own names).</li>
    <li>A Python dictionary in which keys are the names to use and the values are the variables to manage.</li>

In [17]:
tf.reset_default_graph()
W = tf.get_variable("W",[3],initializer=tf.zeros_initializer)
b = tf.get_variable("b",[5],initializer=tf.zeros_initializer)

#Add ops to save and restore only 'b' using the name "b"
saver = tf.train.Saver({"b":b})

with tf.Session() as sess:
    #initialize W since the saver will not.
    W.initializer.run()
    
    saver.restore(sess,'tmp/model.ckpt')
    
    print("W: %s"%W.eval())
    print("b: %s"%b.eval())

INFO:tensorflow:Restoring parameters from tmp/model.ckpt
W: [0. 0. 0.]
b: [-1. -1. -1. -1. -1.]


<h2>Restore graph with meta file</h2>

<p style="font-size:20px">Since the <b>.meta</b> file has the structure of the graph. You can load this graph structure using <b>tf.train.import_meta_graph()</b>. This adds the graph to the default graph, and returns a <b>Saver</b> instance that you can then use to restore the graph's state (i.e., the variable values).</p>

<p style="font-size:20px"> This allows you to fully restore a saved model, including both the graph structure and the variable values, without having to search for the code that built it.</p>

In [22]:
tf.reset_default_graph()
saver = tf.train.import_meta_graph("tmp/model.ckpt.meta")
W = tf.get_default_graph().get_tensor_by_name("W:0")
b = tf.get_default_graph().get_tensor_by_name("b:0")
with tf.Session() as sess:
    saver.restore(sess,"tmp/model.ckpt")
    
    print("W: %s" %sess.run(W))
    print("b: %s" %sess.run(b))

INFO:tensorflow:Restoring parameters from tmp/model.ckpt
W: [1. 1. 1.]
b: [-1. -1. -1. -1. -1.]
