Skip to content

Commit

Permalink
Cleanup, comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
ujvl committed Dec 16, 2019
1 parent 95a3d13 commit 028876b
Show file tree
Hide file tree
Showing 11 changed files with 69 additions and 49 deletions.
14 changes: 12 additions & 2 deletions python/ray/tune/durable_trainable.py
Expand Up @@ -12,6 +12,11 @@ class DurableTrainable(Trainable):
"""A fault-tolerant Trainable.
Supports checkpointing to and restoring from remote storage.
The storage client must provide durability for restoration to work. That
is, once ``storage.client.wait()`` returns after a checkpoint `sync up`,
the checkpoint is considered committed and can be used to restore the
trainable.
"""

def __init__(self, *args, **kwargs):
Expand All @@ -34,6 +39,7 @@ def restore(self, checkpoint_path):
if not os.path.exists(local_dirpath):
os.makedirs(local_dirpath)
self.storage_client.sync_down(storage_dirpath, local_dirpath)
self.storage_client.wait()
super(DurableTrainable, self).restore(checkpoint_path)

def delete_checkpoint(self, checkpoint_path):
Expand All @@ -43,9 +49,13 @@ def delete_checkpoint(self, checkpoint_path):
self.storage_client.delete(self._storage_path(local_dirpath))

def _storage_path(self, local_path):
rel_local_path = os.path.relpath(local_path, self.logdir)
logdir_parent = os.path.dirname(self.logdir)
rel_local_path = os.path.relpath(local_path, logdir_parent)
return os.path.join(self._root_storage_path(), rel_local_path)

def _root_storage_path(self):
"""Path to directory in which checkpoints and logs are stored."""
"""Path to directory in which checkpoints are stored.
You can also use `self.storage_client` to store logs here.
"""
raise NotImplementedError("Storage path must be provided by subclass.")
10 changes: 6 additions & 4 deletions python/ray/tune/ray_trial_executor.py
Expand Up @@ -148,9 +148,11 @@ def _start_trial(self, trial, checkpoint=None, runner=None):
"""
prior_status = trial.status
self.set_status(trial, Trial.RUNNING)
trial.set_runner(runner or self._setup_remote_runner(
trial,
reuse_allowed=checkpoint is not None or trial.has_checkpoint()))
trial.set_runner(
runner or self._setup_remote_runner(
trial,
reuse_allowed=checkpoint is not None
or trial.has_checkpoint()))
self.restore(trial, checkpoint)

previous_run = self._find_item(self._paused, trial)
Expand Down Expand Up @@ -208,7 +210,7 @@ def start_trial(self, trial, checkpoint=None):
checkpoint (Checkpoint): A Python object or path storing the state
of trial.
"""
attempts = trial.num_failures_between_results
attempts = trial.num_failures_since_result
if attempts >= TRIAL_START_ATTEMPTS:
return # Exceeded restoration attempts.
self._commit_resources(trial.resources)
Expand Down
29 changes: 25 additions & 4 deletions python/ray/tune/sync_client.py
Expand Up @@ -29,7 +29,16 @@ def noop(*args):


def get_sync_client(sync_function, delete_function=None):
"""Gets sync client."""
"""Returns a sync client.
Args:
sync_function (str|function): Sync function.
delete_function (Optional[str|function]): Delete function. Must be
the same type as sync_function if it is provided.
Raises:
ValueError if sync_function or delete_function are malformed.
"""
if delete_function and type(sync_function) != type(delete_function):
raise ValueError("Sync and delete functions must be of same type.")
if isinstance(sync_function, types.FunctionType):
Expand All @@ -45,6 +54,14 @@ def get_sync_client(sync_function, delete_function=None):


