Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
Signed-off-by: Cindy Zhang <cindyzyx9@gmail.com>
  • Loading branch information
zcin committed Jan 22, 2024
1 parent 929dc6b commit b09ae1d
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 27 deletions.
8 changes: 7 additions & 1 deletion .buildkite/serve.rayci.yml
Expand Up @@ -39,9 +39,15 @@ steps:
instance_type: large
commands:
- bazel run //ci/ray_ci:test_in_docker -- //python/ray/serve/... //python/ray/tests/... serve
--except-tags post_wheel_build,gpu,worker-container,ha_integration
--except-tags post_wheel_build,gpu,worker-container,ha_integration,autoscaling
--workers "$${BUILDKITE_PARALLEL_JOB_COUNT}" --worker-id "$${BUILDKITE_PARALLEL_JOB}" --parallelism-per-worker 3
--build-name servebuild --test-env=EXPECTED_PYTHON_VERSION=3.9
- bazel run //ci/ray_ci:test_in_docker -- //python/ray/serve/... --only-tags autoscaling
--workers "$${BUILDKITE_PARALLEL_JOB_COUNT}" --worker-id "$${BUILDKITE_PARALLEL_JOB}"
--test-env=EXPECTED_PYTHON_VERSION=3.9 --test-env=RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE=0
- bazel run //ci/ray_ci:test_in_docker -- //python/ray/serve/... --only-tags autoscaling
--workers "$${BUILDKITE_PARALLEL_JOB_COUNT}" --worker-id "$${BUILDKITE_PARALLEL_JOB}"
--test-env=EXPECTED_PYTHON_VERSION=3.9 --test-env=RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE=1
depends_on: "servebuild"

- label: ":ray-serve: serve: pydantic < 2.0 tests"
Expand Down
6 changes: 6 additions & 0 deletions python/ray/serve/_private/constants.py
Expand Up @@ -268,3 +268,9 @@

# The default autoscaling policy to use if none is specified.
DEFAULT_AUTOSCALING_POLICY = "ray.serve.autoscaling_policy:default_autoscaling_policy"

