Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Skip incrementing failure counter on preemption node died failures #41285

Merged
merged 26 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
2c80249
fix failure config error msg
justinvyu Nov 20, 2023
d676171
add logic for mocked preemption failure in tune error handling
justinvyu Nov 20, 2023
a89ec45
slight refactor of backend exec for mocking in test
justinvyu Nov 20, 2023
1b3b51f
add test
justinvyu Nov 20, 2023
72a84fc
move handle error logic out of schedule_trial_stop, and update should…
justinvyu Nov 21, 2023
dda8466
update test
justinvyu Nov 21, 2023
5050aef
Merge branch 'master' of https://github.com/ray-project/ray into hand…
justinvyu Nov 27, 2023
80eeb5f
fix failing test (and remove unneeded tune restore error)
justinvyu Nov 27, 2023
ba62f43
rework test try 1
justinvyu Nov 28, 2023
14cda17
Revert "rework test try 1"
justinvyu Nov 28, 2023
8fc0c1b
Merge branch 'master' of https://github.com/ray-project/ray into hand…
justinvyu Nov 29, 2023
2a75f6e
use public as_instanceof_cause
justinvyu Nov 29, 2023
9bdfb44
use correct core api
justinvyu Nov 29, 2023
dbfc3d3
add configuration env var
justinvyu Nov 29, 2023
ba3e298
add unit test
justinvyu Nov 29, 2023
6d42044
remove .run_metadata
justinvyu Nov 29, 2023
3205f05
revamp the integration test
justinvyu Nov 29, 2023
c5bd712
add todo to remove
justinvyu Nov 29, 2023
c4bdb21
Merge branch 'master' of https://github.com/ray-project/ray into hand…
justinvyu Nov 29, 2023
39501f4
fix ray init error
justinvyu Nov 29, 2023
3df81fa
rename env var
justinvyu Nov 29, 2023
546843d
Merge branch 'master' of https://github.com/ray-project/ray into hand…
justinvyu Nov 29, 2023
a83fae2
increase test timeouts
justinvyu Nov 29, 2023
df5c05a
Merge branch 'master' of https://github.com/ray-project/ray into hand…
justinvyu Dec 5, 2023
1a44f3d
remove the thing i need to remove
justinvyu Dec 5, 2023
714e1a8
fix test for preempted property
justinvyu Dec 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions python/ray/air/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,29 +382,28 @@ class FailureConfig:
Will recover from the latest checkpoint if present.
Setting to -1 will lead to infinite recovery retries.
Setting to 0 will disable retries. Defaults to 0.
fail_fast: Whether to fail upon the first error. Only used for
Ray Tune - this does not apply
to single training runs (e.g. with ``Trainer.fit()``).
If fail_fast='raise' provided, Ray Tune will automatically
raise the exception received by the Trainable. fail_fast='raise'
can easily leak resources and should be used with caution (it
is best used with `ray.init(local_mode=True)`).
fail_fast: Whether to fail upon the first error.
If fail_fast='raise' provided, the original error during training will be
immediately raised. fail_fast='raise' can easily leak resources and
should be used with caution.
"""

max_failures: int = 0
fail_fast: Union[bool, str] = False

def __post_init__(self):
# Same check as in tune.run
if self.fail_fast and self.max_failures != 0:
raise ValueError("max_failures must be 0 if fail_fast=True.")

# Same check as in TuneController
if not (isinstance(self.fail_fast, bool) or self.fail_fast.upper() == "RAISE"):
raise ValueError(
"fail_fast must be one of {bool, 'raise'}. " f"Got {self.fail_fast}."
)

# Same check as in tune.run
if self.fail_fast and self.max_failures != 0:
raise ValueError(
f"max_failures must be 0 if fail_fast={repr(self.fail_fast)}."
)

def __repr__(self):
return _repr_dataclass(self)

Expand Down
126 changes: 124 additions & 2 deletions python/ray/air/tests/test_errors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
This test suite covers error handling and propagation in Ray AIR.
This test suite covers error handling and propagation in Ray Train/Tune.

There are two main error types to test:
1. Trainable errors: These happen in the remote actor itself.
Expand All @@ -16,16 +16,26 @@
- Assert how errors from the Tune driver get propagated to the user.
"""
import gc
import pytest
import threading
import time
from tempfile import TemporaryDirectory

import pytest

import ray
from ray import train, tune
from ray._private.test_utils import wait_for_condition
from ray._raylet import GcsClient
from ray.cluster_utils import Cluster
from ray.core.generated import autoscaler_pb2
from ray.train import Checkpoint, FailureConfig, RunConfig, ScalingConfig
from ray.train.data_parallel_trainer import DataParallelTrainer
from ray.train.trainer import BaseTrainer, TrainingFailedError
from ray.tune import Tuner, TuneConfig, TuneError

from ray.tests.conftest import * # noqa
from ray.train.tests.util import create_dict_checkpoint, load_dict_checkpoint


@pytest.fixture(scope="module")
def ray_start_4_cpus():
Expand All @@ -43,6 +53,34 @@ def gc_collect():
gc.collect()