def get_cloud_sync_client(remote_path):
"""Returns a CommandBasedClient that can sync to/from remote storage.
Args:
remote_path (str): Path to remote storage (S3 or GS).
Raises:
ValueError if malformed remote_dir.
"""
if remote_path.startswith(S3_PREFIX):
if not distutils.spawn.find_executable("aws"):
raise ValueError(
Expand All @@ -67,20 +84,24 @@ def get_cloud_sync_client(remote_path):

class SyncClient(object):
def sync_up(self, source, target):
"""Sync up from source to target.
"""Syncs up from source to target.
Args:
source (str): Source path.
target (str): Target path.
Returns:
True if sync initiation successful, False otherwise.
"""
raise NotImplementedError

def sync_down(self, source, target):
"""Sync down from source to target.
"""Syncs down from source to target.
Args:
source (str): Source path.
target (str): Target path.
Returns:
True if sync initiation successful, False otherwise.
"""
Expand All @@ -98,7 +119,7 @@ def delete(self, target):
raise NotImplementedError

def wait(self):
"""Wait for current sync to complete, if asynchronously started."""
"""Waits for current sync to complete, if asynchronously started."""
pass

def reset(self):
Expand Down
3 changes: 3 additions & 0 deletions python/ray/tune/syncer.py
Expand Up @@ -221,6 +221,9 @@ def get_cloud_syncer(local_dir, remote_dir=None, sync_function=None):
remote_dir. If string, then it must be a string template for
syncer to run. If not provided, it defaults
to standard S3 or gsutil sync commands.
Raises:
ValueError if malformed remote_dir.
"""
key = (local_dir, remote_dir)

Expand Down
4 changes: 2 additions & 2 deletions python/ray/tune/tests/test_checkpoint_manager.py
Expand Up @@ -21,8 +21,8 @@ def mock_result(i):
return {"i": i}

def checkpoint_manager(self, keep_checkpoints_num):
return CheckpointManager(keep_checkpoints_num, "i",
delete_fn=lambda c: None)
return CheckpointManager(
keep_checkpoints_num, "i", delete_fn=lambda c: None)

def testOnCheckpointOrdered(self):
"""
Expand Down
7 changes: 0 additions & 7 deletions python/ray/tune/tests/test_ray_trial_executor.py
Expand Up @@ -4,11 +4,9 @@
from __future__ import print_function

import json
import sys
import unittest

import ray
from ray.exceptions import RayTimeoutError
from ray.rllib import _register_all
from ray.tune import Trainable
from ray.tune.ray_trial_executor import RayTrialExecutor
Expand All @@ -18,11 +16,6 @@
from ray.tune.resources import Resources
from ray.cluster_utils import Cluster

if sys.version_info >= (3, 3):
from unittest.mock import patch
else:
from mock import patch


class RayTrialExecutorTest(unittest.TestCase):
def setUp(self):
Expand Down
15 changes: 9 additions & 6 deletions python/ray/tune/tests/test_sync.py
Expand Up @@ -13,6 +13,7 @@
from ray.rllib import _register_all

from ray import tune
from ray.tune import TuneError
from ray.tune.syncer import CommandBasedClient

if sys.version_info >= (3, 3):
Expand All @@ -29,7 +30,7 @@ def tearDown(self):
ray.shutdown()
_register_all() # re-register the evicted objects

@patch("ray.tune.syncer.S3_PREFIX", "test")
@patch("ray.tune.sync_client.S3_PREFIX", "test")
def testNoUploadDir(self):
"""No Upload Dir is given."""
with self.assertRaises(AssertionError):
Expand All @@ -44,7 +45,7 @@ def testNoUploadDir(self):
"sync_to_cloud": "echo {source} {target}"
}).trials

@patch("ray.tune.syncer.S3_PREFIX", "test")
@patch("ray.tune.sync_client.S3_PREFIX", "test")
def testCloudProperString(self):
with self.assertRaises(ValueError):
[trial] = tune.run(
Expand Down Expand Up @@ -93,7 +94,8 @@ def testCloudProperString(self):

def testClusterProperString(self):
"""Tests that invalid commands throw.."""
with self.assertRaises(ValueError):
with self.assertRaises(TuneError):
# This raises TuneError because logger is init in safe zone.
[trial] = tune.run(
"__fake",
name="foo",
Expand All @@ -105,7 +107,8 @@ def testClusterProperString(self):
"sync_to_driver": "ls {target}"
}).trials

with self.assertRaises(ValueError):
with self.assertRaises(TuneError):
# This raises TuneError because logger is init in safe zone.
[trial] = tune.run(
"__fake",
name="foo",
Expand All @@ -117,7 +120,7 @@ def testClusterProperString(self):
"sync_to_driver": "ls {source}"
}).trials

