Skip to content

Commit

Permalink
[Core] Fixing segfaults in Async Streaming Generator (#43775)
Browse files Browse the repository at this point in the history
This change makes sure that no side-effects outlive invocation of execute_task method:

Previously, after scheduling tasks onto Core Worker's ThreadPoolExecutor these could have continued executing, even after the request has been cancelled (cancelling of the future wouldn't cancel already running task), leading to SIGSEGV when the task running in TPE would try to access data-structures that were already cleaned up after returning from this method.

With this change:

Upon encountering any failure, we'd set an interrupt_signal_event interrupting already scheduled, but not yet executed tasks (preventing them from modifying externally passed in data-structures)

Signed-off-by: Alexey Kudinkin <ak@anyscale.com>
  • Loading branch information
alexeykudinkin committed Mar 9, 2024
1 parent d180d5c commit cfebe14
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 42 deletions.
2 changes: 1 addition & 1 deletion python/ray/_raylet.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ cdef class CoreWorker:
object fd_to_cgname_dict
object _task_id_to_future_lock
dict _task_id_to_future
object thread_pool_for_async_event_loop
object event_loop_executor

cdef _create_put_buffer(self, shared_ptr[CBuffer] &metadata,
size_t data_size, ObjectRef object_ref,
Expand Down
121 changes: 81 additions & 40 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1235,7 +1235,8 @@ cdef class StreamingGeneratorExecutionContext:
cdef report_streaming_generator_output(
StreamingGeneratorExecutionContext context,
output: object,
generator_index: int64_t
generator_index: int64_t,
interrupt_signal_event: Optional[threading.Event],
):
"""Report a given generator output to a caller.
Expand Down Expand Up @@ -1268,6 +1269,11 @@ cdef report_streaming_generator_output(
# usage asap.
del output

# NOTE: Once interrupting event is set by the caller, we can NOT access
# externally provided data-structures, and have to interrupt the execution
if interrupt_signal_event is not None and interrupt_signal_event.is_set():
return

context.streaming_generator_returns[0].push_back(
c_pair[CObjectID, c_bool](
return_obj.first,
Expand All @@ -1286,7 +1292,8 @@ cdef report_streaming_generator_output(
cdef report_streaming_generator_exception(
StreamingGeneratorExecutionContext context,
e: Exception,
generator_index: int64_t
generator_index: int64_t,
interrupt_signal_event: Optional[threading.Event],
):
"""Report a given generator exception to a caller.
Expand Down Expand Up @@ -1328,6 +1335,11 @@ cdef report_streaming_generator_exception(
# usage asap.
del e

# NOTE: Once interrupting event is set by the caller, we can NOT access
# externally provided data-structures, and have to interrupt the execution
if interrupt_signal_event is not None and interrupt_signal_event.is_set():
return

context.streaming_generator_returns[0].push_back(
c_pair[CObjectID, c_bool](
return_obj.first,
Expand Down Expand Up @@ -1369,10 +1381,10 @@ cdef execute_streaming_generator_sync(StreamingGeneratorExecutionContext context

try:
for output in gen:
report_streaming_generator_output(context, output, gen_index)
report_streaming_generator_output(context, output, gen_index, None)
gen_index += 1
except Exception as e:
report_streaming_generator_exception(context, e, gen_index)
report_streaming_generator_exception(context, e, gen_index, None)


async def execute_streaming_generator_async(
Expand Down Expand Up @@ -1407,42 +1419,67 @@ async def execute_streaming_generator_async(

gen = context.generator

futures = []

loop = asyncio.get_running_loop()
worker = ray._private.worker.global_worker

# NOTE: Reporting generator output in a streaming fashion,
# is done in a standalone thread-pool fully *asynchronously*
# to avoid blocking the event-loop and allow it to *concurrently*
# make progress, since serializing and actual RPC I/O is done
# with "nogil".
executor = worker.core_worker.get_event_loop_executor()
interrupt_signal_event = threading.Event()

futures = []
try:
async for output in gen:
# Report the output to the owner of the task.
try:
async for output in gen:
# NOTE: Reporting generator output in a streaming fashion,
# is done in a standalone thread-pool fully *asynchronously*
# to avoid blocking the event-loop and allow it to *concurrently*
# make progress, since serializing and actual RPC I/O is done
# with "nogil".
futures.append(
loop.run_in_executor(
executor,
report_streaming_generator_output,
context,
output,
cur_generator_index,
interrupt_signal_event,
)
)
cur_generator_index += 1
except Exception as e:
# Report the exception to the owner of the task.
futures.append(
loop.run_in_executor(
worker.core_worker.get_thread_pool_for_async_event_loop(),
report_streaming_generator_output,
executor,
report_streaming_generator_exception,
context,
output,
e,
cur_generator_index,
interrupt_signal_event,
)
)
cur_generator_index += 1
except Exception as e:
# Report the exception to the owner of the task.
futures.append(
loop.run_in_executor(
worker.core_worker.get_thread_pool_for_async_event_loop(),
report_streaming_generator_exception,
context,
e,
cur_generator_index,
)
)
# Make sure all RPC I/O completes before returning
await asyncio.gather(*futures)

# Make sure all RPC I/O completes before returning
await asyncio.gather(*futures)

except BaseException as be:
# NOTE: PLEASE READ CAREFULLY BEFORE CHANGING
#
# Upon encountering any failures in reporting generator's output we have to
# make sure that any already scheduled (onto thread-pool executor), but not
# finished tasks are canceled before re-throwing the exception to avoid
# use-after-free failures where tasks could potentially access data-structures
# that are already cleaned by the caller.
#
# For that we set an event to interrupt already scheduled tasks (that have
# not finished executing), therefore interrupting their execution and
# making sure that externally provided data-structures are not
# accessed after this point
#
# For more details, please check out
# https://github.com/ray-project/ray/issues/43771
interrupt_signal_event.set()

raise


cdef create_generator_return_obj(
Expand Down Expand Up @@ -3305,7 +3342,7 @@ cdef class CoreWorker:
self.current_runtime_env = None
self._task_id_to_future_lock = threading.Lock()
self._task_id_to_future = {}
self.thread_pool_for_async_event_loop = None
self.event_loop_executor = None

def shutdown_driver(self):
# If it's a worker, the core worker process should have been
Expand Down Expand Up @@ -4614,12 +4651,16 @@ cdef class CoreWorker:
for fd in function_descriptors:
self.fd_to_cgname_dict[fd] = cg_name

def get_thread_pool_for_async_event_loop(self):
if self.thread_pool_for_async_event_loop is None:
# Theoretically, we can use multiple threads,
self.thread_pool_for_async_event_loop = ThreadPoolExecutor(
max_workers=1)
return self.thread_pool_for_async_event_loop
def get_event_loop_executor(self) -> ThreadPoolExecutor:
if self.event_loop_executor is None:
# NOTE: We're deliberately allocating thread-pool executor with
# a single thread, provided that many of its use-cases are
# not thread-safe yet (for ex, reporting streaming generator output)
self.event_loop_executor = ThreadPoolExecutor(max_workers=1)
return self.event_loop_executor

def reset_event_loop_executor(self, executor: ThreadPoolExecutor):
self.event_loop_executor = executor

def get_event_loop(self, function_descriptor, specified_cgname):
# __init__ will be invoked in default eventloop
Expand Down Expand Up @@ -4729,9 +4770,9 @@ cdef class CoreWorker:
def stop_and_join_asyncio_threads_if_exist(self):
event_loops = []
threads = []
if self.thread_pool_for_async_event_loop:
self.thread_pool_for_async_event_loop.shutdown(
wait=False, cancel_futures=True)
if self.event_loop_executor:
self.event_loop_executor.shutdown(
wait=True, cancel_futures=True)
if self.eventloop_for_default_cg is not None:
event_loops.append(self.eventloop_for_default_cg)
if self.thread_for_default_cg is not None:
Expand Down
1 change: 1 addition & 0 deletions python/ray/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ py_test_module_list(
"test_streaming_generator_3.py",
"test_streaming_generator_4.py",
"test_streaming_generator_backpressure.py",
"test_streaming_generator_regression.py",
"test_scheduling_performance.py",
"test_implicit_resource.py",
],
Expand Down
2 changes: 1 addition & 1 deletion python/ray/tests/test_streaming_generator_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_ray_datasetlike_mini_stress_test(
with monkeypatch.context() as m:
m.setenv(
"RAY_testing_asio_delay_us",
"CoreWorkerService.grpc_server." "ReportGeneratorItemReturns=10000:1000000",
"CoreWorkerService.grpc_server.ReportGeneratorItemReturns=10000:1000000",
)
cluster = ray_start_cluster
cluster.add_node(
Expand Down
137 changes: 137 additions & 0 deletions python/ray/tests/test_streaming_generator_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import sys
import time
from concurrent.futures import ThreadPoolExecutor

import pytest

import ray
from ray.actor import ActorHandle
from ray.exceptions import RayTaskError, TaskCancelledError
from ray.util.state import list_workers


@ray.remote(num_cpus=1)
class EndpointActor:
def __init__(self, *, injected_executor_delay_s: float, tokens_per_request: int):
self._tokens_per_request = tokens_per_request
# In this test we simulate conditions leading to use-after-free conditions,
# by injecting delays into worker's thread-pool executor
self._inject_delay_in_core_worker_executor(
target_delay_s=injected_executor_delay_s,
max_workers=1,
)

async def aio_stream(self):
for i in range(self._tokens_per_request):
yield i

@classmethod
def _inject_delay_in_core_worker_executor(
cls, target_delay_s: float, max_workers: int
):
if target_delay_s > 0:

class DelayedThreadPoolExecutor(ThreadPoolExecutor):
def submit(self, fn, /, *args, **kwargs):
def __slowed_fn():
print(
f">>> [DelayedThreadPoolExecutor] Starting executing "
f"function with delay {target_delay_s}s"
)

time.sleep(target_delay_s)
fn(*args, **kwargs)

return super().submit(__slowed_fn)

executor = DelayedThreadPoolExecutor(max_workers=max_workers)
ray._private.worker.global_worker.core_worker.reset_event_loop_executor(
executor
)


@ray.remote(num_cpus=1)
class CallerActor:
def __init__(
self,
downstream: ActorHandle,
):
self._h = downstream

async def run(self):
print(">>> [Caller] Starting consuming stream")

async_obj_ref_gen = self._h.aio_stream.options(num_returns="streaming").remote()
async for ref in async_obj_ref_gen:
r = await ref
if r == 1:
print(">>> [Caller] Cancelling generator")
ray.cancel(async_obj_ref_gen, recursive=False)

# NOTE: This delay is crucial to let already scheduled task to report
# generated item (report_streaming_generator_output) before we
# will tear down this stream
delay_after_cancellation_s = 2

print(f">>> [Caller] **Sleeping** {delay_after_cancellation_s}s")
time.sleep(delay_after_cancellation_s)
else:
print(f">>> [Caller] Received {r}")

print(">>> [Caller] Completed consuming stream")


@pytest.mark.parametrize("injected_executor_delay_s", [0, 2])
@pytest.mark.parametrize(
"ray_start_cluster",
[
{
"num_nodes": 2,
"num_cpus": 1,
}
],
indirect=True,
)
def test_segfault_report_streaming_generator_output(
ray_start_cluster, injected_executor_delay_s: float
):
"""
This is a "smoke" test attempting to emulate condition, when using Ray's async
streaming generator, that leads to worker crashing with SIGSEGV.
For more details summarizing these conditions, please refer to
https://github.com/ray-project/ray/issues/43771#issuecomment-1982301654
"""

caller = CallerActor.remote(
EndpointActor.remote(
injected_executor_delay_s=injected_executor_delay_s,
tokens_per_request=100,
),
)

worker_state_before = [(a.worker_id, a.exit_type) for a in list_workers()]
print(">>> Workers state before: ", worker_state_before)

with pytest.raises(RayTaskError) as exc_info:
ray.get(caller.run.remote())

assert isinstance(exc_info.value.cause, TaskCancelledError)

worker_state_after = [(a.worker_id, a.exit_type) for a in list_workers()]
print(">>> Workers state after: ", worker_state_after)

worker_ids, worker_exit_types = zip(*worker_state_after)
# Make sure no workers crashed
assert (
"SYSTEM_ERROR" not in worker_exit_types
), f"Unexpected crashed worker(s) in {worker_ids}"


if __name__ == "__main__":
import os

if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
else:
sys.exit(pytest.main(["-sv", __file__]))

0 comments on commit cfebe14

Please sign in to comment.