## 5.4 TensorFlow模型持久化
在5.2.1节中给出的样例代码在训练完成后就直接退出了，并没有将训练得到的模型保存下来方便下次直接使用。**为了让训练结果可以复用，需要将得到的神经网络模型持久化。**

5.4.1节将介绍通过TensorFlow程序来持久化一个训练好的模型，并从持久化之后的模型文件中还原被保存的模型；5.4.2节将介绍TensorFlow持久化的工作原理和持久化之后文件中的数据格式。

### 5.4.1 持久化代码实现
**1. 保存模型**

**TensorFlow提供了一个非常简单的API来保存和还原一个神经网络模型，这个API就是tf.train.Saver类**，如下：

In [1]:
import tensorflow as tf

# 声明两个变量并计算他们的和
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(1.0, shape=[1]), name='v2')
result = v1 + v2

init_op = tf.global_variables_initializer()
# 声明tf.train.Saver类用于保存模型
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    saver.save(sess, "Saved_model/model.ckpt")

上述代码中，`saver.save` 函数将TensorFlow模型保存到了*Saved_model/model.ckpt*文件中。上述程序虽然只指定了一个文件路径，但是在这个目录下会出现三个文件，这是因为TensorFlow会将计算图的结构和图上参数取值分开保存，这几个文件分别为（这一部分参考了5.4.2节）：

**a. model.ckpt.meta**，它保存了TensorFlow计算图的结构，可以简单理解为神经网络的网络结构；

**b. model.ckpt**，它保存了模型中所有变量的取值，实际上分为两个文件：model.ckpt.index 和 model.ckpt.data-\*\*\*\*\*-of-\*\*\*\*\*，其中后者是通过SStable格式存储的，可以大致理解为一个（key, value）的列表。TensorFlow提供了`tf.train.NewCheckpointReader`类来查看保存的变量信息，如下：

In [2]:
# 该类可以读取checkpoint文件中保存的所有变量，注意后面的.data和.index可以省去
reader = tf.train.NewCheckpointReader('Saved_model/model.ckpt')

# 获取所有变量列表，这个是一个从变量名到变量维度的字典
global_varibales = reader.get_variable_to_shape_map()
for variable_name in global_varibales:
    # variable_name为变量名称，global_variables[variable_name]为变量的维度
    print(variable_name, global_varibales[variable_name])
    
# 获取名为v1的变量的取值
print("Value for variable v1 is ", reader.get_tensor("v1"))

v1 [1]
v2 [1]
Value for variable v1 is  [1.]


**c. checkpoint**，这个文件是 `tf.train.Saver` 类自动生成生成且自动维护的。在checkpoint文件中维护了一个由 `tf.train.Saver` 类持久化的所有TensorFlow模型文件的文件名。当某个保存的TensorFlow模型文件被删除时，这个模型对应的文件名也会从checkpoint文件中删除。checkpoint中内容的格式为 CheckpointState Protocol Buffer，下面给出了 CheckpointState 类型的定义：

`message CheckpointState {
  string model_checkpoint_path = 1;
  repeated string all_model_checkpoint_paths = 2;
}`

- model_checkpoint_path 属性保存了最新的TensorFlow模型文件的文件名；
- all_model_checkpoint_paths 属性列出了当前还没有被删除的所有TensorFlow模型文件的文件名。

对于本例：

 `model_checkpoint_path: "model.ckpt"
 all_model_checkpoint_paths: "model.ckpt"
`

**2. 加载模型**

下面的代码显示了如何**加载已经保存的模型**（建议restart the kernel，然后从这个cell运行，下同）：

In [1]:
import tensorflow as tf

# 使用和保存模型代码中一样的方式来声明变量
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(1.0, shape=[1]), name='v2')
result = v1 + v2

saver = tf.train.Saver()

with tf.Session() as sess:
    # 加载已经保存的模型，并通过已保存的模型中变量来计算加法
    saver.restore(sess, "Saved_model/model.ckpt")
    print(sess.run(result))