with patch.object(CommandBasedClient, "execute") as mock_fn:
with patch.object(CommandBasedClient, "_execute") as mock_fn:
with patch("ray.services.get_node_ip_address") as mock_sync:
mock_sync.return_value = "0.0.0.0"
[trial] = tune.run(
Expand Down Expand Up @@ -195,7 +198,7 @@ def testNoSync(self):
def sync_func(source, target):
pass

with patch.object(CommandBasedClient, "execute") as mock_sync:
with patch.object(CommandBasedClient, "_execute") as mock_sync:
[trial] = tune.run(
"__fake",
name="foo",
Expand Down
7 changes: 6 additions & 1 deletion python/ray/tune/trainable.py
Expand Up @@ -373,9 +373,14 @@ def restore_from_object(self, obj):
shutil.rmtree(tmpdir)

def delete_checkpoint(self, checkpoint_path):
"""Deletes local copy of checkpoint.
Args:
checkpoint_path (str): Path to checkpoint.
"""
checkpoint_dir = os.path.dirname(checkpoint_path)
if os.path.exists(checkpoint_dir):
os.rmdir(checkpoint_dir)
shutil.rmtree(checkpoint_dir)

def export_model(self, export_formats, export_dir=None):
"""Exports model based on export_formats.
Expand Down
10 changes: 4 additions & 6 deletions python/ray/tune/trial.py
Expand Up @@ -199,7 +199,7 @@ def __init__(self,
# Restoration fields
self.restoring_from = None
self.num_failures = 0
self.num_failures_between_results = 0 # Name this something better...
self.num_failures_since_result = 0

# AutoML fields
self.results = None
Expand Down Expand Up @@ -293,7 +293,7 @@ def close_logger(self):
def write_error_log(self, error_msg):
if error_msg and self.logdir:
self.num_failures += 1
self.num_failures_between_results += 1
self.num_failures_since_result += 1
self.error_file = os.path.join(self.logdir, "error.txt")
with open(self.error_file, "a+") as f:
f.write("Failure # {} (occurred at {})\n".format(
Expand Down Expand Up @@ -350,8 +350,6 @@ def on_checkpoint(self, checkpoint):
# after this to handle checkpoints taken mid-sync.
self.result_logger.wait()
# Force sync down and wait before tracking the new checkpoint.
# This prevents attempts to restore from partially synced
# checkpoints.
if self.result_logger.sync_down():
self.result_logger.wait()
else:
Expand All @@ -361,7 +359,7 @@ def on_checkpoint(self, checkpoint):
self.checkpoint_manager.on_checkpoint(checkpoint)

def on_begin_restore(self, checkpoint):
"""Handles newly dispatched restore.
"""Handles dispatched async restore.
This can be called multiple times without subsequently calling
`on_restore` since a restoration attempt can fail.
Expand Down Expand Up @@ -398,7 +396,7 @@ def update_last_result(self, result, terminate=False):
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
self.last_debug = time.time()
self.set_location(Location(result.get("node_ip"), result.get("pid")))
self.num_failures_between_results = 0
self.num_failures_since_result = 0
self.last_result = result
self.last_update_time = time.time()
self.result_logger.on_result(self.last_result)
Expand Down
4 changes: 2 additions & 2 deletions python/ray/tune/trial_executor.py
Expand Up @@ -144,9 +144,9 @@ def reset_trial(self, trial, new_config, new_experiment_tag):
raise NotImplementedError

def get_running_trials(self):
"""Returns all active trials."""
"""Returns all running trials."""
raise NotImplementedError("Subclasses of TrialExecutor must provide "
"get_active_trials() method")
"get_running_trials() method")

def on_step_begin(self, trial_runner):
"""A hook called before running one step of the trial event loop."""
Expand Down
15 changes: 0 additions & 15 deletions python/ray/tune/trial_runner.py
Expand Up @@ -416,20 +416,6 @@ def _process_events(self):
with warn_if_slow("process_trial"):
self._process_trial(trial)

# def _process_trial_checkpoint(self, trial):
# """Processes a trial's persisted checkpoint."""
# value = None
# try:
# value = self.trial_executor.fetch_result(trial)
# except Exception:
# error_msg = "Trial {}: Error processing checkpoint".format(trial)
# error_msg += "{}.".format(value) if value else ""
# logger.exception(error_msg)
# else:
# checkpoint = Checkpoint(Checkpoint.PERSISTENT, value, trial.last_result)
# trial.on_checkpoint(checkpoint)
# # notify trial_executor according to cached decision.

def _process_trial(self, trial):
"""Processes a trial result."""
try:
Expand Down Expand Up @@ -501,7 +487,6 @@ def _process_trial_restore(self, trial):
except Exception:
logger.exception("Trial %s: Error processing restore.", trial)
self._process_trial_failure(trial, traceback.format_exc())
return

def _process_trial_failure(self, trial, error_msg):
"""Handle trial failure.
Expand Down

0 comments on commit 028876b

Please sign in to comment.