Skip to content

Commit

Permalink
[tune] Cluster Fault Tolerance (#3309)
Browse files Browse the repository at this point in the history
This PR introduces cluster-level fault tolerance for Tune by checkpointing global state. This occurs with relatively high frequency and allows users to easily resume experiments when the cluster crashes.

Note that this PR may affect automated workflows due to auto-prompting, but this is resolvable.
  • Loading branch information
richardliaw committed Dec 29, 2018
1 parent 382b138 commit aad3c50
Show file tree
Hide file tree
Showing 16 changed files with 805 additions and 127 deletions.
25 changes: 25 additions & 0 deletions doc/source/tune-usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,31 @@ of a trial, you can additionally set the checkpoint_at_end to True. An example i
},
})
Recovering From Failures (Experimental)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Tune automatically persists the progress of your experiments, so if an experiment crashes or is otherwise cancelled, it can be resumed after prompting. The default setting of `resume=None` will cause Tune to prompt you for whether you want to resume. Prompting can be turned off with ``resume=True``. If ``resume=False``, a new experiment will be created instead. You can always force a new experiment to be created by changing the experiment name.

Note that trials will be restored to their last checkpoint. If trial checkpointing is not enabled, unfinished trials will be restarted from scratch.

E.g.:

.. code-block:: python
run_experiments({
"my_experiment_name": {
"run": my_trainable
"checkpoint_freq": 10,
"local_dir": "~/path/to/results"
},
}, resume=True)
Upon a second run, this will restore the entire experiment state from ``~/path/to/results/my_experiment_name``. Importantly, any changes to the experiment specification upon resume will be ignored.

This feature is still experimental, so any provided Trial Scheduler or Search Algorithm will not be preserved. Only ``FIFOScheduler`` and ``BasicVariantGenerator`` will be supported.


Handling Large Datasets
-----------------------

Expand Down
10 changes: 8 additions & 2 deletions python/ray/rllib/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

import ray
from ray.test.cluster_utils import Cluster
from ray.tune.config_parser import make_parser, resources_to_json
from ray.tune.config_parser import make_parser
from ray.tune.trial import resources_to_json
from ray.tune.tune import _make_scheduler, run_experiments

EXAMPLE_USAGE = """
Expand Down Expand Up @@ -70,6 +71,10 @@ def create_parser(parser_creator=None):
default="default",
type=str,
help="Name of the subdirectory under `local_dir` to put results in.")
parser.add_argument(
"--resume",
action="store_true",
help="Whether to attempt to resume previous Tune experiments.")
parser.add_argument(
"--env", default=None, type=str, help="The gym environment to use.")
parser.add_argument(
Expand Down Expand Up @@ -138,7 +143,8 @@ def run(args, parser):
run_experiments(
experiments,
scheduler=_make_scheduler(args),
queue_trials=args.queue_trials)
queue_trials=args.queue_trials,
resume=args.resume)


if __name__ == "__main__":
Expand Down
4 changes: 3 additions & 1 deletion python/ray/test/cluster_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def connect(self, head_node_args):
assert not self.connected
redis_password = head_node_args.get("redis_password")
output_info = ray.init(
redis_address=self.redis_address, redis_password=redis_password)
ignore_reinit_error=True,
redis_address=self.redis_address,
redis_password=redis_password)
logger.info(output_info)
self.connected = True

Expand Down
32 changes: 1 addition & 31 deletions python/ray/tune/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,40 +11,10 @@

from ray.tune import TuneError
from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.trial import Resources, Trial
from ray.tune.trial import Trial, json_to_resources
from ray.tune.logger import _SafeFallbackEncoder


def json_to_resources(data):
if data is None or data == "null":
return None
if isinstance(data, string_types):
data = json.loads(data)
for k in data:
if k in ["driver_cpu_limit", "driver_gpu_limit"]:
raise TuneError(
"The field `{}` is no longer supported. Use `extra_cpu` "
"or `extra_gpu` instead.".format(k))
if k not in Resources._fields:
raise TuneError(
"Unknown resource type {}, must be one of {}".format(
k, Resources._fields))
return Resources(
data.get("cpu", 1), data.get("gpu", 0), data.get("extra_cpu", 0),
data.get("extra_gpu", 0))


def resources_to_json(resources):
if resources is None:
return None
return {
"cpu": resources.cpu,
"gpu": resources.gpu,
"extra_cpu": resources.extra_cpu,
"extra_gpu": resources.extra_gpu,
}


def make_parser(parser_creator=None, **kwargs):
"""Returns a base argument parser for the ray.tune tool.
Expand Down
12 changes: 6 additions & 6 deletions python/ray/tune/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

import copy
import logging
import os
import six
import types

from ray.tune.error import TuneError
from ray.tune.log_sync import validate_sync_function
from ray.tune.registry import register_trainable
from ray.tune.result import DEFAULT_RESULTS_DIR

Expand Down Expand Up @@ -122,7 +122,6 @@ def __init__(self,
restore=None,
repeat=None,
trial_resources=None):
validate_sync_function(sync_function)
if sync_function:
assert upload_dir, "Need `upload_dir` if sync_function given."

