Skip to content

Commit

Permalink
Modify flag name for the checkpoint path
Browse files Browse the repository at this point in the history
change flag name to checkpoint_dir according to the variable name
used by the checkpoint_utils within tensorflow python framework.

The important point is that when run the run_eval script, an error
occurs due to the different flag name.

Signed-off-by: MyungSung Kwak <yesmung@gmail.com>
  • Loading branch information
yesmung committed Aug 14, 2018
1 parent 5be3727 commit 2949cfd
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions research/learning_unsupervised_learning/run_eval.py
Expand Up @@ -35,13 +35,13 @@

from tensorflow.contrib.framework.python.framework import checkpoint_utils

flags.DEFINE_string("checkpoint", None, "Dir to load pretrained update rule from")
flags.DEFINE_string("checkpoint_dir", None, "Dir to load pretrained update rule from")
flags.DEFINE_string("train_log_dir", None, "Training log directory")

FLAGS = flags.FLAGS


def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000):
def train(train_log_dir, checkpoint_dir, eval_every_n_steps=10, num_steps=3000):
dataset_fn = datasets.mnist.TinyMnist
w_learner_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateWLearner
theta_process_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateProcess
Expand Down Expand Up @@ -77,8 +77,8 @@ def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000):
summary_op = tf.summary.merge_all()

file_writer = summary_utils.LoggingFileWriter(train_log_dir, regexes=[".*"])
if checkpoint:
str_var_list = checkpoint_utils.list_variables(checkpoint)
if checkpoint_dir:
str_var_list = checkpoint_utils.list_variables(checkpoint_dir)
name_to_v_map = {v.op.name: v for v in tf.all_variables()}
var_list = [
name_to_v_map[vn] for vn, _ in str_var_list if vn in name_to_v_map
Expand All @@ -99,9 +99,9 @@ def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000):
# global step should be restored from the evals job checkpoint or zero for fresh.
step = sess.run(global_step)

if step == 0 and checkpoint:
if step == 0 and checkpoint_dir:
tf.logging.info("force restore")
saver.restore(sess, checkpoint)
saver.restore(sess, checkpoint_dir)
tf.logging.info("force restore done")
sess.run(reset_global_step)
step = sess.run(global_step)
Expand All @@ -115,7 +115,7 @@ def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000):


def main(argv):
train(FLAGS.train_log_dir, FLAGS.checkpoint)
train(FLAGS.train_log_dir, FLAGS.checkpoint_dir)


if __name__ == "__main__":
Expand Down

0 comments on commit 2949cfd

Please sign in to comment.