Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Allow to specify application-level error to retry for actor task #42492

Merged
merged 4 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def __init__(
ray_remote_args,
)
self._ray_remote_args = self._apply_default_remote_args(self._ray_remote_args)
self._ray_actor_task_remote_args = {}
actor_task_errors = DataContext.get_current().actor_task_retry_on_errors
if len(actor_task_errors) > 0:
self._ray_actor_task_remote_args["retry_exceptions"] = actor_task_errors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC retry_exceptions can either be True/False or a list of exception types in ray core. and defaults to false.
Maybe let's keep the behavior same in data.
also remember to update the comments in DataContext.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, updated.

self._min_rows_per_bundle = min_rows_per_bundle

# Create autoscaling policy from compute strategy.
Expand Down Expand Up @@ -194,9 +198,11 @@ def _dispatch_tasks(self):
task_idx=self._next_data_task_idx,
target_max_block_size=self.actual_target_max_block_size,
)
gen = actor.submit.options(num_returns="streaming", name=self.name).remote(
DataContext.get_current(), ctx, *input_blocks
)
gen = actor.submit.options(
num_returns="streaming",
name=self.name,
**self._ray_actor_task_remote_args,
).remote(DataContext.get_current(), ctx, *input_blocks)

def _task_done_callback(actor_to_return):
# Return the actor that was running the task to the pool.
Expand Down
7 changes: 7 additions & 0 deletions python/ray/data/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@
"AWS Error SLOW_DOWN",
]

# The application-level errors that actor task would retry.
# Default to empty list to not retry on any errors.
DEFAULT_ACTOR_TASK_RETRY_ON_ERRORS = []


@DeveloperAPI
class DataContext:
Expand Down Expand Up @@ -201,6 +205,7 @@ def __init__(
enable_get_object_locations_for_metrics: bool,
use_runtime_metrics_scheduling: bool,
write_file_retry_on_errors: List[str],
actor_task_retry_on_errors: List[str],
):
"""Private constructor (use get_current() instead)."""
self.target_max_block_size = target_max_block_size
Expand Down Expand Up @@ -239,6 +244,7 @@ def __init__(
)
self.use_runtime_metrics_scheduling = use_runtime_metrics_scheduling
self.write_file_retry_on_errors = write_file_retry_on_errors
self.actor_task_retry_on_errors = actor_task_retry_on_errors
# The additonal ray remote args that should be added to
# the task-pool-based data tasks.
self._task_pool_data_task_remote_args: Dict[str, Any] = {}
Expand Down Expand Up @@ -309,6 +315,7 @@ def get_current() -> "DataContext":
enable_get_object_locations_for_metrics=DEFAULT_ENABLE_GET_OBJECT_LOCATIONS_FOR_METRICS, # noqa E501
use_runtime_metrics_scheduling=DEFAULT_USE_RUNTIME_METRICS_SCHEDULING, # noqa: E501
write_file_retry_on_errors=DEFAULT_WRITE_FILE_RETRY_ON_ERRORS,
actor_task_retry_on_errors=DEFAULT_ACTOR_TASK_RETRY_ON_ERRORS,
)

return _default_context
Expand Down
21 changes: 21 additions & 0 deletions python/ray/data/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,27 @@ def mapper(x):
ds.map(mapper).materialize()


def test_actor_task_failure(shutdown_only, restore_data_context):
ray.init(num_cpus=2)

ctx = DataContext.get_current()
ctx.actor_task_retry_on_errors = [ValueError]

ds = ray.data.from_items([0, 10], parallelism=2)

class Mapper:
def __init__(self):
self._counter = 0

def __call__(self, x):
if self._counter < 2:
self._counter += 1
raise ValueError("oops")
return x

ds.map_batches(Mapper, concurrency=1).materialize()


def test_concurrency(shutdown_only):
ray.init(num_cpus=6)
ds = ray.data.range(10, parallelism=10)
Expand Down