Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serve] Decrement ray_serve_deployment_queued_queries when client disconnects #37965

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 40 additions & 27 deletions python/ray/serve/_private/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,35 +974,48 @@ async def assign_request(
) -> Union[ray.ObjectRef, "ray._raylet.StreamingObjectRefGenerator"]:
"""Assign a query to a replica and return the resulting object_ref."""

self.num_router_requests.inc(
tags={"route": request_meta.route, "application": request_meta.app_name}
)
self.num_queued_queries += 1
self.num_queued_queries_gauge.set(
self.num_queued_queries,
tags={
"application": request_meta.app_name,
},
)
incremented_queue_metric = False
try:
self.num_router_requests.inc(
tags={"route": request_meta.route, "application": request_meta.app_name}
)
self.num_queued_queries += 1
self.num_queued_queries_gauge.set(
self.num_queued_queries,
tags={
"application": request_meta.app_name,
},
)
incremented_queue_metric += True

query = Query(
args=list(request_args),
kwargs=request_kwargs,
metadata=request_meta,
)
await query.resolve_async_tasks()
await query.buffer_starlette_requests_and_warn()
result = await self._replica_scheduler.assign_replica(query)

self.num_queued_queries -= 1
self.num_queued_queries_gauge.set(
self.num_queued_queries,
tags={
"application": request_meta.app_name,
},
)
query = Query(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add try here, so that we don't need to have incremented_queue_metric.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, that makes the code simpler. I made the change.

args=list(request_args),
kwargs=request_kwargs,
metadata=request_meta,
)
await query.resolve_async_tasks()
await query.buffer_starlette_requests_and_warn()
result = await self._replica_scheduler.assign_replica(query)

self.num_queued_queries -= 1
self.num_queued_queries_gauge.set(
self.num_queued_queries,
tags={
"application": request_meta.app_name,
},
)

return result
return result
except asyncio.CancelledError:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be an issue if different exceptions happens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it would. I changed the code to a try-finally block, so the metric is decremented no matter what exception is raised.

if incremented_queue_metric:
self.num_queued_queries -= 1
self.num_queued_queries_gauge.set(
self.num_queued_queries,
tags={
"application": request_meta.app_name,
},
)
raise

def shutdown(self):
"""Shutdown router gracefully.
Expand Down
51 changes: 50 additions & 1 deletion python/ray/serve/tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
from functools import partial
from multiprocessing import Pool
from typing import List, Dict, DefaultDict

import requests
Expand Down Expand Up @@ -871,6 +873,53 @@ def verify_metrics():
)


def test_queued_queries_disconnected(serve_start_shutdown):
"""Check that queued_queries decrements when queued requests disconnect."""

signal = SignalActor.remote()

@serve.deployment(
max_concurrent_queries=1,
graceful_shutdown_timeout_s=0.0001,
)
async def hang_on_first_request():
await signal.wait.remote()

serve.run(hang_on_first_request.bind())

def queue_size() -> float:
metrics = requests.get("http://127.0.0.1:9999").text
for line in metrics.split("\n"):
if "ray_serve_deployment_queued_queries" in line:
queue_size = line.split(" ")[-1]

return float(queue_size)

def first_request_executing(request_future) -> bool:
try:
request_future.get(timeout=0.1)
except Exception:
return ray.get(signal.cur_num_waiters.remote()) == 1

url = "http://localhost:8000/"

pool = Pool()

# Make a request to block the deployment from accepting other requests
fut = pool.apply_async(partial(requests.get, url))
wait_for_condition(lambda: first_request_executing(fut), timeout=5)

num_requests = 5
for _ in range(num_requests):
pool.apply_async(partial(requests.get, url))

# First request should be processing. All others should be queued.
wait_for_condition(lambda: queue_size() == num_requests, timeout=15)
pool.terminate()

wait_for_condition(lambda: queue_size() == 0, timeout=15)


def test_actor_summary(serve_instance):
@serve.deployment
def f():
Expand All @@ -885,7 +934,7 @@ def f():


def get_metric_dictionaries(name: str, timeout: float = 20) -> List[Dict]:
"""Gets a list of metric's dictionaries from metrics' text output.
"""Gets a list of metric's tags from metrics' text output.

Return:
Example:
Expand Down
Loading