# Feature flag to enable collecting all queued and ongoing request
# metrics at handles instead of replicas. OFF by default.
RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE = (
os.environ.get("RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE", "0") == "1"
)
25 changes: 16 additions & 9 deletions python/ray/serve/_private/deployment_state.py
Expand Up @@ -35,6 +35,7 @@
from ray.serve._private.config import DeploymentConfig
from ray.serve._private.constants import (
MAX_DEPLOYMENT_CONSTRUCTOR_RETRY_COUNT,
RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE,
RAY_SERVE_FORCE_STOP_UNHEALTHY_REPLICAS,
REPLICA_HEALTH_CHECK_UNHEALTHY_THRESHOLD,
SERVE_LOGGER_NAME,
Expand Down Expand Up @@ -1595,14 +1596,18 @@ def get_total_num_requests(self) -> int:

total_requests = 0
running_replicas = self._replicas.get([ReplicaState.RUNNING])
for replica in running_replicas:
replica_tag = replica.replica_tag
if replica_tag in self.replica_average_ongoing_requests:
total_requests += self.replica_average_ongoing_requests[replica_tag][1]
if RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE:
for handle_metric in self.handle_requests.values():
total_requests += handle_metric[1]
else:
for replica in running_replicas:
id = replica.replica_tag
if id in self.replica_average_ongoing_requests:
total_requests += self.replica_average_ongoing_requests[id][1]

if len(running_replicas) == 0:
for handle_metrics in self.handle_requests.values():
total_requests += handle_metrics[1]
if len(running_replicas) == 0:
for handle_metrics in self.handle_requests.values():
total_requests += handle_metrics[1]

return total_requests

Expand All @@ -1613,11 +1618,12 @@ def autoscale(self) -> int:
return

total_num_requests = self.get_total_num_requests()
num_running_replicas = len(self.get_running_replica_infos())
autoscaling_policy_manager = self.autoscaling_policy_manager
decision_num_replicas = autoscaling_policy_manager.get_decision_num_replicas(
curr_target_num_replicas=self._target_state.target_num_replicas,
total_num_requests=total_num_requests,
num_running_replicas=len(self.get_running_replica_infos()),
num_running_replicas=num_running_replicas,
target_capacity=self._target_state.info.target_capacity,
target_capacity_direction=self._target_state.info.target_capacity_direction,
)
Expand All @@ -1631,7 +1637,8 @@ def autoscale(self) -> int:
logger.info(
f"Autoscaling replicas for deployment '{self.deployment_name}' in "
f"application '{self.app_name}' to {decision_num_replicas}. "
f"Current number of requests: {total_num_requests}."
f"Current number of requests: {total_num_requests}. Current number of "
f"running replicas: {num_running_replicas}."
)

new_info = copy(self._target_state.info)
Expand Down
3 changes: 2 additions & 1 deletion python/ray/serve/_private/replica.py
Expand Up @@ -33,6 +33,7 @@
DEFAULT_LATENCY_BUCKET_MS,
GRPC_CONTEXT_ARG_NAME,
HEALTH_CHECK_METHOD,
RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE,
RAY_SERVE_GAUGE_METRIC_SET_PERIOD_S,
RAY_SERVE_REPLICA_AUTOSCALING_METRIC_RECORD_PERIOD_S,
RECONFIGURE_METHOD,
Expand Down Expand Up @@ -180,7 +181,7 @@ def set_autoscaling_config(self, autoscaling_config: AutoscalingConfig):

self._autoscaling_config = autoscaling_config

if self._autoscaling_config:
if RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE and self._autoscaling_config:
# Push autoscaling metrics to the controller periodically.
self._metrics_pusher.register_task(
self.PUSH_METRICS_TO_CONTROLLER_TASK_NAME,
Expand Down
101 changes: 86 additions & 15 deletions python/ray/serve/_private/router.py
Expand Up @@ -4,10 +4,12 @@
import math
import pickle
import random
import threading
import time
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from dataclasses import dataclass
from functools import partial
from typing import (
Any,
AsyncGenerator,
Expand All @@ -26,9 +28,11 @@
from ray.actor import ActorHandle
from ray.dag.py_obj_scanner import _PyObjScanner
from ray.exceptions import RayActorError
from ray.serve._private.autoscaling_metrics import InMemoryMetricsStore
from ray.serve._private.common import DeploymentID, RequestProtocol, RunningReplicaInfo
from ray.serve._private.constants import (
HANDLE_METRIC_PUSH_INTERVAL_S,
RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE,
RAY_SERVE_MAX_QUEUE_LENGTH_RESPONSE_DEADLINE_S,
RAY_SERVE_MULTIPLEXED_MODEL_ID_MATCHING_TIMEOUT_S,
RAY_SERVE_PROXY_PREFER_LOCAL_AZ_ROUTING,
Expand All @@ -37,12 +41,14 @@
)
from ray.serve._private.long_poll import LongPollClient, LongPollNamespace
from ray.serve._private.utils import JavaActorHandleProxy, MetricsPusher
from ray.serve.config import AutoscalingConfig
from ray.serve.generated.serve_pb2 import RequestMetadata as RequestMetadataProto
from ray.serve.grpc_util import RayServegRPCContext
from ray.util import metrics

logger = logging.getLogger(SERVE_LOGGER_NAME)
PUSH_METRICS_TO_CONTROLLER_TASK_NAME = "push_metrics_to_controller"
RECORD_METRICS_TASK_NAME = "record_metrics"


@dataclass
Expand Down Expand Up @@ -917,7 +923,7 @@ async def assign_replica(
request, so it's up to the caller to time out or cancel the request.
"""
replica = await self.choose_replica_for_query(query)
return replica.send_query(query)
return replica.send_query(query), replica.replica_id


class Router:
Expand Down Expand Up @@ -972,6 +978,8 @@ def __init__(
{"deployment": deployment_id.name, "application": deployment_id.app}
)

self.num_queries_sent_to_replicas = defaultdict(int)
self._queries_lock = threading.Lock()
self.num_queued_queries = 0
self.num_queued_queries_gauge = metrics.Gauge(
"serve_deployment_queued_queries",
Expand All @@ -991,7 +999,7 @@ def __init__(
(
LongPollNamespace.RUNNING_REPLICAS,
deployment_id,
): self._replica_scheduler.update_running_replicas,
): self.update_running_replicas,
(
LongPollNamespace.AUTOSCALING_CONFIG,
deployment_id,
Expand All @@ -1001,10 +1009,20 @@ def __init__(
)

self.metrics_pusher = MetricsPusher()
self.metrics_store = InMemoryMetricsStore()
self.autoscaling_config = None
self.push_metrics_to_controller = controller_handle.record_handle_metrics.remote

def update_autoscaling_config(self, autoscaling_config):
def update_running_replicas(self, running_replicas: List[RunningReplicaInfo]):
# Prune list of replica ids
running_replica_set = {replica.replica_tag for replica in running_replicas}
for replica_id in self.num_queries_sent_to_replicas:
if replica_id not in running_replica_set:
del self.num_queries_sent_to_replicas[replica_id]

self._replica_scheduler.update_running_replicas(running_replicas)

def update_autoscaling_config(self, autoscaling_config: AutoscalingConfig):
self.autoscaling_config = autoscaling_config

# Start the metrics pusher if autoscaling is enabled.
Expand All @@ -1018,23 +1036,65 @@ def update_autoscaling_config(self, autoscaling_config):
and self.num_queued_queries
):
self.push_metrics_to_controller(
self._collect_handle_queue_metrics(), time.time()
self._get_aggregated_requests(), time.time()
)

self.metrics_pusher.register_task(
PUSH_METRICS_TO_CONTROLLER_TASK_NAME,
self._collect_handle_queue_metrics,
HANDLE_METRIC_PUSH_INTERVAL_S,
self.push_metrics_to_controller,
)

if RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE:
# Record number of queued + ongoing requests at regular
# intervals into the in-memory metrics store
self.metrics_pusher.register_task(
RECORD_METRICS_TASK_NAME,
self._get_num_requests_for_autoscaling,
min(0.5, self.autoscaling_config.metrics_interval_s),
self._add_autoscaling_metrics_point,
)
# Regularly push aggregated metrics to the controller
self.metrics_pusher.register_task(
PUSH_METRICS_TO_CONTROLLER_TASK_NAME,
self._get_aggregated_requests,
self.autoscaling_config.metrics_interval_s,
self.push_metrics_to_controller,
)
else:
self.metrics_pusher.register_task(
PUSH_METRICS_TO_CONTROLLER_TASK_NAME,
self._get_aggregated_requests,
HANDLE_METRIC_PUSH_INTERVAL_S,
self.push_metrics_to_controller,
)
self.metrics_pusher.start()
else:
if self.metrics_pusher:
self.metrics_pusher.shutdown()

def _collect_handle_queue_metrics(self) -> Dict[str, int]:
return (self.deployment_id, self.handle_id), self.num_queued_queries
def _get_num_requests_for_autoscaling(self) -> int:
if RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE:
data = self.num_queued_queries + sum(
self.num_queries_sent_to_replicas.values()
)
return data
else:
return self.num_queued_queries

def _add_autoscaling_metrics_point(self, data, send_timestamp: float):
self.metrics_store.add_metrics_point({self.deployment_id: data}, send_timestamp)

def _get_aggregated_requests(self):
if RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE:
look_back_period = self.autoscaling_config.look_back_period_s
window_avg = self.metrics_store.window_average(
self.deployment_id, time.time() - look_back_period
)
data = window_avg or self.num_queued_queries + sum(
self.num_queries_sent_to_replicas.values()
)
return (self.deployment_id, self.handle_id), data
else:
return (self.deployment_id, self.handle_id), self.num_queued_queries

def process_finished_request(self, replica_tag, *args):
with self._queries_lock:
self.num_queries_sent_to_replicas[replica_tag] -= 1

async def assign_request(
self,
Expand All @@ -1061,7 +1121,7 @@ async def assign_request(
and self.num_queued_queries == 1
):
self.push_metrics_to_controller(
self._collect_handle_queue_metrics(), time.time()
self._get_aggregated_requests(), time.time()
)

try:
Expand All @@ -1071,7 +1131,18 @@ async def assign_request(
metadata=request_meta,
)
await query.replace_known_types_in_args()
return await self._replica_scheduler.assign_replica(query)
ref, replica_tag = await self._replica_scheduler.assign_replica(query)

# Keep track of requests that have been sent out to replicas
if RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE:
self.num_queries_sent_to_replicas[replica_tag] += 1
callback = partial(self.process_finished_request, replica_tag)
if isinstance(ref, ray.ObjectRef):
ref._on_completed(callback)
else:
ref.completed()._on_completed(callback)

return ref
finally:
# If the query is disconnected before assignment, this coroutine
# gets cancelled by the caller and an asyncio.CancelledError is
Expand Down
10 changes: 9 additions & 1 deletion python/ray/serve/tests/BUILD
Expand Up @@ -66,7 +66,6 @@ py_test_module_list(
"test_regression.py",
"test_request_timeout.py",
"test_cluster.py",
"test_autoscaling_policy.py",
"test_cancellation.py",
"test_streaming_response.py",
"test_controller_recovery.py",
Expand All @@ -85,6 +84,15 @@ py_test_module_list(
deps = ["//python/ray/serve:serve_lib", ":conftest", ":common"],
)

py_test_module_list(
files = [
"test_autoscaling_policy.py",
],
size = "small",
tags = ["exclusive", "team:serve", "autoscaling"],
deps = ["//python/ray/serve:serve_lib", ":conftest", ":common"],
)

py_test_module_list(
files = [
"test_gcs_failure.py",
Expand Down
57 changes: 57 additions & 0 deletions python/ray/serve/tests/test_autoscaling_policy.py
Expand Up @@ -753,6 +753,63 @@ def send_request():
assert existing_pid in pids


@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
@pytest.mark.skipif(
os.environ.get("RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE", "0") == "0",
reason="Only works when collecting request metrics at handle.",
)
def test_max_concurrent_queries_set_to_one(serve_instance):
signal = SignalActor.remote()

@serve.deployment(
max_concurrent_queries=1,
autoscaling_config=AutoscalingConfig(
min_replicas=1,
max_replicas=5,
upscale_delay_s=0.5,
downscale_delay_s=0.5,
metrics_interval_s=0.5,
look_back_period_s=2,
),
ray_actor_options={"num_cpus": 0},
)
async def f():
await signal.wait.remote()
return os.getpid()

h = serve.run(f.bind())
check_num_replicas_eq("f", 1)

# Repeatedly (5 times):
# 1. Send a new request.
# 2. Wait for the number of waiters on signal to increase by 1.
# 3. Assert the number of replicas has increased by 1.
refs = []
for i in range(5):
refs.append(h.remote())

def check_num_waiters(target: int):
assert ray.get(signal.cur_num_waiters.remote()) == target
return True

wait_for_condition(check_num_waiters, target=i + 1)
print(time.time(), f"Number of waiters on signal reached {i+1}.")
check_num_replicas_eq("f", i + 1)
print(time.time(), f"Confirmed number of replicas are at {i+1}.")

print(time.time(), "Releasing signal.")
signal.send.remote()

# Check that pids returned are unique
# This implies that each replica only served one request, so the
# number of "running" requests per replica was at most 1 at any time;
# meaning the "queued" requests were taken into consideration for
# autoscaling.
pids = [ref.result() for ref in refs]
assert len(pids) == len(set(pids)), f"Pids {pids} are not unique."
print("Confirmed each replica only served one request.")


@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
def test_autoscaling_status_changes(serve_instance):
"""Test status changes when autoscaling deployments are deployed.
Expand Down

0 comments on commit b09ae1d

Please sign in to comment.