Skip to content

Commit

Permalink
[Serve] Use Async Handle for DAG Execution (ray-project#27411)
Browse files Browse the repository at this point in the history
Signed-off-by: simon-mo <simon.mo@hey.com>
  • Loading branch information
simon-mo committed Aug 7, 2022
1 parent 8882ae4 commit edab36a
Show file tree
Hide file tree
Showing 37 changed files with 556 additions and 235 deletions.
2 changes: 1 addition & 1 deletion dashboard/modules/serve/tests/test_serve_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_delete(ray_start_stop):
"runtime_env": {
"working_dir": (
"https://github.com/ray-project/test_dag/archive/"
"76a741f6de31df78411b1f302071cde46f098418.zip"
"40d61c141b9c37853a7014b8659fc7f23c1d04f6.zip"
)
},
"deployments": [
Expand Down
219 changes: 219 additions & 0 deletions doc/source/serve/doc_code/migration_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import ray
from ray import serve
from ray.serve.drivers import DAGDriver
from ray.dag.input_node import InputNode
from ray.serve.handle import RayServeDeploymentHandle
from ray.serve.handle import RayServeSyncHandle

import requests
import starlette

serve.start()


# __raw_handle_graph_start__


@serve.deployment
class Model:
def forward(self, input):
# do some inference work
return "done"


@serve.deployment
class Preprocess:
def __init__(self, model_handle: RayServeSyncHandle):
self.model_handle = model_handle

async def __call__(self, input):
# do some preprocessing works for your inputs
return await self.model_handle.forward.remote(input)


Model.deploy()
model_handle = Model.get_handle()

Preprocess.deploy(model_handle)
preprocess_handle = Preprocess.get_handle()
ray.get(preprocess_handle.remote(1))

# __raw_handle_graph_end__

serve.shutdown()
serve.start()


# __single_deployment_old_api_start__
@serve.deployment
class Model:
def __call__(self, input: int):
# some inference work
return


Model.deploy()
handle = Model.get_handle()
handle.remote(1)
# __single_deployment_old_api_end__

serve.shutdown()
serve.start()


# __multi_deployments_old_api_start__
@serve.deployment
class Model:
def forward(self, input: int):
# some inference work
return


@serve.deployment
class Model2:
def forward(self, input: int):
# some inference work
return


Model.deploy()
Model2.deploy()
handle = Model.get_handle()
handle.forward.remote(1)

handle2 = Model2.get_handle()
handle2.forward.remote(1)
# __multi_deployments_old_api_end__

serve.shutdown()
serve.start()


# __customized_route_old_api_start__
@serve.deployment(route_prefix="/my_model1")
class Model:
def __call__(self, req: starlette.requests.Request):
# some inference work
return "done"


Model.deploy()
resp = requests.get("http://localhost:8000/my_model1", data="321")
# __customized_route_old_api_end__

serve.shutdown()


# __single_deployment_new_api_start__
@serve.deployment
class Model:
def __call__(self, input: int):
# some inference work
return


handle = serve.run(Model.bind())
handle.remote(1)
# __single_deployment_new_api_end__

serve.shutdown()


# __multi_deployments_new_api_start__
@serve.deployment
class Model:
def forward(self, input: int):
# some inference work
return


@serve.deployment
class Model2:
def forward(self, input: int):
# some inference work
return


with InputNode() as dag_input:
model = Model.bind()
model2 = Model2.bind()
d = DAGDriver.bind(
{
"/model1": model.forward.bind(dag_input),
"/model2": model2.forward.bind(dag_input),
}
)
handle = serve.run(d)
handle.predict_with_route.remote("/model1", 1)
handle.predict_with_route.remote("/model2", 1)

resp = requests.get("http://localhost:8000/model1", data="1")
resp = requests.get("http://localhost:8000/model2", data="1")
# __multi_deployments_new_api_end__

serve.shutdown()


# __customized_route_old_api_1_start__
@serve.deployment
class Model:
def __call__(self, req: starlette.requests.Request):
# some inference work
return "done"


d = DAGDriver.options(route_prefix="/my_model1").bind(Model.bind())
handle = serve.run(d)
resp = requests.get("http://localhost:8000/my_model1", data="321")
# __customized_route_old_api_1_end__

serve.shutdown()


# __customized_route_old_api_2_start__
@serve.deployment
class Model:
def __call__(self, req: starlette.requests.Request):
# some inference work
return "done"


@serve.deployment
class Model2:
def __call__(self, req: starlette.requests.Request):
# some inference work
return "done"


d = DAGDriver.bind({"/my_model1": Model.bind(), "/my_model2": Model2.bind()})
handle = serve.run(d)
resp = requests.get("http://localhost:8000/my_model1", data="321")
resp = requests.get("http://localhost:8000/my_model2", data="321")
# __customized_route_old_api_2_end__

serve.shutdown()


# __graph_with_new_api_start__
@serve.deployment
class Model:
def forward(self, input):
# do some inference work
return "done"


@serve.deployment
class Preprocess:
def __init__(self, model_handle: RayServeDeploymentHandle):
self.model_handle = model_handle

async def __call__(self, input):
# do some preprocessing works for your inputs
ref = await self.model_handle.forward.remote(input)
result = await ref
return result


handle = serve.run(Preprocess.bind(Model.bind()))
ray.get(handle.remote(1))
# __graph_with_new_api_end__
14 changes: 8 additions & 6 deletions doc/source/serve/doc_code/production_fruit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,36 @@
from ray import serve
from ray.serve.drivers import DAGDriver
from ray.serve.deployment_graph import InputNode
from ray.serve.handle import RayServeDeploymentHandle
from ray.serve.http_adapters import json_request

# These imports are used only for type hints:
from typing import Dict, List
from starlette.requests import Request
from ray.serve.deployment_graph import ClassNode


@serve.deployment(num_replicas=2)
class FruitMarket:
def __init__(
self,
mango_stand: ClassNode,
orange_stand: ClassNode,
pear_stand: ClassNode,
mango_stand: RayServeDeploymentHandle,
orange_stand: RayServeDeploymentHandle,
pear_stand: RayServeDeploymentHandle,
):
self.directory = {
"MANGO": mango_stand,
"ORANGE": orange_stand,
"PEAR": pear_stand,
}

def check_price(self, fruit: str, amount: float) -> float:
async def check_price(self, fruit: str, amount: float) -> float:
if fruit not in self.directory:
return -1
else:
fruit_stand = self.directory[fruit]
return ray.get(fruit_stand.check_price.remote(amount))
ref: ray.ObjectRef = await fruit_stand.check_price.remote(amount)
result = await ref
return result


@serve.deployment(user_config={"price": 3})
Expand Down
8 changes: 8 additions & 0 deletions python/ray/dag/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,11 @@ py_test(
tags = ["exclusive", "team:core", "ray_dag_tests"],
deps = [":dag_lib"],
)

py_test(
name = "test_py_obj_scanner",
size = "small",
srcs = dag_tests_srcs,
tags = ["exclusive", "team:core", "ray_dag_tests"],
deps = [":dag_lib"],
)

0 comments on commit edab36a

Please sign in to comment.