Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[serve] Change stopping behavior #43187

Merged
merged 1 commit into from Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/ray/serve/_private/constants.py
Expand Up @@ -306,3 +306,9 @@
("grpc.max_send_message_length", RAY_SERVE_GRPC_MAX_MESSAGE_SIZE),
("grpc.max_receive_message_length", RAY_SERVE_GRPC_MAX_MESSAGE_SIZE),
]

# Feature flag to eagerly start replacement replicas. This means new
# replicas will start before waiting for old replicas to fully stop.
RAY_SERVE_EAGERLY_START_REPLACEMENT_REPLICAS = (
os.environ.get("RAY_SERVE_EAGERLY_START_REPLACEMENT_REPLICAS", "1") == "1"
)
14 changes: 7 additions & 7 deletions python/ray/serve/_private/deployment_state.py
Expand Up @@ -36,6 +36,7 @@
from ray.serve._private.constants import (
MAX_DEPLOYMENT_CONSTRUCTOR_RETRY_COUNT,
RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE,
RAY_SERVE_EAGERLY_START_REPLACEMENT_REPLICAS,
RAY_SERVE_FORCE_STOP_UNHEALTHY_REPLICAS,
REPLICA_HEALTH_CHECK_UNHEALTHY_THRESHOLD,
SERVE_LOGGER_NAME,
Expand Down Expand Up @@ -1879,13 +1880,12 @@ def _scale_deployment_replicas(
return (upscale, downscale)

elif delta_replicas > 0:
# Don't ever exceed self._target_state.target_num_replicas.
stopping_replicas = self._replicas.count(
states=[
ReplicaState.STOPPING,
]
)
to_add = max(delta_replicas - stopping_replicas, 0)
to_add = delta_replicas
if not RAY_SERVE_EAGERLY_START_REPLACEMENT_REPLICAS:
# Don't ever exceed target_num_replicas.
stopping_replicas = self._replicas.count(states=[ReplicaState.STOPPING])
to_add = max(delta_replicas - stopping_replicas, 0)

if to_add > 0:
# Exponential backoff
failed_to_start_threshold = min(
Expand Down
52 changes: 48 additions & 4 deletions python/ray/serve/_private/test_utils.py
@@ -1,7 +1,7 @@
import asyncio
import threading
import time
from typing import Any, Dict, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple

import grpc
import pytest
Expand All @@ -14,6 +14,7 @@
from ray.actor import ActorHandle
from ray.serve._private.common import DeploymentID, DeploymentStatus
from ray.serve._private.constants import SERVE_DEFAULT_APP_NAME, SERVE_NAMESPACE
from ray.serve._private.deployment_state import ALL_REPLICA_STATES, ReplicaState
from ray.serve._private.proxy import DRAINING_MESSAGE
from ray.serve._private.usage import ServeUsageTag
from ray.serve._private.utils import TimerBase
Expand Down Expand Up @@ -125,7 +126,7 @@ def check_num_replicas_gte(
) -> int:
"""Check if num replicas is >= target."""

assert get_num_running_replicas(name) >= target
assert get_num_running_replicas(name, app_name) >= target
return True


Expand All @@ -134,7 +135,7 @@ def check_num_replicas_eq(
) -> int:
"""Check if num replicas is == target."""

assert get_num_running_replicas(name) == target
assert get_num_running_replicas(name, app_name) == target
return True


Expand All @@ -143,7 +144,50 @@ def check_num_replicas_lte(
) -> int:
"""Check if num replicas is <= target."""

assert get_num_running_replicas(name) <= target
assert get_num_running_replicas(name, app_name) <= target
return True


def check_replica_counts(
controller: ActorHandle,
deployment_id: DeploymentID,
total: Optional[int] = None,
by_state: Optional[List[Tuple[ReplicaState, int, Callable]]] = None,
):
"""Uses _dump_replica_states_for_testing to check replica counts.

Args:
controller: A handle to the Serve controller.
deployment_id: The deployment to check replica counts for.
total: The total number of expected replicas for the deployment.
by_state: A list of tuples of the form
(replica state, number of replicas, filter function).
Used for more fine grained checks.
"""
replicas = ray.get(
controller._dump_replica_states_for_testing.remote(deployment_id)
)

if total is not None:
replica_counts = {
state: len(replicas.get([state]))
for state in ALL_REPLICA_STATES
if replicas.get([state])
}
assert replicas.count() == total, replica_counts

if by_state is not None:
for state, count, check in by_state:
assert isinstance(state, ReplicaState)
assert isinstance(count, int) and count >= 0
if check:
filtered = {r for r in replicas.get(states=[state]) if check(r)}
curr_count = len(filtered)
else:
curr_count = replicas.count(states=[state])
msg = f"Expected {count} for state {state} but got {curr_count}."
assert curr_count == count, msg

return True


Expand Down
16 changes: 16 additions & 0 deletions python/ray/serve/tests/BUILD
Expand Up @@ -202,3 +202,19 @@ py_test_module_list(
deps = ["//python/ray/serve:serve_lib", ":conftest", ":common"],
data = glob(["test_config_files/**/*"]),
)

# Test old stop-fully-then-start behavior.
# TODO(zcin): remove this after the old behavior is completely removed
py_test_module_list(
name_suffix="_with_stop_fully_then_start_behavior",
env={"RAY_SERVE_EAGERLY_START_REPLACEMENT_REPLICAS": "0"},
files = [
"test_controller_recovery.py",
"test_deploy.py",
"test_max_replicas_per_node.py",
],
size = "medium",
tags = ["exclusive", "no_windows", "team:serve"],
deps = ["//python/ray/serve:serve_lib", ":conftest", ":common"],
data = glob(["test_config_files/**/*"]),
)
2 changes: 2 additions & 0 deletions python/ray/serve/tests/test_autoscaling_policy.py
Expand Up @@ -206,6 +206,7 @@ def __call__(self):
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
@pytest.mark.parametrize("smoothing_factor", [1, 0.2])
@pytest.mark.parametrize("use_upscale_downscale_config", [True, False])
@mock.patch("ray.serve._private.router.HANDLE_METRIC_PUSH_INTERVAL_S", 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why's this needed?

Copy link
Contributor Author

@zcin zcin Feb 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this test (for RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE=0) when the deployment scales back down to 0, the last metric report pushed from the handle is often a non-zero number (because the push interval is set to 10 seconds). So for the next 10 seconds the # replicas keeps oscillating between 0 and 1 because of the outdated metric from the handle. I'm not sure why we never ran into this before, but this seems like a totally reasonable scenario so I set the handle push interval lower to avoid making the test wait longer.

def test_e2e_scale_up_down_with_0_replica(
serve_instance, smoothing_factor, use_upscale_downscale_config
):
Expand Down Expand Up @@ -395,6 +396,7 @@ def __call__(self):


@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
@mock.patch("ray.serve._private.router.HANDLE_METRIC_PUSH_INTERVAL_S", 1)
def test_e2e_intermediate_downscaling(serve_instance):
"""
Scales up, then down, and up again.
Expand Down
144 changes: 72 additions & 72 deletions python/ray/serve/tests/test_controller_recovery.py
Expand Up @@ -4,7 +4,6 @@
import re
import sys
import time
from collections import defaultdict

import pytest
import requests
Expand All @@ -15,12 +14,13 @@
from ray.exceptions import RayTaskError
from ray.serve._private.common import DeploymentID, ReplicaState
from ray.serve._private.constants import (
RAY_SERVE_EAGERLY_START_REPLACEMENT_REPLICAS,
SERVE_CONTROLLER_NAME,
SERVE_DEFAULT_APP_NAME,
SERVE_NAMESPACE,
SERVE_PROXY_NAME,
)
from ray.serve._private.utils import get_random_string
from ray.serve._private.test_utils import check_replica_counts
from ray.serve.schema import LoggingConfig
from ray.serve.tests.test_failure import request_with_retries
from ray.util.state import list_actors
Expand Down Expand Up @@ -115,99 +115,99 @@ def __call__(self, *args):


def test_recover_rolling_update_from_replica_actor_names(serve_instance):
"""Test controller is able to recover starting -> updating -> running
"""Test controller can recover replicas during rolling update.

Replicas starting -> updating -> running
replicas from actor names, with right replica versions during rolling
update.
"""
client = serve_instance

name = "test"

@ray.remote(num_cpus=0)
def call(block=False):
handle = serve.get_deployment_handle(name, "app")
ret = handle.handler.remote(block).result()

return ret.split("|")[0], ret.split("|")[1]

signal_name = f"signal#{get_random_string()}"
signal = SignalActor.options(name=signal_name).remote()
signal = SignalActor.remote()

@serve.deployment(name=name, version="1", num_replicas=2)
@serve.deployment(name="test", num_replicas=2)
class V1:
async def handler(self, block: bool):
if block:
signal = ray.get_actor(signal_name)
await signal.wait.remote()

return f"1|{os.getpid()}"

async def __call__(self, request):
return await self.handler(request.query_params["block"] == "True")
async def __call__(self):
await signal.wait.remote()
return "1", os.getpid()

@serve.deployment(name="test", num_replicas=2)
class V2:
async def handler(self, *args):
return f"2|{os.getpid()}"

async def __call__(self, request):
return await self.handler()

def make_nonblocking_calls(expected, expect_blocking=False, num_returns=1):
# Returns dict[val, set(pid)].
blocking = []
responses = defaultdict(set)
start = time.time()
timeout_value = 60 if sys.platform == "win32" else 30
while time.time() - start < timeout_value:
refs = [call.remote(block=False) for _ in range(10)]
ready, not_ready = ray.wait(refs, timeout=5, num_returns=num_returns)
for ref in ready:
val, pid = ray.get(ref)
responses[val].add(pid)
for ref in not_ready:
blocking.extend(not_ready)

if all(len(responses[val]) >= num for val, num in expected.items()) and (
expect_blocking is False or len(blocking) > 0
):
break
else:
assert False, f"Timed out, responses: {responses}."
async def __call__(self):
return "2", os.getpid()

return responses, blocking
h = serve.run(V1.bind(), name="app")

serve.run(V1.bind(), name="app")
responses1, _ = make_nonblocking_calls({"1": 2}, num_returns=2)
pids1 = responses1["1"]
# Send requests to get pids of initial 2 replicas
signal.send.remote()
refs = [h.remote() for _ in range(10)]
versions, pids = zip(*[ref.result() for ref in refs])
assert versions.count("1") == 10
initial_pids = set(pids)
assert len(initial_pids) == 2

# ref2 will block a single replica until the signal is sent. Check that
# some requests are now blocking.
ref2 = call.remote(block=True)
responses2, blocking2 = make_nonblocking_calls({"1": 1}, expect_blocking=True)
assert list(responses2["1"])[0] in pids1
# blocked_ref will block a single replica until the signal is sent.
signal.send.remote(clear=True)
blocked_ref = h.remote()

# Kill the controller
ray.kill(serve.context._global_client._controller, no_restart=False)

# Redeploy new version. Since there is one replica blocking, only one new
# replica should be started up.
V2 = V1.options(func_or_class=V2, version="2")
# Redeploy new version.
serve._run(V2.bind(), _blocking=False, name="app")
with pytest.raises(TimeoutError):
client._wait_for_application_running("app", timeout_s=0.1)
responses3, blocking3 = make_nonblocking_calls({"1": 1}, expect_blocking=True)

# One replica of the old version should be stuck in stopping because
# of the blocked request.
if RAY_SERVE_EAGERLY_START_REPLACEMENT_REPLICAS:
# Two replicas of the new version should be brought up without
# waiting for the old replica to stop.
wait_for_condition(
check_replica_counts,
controller=serve_instance._controller,
deployment_id=DeploymentID("test", "app"),
total=3,
by_state=[
(ReplicaState.STOPPING, 1, lambda r: r._actor.pid in initial_pids),
(ReplicaState.RUNNING, 2, lambda r: r._actor.pid not in initial_pids),
],
)

# All new requests should be sent to the new running replicas
refs = [h.remote() for _ in range(10)]
versions, pids = zip(*[ref.result(timeout_s=5) for ref in refs])
assert versions.count("2") == 10
pids2 = set(pids)
assert len(pids2 & initial_pids) == 0
else:
with pytest.raises(TimeoutError):
serve_instance._wait_for_application_running("app", timeout_s=0.1)

refs = [h.remote() for _ in range(10)]

# Kill the controller
ray.kill(serve.context._global_client._controller, no_restart=False)

# Signal the original call to exit.
# Release the signal so that the old replica can shutdown
ray.get(signal.send.remote())
val, pid = ray.get(ref2)
val, pid = blocked_ref.result()
assert val == "1"
assert pid in responses1["1"]
assert pid in initial_pids

if not RAY_SERVE_EAGERLY_START_REPLACEMENT_REPLICAS:
versions, pids = zip(*[ref.result(timeout_s=5) for ref in refs])
assert versions.count("1") == 10
assert len(set(pids) & initial_pids)

# Now the goal and requests to the new version should complete.
# We should have two running replicas of the new version.
client._wait_for_application_running("app")
make_nonblocking_calls({"2": 2}, num_returns=2)
serve_instance._wait_for_application_running("app")
check_replica_counts(
controller=serve_instance._controller,
deployment_id=DeploymentID("test", "app"),
total=2,
by_state=(
[(ReplicaState.RUNNING, 2, lambda r: r._actor.pid not in initial_pids)]
),
)


def test_controller_recover_initializing_actor(serve_instance):
Expand Down