Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][1/3] Make streaming generator public API #38784

Merged
merged 17 commits into from
Nov 27, 2023
12 changes: 9 additions & 3 deletions python/ray/_private/ray_option_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,15 @@ def issubclass_safe(obj: Any, cls_: type) -> bool:
(int, str, type(None)),
lambda x: None
if (x is None or x == "dynamic" or x == "streaming" or x >= 0)
else "The keyword 'num_returns' only accepts None, a non-negative integer, or "
'"dynamic" (for generators)',
default_value=1,
else "Default None. When None is passed, "
"The default value is 1 for a task and actor task, and "
"'streaming' for generator tasks and generator actor tasks. "
"The keyword 'num_returns' only accepts None, "
"a non-negative integer, "
"'streaming' (for generators), or 'dynamic'. 'dynamic' flag "
"will be deprecated in the future, and it is recommended to use "
"'streaming' instead.",
default_value=None,
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
),
"object_store_memory": Option( # override "_common_options"
(int, type(None)),
Expand Down
115 changes: 69 additions & 46 deletions python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
logger = logging.getLogger(__name__)


T = TypeVar("T")
T0 = TypeVar("T0")
T1 = TypeVar("T1")
T2 = TypeVar("T2")
Expand Down Expand Up @@ -2510,7 +2511,7 @@ def get(object_refs: "ObjectRef[R]", *, timeout: Optional[float] = None) -> R:
@PublicAPI
@client_mode_hook
def get(
object_refs: Union[ray.ObjectRef, Sequence[ray.ObjectRef]],
object_refs: Union["ObjectRef[Any]", Sequence["ObjectRef[Any]"]],
*,
timeout: Optional[float] = None,
) -> Union[Any, List[Any]]:
Expand All @@ -2530,6 +2531,8 @@ def get(
you can use ``await object_ref`` instead of ``ray.get(object_ref)``. For
a list of object refs, you can use ``await asyncio.gather(*object_refs)``.

Passing :class:`~StreamingObjectRefGenerator` is not allowed.

Related patterns and anti-patterns:

- :doc:`/ray-core/patterns/ray-get-loop`
Expand Down Expand Up @@ -2582,7 +2585,8 @@ def get(

if not isinstance(object_refs, list):
raise ValueError(
"'object_refs' must either be an ObjectRef or a list of ObjectRefs."
f"Invalid type of object refs, {type(object_refs)}, is given. "
"'object_refs' must either be an ObjectRef or a list of ObjectRefs. "
)

# TODO(ujvl): Consider how to allow user to retrieve the ready objects.
Expand Down Expand Up @@ -2682,50 +2686,65 @@ def put(
@PublicAPI
@client_mode_hook
def wait(
object_refs: List["ray.ObjectRef"],
ray_waitables: Union["ObjectRef[R]", "StreamingObjectRefGenerator[R]"],
*,
num_returns: int = 1,
timeout: Optional[float] = None,
fetch_local: bool = True,
) -> Tuple[List["ray.ObjectRef"], List["ray.ObjectRef"]]:
) -> Tuple[
List[Union["ObjectRef[R]", "StreamingObjectRefGenerator[R]"]],
List[Union["ObjectRef[R]", "StreamingObjectRefGenerator[R]"]],
]:
"""Return a list of IDs that are ready and a list of IDs that are not.

If timeout is set, the function returns either when the requested number of
IDs are ready or when the timeout is reached, whichever occurs first. If it
is not set, the function simply waits until that number of objects is ready
and returns that exact number of object refs.

This method returns two lists. The first list consists of object refs that
correspond to objects that are available in the object store. The second
list corresponds to the rest of the object refs (which may or may not be
ready).
`ray_waitables` is a list of :class:`~ObjectRef` and
:class:`~StreamingObjectRefGenerator`.

The method returns two lists, ready and unready `ray_waitables`.

ObjectRef:
object refs that correspond to objects that are available
in the object store are in the first list.
The rest of the object refs are in the second list.

Ordering of the input list of object refs is preserved. That is, if A
StreamingObjectRefGenerator:
Generators whose next reference (that will be obtained
via `next(generator)`) has a corresponding object available
in the object store are in the first list.
All other generators are placed in the second list.

Ordering of the input list of ray_waitables is preserved. That is, if A
precedes B in the input list, and both are in the ready list, then A will
precede B in the ready list. This also holds true if A and B are both in
the remaining list.

This method will issue a warning if it's running inside an async context.
Instead of ``ray.wait(object_refs)``, you can use
``await asyncio.wait(object_refs)``.
Instead of ``ray.wait(ray_waitables)``, you can use
``await asyncio.wait(ray_waitables)``.

Related patterns and anti-patterns:

- :doc:`/ray-core/patterns/limit-pending-tasks`
- :doc:`/ray-core/patterns/ray-get-submission-order`

Args:
object_refs: List of :class:`~ObjectRefs` or
:class:`~StreamingObjectRefGenerators` for objects that may or may
ray_waitables: List of :class:`~ObjectRef` or
:class:`~StreamingObjectRefGenerator` for objects that may or may
not be ready. Note that these must be unique.
num_returns: The number of object refs that should be returned.
num_returns: The number of ray_waitables that should be returned.
timeout: The maximum amount of time in seconds to wait before
returning.
fetch_local: If True, wait for the object to be downloaded onto
the local node before returning it as ready. If False, ray.wait()
will not trigger fetching of objects to the local node and will
return immediately once the object is available anywhere in the
cluster.
the local node before returning it as ready. If the `ray_waitable`
is a generator, it will wait until the next object in the generator
is downloaed. If False, ray.wait() will not trigger fetching of
objects to the local node and will return immediately once the
object is available anywhere in the cluster.

Returns:
A list of object refs that are ready and a list of the remaining object
Expand All @@ -2748,35 +2767,35 @@ def wait(
)
blocking_wait_inside_async_warned = True

if isinstance(object_refs, ObjectRef) or isinstance(
object_refs, StreamingObjectRefGenerator
if isinstance(ray_waitables, ObjectRef) or isinstance(
ray_waitables, StreamingObjectRefGenerator
):
raise TypeError(
"wait() expected a list of ray.ObjectRef or ray.StreamingObjectRefGenerator"
", got a single ray.ObjectRef or ray.StreamingObjectRefGenerator "
f"{object_refs}"
f"{ray_waitables}"
)

if not isinstance(object_refs, list):
if not isinstance(ray_waitables, list):
raise TypeError(
"wait() expected a list of ray.ObjectRef or "
"ray.StreamingObjectRefGenerator, "
f"got {type(object_refs)}"
f"got {type(ray_waitables)}"
)

if timeout is not None and timeout < 0:
raise ValueError(
"The 'timeout' argument must be nonnegative. " f"Received {timeout}"
)

for object_ref in object_refs:
if not isinstance(object_ref, ObjectRef) and not isinstance(
object_ref, StreamingObjectRefGenerator
for ray_waitable in ray_waitables:
if not isinstance(ray_waitable, ObjectRef) and not isinstance(
ray_waitable, StreamingObjectRefGenerator
):
raise TypeError(
"wait() expected a list of ray.ObjectRef or "
"ray.StreamingObjectRefGenerator, "
f"got list containing {type(object_ref)}"
f"got list containing {type(ray_waitable)}"
)
worker.check_connected()

Expand All @@ -2785,23 +2804,23 @@ def wait(
# TODO(rkn): This is a temporary workaround for
# https://github.com/ray-project/ray/issues/997. However, it should be
# fixed in Arrow instead of here.
if len(object_refs) == 0:
if len(ray_waitables) == 0:
return [], []

if len(object_refs) != len(set(object_refs)):
raise ValueError("Wait requires a list of unique object refs.")
if len(ray_waitables) != len(set(ray_waitables)):
raise ValueError("Wait requires a list of unique ray_waitables.")
if num_returns <= 0:
raise ValueError("Invalid number of objects to return %d." % num_returns)
if num_returns > len(object_refs):
if num_returns > len(ray_waitables):
raise ValueError(
"num_returns cannot be greater than the number "
"of objects provided to ray.wait."
"of ray_waitables provided to ray.wait."
)

timeout = timeout if timeout is not None else 10**6
timeout_milliseconds = int(timeout * 1000)
ready_ids, remaining_ids = worker.core_worker.wait(
object_refs,
ray_waitables,
num_returns,
timeout_milliseconds,
worker.current_task_id,
Expand Down Expand Up @@ -2879,7 +2898,10 @@ def kill(actor: "ray.actor.ActorHandle", *, no_restart: bool = True):
@PublicAPI
@client_mode_hook
def cancel(
object_ref: "ray.ObjectRef", *, force: bool = False, recursive: bool = True
ray_waitable: Union["ObjectRef[R]", "StreamingObjectRefGenerator[R]"],
*,
force: bool = False,
recursive: bool = True,
) -> None:
"""Cancels a task.

Expand Down Expand Up @@ -2927,26 +2949,27 @@ def cancel(
are cancelled.

Args:
object_ref: ObjectRef returned by the Task
that should be cancelled.
force: Whether to force-kill a running Task by killing
the worker that is running the Task.
recursive: Whether to try to cancel Tasks submitted by the
Task specified.
ray_waitable: :class:`~ObjectRef` and
:class:`~StreamingObjectRefGenerator`
returned by the task that should be canceled.
force: Whether to force-kill a running task by killing
the worker that is running the task.
recursive: Whether to try to cancel tasks submitted by the
task specified.
"""
worker = ray._private.worker.global_worker
worker.check_connected()

if isinstance(object_ref, ray._raylet.StreamingObjectRefGenerator):
assert hasattr(object_ref, "_generator_ref")
object_ref = object_ref._generator_ref
if isinstance(ray_waitable, ray._raylet.StreamingObjectRefGenerator):
assert hasattr(ray_waitable, "_generator_ref")
ray_waitable = ray_waitable._generator_ref

if not isinstance(object_ref, ray.ObjectRef):
if not isinstance(ray_waitable, ray.ObjectRef):
raise TypeError(
"ray.cancel() only supported for object refs. "
f"For actors, try ray.kill(). Got: {type(object_ref)}."
f"For actors, try ray.kill(). Got: {type(ray_waitable)}."
)
return worker.core_worker.cancel_task(object_ref, force, recursive)
return worker.core_worker.cancel_task(ray_waitable, force, recursive)


def _mode(worker=global_worker):
Expand Down
Loading
Loading