-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Train] Skip incrementing failure counter on preemption node died failures #41285
[Train] Skip incrementing failure counter on preemption node died failures #41285
Conversation
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
…_recover logic accordingly Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
@@ -1432,9 +1436,6 @@ def _schedule_trial_stop(self, trial: Trial, exception: Optional[Exception] = No | |||
self._set_trial_status(trial, Trial.ERROR if exception else Trial.TERMINATED) | |||
trial.set_location(_Location()) | |||
|
|||
if exception: | |||
trial.handle_error(exc=exception) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Key change 1: The reason for moving this trial.handle_error
is:
handle_error
is what increments the number of failures.- Upon a trial failure, we used to check
trial.should_recover
BEFORE incrementing the new failure, sotrial.num_failures
is 1 less than it should be at that check. (Let's saynum_failures=2, max_failures=3
at this point.) handle_error
would happen afterwards right here. (num_failures=3
now.)- The old
should_recover
condition made it impossible for us to try recovering on a preemption error. (Even thoughhandle_error
would noop for the preemption error, we're already atnum_failures==max_failures=3
, soshould_recover=False
.) - It's more intuitive to have
num_failures
updated by the time oftrial.should_recover
, so now we just handle the error separately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like we need to call trial.handle_error
for all the other places that are currently calling _schedule_trial_stop
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general though I think this does move us in the right direction. In the long-term we should clean up the state machine in the tune controller, in it's current state it's not really clear where error handling is supposed to take place. 😵💫
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, I call it in the other 2 other places.
python/ray/tune/experiment/trial.py
Outdated
def _handle_ray_actor_error(self, exc: RayActorError): | ||
exc._preempted = True # TODO(justinvyu): Test the real integration | ||
if not exc._preempted: | ||
# Only count non-preempted actor errors as failures. | ||
self.run_metadata.num_failures += 1 | ||
|
||
def _handle_ray_task_error(self, exc: RayTaskError): | ||
if isinstance(exc.cause, RayActorError): | ||
# Handle the RayActorError directly (ex: Ray Train worker actor errors) | ||
return self._handle_ray_actor_error(exc.cause) | ||
|
||
# Increment failures for all user errors (which get raised as RayTaskError) | ||
self.run_metadata.num_failures += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Key change 2: this is the actual logic of the PR.
Question: Is it ok to treat RayTaskError
with a cause of RayActorError
so broadly like this? One strawman counterexample:
def tune_fn_trainable(config):
e = RayActorError()
e._preempted = True
raise e
tune.Tuner(tune_fn_trainable).fit()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another possibility would be to have the DataParallelTrainer
pass through the pre-emption RayActorError
as a special case, but I feel like that's more misleading, as it's disguising the coordinator's error with the worker's error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: use exc.as_instanceof_cause()
instead of the private cause
attr once that is fixed by @rkooo567
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems okay for now... this seems clean enough for now such that if new use cases come up in the future we can separate this logic and improve it further.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: add the configurability of whether or not to count preemption errors here.
`num_failures` should represent the number of times the trial has | ||
failed *up to the moment this method is called.* If we've failed | ||
5 times and `max_failures=5`, then we should recover, since | ||
we only pass the limit on the 6th failure. | ||
|
||
Note this may return true even when there is no checkpoint, either because | ||
`self.checkpoint_freq` is `0` or because the trial failed before | ||
a checkpoint has been made. | ||
""" | ||
return ( | ||
self.run_metadata.num_failures < self.max_failures | ||
or self.max_failures < 0 | ||
or ( | ||
self.run_metadata.num_failures == self.max_failures | ||
and self.temporary_state.num_restore_failures | ||
< int(os.environ.get("TUNE_RESTORE_RETRY_NUM", 0)) | ||
) | ||
self.run_metadata.num_failures <= self.max_failures or self.max_failures < 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See key change 1 comment.
self.run_metadata.num_failures == self.max_failures | ||
and self.temporary_state.num_restore_failures | ||
< int(os.environ.get("TUNE_RESTORE_RETRY_NUM", 0)) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this condition is not needed anymore.
TUNE_RESTORE_RETRY_NUM
configures how many attempts we try to restore before it counts as a real error.
The behavior with this condition removed makes sense to me:
If I'm at num_failures==max_failures
, then I should try up to TUNE_RESTORE_RETRY_NUM
times to restore. If all of those attempts fail, then we'll increment so that num_failures > max_failures
, and the run will not try to recover anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General question - what's the recommendation to users on how to configure checkpointing at for Train Jobs running on top of preemptible Clusters?
UPDATE - Addressed offline.
python/ray/air/tests/test_errors.py
Outdated
- Round 0: Actor error in the training worker. (shouldn't be counted) | ||
- Round 1: User error in the training worker. | ||
- Round 2: Actor error in the coordinator actor. (shouldn't be counted) | ||
- Round 3: No error. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we just run this as 4 separate jobs and check each one if it failed/counted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah good idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not able to figure out how to mock a property on the RayActorError
that core raises -- any ideas here?
I tried this:
class MockRayActorError(ray.exceptions.RayActorError):
preempted = True
monkeypatch.setattr(
ray.tune.execution.tune_controller, "RayActorError", MockRayActorError
)
monkeypatch.setattr(ray.exceptions, "RayActorError", MockRayActorError)
I was planning on reworking this test to use the actual gcs_client.drain_node
API to mock the preemption instead of mocking the attribute. (example here)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jjyao any ideas here?
python/ray/tune/experiment/trial.py
Outdated
def _handle_restore_error(self, exc: _TuneRestoreError): | ||
exc = exc.exc | ||
if self.temporary_state.num_restore_failures >= int( | ||
os.environ.get("TUNE_RESTORE_RETRY_NUM", 0) | ||
): | ||
# Restore was unsuccessful, try again without checkpoint. | ||
self.clear_checkpoint() | ||
self.run_metadata.num_failures += 1 | ||
else: | ||
self.temporary_state.num_restore_failures += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Orthogonal to this change but I'm wondering if we even want to keep this logic... not really clear to me why we remove the checkpoint and increase the number of failures.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This handle_restore_error
happens when the call to Trainable.restore
fails:
- This may be caused by a checkpoint download from cloud failing. Retrying without adding to the total failures counter may help here.
- There may be a bug in a user's
load_checkpoint
code. Retrying wouldn't help here. - Function trainables don't do any logic in
restore
/load_checkpoint
, leaving it to the user instead -- so this only really applies to class trainables.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default behavior is a little strange though: TUNE_RESTORE_RETRY_NUM=0
--> failures during restore clear the checkpoint count toward num_failures
and the run starts from scratch immediately.
If we remove this logic, the behavior becomes: failure during restore are treated normally and keep retrying from the checkpoint until max_failures
. I think it makes sense to remove this and restoring_from
so that we have to keep track of less state in total. Let's do that in a separate PR.
@@ -1432,9 +1436,6 @@ def _schedule_trial_stop(self, trial: Trial, exception: Optional[Exception] = No | |||
self._set_trial_status(trial, Trial.ERROR if exception else Trial.TERMINATED) | |||
trial.set_location(_Location()) | |||
|
|||
if exception: | |||
trial.handle_error(exc=exception) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like we need to call trial.handle_error
for all the other places that are currently calling _schedule_trial_stop
?
@@ -1432,9 +1436,6 @@ def _schedule_trial_stop(self, trial: Trial, exception: Optional[Exception] = No | |||
self._set_trial_status(trial, Trial.ERROR if exception else Trial.TERMINATED) | |||
trial.set_location(_Location()) | |||
|
|||
if exception: | |||
trial.handle_error(exc=exception) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general though I think this does move us in the right direction. In the long-term we should clean up the state machine in the tune controller, in it's current state it's not really clear where error handling is supposed to take place. 😵💫
python/ray/tune/experiment/trial.py
Outdated
def _handle_ray_actor_error(self, exc: RayActorError): | ||
exc._preempted = True # TODO(justinvyu): Test the real integration | ||
if not exc._preempted: | ||
# Only count non-preempted actor errors as failures. | ||
self.run_metadata.num_failures += 1 | ||
|
||
def _handle_ray_task_error(self, exc: RayTaskError): | ||
if isinstance(exc.cause, RayActorError): | ||
# Handle the RayActorError directly (ex: Ray Train worker actor errors) | ||
return self._handle_ray_actor_error(exc.cause) | ||
|
||
# Increment failures for all user errors (which get raised as RayTaskError) | ||
self.run_metadata.num_failures += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems okay for now... this seems clean enough for now such that if new use cases come up in the future we can separate this logic and improve it further.
else: | ||
self.run_metadata.num_failures += 1 | ||
|
||
if self.local_path: | ||
self.run_metadata.error_filename = EXPR_ERROR_FILE | ||
if isinstance(exc, RayTaskError): | ||
if isinstance(exc, (RayTaskError, RayActorError)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm given that we've never logged these before, when does RayActorError
actually happen? Would it be a RayActorError
or RayTaskError
if the trial node gets preempted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the trial node gets preempted, it's a RayActorError
.
ray.get(A.task.remote()) -> RayActorError if A's node dies
ray.get(A.task.remote()) -> RayTaskError(OriginalError) if A.task raises an OriginalError inside it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it was just an oversight not to log RayActorError
…le_spot_instance_failures
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
@@ -1896,8 +1899,6 @@ def _try_recover(self, trial: Trial, exc: Union[TuneError, RayTaskError]): | |||
# Resetting this, in case that the trial is in saving status when it crashes. | |||
if trial.is_saving: | |||
trial.temporary_state.saving_to = None | |||
if trial.is_restoring and exc: | |||
exc = _TuneRestoreError(exc) | |||
self._schedule_trial_stop(trial, exception=exc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to call trial.handle_error(exception)
before this one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try_recover
only gets called in process_trial_failure
, which already calls handle_error
.
python/ray/air/tests/test_errors.py
Outdated
- Round 0: Actor error in the training worker. (shouldn't be counted) | ||
- Round 1: User error in the training worker. | ||
- Round 2: Actor error in the coordinator actor. (shouldn't be counted) | ||
- Round 3: No error. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jjyao any ideas here?
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
This reverts commit ba62f43. Signed-off-by: Justin Yu <justinvyu@anyscale.com>
…le_spot_instance_failures
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
…le_spot_instance_failures
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome!
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
…le_spot_instance_failures
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
…le_spot_instance_failures
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
…d failures (ray-project#41285) Users expect different failures types to be handled differently in step 4 above: * The current behavior is that the count decrements, regardless of the error type. For example, if 3 pre-emptions happen with `max_failures=3`, then the run will end without continuing to recover through preemptions. * With `max_failures=-1` or some large value, there will be an infinite number of retries, but this could crash-loop on an application error (ex: a bug in the user code). This can be very expensive. This PR changes the failure counting of Ray Train/Tune to ignore spot instance preemption failures by default. This behavior is enabled by the new `RayActorError.preempted` flag introduced in ray-project#41102 that is set if the underlying cluster setup handles the cloud preemption signals properly and sets the preempting node to the `DRAINING` status. --------- Signed-off-by: Justin Yu <justinvyu@anyscale.com>
@justinvyu Nit:
|
…d failures (#41285) (#41609) Users expect different failures types to be handled differently in step 4 above: * The current behavior is that the count decrements, regardless of the error type. For example, if 3 pre-emptions happen with `max_failures=3`, then the run will end without continuing to recover through preemptions. * With `max_failures=-1` or some large value, there will be an infinite number of retries, but this could crash-loop on an application error (ex: a bug in the user code). This can be very expensive. This PR changes the failure counting of Ray Train/Tune to ignore spot instance preemption failures by default. This behavior is enabled by the new `RayActorError.preempted` flag introduced in #41102 that is set if the underlying cluster setup handles the cloud preemption signals properly and sets the preempting node to the `DRAINING` status. --------- Signed-off-by: Justin Yu <justinvyu@anyscale.com>
@justinvyu did we address @zhe-thoughts 's last comment above (it was made post merge and I didn't see any additional links in this ticket). |
Yeah, see the updated PR description. |
Why are these changes needed?
Users expect different failures types to be handled differently:
max_failures=3
, then the run will end without continuing to recover through preemptions.max_failures=-1
or some large value, there will be an infinite number of retries, but this could crash-loop on an application error (ex: a bug in the user code). This can be very expensive.This PR changes the failure counting of Ray Train/Tune to ignore spot instance preemption failures by default. This behavior is enabled by the new
RayActorError.preempted
flag introduced in #41102 that is set if the underlying cluster setup handles the cloud preemption signals properly and sets the preempting node to theDRAINING
status.Example
Here is an example scenario:
train.report
/torch.distributed
), while D is preempted.RayActorError(preempted=True)
.preempted=True
. This allows preemption failures to be retried repeatedly without contributing totrain.FailureConfig(max_failures=X)
.X
max failures by setting the environment variable:RAY_TRAIN_COUNT_PREEMPTION_AS_FAILURE=1
Miscellaneous
This is the current output in
error.txt
. TODO: the numbering should be fixed, and some indication of ignored errors should be added in.Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.