Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serve] Use Async Handle for DAG Execution #27411

Merged
merged 22 commits into from
Aug 7, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/ray/dag/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,11 @@ py_test(
tags = ["exclusive", "team:core", "ray_dag_tests"],
deps = [":dag_lib"],
)

py_test(
name = "test_py_obj_scanner",
size = "small",
srcs = dag_tests_srcs,
tags = ["exclusive", "team:core", "ray_dag_tests"],
deps = [":dag_lib"],
)
62 changes: 47 additions & 15 deletions python/ray/dag/py_obj_scanner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from weakref import WeakValueDictionary
import ray

import io
Expand All @@ -12,16 +13,19 @@
else:
import pickle # noqa: F401

from typing import List, Dict, Any, TypeVar
from typing import Generic, List, Dict, Any, Type, TypeVar
from ray.dag.base import DAGNodeBase

T = TypeVar("T")

# Used in deserialization hooks to reference scanner instances.
_instances: Dict[int, "_PyObjScanner"] = {}

# Generic types for the scanner to transform from and to.
SourceType = TypeVar("SourceType")
TransformedType = TypeVar("TransformedType")

def _get_node(instance_id: int, node_index: int) -> DAGNodeBase:

def _get_node(instance_id: int, node_index: int) -> SourceType:
"""Get the node instance.

Note: This function should be static and globally importable,
Expand All @@ -30,50 +34,78 @@ def _get_node(instance_id: int, node_index: int) -> DAGNodeBase:
return _instances[instance_id]._replace_index(node_index)


class _PyObjScanner(ray.cloudpickle.CloudPickler):
"""Utility to find and replace DAGNodes in Python objects.
def _get_object(instance_id: int, node_index: int) -> Any:
"""Used to get arbitrary object other than SourceType.

Note: This function should be static and globally importable,
otherwise the serialization overhead would be very significant.
"""
return _instances[instance_id]._objects[node_index]


class _PyObjScanner(ray.cloudpickle.CloudPickler, Generic[SourceType, TransformedType]):
"""Utility to find and replace the `source_type` in Python objects.

This uses pickle to walk the PyObj graph and find first-level DAGNode
instances on ``find_nodes()``. The caller can then compute a replacement
table and then replace the nodes via ``replace_nodes()``.

Args:
source_type: the type of object to find and replace. Default to DAGNodeBase.
"""

def __init__(self):
def __init__(self, source_type: Type = DAGNodeBase):
self.source_type = source_type
# Buffer to keep intermediate serialized state.
self._buf = io.BytesIO()
# List of top-level DAGNodes found during the serialization pass.
# List of top-level SourceType found during the serialization pass.
self._found = None
# List of other objects found during the serializatoin pass.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# List of other objects found during the serializatoin pass.
# List of other objects found during the serialization pass.

# This is used to store references to objects so they won't be
# serialized by cloudpickle.
self._objects = WeakValueDictionary()
# Replacement table to consult during deserialization.
self._replace_table: Dict[DAGNodeBase, T] = None
self._replace_table: Dict[SourceType, TransformedType] = None
_instances[id(self)] = self
super().__init__(self._buf)

def reducer_override(self, obj):
"""Hook for reducing objects."""
if isinstance(obj, DAGNodeBase):
"""Hook for reducing objects.

The function intercepts serialization of all objects and store them
to internal data structures, preventing actually writing them to
the buffer.
"""
if obj is _get_object or obj is _get_object:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate

# Only fall back to cloudpickle for these two functions.
return super().reducer_override(obj)
elif isinstance(obj, self.source_type):
index = len(self._found)
self._found.append(obj)
return _get_node, (id(self), index)
else:
index = len(self._objects)
self._objects.append(obj)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do we do the dict object to do append?

return _get_object, (id(self), index)

return super().reducer_override(obj)

def find_nodes(self, obj: Any) -> List[DAGNodeBase]:
def find_nodes(self, obj: Any) -> List[SourceType]:
"""Find top-level DAGNodes."""
assert (
self._found is None
), "find_nodes cannot be called twice on the same PyObjScanner instance."
self._found = []
self._objects = []
self.dump(obj)
return self._found

def replace_nodes(self, table: Dict[DAGNodeBase, T]) -> Any:
def replace_nodes(self, table: Dict[SourceType, TransformedType]) -> Any:
"""Replace previously found DAGNodes per the given table."""
assert self._found is not None, "find_nodes must be called first"
self._replace_table = table
self._buf.seek(0)
return pickle.load(self._buf)

