Skip to content

Commit

Permalink
[train] New persistence mode cleanup: Train internals (ray-project#39921
Browse files Browse the repository at this point in the history
)

This PR cleans up some train internals to remove the old codepath:
* Removes old train checkpoint managers
* Removes lazy checkpointing remnants
* Cleans up `TrainingIterator`, `BackendExecutor`, parts of `_TrainSession`, and `BaseTrainer`

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Victor <vctr.y.m@example.com>
  • Loading branch information
justinvyu authored and Victor committed Oct 11, 2023
1 parent 8b16b07 commit 0cd37fe
Show file tree
Hide file tree
Showing 11 changed files with 103 additions and 907 deletions.
6 changes: 0 additions & 6 deletions python/ray/air/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,10 @@
"TRAIN_COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING"
)

# Integer value which if set will disable lazy checkpointing
# (avoiding unnecessary serialization if worker is on the same node
# as Trainable)
DISABLE_LAZY_CHECKPOINTING_ENV = "TRAIN_DISABLE_LAZY_CHECKPOINTING"

# NOTE: When adding a new environment variable, please track it in this list.
# TODO(ml-team): Most env var constants should get moved here.
AIR_ENV_VARS = {
COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV,
DISABLE_LAZY_CHECKPOINTING_ENV,
"RAY_AIR_FULL_TRACEBACKS",
"RAY_AIR_NEW_OUTPUT",
"RAY_AIR_RICH_LAYOUT",
Expand Down
56 changes: 8 additions & 48 deletions python/ray/train/_internal/backend_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,17 @@
import ray._private.ray_constants as ray_constants
from ray.data import Dataset
from ray._private.ray_constants import env_integer
from ray.air.config import CheckpointConfig
from ray.exceptions import RayActorError
from ray.train import DataConfig
from ray.air.checkpoint import Checkpoint
from ray.train._internal.session import (
TrainingResult,
_TrainingResult,
TrialInfo,
get_session,
init_session,
shutdown_session,
)
from ray.train._internal.storage import _use_storage_context, StorageContext
from ray.train._internal.storage import StorageContext
from ray.train._internal.utils import check_for_failure
from ray.train._internal.worker_group import WorkerGroup
from ray.train.backend import BackendConfig
Expand All @@ -28,7 +27,6 @@
ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV,
TRAIN_ENABLE_WORKER_SPREAD_ENV,
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV,
DISABLE_LAZY_CHECKPOINTING_ENV,
ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV,
)
from ray.util.placement_group import get_current_placement_group, remove_placement_group
Expand Down Expand Up @@ -96,7 +94,6 @@ def __init__(
num_gpus_per_worker: float = 0,
additional_resources_per_worker: Optional[Dict[str, float]] = None,
max_retries: int = 3,
checkpoint_config: Optional[CheckpointConfig] = None,
):
self._backend_config = backend_config
self._backend = backend_config.backend_cls()
Expand All @@ -117,12 +114,6 @@ def __init__(
self.worker_group = InactiveWorkerGroup()
self.dataset_shards = None

self._checkpoint_keep_all_ranks = (
checkpoint_config and checkpoint_config._checkpoint_keep_all_ranks
)
self._checkpoint_upload_from_workers = (
checkpoint_config and checkpoint_config._checkpoint_upload_from_workers
)
self._resource_configs = [
ResourceConfig(
ray_constants.NEURON_CORES,
Expand Down Expand Up @@ -456,7 +447,6 @@ def start_training(
use_detailed_autofilled_metrics = env_integer(
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, 0
)
use_lazy_checkpointing = not env_integer(DISABLE_LAZY_CHECKPOINTING_ENV, 0)

# First initialize the session.
def initialize_session(
Expand All @@ -470,8 +460,6 @@ def initialize_session(
checkpoint,
dataset_shard,
metadata,
checkpoint_keep_all_ranks,
checkpoint_upload_from_workers,
storage,
):
try:
Expand All @@ -487,9 +475,6 @@ def initialize_session(
metadata=metadata,
checkpoint=checkpoint,
detailed_autofilled_metrics=use_detailed_autofilled_metrics,
enable_lazy_checkpointing=use_lazy_checkpointing,
checkpoint_keep_all_ranks=checkpoint_keep_all_ranks,
checkpoint_upload_from_workers=(checkpoint_upload_from_workers),
storage=storage,
)
except ValueError:
Expand Down Expand Up @@ -532,10 +517,6 @@ def initialize_session(
dataset_shard=self.dataset_shards[index],
metadata=metadata,
checkpoint=checkpoint,
checkpoint_keep_all_ranks=self._checkpoint_keep_all_ranks,
checkpoint_upload_from_workers=(
self._checkpoint_upload_from_workers
),
storage=storage,
)
)
Expand All @@ -554,15 +535,15 @@ def train_async():

self.worker_group.execute_async(train_async)

def get_next_results(self) -> Optional[List[TrainingResult]]:
"""Fetches the next ``TrainingResult`` from each worker.
def get_next_results(self) -> Optional[List[_TrainingResult]]:
"""Fetches the next ``_TrainingResult`` from each worker.
Each ``TrainingResult`` is expected to correspond to the same step from
each worker (e.g. the same call to ``session.report()``).
Each ``_TrainingResult`` is expected to correspond to the same step from
each worker (e.g. the same call to ``train.report()``).
Returns:
A list of ``TrainingResult``s with the same
``TrainingResultType``, or ``None`` if there are no more results.
A list of ``_TrainingResult``s or ``None`` if there are no more results
since the training function has exited on all workers.
"""

def get_next():
Expand Down Expand Up @@ -598,29 +579,8 @@ def get_next():
# Return None if all results are None.
return None

if not _use_storage_context():
first_result = results[0]
result_type = first_result.type
if any(r.type != result_type for r in results):
raise RuntimeError(
"Some workers returned results with "
"different types. Make sure that "
"`session.report()` are called the "
"same number of times on all workers."
)

return results

def _set_legacy_checkpoint_uri(self, uri: str):
"""Tell remote sessions where to upload the chekcpoint."""

def set_uri():
session = _get_session("_set_legacy_checkpoint_uri")
session._set_legacy_checkpoint_uri(uri)

futures = self.worker_group.execute_async(set_uri)
self.get_with_failure_handling(futures)

def pause_reporting(self):
"""Disable workers from enqueuing results from ``session.report()``.
Expand Down
Loading

0 comments on commit 0cd37fe

Please sign in to comment.