-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[tune] Fix hyperband r calculation and stopping #39157
Conversation
Signed-off-by: Kai Fricke <kai@anyscale.com>
Signed-off-by: Kai Fricke <kai@anyscale.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these fixes make sense. I'm not so familiar with our hyperband implementation so would need to spend more time on this for a full review.
If tests are passing, happy to get this in now (if aiming for 2.7)
@@ -276,7 +276,7 @@ def _process_bracket( | |||
# kill bad trials | |||
self._num_stopped += len(bad) | |||
for t in bad: | |||
if t.status == Trial.PAUSED: | |||
if t.status == Trial.PAUSED or t.is_saving: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May lose the checkpoint here if we stop while saving?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is avoiding the situation when the trial is in the middle of pausing:
if should_checkpoint:
self._cached_trial_decisions[trial.trial_id] = TrialScheduler.PAUSE
future_result = self._schedule_trial_save(
trial=trial, storage=CheckpointStorage.PERSISTENT
)
trial.temporary_state.saving_to = future_result
where the status is technically still "RUNNING" while the save is resolving.
I think it's fine to miss the checkpoint here. The trial is bad anyway and is going to be early-terminated.
Tests seem to be passing fine - and yes, let's pick for 2.7 as it fixes behavior relating to the new storage path |
This PR fixes two flaws in the current hyperband implementation. #### 1. Bug in the `r` calculation. ray-project#1620 introduced a minimum constraint in the `r` calculation during successive halving with `r = min(r, max_t - prev_r)`. It's unclear to me where this is coming from (cc @richardliaw), but in my opinion this is flawed. E.g. for `s=1`, `max_t=8` and `eta=2`, we get `r0 = 4`. Then `r = r0 * 2 = 8`. With the current formula, we then get `r1 = min(r, max_t - r0) = min(8, 8-4) = 4`. Thus, `r0 = r1` and the bracket already "finished" after 4 iterations and should terminate all trials. Or in other words, none of the trials in this bracket will ever proceed. I believe the correct fix here is to set `r = min(r, max_t)`. I couldn't find a reference implementation for comparison, but it logically makes sense and seems to match the formula behavior described in the paper. #### 2. Stopping of "overstepped" trials. The first bug revealed a second shortcoming in the current implementation. When a trial reports a timestep that is higher than `r_i` and `r_(i+1)`, it can hang forever. This is because "good" trials are only continued if `bracket.continue_trial(t)` returns `True`. However, if the trial already overstepped `r_(i+1)`, this can return `False` (specifically when `stop_last_trials=True`). In that case, a paused or running trial will not be terminated nor continued, and instead hang forever. This second case is fixed in this PR by introducing another clause in the processing of "good" trials that checks for this condition. Signed-off-by: Kai Fricke <kai@anyscale.com>
This PR fixes two flaws in the current hyperband implementation. #### 1. Bug in the `r` calculation. #1620 introduced a minimum constraint in the `r` calculation during successive halving with `r = min(r, max_t - prev_r)`. It's unclear to me where this is coming from (cc @richardliaw), but in my opinion this is flawed. E.g. for `s=1`, `max_t=8` and `eta=2`, we get `r0 = 4`. Then `r = r0 * 2 = 8`. With the current formula, we then get `r1 = min(r, max_t - r0) = min(8, 8-4) = 4`. Thus, `r0 = r1` and the bracket already "finished" after 4 iterations and should terminate all trials. Or in other words, none of the trials in this bracket will ever proceed. I believe the correct fix here is to set `r = min(r, max_t)`. I couldn't find a reference implementation for comparison, but it logically makes sense and seems to match the formula behavior described in the paper. #### 2. Stopping of "overstepped" trials. The first bug revealed a second shortcoming in the current implementation. When a trial reports a timestep that is higher than `r_i` and `r_(i+1)`, it can hang forever. This is because "good" trials are only continued if `bracket.continue_trial(t)` returns `True`. However, if the trial already overstepped `r_(i+1)`, this can return `False` (specifically when `stop_last_trials=True`). In that case, a paused or running trial will not be terminated nor continued, and instead hang forever. This second case is fixed in this PR by introducing another clause in the processing of "good" trials that checks for this condition. Signed-off-by: Kai Fricke <kai@anyscale.com>
This PR fixes two flaws in the current hyperband implementation. #### 1. Bug in the `r` calculation. ray-project#1620 introduced a minimum constraint in the `r` calculation during successive halving with `r = min(r, max_t - prev_r)`. It's unclear to me where this is coming from (cc @richardliaw), but in my opinion this is flawed. E.g. for `s=1`, `max_t=8` and `eta=2`, we get `r0 = 4`. Then `r = r0 * 2 = 8`. With the current formula, we then get `r1 = min(r, max_t - r0) = min(8, 8-4) = 4`. Thus, `r0 = r1` and the bracket already "finished" after 4 iterations and should terminate all trials. Or in other words, none of the trials in this bracket will ever proceed. I believe the correct fix here is to set `r = min(r, max_t)`. I couldn't find a reference implementation for comparison, but it logically makes sense and seems to match the formula behavior described in the paper. #### 2. Stopping of "overstepped" trials. The first bug revealed a second shortcoming in the current implementation. When a trial reports a timestep that is higher than `r_i` and `r_(i+1)`, it can hang forever. This is because "good" trials are only continued if `bracket.continue_trial(t)` returns `True`. However, if the trial already overstepped `r_(i+1)`, this can return `False` (specifically when `stop_last_trials=True`). In that case, a paused or running trial will not be terminated nor continued, and instead hang forever. This second case is fixed in this PR by introducing another clause in the processing of "good" trials that checks for this condition. Signed-off-by: Kai Fricke <kai@anyscale.com> Signed-off-by: Jim Thompson <jimthompson5802@gmail.com>
This PR fixes two flaws in the current hyperband implementation. #### 1. Bug in the `r` calculation. ray-project#1620 introduced a minimum constraint in the `r` calculation during successive halving with `r = min(r, max_t - prev_r)`. It's unclear to me where this is coming from (cc @richardliaw), but in my opinion this is flawed. E.g. for `s=1`, `max_t=8` and `eta=2`, we get `r0 = 4`. Then `r = r0 * 2 = 8`. With the current formula, we then get `r1 = min(r, max_t - r0) = min(8, 8-4) = 4`. Thus, `r0 = r1` and the bracket already "finished" after 4 iterations and should terminate all trials. Or in other words, none of the trials in this bracket will ever proceed. I believe the correct fix here is to set `r = min(r, max_t)`. I couldn't find a reference implementation for comparison, but it logically makes sense and seems to match the formula behavior described in the paper. #### 2. Stopping of "overstepped" trials. The first bug revealed a second shortcoming in the current implementation. When a trial reports a timestep that is higher than `r_i` and `r_(i+1)`, it can hang forever. This is because "good" trials are only continued if `bracket.continue_trial(t)` returns `True`. However, if the trial already overstepped `r_(i+1)`, this can return `False` (specifically when `stop_last_trials=True`). In that case, a paused or running trial will not be terminated nor continued, and instead hang forever. This second case is fixed in this PR by introducing another clause in the processing of "good" trials that checks for this condition. Signed-off-by: Kai Fricke <kai@anyscale.com> Signed-off-by: Victor <vctr.y.m@example.com>
Why are these changes needed?
This PR fixes two flaws in the current hyperband implementation.
1. Bug in the
r
calculation.#1620 introduced a minimum constraint in the
r
calculation during successive halving withr = min(r, max_t - prev_r)
. It's unclear to me where this is coming from (cc @richardliaw), but in my opinion this is flawed.E.g. for
s=1
,max_t=8
andeta=2
, we getr0 = 4
. Thenr = r0 * 2 = 8
. With the current formula, we then getr1 = min(r, max_t - r0) = min(8, 8-4) = 4
. Thus,r0 = r1
and the bracket already "finished" after 4 iterations and should terminate all trials. Or in other words, none of the trials in this bracket will ever proceed.I believe the correct fix here is to set
r = min(r, max_t)
. I couldn't find a reference implementation for comparison, but it logically makes sense and seems to match the formula behavior described in the paper.2. Stopping of "overstepped" trials.
The first bug revealed a second shortcoming in the current implementation. When a trial reports a timestep that is higher than
r_i
andr_(i+1)
, it can hang forever. This is because "good" trials are only continued ifbracket.continue_trial(t)
returnsTrue
. However, if the trial already oversteppedr_(i+1)
, this can returnFalse
(specifically whenstop_last_trials=True
). In that case, a paused or running trial will not be terminated nor continued, and instead hang forever.This second case is fixed in this PR by introducing another clause in the processing of "good" trials that checks for this condition.
Related issue number
Closes #37605
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.