Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[air] pyarrow.fs persistence (10/n): Unify Tune and Train sessions to support new persistence path in FunctionTrainable #38284

Merged
merged 34 commits into from
Aug 12, 2023
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
1d8f49b
Move _TrainingResult to session.py
justinvyu Aug 8, 2023
c88df86
Prototype unified session
justinvyu Aug 8, 2023
7bca696
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 9, 2023
8871231
Fix incorrect merge
justinvyu Aug 9, 2023
d87100d
Implement reset with unified session (for actor reuse)
justinvyu Aug 10, 2023
5455c2c
Eager mode in session (difference in tune/train behavior)
justinvyu Aug 10, 2023
e1bd135
Working for train again
justinvyu Aug 10, 2023
60ea1ee
Remove unused ckpt index code
justinvyu Aug 10, 2023
3d23988
Fix lint
justinvyu Aug 10, 2023
4d209e7
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 10, 2023
67a0721
Add dict checkpoint utils for tests
justinvyu Aug 9, 2023
2a35b3e
Add env var as a constant
justinvyu Aug 9, 2023
7915999
Remove prints
justinvyu Aug 11, 2023
74a2f19
Fix tune.run sync config = None issue
justinvyu Aug 11, 2023
dfff147
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 11, 2023
35aaf2c
Rename eager_mode -> synchronous_result_reporting
justinvyu Aug 11, 2023
d08e6e4
Improve some comments + some cleanups
justinvyu Aug 11, 2023
c2e33cc
Some more cleanups (remove unused code)
justinvyu Aug 11, 2023
9b475ec
Remove reference to global session
justinvyu Aug 11, 2023
bf32992
Remove shared storage context (store in global session instead)
justinvyu Aug 11, 2023
b59e0b6
Update tuner e2e test to test restoration / checkpointing
justinvyu Aug 11, 2023
86b5520
Fix lint
justinvyu Aug 11, 2023
6c3b4e5
More cleanups + docstrings
justinvyu Aug 11, 2023
ce9920b
synch result reporting logic is flipped...
justinvyu Aug 11, 2023
a97025f
Handle trainable outputs correctly
justinvyu Aug 11, 2023
b14bce5
Propagate storage on trial.reset (for restarting upon restore)
justinvyu Aug 11, 2023
0792283
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 11, 2023
90480aa
Guard the tune session assertion
justinvyu Aug 11, 2023
e9ff63c
thread join timeout = 0 on cleanup + report any errors left in the queue
justinvyu Aug 11, 2023
c3246dd
Add back saving_to for now
justinvyu Aug 11, 2023
7f68633
Convert path to str for env var
justinvyu Aug 11, 2023
1498f05
TIL local variables override imports even if they're set conditionally
justinvyu Aug 11, 2023
5b40250
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 11, 2023
5d72ca8
Clarify comment a bit
justinvyu Aug 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 17 additions & 25 deletions python/ray/train/_internal/backend_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
from ray.train import DataConfig
from ray.air.checkpoint import Checkpoint
from ray.train._internal.session import (
_TrainSession,
TrainingResult,
TrialInfo,
get_session,
init_session,
shutdown_session,
)
from ray.train._internal.storage import _use_storage_context, get_storage_context
from ray.train._internal.storage import _use_storage_context
from ray.train._internal.utils import check_for_failure
from ray.train._internal.worker_group import WorkerGroup
from ray.train.backend import BackendConfig
Expand Down Expand Up @@ -424,6 +425,9 @@ def initialize_session(
node_rank_map,
) = self._create_rank_world_size_mappings()

tune_session: _TrainSession = get_session()
assert tune_session, "`start_training` should only be called from within Tune"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rllib's LearnerGroup uses BackendExecutor, and they may not be inside a Tune session (??). But they never call this code and only use it to start and stop a WorkerGroup. Maybe they should just use the WorkerGroup abstraction directly 😅


futures = []
for index in range(len(self.worker_group)):
futures.append(
Expand All @@ -444,8 +448,7 @@ def initialize_session(
checkpoint_upload_from_workers=(
self._checkpoint_upload_from_workers
),
# Pass the Trainable's shared storage context to the Train workers
storage=get_storage_context() if _use_storage_context() else None,
storage=tune_session.storage if _use_storage_context() else None,
)
)

Expand Down Expand Up @@ -506,30 +509,19 @@ def get_next():
else:
# Return None if all results are None.
return None
first_result = results[0]
result_type = first_result.type
if any(r.type != result_type for r in results):
raise RuntimeError(
"Some workers returned results with "
"different types. Make sure that "
"`session.report()` are called the "
"same number of times on all workers."
)

return results

def _set_checkpoint_index(self, checkpoint_index: int):
"""Update the checkpoint id in the StorageContext of all workers.

This determines the path that the next checkpoint will be saved to.
"""

def set_checkpoint_index():
session = _get_session("_set_checkpoint_index")
session.storage.current_checkpoint_index = checkpoint_index
if not _use_storage_context():
first_result = results[0]
result_type = first_result.type
if any(r.type != result_type for r in results):
raise RuntimeError(
"Some workers returned results with "
"different types. Make sure that "
"`session.report()` are called the "
"same number of times on all workers."
)

futures = self.worker_group.execute_async(set_checkpoint_index)
self.get_with_failure_handling(futures)
return results

def _set_legacy_checkpoint_uri(self, uri: str):
"""Tell remote sessions where to upload the chekcpoint."""
Expand Down
12 changes: 2 additions & 10 deletions python/ray/train/_internal/checkpoint_manager.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,18 @@
import logging
import numbers
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple

from ray._private.dict import flatten_dict
from ray.air.config import MAX
from ray.air._internal.util import is_nan
from ray.train import CheckpointConfig
from ray.train._internal.storage import _delete_fs_path
from ray.train._checkpoint import Checkpoint
from ray.train._internal.session import _TrainingResult


logger = logging.getLogger(__name__)


class _TrainingResult:
"""A (checkpoint, metrics) result reported by the user."""

def __init__(self, checkpoint: Checkpoint, metrics: Dict[str, Any]):
self.checkpoint = checkpoint
self.metrics = metrics


def _insert_into_sorted_list(list: List[Any], item: Any, key: Callable[[Any], Any]):
"""Insert an item into a sorted list with a custom key function.

Expand Down
Loading
Loading