INFO:tensorflow:Restoring parameters from Saved_model/model.ckpt
[2.]


这段代码基本和保存模型的代码是一样的，也是先定义了TensorFlow计算图上的所有运算，并声明了一个 `tf.train.Saver` 类。**唯一不同的是在加载模型的代码中没有运行变量的初始化过程，而是将变量的值通过已经保存的模型加载进来。**

**如果不希望重复定义图上的计算，也可以直接加载已经持久化的图**，如下：

In [1]:
import tensorflow as tf

saver = tf.train.import_meta_graph("Saved_model/model.ckpt.meta")

with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model.ckpt")
    # 通过张量名称来获取张量
    print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))

INFO:tensorflow:Restoring parameters from Saved_model/model.ckpt
[2.]


**3. 保存或加载部分变量**

上面给出的程序，默认都是保存或加载了TensorFlow计算图上定义的所有变量，但有时可能只需要保存或加载部分变量，如一个训练好神经网络的前五层。**为了保存或者加载部分变量，在声明 `tf.train.Saver` 类时可以提供一个列表来指定需要保存或加载的变量**。

如在加载模型的代码中将 `saver = tf.train.Saver()` 改为 `saver = tf.train.Saver([v1])`，可以看到就会报错`FailedPreconditionError: Attempting to use uninitialized value v2`，如下：

In [1]:
import tensorflow as tf

# 使用和保存模型代码中一样的方式来声明变量
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(1.0, shape=[1]), name='v2')
result = v1 + v2

saver = tf.train.Saver([v1])

with tf.Session() as sess:
    # 加载已经保存的模型，并通过已保存的模型中变量来计算加法
    saver.restore(sess, "Saved_model/model.ckpt")
    print(sess.run(result))

INFO:tensorflow:Restoring parameters from Saved_model/model.ckpt


