Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[air/output] Add callback hook for trial recovery, only print error table at end #37572

Merged
merged 17 commits into from
Aug 1, 2023
20 changes: 20 additions & 0 deletions python/ray/tune/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,22 @@ def on_trial_complete(
"""
pass

def on_trial_recover(
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
):
"""Called after a trial instance failed (errored) but the trial is scheduled
for retry.

The search algorithm and scheduler are not notified.

Arguments:
iteration: Number of iterations of the tuning loop.
trials: List of trials.
trial: Trial that just has errored.
**info: Kwargs dict for forward compatibility.
"""
pass

def on_trial_error(
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
):
Expand Down Expand Up @@ -399,6 +415,10 @@ def on_trial_complete(self, **info):
for callback in self._callbacks:
callback.on_trial_complete(**info)

def on_trial_recover(self, **info):
for callback in self._callbacks:
callback.on_trial_recover(**info)

def on_trial_error(self, **info):
for callback in self._callbacks:
callback.on_trial_error(**info)
Expand Down
5 changes: 4 additions & 1 deletion python/ray/tune/execution/trial_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,13 +666,16 @@ def _process_trial_failure(
if trial.status == Trial.RUNNING:
if trial.should_recover():
self._try_recover(trial, exc=exception)
self._callbacks.on_trial_recover(
iteration=self._iteration, trials=self._trials, trial=trial
)
else:
self._scheduler_alg.on_trial_error(self, trial)
self._search_alg.on_trial_complete(trial.trial_id, error=True)
self._schedule_trial_stop(trial, exception=exception)
self._callbacks.on_trial_error(
iteration=self._iteration, trials=self._trials, trial=trial
)
self._schedule_trial_stop(trial, exception=exception)

###
# STOP
Expand Down
26 changes: 26 additions & 0 deletions python/ray/tune/experimental/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,28 @@ def on_trial_complete(
)
self._print_result(trial)

def on_trial_error(
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
):
curr_time_str, running_time_str = _get_time_str(self._start_time, time.time())
finished_iter = 0
if trial.last_result and TRAINING_ITERATION in trial.last_result:
finished_iter = trial.last_result[TRAINING_ITERATION]

self._start_block(f"trial_{trial}_error")
print(
f"{self._addressing_tmpl.format(trial)} "
f"errored after {finished_iter} iterations "
f"at {curr_time_str}. Total running time: {running_time_str}\n"
f"Error file: {trial.error_file}"
)
self._print_result(trial)

def on_trial_recover(
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
):
self.on_trial_error(iteration=iteration, trials=trials, trial=trial, **info)

def on_checkpoint(
self,
iteration: int,
Expand Down Expand Up @@ -970,6 +992,10 @@ def _print_heartbeat(self, trials, *sys_args, force: bool = False):
if more_infos:
print(", ".join(more_infos))

if not force:
# Only print error table at end of training
return

trials_with_error = _get_trials_with_error(trials)
if not trials_with_error:
return
Expand Down
Loading