@pytest.fixture
def cluster_setup(ray_start_cluster_head: Cluster):
# Sets up a cluster with 4 nodes: head node + 3 workers
cluster = ray_start_cluster_head
nodes = []
nodes.append(cluster.add_node(resources={"worker1": 1, "coordinator": 1}))
nodes.append(cluster.add_node(resources={"worker2": 1, "cpu": 1}))
nodes.append(cluster.add_node(resources={"worker3": 1, "cpu": 1}))
cluster.wait_for_nodes()

@ray.remote
def get_node_id():
return ray.get_runtime_context().get_node_id()

worker1_node_id = ray.get(get_node_id.options(resources={"worker1": 1}).remote())
worker2_node_id = ray.get(get_node_id.options(resources={"worker2": 1}).remote())
worker3_node_id = ray.get(get_node_id.options(resources={"worker3": 1}).remote())
wait_for_condition(
lambda: len({node["NodeID"] for node in ray.nodes() if (node["Alive"])}) == 4
)

yield cluster, nodes, [
worker1_node_id,
worker2_node_id,
worker3_node_id,
]


class _TestSpecificError(RuntimeError):
pass

Expand Down Expand Up @@ -195,6 +233,90 @@ def test_driver_error_with_trainer(ray_start_4_cpus, tmp_path, error_on):
assert TrainingFailedError._FAILURE_CONFIG_MSG not in str(exc_info.value)


@pytest.mark.parametrize("error_at_level", ["worker", "coordinator"])
def test_preemption_handling(
cluster_setup,
tmp_path,
error_at_level: str,
):
"""Integration test for node preemption handling in Ray Train/Tune.
Even though `max_failures=0`, preemption errors should still be retried."""
cluster, nodes, node_ids = cluster_setup
# node 1 = coordinator, node 2 = worker, node 3 = worker
coordinator_node, worker_node, _ = nodes
coordinator_node_id, worker_node_id, _ = node_ids

num_workers = 2
tmp_path.joinpath("markers").mkdir()

def train_fn(config):
checkpoint = train.get_checkpoint()
start_iter = 0
if checkpoint:
start_iter = load_dict_checkpoint(checkpoint)["iter"] + 1
print(f"Restored at iter = {start_iter}")

for iter in range(start_iter, 6):
with create_dict_checkpoint({"iter": iter}) as checkpoint:
ray.train.report({"iter": iter}, checkpoint=checkpoint)

if iter == 2:
# Write a "done marker" to tell the driver to simulate a preemption.
tmp_path.joinpath(
"markers", str(ray.train.get_context().get_world_rank())
).touch()
# Await execution.
time.sleep(120)

def launch_training():
trainer = DataParallelTrainer(
train_loop_per_worker=train_fn,
scaling_config=ScalingConfig(
num_workers=num_workers,
trainer_resources={"coordinator": 1},
resources_per_worker={"cpu": 1}, # worker2 and worker3
),
run_config=RunConfig(
storage_path=str(tmp_path),
name="test_preemption_error",
failure_config=train.FailureConfig(fail_fast=False, max_failures=0),
),
)
result = trainer.fit()
assert result.metrics["iter"] == 5

t = threading.Thread(target=launch_training)
t.start()

# Wait until the workers are ready for preemption (after a few checkpoints).
while len(list(tmp_path.joinpath("markers").glob("*"))) < num_workers:
time.sleep(0.5)

if error_at_level == "coordinator":
node, node_id = coordinator_node, coordinator_node_id
elif error_at_level == "worker":
node, node_id = worker_node, worker_node_id
else:
raise NotImplementedError(f"Invalid error_at_level = {error_at_level}")

# Preempt a node.
gcs_client = GcsClient(address=ray.get_runtime_context().gcs_address)
print("Draining node...")
is_accepted = gcs_client.drain_node(
node_id,
autoscaler_pb2.DrainNodeReason.Value("DRAIN_NODE_REASON_PREEMPTION"),
"preemption",
)
assert is_accepted
print("Killing node...")
cluster.remove_node(node, allow_graceful=True)
print("Adding new node..") # so that the job can be rescheduled
# New node can replace a preempted coordinator or worker
# NOTE: `cluster.add_node` only works in the main thread.
cluster.add_node(resources={"coordinator": 1, "cpu": 1})
t.join() # Assert no errors during training.


if __name__ == "__main__":
import sys

Expand Down
4 changes: 4 additions & 0 deletions python/ray/train/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def _get_defaults_results_dir() -> str:
# or Train worker to the trial directory. Defaults to 1.
RAY_CHDIR_TO_TRIAL_DIR = "RAY_CHDIR_TO_TRIAL_DIR"

# Set this to 1 to count preemption errors toward `FailureConfig(max_failures)`.
# Defaults to 0, which always retries on node preemption failures.
RAY_TRAIN_COUNT_PREEMPTION_ERRORS = "RAY_TRAIN_COUNT_PREEMPTION_ERRORS"
justinvyu marked this conversation as resolved.
Show resolved Hide resolved

