Skip to content

Commit

Permalink
[air] pyarrow.fs persistence: Some circular dependency cleanup (#38227
Browse files Browse the repository at this point in the history
)

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
  • Loading branch information
justinvyu committed Aug 8, 2023
1 parent 4e9e891 commit ed41e86
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 42 deletions.
11 changes: 6 additions & 5 deletions python/ray/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@
GenDataset,
TrainingFailedError,
)
from ray.tune.trainable.util import TrainableUtil
from ray.util.annotations import DeveloperAPI
from ray.train._internal.storage import _use_storage_context, get_storage_context
from ray.train._internal.storage import (
_use_storage_context,
StorageContext,
get_storage_context,
)


T = TypeVar("T")
Expand Down Expand Up @@ -257,7 +260,6 @@ def _fetch_next_result(self) -> Optional[List[Dict]]:
f"{[type in TrainingResultType]}"
)

# TODO(justinvyu): Remove unused code
def _finish_checkpointing(self):
while True:
results = self._backend_executor.get_next_results()
Expand All @@ -272,7 +274,6 @@ def _finish_checkpointing(self):
# TODO: Is this needed? I don't think this is ever called...
self._send_next_checkpoint_path_to_workers()

# TODO(justinvyu): Remove unused code
def _finish_training(self):
"""Finish training and return final results. Propagate any exceptions.
Expand Down Expand Up @@ -335,7 +336,7 @@ def __get_cloud_checkpoint_dir(self):
path = Path(session.get_trial_dir())
trial_dir_name = path.name
exp_dir_name = path.parent.name
checkpoint_dir_name = TrainableUtil._make_checkpoint_dir_name(
checkpoint_dir_name = StorageContext._make_checkpoint_dir_name(
self._checkpoint_manager._latest_checkpoint_id
)

Expand Down
5 changes: 1 addition & 4 deletions python/ray/tune/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@
from ray.air import CheckpointConfig
from ray.air._internal.uri_utils import URI
from ray.exceptions import RpcError
from ray.train._internal.storage import (
_use_storage_context,
StorageContext,
)
from ray.train._internal.storage import _use_storage_context, StorageContext
from ray.tune.error import TuneError
from ray.tune.registry import register_trainable, is_function_trainable
from ray.tune.result import _get_defaults_results_dir
Expand Down
12 changes: 4 additions & 8 deletions python/ray/tune/experiment/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@

import ray.cloudpickle as cloudpickle
from ray.exceptions import RayActorError, RayTaskError
from ray.train._internal.checkpoint_manager import (
_TrainingResult,
_CheckpointManager as _NewCheckpointManager,
)
from ray.train._internal.storage import _use_storage_context, StorageContext
from ray.tune import TuneError
from ray.tune.error import _TuneRestoreError
Expand Down Expand Up @@ -531,10 +535,6 @@ def __init__(
self.checkpoint_config = checkpoint_config

if _use_storage_context():
from ray.train._internal.checkpoint_manager import (
_CheckpointManager as _NewCheckpointManager,
)

self.checkpoint_manager = _NewCheckpointManager(
checkpoint_config=self.checkpoint_config
)
Expand Down Expand Up @@ -1075,8 +1075,6 @@ def on_checkpoint(self, checkpoint: _TrackedCheckpoint):
checkpoint: Checkpoint taken.
"""
if _use_storage_context():
from ray.train._internal.checkpoint_manager import _TrainingResult

checkpoint_result = checkpoint
assert isinstance(checkpoint_result, _TrainingResult)
self.checkpoint_manager.register_checkpoint(checkpoint_result)
Expand All @@ -1093,8 +1091,6 @@ def on_restore(self):
assert self.is_restoring

if _use_storage_context():
from ray.train._internal.checkpoint_manager import _TrainingResult

assert isinstance(self.restoring_from, _TrainingResult)

self.last_result = self.restoring_from.metrics
Expand Down
25 changes: 4 additions & 21 deletions python/ray/tune/trainable/function_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import warnings
from functools import partial
from numbers import Number
from typing import Any, Callable, Dict, Optional, Type, TYPE_CHECKING
from typing import Any, Callable, Dict, Optional, Type

