Skip to content

Commit

Permalink
[tune] Avoid overwriting checkpoint file (#3781)
Browse files Browse the repository at this point in the history
  • Loading branch information
richardliaw committed Jan 16, 2019
1 parent a237b4a commit c28e6d4
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 19 deletions.
14 changes: 4 additions & 10 deletions python/ray/tune/test/cluster_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,7 @@ def test_cluster_rllib_restore(start_connected_cluster, tmpdir):
# the checkpoint.
metadata_checkpoint_dir = os.path.join(dirpath, "experiment")
for i in range(100):
if os.path.exists(
os.path.join(metadata_checkpoint_dir,
TrialRunner.CKPT_FILE_NAME)):
if TrialRunner.checkpoint_exists(metadata_checkpoint_dir):
# Inspect the internal trialrunner
runner = TrialRunner.restore(metadata_checkpoint_dir)
trials = runner.get_trials()
Expand All @@ -401,8 +399,7 @@ def test_cluster_rllib_restore(start_connected_cluster, tmpdir):
break
time.sleep(0.3)

if not os.path.exists(
os.path.join(metadata_checkpoint_dir, TrialRunner.CKPT_FILE_NAME)):
if not TrialRunner.checkpoint_exists(metadata_checkpoint_dir):
raise RuntimeError("Checkpoint file didn't appear.")

ray.shutdown()
Expand Down Expand Up @@ -485,9 +482,7 @@ def _restore(self, state):
# the checkpoint.
metadata_checkpoint_dir = os.path.join(dirpath, "experiment")
for i in range(50):
if os.path.exists(
os.path.join(metadata_checkpoint_dir,
TrialRunner.CKPT_FILE_NAME)):
if TrialRunner.checkpoint_exists(metadata_checkpoint_dir):
# Inspect the internal trialrunner
runner = TrialRunner.restore(metadata_checkpoint_dir)
trials = runner.get_trials()
Expand All @@ -496,8 +491,7 @@ def _restore(self, state):
break
time.sleep(0.2)

if not os.path.exists(
os.path.join(metadata_checkpoint_dir, TrialRunner.CKPT_FILE_NAME)):
if not TrialRunner.checkpoint_exists(metadata_checkpoint_dir):
raise RuntimeError("Checkpoint file didn't appear.")

ray.shutdown()
Expand Down
27 changes: 27 additions & 0 deletions python/ray/tune/test/trial_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1796,6 +1796,33 @@ def testCheckpointWithFunction(self):
self.assertTrue("on_episode_start" in new_trial.config["callbacks"])
shutil.rmtree(tmpdir)

def testCheckpointOverwrite(self):
def count_checkpoints(cdir):
return sum((fname.startswith("experiment_state")
and fname.endswith(".json"))
for fname in os.listdir(cdir))

ray.init()
trial = Trial("__fake", checkpoint_freq=1)
tmpdir = tempfile.mkdtemp()
runner = TrialRunner(
BasicVariantGenerator(), metadata_checkpoint_dir=tmpdir)
runner.add_trial(trial)
for i in range(5):
runner.step()
# force checkpoint
runner.checkpoint()
self.assertEquals(count_checkpoints(tmpdir), 1)

runner2 = TrialRunner.restore(tmpdir)
for i in range(5):
runner2.step()
self.assertEquals(count_checkpoints(tmpdir), 2)

runner2.checkpoint()
self.assertEquals(count_checkpoints(tmpdir), 2)
shutil.rmtree(tmpdir)


class SearchAlgorithmTest(unittest.TestCase):
def testNestedSuggestion(self):
Expand Down
39 changes: 32 additions & 7 deletions python/ray/tune/trial_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import print_function

import collections
from datetime import datetime
import json
import logging
import os
Expand All @@ -28,6 +29,15 @@ def _naturalize(string):
return [int(text) if text.isdigit() else text.lower() for text in splits]


def _find_newest_ckpt(ckpt_dir):
"""Returns path to most recently modified checkpoint."""
full_paths = [
os.path.join(ckpt_dir, fname) for fname in os.listdir(ckpt_dir)
if fname.startswith("experiment_state") and fname.endswith(".json")
]
return max(full_paths)


class TrialRunner(object):
"""A TrialRunner implements the event loop for scheduling trials on Ray.
Expand All @@ -50,7 +60,7 @@ class TrialRunner(object):
misleading benchmark results.
"""

CKPT_FILE_NAME = "experiment_state.json"
CKPT_FILE_TMPL = "experiment_state-{}.json"

def __init__(self,
search_alg,
Expand Down Expand Up @@ -102,8 +112,22 @@ def __init__(self,
self._stop_queue = []
self._metadata_checkpoint_dir = metadata_checkpoint_dir

self._session = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")

@classmethod
def checkpoint_exists(cls, directory):
if not os.path.exists(directory):
return False
return any(
(fname.startswith("experiment_state") and fname.endswith(".json"))
for fname in os.listdir(directory))

def checkpoint(self):
"""Saves execution state to `self._metadata_checkpoint_dir`."""
"""Saves execution state to `self._metadata_checkpoint_dir`.
Overwrites the current session checkpoint, which starts when self
is instantiated.
"""
if not self._metadata_checkpoint_dir:
return
metadata_checkpoint_dir = self._metadata_checkpoint_dir
Expand All @@ -121,7 +145,8 @@ def checkpoint(self):

os.rename(
tmp_file_name,
os.path.join(metadata_checkpoint_dir, TrialRunner.CKPT_FILE_NAME))
os.path.join(metadata_checkpoint_dir,
TrialRunner.CKPT_FILE_TMPL.format(self._session)))
return metadata_checkpoint_dir

@classmethod
Expand All @@ -146,9 +171,9 @@ def restore(cls,
Returns:
runner (TrialRunner): A TrialRunner to resume experiments from.
"""
with open(
os.path.join(metadata_checkpoint_dir,
TrialRunner.CKPT_FILE_NAME), "r") as f:

newest_ckpt_path = _find_newest_ckpt(metadata_checkpoint_dir)
with open(newest_ckpt_path, "r") as f:
runner_state = json.load(f)

logger.warning("".join([
Expand Down Expand Up @@ -520,7 +545,7 @@ def __getstate__(self):
state = self.__dict__.copy()
for k in [
"_trials", "_stop_queue", "_server", "_search_alg",
"_scheduler_alg", "trial_executor"
"_scheduler_alg", "trial_executor", "_session"
]:
del state[k]
state["launch_web_server"] = bool(self._server)
Expand Down
3 changes: 1 addition & 2 deletions python/ray/tune/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,7 @@ def run_experiments(experiments,
runner = None
restore = False

if os.path.exists(
os.path.join(checkpoint_dir, TrialRunner.CKPT_FILE_NAME)):
if TrialRunner.checkpoint_exists(checkpoint_dir):
if resume == "prompt":
msg = ("Found incomplete experiment at {}. "
"Would you like to resume it?".format(checkpoint_dir))
Expand Down

0 comments on commit c28e6d4

Please sign in to comment.