Skip to content

Commit

Permalink
Implement ClientObjectRef and ClientActorID in cython
Browse files Browse the repository at this point in the history
  • Loading branch information
mwtian committed May 28, 2021
1 parent cd71d5e commit 3c8a1db
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 95 deletions.
5 changes: 5 additions & 0 deletions python/ray/_raylet.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ cdef class ObjectRef(BaseID):

cdef CObjectID native(self)


cdef class ClientObjectRef(ObjectRef):
pass


cdef class ActorID(BaseID):
cdef CActorID data

Expand Down
58 changes: 58 additions & 0 deletions python/ray/includes/object_ref.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import logging
from typing import Callable, Any, Union

import ray
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.util.client as client

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -133,3 +135,59 @@ cdef class ObjectRef(BaseID):
core_worker = ray.worker.global_worker.core_worker
core_worker.set_get_async_callback(self, py_callback)
return self


cdef class ClientObjectRef(ObjectRef):

def __init__(self, id: bytes):
check_id(id)
self.data = CObjectID.FromBinary(<c_string>id)
client.ray.call_retain(id)
self.in_core_worker = False

def __dealloc__(self):
if client.ray.is_connected() and not self.data.IsNil():
client.ray.call_release(self.id)

@property
def id(self):
return self.binary()

def future(self) -> concurrent.futures.Future:
fut = concurrent.futures.Future()

def set_value(data: Any) -> None:
"""Schedules a callback to set the exception or result
in the Future."""

if isinstance(data, Exception):
fut.set_exception(data)
else:
fut.set_result(data)

self._on_completed(set_value)

# Prevent this object ref from being released.
fut.object_ref = self
return fut

def _on_completed(self, py_callback: Callable[[Any], None]) -> None:
"""Register a callback that will be called after Object is ready.
If the ObjectRef is already ready, the callback will be called soon.
The callback should take the result as the only argument. The result
can be an exception object in case of task error.
"""
from ray.util.client.client_pickler import loads_from_server

def deserialize_obj(resp: ray_client_pb2.DataResponse) -> None:
"""Converts from a GetResponse proto to a python object."""
obj = resp.get
data = None
if not obj.valid:
data = loads_from_server(resp.get.error)
else:
data = loads_from_server(resp.get.data)

py_callback(data)

client.ray._register_callback(self, deserialize_obj)
17 changes: 17 additions & 0 deletions python/ray/includes/unique_ids.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ cdef class WorkerID(UniqueID):
return <CWorkerID>self.data

cdef class ActorID(BaseID):

def __init__(self, id):
check_id(id, CActorID.Size())
self.data = CActorID.FromBinary(<c_string>id)
Expand Down Expand Up @@ -302,6 +303,22 @@ cdef class ActorID(BaseID):
return self.data.Hash()


cdef class ClientActorID(ActorID):

def __init__(self, id: bytes):
check_id(id, CActorID.Size())
self.data = CActorID.FromBinary(<c_string>id)
client.ray.call_retain(id)

def __dealloc__(self):
if client.ray.is_connected() and not self.data.IsNil():
client.ray.call_release(self.id)

@property
def id(self):
return self.binary()


cdef class FunctionID(UniqueID):

def __init__(self, id):
Expand Down
3 changes: 3 additions & 0 deletions python/ray/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ray.util.client.ray_client_helpers import ray_start_client_server
from ray._private.client_mode_hook import client_mode_should_convert
from ray._private.client_mode_hook import enable_client_mode
from ray._raylet import ObjectRef


@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
Expand Down Expand Up @@ -129,6 +130,8 @@ def test_put_get(ray_start_regular_shared):
assert not objectref == 1
# Make sure it returns True when necessary as well.
assert objectref == ClientObjectRef(objectref.id)
# Make sure ClientObjectRef is a subclass of ObjectRef
assert isinstance(objectref, ObjectRef)


@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
Expand Down
1 change: 1 addition & 0 deletions python/ray/util/client/ARCHITECTURE.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ For many of the objects in the root `ray` namespace, there is an equivalent clie
These objects are client stand-ins for their server-side objects. For example:
```
ObjectRef <-> ClientObjectRef
ActorID <-> ClientActorRef
RemoteFunc <-> ClientRemoteFunc
```

Expand Down
88 changes: 4 additions & 84 deletions python/ray/util/client/common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import ray._raylet as raylet
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray.util.client import ray
from ray.util.client.options import validate_options

import asyncio
import concurrent.futures
from dataclasses import dataclass
import grpc
import os
Expand All @@ -14,7 +13,6 @@
import json
import threading
from typing import Any
from typing import Callable
from typing import List
from typing import Dict
from typing import Optional
Expand Down Expand Up @@ -52,87 +50,9 @@
CLIENT_SERVER_MAX_THREADS = float(
os.getenv("RAY_CLIENT_SERVER_MAX_THREADS", 100))