from ray.air._internal.util import StartTraceback, RunnerThread
import queue
Expand All @@ -20,6 +20,9 @@
_RESULT_FETCH_TIMEOUT,
TIME_THIS_ITER_S,
)
from ray.train._checkpoint import Checkpoint as NewCheckpoint
from ray.train._internal.storage import _use_storage_context
from ray.train._internal.checkpoint_manager import _TrainingResult
from ray.tune import TuneError
from ray.tune.execution.placement_groups import PlacementGroupFactory
from ray.tune.trainable import session
Expand All @@ -37,9 +40,6 @@
from ray.util.annotations import DeveloperAPI
from ray.util.debug import log_once

if TYPE_CHECKING:
from ray.train._internal.checkpoint_manager import _TrainingResult


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -249,8 +249,6 @@ def has_new_checkpoint(self):
return self._fresh_checkpoint

def get_checkpoint_result(self) -> Optional["_TrainingResult"]:
from ray.train._internal.storage import _use_storage_context

assert _use_storage_context()
# The checkpoint is no longer fresh after it's been handed off to Tune.
self._fresh_checkpoint = False
Expand All @@ -267,10 +265,6 @@ def _start(self):
self._last_report_time = time.time()

def report(self, metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
from ray.train._internal.storage import _use_storage_context
from ray.train._internal.checkpoint_manager import _TrainingResult
from ray.train._checkpoint import Checkpoint as NewCheckpoint

# TODO(xwjiang): Tons of optimizations.
self._air_session_has_reported = True

Expand Down Expand Up @@ -301,9 +295,6 @@ def report(self, metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> N

@property
def loaded_checkpoint(self) -> Optional[Checkpoint]:
from ray.train._internal.storage import _use_storage_context
from ray.train._internal.checkpoint_manager import _TrainingResult

if _use_storage_context():
if not self._latest_checkpoint_result:
return None
Expand Down Expand Up @@ -502,8 +493,6 @@ def execute(self, fn):
def get_state(self):
state = super().get_state()

from ray.train._internal.storage import _use_storage_context

if _use_storage_context():
# TODO(justinvyu): This is only used to populate the tune metadata
# file within the checkpoint, so can be removed after if remove
Expand All @@ -519,9 +508,6 @@ def save_checkpoint(self, checkpoint_dir: str = ""):
if checkpoint_dir:
raise ValueError("Checkpoint dir should not be used with function API.")

from ray.train._internal.storage import _use_storage_context
from ray.train._internal.checkpoint_manager import _TrainingResult

if _use_storage_context():
checkpoint_result = self._status_reporter.get_checkpoint_result()
assert isinstance(checkpoint_result, _TrainingResult)
Expand Down Expand Up @@ -566,9 +552,6 @@ def save_to_object(self):
return checkpoint.to_bytes()

def load_checkpoint(self, checkpoint):
from ray.train._internal.storage import _use_storage_context
from ray.train._internal.checkpoint_manager import _TrainingResult

if _use_storage_context():
checkpoint_result = checkpoint
assert isinstance(checkpoint_result, _TrainingResult)
Expand Down
5 changes: 1 addition & 4 deletions python/ray/tune/trainable/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
TIME_THIS_ITER_S,
TRAINING_ITERATION,
)
from ray.train._internal.checkpoint_manager import _TrainingResult
from ray.train._internal.storage import (
_use_storage_context,
StorageContext,
Expand Down Expand Up @@ -501,8 +502,6 @@ def save(
# User saves checkpoint
checkpoint_dict_or_path = self.save_checkpoint(checkpoint_dir)

from ray.train._internal.checkpoint_manager import _TrainingResult

if _use_storage_context() and isinstance(
checkpoint_dict_or_path, _TrainingResult
):
Expand Down Expand Up @@ -869,8 +868,6 @@ def restore(
"""
if _use_storage_context():
from ray.train._internal.checkpoint_manager import _TrainingResult

checkpoint_result = checkpoint_path
assert isinstance(checkpoint_result, _TrainingResult)

Expand Down

0 comments on commit ed41e86

Please sign in to comment.