diff --git a/python/ray/tune/durable_trainable.py b/python/ray/tune/durable_trainable.py index d48053bab1512..150d98c6b9323 100644 --- a/python/ray/tune/durable_trainable.py +++ b/python/ray/tune/durable_trainable.py @@ -12,6 +12,11 @@ class DurableTrainable(Trainable): """A fault-tolerant Trainable. Supports checkpointing to and restoring from remote storage. + + The storage client must provide durability for restoration to work. That + is, once ``storage.client.wait()`` returns after a checkpoint `sync up`, + the checkpoint is considered committed and can be used to restore the + trainable. """ def __init__(self, *args, **kwargs): @@ -34,6 +39,7 @@ def restore(self, checkpoint_path): if not os.path.exists(local_dirpath): os.makedirs(local_dirpath) self.storage_client.sync_down(storage_dirpath, local_dirpath) + self.storage_client.wait() super(DurableTrainable, self).restore(checkpoint_path) def delete_checkpoint(self, checkpoint_path): @@ -43,9 +49,13 @@ def delete_checkpoint(self, checkpoint_path): self.storage_client.delete(self._storage_path(local_dirpath)) def _storage_path(self, local_path): - rel_local_path = os.path.relpath(local_path, self.logdir) + logdir_parent = os.path.dirname(self.logdir) + rel_local_path = os.path.relpath(local_path, logdir_parent) return os.path.join(self._root_storage_path(), rel_local_path) def _root_storage_path(self): - """Path to directory in which checkpoints and logs are stored.""" + """Path to directory in which checkpoints are stored. + + You can also use `self.storage_client` to store logs here. + """ raise NotImplementedError("Storage path must be provided by subclass.") diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index c9efee0e8d971..368a26a510a8e 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -148,9 +148,11 @@ def _start_trial(self, trial, checkpoint=None, runner=None): """ prior_status = trial.status self.set_status(trial, Trial.RUNNING) - trial.set_runner(runner or self._setup_remote_runner( - trial, - reuse_allowed=checkpoint is not None or trial.has_checkpoint())) + trial.set_runner( + runner or self._setup_remote_runner( + trial, + reuse_allowed=checkpoint is not None + or trial.has_checkpoint())) self.restore(trial, checkpoint) previous_run = self._find_item(self._paused, trial) @@ -208,7 +210,7 @@ def start_trial(self, trial, checkpoint=None): checkpoint (Checkpoint): A Python object or path storing the state of trial. """ - attempts = trial.num_failures_between_results + attempts = trial.num_failures_since_result if attempts >= TRIAL_START_ATTEMPTS: return # Exceeded restoration attempts. self._commit_resources(trial.resources) diff --git a/python/ray/tune/sync_client.py b/python/ray/tune/sync_client.py index 6c10053c1bc16..a62284d5abd4b 100644 --- a/python/ray/tune/sync_client.py +++ b/python/ray/tune/sync_client.py @@ -29,7 +29,16 @@ def noop(*args): def get_sync_client(sync_function, delete_function=None): - """Gets sync client.""" + """Returns a sync client. + + Args: + sync_function (str|function): Sync function. + delete_function (Optional[str|function]): Delete function. Must be + the same type as sync_function if it is provided. + + Raises: + ValueError if sync_function or delete_function are malformed. + """ if delete_function and type(sync_function) != type(delete_function): raise ValueError("Sync and delete functions must be of same type.") if isinstance(sync_function, types.FunctionType): @@ -45,6 +54,14 @@ def get_sync_client(sync_function, delete_function=None): def get_cloud_sync_client(remote_path): + """Returns a CommandBasedClient that can sync to/from remote storage. + + Args: + remote_path (str): Path to remote storage (S3 or GS). + + Raises: + ValueError if malformed remote_dir. + """ if remote_path.startswith(S3_PREFIX): if not distutils.spawn.find_executable("aws"): raise ValueError( @@ -67,20 +84,24 @@ def get_cloud_sync_client(remote_path): class SyncClient(object): def sync_up(self, source, target): - """Sync up from source to target. + """Syncs up from source to target. + Args: source (str): Source path. target (str): Target path. + Returns: True if sync initiation successful, False otherwise. """ raise NotImplementedError def sync_down(self, source, target): - """Sync down from source to target. + """Syncs down from source to target. + Args: source (str): Source path. target (str): Target path. + Returns: True if sync initiation successful, False otherwise. """ @@ -98,7 +119,7 @@ def delete(self, target): raise NotImplementedError def wait(self): - """Wait for current sync to complete, if asynchronously started.""" + """Waits for current sync to complete, if asynchronously started.""" pass def reset(self): diff --git a/python/ray/tune/syncer.py b/python/ray/tune/syncer.py index f442a8288c1e4..0c5f333a9bd76 100644 --- a/python/ray/tune/syncer.py +++ b/python/ray/tune/syncer.py @@ -221,6 +221,9 @@ def get_cloud_syncer(local_dir, remote_dir=None, sync_function=None): remote_dir. If string, then it must be a string template for syncer to run. If not provided, it defaults to standard S3 or gsutil sync commands. + + Raises: + ValueError if malformed remote_dir. """ key = (local_dir, remote_dir) diff --git a/python/ray/tune/tests/test_checkpoint_manager.py b/python/ray/tune/tests/test_checkpoint_manager.py index 4567db6c15a36..7085cecdce40c 100644 --- a/python/ray/tune/tests/test_checkpoint_manager.py +++ b/python/ray/tune/tests/test_checkpoint_manager.py @@ -21,8 +21,8 @@ def mock_result(i): return {"i": i} def checkpoint_manager(self, keep_checkpoints_num): - return CheckpointManager(keep_checkpoints_num, "i", - delete_fn=lambda c: None) + return CheckpointManager( + keep_checkpoints_num, "i", delete_fn=lambda c: None) def testOnCheckpointOrdered(self): """ diff --git a/python/ray/tune/tests/test_ray_trial_executor.py b/python/ray/tune/tests/test_ray_trial_executor.py index db0e945df69d4..abdb49a9a893c 100644 --- a/python/ray/tune/tests/test_ray_trial_executor.py +++ b/python/ray/tune/tests/test_ray_trial_executor.py @@ -4,11 +4,9 @@ from __future__ import print_function import json -import sys import unittest import ray -from ray.exceptions import RayTimeoutError from ray.rllib import _register_all from ray.tune import Trainable from ray.tune.ray_trial_executor import RayTrialExecutor @@ -18,11 +16,6 @@ from ray.tune.resources import Resources from ray.cluster_utils import Cluster -if sys.version_info >= (3, 3): - from unittest.mock import patch -else: - from mock import patch - class RayTrialExecutorTest(unittest.TestCase): def setUp(self): diff --git a/python/ray/tune/tests/test_sync.py b/python/ray/tune/tests/test_sync.py index 64621dc3e1a3e..85c36470b48d3 100644 --- a/python/ray/tune/tests/test_sync.py +++ b/python/ray/tune/tests/test_sync.py @@ -13,6 +13,7 @@ from ray.rllib import _register_all from ray import tune +from ray.tune import TuneError from ray.tune.syncer import CommandBasedClient if sys.version_info >= (3, 3): @@ -29,7 +30,7 @@ def tearDown(self): ray.shutdown() _register_all() # re-register the evicted objects - @patch("ray.tune.syncer.S3_PREFIX", "test") + @patch("ray.tune.sync_client.S3_PREFIX", "test") def testNoUploadDir(self): """No Upload Dir is given.""" with self.assertRaises(AssertionError): @@ -44,7 +45,7 @@ def testNoUploadDir(self): "sync_to_cloud": "echo {source} {target}" }).trials - @patch("ray.tune.syncer.S3_PREFIX", "test") + @patch("ray.tune.sync_client.S3_PREFIX", "test") def testCloudProperString(self): with self.assertRaises(ValueError): [trial] = tune.run( @@ -93,7 +94,8 @@ def testCloudProperString(self): def testClusterProperString(self): """Tests that invalid commands throw..""" - with self.assertRaises(ValueError): + with self.assertRaises(TuneError): + # This raises TuneError because logger is init in safe zone. [trial] = tune.run( "__fake", name="foo", @@ -105,7 +107,8 @@ def testClusterProperString(self): "sync_to_driver": "ls {target}" }).trials - with self.assertRaises(ValueError): + with self.assertRaises(TuneError): + # This raises TuneError because logger is init in safe zone. [trial] = tune.run( "__fake", name="foo", @@ -117,7 +120,7 @@ def testClusterProperString(self): "sync_to_driver": "ls {source}" }).trials - with patch.object(CommandBasedClient, "execute") as mock_fn: + with patch.object(CommandBasedClient, "_execute") as mock_fn: with patch("ray.services.get_node_ip_address") as mock_sync: mock_sync.return_value = "0.0.0.0" [trial] = tune.run( @@ -195,7 +198,7 @@ def testNoSync(self): def sync_func(source, target): pass - with patch.object(CommandBasedClient, "execute") as mock_sync: + with patch.object(CommandBasedClient, "_execute") as mock_sync: [trial] = tune.run( "__fake", name="foo", diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 80572a97689a9..4e3b4937cff40 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -373,9 +373,14 @@ def restore_from_object(self, obj): shutil.rmtree(tmpdir) def delete_checkpoint(self, checkpoint_path): + """Deletes local copy of checkpoint. + + Args: + checkpoint_path (str): Path to checkpoint. + """ checkpoint_dir = os.path.dirname(checkpoint_path) if os.path.exists(checkpoint_dir): - os.rmdir(checkpoint_dir) + shutil.rmtree(checkpoint_dir) def export_model(self, export_formats, export_dir=None): """Exports model based on export_formats. diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index ffad1f778126e..135371c6f6bcc 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -199,7 +199,7 @@ def __init__(self, # Restoration fields self.restoring_from = None self.num_failures = 0 - self.num_failures_between_results = 0 # Name this something better... + self.num_failures_since_result = 0 # AutoML fields self.results = None @@ -293,7 +293,7 @@ def close_logger(self): def write_error_log(self, error_msg): if error_msg and self.logdir: self.num_failures += 1 - self.num_failures_between_results += 1 + self.num_failures_since_result += 1 self.error_file = os.path.join(self.logdir, "error.txt") with open(self.error_file, "a+") as f: f.write("Failure # {} (occurred at {})\n".format( @@ -350,8 +350,6 @@ def on_checkpoint(self, checkpoint): # after this to handle checkpoints taken mid-sync. self.result_logger.wait() # Force sync down and wait before tracking the new checkpoint. - # This prevents attempts to restore from partially synced - # checkpoints. if self.result_logger.sync_down(): self.result_logger.wait() else: @@ -361,7 +359,7 @@ def on_checkpoint(self, checkpoint): self.checkpoint_manager.on_checkpoint(checkpoint) def on_begin_restore(self, checkpoint): - """Handles newly dispatched restore. + """Handles dispatched async restore. This can be called multiple times without subsequently calling `on_restore` since a restoration attempt can fail. @@ -398,7 +396,7 @@ def update_last_result(self, result, terminate=False): print(" {}".format(pretty_print(result).replace("\n", "\n "))) self.last_debug = time.time() self.set_location(Location(result.get("node_ip"), result.get("pid"))) - self.num_failures_between_results = 0 + self.num_failures_since_result = 0 self.last_result = result self.last_update_time = time.time() self.result_logger.on_result(self.last_result) diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py index 921b325e32670..95b420dccb2e9 100644 --- a/python/ray/tune/trial_executor.py +++ b/python/ray/tune/trial_executor.py @@ -144,9 +144,9 @@ def reset_trial(self, trial, new_config, new_experiment_tag): raise NotImplementedError def get_running_trials(self): - """Returns all active trials.""" + """Returns all running trials.""" raise NotImplementedError("Subclasses of TrialExecutor must provide " - "get_active_trials() method") + "get_running_trials() method") def on_step_begin(self, trial_runner): """A hook called before running one step of the trial event loop.""" diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 54d2d0df7c2f1..cbd522ee92dd8 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -416,20 +416,6 @@ def _process_events(self): with warn_if_slow("process_trial"): self._process_trial(trial) - # def _process_trial_checkpoint(self, trial): - # """Processes a trial's persisted checkpoint.""" - # value = None - # try: - # value = self.trial_executor.fetch_result(trial) - # except Exception: - # error_msg = "Trial {}: Error processing checkpoint".format(trial) - # error_msg += "{}.".format(value) if value else "" - # logger.exception(error_msg) - # else: - # checkpoint = Checkpoint(Checkpoint.PERSISTENT, value, trial.last_result) - # trial.on_checkpoint(checkpoint) - # # notify trial_executor according to cached decision. - def _process_trial(self, trial): """Processes a trial result.""" try: @@ -501,7 +487,6 @@ def _process_trial_restore(self, trial): except Exception: logger.exception("Trial %s: Error processing restore.", trial) self._process_trial_failure(trial, traceback.format_exc()) - return def _process_trial_failure(self, trial, error_msg): """Handle trial failure.