def freeze(ckpt_path=None, save_path=None):
from tensorflow.python.tools import freeze_graph # , optimize_for_inference_lib
with tf.get_default_graph().as_default():
input_images = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_images')
global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)
seg_maps_pred = model.model(input_images, is_training=False)
tf.identity(seg_maps_pred, name='seg_maps')
variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step)
saver = tf.train.Saver(variable_averages.variables_to_restore())
with tf.Session() as sess:
saver.restore(sess, ckpt_path)
print('Freeze Model Will Saved at ', save_path)
fdir, name = os.path.split(save_path)
tf.train.write_graph(sess.graph_def, fdir, name, as_text=True)
freeze_graph.freeze_graph(
input_graph=save_path,
input_saver='',
input_binary=False,
input_checkpoint=ckpt_path,
output_node_names='seg_maps',
restore_op_name='',
filename_tensor_name='',
output_graph=save_path,
clear_devices=True,
initializer_nodes='',
)
forked from cv-or-not-cv/psenet_cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
yiran-THU/psenet_cpp
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
About
No description, website, or topics provided.
Resources
Stars
Watchers
Forks
Releases
No releases published
Packages 0
No packages published
Languages
- C++ 85.3%
- C 14.7%