Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train/tune] Refactor trial metadata organization #38165

Merged
merged 28 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/ray/air/_internal/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,10 @@ def __init__(

self.set_delete_fn(delete_fn)

@property
def checkpoint_config(self):
return self._checkpoint_strategy

def set_delete_fn(
self, delete_fn: Optional[Callable[["_TrackedCheckpoint"], None]]
):
Expand Down
4 changes: 4 additions & 0 deletions python/ray/train/_internal/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def __init__(self, checkpoint_config: Optional[CheckpointConfig]):
f"{self._checkpoint_config.num_to_keep}"
)

@property
def checkpoint_config(self):
return self._checkpoint_config

def register_checkpoint(self, checkpoint_result: _TrainingResult):
"""Register new checkpoint and add to bookkeeping.

Expand Down
10 changes: 7 additions & 3 deletions python/ray/tune/analysis/experiment_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,11 @@ def _load_checkpoints_from_latest(self, latest_checkpoint: List[str]) -> None:
experiment_state = json.load(f, cls=TuneFunctionDecoder)
self._experiment_states.append(experiment_state)

if "checkpoints" not in experiment_state:
if "trial_data" not in experiment_state:
raise TuneError("Experiment state invalid; no checkpoints found.")

self._checkpoints_and_paths += [
(cp, Path(path).parent) for cp in experiment_state["checkpoints"]
(cp, Path(path).parent) for cp in experiment_state["trial_data"]
]

def _maybe_download_experiment_checkpoint(
Expand Down Expand Up @@ -977,9 +977,13 @@ def _get_trial_paths(self) -> List[str]:
"since checkpointing is periodic."
)
self.trials = []
for trial_json_state, path in self._checkpoints_and_paths:
for (
trial_json_state,
trial_runtime_state,
), path in self._checkpoints_and_paths:
try:
trial = Trial.from_json_state(trial_json_state, stub=True)
trial.restore_runtime_state(trial_runtime_state)
trial.local_experiment_path = str(path)
except Exception:
logger.warning(
Expand Down
47 changes: 25 additions & 22 deletions python/ray/tune/execution/tune_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def save_to_dir(self, experiment_dir: Optional[str] = None):
# Get state from trial executor and runner
runner_state = {
# Trials
"checkpoints": list(self._get_trial_checkpoints().values()),
"trial_data": list(self._get_trial_checkpoints().values()),
# Experiment data
"runner_data": self.__getstate__(),
# Metadata
Expand Down Expand Up @@ -529,8 +529,9 @@ def restore_from_dir(self, experiment_dir: Optional[str] = None) -> List[Trial]:

# 3. Load trials
trials = []
for trial_json_state in runner_state["checkpoints"]:
for trial_json_state, trial_runtime_metadata in runner_state["trial_data"]:
trial = Trial.from_json_state(trial_json_state)
trial.restore_runtime_state(trial_runtime_metadata)

# The following properties may be updated on restoration
# Ex: moved local/cloud experiment directory
Expand Down Expand Up @@ -595,13 +596,15 @@ def resume(
trials = self.restore_from_dir()

# Set trial statuses according to the resume configuration
for trial in sorted(trials, key=lambda t: t.last_update_time, reverse=True):
for trial in sorted(
trials, key=lambda t: t.run_metadata.last_result_time, reverse=True
):
trial_to_add = trial
if trial.status == Trial.ERROR:
if resume_errored:
# Keep trial ID on resume
trial_to_add.error_filename = None
trial_to_add.pickled_error_filename = None
trial_to_add.run_metadata.error_filename = None
trial_to_add.run_metadata.pickled_error_filename = None
trial_to_add.set_status(Trial.PENDING)
trial_to_add.restore_path = trial.checkpoint.dir_or_data
elif restart_errored:
Expand Down Expand Up @@ -1106,7 +1109,7 @@ def _maybe_reuse_cached_actor(self, trial: Trial) -> bool:
ray_actor = self._actor_manager._live_actors_to_ray_actors_resources[
cached_actor
][0]
trial.set_runner(ray_actor)
trial.set_ray_actor(ray_actor)

self._schedule_trial_reset(trial, trial.config, trial.experiment_tag)

Expand Down Expand Up @@ -1262,7 +1265,7 @@ def _maybe_cache_trial_actor(self, trial: Trial) -> bool:
tracked_actor = self._trial_to_actor.pop(trial)
self._actor_to_trial.pop(tracked_actor)

trial.set_runner(None)
trial.set_ray_actor(None)

return True

Expand All @@ -1278,7 +1281,7 @@ def _actor_started(self, tracked_actor: TrackedActor, log: str = "STARTED"):
ray_actor = self._actor_manager._live_actors_to_ray_actors_resources[
tracked_actor
][0]
trial.set_runner(ray_actor)
trial.set_ray_actor(ray_actor)

self._callbacks.on_trial_start(
iteration=self._iteration, trials=self._trials, trial=trial
Expand All @@ -1296,7 +1299,7 @@ def _actor_stopped(self, tracked_actor: TrackedActor):
trial = self._actor_to_trial.pop(tracked_actor)
logger.debug(f"Actor STOPPED for trial {trial}: {tracked_actor}")
self._trial_to_actor.pop(trial)
trial.set_runner(None)
trial.set_ray_actor(None)

logger.debug(f"Actor STOPPED: {tracked_actor}")

Expand Down Expand Up @@ -1516,8 +1519,8 @@ def _schedule_trial_stop(self, trial: Trial, exception: Optional[Exception] = No

logger.debug(f"Requesting to STOP actor for trial {trial}")

trial.saving_to = None
trial.restoring_from = None
trial.temporary_state.saving_to = None
trial.temporary_state.restoring_from = None

self._set_trial_status(trial, Trial.ERROR if exception else Trial.TERMINATED)
trial.set_location(_Location())
Expand All @@ -1544,7 +1547,7 @@ def _schedule_trial_stop(self, trial: Trial, exception: Optional[Exception] = No
tracked_actor = self._trial_to_actor.pop(trial)
self._actor_to_trial.pop(tracked_actor)

trial.set_runner(None)
trial.set_ray_actor(None)

self._remove_actor(tracked_actor=tracked_actor)

Expand Down Expand Up @@ -1867,7 +1870,7 @@ def _schedule_trial_save(
storage_mode=storage,
metrics=result,
)
trial.saving_to = checkpoint
trial.temporary_state.saving_to = checkpoint

return checkpoint

Expand All @@ -1890,8 +1893,8 @@ def _on_saving_result(self, trial, checkpoint_value: Union[ray.ObjectRef, str]):
"is being synced from the worker to the head node."
)

if trial.location.hostname and (
trial.location.hostname != get_node_ip_address()
if trial.temporary_state.location.hostname and (
trial.temporary_state.location.hostname != get_node_ip_address()
):
if log_once("tune_head_worker_checkpoint"):
logger.warning(msg)
Expand Down Expand Up @@ -1920,14 +1923,14 @@ def _process_trial_save(
self._checkpoint_manager.on_trial_checkpoint(trial)
self._mark_trial_to_checkpoint(trial)
else:
trial.saving_to.dir_or_data = checkpoint_value
trial.temporary_state.saving_to.dir_or_data = checkpoint_value
self._callbacks.on_checkpoint(
iteration=self._iteration,
trials=self._trials,
trial=trial,
checkpoint=trial.saving_to,
checkpoint=trial.temporary_state.saving_to,
)
trial.on_checkpoint(trial.saving_to)
trial.on_checkpoint(trial.temporary_state.saving_to)
self._checkpoint_manager.on_trial_checkpoint(trial)
if trial.checkpoint.storage_mode != CheckpointStorage.MEMORY:
self._mark_trial_to_checkpoint(trial)
Expand All @@ -1941,7 +1944,7 @@ def _process_trial_save(
"Trial %s: Error handling checkpoint %s", trial, checkpoint_value
)

trial.saving_to = None
trial.temporary_state.saving_to = None
decision = self._cached_trial_decisions.pop(trial.trial_id, None)
if decision and checkpoint_value:
self._queue_decision(trial, decision)
Expand All @@ -1950,7 +1953,7 @@ def _checkpoint_trial_if_needed(self, trial, force=False):
"""Checkpoints trial based off trial.last_result."""
if trial.should_checkpoint() or force:
# Save trial runtime if possible.
if trial.runner:
if trial.temporary_state.ray_actor:
self._schedule_trial_save(trial, storage=CheckpointStorage.PERSISTENT)

###
Expand Down Expand Up @@ -2017,7 +2020,7 @@ def _schedule_trial_restore(self, trial: Trial) -> bool:
"storage-based restoration"
)

trial.restoring_from = checkpoint
trial.temporary_state.restoring_from = checkpoint
self._schedule_trial_task(
trial=trial,
method_name=method_name,
Expand Down Expand Up @@ -2056,7 +2059,7 @@ def _try_recover(self, trial: Trial, exc: Union[TuneError, RayTaskError]):
self._cached_trial_decisions.pop(trial.trial_id, None)
# Resetting this, in case that the trial is in saving status when it crashes.
if trial.is_saving:
trial.saving_to = None
trial.temporary_state.saving_to = None
if trial.is_restoring and exc:
exc = _TuneRestoreError(exc)
self._schedule_trial_stop(trial, exception=exc)
Expand Down
Loading
Loading