Skip to content

Commit

Permalink
[Client] Make Client_Hook per-thread (#16731)
Browse files Browse the repository at this point in the history
  • Loading branch information
ijrsvt committed Jun 30, 2021
1 parent bf4fcb2 commit 1f14f36
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 46 deletions.
49 changes: 34 additions & 15 deletions python/ray/_private/client_mode_hook.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,56 @@
import os
from contextlib import contextmanager
from functools import wraps
import threading

# Attr set on func defs to mark they have been converted to client mode.
RAY_CLIENT_MODE_ATTR = "__ray_client_mode_key__"

client_mode_enabled = os.environ.get("RAY_CLIENT_MODE", "0") == "1"
# Global setting of whether client mode is enabled. This default to OFF,
# but is enabled upon ray.client(...).connect() or in tests.
is_client_mode_enabled = os.environ.get("RAY_CLIENT_MODE", "0") == "1"
os.environ.update({"RAY_CLIENT_MODE": "0"})

_client_hook_enabled = True
# Local setting of whether to ignore client hook conversion. This defaults
# to TRUE and is disabled when the underlying 'real' Ray function is needed.
_client_hook_status_on_thread = threading.local()
_client_hook_status_on_thread.status = True


def _enable_client_hook(val: bool):
global _client_hook_enabled
_client_hook_enabled = val
def _get_client_hook_status_on_thread():
"""Get's the value of `_client_hook_status_on_thread`.
Since `_client_hook_status_on_thread` is a thread-local variable, we may
need to add and set the 'status' attribute.
"""
global _client_hook_status_on_thread
if not hasattr(_client_hook_status_on_thread, "status"):
_client_hook_status_on_thread.status = True
return _client_hook_status_on_thread.status


def _set_client_hook_status(val: bool):
global _client_hook_status_on_thread
_client_hook_status_on_thread.status = val


def _disable_client_hook():
global _client_hook_enabled
out = _client_hook_enabled
_client_hook_enabled = False
global _client_hook_status_on_thread
out = _get_client_hook_status_on_thread()
_client_hook_status_on_thread.status = False
return out


def _explicitly_enable_client_mode():
global client_mode_enabled
client_mode_enabled = True
"""Force client mode to be enabled.
NOTE: This should not be used in tests, use `enable_client_mode`.
"""
global is_client_mode_enabled
is_client_mode_enabled = True


def _explicitly_disable_client_mode():
global client_mode_enabled
client_mode_enabled = False
global is_client_mode_enabled
is_client_mode_enabled = False


@contextmanager
Expand All @@ -39,7 +59,7 @@ def disable_client_hook():
try:
yield None
finally:
_enable_client_hook(val)
_set_client_hook_status(val)


@contextmanager
Expand All @@ -65,8 +85,7 @@ def wrapper(*args, **kwargs):


def client_mode_should_convert():
global _client_hook_enabled
return client_mode_enabled and _client_hook_enabled
return is_client_mode_enabled and _get_client_hook_status_on_thread()


def client_mode_wrap(func):
Expand Down
8 changes: 2 additions & 6 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ from ray.exceptions import (
)
from ray._private.utils import decode
from ray._private.client_mode_hook import (
_enable_client_hook,
_disable_client_hook,
disable_client_hook,
)
import msgpack

Expand Down Expand Up @@ -620,9 +619,8 @@ cdef CRayStatus task_execution_handler(
const c_string debugger_breakpoint,
c_vector[shared_ptr[CRayObject]] *returns,
shared_ptr[LocalMemoryBuffer] &creation_task_exception_pb_bytes) nogil:
with gil:
with gil, disable_client_hook():
try:
client_was_enabled = _disable_client_hook()
try:
# The call to execute_task should never raise an exception. If
# it does, that indicates that there was an internal error.
Expand Down Expand Up @@ -663,8 +661,6 @@ cdef CRayStatus task_execution_handler(
else:
logger.exception("SystemExit was raised from the worker")
return CRayStatus.UnexpectedSystemExit()
finally:
_enable_client_hook(client_was_enabled)

return CRayStatus.OK()

Expand Down
29 changes: 29 additions & 0 deletions python/ray/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
import sys
import logging
import queue
import threading
import _thread

Expand All @@ -12,6 +13,7 @@
from ray.util.client.ray_client_helpers import connect_to_client_or_not
from ray.util.client.ray_client_helpers import ray_start_client_server
from ray._private.client_mode_hook import client_mode_should_convert
from ray._private.client_mode_hook import disable_client_hook
from ray._private.client_mode_hook import enable_client_mode


Expand Down Expand Up @@ -48,6 +50,33 @@ def run(self):
assert ray.get(fast.remote(), timeout=5) == "ok"


# @pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
# @pytest.mark.skip()
def test_client_mode_hook_thread_safe(ray_start_regular_shared):
with ray_start_client_server():
with enable_client_mode():
assert client_mode_should_convert()
lock = threading.Lock()
lock.acquire()
q = queue.Queue()

def disable():
with disable_client_hook():
q.put(client_mode_should_convert())
lock.acquire()
q.put(client_mode_should_convert())

t = threading.Thread(target=disable)
t.start()
assert client_mode_should_convert()
lock.release()
t.join()
assert q.get(
) is False, "Threaded disable_client_hook failed to disable"
assert q.get(
) is True, "Threaded disable_client_hook failed to re-enable"


@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
def test_interrupt_ray_get(call_ray_stop_only):
import ray
Expand Down
28 changes: 14 additions & 14 deletions python/ray/tests/test_client_library_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ray.rllib.examples import rock_paper_scissors_multiagent
from ray.util.client.ray_client_helpers import ray_start_client_server
from ray._private.client_mode_hook import _explicitly_enable_client_mode,\
from ray._private.client_mode_hook import enable_client_mode,\
client_mode_should_convert


Expand All @@ -15,28 +15,28 @@ def test_rllib_integration(ray_start_regular_shared):
# (Client mode hook not yet enabled.)
assert not client_mode_should_convert()
# Need to enable this for client APIs to be used.
_explicitly_enable_client_mode()
# Confirming mode hook is enabled.
assert client_mode_should_convert()
with enable_client_mode():
# Confirming mode hook is enabled.
assert client_mode_should_convert()

rock_paper_scissors_multiagent.main()
rock_paper_scissors_multiagent.main()


@pytest.mark.asyncio
async def test_serve_handle(ray_start_regular_shared):
with ray_start_client_server() as ray:
from ray import serve
_explicitly_enable_client_mode()
serve.start()
with enable_client_mode():
serve.start()

@serve.deployment
def hello():
return "hello"
@serve.deployment
def hello():
return "hello"

hello.deploy()
handle = hello.get_handle()
assert ray.get(handle.remote()) == "hello"
assert await handle.remote() == "hello"
hello.deploy()
handle = hello.get_handle()
assert ray.get(handle.remote()) == "hello"
assert await handle.remote() == "hello"


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions python/ray/util/client_connect.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ray.util.client import ray
from ray.job_config import JobConfig
from ray._private.client_mode_hook import _enable_client_hook
from ray._private.client_mode_hook import _set_client_hook_status
from ray._private.client_mode_hook import _explicitly_enable_client_mode

from typing import List, Tuple, Dict, Any
Expand All @@ -20,7 +20,7 @@ def connect(conn_str: str,
"accident?")
# Enable the same hooks that RAY_CLIENT_MODE does, as
# calling ray.util.connect() is specifically for using client mode.
_enable_client_hook(True)
_set_client_hook_status(True)
_explicitly_enable_client_mode()

# TODO(barakmich): https://github.com/ray-project/ray/issues/13274
Expand Down
18 changes: 9 additions & 9 deletions python/ray/util/dask/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ py_test(
deps = [":dask_lib"],
)

# This is currently failing.
#py_test(
# name = "dask_ray_shuffle_optimization_client_mode",
# size = "medium",
# main = "dask_ray_shuffle_optimization.py",
# srcs = ["examples/dask_ray_shuffle_optimization.py"],
# tags = ["exclusive", "client"],
# deps = [":dask_lib"],
#)

py_test(
name = "dask_ray_shuffle_optimization_client_mode",
size = "medium",
main = "dask_ray_shuffle_optimization.py",
srcs = ["examples/dask_ray_shuffle_optimization.py"],
tags = ["exclusive", "client"],
deps = [":dask_lib"],
)

# This is a dummy test dependency that causes the above tests to be
# re-run if any of these files changes.
Expand Down

0 comments on commit 1f14f36

Please sign in to comment.