def _replace_index(self, i: int) -> DAGNodeBase:
def _replace_index(self, i: int) -> SourceType:
return self._replace_table[self._found[i]]

def __del__(self):
Expand Down
4 changes: 2 additions & 2 deletions python/ray/serve/_private/deployment_function_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ray.dag.format_utils import get_dag_node_str
from ray.serve.deployment import Deployment, schema_to_deployment
from ray.serve.config import DeploymentConfig
from ray.serve.handle import RayServeLazySyncHandle
from ray.serve.handle import RayServeLazyAsyncHandle
from ray.serve.schema import DeploymentSchema


Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(
_internal=True,
)
# TODO (jiaodong): Polish with async handle support later
self._deployment_handle = RayServeLazySyncHandle(self._deployment.name)
self._deployment_handle = RayServeLazyAsyncHandle(self._deployment.name)

def _copy_impl(
self,
Expand Down
4 changes: 2 additions & 2 deletions python/ray/serve/_private/deployment_graph_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
DeploymentFunctionExecutorNode,
)
from ray.serve._private.json_serde import DAGNodeEncoder
from ray.serve.handle import RayServeLazySyncHandle
from ray.serve.handle import RayServeLazyAsyncHandle
from ray.serve.schema import DeploymentSchema


Expand Down Expand Up @@ -153,7 +153,7 @@ def transform_ray_dag_to_serve_dag(
# serve DAG end to end executable.
def replace_with_handle(node):
if isinstance(node, DeploymentNode):
return RayServeLazySyncHandle(node._deployment.name)
return RayServeLazyAsyncHandle(node._deployment.name)
elif isinstance(node, DeploymentExecutorNode):
return node._deployment_handle

Expand Down
4 changes: 2 additions & 2 deletions python/ray/serve/_private/deployment_node.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict, Optional, List, Tuple

from ray.dag import DAGNode
from ray.serve.handle import RayServeLazySyncHandle
from ray.serve.handle import RayServeLazyAsyncHandle

from ray.dag.constants import PARENT_CLASS_NODE_KEY
from ray.dag.format_utils import get_dag_node_str
Expand Down Expand Up @@ -30,7 +30,7 @@ def __init__(
other_args_to_resolve=other_args_to_resolve,
)
self._deployment = deployment
self._deployment_handle = RayServeLazySyncHandle(self._deployment.name)
self._deployment_handle = RayServeLazyAsyncHandle(self._deployment.name)

def _copy_impl(
self,
Expand Down
2 changes: 1 addition & 1 deletion python/ray/serve/_private/http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def _send_request_to_handle(handle, scope, receive, send) -> str:
# call might never arrive; if it does, it can only be `http.disconnect`.
client_disconnection_task = loop.create_task(receive())
while retries < MAX_REPLICA_FAILURE_RETRIES:
assignment_task = loop.create_task(handle.remote(request))
assignment_task: asyncio.Task = handle.remote(request)
done, _ = await asyncio.wait(
[assignment_task, client_disconnection_task], return_when=FIRST_COMPLETED
)
Expand Down
10 changes: 5 additions & 5 deletions python/ray/serve/_private/json_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ray.serve.handle import (
HandleOptions,
RayServeHandle,
RayServeLazySyncHandle,
RayServeLazyAsyncHandle,
_serve_handle_to_json_dict,
_serve_handle_from_json_dict,
)
Expand Down Expand Up @@ -98,9 +98,9 @@ def default(self, obj):
DAGNODE_TYPE_KEY: RayServeDAGHandle.__name__,
"dag_node_json": obj.dag_node_json,
}
elif isinstance(obj, RayServeLazySyncHandle):
elif isinstance(obj, RayServeLazyAsyncHandle):
return {
DAGNODE_TYPE_KEY: RayServeLazySyncHandle.__name__,
DAGNODE_TYPE_KEY: RayServeLazyAsyncHandle.__name__,
"deployment_name": obj.deployment_name,
"handle_options_method_name": obj.handle_options.method_name,
}
Expand Down Expand Up @@ -148,8 +148,8 @@ def dagnode_from_json(input_json: Any) -> Union[DAGNode, RayServeHandle, Any]:
return RayServeDAGHandle(input_json["dag_node_json"])
elif input_json[DAGNODE_TYPE_KEY] == "DeploymentSchema":
return DeploymentSchema.parse_obj(input_json["schema"])
elif input_json[DAGNODE_TYPE_KEY] == RayServeLazySyncHandle.__name__:
return RayServeLazySyncHandle(
elif input_json[DAGNODE_TYPE_KEY] == RayServeLazyAsyncHandle.__name__:
return RayServeLazyAsyncHandle(
input_json["deployment_name"],
HandleOptions(input_json["handle_options_method_name"]),
)
Expand Down
12 changes: 12 additions & 0 deletions python/ray/serve/_private/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import ray
from ray.actor import ActorHandle
from ray.dag.py_obj_scanner import _PyObjScanner
from ray.exceptions import RayActorError, RayTaskError
from ray.util import metrics

Expand Down Expand Up @@ -45,6 +46,16 @@ class Query:
kwargs: Dict[Any, Any]
metadata: RequestMetadata

async def resolve_coroutines(self):
"""Find all unresolved asyncio.Task and gather them all at once."""
scanner = _PyObjScanner(source_type=asyncio.Task)
tasks = scanner.find_nodes((self.args, self.kwargs))

if len(tasks) > 0:
resolved = await asyncio.gather(*tasks)
replacement_table = dict(zip(tasks, resolved))
self.args, self.kwargs = scanner.replace_nodes(replacement_table)


class ReplicaSet:
"""Data structure representing a set of replica actor handles"""
Expand Down Expand Up @@ -216,6 +227,7 @@ async def assign_replica(self, query: Query) -> ray.ObjectRef:
self.num_queued_queries_gauge.set(
self.num_queued_queries, tags={"endpoint": endpoint}
)
await query.resolve_coroutines()
assigned_ref = self._try_assign_replica(query)
while assigned_ref is None: # Can't assign a replica right now.
logger.debug(
Expand Down
6 changes: 4 additions & 2 deletions python/ray/serve/deployment_graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import ray

from ray.dag.class_node import ClassNode # noqa: F401
from ray.dag.function_node import FunctionNode # noqa: F401
Expand Down Expand Up @@ -31,11 +32,12 @@ def _deserialize(cls, *args):
def __reduce__(self):
return RayServeDAGHandle._deserialize, (self.dag_node_json,)

def remote(self, *args, **kwargs):
async def remote(self, *args, **kwargs) -> ray.ObjectRef:
"""Execute the request, returns a ObjectRef representing final result."""
if self.dag_node is None:
from ray.serve._private.json_serde import dagnode_from_json

self.dag_node = json.loads(
self.dag_node_json, object_hook=dagnode_from_json
)
return self.dag_node.execute(*args, **kwargs)
return await self.dag_node.execute(*args, **kwargs)
20 changes: 12 additions & 8 deletions python/ray/serve/drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ray._private.utils import import_attr
from ray.serve.deployment_graph import RayServeDAGHandle
from ray.serve._private.http_util import ASGIHTTPSender
from ray.serve.handle import RayServeLazySyncHandle
from ray.serve.handle import RayServeLazyAsyncHandle
from ray.serve.exceptions import RayServeException
from ray import serve

Expand Down Expand Up @@ -101,20 +101,22 @@ def __init__(

if isinstance(dags, dict):
self.dags = dags
for route, handle in dags.items():
for route in dags.keys():

def endpoint_create(handle):
def endpoint_create(route):
@self.app.get(f"{route}")
@self.app.post(f"{route}")
async def handle_request(inp=Depends(http_adapter)):
return await handle.remote(inp)
return await self.predict_with_route(
route, inp # noqa: B023 function redefinition
)

# bind current handle with endpoint creation function
endpoint_create_func = functools.partial(endpoint_create, handle)
endpoint_create_func = functools.partial(endpoint_create, route)
endpoint_create_func()

else:
assert isinstance(dags, (RayServeDAGHandle, RayServeLazySyncHandle))
assert isinstance(dags, (RayServeDAGHandle, RayServeLazyAsyncHandle))
self.dags = {self.MATCH_ALL_ROUTE_PREFIX: dags}

# Single dag case, we will receive all prefix route
Expand All @@ -132,10 +134,12 @@ async def __call__(self, request: starlette.requests.Request):

async def predict(self, *args, **kwargs):
"""Perform inference directly without HTTP."""
return await self.dags[self.MATCH_ALL_ROUTE_PREFIX].remote(*args, **kwargs)
return await (
await self.dags[self.MATCH_ALL_ROUTE_PREFIX].remote(*args, **kwargs)
)

async def predict_with_route(self, route_path, *args, **kwargs):
"""Perform inference directly without HTTP for multi dags."""
if route_path not in self.dags:
raise RayServeException(f"{route_path} does not exist in dags routes")
return await self.dags[route_path].remote(*args, **kwargs)
return await (await self.dags[route_path].remote(*args, **kwargs))