diff --git a/python/ray/_private/ray_option_utils.py b/python/ray/_private/ray_option_utils.py index 56d998ace406d..de27a125b3241 100644 --- a/python/ray/_private/ray_option_utils.py +++ b/python/ray/_private/ray_option_utils.py @@ -169,9 +169,15 @@ def issubclass_safe(obj: Any, cls_: type) -> bool: (int, str, type(None)), lambda x: None if (x is None or x == "dynamic" or x == "streaming" or x >= 0) - else "The keyword 'num_returns' only accepts None, a non-negative integer, or " - '"dynamic" (for generators)', - default_value=1, + else "Default None. When None is passed, " + "The default value is 1 for a task and actor task, and " + "'streaming' for generator tasks and generator actor tasks. " + "The keyword 'num_returns' only accepts None, " + "a non-negative integer, " + "'streaming' (for generators), or 'dynamic'. 'dynamic' flag " + "will be deprecated in the future, and it is recommended to use " + "'streaming' instead.", + default_value=None, ), "object_store_memory": Option( # override "_common_options" (int, type(None)), diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index 47e5b1b6ee782..1d3324fa58fe8 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -107,6 +107,7 @@ logger = logging.getLogger(__name__) +T = TypeVar("T") T0 = TypeVar("T0") T1 = TypeVar("T1") T2 = TypeVar("T2") @@ -2510,7 +2511,7 @@ def get(object_refs: "ObjectRef[R]", *, timeout: Optional[float] = None) -> R: @PublicAPI @client_mode_hook def get( - object_refs: Union[ray.ObjectRef, Sequence[ray.ObjectRef]], + object_refs: Union["ObjectRef[Any]", Sequence["ObjectRef[Any]"]], *, timeout: Optional[float] = None, ) -> Union[Any, List[Any]]: @@ -2530,6 +2531,8 @@ def get( you can use ``await object_ref`` instead of ``ray.get(object_ref)``. For a list of object refs, you can use ``await asyncio.gather(*object_refs)``. + Passing :class:`~StreamingObjectRefGenerator` is not allowed. + Related patterns and anti-patterns: - :doc:`/ray-core/patterns/ray-get-loop` @@ -2582,7 +2585,8 @@ def get( if not isinstance(object_refs, list): raise ValueError( - "'object_refs' must either be an ObjectRef or a list of ObjectRefs." + f"Invalid type of object refs, {type(object_refs)}, is given. " + "'object_refs' must either be an ObjectRef or a list of ObjectRefs. " ) # TODO(ujvl): Consider how to allow user to retrieve the ready objects. @@ -2682,12 +2686,15 @@ def put( @PublicAPI @client_mode_hook def wait( - object_refs: List["ray.ObjectRef"], + ray_waitables: Union["ObjectRef[R]", "StreamingObjectRefGenerator[R]"], *, num_returns: int = 1, timeout: Optional[float] = None, fetch_local: bool = True, -) -> Tuple[List["ray.ObjectRef"], List["ray.ObjectRef"]]: +) -> Tuple[ + List[Union["ObjectRef[R]", "StreamingObjectRefGenerator[R]"]], + List[Union["ObjectRef[R]", "StreamingObjectRefGenerator[R]"]], +]: """Return a list of IDs that are ready and a list of IDs that are not. If timeout is set, the function returns either when the requested number of @@ -2695,19 +2702,30 @@ def wait( is not set, the function simply waits until that number of objects is ready and returns that exact number of object refs. - This method returns two lists. The first list consists of object refs that - correspond to objects that are available in the object store. The second - list corresponds to the rest of the object refs (which may or may not be - ready). + `ray_waitables` is a list of :class:`~ObjectRef` and + :class:`~StreamingObjectRefGenerator`. + + The method returns two lists, ready and unready `ray_waitables`. + + ObjectRef: + object refs that correspond to objects that are available + in the object store are in the first list. + The rest of the object refs are in the second list. - Ordering of the input list of object refs is preserved. That is, if A + StreamingObjectRefGenerator: + Generators whose next reference (that will be obtained + via `next(generator)`) has a corresponding object available + in the object store are in the first list. + All other generators are placed in the second list. + + Ordering of the input list of ray_waitables is preserved. That is, if A precedes B in the input list, and both are in the ready list, then A will precede B in the ready list. This also holds true if A and B are both in the remaining list. This method will issue a warning if it's running inside an async context. - Instead of ``ray.wait(object_refs)``, you can use - ``await asyncio.wait(object_refs)``. + Instead of ``ray.wait(ray_waitables)``, you can use + ``await asyncio.wait(ray_waitables)``. Related patterns and anti-patterns: @@ -2715,17 +2733,18 @@ def wait( - :doc:`/ray-core/patterns/ray-get-submission-order` Args: - object_refs: List of :class:`~ObjectRefs` or - :class:`~StreamingObjectRefGenerators` for objects that may or may + ray_waitables: List of :class:`~ObjectRef` or + :class:`~StreamingObjectRefGenerator` for objects that may or may not be ready. Note that these must be unique. - num_returns: The number of object refs that should be returned. + num_returns: The number of ray_waitables that should be returned. timeout: The maximum amount of time in seconds to wait before returning. fetch_local: If True, wait for the object to be downloaded onto - the local node before returning it as ready. If False, ray.wait() - will not trigger fetching of objects to the local node and will - return immediately once the object is available anywhere in the - cluster. + the local node before returning it as ready. If the `ray_waitable` + is a generator, it will wait until the next object in the generator + is downloaed. If False, ray.wait() will not trigger fetching of + objects to the local node and will return immediately once the + object is available anywhere in the cluster. Returns: A list of object refs that are ready and a list of the remaining object @@ -2748,20 +2767,20 @@ def wait( ) blocking_wait_inside_async_warned = True - if isinstance(object_refs, ObjectRef) or isinstance( - object_refs, StreamingObjectRefGenerator + if isinstance(ray_waitables, ObjectRef) or isinstance( + ray_waitables, StreamingObjectRefGenerator ): raise TypeError( "wait() expected a list of ray.ObjectRef or ray.StreamingObjectRefGenerator" ", got a single ray.ObjectRef or ray.StreamingObjectRefGenerator " - f"{object_refs}" + f"{ray_waitables}" ) - if not isinstance(object_refs, list): + if not isinstance(ray_waitables, list): raise TypeError( "wait() expected a list of ray.ObjectRef or " "ray.StreamingObjectRefGenerator, " - f"got {type(object_refs)}" + f"got {type(ray_waitables)}" ) if timeout is not None and timeout < 0: @@ -2769,14 +2788,14 @@ def wait( "The 'timeout' argument must be nonnegative. " f"Received {timeout}" ) - for object_ref in object_refs: - if not isinstance(object_ref, ObjectRef) and not isinstance( - object_ref, StreamingObjectRefGenerator + for ray_waitable in ray_waitables: + if not isinstance(ray_waitable, ObjectRef) and not isinstance( + ray_waitable, StreamingObjectRefGenerator ): raise TypeError( "wait() expected a list of ray.ObjectRef or " "ray.StreamingObjectRefGenerator, " - f"got list containing {type(object_ref)}" + f"got list containing {type(ray_waitable)}" ) worker.check_connected() @@ -2785,23 +2804,23 @@ def wait( # TODO(rkn): This is a temporary workaround for # https://github.com/ray-project/ray/issues/997. However, it should be # fixed in Arrow instead of here. - if len(object_refs) == 0: + if len(ray_waitables) == 0: return [], [] - if len(object_refs) != len(set(object_refs)): - raise ValueError("Wait requires a list of unique object refs.") + if len(ray_waitables) != len(set(ray_waitables)): + raise ValueError("Wait requires a list of unique ray_waitables.") if num_returns <= 0: raise ValueError("Invalid number of objects to return %d." % num_returns) - if num_returns > len(object_refs): + if num_returns > len(ray_waitables): raise ValueError( "num_returns cannot be greater than the number " - "of objects provided to ray.wait." + "of ray_waitables provided to ray.wait." ) timeout = timeout if timeout is not None else 10**6 timeout_milliseconds = int(timeout * 1000) ready_ids, remaining_ids = worker.core_worker.wait( - object_refs, + ray_waitables, num_returns, timeout_milliseconds, worker.current_task_id, @@ -2879,7 +2898,10 @@ def kill(actor: "ray.actor.ActorHandle", *, no_restart: bool = True): @PublicAPI @client_mode_hook def cancel( - object_ref: "ray.ObjectRef", *, force: bool = False, recursive: bool = True + ray_waitable: Union["ObjectRef[R]", "StreamingObjectRefGenerator[R]"], + *, + force: bool = False, + recursive: bool = True, ) -> None: """Cancels a task. @@ -2927,26 +2949,27 @@ def cancel( are cancelled. Args: - object_ref: ObjectRef returned by the Task - that should be cancelled. - force: Whether to force-kill a running Task by killing - the worker that is running the Task. - recursive: Whether to try to cancel Tasks submitted by the - Task specified. + ray_waitable: :class:`~ObjectRef` and + :class:`~StreamingObjectRefGenerator` + returned by the task that should be canceled. + force: Whether to force-kill a running task by killing + the worker that is running the task. + recursive: Whether to try to cancel tasks submitted by the + task specified. """ worker = ray._private.worker.global_worker worker.check_connected() - if isinstance(object_ref, ray._raylet.StreamingObjectRefGenerator): - assert hasattr(object_ref, "_generator_ref") - object_ref = object_ref._generator_ref + if isinstance(ray_waitable, ray._raylet.StreamingObjectRefGenerator): + assert hasattr(ray_waitable, "_generator_ref") + ray_waitable = ray_waitable._generator_ref - if not isinstance(object_ref, ray.ObjectRef): + if not isinstance(ray_waitable, ray.ObjectRef): raise TypeError( "ray.cancel() only supported for object refs. " - f"For actors, try ray.kill(). Got: {type(object_ref)}." + f"For actors, try ray.kill(). Got: {type(ray_waitable)}." ) - return worker.core_worker.cancel_task(object_ref, force, recursive) + return worker.core_worker.cancel_task(ray_waitable, force, recursive) def _mode(worker=global_worker): diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 48ce74677be2e..a0a75a44ff4bf 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -30,7 +30,7 @@ from typing import ( Any, Optional, Generator, - AsyncGenerator + AsyncGenerator, ) import contextvars @@ -267,11 +267,21 @@ class ObjectRefGenerator: class StreamingObjectRefGenerator: + """A generator to obtain object references + from a task in a streaming manner. + + The class is compatible with generator and + async generator interface. + + The class is not thread-safe. + + >>> gen = generator_task.remote() + >>> next(gen) + >>> await gen.__anext__() + """ def __init__(self, generator_ref: ObjectRef, worker: "Worker"): # The reference to a generator task. self._generator_ref = generator_ref - # The last time generator task has completed. - self._generator_task_completed_time = None # The exception raised from a generator task. self._generator_task_exception = None # Ray's worker class. ray._private.worker.global_worker @@ -279,11 +289,9 @@ class StreamingObjectRefGenerator: self.worker.check_connected() assert hasattr(worker, "core_worker") - def get_next_ref(self) -> ObjectRef: - self.worker.check_connected() - core_worker = self.worker.core_worker - return core_worker.peek_object_ref_stream( - self._generator_ref)[0] + """ + Public APIs + """ def __iter__(self) -> "StreamingObjectRefGenerator": return self @@ -301,12 +309,115 @@ class StreamingObjectRefGenerator: """ return self._next_sync() - def __aiter__(self): + def send(self, value): + raise NotImplementedError("`gen.send` is not supported.") + + def throw(self, value): + raise NotImplementedError("`gen.throw` is not supported.") + + def close(self): + raise NotImplementedError("`gen.close` is not supported.") + + def __aiter__(self) -> "StreamingObjectRefGenerator": return self async def __anext__(self): return await self._next_async() + async def asend(self, value): + raise NotImplementedError("`gen.asend` is not supported.") + + async def athrow(self, value): + raise NotImplementedError("`gen.athrow` is not supported.") + + async def aclose(self): + raise NotImplementedError("`gen.aclose` is not supported.") + + def completed(self) -> ObjectRef: + """Returns an object ref that is ready when + a generator task completes. + + If the task is failed unexpectedly (e.g., worker failure), + the `ray.get(gen.completed())` raises an exception. + + The function returns immediately. + + >>> ray.get(gen.completed()) + """ + return self._generator_ref + + def next_ready(self) -> bool: + """If True, it means the output of next(gen) is ready and + ray.get(next(gen)) returns immediately. False otherwise. + + It returns False when next(gen) raises a StopIteration + (this condition should be checked using is_finished). + + The function returns immediately. + """ + self.worker.check_connected() + core_worker = self.worker.core_worker + + if self.is_finished(): + return False + + expected_ref, is_ready = core_worker.peek_object_ref_stream( + self._generator_ref) + + if is_ready: + return True + + ready, _ = ray.wait( + [expected_ref], timeout=0, fetch_local=False) + return len(ready) > 0 + + def is_finished(self) -> bool: + """If True, it means the generator is finished + and all output is taken. False otherwise. + + When True, if next(gen) is called, it will raise StopIteration + or StopAsyncIteration + + The function returns immediately. + """ + self.worker.check_connected() + core_worker = self.worker.core_worker + + finished = core_worker.is_object_ref_stream_finished( + self._generator_ref) + + if finished: + if self._generator_task_exception: + return True + else: + # We should try ray.get on a generator ref. + # If it raises an exception and + # _generator_task_exception is not set, + # this means the last ref is not taken yet. + try: + ray.get(self._generator_ref) + except Exception: + # The exception from _generator_ref + # hasn't been taken yet. + return False + else: + return True + + """ + Private APIs + """ + + def _get_next_ref(self) -> ObjectRef: + """Return the next reference from a generator. + + Note that the ObjectID generated from a generator + is always deterministic. + """ + self.worker.check_connected() + core_worker = self.worker.core_worker + return core_worker.peek_object_ref_stream( + self._generator_ref)[0] + def _next_sync( self, timeout_s: Optional[float] = None @@ -366,7 +477,7 @@ class StreamingObjectRefGenerator: raise StopIteration return ref - async def suppress_exceptions(self, ref: ObjectRef): + async def _suppress_exceptions(self, ref: ObjectRef) -> None: # Wrap a streamed ref to avoid asyncio warnings about not retrieving # the exception when we are just waiting for the ref to become ready. # The exception will get returned (or warned) to the user once they @@ -388,7 +499,7 @@ class StreamingObjectRefGenerator: if not is_ready: # TODO(swang): Avoid fetching the value. ready, unready = await asyncio.wait( - [asyncio.create_task(self.suppress_exceptions(ref))], + [asyncio.create_task(self._suppress_exceptions(ref))], timeout=timeout_s ) if len(unready) > 0: @@ -3434,7 +3545,7 @@ cdef class CoreWorker: if isinstance(ref_or_generator, StreamingObjectRefGenerator): # Before calling wait, # get the next reference from a generator. - object_refs.append(ref_or_generator.get_next_ref()) + object_refs.append(ref_or_generator._get_next_ref()) else: object_refs.append(ref_or_generator) @@ -3984,6 +4095,7 @@ cdef class CoreWorker: method_meta = ray.actor._ActorClassMethodMetadata.create( actor_class, actor_creation_function_descriptor) return ray.actor.ActorHandle(language, actor_id, + method_meta.method_is_generator, method_meta.decorators, method_meta.signatures, method_meta.num_returns, @@ -3993,6 +4105,7 @@ cdef class CoreWorker: worker.current_session_and_job) else: return ray.actor.ActorHandle(language, actor_id, + {}, # method is_generator {}, # method decorators {}, # method signatures {}, # method num_returns @@ -4662,6 +4775,16 @@ cdef class CoreWorker: # Already added when the ref is updated. skip_adding_local_ref=True) + def is_object_ref_stream_finished(self, ObjectRef generator_id): + cdef: + CObjectID c_generator_id = generator_id.native() + c_bool finished + + with nogil: + finished = CCoreWorkerProcess.GetCoreWorker().IsFinished( + c_generator_id) + return finished + def peek_object_ref_stream(self, ObjectRef generator_id): cdef: CObjectID c_generator_id = generator_id.native() diff --git a/python/ray/actor.py b/python/ray/actor.py index a6adfd5e862a7..f42da3208f90e 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -1,7 +1,7 @@ import inspect import logging import weakref -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import ray._private.ray_constants as ray_constants import ray._private.signature as signature @@ -114,7 +114,10 @@ class ActorMethod: _actor_ref: A weakref handle to the actor. _method_name: The name of the actor method. _num_returns: The default number of return values that the method - invocation should return. + invocation should return. If None is given, it uses + DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS for a normal actor task + and "streaming" for a generator task (when `is_generator` is True). + _is_generator: True if a given method is a Python generator. _generator_backpressure_num_objects: Generator-only config. If a number of unconsumed objects reach this threshold, a actor task stop pausing. @@ -129,9 +132,10 @@ class ActorMethod: def __init__( self, - actor: object, - method_name: str, - num_returns: int, + actor, + method_name, + num_returns: Optional[Union[int, str]], + is_generator: bool, generator_backpressure_num_objects: int, decorator=None, hardref=False, @@ -139,6 +143,15 @@ def __init__( self._actor_ref = weakref.ref(actor) self._method_name = method_name self._num_returns = num_returns + self._is_generator = is_generator + + # Default case. + if self._num_returns is None: + if is_generator: + self._num_returns = "streaming" + else: + self._num_returns = ray_constants.DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS + self._generator_backpressure_num_objects = generator_backpressure_num_objects # This is a decorator that is used to wrap the function invocation (as # opposed to the function execution). The decorator must return a @@ -230,6 +243,8 @@ def __getstate__(self): "method_name": self._method_name, "num_returns": self._num_returns, "decorator": self._decorator, + "is_generator": self._is_generator, + "generator_backpressure_num_objects": self._generator_backpressure_num_objects, # noqa } def __setstate__(self, state): @@ -237,6 +252,8 @@ def __setstate__(self, state): state["actor"], state["method_name"], state["num_returns"], + state["is_generator"], + state["generator_backpressure_num_objects"], state["decorator"], hardref=True, ) @@ -289,6 +306,7 @@ def create(cls, modified_class, actor_creation_function_descriptor): self.decorators = {} self.signatures = {} self.num_returns = {} + self.method_is_generator = {} self.generator_backpressure_num_objects = {} self.concurrency_group_for_methods = {} @@ -312,9 +330,7 @@ def create(cls, modified_class, actor_creation_function_descriptor): if hasattr(method, "__ray_num_returns__"): self.num_returns[method_name] = method.__ray_num_returns__ else: - self.num_returns[ - method_name - ] = ray_constants.DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS + self.num_returns[method_name] = None if hasattr(method, "__ray_invocation_decorator__"): self.decorators[method_name] = method.__ray_invocation_decorator__ @@ -324,6 +340,11 @@ def create(cls, modified_class, actor_creation_function_descriptor): method_name ] = method.__ray_concurrency_group__ + is_generator = inspect.isgeneratorfunction( + method + ) or inspect.isasyncgenfunction(method) + self.method_is_generator[method_name] = is_generator + if hasattr(method, "__ray_generator_backpressure_num_objects__"): self.generator_backpressure_num_objects[ method_name @@ -1038,6 +1059,7 @@ def _remote(self, args=None, kwargs=None, **actor_options): actor_handle = ActorHandle( meta.language, actor_id, + meta.method_meta.method_is_generator, meta.method_meta.decorators, meta.method_meta.signatures, meta.method_meta.num_returns, @@ -1078,6 +1100,8 @@ class ActorHandle: Attributes: _ray_actor_language: The actor language. _ray_actor_id: Actor ID. + _ray_method_is_generator: Map of method name -> if it is a generator + method. _ray_method_decorators: Optional decorators for the function invocation. This can be used to change the behavior on the invocation side, whereas a regular decorator can be used to change @@ -1101,6 +1125,7 @@ def __init__( self, language, actor_id, + method_is_generator: Dict[str, bool], method_decorators, method_signatures, method_num_returns: Dict[str, int], @@ -1113,6 +1138,7 @@ def __init__( self._ray_actor_language = language self._ray_actor_id = actor_id self._ray_original_handle = original_handle + self._ray_method_is_generator = method_is_generator self._ray_method_decorators = method_decorators self._ray_method_signatures = method_signatures self._ray_method_num_returns = method_num_returns @@ -1142,6 +1168,7 @@ def __init__( self, method_name, self._ray_method_num_returns[method_name], + self._ray_method_is_generator[method_name], self._ray_method_generator_backpressure_num_objects.get( method_name ), # noqa @@ -1280,9 +1307,8 @@ def remote(self, *args, **kwargs): return ActorMethod( self, item, - ray_constants. - # Currently, we use default num returns - DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS, + ray_constants.DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS, + False, self._ray_method_generator_backpressure_num_objects.get(item, -1), # Currently, cross-lang actor method not support decorator decorator=None, @@ -1321,6 +1347,7 @@ def _serialization_helper(self): { "actor_language": self._ray_actor_language, "actor_id": self._ray_actor_id, + "method_is_generator": self._ray_method_is_generator, "method_decorators": self._ray_method_decorators, "method_signatures": self._ray_method_signatures, "method_num_returns": self._ray_method_num_returns, @@ -1361,6 +1388,7 @@ def _deserialization_helper(cls, state, outer_object_ref=None): # thread-safe. state["actor_language"], state["actor_id"], + state["method_is_generator"], state["method_decorators"], state["method_signatures"], state["method_num_returns"], diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index bccecb03596df..0f9d158cca352 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -162,6 +162,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: CRayStatus TryReadObjectRefStream( const CObjectID &generator_id, CObjectReference *object_ref_out) + c_bool IsFinished(const CObjectID &generator_id) const pair[CObjectReference, c_bool] PeekObjectRefStream( const CObjectID &generator_id) CObjectID AllocateDynamicReturnId( diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index 77212413ad150..317a28bc6a3b7 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -121,6 +121,7 @@ def __init__( self._default_options["runtime_env"] = self._runtime_env self._language = language + self._is_generator = inspect.isgeneratorfunction(function) self._function = function self._function_signature = None # Guards trace injection to enforce exactly once semantics @@ -331,7 +332,14 @@ def _remote(self, args=None, kwargs=None, **task_options): "placement_group_capture_child_tasks" ] scheduling_strategy = task_options["scheduling_strategy"] + num_returns = task_options["num_returns"] + if num_returns is None: + if self._is_generator: + num_returns = "streaming" + else: + num_returns = 1 + if num_returns == "dynamic": num_returns = -1 elif num_returns == "streaming": diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 87e422bc22a93..58001ccafd67d 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -335,7 +335,8 @@ class A: with pytest.raises( ValueError, match=f"The keyword '{keyword}' only accepts None, " - 'a non-negative integer, or "dynamic"', + "a non-negative integer, " + "'streaming' \(for generators\), or 'dynamic'", ): ray.remote(**{keyword: v})(f) diff --git a/python/ray/tests/test_generators.py b/python/ray/tests/test_generators.py index 7d7192f3d4e41..44ba175dcf0b6 100644 --- a/python/ray/tests/test_generators.py +++ b/python/ray/tests/test_generators.py @@ -132,7 +132,7 @@ def generator(num_returns, store_in_plasma): @pytest.mark.parametrize("use_actors", [False, True]) @pytest.mark.parametrize("store_in_plasma", [False, True]) -@pytest.mark.parametrize("num_returns_type", ["dynamic", "streaming"]) +@pytest.mark.parametrize("num_returns_type", ["dynamic", None]) def test_generator_errors( ray_start_regular, use_actors, store_in_plasma, num_returns_type ): @@ -186,7 +186,7 @@ def generator(num_returns, store_in_plasma): @pytest.mark.parametrize("store_in_plasma", [False, True]) -@pytest.mark.parametrize("num_returns_type", ["dynamic", "streaming"]) +@pytest.mark.parametrize("num_returns_type", ["dynamic", None]) def test_dynamic_generator_retry_exception( ray_start_regular, store_in_plasma, num_returns_type ): @@ -239,7 +239,7 @@ def generator(num_returns, store_in_plasma, counter): @pytest.mark.parametrize("use_actors", [False, True]) @pytest.mark.parametrize("store_in_plasma", [False, True]) -@pytest.mark.parametrize("num_returns_type", ["dynamic", "streaming"]) +@pytest.mark.parametrize("num_returns_type", ["dynamic", None]) def test_dynamic_generator( ray_start_regular, use_actors, store_in_plasma, num_returns_type ): @@ -315,18 +315,19 @@ def read(gen): ) ) - # Normal remote functions don't work with num_returns="dynamic". - @ray.remote(num_returns=num_returns_type) - def static(num_returns): - return list(range(num_returns)) + if num_returns_type == "dynamic": + # Normal remote functions don't work with num_returns="dynamic". + @ray.remote(num_returns=num_returns_type) + def static(num_returns): + return list(range(num_returns)) - with pytest.raises(ray.exceptions.RayTaskError): - gen = ray.get(static.remote(3)) - for ref in gen: - ray.get(ref) + with pytest.raises(ray.exceptions.RayTaskError): + gen = ray.get(static.remote(3)) + for ref in gen: + ray.get(ref) -@pytest.mark.parametrize("num_returns_type", ["dynamic", "streaming"]) +@pytest.mark.parametrize("num_returns_type", ["dynamic", None]) def test_dynamic_generator_distributed(ray_start_cluster, num_returns_type): cluster = ray_start_cluster # Head node with no resources. @@ -347,7 +348,7 @@ def dynamic_generator(num_returns): assert ray.get(ref)[0] == i -@pytest.mark.parametrize("num_returns_type", ["dynamic", "streaming"]) +@pytest.mark.parametrize("num_returns_type", ["dynamic", None]) def test_dynamic_generator_reconstruction(ray_start_cluster, num_returns_type): config = { "health_check_failure_threshold": 10, @@ -407,7 +408,7 @@ def fetch(x): @pytest.mark.parametrize("too_many_returns", [False, True]) -@pytest.mark.parametrize("num_returns_type", ["dynamic", "streaming"]) +@pytest.mark.parametrize("num_returns_type", ["dynamic", None]) def test_dynamic_generator_reconstruction_nondeterministic( ray_start_cluster, too_many_returns, num_returns_type ): @@ -491,14 +492,14 @@ def fetch(x): # ray.get(ref) del gen del refs - if num_returns_type == "streaming": + if num_returns_type is None: # TODO(sang): For some reasons, it fails when "dynamic" # is used. We don't fix the issue because we will # remove this flag soon anyway. assert_no_leak() -@pytest.mark.parametrize("num_returns_type", ["dynamic", "streaming"]) +@pytest.mark.parametrize("num_returns_type", ["dynamic", None]) def test_dynamic_generator_reconstruction_fails(ray_start_cluster, num_returns_type): config = { "health_check_failure_threshold": 10, @@ -566,7 +567,7 @@ def fetch(*refs): assert_no_leak() -@pytest.mark.parametrize("num_returns_type", ["dynamic", "streaming"]) +@pytest.mark.parametrize("num_returns_type", ["dynamic", None]) def test_dynamic_empty_generator_reconstruction_nondeterministic( ray_start_cluster, num_returns_type ): diff --git a/python/ray/tests/test_streaming_generator.py b/python/ray/tests/test_streaming_generator.py index c8bffe940fc3b..d4798377d62b9 100644 --- a/python/ray/tests/test_streaming_generator.py +++ b/python/ray/tests/test_streaming_generator.py @@ -141,7 +141,7 @@ def f(): for i in range(5): yield i - gen = f.options(num_returns="streaming").remote() + gen = f.remote() i = 0 for ref in gen: print(ray.get(ref)) @@ -159,7 +159,7 @@ def f(): raise ValueError yield i - gen = f.options(num_returns="streaming").remote() + gen = f.remote() print(ray.get(next(gen))) print(ray.get(next(gen))) with pytest.raises(ray.exceptions.RayTaskError) as e: @@ -185,7 +185,7 @@ def f(self): yield i a = A.remote() - gen = a.f.options(num_returns="streaming").remote() + gen = a.f.remote() i = 0 for ref in gen: if i == 2: @@ -228,7 +228,7 @@ def f(a): yield i a = Actor.remote() - gen = f.options(num_returns="streaming").remote(a) + gen = f.remote(a) assert ray.get(next(gen)) == 0 assert ray.get(next(gen)) == 1 assert ray.get(next(gen)) == 2 @@ -247,7 +247,7 @@ def f(): time.sleep(5) yield i - gen = f.options(num_returns="streaming").remote() + gen = f.remote() assert ray.get(next(gen)) == 0 ray.cancel(gen) with pytest.raises(ray.exceptions.RayTaskError) as e: @@ -275,7 +275,7 @@ def f(): raise UnserializableException yield 1 # noqa - for ref in f.options(num_returns="streaming").remote(): + for ref in f.remote(): with pytest.raises(ray.exceptions.RayTaskError): ray.get(ref) captured = capsys.readouterr() @@ -301,7 +301,7 @@ def test_generator_streaming_no_leak_upon_failures( @ray.remote def g(): try: - gen = f.options(num_returns="streaming").remote() + gen = f.remote() for ref in gen: print(ref) ray.get(ref) @@ -370,9 +370,7 @@ def generator(num_returns, store_in_plasma): remote_generator_fn = generator """Verify num_returns="streaming" is streaming""" - gen = remote_generator_fn.options(num_returns="streaming").remote( - 3, store_in_plasma - ) + gen = remote_generator_fn.remote(3, store_in_plasma) i = 0 for ref in gen: id = ref.hex() @@ -412,9 +410,7 @@ def get_data(self): time.sleep(0.1) yield np.ones(5 * 1024 * 1024) else: - for data in self.child.get_data.options( - num_returns="streaming" - ).remote(): + for data in self.child.get_data.remote(): yield ray.get(data) chain_actor = ChainActor.remote() @@ -422,7 +418,7 @@ def get_data(self): chain_actor_3 = ChainActor.remote(chain_actor_2) chain_actor_4 = ChainActor.remote(chain_actor_3) - for ref in chain_actor_4.get_data.options(num_returns="streaming").remote(): + for ref in chain_actor_4.get_data.remote(): assert np.array_equal(np.ones(5 * 1024 * 1024), ray.get(ref)) print("getting the next data") del ref @@ -443,7 +439,7 @@ def test_generator_slow_pinning_requests(monkeypatch, shutdown_only): def f(): yield np.ones(5 * 1024 * 1024) - for ref in f.options(num_returns="streaming").remote(): + for ref in f.remote(): del ref print(list_objects()) @@ -475,7 +471,7 @@ def g(self): arr = 3 def verify_sync_task_executor(): - generator = a.f.options(num_returns="streaming").remote(ray.put(arr)) + generator = a.f.remote(ray.put(arr)) # Verify it works with next. assert isinstance(generator, StreamingObjectRefGenerator) assert ray.get(next(generator)) == 0 @@ -485,26 +481,26 @@ def verify_sync_task_executor(): ray.get(next(generator)) # Verify it works with for. - generator = a.f.options(num_returns="streaming").remote(ray.put(3)) + generator = a.f.remote(ray.put(3)) for index, ref in enumerate(generator): assert index == ray.get(ref) def verify_async_task_executor(): # Verify it works with next. - generator = a.async_f.options(num_returns="streaming").remote(ray.put(arr)) + generator = a.async_f.remote(ray.put(arr)) assert isinstance(generator, StreamingObjectRefGenerator) assert ray.get(next(generator)) == 0 assert ray.get(next(generator)) == 1 assert ray.get(next(generator)) == 2 # Verify it works with for. - generator = a.f.options(num_returns="streaming").remote(ray.put(3)) + generator = a.f.remote(ray.put(3)) for index, ref in enumerate(generator): assert index == ray.get(ref) async def verify_sync_task_async_generator(): # Verify anext - async_generator = a.f.options(num_returns="streaming").remote(ray.put(arr)) + async_generator = a.f.remote(ray.put(arr)) assert isinstance(async_generator, StreamingObjectRefGenerator) for expected in range(3): ref = await async_generator.__anext__() @@ -513,7 +509,7 @@ async def verify_sync_task_async_generator(): await async_generator.__anext__() # Verify async for. - async_generator = a.f.options(num_returns="streaming").remote(ray.put(arr)) + async_generator = a.f.remote(ray.put(arr)) expected = 0 async for ref in async_generator: value = await ref @@ -521,9 +517,7 @@ async def verify_sync_task_async_generator(): expected += 1 async def verify_async_task_async_generator(): - async_generator = a.async_f.options(num_returns="streaming").remote( - ray.put(arr) - ) + async_generator = a.async_f.remote(ray.put(arr)) assert isinstance(async_generator, StreamingObjectRefGenerator) for expected in range(3): ref = await async_generator.__anext__() @@ -532,9 +526,7 @@ async def verify_async_task_async_generator(): await async_generator.__anext__() # Verify async for. - async_generator = a.async_f.options(num_returns="streaming").remote( - ray.put(arr) - ) + async_generator = a.async_f.remote(ray.put(arr)) expected = 0 async for ref in async_generator: value = await ref @@ -563,7 +555,7 @@ async def async_f(self): yield 1 # noqa a = Actor.remote() - g = a.f.options(num_returns="streaming").remote() + g = a.f.remote() with pytest.raises(ValueError): ray.get(next(g)) @@ -573,7 +565,7 @@ async def async_f(self): with pytest.raises(StopIteration): ray.get(next(g)) - g = a.async_f.options(num_returns="streaming").remote() + g = a.async_f.remote() with pytest.raises(ValueError): ray.get(next(g)) diff --git a/python/ray/tests/test_streaming_generator_2.py b/python/ray/tests/test_streaming_generator_2.py index 68f56b3fb84d4..8ddce574b3f05 100644 --- a/python/ray/tests/test_streaming_generator_2.py +++ b/python/ray/tests/test_streaming_generator_2.py @@ -52,7 +52,7 @@ def test_reconstruction(monkeypatch, ray_start_cluster, delay): node_to_kill = cluster.add_node(num_cpus=1, object_store_memory=10**8) cluster.wait_for_nodes() - @ray.remote(num_returns="streaming", max_retries=2) + @ray.remote(max_retries=2) def dynamic_generator(num_returns): for i in range(num_returns): yield np.ones(1_000_000, dtype=np.int8) * i @@ -62,7 +62,7 @@ def fetch(x): return x[0] # Test recovery of all dynamic objects through re-execution. - gen = ray.get(dynamic_generator.remote(10)) + gen = dynamic_generator.remote(10) refs = [] for i in range(5): @@ -133,7 +133,7 @@ def get(self): node_to_kill = cluster.add_node(num_cpus=1, object_store_memory=10**8) cluster.wait_for_nodes() - @ray.remote(num_returns="streaming") + @ray.remote def dynamic_generator(num_returns, signal_actor): for i in range(num_returns): if i == 3: @@ -150,7 +150,7 @@ def dynamic_generator(num_returns, signal_actor): def fetch(x): return x[0] - gen = ray.get(dynamic_generator.remote(10, signal)) + gen = dynamic_generator.remote(10, signal) refs = [] for i in range(5): @@ -197,7 +197,7 @@ def test_generator_max_returns(monkeypatch, shutdown_only): "2", ) - @ray.remote(num_returns="streaming") + @ray.remote def generator_task(): for _ in range(3): yield 1 @@ -224,7 +224,7 @@ def g(): yield i return - generator = g.options(num_returns="streaming").remote() + generator = g.remote() result = [] for ref in generator: result.append(ray.get(ref)) @@ -252,7 +252,7 @@ async def gen(self): assert task_name == asyncio.current_task().get_name() a = A.remote() - for obj_ref in a.gen.options(num_returns="streaming").remote(): + for obj_ref in a.gen.remote(): print(ray.get(obj_ref)) @@ -269,7 +269,7 @@ async def gen(self): a = A.remote() async def co(): - async for ref in a.gen.options(num_returns="streaming").remote(): + async for ref in a.gen.remote(): print(await ref) async def main(): @@ -294,7 +294,7 @@ def f(): yield 1 for _ in range(10): - for ref in f.options(num_returns="streaming").remote(): + for ref in f.remote(): del ref time.sleep(0.2) @@ -304,7 +304,7 @@ def f(): assert_no_leak() for _ in range(10): - for ref in f.options(num_returns="streaming").remote(): + for ref in f.remote(): break time.sleep(0.2) @@ -388,7 +388,7 @@ def verify_regular(actor, fail): def verify_generator(actor, fail): for _ in range(100): - for ref in actor.gen.options(num_returns="streaming").remote(fail=fail): + for ref in actor.gen.remote(fail=fail): try: ray.get(ref) except Exception: diff --git a/python/ray/tests/test_streaming_generator_3.py b/python/ray/tests/test_streaming_generator_3.py index e71cff25b499d..d69b9e63139d4 100644 --- a/python/ray/tests/test_streaming_generator_3.py +++ b/python/ray/tests/test_streaming_generator_3.py @@ -13,6 +13,7 @@ from pydantic import BaseModel import ray from ray._raylet import StreamingObjectRefGenerator +from ray.exceptions import WorkerCrashedError from ray._private.test_utils import run_string_as_driver_nonblocking from ray.util.state import list_actors @@ -250,6 +251,61 @@ def g(sleep_time): assert result[10] == 4 +def test_completed_next_ready_is_finished(shutdown_only): + @ray.remote + def f(): + for _ in range(3): + time.sleep(1) + yield 1 + + gen = f.remote() + assert not gen.is_finished() + assert not gen.next_ready() + r, _ = ray.wait([gen]) + gen = r[0] + assert gen.next_ready() + _, ur = ray.wait([gen.completed()], timeout=0) + assert len(ur) == 1 + + # Consume object refs + next(gen) + assert not gen.is_finished() + _, ur = ray.wait([gen.completed()], timeout=0) + assert len(ur) == 1 + + next(gen) + assert not gen.is_finished() + _, ur = ray.wait([gen.completed()], timeout=0) + assert len(ur) == 1 + + next(gen) + with pytest.raises(StopIteration): + next(gen) + + assert gen.is_finished() + # Since the next should raise StopIteration, + # it should be False. + assert not gen.next_ready() + r, _ = ray.wait([gen.completed()], timeout=0) + assert len(r) == 1 + + # Test the failed case. + gen = f.remote() + next(gen) + ray.cancel(gen, force=True) + r, _ = ray.wait([gen]) + assert len(r) == 1 + # The last exception is not taken yet. + assert gen.next_ready() + assert not gen.is_finished() + with pytest.raises(WorkerCrashedError): + ray.get(gen.completed()) + with pytest.raises(WorkerCrashedError): + ray.get(next(gen)) + assert not gen.next_ready() + assert gen.is_finished() + + def test_streaming_generator_load(shutdown_only): app = FastAPI() @@ -263,32 +319,24 @@ def __init__(self, handle) -> None: @app.get("/") def stream_hi(self, request: Request) -> StreamingResponse: async def consume_obj_ref_gen(): - obj_ref_gen = await self._h.hi_gen.remote() - start = time.time() + obj_ref_gen = self._h.hi_gen.remote() num_recieved = 0 async for chunk in obj_ref_gen: - chunk = await chunk num_recieved += 1 yield str(chunk.json()) - delta = time.time() - start - print(f"**request throughput: {num_recieved / delta}") return StreamingResponse(consume_obj_ref_gen(), media_type="text/plain") @serve.deployment(max_concurrent_queries=1000) class SimpleGenerator: async def hi_gen(self): - start = time.time() for i in range(100): - # await asyncio.sleep(0.001) time.sleep(0.001) # if change to async sleep, i don't see crash. class Model(BaseModel): msg = "a" * 56 yield Model() - delta = time.time() - start - print(f"**model throughput: {100 / delta}") serve.run(Router.bind(SimpleGenerator.bind())) @@ -308,10 +356,10 @@ def send_serve_requests(): "exception": None, } start_perf_counter = time.perf_counter() - #r = self.client.get("/", stream=True) r = requests.get("http://localhost:8000", stream=True) + print("status code: ", r.status_code) if r.status_code != 200: - print(r) + assert False else: for i, chunk in enumerate(r.iter_content(chunk_size=None, decode_unicode=True)): pass @@ -333,6 +381,9 @@ def send_serve_requests(): # Wait sufficient time. time.sleep(5) proc.terminate() + out_str = proc.stdout.read().decode("ascii") + err_str = proc.stderr.read().decode("ascii") + print(out_str, err_str) for actor in list_actors(): assert actor.state != "DEAD" diff --git a/python/ray/tests/test_typing.py b/python/ray/tests/test_typing.py index ee03f04fa2323..d98f34ad96997 100644 --- a/python/ray/tests/test_typing.py +++ b/python/ray/tests/test_typing.py @@ -14,7 +14,8 @@ def test_typing_good(): typing_good_tmp_path = create_tmp_copy(TYPING_GOOD_PATH) - _, msg, status_code = mypy_api.run([typing_good_tmp_path]) + out, msg, status_code = mypy_api.run([typing_good_tmp_path]) + print(out) assert status_code == 0, msg diff --git a/python/ray/tests/typing_files/check_typing_good.py b/python/ray/tests/typing_files/check_typing_good.py index b81f8527ddb3f..389d56c75a733 100644 --- a/python/ray/tests/typing_files/check_typing_good.py +++ b/python/ray/tests/typing_files/check_typing_good.py @@ -1,8 +1,15 @@ import ray +from typing import Generator +from ray import ObjectRef ray.init() +@ray.remote +def int_task() -> int: + return 1 + + @ray.remote def f(a: int) -> str: return "a = {}".format(a + 1) @@ -18,15 +25,88 @@ def h(a: str, b: int) -> str: return a +def func(a: "ObjectRef[str]"): + pass + + # Make sure the function arg is check print(f.remote(1)) object_ref_str = f.remote(1) +object_ref_int = int_task.remote() # Make sure the ObjectRef[T] variant of function arg is checked print(g.remote(object_ref_str)) +# Make sure it is backward compatible after +# introducing generator types. +func(object_ref_str) + # Make sure there can be mixed T0 and ObjectRef[T1] for args print(h.remote(object_ref_str, 100)) +ready, unready = ray.wait([object_ref_str, object_ref_int]) + # Make sure the return type is checked. xy = ray.get(object_ref_str) + "y" + + +# Right now, we only check if it doesn't raise errors. +@ray.remote +def generator_1() -> Generator[int, None, None]: + yield 1 + + +gen = generator_1.remote() + + +""" +TODO(sang): Enable it. +Test generator + +Generator can have 4 different output +per generator and async generator. See +https://docs.python.org/3/library/typing.html#typing.Generator +for more details. +""" + +# @ray.remote +# def generator_1() -> Generator[int, None, None]: +# yield 1 + + +# @ray.remote +# def generator_2() -> Iterator[int]: +# yield 1 + + +# @ray.remote +# def generator_3() -> Iterable[int]: +# yield 1 + + +# gen: StreamingObjectRefGeneratorType[int] = generator_1.remote() +# gen2: StreamingObjectRefGeneratorType[int] = generator_2.remote() +# gen3: StreamingObjectRefGeneratorType[int] = generator_3.remote() + + +# next_item: ObjectRef[int] = gen.__next__() + + +# @ray.remote +# async def async_generator_1() -> AsyncGenerator[int, None]: +# yield 1 + + +# @ray.remote +# async def async_generator_2() -> AsyncIterator[int]: +# yield 1 + + +# @ray.remote +# async def async_generator_3() -> AsyncIterable[int]: +# yield 1 + + +# gen4: StreamingObjectRefGeneratorType[int] = async_generator_1.remote() +# gen5: StreamingObjectRefGeneratorType[int] = async_generator_2.remote() +# gen6: StreamingObjectRefGeneratorType[int] = async_generator_3.remote() diff --git a/python/ray/types.py b/python/ray/types.py index 79110f5425574..a4733f44e98eb 100644 --- a/python/ray/types.py +++ b/python/ray/types.py @@ -7,6 +7,8 @@ # TODO(ekl) this is a dummy generic ref type for documentation purposes only. # We should try to make the Cython ray.ObjectRef properly generic. +# NOTE(sang): Looks like using Generic in Cython is not currently possible. +# We should update Cython > 3.0 for this. @PublicAPI class ObjectRef(Generic[T]): pass diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 07ae2a5d883d7..a1aceac7da3ba 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -2850,6 +2850,10 @@ Status CoreWorker::TryReadObjectRefStream(const ObjectID &generator_id, return status; } +bool CoreWorker::IsFinished(const ObjectID &generator_id) const { + return task_manager_->IsFinished(generator_id); +} + std::pair CoreWorker::PeekObjectRefStream( const ObjectID &generator_id) { auto [object_id, ready] = task_manager_->PeekObjectRefStream(generator_id); diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index a8997630f9a50..e8c7e666fa98b 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -384,6 +384,9 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { Status TryReadObjectRefStream(const ObjectID &generator_id, rpc::ObjectReference *object_ref_out); + /// Return True if there's no more object to read. False otherwise. + bool IsFinished(const ObjectID &generator_id) const; + /// Read the next index of a ObjectRefStream of generator_id without /// consuming an index. /// \param[in] generator_id The object ref id of the streaming diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index aa50ee9481ded..cbebb0c9b80c3 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -63,8 +63,7 @@ bool ObjectRefStream::IsObjectConsumed(int64_t item_index) { Status ObjectRefStream::TryReadNextItem(ObjectID *object_id_out) { *object_id_out = GetObjectRefAtIndex(next_index_); - bool is_eof_set = end_of_stream_index_ != -1; - if (is_eof_set && next_index_ >= end_of_stream_index_) { + if (IsFinished()) { // next_index_ cannot be bigger than end_of_stream_index_. RAY_CHECK(next_index_ == end_of_stream_index_); RAY_LOG(DEBUG) << "ObjectRefStream of an id " << generator_id_ @@ -90,6 +89,11 @@ Status ObjectRefStream::TryReadNextItem(ObjectID *object_id_out) { return Status::OK(); } +bool ObjectRefStream::IsFinished() const { + bool is_eof_set = end_of_stream_index_ != -1; + return is_eof_set && next_index_ >= end_of_stream_index_; +} + std::pair ObjectRefStream::PeekNextItem() { const auto &object_id = GetObjectRefAtIndex(next_index_); if (refs_written_to_stream_.find(object_id) == refs_written_to_stream_.end()) { @@ -531,6 +535,16 @@ Status TaskManager::TryReadObjectRefStream(const ObjectID &generator_id, return status; } +bool TaskManager::IsFinished(const ObjectID &generator_id) const { + absl::MutexLock lock(&objet_ref_stream_ops_mu_); + auto stream_it = object_ref_streams_.find(generator_id); + RAY_CHECK(stream_it != object_ref_streams_.end()) + << "IsFinished API can be used only when the stream has been " + "created " + "and not removed."; + return stream_it->second.IsFinished(); +} + std::pair TaskManager::PeekObjectRefStream(const ObjectID &generator_id) { ObjectID next_object_id; absl::MutexLock lock(&objet_ref_stream_ops_mu_); diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h index e0d310afdff68..c6b761db61f0d 100644 --- a/src/ray/core_worker/task_manager.h +++ b/src/ray/core_worker/task_manager.h @@ -110,6 +110,9 @@ class ObjectRefStream { /// \return KeyError if it reaches to EoF. Ok otherwise. Status TryReadNextItem(ObjectID *object_id_out); + /// Return True if there's no more object to read. False otherwise. + bool IsFinished() const; + std::pair PeekNextItem(); /// Return True if the item_index is already consumed. @@ -406,6 +409,9 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa Status TryReadObjectRefStream(const ObjectID &generator_id, ObjectID *object_id_out) ABSL_LOCKS_EXCLUDED(mu_); + /// Return True if there's no more object to read. False otherwise. + bool IsFinished(const ObjectID &generator_id) const ABSL_LOCKS_EXCLUDED(mu_); + /// Read the next index of a ObjectRefStream of generator_id without /// consuming an index. ///