# NOTE: When adding a new environment variable, please track it in this list.
TRAIN_ENV_VARS = {
justinvyu marked this conversation as resolved.
Show resolved Hide resolved
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV,
Expand Down
112 changes: 61 additions & 51 deletions python/ray/train/data_parallel_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union

import ray
from ray._private.thirdparty.tabulate.tabulate import tabulate
Expand Down Expand Up @@ -357,61 +357,71 @@ def _validate_scaling_config(cls, scaling_config: ScalingConfig) -> ScalingConfi

return scaling_config

def _report(self, training_iterator: TrainingIterator) -> None:
for results in training_iterator:
def _run_training(self, training_iterator: TrainingIterator) -> None:
"""This method loops over the `TrainingIterator`:
The actual iteration (for ... in ...) waits for the training function
on each worker to report a result and supplies it as a list of results.
Afterwards (in the body of the loop), it will report the result
to the Tune session.
The iterator ends after the training function on each worker has finished.
"""
for training_results in training_iterator:
# TODO(ml-team): add ability to report results from multiple workers.
first_worker_result = results[0]
assert all(isinstance(result, _TrainingResult) for result in results)

tune_session = get_session()

# Check if any workers reported a checkpoint.
# If so, report a checkpoint pointing to the persisted location
# to Tune for book-keeping.
# NOTE: This removes the restriction for any individual worker
# (ex: global rank 0 worker) from needing to report a checkpoint.
# All workers reported a checkpoint to the same fs path, so there's
# no need to report multiple checkpoints to Tune.
worker_checkpoints = [
result.checkpoint for result in results if result.checkpoint is not None
]
at_least_one_reported_checkpoint = len(worker_checkpoints) > 0

if at_least_one_reported_checkpoint:
# Update the coordinator's checkpoint index to the latest.
# This is what keeps the checkpoint index in line with the workers.
tune_session.storage._update_checkpoint_index(
first_worker_result.metrics
)

# Make sure that all workers uploaded to the same location.
assert all(
checkpoint.path == tune_session.storage.checkpoint_fs_path
for checkpoint in worker_checkpoints
)
self._propagate_results(training_results)

def _propagate_results(self, training_results: List[_TrainingResult]):
first_worker_result = training_results[0]
assert all(isinstance(result, _TrainingResult) for result in training_results)

tune_session = get_session()

# Check if any workers reported a checkpoint.
# If so, report a checkpoint pointing to the persisted location
# to Tune for book-keeping.
# NOTE: This removes the restriction for any individual worker
# (ex: global rank 0 worker) from needing to report a checkpoint.
# All workers reported a checkpoint to the same fs path, so there's
# no need to report multiple checkpoints to Tune.
worker_checkpoints = [
result.checkpoint
for result in training_results
if result.checkpoint is not None
]
at_least_one_reported_checkpoint = len(worker_checkpoints) > 0

if at_least_one_reported_checkpoint:
# Update the coordinator's checkpoint index to the latest.
# This is what keeps the checkpoint index in line with the workers.
tune_session.storage._update_checkpoint_index(first_worker_result.metrics)

# Make sure that all workers uploaded to the same location.
assert all(
checkpoint.path == tune_session.storage.checkpoint_fs_path
for checkpoint in worker_checkpoints
)

checkpoint = (
Checkpoint(
filesystem=tune_session.storage.storage_filesystem,
path=tune_session.storage.checkpoint_fs_path,
)
if at_least_one_reported_checkpoint
else None
checkpoint = (
Checkpoint(
filesystem=tune_session.storage.storage_filesystem,
path=tune_session.storage.checkpoint_fs_path,
)
if at_least_one_reported_checkpoint
else None
)

tracked_training_result = _TrainingResult(
checkpoint=checkpoint,
metrics=first_worker_result.metrics,
)
tracked_training_result = _TrainingResult(
checkpoint=checkpoint,
metrics=first_worker_result.metrics,
)

logger.debug(
"Report (metrics, checkpoint) to the Tune session:\n"
f" metrics={tracked_training_result.metrics}\n"
f" checkpoint={tracked_training_result.checkpoint}"
)
logger.debug(
"Report (metrics, checkpoint) to the Tune session:\n"
f" metrics={tracked_training_result.metrics}\n"
f" checkpoint={tracked_training_result.checkpoint}"
)

# Report the metrics and checkpoint to Tune.
tune_session._report_training_result(tracked_training_result)
# Report the metrics and checkpoint to Tune.
tune_session._report_training_result(tracked_training_result)

def training_loop(self) -> None:
scaling_config = self._validate_scaling_config(self.scaling_config)
Expand Down Expand Up @@ -457,7 +467,7 @@ def training_loop(self) -> None:
checkpoint=self.starting_checkpoint,
)

self._report(training_iterator)
self._run_training(training_iterator)

# Shutdown workers.
backend_executor.shutdown()
Expand Down
7 changes: 0 additions & 7 deletions python/ray/tune/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,3 @@ class _TuneNoNextExecutorEventError(_SubCategoryTuneError):
this category. This category is for everything else."""

pass


class _TuneRestoreError(_SubCategoryTuneError):
"""Error that happens in restoring a remote trainable."""

def __init__(self, exc: Exception):
self.exc = exc
Loading
Loading