From 2949cfd885794b563f1408a838b6385bd6faadbc Mon Sep 17 00:00:00 2001 From: MyungSung Kwak Date: Tue, 14 Aug 2018 13:35:28 +0900 Subject: [PATCH] Modify flag name for the checkpoint path change flag name to checkpoint_dir according to the variable name used by the checkpoint_utils within tensorflow python framework. The important point is that when run the run_eval script, an error occurs due to the different flag name. Signed-off-by: MyungSung Kwak --- .../learning_unsupervised_learning/run_eval.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/research/learning_unsupervised_learning/run_eval.py b/research/learning_unsupervised_learning/run_eval.py index 7e64e8d6acb..dcb2529dd4c 100644 --- a/research/learning_unsupervised_learning/run_eval.py +++ b/research/learning_unsupervised_learning/run_eval.py @@ -35,13 +35,13 @@ from tensorflow.contrib.framework.python.framework import checkpoint_utils -flags.DEFINE_string("checkpoint", None, "Dir to load pretrained update rule from") +flags.DEFINE_string("checkpoint_dir", None, "Dir to load pretrained update rule from") flags.DEFINE_string("train_log_dir", None, "Training log directory") FLAGS = flags.FLAGS -def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000): +def train(train_log_dir, checkpoint_dir, eval_every_n_steps=10, num_steps=3000): dataset_fn = datasets.mnist.TinyMnist w_learner_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateWLearner theta_process_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateProcess @@ -77,8 +77,8 @@ def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000): summary_op = tf.summary.merge_all() file_writer = summary_utils.LoggingFileWriter(train_log_dir, regexes=[".*"]) - if checkpoint: - str_var_list = checkpoint_utils.list_variables(checkpoint) + if checkpoint_dir: + str_var_list = checkpoint_utils.list_variables(checkpoint_dir) name_to_v_map = {v.op.name: v for v in tf.all_variables()} var_list = [ name_to_v_map[vn] for vn, _ in str_var_list if vn in name_to_v_map @@ -99,9 +99,9 @@ def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000): # global step should be restored from the evals job checkpoint or zero for fresh. step = sess.run(global_step) - if step == 0 and checkpoint: + if step == 0 and checkpoint_dir: tf.logging.info("force restore") - saver.restore(sess, checkpoint) + saver.restore(sess, checkpoint_dir) tf.logging.info("force restore done") sess.run(reset_global_step) step = sess.run(global_step) @@ -115,7 +115,7 @@ def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000): def main(argv): - train(FLAGS.train_log_dir, FLAGS.checkpoint) + train(FLAGS.train_log_dir, FLAGS.checkpoint_dir) if __name__ == "__main__":