From 1f14f366fd4006b9f722bf6710111bb5ed36c35b Mon Sep 17 00:00:00 2001 From: Ian Rodney Date: Wed, 30 Jun 2021 11:48:55 -0700 Subject: [PATCH] [Client] Make `Client_Hook` per-thread (#16731) --- python/ray/_private/client_mode_hook.py | 49 +++++++++++++------ python/ray/_raylet.pyx | 8 +-- python/ray/tests/test_client.py | 29 +++++++++++ .../tests/test_client_library_integration.py | 28 +++++------ python/ray/util/client_connect.py | 4 +- python/ray/util/dask/BUILD | 18 +++---- 6 files changed, 90 insertions(+), 46 deletions(-) diff --git a/python/ray/_private/client_mode_hook.py b/python/ray/_private/client_mode_hook.py index 223f939cb9114..c71eda385c46e 100644 --- a/python/ray/_private/client_mode_hook.py +++ b/python/ray/_private/client_mode_hook.py @@ -1,36 +1,56 @@ import os from contextlib import contextmanager from functools import wraps +import threading # Attr set on func defs to mark they have been converted to client mode. RAY_CLIENT_MODE_ATTR = "__ray_client_mode_key__" -client_mode_enabled = os.environ.get("RAY_CLIENT_MODE", "0") == "1" +# Global setting of whether client mode is enabled. This default to OFF, +# but is enabled upon ray.client(...).connect() or in tests. +is_client_mode_enabled = os.environ.get("RAY_CLIENT_MODE", "0") == "1" os.environ.update({"RAY_CLIENT_MODE": "0"}) -_client_hook_enabled = True +# Local setting of whether to ignore client hook conversion. This defaults +# to TRUE and is disabled when the underlying 'real' Ray function is needed. +_client_hook_status_on_thread = threading.local() +_client_hook_status_on_thread.status = True -def _enable_client_hook(val: bool): - global _client_hook_enabled - _client_hook_enabled = val +def _get_client_hook_status_on_thread(): + """Get's the value of `_client_hook_status_on_thread`. + Since `_client_hook_status_on_thread` is a thread-local variable, we may + need to add and set the 'status' attribute. + """ + global _client_hook_status_on_thread + if not hasattr(_client_hook_status_on_thread, "status"): + _client_hook_status_on_thread.status = True + return _client_hook_status_on_thread.status + + +def _set_client_hook_status(val: bool): + global _client_hook_status_on_thread + _client_hook_status_on_thread.status = val def _disable_client_hook(): - global _client_hook_enabled - out = _client_hook_enabled - _client_hook_enabled = False + global _client_hook_status_on_thread + out = _get_client_hook_status_on_thread() + _client_hook_status_on_thread.status = False return out def _explicitly_enable_client_mode(): - global client_mode_enabled - client_mode_enabled = True + """Force client mode to be enabled. + NOTE: This should not be used in tests, use `enable_client_mode`. + """ + global is_client_mode_enabled + is_client_mode_enabled = True def _explicitly_disable_client_mode(): - global client_mode_enabled - client_mode_enabled = False + global is_client_mode_enabled + is_client_mode_enabled = False @contextmanager @@ -39,7 +59,7 @@ def disable_client_hook(): try: yield None finally: - _enable_client_hook(val) + _set_client_hook_status(val) @contextmanager @@ -65,8 +85,7 @@ def wrapper(*args, **kwargs): def client_mode_should_convert(): - global _client_hook_enabled - return client_mode_enabled and _client_hook_enabled + return is_client_mode_enabled and _get_client_hook_status_on_thread() def client_mode_wrap(func): diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index c9c9757ab9338..519cfaa8f9577 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -117,8 +117,7 @@ from ray.exceptions import ( ) from ray._private.utils import decode from ray._private.client_mode_hook import ( - _enable_client_hook, - _disable_client_hook, + disable_client_hook, ) import msgpack @@ -620,9 +619,8 @@ cdef CRayStatus task_execution_handler( const c_string debugger_breakpoint, c_vector[shared_ptr[CRayObject]] *returns, shared_ptr[LocalMemoryBuffer] &creation_task_exception_pb_bytes) nogil: - with gil: + with gil, disable_client_hook(): try: - client_was_enabled = _disable_client_hook() try: # The call to execute_task should never raise an exception. If # it does, that indicates that there was an internal error. @@ -663,8 +661,6 @@ cdef CRayStatus task_execution_handler( else: logger.exception("SystemExit was raised from the worker") return CRayStatus.UnexpectedSystemExit() - finally: - _enable_client_hook(client_was_enabled) return CRayStatus.OK() diff --git a/python/ray/tests/test_client.py b/python/ray/tests/test_client.py index 6e6bea5ab1b79..940ec0ffd7b2a 100644 --- a/python/ray/tests/test_client.py +++ b/python/ray/tests/test_client.py @@ -3,6 +3,7 @@ import time import sys import logging +import queue import threading import _thread @@ -12,6 +13,7 @@ from ray.util.client.ray_client_helpers import connect_to_client_or_not from ray.util.client.ray_client_helpers import ray_start_client_server from ray._private.client_mode_hook import client_mode_should_convert +from ray._private.client_mode_hook import disable_client_hook from ray._private.client_mode_hook import enable_client_mode @@ -48,6 +50,33 @@ def run(self): assert ray.get(fast.remote(), timeout=5) == "ok" +# @pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") +# @pytest.mark.skip() +def test_client_mode_hook_thread_safe(ray_start_regular_shared): + with ray_start_client_server(): + with enable_client_mode(): + assert client_mode_should_convert() + lock = threading.Lock() + lock.acquire() + q = queue.Queue() + + def disable(): + with disable_client_hook(): + q.put(client_mode_should_convert()) + lock.acquire() + q.put(client_mode_should_convert()) + + t = threading.Thread(target=disable) + t.start() + assert client_mode_should_convert() + lock.release() + t.join() + assert q.get( + ) is False, "Threaded disable_client_hook failed to disable" + assert q.get( + ) is True, "Threaded disable_client_hook failed to re-enable" + + @pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") def test_interrupt_ray_get(call_ray_stop_only): import ray diff --git a/python/ray/tests/test_client_library_integration.py b/python/ray/tests/test_client_library_integration.py index 6b19844c0dc10..fd92ec0d2a477 100644 --- a/python/ray/tests/test_client_library_integration.py +++ b/python/ray/tests/test_client_library_integration.py @@ -4,7 +4,7 @@ from ray.rllib.examples import rock_paper_scissors_multiagent from ray.util.client.ray_client_helpers import ray_start_client_server -from ray._private.client_mode_hook import _explicitly_enable_client_mode,\ +from ray._private.client_mode_hook import enable_client_mode,\ client_mode_should_convert @@ -15,28 +15,28 @@ def test_rllib_integration(ray_start_regular_shared): # (Client mode hook not yet enabled.) assert not client_mode_should_convert() # Need to enable this for client APIs to be used. - _explicitly_enable_client_mode() - # Confirming mode hook is enabled. - assert client_mode_should_convert() + with enable_client_mode(): + # Confirming mode hook is enabled. + assert client_mode_should_convert() - rock_paper_scissors_multiagent.main() + rock_paper_scissors_multiagent.main() @pytest.mark.asyncio async def test_serve_handle(ray_start_regular_shared): with ray_start_client_server() as ray: from ray import serve - _explicitly_enable_client_mode() - serve.start() + with enable_client_mode(): + serve.start() - @serve.deployment - def hello(): - return "hello" + @serve.deployment + def hello(): + return "hello" - hello.deploy() - handle = hello.get_handle() - assert ray.get(handle.remote()) == "hello" - assert await handle.remote() == "hello" + hello.deploy() + handle = hello.get_handle() + assert ray.get(handle.remote()) == "hello" + assert await handle.remote() == "hello" if __name__ == "__main__": diff --git a/python/ray/util/client_connect.py b/python/ray/util/client_connect.py index 0ca3cbd2c2e76..49eca7cb4b465 100644 --- a/python/ray/util/client_connect.py +++ b/python/ray/util/client_connect.py @@ -1,6 +1,6 @@ from ray.util.client import ray from ray.job_config import JobConfig -from ray._private.client_mode_hook import _enable_client_hook +from ray._private.client_mode_hook import _set_client_hook_status from ray._private.client_mode_hook import _explicitly_enable_client_mode from typing import List, Tuple, Dict, Any @@ -20,7 +20,7 @@ def connect(conn_str: str, "accident?") # Enable the same hooks that RAY_CLIENT_MODE does, as # calling ray.util.connect() is specifically for using client mode. - _enable_client_hook(True) + _set_client_hook_status(True) _explicitly_enable_client_mode() # TODO(barakmich): https://github.com/ray-project/ray/issues/13274 diff --git a/python/ray/util/dask/BUILD b/python/ray/util/dask/BUILD index 3b1f5c936371d..c731b14a33ed3 100644 --- a/python/ray/util/dask/BUILD +++ b/python/ray/util/dask/BUILD @@ -101,15 +101,15 @@ py_test( deps = [":dask_lib"], ) -# This is currently failing. -#py_test( -# name = "dask_ray_shuffle_optimization_client_mode", -# size = "medium", -# main = "dask_ray_shuffle_optimization.py", -# srcs = ["examples/dask_ray_shuffle_optimization.py"], -# tags = ["exclusive", "client"], -# deps = [":dask_lib"], -#) + +py_test( + name = "dask_ray_shuffle_optimization_client_mode", + size = "medium", + main = "dask_ray_shuffle_optimization.py", + srcs = ["examples/dask_ray_shuffle_optimization.py"], + tags = ["exclusive", "client"], + deps = [":dask_lib"], +) # This is a dummy test dependency that causes the above tests to be # re-run if any of these files changes.