Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune] Async restores and S3/GCP-capable trial FT #6376

Merged
merged 54 commits into from Jan 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
e9b641e
Initial commit for asynchronous save/restore
ujvl Dec 6, 2019
102bc30
Set stage for cloud checkpointable trainable.
ujvl Dec 15, 2019
8a31dea
Refactor log_sync and sync_client.
ujvl Dec 15, 2019
d0a6dbb
Add durable trainable impl.
ujvl Dec 15, 2019
4641f0e
Support delete in cmd based client
ujvl Dec 15, 2019
a3d4bfc
Fix some tests and such
ujvl Dec 15, 2019
95a3d13
Merge branch 'master' into tune-async-save-restore
ujvl Dec 15, 2019
028876b
Cleanup, comments.
ujvl Dec 16, 2019
50cd2c1
Use upload_dir instead.
ujvl Dec 16, 2019
859baa1
Revert files belonging to other PR in split.
ujvl Dec 16, 2019
b84be48
Pass upload_dir into trainable init.
ujvl Dec 16, 2019
abeb20f
Merge branch 'master' into tune-async-save-restore
ujvl Dec 17, 2019
9c29355
Pickle checkpoint at driver, more robust checkpoint_dir discovery.
ujvl Dec 18, 2019
7a6ebb9
Merge with master
ujvl Dec 18, 2019
a633575
Cleanup trainable helper functions, fix tests.
ujvl Dec 18, 2019
5467f5a
Addressed comments.
ujvl Dec 20, 2019
65b4058
Fix bugs from cluster testing, add parameterized cluster tests.
ujvl Dec 21, 2019
c8aaab7
Add trainable util test
ujvl Dec 21, 2019
9df308b
Merge branch 'master' into tune-async-save-restore
richardliaw Dec 25, 2019
6ebd06b
package_ref
richardliaw Dec 25, 2019
b3ccdec
pbt_address
richardliaw Dec 25, 2019
ed40f43
Fix bug after running pbt example (_save returning dir).
ujvl Dec 25, 2019
ac5fda4
Merge branch 'tune-async-save-restore' of github.com:ujvl/ray into tu…
ujvl Dec 25, 2019
422e109
Merge master
ujvl Dec 27, 2019
ecd85db
get cluster tests running, other bug fixes.
ujvl Dec 28, 2019
e6a4e14
Merge branch 'master' into tune-async-save-restore
richardliaw Dec 28, 2019
b42facb
raise_errors
richardliaw Dec 29, 2019
8457dbc
Fix deleter bug, add durable trainable example.
ujvl Dec 29, 2019
e2ae66e
Merge branch 'tune-async-save-restore' of github.com:ujvl/ray into tu…
ujvl Dec 29, 2019
71ec225
Fix cluster test bugs.
ujvl Dec 29, 2019
76e97ca
filelock
richardliaw Dec 30, 2019
c17ea1d
save/restore bug fixes
ujvl Dec 31, 2019
ddb50a8
.
ujvl Dec 31, 2019
17ce1c6
Working cluster tests.
ujvl Dec 31, 2019
82703e6
Lint, revert to tracking memory checkpoints.
ujvl Dec 31, 2019
90b0d45
Documentation, cleanup
ujvl Dec 31, 2019
bdbe96b
Merge branch 'master' into tune-async-save-restore
ujvl Dec 31, 2019
d81cdf6
fixinitialsync
richardliaw Jan 1, 2020
db18bf2
fix_one_test
richardliaw Jan 1, 2020
964af81
Merge master
ujvl Jan 1, 2020
642fe61
Fix cluster test bug
ujvl Jan 2, 2020
6532432
nit
richardliaw Jan 2, 2020
dd3321f
Merge branch 'tune-async-save-restore' of github.com:ujvl/ray into tu…
richardliaw Jan 2, 2020
e324725
lint
richardliaw Jan 2, 2020
c37ae41
Revert tune md change
ujvl Jan 2, 2020
edc1cea
Fix basename bug for directories.
ujvl Jan 2, 2020
03b2fe2
lint
ujvl Jan 2, 2020
507c885
fix_tests
richardliaw Jan 2, 2020
a3ebf67
nit_fix
richardliaw Jan 2, 2020
e74e960
Add __init__ file.
ujvl Jan 2, 2020
2153d1e
Merge branch 'tune-async-save-restore' of github.com:ujvl/ray into tu…
ujvl Jan 2, 2020
113494f
Move to utils package
ujvl Jan 3, 2020
152a0ce
Merge with master
ujvl Jan 3, 2020
ee78765
Fix merge conflicts
ujvl Jan 3, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion doc/source/tune-package-ref.rst
Expand Up @@ -7,12 +7,14 @@ ray.tune
.. automodule:: ray.tune
:members:
:show-inheritance:
:exclude-members: TuneError, Trainable
:exclude-members: TuneError, Trainable, DurableTrainable