class ClientBaseRef:
def __init__(self, id: bytes):
self.id = None
if not isinstance(id, bytes):
raise TypeError("ClientRefs must be created with bytes IDs")
self.id: bytes = id
ray.call_retain(id)

def binary(self):
return self.id

def hex(self):
return self.id.hex()

def __eq__(self, other):
return isinstance(other, ClientBaseRef) and self.id == other.id

def __repr__(self):
return "%s(%s)" % (
type(self).__name__,
self.id.hex(),
)

def __hash__(self):
return hash(self.id)

def __del__(self):
if ray.is_connected() and self.id is not None:
ray.call_release(self.id)


class ClientObjectRef(ClientBaseRef):
def __await__(self):
return self.as_future().__await__()

def as_future(self) -> asyncio.Future:
return asyncio.wrap_future(self.future())

def future(self) -> concurrent.futures.Future:
fut = concurrent.futures.Future()

def set_value(data: Any) -> None:
"""Schedules a callback to set the exception or result
in the Future."""

if isinstance(data, Exception):
fut.set_exception(data)
else:
fut.set_result(data)

self._on_completed(set_value)

# Prevent this object ref from being released.
fut.object_ref = self
return fut

def _on_completed(self, py_callback: Callable[[Any], None]) -> None:
"""Register a callback that will be called after Object is ready.
If the ObjectRef is already ready, the callback will be called soon.
The callback should take the result as the only argument. The result
can be an exception object in case of task error.
"""
from ray.util.client.client_pickler import loads_from_server

def deserialize_obj(resp: ray_client_pb2.DataResponse) -> None:
"""Converts from a GetResponse proto to a python object."""
obj = resp.get
data = None
if not obj.valid:
data = loads_from_server(resp.get.error)
else:
data = loads_from_server(resp.get.data)

py_callback(data)

ray._register_callback(self, deserialize_obj)


class ClientActorRef(ClientBaseRef):
pass
# Aliases for compatibility.
ClientObjectRef = raylet.ClientObjectRef
ClientActorRef = raylet.ClientActorID


class ClientStub:
Expand Down
3 changes: 1 addition & 2 deletions python/ray/util/dask/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import uuid

import ray
from ray.util.client.common import ClientObjectRef

from dask.base import quote
from dask.core import get as get_sync
Expand Down Expand Up @@ -47,7 +46,7 @@ def unpack_object_refs(*args):
object_refs_token = uuid.uuid4().hex

def _unpack(expr):
if isinstance(expr, (ray.ObjectRef, ClientObjectRef)):
if isinstance(expr, ray.ObjectRef):
token = expr.hex()
repack_dsk[token] = (getitem, object_refs_token, len(object_refs))
object_refs.append(expr)
Expand Down
6 changes: 2 additions & 4 deletions python/ray/util/dask/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from multiprocessing.pool import ThreadPool

import ray
from ray.util.client.common import ClientObjectRef

from dask.core import istask, ishashable, _execute_task
from dask.system import CPU_COUNT
Expand Down Expand Up @@ -370,9 +369,8 @@ def ray_get_unpack(object_refs):
if isinstance(object_refs, tuple):
object_refs = list(object_refs)

if isinstance(object_refs, list) and any(
not isinstance(x, (ray.ObjectRef, ClientObjectRef))
for x in object_refs):
if isinstance(object_refs, list) and any(not isinstance(x, ray.ObjectRef)
for x in object_refs):
# We flatten the object references before calling ray.get(), since Dask
# loves to nest collections in nested tuples and Ray expects a flat
# list of object references. We repack the results after ray.get()
Expand Down
6 changes: 1 addition & 5 deletions python/ray/util/placement_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import List
from typing import Optional
from typing import Union
from typing import TYPE_CHECKING

import ray
from ray._raylet import ObjectRef
Expand All @@ -13,9 +12,6 @@
from ray._private.client_mode_hook import client_mode_should_convert
from ray._private.client_mode_hook import client_mode_wrap

if TYPE_CHECKING:
from ray.util.common import ClientObjectRef # noqa

bundle_reservation_check = None


Expand Down Expand Up @@ -49,7 +45,7 @@ def __init__(self,
self.id = id
self.bundle_cache = bundle_cache

def ready(self) -> Union[ObjectRef, "ClientObjectRef"]:
def ready(self) -> ObjectRef:
"""Returns an ObjectRef to check ready status.
This API runs a small dummy task to wait for placement group creation.
Expand Down

0 comments on commit 3c8a1db

Please sign in to comment.