From eef2aaa3f7931cc3c0e7bbc6757386cfbcf5e3ef Mon Sep 17 00:00:00 2001 From: MarkDaoust Date: Thu, 19 Nov 2015 13:26:35 -0500 Subject: [PATCH] Fixed saver relative paths for `latest_checkpoint` This would be cleaner if we made all paths listed in the "latest" file relative to the its directory, allowing the removal of the added `os.path.isabs` checks. That would make the `os.join` in `saver.latest_checkpoint` much less surprising. But at least this way, there is no effect on currently working code. Fixes #571 Change-Id: I47d8536b9b2ed3dcc193d6e6b7f4573a4e22c9b3 --- tensorflow/python/training/saver.py | 17 ++++++ tensorflow/python/training/saver_test.py | 71 ++++++++++++++++++++++++ 2 files changed, 88 insertions(+) diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 34db5c3cd37897..6d86925b8fb685 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -489,6 +489,16 @@ def update_checkpoint_state(save_dir, all_model_checkpoint_paths.append(model_checkpoint_path) # Writes the "checkpoint" file for the coordinator for later restoration. coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename) + + # Relative paths need to be rewritten to be relative to the "save_dir". + if not os.path.isabs(model_checkpoint_path): + model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir) + + all_model_checkpoint_paths = [ + os.path.relpath(p, save_dir) for p in all_model_checkpoint_paths + if not os.path.isabs(p) + ] + if coord_checkpoint_filename == model_checkpoint_path: raise RuntimeError("Save path '%s' conflicts with path used for " "checkpoint state. Please use a different save path." % @@ -854,6 +864,10 @@ def save(self, sess, save_path, global_step=None, latest_filename=None): """ if latest_filename is None: latest_filename = "checkpoint" + + if os.path.split(latest_filename)[0]: + raise ValueError("'latest_filename' must not contain path components") + if global_step is not None: if not isinstance(global_step, compat.integral_types): global_step = training_util.global_step(sess, global_step) @@ -905,8 +919,11 @@ def latest_checkpoint(checkpoint_dir, latest_filename=None): # Pick the latest checkpoint based on checkpoint state. ckpt = get_checkpoint_state(checkpoint_dir, latest_filename) if ckpt and ckpt.model_checkpoint_path: + + # If you pass "os.path.join" two absolute paths it returns the second one. checkpoint_pattern = os.path.join( checkpoint_dir, ckpt.model_checkpoint_path) + if gfile.Glob(checkpoint_pattern): return checkpoint_pattern diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index af4c6a6e60b626..59df6c9ac0ae75 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -20,6 +20,9 @@ import os.path import time +import contextlib +import shutil +import tempfile import tensorflow.python.platform @@ -583,5 +586,73 @@ def testNonReshape(self): self.assertEqual(20.0, v1.eval()) +class LatestCheckpointWithRelativePaths(tf.test.TestCase): + + @staticmethod + @contextlib.contextmanager + def tempWorkingDir(temppath): + cwd = os.getcwd() + os.chdir(temppath) + try: + yield + finally: + os.chdir(cwd) + + @staticmethod + @contextlib.contextmanager + def tempDir(): + tempdir = tempfile.mkdtemp() + try: + yield tempdir + finally: + shutil.rmtree(tempdir) + + def testRelativePath(self): + # Make sure we have a clean directory to work in. + with self.tempDir() as tempdir: + + # Jump to that directory until this test is done. + with self.tempWorkingDir(tempdir): + + # Save training snapshots to a relative path. + traindir = 'train/' + os.mkdir(traindir) + + filename = 'snapshot' + filepath = os.path.join(traindir, filename) + + with self.test_session() as sess: + # Build a simple graph. + v0 = tf.Variable(0.0) + inc = v0.assign_add(1.0) + + save = tf.train.Saver({'v0': v0}) + + # Record a short training history. + tf.initialize_all_variables().run() + save.save(sess, filepath, global_step=0) + inc.eval() + save.save(sess, filepath, global_step=1) + inc.eval() + save.save(sess, filepath, global_step=2) + + with self.test_session() as sess: + # Build a new graph with different initialization. + v0 = tf.Variable(-1.0) + + # Create a new saver. + save = tf.train.Saver({'v0': v0}) + tf.initialize_all_variables().run() + + # Get the most recent checkpoint name from the training history file. + name = tf.train.latest_checkpoint(traindir) + self.assertIsNotNone(name) + + # Restore "v0" from that checkpoint. + save.restore(sess, name) + self.assertEquals(v0.eval(), 2.0) + + + if __name__ == "__main__": tf.test.main()