-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[train] Storage refactor: Support PBT and BOHB #38736
[train] Storage refactor: Support PBT and BOHB #38736
Conversation
Signed-off-by: Kai Fricke <kai@anyscale.com>
Signed-off-by: Kai Fricke <kai@anyscale.com>
# Conflicts: # python/ray/tune/utils/release_test_util.py
Signed-off-by: Kai Fricke <kai@anyscale.com>
Signed-off-by: Kai Fricke <kai@anyscale.com>
Signed-off-by: Kai Fricke <kai@anyscale.com>
Signed-off-by: Kai Fricke <kai@anyscale.com>
Signed-off-by: Kai Fricke <kai@anyscale.com>
Signed-off-by: Kai Fricke <kai@anyscale.com>
Signed-off-by: Kai Fricke <kai@anyscale.com>
Signed-off-by: Kai Fricke <kai@anyscale.com>
# Conflicts: # python/ray/tune/execution/tune_controller.py
Signed-off-by: Kai Fricke <kai@anyscale.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leaving some review guidance
@@ -286,6 +286,8 @@ def __init__( | |||
|
|||
# Used for keeping top K checkpoints. | |||
self._top_persisted_checkpoints: List[_HeapCheckpointWrapper] = [] | |||
# Also keep a set of all existing checkpoints | |||
self._all_persisted_checkpoint_data: Set[_TrackedCheckpoint] = set() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We now may schedule multiple saves subsequently (e.g. save is scheduled, trial is instructed to pause which also schedules a save, or maybe PBT schedules an additional save).
Previously, these additional saves were memory checkpoints that were not registered. We want to safeguard here against the same checkpoint being registered multiple times.
@@ -80,6 +80,31 @@ class TrainingResult: | |||
metadata: Optional[Dict] = None | |||
|
|||
|
|||
class _FutureTrainingResult: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This wrapper is only used by PBT and it serves a similar purpose as the previous future in _TrackedCheckpoint.
It's not used anywhere else (except being created in the tune controller).
@@ -1,3 +1,6 @@ | |||
import os |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This just changes the API to the new Checkpoint API
@@ -1,5 +1,6 @@ | |||
import argparse | |||
import os | |||
import tempfile |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, this just changes the API to the new Checkpoint API
self._callbacks.on_trial_recover( | ||
iteration=self._iteration, trials=self._trials, trial=trial | ||
) | ||
elif trial.status in {Trial.RUNNING, Trial.PENDING}: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trials now may be PENDING and failing - e.g. when a save resolves late
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is if a trial is paused, schedules a save, then gets unpaused (set to pending) and then the save fails?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's right!
if trial.temporary_state.saving_to: | ||
# If a save is already in progress, don't schedule another one. | ||
return trial.temporary_state.saving_to | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also de-dupes checkpoints
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
old codepath
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also add a typehint to saving_to in the _TemporaryTrialState
class? It's always a _FutureTrainingResult
now
|
||
if checkpoint.dir_or_data is None: | ||
logger.debug(f"Not restoring trial {trial}: No checkpoint found.") | ||
return False | ||
|
||
kwargs = {} | ||
|
||
if checkpoint.storage_mode == CheckpointStorage.MEMORY: | ||
if checkpoint.storage_mode == CheckpointStorage.MEMORY or isinstance( | ||
checkpoint.dir_or_data, ray.ObjectRef |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We schedule some restores directly from the object ref
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
old codepath
trial.status == Trial.PAUSED | ||
and trial in bracket.trials_to_unpause |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a change in the hyperband - instead of directly setting the trial status to PENDING, we add them to the trials_to_unpause
set. Then the scheduler will selet them in choose_trial_to_run
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice I like this.
) | ||
|
||
if isinstance(last_checkpoint, _FutureTrainingResult): | ||
training_result = last_checkpoint.resolve() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we actually synchronously wait for the save to resolve. This may impact performance slightly but let's see if it's actually a problem. We only wait for trials we actually exploit and we do need the attached result.
Signed-off-by: Kai Fricke <kai@anyscale.com>
Signed-off-by: Kai Fricke <kai@anyscale.com>
Signed-off-by: Kai Fricke <kai@anyscale.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's good to go.
I have one suggestion for clarity that I'd like to get in but not 100% needed.
Basically, let's make PBTTrialState.last_checkpoint
always either a _FutureTrainingResult
or a _TrainingResult
. Right now it can be those as well as a Checkpoint
.
Basically here, just set clone_state.last_checkpoint = training_result
:
ray/python/ray/tune/schedulers/pbt.py
Lines 698 to 712 in c3bf12b
last_checkpoint = clone_state.last_checkpoint | |
logger.debug( | |
f"Trial {trial} is in lower quantile. " | |
f"Exploiting trial {trial_to_clone}." | |
) | |
if isinstance(last_checkpoint, _FutureTrainingResult): | |
training_result = last_checkpoint.resolve() | |
if training_result: | |
clone_state.last_result = training_result.metrics | |
clone_state.last_checkpoint = training_result.checkpoint | |
last_checkpoint = clone_state.last_checkpoint | |
else: |
Then, no need to construct a new "training result" on exploit.
It makes sense, but it's a bit confusing (why is a Let's keep it as is for now (the whole situation is very confusing right now anyway - two different checkpoint, three checkpoint managers, training results, tracked checkpoints...). But I'd like to specifically make time to clean up the leftovers from the old code path. |
Ok, good with me. Here's what should remain after everything is all cleaned up:
|
# Conflicts: # python/ray/tune/experiment/trial.py
Signed-off-by: Kai Fricke <kai@anyscale.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Necessary change for Train
Signed-off-by: Kai Fricke <kai@anyscale.com> Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
Signed-off-by: Kai Fricke <kai@anyscale.com>
Signed-off-by: Kai Fricke <kai@anyscale.com> Signed-off-by: Jim Thompson <jimthompson5802@gmail.com>
Signed-off-by: Kai Fricke <kai@anyscale.com> Signed-off-by: Victor <vctr.y.m@example.com>
Why are these changes needed?
Previously PBT and BOHB do not work with the new storage context. This is because they directly affect the Ray Tune control flow and rely on specific behavior.
In particular, PBT and BOHB made heavy use of pausing trials, and PBT also saved and restored from "objects".
However, the new storage backend does not support memory checkpoints anymore. Saving and restoring from objects was also removed. This means PBT and BOHB have to be adjusted and the pausing logic has to be revamped.
For BOHB/Hypberband in general, we now avoid controlling trial status directly, and instead rely on
choose_trial_to_run
to unpause trials. This means the tune control loop has greater control over the pausing logic.In the tune control loop, we now detect double saves. If a save() future is scheduled while another one is in-flight (which can happen in PBT), we don't schedule another save.
In PBT, we now schedule persistent checkpoints instead of memory checkpoints. The main difference here is that persistent checkpoints may get deleted before the exploiting trial gets a chance to load from it. For this reason, we detect too small
num_to_keep
values and log a warning.Related issue number
Closes #38569
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.