FailedPreconditionError: Attempting to use uninitialized value v2
	 [[{{node v2/read}} = Identity[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"](v2)]]
	 [[{{node add/_3}} = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_9_add", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

Caused by op 'v2/read', defined at:
  File "d:\python3\Lib\runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "d:\python3\Lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "d:\python3\tfgpu\dl+\lib\site-packages\ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "d:\python3\tfgpu\dl+\lib\site-packages\traitlets\config\application.py", line 658, in launch_instance
    app.start()
  File "d:\python3\tfgpu\dl+\lib\site-packages\ipykernel\kernelapp.py", line 486, in start
    self.io_loop.start()
  File "d:\python3\tfgpu\dl+\lib\site-packages\tornado\platform\asyncio.py", line 112, in start
    self.asyncio_loop.run_forever()
  File "d:\python3\Lib\asyncio\base_events.py", line 421, in run_forever
    self._run_once()
  File "d:\python3\Lib\asyncio\base_events.py", line 1431, in _run_once
    handle._run()
  File "d:\python3\Lib\asyncio\events.py", line 145, in _run
    self._callback(*self._args)
  File "d:\python3\tfgpu\dl+\lib\site-packages\tornado\platform\asyncio.py", line 102, in _handle_events
    handler_func(fileobj, events)
  File "d:\python3\tfgpu\dl+\lib\site-packages\tornado\stack_context.py", line 276, in null_wrapper
    return fn(*args, **kwargs)
  File "d:\python3\tfgpu\dl+\lib\site-packages\zmq\eventloop\zmqstream.py", line 450, in _handle_events
    self._handle_recv()
  File "d:\python3\tfgpu\dl+\lib\site-packages\zmq\eventloop\zmqstream.py", line 480, in _handle_recv
    self._run_callback(callback, msg)
  File "d:\python3\tfgpu\dl+\lib\site-packages\zmq\eventloop\zmqstream.py", line 432, in _run_callback
    callback(*args, **kwargs)
  File "d:\python3\tfgpu\dl+\lib\site-packages\tornado\stack_context.py", line 276, in null_wrapper
    return fn(*args, **kwargs)
  File "d:\python3\tfgpu\dl+\lib\site-packages\ipykernel\kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "d:\python3\tfgpu\dl+\lib\site-packages\ipykernel\kernelbase.py", line 233, in dispatch_shell
    handler(stream, idents, msg)
  File "d:\python3\tfgpu\dl+\lib\site-packages\ipykernel\kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "d:\python3\tfgpu\dl+\lib\site-packages\ipykernel\ipkernel.py", line 208, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "d:\python3\tfgpu\dl+\lib\site-packages\ipykernel\zmqshell.py", line 537, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "d:\python3\tfgpu\dl+\lib\site-packages\IPython\core\interactiveshell.py", line 2728, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "d:\python3\tfgpu\dl+\lib\site-packages\IPython\core\interactiveshell.py", line 2850, in run_ast_nodes
    if self.run_code(code, result):
  File "d:\python3\tfgpu\dl+\lib\site-packages\IPython\core\interactiveshell.py", line 2910, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-1-9ee245d85979>", line 5, in <module>
    v2 = tf.Variable(tf.constant(1.0, shape=[1]), name='v2')
  File "d:\python3\tfgpu\dl+\lib\site-packages\tensorflow\python\ops\variables.py", line 145, in __call__
    return cls._variable_call(*args, **kwargs)
  File "d:\python3\tfgpu\dl+\lib\site-packages\tensorflow\python\ops\variables.py", line 141, in _variable_call
    aggregation=aggregation)
  File "d:\python3\tfgpu\dl+\lib\site-packages\tensorflow\python\ops\variables.py", line 120, in <lambda>
    previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
  File "d:\python3\tfgpu\dl+\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 2441, in default_variable_creator
    expected_shape=expected_shape, import_scope=import_scope)
  File "d:\python3\tfgpu\dl+\lib\site-packages\tensorflow\python\ops\variables.py", line 147, in __call__
    return super(VariableMetaclass, cls).__call__(*args, **kwargs)
  File "d:\python3\tfgpu\dl+\lib\site-packages\tensorflow\python\ops\variables.py", line 1104, in __init__
    constraint=constraint)
  File "d:\python3\tfgpu\dl+\lib\site-packages\tensorflow\python\ops\variables.py", line 1266, in _init_from_args
    self._snapshot = array_ops.identity(self._variable, name="read")
  File "d:\python3\tfgpu\dl+\lib\site-packages\tensorflow\python\ops\array_ops.py", line 81, in identity
    return gen_array_ops.identity(input, name=name)
  File "d:\python3\tfgpu\dl+\lib\site-packages\tensorflow\python\ops\gen_array_ops.py", line 3994, in identity
    "Identity", input=input, name=name)
  File "d:\python3\tfgpu\dl+\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "d:\python3\tfgpu\dl+\lib\site-packages\tensorflow\python\util\deprecation.py", line 488, in new_func
    return func(*args, **kwargs)
  File "d:\python3\tfgpu\dl+\lib\site-packages\tensorflow\python\framework\ops.py", line 3272, in create_op
    op_def=op_def)
  File "d:\python3\tfgpu\dl+\lib\site-packages\tensorflow\python\framework\ops.py", line 1768, in __init__
    self._traceback = tf_stack.extract_stack()

