Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable deployment on Serve of functions that take no parameters #19708

Merged
merged 4 commits into from
Oct 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions python/ray/serve/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,15 @@ async def invoke_single(self, request_item: Query) -> Any:
start = time.time()
method_to_call = None
try:
method_to_call = sync_to_async(
self.get_runner_method(request_item))
result = await method_to_call(*args, **kwargs)
runner_method = self.get_runner_method(request_item)
method_to_call = sync_to_async(runner_method)
result = None
if len(inspect.signature(runner_method).parameters) > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this check need to be > 1 for class methods because of self? Where does the self arg get injected? It may work as-is because self is already bound in method_to_call. @simon-mo would probably know.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's already bound:

In [1]: class A:
   ...:     def b(self):
   ...:         pass
   ...:

In [2]: runner_method = A().b

In [3]: import inspect

In [4]: inspect.signature(runner_method).parameters
Out[4]: mappingproxy({})

result = await method_to_call(*args, **kwargs)
else:
# The method doesn't take in anything, including the request
# information, so we pass nothing into it
result = await method_to_call()

result = await self.ensure_serializable_response(result)
self.request_counter.inc()
Expand Down
67 changes: 67 additions & 0 deletions python/ray/serve/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,73 @@ async def slow_numbers():
assert resp.status_code == 418


def test_deploy_sync_function_no_params(serve_instance):
@serve.deployment()
def sync_d():
return "sync!"

serve.start()

sync_d.deploy()
assert requests.get("http://localhost:8000/sync_d").text == "sync!"
assert ray.get(sync_d.get_handle().remote()) == "sync!"


def test_deploy_async_function_no_params(serve_instance):
@serve.deployment()
async def async_d():
await asyncio.sleep(5)
return "async!"

serve.start()

async_d.deploy()
assert requests.get("http://localhost:8000/async_d").text == "async!"
assert ray.get(async_d.get_handle().remote()) == "async!"


def test_deploy_sync_class_no_params(serve_instance):
@serve.deployment
class Counter:
def __init__(self):
self.count = 0

def __call__(self):
self.count += 1
return {"count": self.count}

serve.start()
Counter.deploy()

assert requests.get("http://127.0.0.1:8000/Counter").json() == {"count": 1}
assert requests.get("http://127.0.0.1:8000/Counter").json() == {"count": 2}
assert ray.get(Counter.get_handle().remote()) == {"count": 3}


def test_deploy_async_class_no_params(serve_instance):
@serve.deployment
class AsyncCounter:
async def __init__(self):
await asyncio.sleep(5)
self.count = 0

async def __call__(self):
self.count += 1
await asyncio.sleep(5)
return {"count": self.count}

serve.start()
AsyncCounter.deploy()

assert requests.get("http://127.0.0.1:8000/AsyncCounter").json() == {
"count": 1
}
assert requests.get("http://127.0.0.1:8000/AsyncCounter").json() == {
"count": 2
}
assert ray.get(AsyncCounter.get_handle().remote()) == {"count": 3}


def test_user_config(serve_instance):
@serve.deployment(
"counter", num_replicas=2, user_config={
Expand Down