Skip to content

Commit

Permalink
[Data] Add function to dynamically generate ray_remote_args for Map…
Browse files Browse the repository at this point in the history
… APIs (ray-project#45143)

Adds a new parameter`ray_remote_args_fn` to Map APIs (`map()`, `map_batches()`, `flat_map()`, `filter()`), which allows the user to specify a function which returns a dict of Ray remote args be passed to an actor initialized from ActorPoolMapOperator. This function is called each time a worker is initialized, allowing the user to specify the parameters for every worker (e.g. setting the scheduling strategy at runtime).

Currently, Ray Data only allows passing static ray remote args, which has the limitation of sharing the placement group for all actors. This feature allows users to create different placement groups for each actor. For example, this will enable users to use Ray Data with vLLM with tensor parallel size > 1.

Signed-off-by: Scott Lee <sjl@anyscale.com>
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
  • Loading branch information
scottjlee authored and ryanaoleary committed Jun 6, 2024
1 parent 17f4b9a commit 3da99fb
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
compute_strategy: ActorPoolStrategy,
name: str = "ActorPoolMap",
min_rows_per_bundle: Optional[int] = None,
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
):
"""Create an ActorPoolMapOperator instance.
Expand All @@ -71,6 +72,12 @@ def __init__(
transform_fn, or None to use the block size. Setting the batch size is
important for the performance of GPU-accelerated transform functions.
The actual rows passed may be less if the dataset is small.
ray_remote_args_fn: A function that returns a dictionary of remote args
passed to each map worker. The purpose of this argument is to generate
dynamic arguments for each actor/task, and will be called each time
prior to initializing the worker. Args returned from this dict will
always override the args in ``ray_remote_args``. Note: this is an
advanced, experimental feature.
ray_remote_args: Customize the ray remote args for this op's tasks.
"""
super().__init__(
Expand All @@ -79,9 +86,9 @@ def __init__(
name,
target_max_block_size,
min_rows_per_bundle,
ray_remote_args_fn,
ray_remote_args,
)
self._ray_remote_args = self._apply_default_remote_args(self._ray_remote_args)
self._ray_actor_task_remote_args = {}
actor_task_errors = DataContext.get_current().actor_task_retry_on_errors
if actor_task_errors:
Expand All @@ -96,6 +103,8 @@ def __init__(
2 * data_context._max_num_blocks_in_streaming_gen_buffer
)
self._min_rows_per_bundle = min_rows_per_bundle
self._ray_remote_args_fn = ray_remote_args_fn
self._ray_remote_args = self._apply_default_remote_args(self._ray_remote_args)

self._actor_pool = _ActorPool(compute_strategy, self._start_actor)
# A queue of bundles awaiting dispatch to actors.
Expand Down Expand Up @@ -138,6 +147,8 @@ def _start_actor(self):
"""Start a new actor and add it to the actor pool as a pending actor."""
assert self._cls is not None
ctx = DataContext.get_current()
if self._ray_remote_args_fn:
self._refresh_actor_cls()
actor = self._cls.remote(
ctx,
src_fn_name=self.name,
Expand Down Expand Up @@ -213,6 +224,24 @@ def _task_done_callback(actor_to_return):
lambda: _task_done_callback(actor_to_return),
)

def _refresh_actor_cls(self):
"""When `self._ray_remote_args_fn` is specified, this method should
be called prior to initializing the new worker in order to get new
remote args passed to the worker. It updates `self.cls` with the same
`_MapWorker` class, but with the new remote args from
`self._ray_remote_args_fn`."""
assert self._ray_remote_args_fn, "_ray_remote_args_fn must be provided"
remote_args = self._ray_remote_args.copy()
new_remote_args = self._ray_remote_args_fn()

# Override args from user-defined remote args function.
new_and_overriden_remote_args = {}
for k, v in new_remote_args.items():
remote_args[k] = v
new_and_overriden_remote_args[k] = v
self._cls = ray.remote(**remote_args)(_MapWorker)
return new_and_overriden_remote_args

def all_inputs_done(self):
# Call base implementation to handle any leftover bundles. This may or may not
# trigger task dispatch.
Expand Down
25 changes: 21 additions & 4 deletions python/ray/data/_internal/execution/operators/map_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
name: str,
target_max_block_size: Optional[int],
min_rows_per_bundle: Optional[int],
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]],
ray_remote_args: Optional[Dict[str, Any]],
):
# NOTE: This constructor should not be called directly; use MapOperator.create()
Expand All @@ -60,7 +61,8 @@ def __init__(

self._map_transformer = map_transformer
self._ray_remote_args = _canonicalize_ray_remote_args(ray_remote_args or {})
self._ray_remote_args_factory = None
self._ray_remote_args_fn = ray_remote_args_fn
self._ray_remote_args_factory_actor_locality = None
self._remote_args_for_metrics = copy.deepcopy(self._ray_remote_args)

# Bundles block references up to the min_rows_per_bundle target.
Expand Down Expand Up @@ -111,6 +113,7 @@ def create(
# config and not contain implementation code.
compute_strategy: Optional[ComputeStrategy] = None,
min_rows_per_bundle: Optional[int] = None,
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
) -> "MapOperator":
"""Create a MapOperator.
Expand All @@ -132,6 +135,12 @@ def create(
transform_fn, or None to use the block size. Setting the batch size is
important for the performance of GPU-accelerated transform functions.
The actual rows passed may be less if the dataset is small.
ray_remote_args_fn: A function that returns a dictionary of remote args
passed to each map worker. The purpose of this argument is to generate
dynamic arguments for each actor/task, and will be called each time
prior to initializing the worker. Args returned from this dict will
always override the args in ``ray_remote_args``. Note: this is an
advanced, experimental feature.
ray_remote_args: Customize the ray remote args for this op's tasks.
"""
if compute_strategy is None:
Expand All @@ -149,6 +158,7 @@ def create(
target_max_block_size=target_max_block_size,
min_rows_per_bundle=min_rows_per_bundle,
concurrency=compute_strategy.size,
ray_remote_args_fn=ray_remote_args_fn,
ray_remote_args=ray_remote_args,
)
elif isinstance(compute_strategy, ActorPoolStrategy):
Expand All @@ -163,6 +173,7 @@ def create(
compute_strategy=compute_strategy,
name=name,
min_rows_per_bundle=min_rows_per_bundle,
ray_remote_args_fn=ray_remote_args_fn,
ray_remote_args=ray_remote_args,
)
else:
Expand Down Expand Up @@ -198,7 +209,7 @@ def __call__(self, args):
self.i %= len(self.locs)
return args

self._ray_remote_args_factory = RoundRobinAssign(locs)
self._ray_remote_args_factory_actor_locality = RoundRobinAssign(locs)

map_transformer = self._map_transformer
# Apply additional block split if needed.
Expand Down Expand Up @@ -227,6 +238,12 @@ def _get_runtime_ray_remote_args(
self, input_bundle: Optional[RefBundle] = None
) -> Dict[str, Any]:
ray_remote_args = copy.deepcopy(self._ray_remote_args)

# Override parameters from user provided remote args function.
if self._ray_remote_args_fn:
new_remote_args = self._ray_remote_args_fn()
for k, v in new_remote_args.items():
ray_remote_args[k] = v
# For tasks with small args, we will use SPREAD by default to optimize for
# compute load-balancing. For tasks with large args, we will use DEFAULT to
# allow the Ray locality scheduler a chance to optimize task placement.
Expand All @@ -246,8 +263,8 @@ def _get_runtime_ray_remote_args(
self._remote_args_for_metrics = copy.deepcopy(ray_remote_args)
# This should take precedence over previously set scheduling strategy, as it
# implements actor-based locality overrides.
if self._ray_remote_args_factory:
return self._ray_remote_args_factory(ray_remote_args)
if self._ray_remote_args_factory_actor_locality:
return self._ray_remote_args_factory_actor_locality(ray_remote_args)
return ray_remote_args

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Any, Callable, Dict, Optional

import ray
from ray.data._internal.execution.interfaces import (
Expand All @@ -24,6 +24,7 @@ def __init__(
name: str = "TaskPoolMap",
min_rows_per_bundle: Optional[int] = None,
concurrency: Optional[int] = None,
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
):
"""Create an TaskPoolMapOperator instance.
Expand All @@ -40,6 +41,12 @@ def __init__(
The actual rows passed may be less if the dataset is small.
concurrency: The maximum number of Ray tasks to use concurrently,
or None to use as many tasks as possible.
ray_remote_args_fn: A function that returns a dictionary of remote args
passed to each map worker. The purpose of this argument is to generate
dynamic arguments for each actor/task, and will be called each time
prior to initializing the worker. Args returned from this dict will
always override the args in ``ray_remote_args``. Note: this is an
advanced, experimental feature.
ray_remote_args: Customize the ray remote args for this op's tasks.
"""
super().__init__(
Expand All @@ -48,6 +55,7 @@ def __init__(
name,
target_max_block_size,
min_rows_per_bundle,
ray_remote_args_fn,
ray_remote_args,
)
self._concurrency = concurrency
Expand Down
26 changes: 25 additions & 1 deletion python/ray/data/_internal/logical/operators/map_operator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
import logging
from typing import Any, Dict, Iterable, Optional, Union
from typing import Any, Callable, Dict, Iterable, Optional, Union

from ray.data._internal.compute import ComputeStrategy, TaskPoolStrategy
from ray.data._internal.logical.interfaces import LogicalOperator
Expand All @@ -25,6 +25,7 @@ def __init__(
*,
min_rows_per_bundled_input: Optional[int] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
):
"""
Args:
Expand All @@ -35,10 +36,17 @@ def __init__(
min_rows_per_bundled_input: The target number of rows to pass to
``MapOperator._add_bundled_input()``.
ray_remote_args: Args to provide to ray.remote.
ray_remote_args_fn: A function that returns a dictionary of remote args
passed to each map worker. The purpose of this argument is to generate
dynamic arguments for each actor/task, and will be called each time
prior to initializing the worker. Args returned from this dict will
always override the args in ``ray_remote_args``. Note: this is an
advanced, experimental feature.
"""
super().__init__(name, input_op, num_outputs)
self._min_rows_per_bundled_input = min_rows_per_bundled_input
self._ray_remote_args = ray_remote_args or {}
self._ray_remote_args_fn = ray_remote_args_fn


class AbstractUDFMap(AbstractMap):
Expand All @@ -57,6 +65,7 @@ def __init__(
fn_constructor_kwargs: Optional[Dict[str, Any]] = None,
min_rows_per_bundled_input: Optional[int] = None,
compute: Optional[Union[str, ComputeStrategy]] = None,
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
):
"""
Expand All @@ -76,6 +85,12 @@ def __init__(
``MapOperator._add_bundled_input()``.
compute: The compute strategy, either ``"tasks"`` (default) to use Ray
tasks, or ``"actors"`` to use an autoscaling actor pool.
ray_remote_args_fn: A function that returns a dictionary of remote args
passed to each map worker. The purpose of this argument is to generate
dynamic arguments for each actor/task, and will be called each time
prior to initializing the worker. Args returned from this dict will
always override the args in ``ray_remote_args``. Note: this is an
advanced, experimental feature.
ray_remote_args: Args to provide to ray.remote.
"""
name = self._get_operator_name(name, fn)
Expand All @@ -91,6 +106,7 @@ def __init__(
self._fn_constructor_args = fn_constructor_args
self._fn_constructor_kwargs = fn_constructor_kwargs
self._compute = compute or TaskPoolStrategy()
self._ray_remote_args_fn = ray_remote_args_fn

def _get_operator_name(self, op_name: str, fn: UserDefinedFunction):
"""Gets the Operator name including the map `fn` UDF name."""
Expand Down Expand Up @@ -135,6 +151,7 @@ def __init__(
fn_constructor_kwargs: Optional[Dict[str, Any]] = None,
min_rows_per_bundled_input: Optional[int] = None,
compute: Optional[Union[str, ComputeStrategy]] = None,
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
):
super().__init__(
Expand All @@ -147,6 +164,7 @@ def __init__(
fn_constructor_kwargs=fn_constructor_kwargs,
min_rows_per_bundled_input=min_rows_per_bundled_input,
compute=compute,
ray_remote_args_fn=ray_remote_args_fn,
ray_remote_args=ray_remote_args,
)
self._batch_size = batch_size
Expand All @@ -170,6 +188,7 @@ def __init__(
fn_constructor_args: Optional[Iterable[Any]] = None,
fn_constructor_kwargs: Optional[Dict[str, Any]] = None,
compute: Optional[Union[str, ComputeStrategy]] = None,
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
):
super().__init__(
Expand All @@ -181,6 +200,7 @@ def __init__(
fn_constructor_args=fn_constructor_args,
fn_constructor_kwargs=fn_constructor_kwargs,
compute=compute,
ray_remote_args_fn=ray_remote_args_fn,
ray_remote_args=ray_remote_args,
)

Expand All @@ -197,13 +217,15 @@ def __init__(
input_op: LogicalOperator,
fn: UserDefinedFunction,
compute: Optional[Union[str, ComputeStrategy]] = None,
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
):
super().__init__(
"Filter",
input_op,
fn,
compute=compute,
ray_remote_args_fn=ray_remote_args_fn,
ray_remote_args=ray_remote_args,
)

Expand All @@ -224,6 +246,7 @@ def __init__(
fn_constructor_args: Optional[Iterable[Any]] = None,
fn_constructor_kwargs: Optional[Dict[str, Any]] = None,
compute: Optional[Union[str, ComputeStrategy]] = None,
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
):
super().__init__(
Expand All @@ -235,6 +258,7 @@ def __init__(
fn_constructor_args=fn_constructor_args,
fn_constructor_kwargs=fn_constructor_kwargs,
compute=compute,
ray_remote_args_fn=ray_remote_args_fn,
ray_remote_args=ray_remote_args,
)

Expand Down
13 changes: 13 additions & 0 deletions python/ray/data/_internal/logical/rules/operator_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,13 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool:
):
return False

# Do not fuse if either op specifies a `_ray_remote_args_fn`,
# since it is not known whether the generated args will be compatible.
if getattr(up_logical_op, "_ray_remote_args_fn", None) or getattr(
down_logical_op, "_ray_remote_args_fn", None
):
return False

if not self._can_merge_target_max_block_size(
up_op.target_max_block_size, down_op.target_max_block_size
):
Expand Down Expand Up @@ -296,6 +303,9 @@ def _get_fused_map_operator(
if isinstance(down_logical_op, AbstractUDFMap):
compute = get_compute(down_logical_op._compute)
ray_remote_args = up_logical_op._ray_remote_args
ray_remote_args_fn = (
up_logical_op._ray_remote_args_fn or down_logical_op._ray_remote_args_fn
)
# Make the upstream operator's inputs the new, fused operator's inputs.
input_deps = up_op.input_dependencies
assert len(input_deps) == 1
Expand All @@ -310,6 +320,7 @@ def _get_fused_map_operator(
compute_strategy=compute,
min_rows_per_bundle=min_rows_per_bundled_input,
ray_remote_args=ray_remote_args,
ray_remote_args_fn=ray_remote_args_fn,
)

# Build a map logical operator to be used as a reference for further fusion.
Expand All @@ -331,6 +342,7 @@ def _get_fused_map_operator(
down_logical_op._fn_constructor_kwargs,
min_rows_per_bundled_input,
compute,
ray_remote_args_fn,
ray_remote_args,
)
else:
Expand All @@ -341,6 +353,7 @@ def _get_fused_map_operator(
name,
input_op,
min_rows_per_bundled_input=min_rows_per_bundled_input,
ray_remote_args_fn=ray_remote_args_fn,
ray_remote_args=ray_remote_args,
)
self._op_map[op] = logical_op
Expand Down
1 change: 1 addition & 0 deletions python/ray/data/_internal/planner/plan_udf_map_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def plan_udf_map_op(
target_max_block_size=None,
compute_strategy=compute,
min_rows_per_bundle=op._min_rows_per_bundled_input,
ray_remote_args_fn=op._ray_remote_args_fn,
ray_remote_args=op._ray_remote_args,
)

Expand Down
Loading

0 comments on commit 3da99fb

Please sign in to comment.