Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune] move logger and syncer handling to callbacks #11699

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/raysgd/raysgd_pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ You can also obtain profiling information:

.. code-block:: python

>>> from ray.tune.logger import pretty_print
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(backwards compat) can we keep this for this PR?

>>> from ray.tune.utils.util import pretty_print
>>> print(pretty_print(trainer.train(profile=True)))

batch_count: 16
Expand Down
2 changes: 1 addition & 1 deletion doc/source/rllib-training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ Here is an example of the basic usage (for a more complete example, see `custom_

import ray
import ray.rllib.agents.ppo as ppo
from ray.tune.logger import pretty_print
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(backwards compat) can we keep this for this PR?

from ray.tune.utils.util import pretty_print

ray.init()
config = ppo.DEFAULT_CONFIG.copy()
Expand Down
2 changes: 1 addition & 1 deletion doc/source/tune/api_docs/internals.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ Trial
Callbacks
---------

.. autoclass:: ray.tune.trial_runner.Callback
.. autoclass:: ray.tune.callback.Callback
:members:


Expand Down
11 changes: 10 additions & 1 deletion doc/source/tune/user-guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -635,9 +635,18 @@ These are the environment variables Ray Tune currently considers:
* **TUNE_CLUSTER_SSH_KEY**: SSH key used by the Tune driver process to connect
to remote cluster machines for checkpoint syncing. If this is not set,
``~/ray_bootstrap_key.pem`` will be used.
* **TUNE_DISABLE_AUTO_CALLBACK_LOGGERS**: Ray Tune automatically adds a CSV and
JSON logger callback if they haven't been passed. Setting this variable to
`1` disables this automatic creation. Please note that this will most likely
affect analyzing your results after the tuning run.
* **TUNE_DISABLE_AUTO_CALLBACK_SYNCER**: Ray Tune automatically adds a
Syncer callback to sync logs and checkpoints between different nodes if none
has been passed. Setting this variable to `1` disables this automatic creation.
Please note that this will most likely affect advanced scheduling algorithms
like PopulationBasedTraining.
* **TUNE_DISABLE_AUTO_INIT**: Disable automatically calling ``ray.init()`` if
not attached to a Ray session.
* **TUNE_DISABLE_DATED_SUBDIR**: Tune automatically adds a date string to experiment
* **TUNE_DISABLE_DATED_SUBDIR**: Ray Tune automatically adds a date string to experiment
directories when the name is not specified explicitly or the trainable isn't passed
as a string. Setting this environment variable to ``1`` disables adding these date strings.
* **TUNE_DISABLE_STRICT_METRIC_CHECKING**: When you report metrics to Tune via
Expand Down
4 changes: 2 additions & 2 deletions python/ray/tune/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ray.tune.registry import register_env, register_trainable
from ray.tune.trainable import Trainable
from ray.tune.durable_trainable import DurableTrainable
from ray.tune.trial_runner import Callback
from ray.tune.callback import Callback
from ray.tune.suggest import grid_search
from ray.tune.session import (
report, get_trial_dir, get_trial_name, get_trial_id, make_checkpoint_dir,
Expand All @@ -22,7 +22,7 @@
from ray.tune.schedulers import create_scheduler

__all__ = [
"Trainable", "DurableTrainable", "TuneError", "Callback", "grid_search",
"Trainable", "DurableTrainable", "Callback", "TuneError", "grid_search",
"register_env", "register_trainable", "run", "run_experiments",
"with_parameters", "Stopper", "EarlyStopping", "Experiment", "function",
"sample_from", "track", "uniform", "quniform", "choice", "randint",
Expand Down
2 changes: 1 addition & 1 deletion python/ray/tune/analysis/experiment_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ray.tune.result import EXPR_PROGRESS_FILE, EXPR_PARAM_FILE,\
CONFIG_PREFIX, TRAINING_ITERATION
from ray.tune.trial import Trial
from ray.tune.trainable import TrainableUtil
from ray.tune.utils.trainable import TrainableUtil

logger = logging.getLogger(__name__)

Expand Down
202 changes: 202 additions & 0 deletions python/ray/tune/callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
from typing import Dict, List

from ray.tune.checkpoint_manager import Checkpoint
from ray.tune.trial import Trial


class Callback:
"""Tune base callback that can be extended and passed to a ``TrialRunner``

Tune callbacks are called from within the ``TrialRunner`` class. There are
several hooks that can be used, all of which are found in the submethod
definitions of this base class.

The parameters passed to the ``**info`` dict vary between hooks. The
parameters passed are described in the docstrings of the methods.

This example will print a metric each time a result is received:

.. code-block:: python

from ray import tune
from ray.tune import Callback


class MyCallback(Callback):
def on_trial_result(self, iteration, trials, trial, result,
**info):
print(f"Got result: {result['metric']}")


def train(config):
for i in range(10):
tune.report(metric=i)


tune.run(
train,
callbacks=[MyCallback()])

"""

def on_step_begin(self, iteration: int, trials: List[Trial], **info):
"""Called at the start of each tuning loop step.

Arguments:
iteration (int): Number of iterations of the tuning loop.
trials (List[Trial]): List of trials.
**info: Kwargs dict for forward compatibility.
"""
pass

def on_step_end(self, iteration: int, trials: List[Trial], **info):
"""Called at the end of each tuning loop step.

The iteration counter is increased before this hook is called.

Arguments:
iteration (int): Number of iterations of the tuning loop.
trials (List[Trial]): List of trials.
**info: Kwargs dict for forward compatibility.
"""
pass

def on_trial_start(self, iteration: int, trials: List[Trial], trial: Trial,
**info):
"""Called after starting a trial instance.

Arguments:
iteration (int): Number of iterations of the tuning loop.
trials (List[Trial]): List of trials.
trial (Trial): Trial that just has been started.
**info: Kwargs dict for forward compatibility.

"""
pass

def on_trial_restore(self, iteration: int, trials: List[Trial],
trial: Trial, **info):
"""Called after restoring a trial instance.

Arguments:
iteration (int): Number of iterations of the tuning loop.
trials (List[Trial]): List of trials.
trial (Trial): Trial that just has been restored.
**info: Kwargs dict for forward compatibility.
"""
pass

def on_trial_save(self, iteration: int, trials: List[Trial], trial: Trial,
**info):
"""Called after receiving a checkpoint from a trial.

Arguments:
iteration (int): Number of iterations of the tuning loop.
trials (List[Trial]): List of trials.
trial (Trial): Trial that just saved a checkpoint.
**info: Kwargs dict for forward compatibility.
"""
pass

def on_trial_result(self, iteration: int, trials: List[Trial],
trial: Trial, result: Dict, **info):
"""Called after receiving a result from a trial.

The search algorithm and scheduler are notified before this
hook is called.

Arguments:
iteration (int): Number of iterations of the tuning loop.
trials (List[Trial]): List of trials.
trial (Trial): Trial that just sent a result.
result (Dict): Result that the trial sent.
**info: Kwargs dict for forward compatibility.
"""
pass

def on_trial_complete(self, iteration: int, trials: List[Trial],
trial: Trial, **info):
"""Called after a trial instance completed.

The search algorithm and scheduler are notified before this
hook is called.

Arguments:
iteration (int): Number of iterations of the tuning loop.
trials (List[Trial]): List of trials.
trial (Trial): Trial that just has been completed.
**info: Kwargs dict for forward compatibility.
"""
pass

def on_trial_error(self, iteration: int, trials: List[Trial], trial: Trial,
**info):
"""Called after a trial instance failed (errored).

The search algorithm and scheduler are notified before this
hook is called.

Arguments:
iteration (int): Number of iterations of the tuning loop.
trials (List[Trial]): List of trials.
trial (Trial): Trial that just has errored.
**info: Kwargs dict for forward compatibility.
"""
pass

def on_checkpoint(self, iteration: int, trials: List[Trial], trial: Trial,
checkpoint: Checkpoint, **info):
"""Called after a trial saved a checkpoint with Tune.

Arguments:
iteration (int): Number of iterations of the tuning loop.
trials (List[Trial]): List of trials.
trial (Trial): Trial that just has errored.
checkpoint (Checkpoint): Checkpoint object that has been saved
by the trial.
**info: Kwargs dict for forward compatibility.
"""
pass


class CallbackList:
"""Call multiple callbacks at once."""

def __init__(self, callbacks: List[Callback]):
self._callbacks = callbacks

def on_step_begin(self, **info):
for callback in self._callbacks:
callback.on_step_begin(**info)

def on_step_end(self, **info):
for callback in self._callbacks:
callback.on_step_end(**info)

def on_trial_start(self, **info):
for callback in self._callbacks:
callback.on_trial_start(**info)

def on_trial_restore(self, **info):
for callback in self._callbacks:
callback.on_trial_restore(**info)

def on_trial_save(self, **info):
for callback in self._callbacks:
callback.on_trial_save(**info)

def on_trial_result(self, **info):
for callback in self._callbacks:
callback.on_trial_result(**info)

def on_trial_complete(self, **info):
for callback in self._callbacks:
callback.on_trial_complete(**info)

def on_trial_error(self, **info):
for callback in self._callbacks:
callback.on_trial_error(**info)

def on_checkpoint(self, **info):
for callback in self._callbacks:
callback.on_checkpoint(**info)
7 changes: 2 additions & 5 deletions python/ray/tune/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ray.tune import TuneError
from ray.tune.trial import Trial
from ray.tune.resources import json_to_resources
from ray.tune.logger import _SafeFallbackEncoder
from ray.tune.utils.util import SafeFallbackEncoder


def make_parser(parser_creator=None, **kwargs):
Expand Down Expand Up @@ -143,7 +143,7 @@ def to_argv(config):
elif isinstance(v, bool):
pass
else:
argv.append(json.dumps(v, cls=_SafeFallbackEncoder))
argv.append(json.dumps(v, cls=SafeFallbackEncoder))
return argv


Expand Down Expand Up @@ -182,17 +182,14 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
remote_checkpoint_dir=spec.get("remote_checkpoint_dir"),
checkpoint_freq=args.checkpoint_freq,
checkpoint_at_end=args.checkpoint_at_end,
sync_on_checkpoint=args.sync_on_checkpoint,
keep_checkpoints_num=args.keep_checkpoints_num,
checkpoint_score_attr=args.checkpoint_score_attr,
export_formats=spec.get("export_formats", []),
# str(None) doesn't create None
restore_path=spec.get("restore"),
trial_name_creator=spec.get("trial_name_creator"),
trial_dirname_creator=spec.get("trial_dirname_creator"),
loggers=spec.get("loggers"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting... i think this is OK but will continue reading

log_to_file=spec.get("log_to_file"),
# str(None) doesn't create None
sync_to_driver_fn=spec.get("sync_to_driver"),
max_failures=args.max_failures,
**trial_kwargs)
11 changes: 11 additions & 0 deletions python/ray/tune/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,17 @@ def __init__(self,
max_failures=0,
restore=None):

if loggers is not None:
# Most users won't run into this as `tune.run()` does not pass
# the argument anymore. However, we will want to inform users
# if they instantiate their `Experiment` objects themselves.
raise ValueError(
"Passing `loggers` to an `Experiment` is deprecated. Use "
"an `ExperimentLogger` callback instead, e.g. by passing the "
"`Logger` classes to `tune.logger.LegacyExperimentLogger` and "
"passing this as part of the `callback` parameter to "
"`tune.run()`.")

config = config or {}
if callable(run) and detect_checkpoint_function(run):
if checkpoint_at_end:
Expand Down
2 changes: 1 addition & 1 deletion python/ray/tune/integration/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import ray
from ray import tune
from ray.tune.resources import Resources
from ray.tune.trainable import TrainableUtil
from ray.tune.utils.trainable import TrainableUtil
from ray.tune.result import RESULT_DUPLICATE
from ray.tune.logger import NoopLogger

Expand Down
2 changes: 1 addition & 1 deletion python/ray/tune/integration/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ray.tune.logger import NoopLogger
from ray.tune.function_runner import wrap_function
from ray.tune.resources import Resources
from ray.tune.trainable import TrainableUtil
from ray.tune.utils.trainable import TrainableUtil
from ray.tune.utils import detect_checkpoint_function
from ray.util.sgd.torch.utils import setup_process_group, setup_address
from ray.util.sgd.torch.constants import NCCL_TIMEOUT_S
Expand Down
Loading