Skip to content

Commit

Permalink
[core] In Generator, distinguish value yields and exceptions not by t…
Browse files Browse the repository at this point in the history
…ype but by control flow. (#43413)

Previously we use isinstance(output_or_exception, Exception) to find out if this is a raised Exception or a yielded value. This missed one case: yielded Exception. Though this is not likely to happen, it should be fixed for the sake of completeness.

Signed-off-by: Ruiyang Wang <rywang014@gmail.com>
  • Loading branch information
rynewang committed Mar 8, 2024
1 parent 4f3f1f7 commit fc9e82c
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 85 deletions.
1 change: 0 additions & 1 deletion python/ray/_raylet.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ cdef class CoreWorker:
const CAddress &caller_address,
c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *returns,
CObjectID ref_generator_id=*)
cdef yield_current_fiber(self, CFiberEvent &fiber_event)
cdef make_actor_handle(self, ActorHandleSharedPtr c_actor_handle)
cdef c_function_descriptors_to_python(
self, const c_vector[CFunctionDescriptor] &c_function_descriptors)
Expand Down
186 changes: 102 additions & 84 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1234,19 +1234,14 @@ cdef class StreamingGeneratorExecutionContext:

cdef report_streaming_generator_output(
StreamingGeneratorExecutionContext context,
output_or_exception: Union[object, Exception],
output: object,
generator_index: int64_t
):
"""Report a given generator output to a caller.
If a generator produces an exception, it should be
passed as an output to report. The API will return
False if the generator should keep executing.
True otherwise.
Args:
context: Streaming generator's execution context.
output_or_exception: The output yielded from a
output: The output yielded from a
generator or raised as an exception.
generator_index: index of the output element in the
generated sequence
Expand All @@ -1257,43 +1252,81 @@ cdef report_streaming_generator_output(
# Ray Object created from an output.
c_pair[CObjectID, shared_ptr[CRayObject]] return_obj

if isinstance(output_or_exception, Exception):
create_generator_error_object(
output_or_exception,
worker,
context.task_type,
context.caller_address,
context.task_id,
context.serialized_retry_exception_allowlist,
context.function_name,
context.function_descriptor,
context.title,
context.actor,
context.actor_id,
context.return_size,
generator_index,
context.is_async,
context.should_retry_exceptions,
&return_obj,
context.is_retryable_error,
context.application_error
)
else:
# Report the intermediate result if there was no error.
create_generator_return_obj(
output_or_exception,
# Report the intermediate result if there was no error.
create_generator_return_obj(
output,
context.generator_id,
worker,
context.caller_address,
context.task_id,
context.return_size,
generator_index,
context.is_async,
&return_obj)

# Del output here so that we can GC the memory
# usage asap.
del output

context.streaming_generator_returns[0].push_back(
c_pair[CObjectID, c_bool](
return_obj.first,
is_plasma_object(return_obj.second)))

with nogil:
check_status(CCoreWorkerProcess.GetCoreWorker().ReportGeneratorItemReturns(
return_obj,
context.generator_id,
worker,
context.caller_address,
context.task_id,
context.return_size,
generator_index,
context.is_async,
&return_obj)
context.attempt_number,
context.waiter))

# Del output here so that we can GC the memory

cdef report_streaming_generator_exception(
StreamingGeneratorExecutionContext context,
e: Exception,
generator_index: int64_t
):
"""Report a given generator exception to a caller.
Args:
context: Streaming generator's execution context.
output_or_exception: The output yielded from a
generator or raised as an exception.
generator_index: index of the output element in the
generated sequence
"""
worker = ray._private.worker.global_worker

cdef:
# Ray Object created from an output.
c_pair[CObjectID, shared_ptr[CRayObject]] return_obj

create_generator_error_object(
e,
worker,
context.task_type,
context.caller_address,
context.task_id,
context.serialized_retry_exception_allowlist,
context.function_name,
context.function_descriptor,
context.title,
context.actor,
context.actor_id,
context.return_size,
generator_index,
context.is_async,
context.should_retry_exceptions,
&return_obj,
context.is_retryable_error,
context.application_error
)

# Del exception here so that we can GC the memory
# usage asap.
del output_or_exception
del e

context.streaming_generator_returns[0].push_back(
c_pair[CObjectID, c_bool](
Expand All @@ -1309,7 +1342,6 @@ cdef report_streaming_generator_output(
context.attempt_number,
context.waiter))


cdef execute_streaming_generator_sync(StreamingGeneratorExecutionContext context):
"""Execute a given generator and streaming-report the
result to the given caller_address.
Expand All @@ -1335,19 +1367,12 @@ cdef execute_streaming_generator_sync(StreamingGeneratorExecutionContext context

gen = context.generator

while True:
try:
output_or_exception = next(gen)
except StopIteration:
break
except Exception as e:
output_or_exception = e

report_streaming_generator_output(context, output_or_exception, gen_index)
gen_index += 1

if isinstance(output_or_exception, Exception):
break
try:
for output in gen:
report_streaming_generator_output(context, output, gen_index)
gen_index += 1
except Exception as e:
report_streaming_generator_exception(context, e, gen_index)


async def execute_streaming_generator_async(
Expand Down Expand Up @@ -1383,40 +1408,39 @@ async def execute_streaming_generator_async(
gen = context.generator

futures = []
while True:
try:
output_or_exception = await gen.__anext__()
except StopAsyncIteration:
break
except AsyncioActorExit:
# The execute_task will handle this case.
raise
except Exception as e:
output_or_exception = e

loop = asyncio.get_running_loop()
worker = ray._private.worker.global_worker
loop = asyncio.get_running_loop()
worker = ray._private.worker.global_worker

# NOTE: Reporting generator output in a streaming fashion,
# is done in a standalone thread-pool fully *asynchronously*
# to avoid blocking the event-loop and allow it to *concurrently*
# make progress, since serializing and actual RPC I/O is done
# with "nogil".
# NOTE: Reporting generator output in a streaming fashion,
# is done in a standalone thread-pool fully *asynchronously*
# to avoid blocking the event-loop and allow it to *concurrently*
# make progress, since serializing and actual RPC I/O is done
# with "nogil".
try:
async for output in gen:
# Report the output to the owner of the task.
futures.append(
loop.run_in_executor(
worker.core_worker.get_thread_pool_for_async_event_loop(),
report_streaming_generator_output,
context,
output,
cur_generator_index,
)
)
cur_generator_index += 1
except Exception as e:
# Report the exception to the owner of the task.
futures.append(
loop.run_in_executor(
worker.core_worker.get_thread_pool_for_async_event_loop(),
report_streaming_generator_output,
report_streaming_generator_exception,
context,
output_or_exception,
e,
cur_generator_index,
)
)

cur_generator_index += 1

if isinstance(output_or_exception, Exception):
break

# Make sure all RPC I/O completes before returning
await asyncio.gather(*futures)

Expand Down Expand Up @@ -1684,7 +1708,6 @@ cdef void execute_task(
JobID job_id = core_worker.get_current_job_id()
TaskID task_id = core_worker.get_current_task_id()
uint64_t attempt_number = core_worker.get_current_task_attempt_number()
CFiberEvent task_done_event
c_vector[shared_ptr[CRayObject]] dynamic_return_ptrs

# Helper method used to exit current asyncio actor.
Expand Down Expand Up @@ -2044,7 +2067,6 @@ cdef execute_task_with_cancellation_handler(
CoreWorker core_worker = worker.core_worker
JobID job_id = core_worker.get_current_job_id()
TaskID task_id = core_worker.get_current_task_id()
CFiberEvent task_done_event
c_vector[shared_ptr[CRayObject]] dynamic_return_ptrs

task_name = name.decode("utf-8")
Expand Down Expand Up @@ -4753,10 +4775,6 @@ cdef class CoreWorker:

return self.current_runtime_env

cdef yield_current_fiber(self, CFiberEvent &fiber_event):
with nogil:
CCoreWorkerProcess.GetCoreWorker().YieldCurrentFiber(fiber_event)

def get_pending_children_task_ids(self, parent_task_id: TaskID):
cdef:
CTaskID c_parent_task_id = parent_task_id.native()
Expand Down
73 changes: 73 additions & 0 deletions python/ray/tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,79 @@ def check(empty_generator):
assert_no_leak()


def test_yield_exception(ray_start_cluster):
@ray.remote
def f():
yield 1
yield 2
yield Exception("value")
yield 3
raise Exception("raise")
yield 5

gen = f.remote()
assert ray.get(next(gen)) == 1
assert ray.get(next(gen)) == 2
yield_exc = ray.get(next(gen))
assert isinstance(yield_exc, Exception)
assert str(yield_exc) == "value"
assert ray.get(next(gen)) == 3
with pytest.raises(Exception, match="raise"):
ray.get(next(gen))
with pytest.raises(StopIteration):
ray.get(next(gen))


def test_actor_yield_exception(ray_start_cluster):
@ray.remote
class A:
def f(self):
yield 1
yield 2
yield Exception("value")
yield 3
raise Exception("raise")
yield 5

a = A.remote()
gen = a.f.remote()
assert ray.get(next(gen)) == 1
assert ray.get(next(gen)) == 2
yield_exc = ray.get(next(gen))
assert isinstance(yield_exc, Exception)
assert str(yield_exc) == "value"
assert ray.get(next(gen)) == 3
with pytest.raises(Exception, match="raise"):
ray.get(next(gen))
with pytest.raises(StopIteration):
ray.get(next(gen))


def test_async_actor_yield_exception(ray_start_cluster):
@ray.remote
class A:
async def f(self):
yield 1
yield 2
yield Exception("value")
yield 3
raise Exception("raise")
yield 5

a = A.remote()
gen = a.f.remote()
assert ray.get(next(gen)) == 1
assert ray.get(next(gen)) == 2
yield_exc = ray.get(next(gen))
assert isinstance(yield_exc, Exception)
assert str(yield_exc) == "value"
assert ray.get(next(gen)) == 3
with pytest.raises(Exception, match="raise"):
ray.get(next(gen))
with pytest.raises(StopIteration):
ray.get(next(gen))


# Client server port of the shared Ray instance
SHARED_CLIENT_SERVER_PORT = 25555

Expand Down

0 comments on commit fc9e82c

Please sign in to comment.