-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[tune] move logger and syncer handling to callbacks #11699
Closed
Closed
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
cbd4240
Move loggers from Trial to Callbacks
e4747cc
Merge remote-tracking branch 'upstream/master' into tune-logger-callback
c0a245b
Intermediate commit
b944ccc
Create callbacks on tune.run
cb29ed0
Fix syncer IP lookup
79ae993
Update syncer + tests
7959ea1
Merge remote-tracking branch 'upstream/master' into tune-logger-callback
ecde308
Fix cluster tests
9fa3131
fix linter error
acf730f
fix rllib pretty_print import
f273f96
Fix result extra fields
2805d90
Fix run_experiment test
1a18648
Fix stop_logger argument
7053dc0
Fix sync test
cbb3973
Fix lint
14a0f72
Fix `done` result update in trial runner
c5a7008
Move result update
0391c17
Re-order callbacks
382f237
Apply changes from code review
89105af
Move default callback creation
65e2e6f
Better error message
21bbb2d
Flush on trial save/restore
943a8b1
Merge remote-tracking branch 'upstream/master' into tune-logger-callback
f9e6552
Update tests
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -145,7 +145,7 @@ Here is an example of the basic usage (for a more complete example, see `custom_ | |
|
||
import ray | ||
import ray.rllib.agents.ppo as ppo | ||
from ray.tune.logger import pretty_print | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (backwards compat) can we keep this for this PR? |
||
from ray.tune.utils.util import pretty_print | ||
|
||
ray.init() | ||
config = ppo.DEFAULT_CONFIG.copy() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
from typing import Dict, List | ||
|
||
from ray.tune.checkpoint_manager import Checkpoint | ||
from ray.tune.trial import Trial | ||
|
||
|
||
class Callback: | ||
"""Tune base callback that can be extended and passed to a ``TrialRunner`` | ||
|
||
Tune callbacks are called from within the ``TrialRunner`` class. There are | ||
several hooks that can be used, all of which are found in the submethod | ||
definitions of this base class. | ||
|
||
The parameters passed to the ``**info`` dict vary between hooks. The | ||
parameters passed are described in the docstrings of the methods. | ||
|
||
This example will print a metric each time a result is received: | ||
|
||
.. code-block:: python | ||
|
||
from ray import tune | ||
from ray.tune import Callback | ||
|
||
|
||
class MyCallback(Callback): | ||
def on_trial_result(self, iteration, trials, trial, result, | ||
**info): | ||
print(f"Got result: {result['metric']}") | ||
|
||
|
||
def train(config): | ||
for i in range(10): | ||
tune.report(metric=i) | ||
|
||
|
||
tune.run( | ||
train, | ||
callbacks=[MyCallback()]) | ||
|
||
""" | ||
|
||
def on_step_begin(self, iteration: int, trials: List[Trial], **info): | ||
"""Called at the start of each tuning loop step. | ||
|
||
Arguments: | ||
iteration (int): Number of iterations of the tuning loop. | ||
trials (List[Trial]): List of trials. | ||
**info: Kwargs dict for forward compatibility. | ||
""" | ||
pass | ||
|
||
def on_step_end(self, iteration: int, trials: List[Trial], **info): | ||
"""Called at the end of each tuning loop step. | ||
|
||
The iteration counter is increased before this hook is called. | ||
|
||
Arguments: | ||
iteration (int): Number of iterations of the tuning loop. | ||
trials (List[Trial]): List of trials. | ||
**info: Kwargs dict for forward compatibility. | ||
""" | ||
pass | ||
|
||
def on_trial_start(self, iteration: int, trials: List[Trial], trial: Trial, | ||
**info): | ||
"""Called after starting a trial instance. | ||
|
||
Arguments: | ||
iteration (int): Number of iterations of the tuning loop. | ||
trials (List[Trial]): List of trials. | ||
trial (Trial): Trial that just has been started. | ||
**info: Kwargs dict for forward compatibility. | ||
|
||
""" | ||
pass | ||
|
||
def on_trial_restore(self, iteration: int, trials: List[Trial], | ||
trial: Trial, **info): | ||
"""Called after restoring a trial instance. | ||
|
||
Arguments: | ||
iteration (int): Number of iterations of the tuning loop. | ||
trials (List[Trial]): List of trials. | ||
trial (Trial): Trial that just has been restored. | ||
**info: Kwargs dict for forward compatibility. | ||
""" | ||
pass | ||
|
||
def on_trial_save(self, iteration: int, trials: List[Trial], trial: Trial, | ||
**info): | ||
"""Called after receiving a checkpoint from a trial. | ||
|
||
Arguments: | ||
iteration (int): Number of iterations of the tuning loop. | ||
trials (List[Trial]): List of trials. | ||
trial (Trial): Trial that just saved a checkpoint. | ||
**info: Kwargs dict for forward compatibility. | ||
""" | ||
pass | ||
|
||
def on_trial_result(self, iteration: int, trials: List[Trial], | ||
trial: Trial, result: Dict, **info): | ||
"""Called after receiving a result from a trial. | ||
|
||
The search algorithm and scheduler are notified before this | ||
hook is called. | ||
|
||
Arguments: | ||
iteration (int): Number of iterations of the tuning loop. | ||
trials (List[Trial]): List of trials. | ||
trial (Trial): Trial that just sent a result. | ||
result (Dict): Result that the trial sent. | ||
**info: Kwargs dict for forward compatibility. | ||
""" | ||
pass | ||
|
||
def on_trial_complete(self, iteration: int, trials: List[Trial], | ||
trial: Trial, **info): | ||
"""Called after a trial instance completed. | ||
|
||
The search algorithm and scheduler are notified before this | ||
hook is called. | ||
|
||
Arguments: | ||
iteration (int): Number of iterations of the tuning loop. | ||
trials (List[Trial]): List of trials. | ||
trial (Trial): Trial that just has been completed. | ||
**info: Kwargs dict for forward compatibility. | ||
""" | ||
pass | ||
|
||
def on_trial_error(self, iteration: int, trials: List[Trial], trial: Trial, | ||
**info): | ||
"""Called after a trial instance failed (errored). | ||
|
||
The search algorithm and scheduler are notified before this | ||
hook is called. | ||
|
||
Arguments: | ||
iteration (int): Number of iterations of the tuning loop. | ||
trials (List[Trial]): List of trials. | ||
trial (Trial): Trial that just has errored. | ||
**info: Kwargs dict for forward compatibility. | ||
""" | ||
pass | ||
|
||
def on_checkpoint(self, iteration: int, trials: List[Trial], trial: Trial, | ||
checkpoint: Checkpoint, **info): | ||
"""Called after a trial saved a checkpoint with Tune. | ||
|
||
Arguments: | ||
iteration (int): Number of iterations of the tuning loop. | ||
trials (List[Trial]): List of trials. | ||
trial (Trial): Trial that just has errored. | ||
checkpoint (Checkpoint): Checkpoint object that has been saved | ||
by the trial. | ||
**info: Kwargs dict for forward compatibility. | ||
""" | ||
pass | ||
|
||
|
||
class CallbackList: | ||
"""Call multiple callbacks at once.""" | ||
|
||
def __init__(self, callbacks: List[Callback]): | ||
self._callbacks = callbacks | ||
|
||
def on_step_begin(self, **info): | ||
for callback in self._callbacks: | ||
callback.on_step_begin(**info) | ||
|
||
def on_step_end(self, **info): | ||
for callback in self._callbacks: | ||
callback.on_step_end(**info) | ||
|
||
def on_trial_start(self, **info): | ||
for callback in self._callbacks: | ||
callback.on_trial_start(**info) | ||
|
||
def on_trial_restore(self, **info): | ||
for callback in self._callbacks: | ||
callback.on_trial_restore(**info) | ||
|
||
def on_trial_save(self, **info): | ||
for callback in self._callbacks: | ||
callback.on_trial_save(**info) | ||
|
||
def on_trial_result(self, **info): | ||
for callback in self._callbacks: | ||
callback.on_trial_result(**info) | ||
|
||
def on_trial_complete(self, **info): | ||
for callback in self._callbacks: | ||
callback.on_trial_complete(**info) | ||
|
||
def on_trial_error(self, **info): | ||
for callback in self._callbacks: | ||
callback.on_trial_error(**info) | ||
|
||
def on_checkpoint(self, **info): | ||
for callback in self._callbacks: | ||
callback.on_checkpoint(**info) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,7 @@ | |
from ray.tune import TuneError | ||
from ray.tune.trial import Trial | ||
from ray.tune.resources import json_to_resources | ||
from ray.tune.logger import _SafeFallbackEncoder | ||
from ray.tune.utils.util import SafeFallbackEncoder | ||
|
||
|
||
def make_parser(parser_creator=None, **kwargs): | ||
|
@@ -143,7 +143,7 @@ def to_argv(config): | |
elif isinstance(v, bool): | ||
pass | ||
else: | ||
argv.append(json.dumps(v, cls=_SafeFallbackEncoder)) | ||
argv.append(json.dumps(v, cls=SafeFallbackEncoder)) | ||
return argv | ||
|
||
|
||
|
@@ -182,17 +182,14 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs): | |
remote_checkpoint_dir=spec.get("remote_checkpoint_dir"), | ||
checkpoint_freq=args.checkpoint_freq, | ||
checkpoint_at_end=args.checkpoint_at_end, | ||
sync_on_checkpoint=args.sync_on_checkpoint, | ||
keep_checkpoints_num=args.keep_checkpoints_num, | ||
checkpoint_score_attr=args.checkpoint_score_attr, | ||
export_formats=spec.get("export_formats", []), | ||
# str(None) doesn't create None | ||
restore_path=spec.get("restore"), | ||
trial_name_creator=spec.get("trial_name_creator"), | ||
trial_dirname_creator=spec.get("trial_dirname_creator"), | ||
loggers=spec.get("loggers"), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. interesting... i think this is OK but will continue reading |
||
log_to_file=spec.get("log_to_file"), | ||
# str(None) doesn't create None | ||
sync_to_driver_fn=spec.get("sync_to_driver"), | ||
max_failures=args.max_failures, | ||
**trial_kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(backwards compat) can we keep this for this PR?