Skip to content

Commit

Permalink
[Client] Add ray.client().disconnect() (#16021)
Browse files Browse the repository at this point in the history
  • Loading branch information
ijrsvt authored and DmitriGekhtman committed May 28, 2021
1 parent c74c503 commit f102eaf
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 63 deletions.
37 changes: 18 additions & 19 deletions doc/source/namespaces.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,34 +19,33 @@ Named actors are only accessible within their namespaces.

.. code-block:: python
import
import ray
@ray.remote
class Actor:
pass
# Job 1 creates two actors, "orange" and "purple" in the "colors" namespace.
ray.client().namespace("colors").connect()
Actor.options(name="orange", lifetime="detached")
Actor.options(name="purple", lifetime="detached")
ray.util.disconnect()
with ray.client().namespace("colors").connect():
Actor.options(name="orange", lifetime="detached")
Actor.options(name="purple", lifetime="detached")
# Job 2 is now connecting to a different namespace.
ray.client().namespace("fruits").connect()
# This fails because "orange" was defined in the "colors" namespace.
ray.get_actor("orange")
# This succceeds because the name "orange" is unused in this namespace.
Actor.options(name="orange", lifetime="detached")
Actor.options(name="watermelon", lifetime="detached")
ray.util.disconnect()
with ray.client().namespace("fruits").connect():
# This fails because "orange" was defined in the "colors" namespace.
ray.get_actor("orange")
# This succceeds because the name "orange" is unused in this namespace.
Actor.options(name="orange", lifetime="detached")
Actor.options(name="watermelon", lifetime="detached")
# Job 3 connects to the original "colors" namespace
ray.client().namespace("colors").connect()
context = ray.client().namespace("colors").connect()
# This fails because "watermelon" was in the fruits namespace.
ray.get_actor("watermelon")
# This returns the "orange" actor we created in the first job, not the second.
ray.get_actor("orange")
ray.util.disconnect()
context.disconnect()
# We are manually managing the scope of the connection in this example.
Anonymous namespaces
Expand All @@ -58,22 +57,22 @@ will not have access to actors in other namespaces.

.. code-block:: python
import
import ray
@ray.remote
class Actor:
pass
# Job 1 connects to an anonymous namespace by default
ray.client().connect()
ctx = ray.client().connect()
Actor.options(name="my_actor", lifetime="detached")
ray.util.disconnect()
ctx.disconnect()
# Job 2 connects to an _different_ anonymous namespace by default
ray.client().connect()
ctx = ray.client().connect()
# This succeeds because the second job is in its own namespace.
Actor.options(name="my_actor", lifetime="detached")
ray.util.disconnect()
ctx.disconnect()
.. note::

Expand Down
56 changes: 48 additions & 8 deletions python/ray/client_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from dataclasses import dataclass
from urllib.parse import urlparse
import sys
from typing import Any, Dict, Optional, Tuple

from ray.ray_constants import RAY_ADDRESS_ENVIRONMENT_VARIABLE
Expand All @@ -13,17 +14,44 @@


@dataclass
class ClientInfo:
class ClientContext:
"""
Basic information of the remote server for a given Ray Client connection.
Basic context manager for a ClientBuilder connection.
"""
dashboard_url: Optional[str]
python_version: str
ray_version: str
ray_commit: str
protocol_version: str
protocol_version: Optional[str]
_num_clients: int

def __enter__(self) -> "ClientContext":
return self

def __exit__(self, *exc) -> None:
self.disconnect()

def disconnect(self) -> None:
"""
Disconnect Ray. This either disconnects from the remote Client Server
or shuts the current driver down.
"""
if ray.util.client.ray.is_connected():
# This is only a client connected to a server.
ray.util.client_connect.disconnect()
ray._private.client_mode_hook._explicitly_disable_client_mode()
elif ray.worker.global_worker.node is None:
# Already disconnected.
return
elif ray.worker.global_worker.node.is_head():
logger.debug(
"The current Ray Cluster is scoped to this process. "
"Disconnecting is not possible as it will shutdown the "
"cluster.")
else:
# This is only a driver connected to an existing cluster.
ray.shutdown()


class ClientBuilder:
"""
Expand All @@ -45,15 +73,15 @@ def namespace(self, namespace: str) -> "ClientBuilder":
self._job_config.set_ray_namespace(namespace)
return self

def connect(self) -> ClientInfo:
def connect(self) -> ClientContext:
"""
Begin a connection to the address passed in via ray.client(...).
"""
client_info_dict = ray.util.client_connect.connect(
self.address, job_config=self._job_config)
dashboard_url = ray.get(
ray.remote(ray.worker.get_dashboard_url).remote())
return ClientInfo(
return ClientContext(
dashboard_url=dashboard_url,
python_version=client_info_dict["python_version"],
ray_version=client_info_dict["ray_version"],
Expand All @@ -63,11 +91,19 @@ def connect(self) -> ClientInfo:


class _LocalClientBuilder(ClientBuilder):
def connect(self) -> ClientInfo:
def connect(self) -> ClientContext:
"""
Begin a connection to the address passed in via ray.client(...).
"""
return ray.init(address=self.address, job_config=self._job_config)
connection_dict = ray.init(
address=self.address, job_config=self._job_config)
return ClientContext(
dashboard_url=connection_dict["webui_url"],
python_version="{}.{}.{}".format(
sys.version_info[0], sys.version_info[1], sys.version_info[2]),
ray_version=ray.__version__,
ray_commit=ray.__commit__,
protocol_version=None)


def _split_address(address: str) -> Tuple[str, str]:
Expand All @@ -87,7 +123,11 @@ def _get_builder_from_address(address: Optional[str]) -> ClientBuilder:
return _LocalClientBuilder(None)
if address is None:
try:
with open("/tmp/ray/current_cluster", "r") as f:
# NOTE: This is not placed in `Node::get_temp_dir_path`, because
# this file is accessed before the `Node` object is created.
cluster_file = os.path.join(ray._private.utils.get_user_temp_dir(),
"ray_current_cluster")
with open(cluster_file, "r") as f:
address = f.read()
print(address)
except FileNotFoundError:
Expand Down
51 changes: 42 additions & 9 deletions python/ray/tests/test_client_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,14 @@ def ping(self):

def test_connect_to_cluster(ray_start_regular_shared):
server = ray_client_server.serve("localhost:50055")
client_info = ray.client("localhost:50055").connect()

assert client_info.dashboard_url == ray.worker.get_dashboard_url()
python_version = ".".join([str(x) for x in list(sys.version_info)[:3]])
assert client_info.python_version == python_version
assert client_info.ray_version == ray.__version__
assert client_info.ray_commit == ray.__commit__
protocol_version = ray.util.client.CURRENT_PROTOCOL_VERSION
assert client_info.protocol_version == protocol_version
with ray.client("localhost:50055").connect() as client_context:
assert client_context.dashboard_url == ray.worker.get_dashboard_url()
python_version = ".".join([str(x) for x in list(sys.version_info)[:3]])
assert client_context.python_version == python_version
assert client_context.ray_version == ray.__version__
assert client_context.ray_commit == ray.__commit__
protocol_version = ray.util.client.CURRENT_PROTOCOL_VERSION
assert client_context.protocol_version == protocol_version

server.stop(0)
subprocess.check_output("ray stop --force", shell=True)
Expand Down Expand Up @@ -180,3 +179,37 @@ def ping(self):
retry_interval_ms=1000)
p1.kill()
subprocess.check_output("ray stop --force", shell=True)


def test_disconnect(call_ray_stop_only):
subprocess.check_output(
"ray start --head --ray-client-server-port=25555", shell=True)
with ray.client("localhost:25555").namespace("n1").connect():
# Connect via Ray Client
namespace = ray.get_runtime_context().namespace
assert namespace == "n1"
assert ray.util.client.ray.is_connected()

with pytest.raises(ray.exceptions.RaySystemError):
ray.put(300)

with ray.client(None).namespace("n1").connect():
# Connect Directly via Driver
namespace = ray.get_runtime_context().namespace
assert namespace == "n1"
assert not ray.util.client.ray.is_connected()

with pytest.raises(ray.exceptions.RaySystemError):
ray.put(300)

ctx = ray.client("localhost:25555").namespace("n1").connect()
# Connect via Ray Client
namespace = ray.get_runtime_context().namespace
assert namespace == "n1"
assert ray.util.client.ray.is_connected()
ctx.disconnect()
# Check idempotency
ctx.disconnect()

with pytest.raises(ray.exceptions.RaySystemError):
ray.put(300)
14 changes: 6 additions & 8 deletions python/ray/tests/test_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,12 @@ def test_multiple_clients_use_different_drivers(call_ray_start):
"""
Test that each client uses a separate JobIDs and namespaces.
"""
ray.client("localhost:25001").connect()
job_id_one = ray.get_runtime_context().job_id
namespace_one = ray.get_runtime_context().namespace
ray.util.disconnect()
ray.client("localhost:25001").connect()
job_id_two = ray.get_runtime_context().job_id
namespace_two = ray.get_runtime_context().namespace
ray.util.disconnect()
with ray.client("localhost:25001").connect():
job_id_one = ray.get_runtime_context().job_id
namespace_one = ray.get_runtime_context().namespace
with ray.client("localhost:25001").connect():
job_id_two = ray.get_runtime_context().job_id
namespace_two = ray.get_runtime_context().namespace

assert job_id_one != job_id_two
assert namespace_one != namespace_two
Expand Down
30 changes: 11 additions & 19 deletions python/ray/tests/test_runtime_env_complicated.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,15 @@ def remove_tf_env(tf_version: str):

check_remote_client_conda = """
import ray
ray.client("localhost:24001").env({{"conda" : "tf-{tf_version}"}}).connect()
context = ray.client("localhost:24001").env({{"conda" : "tf-{tf_version}"}}).\\
connect()
@ray.remote
def get_tf_version():
import tensorflow as tf
return tf.__version__
assert ray.get(get_tf_version.remote()) == "{tf_version}"
ray.util.disconnect()
context.disconnect()
"""


Expand All @@ -96,9 +97,8 @@ def get_tf_version(self):

tf_versions = ["2.2.0", "2.3.0"]
for i, tf_version in enumerate(tf_versions):
try:
runtime_env = {"conda": f"tf-{tf_version}"}
ray.client("localhost:24001").env(runtime_env).connect()
runtime_env = {"conda": f"tf-{tf_version}"}
with ray.client("localhost:24001").env(runtime_env).connect():
assert ray.get(get_tf_version.remote()) == tf_version
actor_handle = TfVersionActor.remote()
assert ray.get(actor_handle.get_tf_version.remote()) == tf_version
Expand All @@ -108,9 +108,6 @@ def get_tf_version(self):
other_tf_version = tf_versions[(i + 1) % 2]
run_string_as_driver(
check_remote_client_conda.format(tf_version=other_tf_version))
finally:
ray.util.disconnect()
ray._private.client_mode_hook._explicitly_disable_client_mode()


@pytest.mark.skipif(
Expand Down Expand Up @@ -387,28 +384,23 @@ def test_conda_create_ray_client(call_ray_start):
]
}
}
try:
ray.client("localhost:24001").env(runtime_env).connect()

@ray.remote
def f():
import pip_install_test # noqa
return True
@ray.remote
def f():
import pip_install_test # noqa
return True

with ray.client("localhost:24001").env(runtime_env).connect():
with pytest.raises(ModuleNotFoundError):
# Ensure pip-install-test is not installed on the test machine
import pip_install_test # noqa
assert ray.get(f.remote())

ray.util.disconnect()
ray.client("localhost:24001").connect()
with ray.client("localhost:24001").connect():
with pytest.raises(ModuleNotFoundError):
# Ensure pip-install-test is not installed in a client that doesn't
# use the runtime_env
ray.get(f.remote())
finally:
ray.util.disconnect()
ray._private.client_mode_hook._explicitly_disable_client_mode()


@pytest.mark.skipif(
Expand Down

0 comments on commit f102eaf

Please sign in to comment.