FailedPreconditionError (see above for traceback): Attempting to use uninitialized value v2
	 [[{{node v2/read}} = Identity[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"](v2)]]
	 [[{{node add/_3}} = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_9_add", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]


**4. 变量重命名**

除了可以选取需要被保存或加载的变量，**`tf.train.Saver` 类也支持再保存或加载时给变量重命名**。如下面的程序中对变量v1和v2的名称进行了修改，如果直接通过 `tf.train.Saver` 默认的构造函数来加载保存的模型，那么程序会报变量找不到的错误 `NotFoundError: Key other-v1 not found in checkpoint`，如下：

In [1]:
import tensorflow as tf

# 注意定义时的name不一样！！！
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1")
v2 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v2")
result = v1 + v2

saver = tf.train.Saver()

with tf.Session() as sess:
    # 加载已经保存的模型，并通过已保存的模型中变量来计算加法
    saver.restore(sess, "Saved_model/model.ckpt")
    print(sess.run(result))

INFO:tensorflow:Restoring parameters from Saved_model/model.ckpt


NotFoundError: Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Key other-v1 not found in checkpoint
	 [[{{node save/RestoreV2}} = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

Caused by op 'save/RestoreV2', defined at:
  File "d:\python3\Lib\runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "d:\python3\Lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "d:\python3\tfgpu\dl+\lib\site-packages\ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "d:\python3\tfgpu\dl+\lib\site-packages\traitlets\config\application.py", line 658, in launch_instance
    app.start()
  File "d:\python3\tfgpu\dl+\lib\site-packages\ipykernel\kernelapp.py", line 486, in start
    self.io_loop.start()
  File "d:\python3\tfgpu\dl+\lib\site-packages\tornado\platform\asyncio.py", line 112, in start
    self.asyncio_loop.run_forever()
  File "d:\python3\Lib\asyncio\base_events.py", line 421, in run_forever
    self._run_once()
  File "d:\python3\Lib\asyncio\base_events.py", line 1431, in _run_once
    handle._run()
  File "d:\python3\Lib\asyncio\events.py", line 145, in _run
    self._callback(*self._args)
  File "d:\python3\tfgpu\dl+\lib\site-packages\tornado\platform\asyncio.py", line 102, in _handle_events
    handler_func(fileobj, events)
  File "d:\python3\tfgpu\dl+\lib\site-packages\tornado\stack_context.py", line 276, in null_wrapper
    return fn(*args, **kwargs)
  File "d:\python3\tfgpu\dl+\lib\site-packages\zmq\eventloop\zmqstream.py", line 450, in _handle_events
    self._handle_recv()
  File "d:\python3\tfgpu\dl+\lib\site-packages\zmq\eventloop\zmqstream.py", line 480, in _handle_recv
    self._run_callback(callback, msg)
  File "d:\python3\tfgpu\dl+\lib\site-packages\zmq\eventloop\zmqstream.py", line 432, in _run_callback
    callback(*args, **kwargs)
  File "d:\python3\tfgpu\dl+\lib\site-packages\tornado\stack_context.py", line 276, in null_wrapper
    return fn(*args, **kwargs)
  File "d:\python3\tfgpu\dl+\lib\site-packages\ipykernel\kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "d:\python3\tfgpu\dl+\lib\site-packages\ipykernel\kernelbase.py", line 233, in dispatch_shell
    handler(stream, idents, msg)
  File "d:\python3\tfgpu\dl+\lib\site-packages\ipykernel\kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "d:\python3\tfgpu\dl+\lib\site-packages\ipykernel\ipkernel.py", line 208, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "d:\python3\tfgpu\dl+\lib\site-packages\ipykernel\zmqshell.py", line 537, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "d:\python3\tfgpu\dl+\lib\site-packages\IPython\core\interactiveshell.py", line 2728, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "d:\python3\tfgpu\dl+\lib\site-packages\IPython\core\interactiveshell.py", line 2850, in run_ast_nodes
    if self.run_code(code, result):
  File "d:\python3\tfgpu\dl+\lib\site-packages\IPython\core\interactiveshell.py", line 2910, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-1-732e38388eb2>", line 8, in <module>
    saver = tf.train.Saver()
  File "d:\python3\tfgpu\dl+\lib\site-packages\tensorflow\python\training\saver.py", line 1094, in __init__
    self.build()
  File "d:\python3\tfgpu\dl+\lib\site-packages\tensorflow\python\training\saver.py", line 1106, in build
    self._build(self._filename, build_save=True, build_restore=True)
  File "d:\python3\tfgpu\dl+\lib\site-packages\tensorflow\python\training\saver.py", line 1143, in _build
    build_save=build_save, build_restore=build_restore)
  File "d:\python3\tfgpu\dl+\lib\site-packages\tensorflow\python\training\saver.py", line 787, in _build_internal
    restore_sequentially, reshape)
  File "d:\python3\tfgpu\dl+\lib\site-packages\tensorflow\python\training\saver.py", line 406, in _AddRestoreOps
    restore_sequentially)
  File "d:\python3\tfgpu\dl+\lib\site-packages\tensorflow\python\training\saver.py", line 854, in bulk_restore
    return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
  File "d:\python3\tfgpu\dl+\lib\site-packages\tensorflow\python\ops\gen_io_ops.py", line 1550, in restore_v2
    shape_and_slices=shape_and_slices, dtypes=dtypes, name=name)
  File "d:\python3\tfgpu\dl+\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "d:\python3\tfgpu\dl+\lib\site-packages\tensorflow\python\util\deprecation.py", line 488, in new_func
    return func(*args, **kwargs)
  File "d:\python3\tfgpu\dl+\lib\site-packages\tensorflow\python\framework\ops.py", line 3272, in create_op
    op_def=op_def)
  File "d:\python3\tfgpu\dl+\lib\site-packages\tensorflow\python\framework\ops.py", line 1768, in __init__
    self._traceback = tf_stack.extract_stack()

NotFoundError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Key other-v1 not found in checkpoint
	 [[{{node save/RestoreV2}} = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]


这是因为模型保存时的变量名和加载时的变量名不同，为了解决这个问题，**TensorFlow可以通过字典将模型保存时的变量名和需加载的变量联系起来**，如下：

In [1]:
import tensorflow as tf

# 注意定义时的name不一样！！！
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1")
v2 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v2")
result = v1 + v2

# 将原来名称为v1的变量加载到现在变量v1中（名称为other-v1）,将原来名称为v2的变量加载到现在变量v1中（名称为other-v2）
saver = tf.train.Saver({"v1": v1, "v2": v2})

with tf.Session() as sess:
    # 加载已经保存的模型，并通过已保存的模型中变量来计算加法
    saver.restore(sess, "Saved_model/model.ckpt")
    print(sess.run(result))

INFO:tensorflow:Restoring parameters from Saved_model/model.ckpt
[2.]


**5. 滑动平均类的保存和加载**

支持变量重命名的主要目的之一就是方便使用变量的滑动平均值。4.4.3节中介绍了使用滑动平均值可以让模型更加健壮，在TensorFlow中每个变量的滑动平均值是通过影子变量维护的，所以要获取变量的滑动平均值实际上就是获取这个影子变量的取值。**如果在加载加载模型时，直接将影子变量映射到变量自身，那么使用训练好的模型时就不需要调用函数来获取变量的滑动平均值了**。下面给出了一个样例（再次提醒，这里需要restart the kernel，且从这开始运行cell）：

In [1]:
# 保存滑动平均模型
import tensorflow as tf

v = tf.Variable(0, dtype=tf.float32, name="v")
# 在没有申明滑动平均模型时只有一个变量v，所以以下语句只会输出"v:0"
for variables in tf.global_variables():
    print(variables.name)
    
ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
# 在申明滑动平均模型之后，TensorFlow会自动生成一个影子变量，于是以下语句会输出"v:0"和"v/ExponentialMovingAverage:0"
for variables in tf.global_variables(): 
    print(variables.name)
    
saver = tf.train.Saver()
with tf.Session() as sess:
    init_op = tf.global_variables_initializer().run()
    
    sess.run(tf.assign(v, 10))
    sess.run(maintain_averages_op)
    # 保存的时候会将 v:0 和 v/ExponentialMovingAverage:0 这两个变量都存下来
    saver.save(sess, "Saved_model/model2.ckpt")
    print(sess.run([v, ema.average(v)]))     # 输出： [10.0, 0.099999905]

v:0
v:0
v/ExponentialMovingAverage:0
[10.0, 0.099999905]


In [2]:
# 加载滑动平均模型，这个cell不用restart the kernel
v = tf.Variable(0, dtype=tf.float32, name="v")

# 通过变量重命名将原来变量v的滑动平均值直接赋值给v。
saver = tf.train.Saver({"v/ExponentialMovingAverage": v})

with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model2.ckpt")
    print(sess.run(v))

INFO:tensorflow:Restoring parameters from Saved_model/model2.ckpt
0.099999905


为了方便加载重命名滑动平均变量，`tf.train.ExponentialMovingAverage` 类提供了 `variables_to_restore` 函数来生成 `tf.train.Saver` 类所需要的变量重命名字典，如下：

In [1]:
import tensorflow as tf

v = tf.Variable(0, dtype=tf.float32, name="v")

ema = tf.train.ExponentialMovingAverage(0.99)

# 通过variables_to_restore函数可以直接生成上面代码中提供的字典：{"v/ExponentialMovingAverage": v}
print(ema.variables_to_restore())     # 输出：{'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}

saver = tf.train.Saver(ema.variables_to_restore())

with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model2.ckpt")
    print(sess.run(v))

{'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
INFO:tensorflow:Restoring parameters from Saved_model/model2.ckpt
0.099999905


**6. 常量方式保存（pb格式）**

**使用 `tf.train.Saver` 会保存运行TensorFlow程序所需要的全部信息，然而有时并不需要某些信息**。比如在测试或者离线预测时，只需要知道如何从神经网络的输入层经过前向传播计算得到输出层即可，而不需要类似于变量初始化、模型保存等辅助节点的信息，在第6章介绍迁移学习时会遇到类似的情况。而且，**将变量取值和计算图结构分成不同的文件存储有时候也不方便**。

于是TensorFlow提供了 `convert_variables_to_constants` 函数，通过这个函数可以将计算图中的变量及其取值通过常量的方式保存，这样整个TensorFlow计算图可以统一存放在一个文件中。**注意在保存选择节点时是['add']，即计算节点的名称，而在加载时是['add:0']，即张量的名称**。下面给出一个样例：

In [1]:
# 保存
import tensorflow as tf
from tensorflow.python.framework import graph_util

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2

with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    
    # 导出当前计算图的GraphDef部分，只需要这一部分就可以完成从输入层到输出层的计算过程
    graph_def = tf.get_default_graph().as_graph_def()
    
    # 将圈中的变量及其取值转化为常量，同时将图中不必要的节点去掉。在5.4.2 节中将会看
    # 到一些系统运算也会被转化为计算图中的节点（比如变量初始化操作）。如果只关心程序中定
    # 义的某些计算时，和这些计算无关的节点就没有必要导出并保存了。在下面一行代码中，最
    # 后一个参数［'add'］给出了需要保存的节点名称。add 节点是上面定义的两个变量相加的
    # 操作。注意这里给出的是计算节点的名称，所以没有后面的：0 。
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])
    
    with tf.gfile.GFile("Saved_model/combined_model.pb", "wb") as f:
           f.write(output_graph_def.SerializeToString())

INFO:tensorflow:Froze 2 variables.
INFO:tensorflow:Converted 2 variables to const ops.


In [2]:
# 加载，这个cell不用restart the kernel
from tensorflow.python.platform import gfile

with tf.Session() as sess:
    model_filename = "Saved_model/combined_model.pb"
    
    # 读取保存的模型文件，并将文件解析成对应的GraphDef Protocol Buffer
    with gfile.FastGFile(model_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # 将graph_def中保存的图加载到当前的图中。return_elements＝["add:0"]给出了返回的张量
    # 的名称。在保存的时候给出的是计算节点的名称，所以为"add"，在加载的时候给出
    # 的是张量的名称，所以为"add:0"。
    result = tf.import_graph_def(graph_def, return_elements=["add:0"])
    print(sess.run(result))

[array([3.], dtype=float32)]