Expand All @@ -134,16 +133,16 @@ def __init__(self,
resources_per_trial = trial_resources

spec = {
"run": self._register_if_needed(run),
"run": Experiment._register_if_needed(run),
"stop": stop or {},
"config": config or {},
"resources_per_trial": resources_per_trial,
"num_samples": num_samples,
"local_dir": local_dir or DEFAULT_RESULTS_DIR,
"local_dir": os.path.expanduser(local_dir or DEFAULT_RESULTS_DIR),
"upload_dir": upload_dir or "", # argparse converts None to "null"
"trial_name_creator": trial_name_creator,
"custom_loggers": custom_loggers,
"sync_function": sync_function or "", # See `upload_dir`.
"sync_function": sync_function,
"checkpoint_freq": checkpoint_freq,
"checkpoint_at_end": checkpoint_at_end,
"max_failures": max_failures,
Expand Down Expand Up @@ -180,7 +179,8 @@ def from_json(cls, name, spec):
raise TuneError("Improper argument from JSON: {}.".format(spec))
return exp

def _register_if_needed(self, run_object):
@classmethod
def _register_if_needed(cls, run_object):
"""Registers Trainable or Function at runtime.
Assumes already registered if run_object is a string. Does not
Expand Down
30 changes: 20 additions & 10 deletions python/ray/tune/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,19 +106,19 @@ def _init(self):
self.logdir, self.uri, sync_function=self._sync_function)

def on_result(self, result):
for logger in self._loggers:
logger.on_result(result)
for _logger in self._loggers:
_logger.on_result(result)
self._log_syncer.set_worker_ip(result.get(NODE_IP))
self._log_syncer.sync_if_needed()

def close(self):
for logger in self._loggers:
logger.close()
for _logger in self._loggers:
_logger.close()
self._log_syncer.sync_now(force=True)

def flush(self):
for logger in self._loggers:
logger.flush()
for _logger in self._loggers:
_logger.flush()
self._log_syncer.sync_now(force=True)
self._log_syncer.wait()

Expand All @@ -142,7 +142,7 @@ def _init(self):
with open(config_pkl, "wb") as f:
cloudpickle.dump(self.config, f)
local_file = os.path.join(self.logdir, "result.json")
self.local_out = open(local_file, "w")
self.local_out = open(local_file, "a")

def on_result(self, result):
json.dump(result, self, cls=_SafeFallbackEncoder)
Expand All @@ -152,6 +152,9 @@ def write(self, b):
self.local_out.write(b)
self.local_out.flush()

def flush(self):
self.local_out.flush()

def close(self):
self.local_out.close()

Expand Down Expand Up @@ -182,7 +185,8 @@ def on_result(self, result):
for k in [
"config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION
]:
del tmp[k] # not useful to tf log these
if k in tmp:
del tmp[k] # not useful to tf log these
values = to_tf_values(tmp, ["ray", "tune"])
train_stats = tf.Summary(value=values)
t = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
Expand All @@ -205,15 +209,21 @@ class _VisKitLogger(Logger):
def _init(self):
"""CSV outputted with Headers as first set of results."""
# Note that we assume params.json was already created by JsonLogger
self._file = open(os.path.join(self.logdir, "progress.csv"), "w")
progress_file = os.path.join(self.logdir, "progress.csv")
self._continuing = os.path.exists(progress_file)
self._file = open(progress_file, "a")
self._csv_out = None

def on_result(self, result):
if self._csv_out is None:
self._csv_out = csv.DictWriter(self._file, result.keys())
self._csv_out.writeheader()
if not self._continuing:
self._csv_out.writeheader()
self._csv_out.writerow(result.copy())

def flush(self):
self._file.flush()

def close(self):
self._file.close()

Expand Down
20 changes: 11 additions & 9 deletions python/ray/tune/ray_trial_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def _setup_runner(self, trial):
num_gpus=trial.resources.gpu)(trial._get_trainable_cls())

trial.init_logger()
# We checkpoint metadata here to try mitigating logdir duplication
self.try_checkpoint_metadata(trial)
remote_logdir = trial.logdir

def logger_creator(config):
Expand All @@ -60,7 +62,7 @@ def _train(self, trial):

def _start_trial(self, trial, checkpoint=None):
prior_status = trial.status
trial.status = Trial.RUNNING
self.set_status(trial, Trial.RUNNING)
trial.runner = self._setup_runner(trial)
if not self.restore(trial, checkpoint):
return
Expand All @@ -87,10 +89,13 @@ def _stop_trial(self, trial, error=False, error_msg=None,
stop_logger (bool): Whether to shut down the trial logger.
"""

if stop_logger:
trial.close_logger()

if error:
trial.status = Trial.ERROR
self.set_status(trial, Trial.ERROR)
else:
trial.status = Trial.TERMINATED
self.set_status(trial, Trial.TERMINATED)

try:
trial.write_error_log(error_msg)
Expand All @@ -103,13 +108,10 @@ def _stop_trial(self, trial, error=False, error_msg=None,
stop_tasks, num_returns=2, timeout=250)
except Exception:
logger.exception("Error stopping runner.")
trial.status = Trial.ERROR
self.set_status(trial, Trial.ERROR)
finally:
trial.runner = None

if stop_logger:
trial.close_logger()

def start_trial(self, trial, checkpoint=None):
"""Starts the trial.
Expand Down Expand Up @@ -302,7 +304,7 @@ def restore(self, trial, checkpoint=None):
return True
if trial.runner is None:
logger.error("Unable to restore - no runner.")
trial.status = Trial.ERROR
self.set_status(trial, Trial.ERROR)
return False
try:
value = checkpoint.value
Expand All @@ -316,5 +318,5 @@ def restore(self, trial, checkpoint=None):
return True
except Exception:
logger.exception("Error restoring runner.")
trial.status = Trial.ERROR
self.set_status(trial, Trial.ERROR)
return False
Loading

0 comments on commit aad3c50

Please sign in to comment.