.. autoclass:: ray.tune.Trainable
:members:
:private-members:

.. autoclass:: ray.tune.DurableTrainable

.. autoclass:: ray.tune.function_runner.StatusReporter
:members: __call__, logdir

Expand Down
9 changes: 8 additions & 1 deletion python/ray/tune/BUILD
Expand Up @@ -148,7 +148,14 @@ py_test(
deps = [":tune_lib"],
tags = ["exclusive"],
)


py_test(
name = "test_trainable_util",
size = "small",
srcs = ["tests/test_trainable_util.py"],
deps = [":tune_lib"],
)

py_test(
name = "test_trial_scheduler",
size = "medium",
Expand Down
2 changes: 2 additions & 0 deletions python/ray/tune/__init__.py
Expand Up @@ -8,12 +8,14 @@
from ray.tune.analysis import ExperimentAnalysis, Analysis
from ray.tune.registry import register_env, register_trainable
from ray.tune.trainable import Trainable
from ray.tune.durable_trainable import DurableTrainable
from ray.tune.suggest import grid_search
from ray.tune.sample import (function, sample_from, uniform, choice, randint,
randn, loguniform)

__all__ = [
"Trainable",
"DurableTrainable",
"TuneError",
"grid_search",
"register_env",
Expand Down
48 changes: 21 additions & 27 deletions python/ray/tune/checkpoint_manager.py
Expand Up @@ -5,8 +5,6 @@

import heapq
import logging
import os
import shutil

try:
FileNotFoundError
Expand All @@ -23,31 +21,18 @@ class Checkpoint:

Attributes:
storage (str): Storage type.
value (str): If storage==MEMORY, value is a Python object.
If storage==DISK, value is a path points to the checkpoint in disk.
value (str): If storage==MEMORY, it is a Python object.
If storage==PERSISTENT, it is a path to persistent storage.
"""

MEMORY = "memory"
DISK = "disk"
PERSISTENT = "persistent"

def __init__(self, storage, value, result=None):
self.storage = storage
self.value = value
self.result = result or {}

def delete(self):
"""Deletes checkpoint data if disk checkpoint."""
if self.storage == Checkpoint.DISK and self.value:
checkpoint_dir = self.value
if not os.path.exists(checkpoint_dir):
raise FileNotFoundError(
"Attempted to delete checkpoint at {} but "
"path was not found.".format(checkpoint_dir))
elif os.path.isfile(checkpoint_dir):
shutil.rmtree(os.path.dirname(checkpoint_dir))
else:
shutil.rmtree(checkpoint_dir)

@staticmethod
def from_object(value=None):
"""Creates a checkpoint from a Python object."""
Expand All @@ -72,13 +57,15 @@ def __lt__(self, other):
class CheckpointManager:
"""Manages checkpoints on the driver for a trial."""

def __init__(self, keep_checkpoints_num, checkpoint_score_attr):
def __init__(self, keep_checkpoints_num, checkpoint_score_attr, delete_fn):
"""Initializes a new CheckpointManager.

Args:
keep_checkpoints_num (int): Keep at least this many checkpoints.
checkpoint_score_attr (str): Attribute to use to determine which
checkpoints to keep.
delete_fn (function): Function that deletes checkpoints. Must be
idempotent.
"""
self.keep_checkpoints_num = keep_checkpoints_num or float("inf")
assert self.keep_checkpoints_num > 0, (
Expand All @@ -88,7 +75,7 @@ def __init__(self, keep_checkpoints_num, checkpoint_score_attr):
self._checkpoint_score_attr = checkpoint_score_attr[4:]
else:
self._checkpoint_score_attr = checkpoint_score_attr

self.delete = delete_fn
self.newest_checkpoint = Checkpoint(Checkpoint.MEMORY, None)
self._best_checkpoints = []
self._membership = set()
Expand All @@ -101,9 +88,6 @@ def on_checkpoint(self, checkpoint):

Args:
checkpoint (Checkpoint): Trial state checkpoint.

Raises:
KeyError if checkpoint_score_attr not in result of checkpoint.
"""
old_checkpoint = self.newest_checkpoint
self.newest_checkpoint = checkpoint
Expand All @@ -112,7 +96,7 @@ def on_checkpoint(self, checkpoint):
queue_item = QueueItem(self._priority(checkpoint), checkpoint)
except KeyError:
if old_checkpoint not in self._membership:
old_checkpoint.delete()
self.delete(old_checkpoint)
logger.error("Result dict has no key: {}. "
"checkpoint_score_attr must be set to a key in the "
"result dict.".format(self._checkpoint_score_attr))
Expand All @@ -126,11 +110,11 @@ def on_checkpoint(self, checkpoint):
self._membership.add(checkpoint)
if worst in self._membership:
self._membership.remove(worst)
worst.delete()
self.delete(worst)

# Remove the old checkpoint if it isn't one of the best ones.
if old_checkpoint not in self._membership:
old_checkpoint.delete()
if old_checkpoint.value and old_checkpoint not in self._membership:
self.delete(old_checkpoint)

def best_checkpoints(self):
"""Returns best checkpoints, sorted by score."""
Expand All @@ -140,3 +124,13 @@ def best_checkpoints(self):
def _priority(self, checkpoint):
priority = checkpoint.result[self._checkpoint_score_attr]
return -priority if self._checkpoint_score_desc else priority

def __getstate__(self):
state = self.__dict__.copy()
# Avoid serializing lambda since it may capture cyclical dependencies.
state.pop("delete")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.delete = None
1 change: 1 addition & 0 deletions python/ray/tune/config_parser.py
Expand Up @@ -183,6 +183,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
local_dir=os.path.join(spec["local_dir"], output_path),
# json.load leads to str -> unicode in py2.7
stopping_criterion=spec.get("stop", {}),
remote_checkpoint_dir=spec.get("remote_checkpoint_dir"),
checkpoint_freq=args.checkpoint_freq,
checkpoint_at_end=args.checkpoint_at_end,
sync_on_checkpoint=not args.no_sync_on_checkpoint,
Expand Down
98 changes: 98 additions & 0 deletions python/ray/tune/durable_trainable.py
@@ -0,0 +1,98 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from ray.tune.trainable import Trainable, TrainableUtil
from ray.tune.syncer import get_cloud_sync_client


class DurableTrainable(Trainable):
"""Abstract class for a remote-storage backed fault-tolerant Trainable.

Supports checkpointing to and restoring from remote storage. To use this
class, implement the same private methods as ray.tune.Trainable (`_save`,
`_train`, `_restore`, `reset_config`, `_setup`, `_stop`).

.. warning:: This class is currently **experimental** and may
be subject to change.

Run this with Tune as follows. Setting `sync_to_driver=False` disables
syncing to the driver to avoid keeping redundant checkpoints around, as
well as preventing the driver from syncing up the same checkpoint.

See ``tune/trainable.py``.

Attributes:
remote_checkpoint_dir (str): Upload directory (S3 or GS path).
storage_client: Tune-internal interface for interacting with external
storage.

>>> tune.run(MyDurableTrainable, sync_to_driver=False)
"""

def __init__(self, remote_checkpoint_dir, *args, **kwargs):
"""Initializes a DurableTrainable.

Args:
remote_checkpoint_dir (str): Upload directory (S3 or GS path).
"""
super(DurableTrainable, self).__init__(*args, **kwargs)
self.remote_checkpoint_dir = remote_checkpoint_dir
self.storage_client = self._create_storage_client()

def save(self, checkpoint_dir=None):
ujvl marked this conversation as resolved.
Show resolved Hide resolved
"""Saves the current model state to a checkpoint, persisted remotely.

The storage client must provide durability for
restoration to work. That is, once ``storage.client.wait()``
returns after a checkpoint `sync up`, the checkpoint is considered
committed and can be used to restore the trainable.

Args:
checkpoint_dir (Optional[str]): Optional dir to place the
checkpoint. Must be ``logdir`` or a sub-directory.

Returns:
Checkpoint path or prefix that may be passed to restore().
"""
if checkpoint_dir:
if checkpoint_dir.starts_with(os.path.abspath(self.logdir)):
raise ValueError("`checkpoint_dir` must be `self.logdir`, or "
"a sub-directory.")

checkpoint_path = super(DurableTrainable, self).save(checkpoint_dir)
self.storage_client.sync_up(self.logdir, self.remote_checkpoint_dir)
self.storage_client.wait()
return checkpoint_path

def restore(self, checkpoint_path):
"""Restores training state from a given checkpoint persisted remotely.

These checkpoints are returned from calls to save().

Args:
checkpoint_path (str): Local path to checkpoint.
"""
self.storage_client.sync_down(self.remote_checkpoint_dir, self.logdir)
self.storage_client.wait()
super(DurableTrainable, self).restore(checkpoint_path)

def delete_checkpoint(self, checkpoint_path):
"""Deletes checkpoint from both local and remote storage.

Args:
checkpoint_path (str): Local path to checkpoint.
"""
super(DurableTrainable, self).delete_checkpoint(checkpoint_path)
local_dirpath = TrainableUtil.find_checkpoint_dir(checkpoint_path)
self.storage_client.delete(self._storage_path(local_dirpath))

def _create_storage_client(self):
"""Returns a storage client."""
return get_cloud_sync_client(self.remote_checkpoint_dir)

def _storage_path(self, local_path):
rel_local_path = os.path.relpath(local_path, self.logdir)
return os.path.join(self.remote_checkpoint_dir, rel_local_path)
126 changes: 126 additions & 0 deletions python/ray/tune/examples/durable_trainable_example.py
@@ -0,0 +1,126 @@
import argparse
import numpy as np
import time
import logging
import os
import ray
from ray import tune
from ray.tune import DurableTrainable
from ray.tune.sync_client import get_sync_client

import cloudpickle

logger = logging.getLogger(__name__)


class MockDurableTrainable(DurableTrainable):
"""Mocks the storage client on initialization to store data locally."""

def __init__(self, remote_checkpoint_dir, *args, **kwargs):
# Mock the path as a local path.
local_dir_suffix = remote_checkpoint_dir.split("://")[1]
remote_checkpoint_dir = os.path.join("/tmp", local_dir_suffix)
# Disallow malformed relative paths for delete safety.
assert os.path.abspath(remote_checkpoint_dir).startswith("/tmp")
logger.info("Using %s as the mocked remote checkpoint directory.",
self.remote_checkpoint_dir)
super(MockDurableTrainable, self).__init__(remote_checkpoint_dir,
*args, **kwargs)

def _create_storage_client(self):
sync = "mkdir -p {target} && rsync -avz {source} {target}"
delete = "rm -rf {target}"
return get_sync_client(sync, delete)


class OptimusFn(object):
def __init__(self, params, max_t=10000):
self.params = params
self.noise = np.random.normal(size=max_t) * 0.005

def eval(self, k, add_noise=True):
b0, b1, b2 = self.params
score = (b0 * k / 100 + 0.1 * b1 + 0.5)**(-1) + b2 * 0.01
if add_noise:
return score + abs(self.noise[k])
else:
return score


def get_optimus_trainable(parent_cls):
class OptimusTrainable(parent_cls):
def _setup(self, config):
self.iter = 0
if config.get("seed"):
np.random.seed(config["seed"])
time.sleep(config.get("startup_delay", 0))
params = [config["param1"], config["param2"], config["param3"]]
self.func = OptimusFn(params=params)
self.initial_samples_per_step = 500
self.mock_data = open("/dev/urandom", "rb").read(1024)

def _train(self):
self.iter += 1
new_loss = self.func.eval(self.iter)
time.sleep(0.5)
return {
"mean_loss": float(new_loss),
"mean_accuracy": (2 - new_loss) / 2,
"samples": self.initial_samples_per_step
}

def _save(self, checkpoint_dir):
time.sleep(0.5)
return {
"func": cloudpickle.dumps(self.func),
"seed": np.random.get_state(),
"data": self.mock_data,
"iter": self.iter
}

def _restore(self, checkpoint):
self.func = cloudpickle.loads(checkpoint["func"])
self.data = checkpoint["data"]
self.iter = checkpoint["iter"]
np.random.set_state(checkpoint["seed"])

return OptimusTrainable


def parse():
parser = argparse.ArgumentParser()
parser.add_argument("--local", action="store_true", default=False)
parser.add_argument("--mock-storage", action="store_true", default=False)
parser.add_argument("--remote-dir", type=str)
return parser.parse_args()


if __name__ == "__main__":
args = parse()
address = None if args.local else "auto"
ray.init(address=address)

config = {
"seed": None,
"startup_delay": 0.001,
"param1": tune.sample_from(lambda spec: np.random.exponential(0.1)),
"param2": tune.sample_from(lambda _: np.random.rand()),
"param3": tune.sample_from(lambda _: np.random.rand()),
}

parent = MockDurableTrainable if args.mock_storage else DurableTrainable
analysis = tune.run(
get_optimus_trainable(parent),
name="durableTrainable" + str(time.time()),
config=config,
num_samples=4,
verbose=1,
queue_trials=True,
# fault tolerance parameters
max_failures=-1,
checkpoint_freq=20,
sync_to_driver=False,
sync_on_checkpoint=False,
upload_dir="s3://ray-tune-test/exps/",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to dummy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MockDurableTrainable mocks it out to a local dir

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll probably want to turn off sync_to_cloud then

checkpoint_score_attr="training_iteration",
)