diff --git a/python/ray/_private/ray_constants.py b/python/ray/_private/ray_constants.py index 1cfcfd768767..e6a707cb9381 100644 --- a/python/ray/_private/ray_constants.py +++ b/python/ray/_private/ray_constants.py @@ -586,3 +586,8 @@ def gcs_actor_scheduling_enabled(): ) RAY_GC_MIN_COLLECT_INTERVAL = env_float("RAY_GC_MIN_COLLECT_INTERVAL_S", 5) + +# Worker exit type constants for signal handling and shutdown +WORKER_EXIT_TYPE_USER = "user" +WORKER_EXIT_TYPE_SYSTEM = "system" +WORKER_EXIT_TYPE_INTENTIONAL_SYSTEM = "intentional_system_exit" diff --git a/python/ray/_private/utils.py b/python/ray/_private/utils.py index 5bda84aad204..766697d634d9 100644 --- a/python/ray/_private/utils.py +++ b/python/ray/_private/utils.py @@ -17,6 +17,7 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Dict, List, Mapping, @@ -801,6 +802,97 @@ def set_sigterm_handler(sigterm_handler): signal.signal(signal.SIGTERM, sigterm_handler) +_signal_handler_installed = False +_graceful_shutdown_in_progress = False + + +def install_driver_signal_handler() -> None: + """Install SIGTERM handler for Ray driver processes. + + Implements graceful-then-forced shutdown semantics: + - First SIGTERM: trigger graceful shutdown via sys.exit() (allows atexit handlers) + - Second SIGTERM: escalate to immediate forced shutdown via os._exit(1) + + Must be called from the main thread (Python signal handlers requirement). + Refer to https://docs.python.org/3/library/signal.html#signals-and-threads for more details. + """ + global _signal_handler_installed, _graceful_shutdown_in_progress + if _signal_handler_installed: + return + + if threading.current_thread() is not threading.main_thread(): + logger.warning( + "Signal handlers not installed because current thread is not the main thread. Refer to https://docs.python.org/3/library/signal.html#signals-and-threads for more details." + ) + return + + def _handler(signum, _frame): + global _graceful_shutdown_in_progress + if not _graceful_shutdown_in_progress: + _graceful_shutdown_in_progress = True + sys.exit(signum) + else: + logger.warning( + "Received second SIGTERM signal; escalating to immediate forced shutdown." + ) + os._exit(1) + + set_sigterm_handler(_handler) + _signal_handler_installed = True + + +def install_worker_signal_handler(force_shutdown_fn: Callable[[str], None]) -> None: + """Install SIGTERM handler for Ray worker processes. + + Workers receive external SIGTERM as a forced shutdown signal to avoid hangs + during blocking operations like ray.get()/wait(). This is different from + driver semantics where the first signal is graceful. + + Must be called from the main thread (Python signal handlers requirement). + Refer to https://docs.python.org/3/library/signal.html#signals-and-threads for more details. + + Args: + force_shutdown_fn: Function to call for forced shutdown. Should accept a + single string argument (detail message). + + Raises: + AssertionError: If force_shutdown_fn is None. + + Only installs on the main thread; logs a warning otherwise. + """ + global _signal_handler_installed + assert ( + force_shutdown_fn is not None + ), "Worker signal handlers require force_shutdown_fn" + + if _signal_handler_installed: + return + + if threading.current_thread() is not threading.main_thread(): + logger.warning( + "Signal handlers not installed because current thread is not the main thread. Refer to https://docs.python.org/3/library/signal.html#signals-and-threads for more details." + ) + return + + def _handler(signum, _frame): + # Workers treat external SIGTERM as immediate forced exit to avoid hangs. + signal_name = signal.Signals(signum).name + force_shutdown_fn(signal_name) + + set_sigterm_handler(_handler) + _signal_handler_installed = True + + +def reset_signal_handler_state() -> None: + """Reset signal handler module flags for subsequent ray.init() in same process. + + Called during ray.shutdown() to allow re-initialization of signal handlers. + """ + global _signal_handler_installed, _graceful_shutdown_in_progress + _signal_handler_installed = False + _graceful_shutdown_in_progress = False + + def try_to_symlink(symlink_path, target_path): """Attempt to create a symlink. diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index 379fcc550af7..df82bb51da0a 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -87,7 +87,6 @@ from ray._raylet import ( ObjectRefGenerator, TaskID, - raise_sys_exit_with_custom_error_message, ) from ray.actor import ActorClass from ray.exceptions import ObjectStoreFullError, RayError, RaySystemError, RayTaskError @@ -1037,13 +1036,12 @@ def get_objects( def main_loop(self): """The main loop a worker runs to receive and execute tasks.""" - def sigterm_handler(signum, frame): - raise_sys_exit_with_custom_error_message( - "The process receives a SIGTERM.", exit_code=1 + def force_shutdown(detail: str): + self.core_worker.force_exit_worker( + ray_constants.WORKER_EXIT_TYPE_SYSTEM, detail.encode("utf-8") ) - # Note: shutdown() function is called from atexit handler. - ray._private.utils.set_sigterm_handler(sigterm_handler) + ray._private.utils.install_worker_signal_handler(force_shutdown) self.core_worker.run_task_loop() sys.exit(0) @@ -1676,17 +1674,7 @@ def init( system_reserved_memory=system_reserved_memory, ) - # terminate any signal before connecting driver - def sigterm_handler(signum, frame): - sys.exit(signum) - - if threading.current_thread() is threading.main_thread(): - ray._private.utils.set_sigterm_handler(sigterm_handler) - else: - logger.warning( - "SIGTERM handler is not set because current thread " - "is not the main thread." - ) + ray._private.utils.install_driver_signal_handler() # If available, use RAY_ADDRESS to override if the address was left # unspecified, or set to "auto" in the call to init @@ -2101,6 +2089,7 @@ def shutdown(_exiting_interpreter: bool = False): from ray.dag.compiled_dag_node import _shutdown_all_compiled_dags _shutdown_all_compiled_dags() + ray._private.utils.reset_signal_handler_state() global_worker.shutdown_gpu_object_manager() if _exiting_interpreter and global_worker.mode == SCRIPT_MODE: diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index f2ef051b0b6a..0684ecf6812b 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -130,6 +130,7 @@ from ray.includes.common cimport ( WORKER_EXIT_TYPE_USER_ERROR, WORKER_EXIT_TYPE_SYSTEM_ERROR, WORKER_EXIT_TYPE_INTENTIONAL_SYSTEM_ERROR, + WORKER_EXIT_TYPE_INTENDED_USER_EXIT, kResourceUnitScaling, kImplicitResourcePrefix, kWorkerSetupHookKeyName, @@ -3020,11 +3021,11 @@ cdef class CoreWorker: CWorkerExitType c_exit_type cdef const shared_ptr[LocalMemoryBuffer] null_ptr - if exit_type == "user": + if exit_type == ray_constants.WORKER_EXIT_TYPE_USER: c_exit_type = WORKER_EXIT_TYPE_USER_ERROR - elif exit_type == "system": + elif exit_type == ray_constants.WORKER_EXIT_TYPE_SYSTEM: c_exit_type = WORKER_EXIT_TYPE_SYSTEM_ERROR - elif exit_type == "intentional_system_exit": + elif exit_type == ray_constants.WORKER_EXIT_TYPE_INTENTIONAL_SYSTEM: c_exit_type = WORKER_EXIT_TYPE_INTENTIONAL_SYSTEM_ERROR else: raise ValueError(f"Invalid exit type: {exit_type}") @@ -3032,6 +3033,44 @@ cdef class CoreWorker: with nogil: CCoreWorkerProcess.GetCoreWorker().Exit(c_exit_type, detail, null_ptr) + def force_exit_worker(self, exit_type: str, c_string detail): + """Force exit the current worker process immediately without draining. + + Terminates the worker process via CoreWorker.ForceExit, bypassing graceful + shutdown (no task draining). Used for forced shutdowns triggered by signals + or other immediate termination scenarios. + + Args: + exit_type: Type of exit. Must be one of: + - "user": User-initiated forced exit (INTENDED_USER_EXIT) + - "system": System error forced exit (SYSTEM_ERROR) + - "intentional_system_exit": Intentional system-initiated exit + detail: Human-readable detail string describing the exit reason. + + Raises: + AssertionError: If called from a driver process (must be worker-only). + ValueError: If exit_type is not one of the valid options. + """ + assert not self.is_driver, ( + "force_exit_worker must only be called by workers, not drivers" + ) + cdef CWorkerExitType c_exit_type + if exit_type == ray_constants.WORKER_EXIT_TYPE_USER: + c_exit_type = WORKER_EXIT_TYPE_INTENDED_USER_EXIT + elif exit_type == ray_constants.WORKER_EXIT_TYPE_SYSTEM: + c_exit_type = WORKER_EXIT_TYPE_SYSTEM_ERROR + elif exit_type == ray_constants.WORKER_EXIT_TYPE_INTENTIONAL_SYSTEM: + c_exit_type = WORKER_EXIT_TYPE_INTENTIONAL_SYSTEM_ERROR + else: + raise ValueError( + f"Invalid exit_type '{exit_type}'; expected " + f"'{ray_constants.WORKER_EXIT_TYPE_USER}', " + f"'{ray_constants.WORKER_EXIT_TYPE_SYSTEM}', or " + f"'{ray_constants.WORKER_EXIT_TYPE_INTENTIONAL_SYSTEM}'" + ) + with nogil: + CCoreWorkerProcess.GetCoreWorker().ForceExit(c_exit_type, detail) + def get_current_task_name(self) -> str: """Return the current task name. diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index c406a82c5c6c..c99ffcf7485d 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -267,6 +267,7 @@ cdef extern from "src/ray/protobuf/common.pb.h" nogil: cdef CWorkerExitType WORKER_EXIT_TYPE_USER_ERROR "ray::rpc::WorkerExitType::USER_ERROR" # noqa: E501 cdef CWorkerExitType WORKER_EXIT_TYPE_SYSTEM_ERROR "ray::rpc::WorkerExitType::SYSTEM_ERROR" # noqa: E501 cdef CWorkerExitType WORKER_EXIT_TYPE_INTENTIONAL_SYSTEM_ERROR "ray::rpc::WorkerExitType::INTENDED_SYSTEM_EXIT" # noqa: E501 + cdef CWorkerExitType WORKER_EXIT_TYPE_INTENDED_USER_EXIT "ray::rpc::WorkerExitType::INTENDED_USER_EXIT" # noqa: E501 cdef extern from "src/ray/protobuf/common.pb.h" nogil: cdef CTaskType TASK_TYPE_NORMAL_TASK "ray::TaskType::NORMAL_TASK" diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 16aaa749ad28..b4187f0c47da 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -364,6 +364,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: const c_string &detail, const shared_ptr[LocalMemoryBuffer] &creation_task_exception_pb_bytes) + void ForceExit(const CWorkerExitType exit_type, const c_string &detail) + unordered_map[CLineageReconstructionTask, uint64_t] \ GetLocalOngoingLineageReconstructionTasks() const diff --git a/python/ray/tests/BUILD.bazel b/python/ray/tests/BUILD.bazel index d72c9d21d188..fa7bfc0e3627 100644 --- a/python/ray/tests/BUILD.bazel +++ b/python/ray/tests/BUILD.bazel @@ -339,6 +339,7 @@ py_test_module_list( "test_scheduling.py", "test_serialization.py", "test_shuffle.py", + "test_signal_handler.py", "test_state_api_log.py", "test_streaming_generator.py", "test_streaming_generator_2.py", diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index 068bde2b294e..5d7f176e6224 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -591,6 +591,16 @@ def ray_start_no_cpu(request, maybe_setup_external_redis): yield res +# Simple fixture that starts and stops Ray with default settings. +@pytest.fixture +def ray_start(): + ray.init() + try: + yield + finally: + ray.shutdown() + + # The following fixture will start ray with 1 cpu. @pytest.fixture def ray_start_regular(request, maybe_setup_external_redis): diff --git a/python/ray/tests/test_signal_handler.py b/python/ray/tests/test_signal_handler.py new file mode 100644 index 000000000000..2336b4a80804 --- /dev/null +++ b/python/ray/tests/test_signal_handler.py @@ -0,0 +1,204 @@ +import asyncio +import os +import signal +import sys +import tempfile +import textwrap +import time + +import pytest + +import ray +from ray._common.test_utils import wait_for_condition +from ray._private.test_utils import run_string_as_driver_nonblocking + + +@ray.remote +class SimpleActor: + def pid(self) -> int: + return os.getpid() + + def ping(self) -> str: + return "ok" + + +@ray.remote +class AsyncioActor: + async def pid(self) -> int: + return os.getpid() + + async def ping(self) -> str: + return "ok" + + async def run(self): + while True: + await asyncio.sleep(0.1) + + +def _expect_actor_dies(actor: "ray.actor.ActorHandle", timeout_s: float = 30) -> None: + """Wait for an actor to die and raise if it doesn't die within timeout. + + Args: + actor: Ray actor handle to monitor. + timeout_s: Maximum time to wait for actor death in seconds. + + Raises: + AssertionError: If actor doesn't die within timeout_s. + """ + start = time.monotonic() + while time.monotonic() - start < timeout_s: + try: + ray.get(actor.ping.remote(), timeout=1) + except ray.exceptions.GetTimeoutError: + continue + except ray.exceptions.RayActorError: + return + except Exception: + return + raise AssertionError("Actor did not die within timeout") + + +def test_asyncio_actor_force_exit_is_immediate(ray_start): + """Calling force_exit_worker() inside an asyncio actor should exit immediately.""" + + @ray.remote + class A: + async def ping(self): + return "ok" + + async def force_quit(self): + from ray._private.worker import global_worker + + global_worker.core_worker.force_exit_worker("user", b"force from test") + return "unreachable" + + a = A.remote() + assert ray.get(a.ping.remote()) == "ok" + with pytest.raises(ray.exceptions.RayActorError): + ray.get(a.force_quit.remote()) + + +def test_worker_sigterm_terminates_immediately(ray_start): + """Worker should terminate immediately upon receiving SIGTERM.""" + a = SimpleActor.remote() + pid = ray.get(a.pid.remote()) + os.kill(pid, signal.SIGTERM) + _expect_actor_dies(a) + + +def test_worker_sigterm_during_blocking_get(ray_start): + """SIGTERM should force exit even when worker is blocked on ray.get().""" + + @ray.remote + class BlockedActor: + def pid(self) -> int: + return os.getpid() + + def block_on_get(self): + @ray.remote + def never_returns(): + time.sleep(10000) + + ray.get(never_returns.remote()) + + a = BlockedActor.remote() + pid = ray.get(a.pid.remote()) + a.block_on_get.remote() + time.sleep(0.1) + os.kill(pid, signal.SIGTERM) + _expect_actor_dies(a) + + +def test_asyncio_actor_sigterm_termination(ray_start): + """Asyncio actor should terminate upon receiving SIGTERM.""" + a = AsyncioActor.remote() + pid = ray.get(a.pid.remote()) + a.run.remote() + assert ray.get(a.ping.remote()) == "ok" + os.kill(pid, signal.SIGTERM) + _expect_actor_dies(a) + + +def test_driver_sigterm_graceful(): + """Driver should exit gracefully on first SIGTERM and atexit should run.""" + with tempfile.TemporaryDirectory() as td: + flag = os.path.join(td, "driver_atexit_flag.txt") + ready = os.path.join(td, "driver_ready.txt") + driver_code = textwrap.dedent( + f""" + import atexit, os, time, ray + + def on_exit(): + with open('{flag}', 'w') as f: + f.write('ok') + + atexit.register(on_exit) + ray.init() + with open('{ready}', 'w') as f: + f.write('ready') + time.sleep(1000) + """ + ) + p = run_string_as_driver_nonblocking(driver_code) + try: + wait_for_condition(lambda: os.path.exists(ready), timeout=10) + os.kill(p.pid, signal.SIGTERM) + _ = p.wait(timeout=10) + finally: + try: + p.kill() + except Exception: + pass + assert os.path.exists(flag) + with open(flag, "r") as f: + assert f.read() == "ok" + + +def test_driver_double_sigterm_forced(): + """Driver should force-exit on second SIGTERM if first is slow.""" + with tempfile.TemporaryDirectory() as td: + flag = os.path.join(td, "driver_atexit_flag.txt") + ready = os.path.join(td, "driver_ready.txt") + driver_code = textwrap.dedent( + f""" + import atexit, os, time, ray + + def on_exit(): + time.sleep(10) + with open('{flag}', 'w') as f: + f.write('ok') + + atexit.register(on_exit) + ray.init() + with open('{ready}', 'w') as f: + f.write('ready') + time.sleep(1000) + """ + ) + p = run_string_as_driver_nonblocking(driver_code) + try: + wait_for_condition(lambda: os.path.exists(ready), timeout=10) + os.kill(p.pid, signal.SIGTERM) + time.sleep(0.1) + os.kill(p.pid, signal.SIGTERM) + + start_wait = time.monotonic() + p.wait(timeout=2) + wait_time = time.monotonic() - start_wait + + assert ( + wait_time < 2 + ), f"Should exit quickly via forced path, but took {wait_time}s" + finally: + try: + p.kill() + except Exception: + pass + + assert not os.path.exists( + flag + ), "Slow atexit should not complete on forced exit" + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", __file__])) diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index bc6eecd3f752..b41e79e02786 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -1339,6 +1339,11 @@ class CoreWorker { const std::shared_ptr &creation_task_exception_pb_bytes = nullptr); + /// Forcefully exit the worker immediately without draining/cleanup. + /// \param exit_type The reason why this worker process is disconnected. + /// \param exit_detail The detailed reason for a given exit. + void ForceExit(const rpc::WorkerExitType exit_type, const std::string &detail); + void AsyncRetryTask(TaskSpecification &spec, uint32_t delay_ms); private: @@ -1398,12 +1403,6 @@ class CoreWorker { void SetActorId(const ActorID &actor_id); - /// Forcefully exit the worker. `Force` means it will exit actor without draining - /// or cleaning any resources. - /// \param exit_type The reason why this worker process is disconnected. - /// \param exit_detail The detailed reason for a given exit. - void ForceExit(const rpc::WorkerExitType exit_type, const std::string &detail); - /// Forcefully kill child processes. User code running in actors or tasks /// can spawn processes that don't get terminated. If those processes /// own resources (such as GPU memory), then those resources will become diff --git a/src/ray/core_worker/tests/shutdown_coordinator_test.cc b/src/ray/core_worker/tests/shutdown_coordinator_test.cc index 50a7d755bb41..dd446009307a 100644 --- a/src/ray/core_worker/tests/shutdown_coordinator_test.cc +++ b/src/ray/core_worker/tests/shutdown_coordinator_test.cc @@ -176,6 +176,46 @@ TEST_F(ShutdownCoordinatorTest, RequestShutdown_DelegatesToGraceful_OnlyFirstSuc EXPECT_EQ(coordinator->GetReason(), ShutdownReason::kUserError); // unchanged } +TEST_F(ShutdownCoordinatorTest, SingleSignal_IntentionalSystemExit_TriggersExitNotForce) { + auto fake = std::make_unique(); + auto *fake_ptr = fake.get(); + ShutdownCoordinator coordinator(std::move(fake), rpc::WorkerType::WORKER); + + const bool initiated = coordinator.RequestShutdown( + /*force_shutdown=*/false, + ShutdownReason::kIntentionalShutdown, + /*detail=*/"signal:INTENTIONAL", + ShutdownCoordinator::kInfiniteTimeout, + /*creation_task_exception_pb_bytes=*/nullptr); + + ASSERT_TRUE(initiated); + EXPECT_EQ(fake_ptr->force_calls.load(), 0); + EXPECT_EQ(fake_ptr->graceful_calls.load(), 0); + EXPECT_EQ(fake_ptr->worker_exit_calls.load(), 1); +} + +TEST_F(ShutdownCoordinatorTest, DoubleSignal_SecondForce_ExecutesForceShutdown) { + auto fake = std::make_unique(); + auto *fake_ptr = fake.get(); + ShutdownCoordinator coordinator(std::move(fake), rpc::WorkerType::WORKER); + + const bool first = coordinator.RequestShutdown( + /*force_shutdown=*/false, + ShutdownReason::kIntentionalShutdown, + /*detail=*/"first", + ShutdownCoordinator::kInfiniteTimeout, + /*creation_task_exception_pb_bytes=*/nullptr); + ASSERT_TRUE(first); + + (void)coordinator.RequestShutdown( + /*force_shutdown=*/true, + ShutdownReason::kForcedExit, + /*detail=*/"second", + std::chrono::milliseconds{0}, + /*creation_task_exception_pb_bytes=*/nullptr); + EXPECT_EQ(fake_ptr->force_calls.load(), 1); +} + TEST_F(ShutdownCoordinatorTest, RequestShutdown_Graceful_SetsDisconnecting_ThenTryTransitionToShutdown_Succeeds) { auto coordinator = std::make_unique(