Skip to content

Commit

Permalink
[serve] safe draining
Browse files Browse the repository at this point in the history
Implement safe draining.

Signed-off-by: Cindy Zhang <cindyzyx9@gmail.com>
  • Loading branch information
zcin committed Feb 18, 2024
1 parent b879e2c commit 476cee1
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 39 deletions.
10 changes: 5 additions & 5 deletions python/ray/serve/_private/cluster_node_info_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple

import ray
from ray._raylet import GcsClient
Expand Down Expand Up @@ -53,7 +53,7 @@ def get_alive_node_ids(self) -> Set[str]:
return {node_id for node_id, _ in self.get_alive_nodes()}

@abstractmethod
def get_draining_node_ids(self) -> Set[str]:
def get_draining_node_ids(self) -> Dict[str, int]:
"""Get IDs of all draining nodes in the cluster."""
raise NotImplementedError

Expand All @@ -67,15 +67,15 @@ def get_active_node_ids(self) -> Set[str]:
A node is active if it's schedulable for new tasks and actors.
"""
return self.get_alive_node_ids() - self.get_draining_node_ids()
return self.get_alive_node_ids() - set(self.get_draining_node_ids())


class DefaultClusterNodeInfoCache(ClusterNodeInfoCache):
def __init__(self, gcs_client: GcsClient):
super().__init__(gcs_client)

def get_draining_node_ids(self) -> Set[str]:
return set()
def get_draining_node_ids(self) -> Dict[str, int]:
return dict()

def get_node_az(self, node_id: str) -> Optional[str]:
"""Get availability zone of a node."""
Expand Down
1 change: 1 addition & 0 deletions python/ray/serve/_private/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class ReplicaState(str, Enum):
RECOVERING = "RECOVERING"
RUNNING = "RUNNING"
STOPPING = "STOPPING"
TO_BE_STOPPED = "TO_BE_STOPPED"


class ApplicationStatus(str, Enum):
Expand Down
4 changes: 2 additions & 2 deletions python/ray/serve/_private/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,8 @@ def _update_proxy_nodes(self):
(head node and nodes with deployment replicas).
"""
new_proxy_nodes = self.deployment_state_manager.get_active_node_ids()
new_proxy_nodes = (
new_proxy_nodes - self.cluster_node_info_cache.get_draining_node_ids()
new_proxy_nodes = new_proxy_nodes - set(
self.cluster_node_info_cache.get_draining_node_ids()
)
new_proxy_nodes.add(self._controller_node_id)
self._proxy_nodes = new_proxy_nodes
Expand Down
34 changes: 32 additions & 2 deletions python/ray/serve/_private/deployment_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2250,7 +2250,7 @@ def _check_and_update_replicas(self):
def _stop_replicas_on_draining_nodes(self):
draining_nodes = self._cluster_node_info_cache.get_draining_node_ids()
for replica in self._replicas.pop(
states=[ReplicaState.UPDATING, ReplicaState.RUNNING]
states=[ReplicaState.UPDATING, ReplicaState.RUNNING, ReplicaState.STARTING]
):
if replica.actor_node_id in draining_nodes:
state = replica._actor_details.state
Expand All @@ -2259,10 +2259,40 @@ def _stop_replicas_on_draining_nodes(self):
f"of deployment '{self.deployment_name}' in application "
f"'{self.app_name}' on draining node {replica.actor_node_id}."
)
self._stop_replica(replica, graceful_stop=True)
self._replicas.add(ReplicaState.TO_BE_STOPPED, replica)
else:
self._replicas.add(replica.actor_details.state, replica)

# Stop replicas whose deadline is up
for replica in self._replicas.pop(states=[ReplicaState.TO_BE_STOPPED]):
current_timestamp_ms = time.time() * 1000
timeout_ms = replica._actor.graceful_shutdown_timeout_s * 1000
print("current timestamp", current_timestamp_ms)
print("deadline", draining_nodes[replica.actor_node_id])
print("timeout", timeout_ms)
if (
replica.actor_node_id in draining_nodes
and draining_nodes[replica.actor_node_id] > 0
and current_timestamp_ms
>= draining_nodes[replica.actor_node_id] - timeout_ms
):
self._stop_replica(replica, graceful_stop=True)
else:
self._replicas.add(ReplicaState.TO_BE_STOPPED, replica)

# Stop excess replicas
# TODO(zcin): greedily choose replicas with the earliest deadlines
num_running_or_towards_running = self._replicas.count(
states=[ReplicaState.RUNNING]
)
num_to_be_stopped = self._replicas.count(states=[ReplicaState.TO_BE_STOPPED])
num_excess = num_running_or_towards_running - num_to_be_stopped

for replica in self._replicas.pop(
states=[ReplicaState.TO_BE_STOPPED], max_replicas=num_excess
):
self._stop_replica(replica, graceful_stop=True)

def update(self) -> DeploymentStateUpdateResult:
"""Attempts to reconcile this deployment to match its goal state.
Expand Down
223 changes: 193 additions & 30 deletions python/ray/serve/tests/unit/test_deployment_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,6 @@ def start(self, deployment_info: DeploymentInfo):
self.started = True

def _on_scheduled_stub(*args, **kwargs):
print(
f"ReplicaSchedulingRequest.on_scheduled was invoked with:\n"
f"args={args}\n"
f"kwargs={kwargs}"
)
pass

return ReplicaSchedulingRequest(
Expand Down Expand Up @@ -1308,29 +1303,103 @@ def test_deploy_new_config_new_version(mock_deployment_state_manager_full):
)


def test_stop_replicas_on_draining_nodes(mock_deployment_state_manager_full):
create_dsm, _, cluster_node_info_cache = mock_deployment_state_manager_full
def test_stop_replicas_on_draining_nodes_1(mock_deployment_state_manager_full):
create_dsm, timer, cluster_node_info_cache = mock_deployment_state_manager_full
dsm: DeploymentStateManager = create_dsm()
cluster_node_info_cache.draining_node_ids = {"node-1", "node-2"}
timer.reset(0)

b_info_1, v1 = deployment_info(num_replicas=2, version="1")
updated = dsm.deploy(TEST_DEPLOYMENT_ID, b_info_1)
assert updated
b_info_1, v1 = deployment_info(
num_replicas=2, graceful_shutdown_timeout_s=20, version="1"
)
assert dsm.deploy(TEST_DEPLOYMENT_ID, b_info_1)
ds = dsm._deployment_states[TEST_DEPLOYMENT_ID]

dsm.update()
check_counts(ds, total=2, by_state=[(ReplicaState.STARTING, 2, v1)])
assert ds.curr_status_info.status == DeploymentStatus.UPDATING
assert (
ds.curr_status_info.status_trigger
== DeploymentStatusTrigger.CONFIG_UPDATE_STARTED

# Drain node-2 with deadline 60. Since the replicas are still
# starting and we don't know the actor node id yet nothing happens
cluster_node_info_cache.draining_node_ids = {"node-2": 60 * 1000}
dsm.update()
check_counts(ds, total=2, by_state=[(ReplicaState.STARTING, 2, v1)])

one_replica, another_replica = ds._replicas.get()

one_replica._actor.set_node_id("node-1")
one_replica._actor.set_ready()

another_replica._actor.set_node_id("node-2")
another_replica._actor.set_ready()

# Try to start a new replica before initiating the graceful stop
# process for the replica on the draining node
dsm.update()
check_counts(
ds,
total=3,
by_state=[
(ReplicaState.RUNNING, 1, v1),
(ReplicaState.TO_BE_STOPPED, 1, v1),
(ReplicaState.STARTING, 1, v1),
],
)

# Drain node-2.
cluster_node_info_cache.draining_node_ids = {"node-2"}
# 5 seconds later, the replica hasn't started yet. The replica on
# the draining node should not start graceful termination yet.
timer.advance(5)
dsm.update()
check_counts(
ds,
total=3,
by_state=[
(ReplicaState.RUNNING, 1, v1),
(ReplicaState.TO_BE_STOPPED, 1, v1),
(ReplicaState.STARTING, 1, v1),
],
)

# Since the replicas are still starting and we don't know the actor node id
# yet so nothing happens
# Simulate it took 5 more seconds for the new replica to be started
timer.advance(5)
ds._replicas.get([ReplicaState.STARTING])[0]._actor.set_ready()
dsm.update()
check_counts(
ds,
total=3,
by_state=[
(ReplicaState.RUNNING, 2, v1),
(ReplicaState.STOPPING, 1, v1),
],
)

# After replica on draining node stops, deployment is healthy with 2
# running replicas.
another_replica._actor.set_done_stopping()
dsm.update()
check_counts(
ds,
total=2,
by_state=[(ReplicaState.RUNNING, 2, v1)],
)
assert ds.curr_status_info.status == DeploymentStatus.HEALTHY


def test_stop_replicas_on_draining_nodes_2(mock_deployment_state_manager_full):
create_dsm, timer, cluster_node_info_cache = mock_deployment_state_manager_full
dsm: DeploymentStateManager = create_dsm()
timer.reset(0)

b_info_1, v1 = deployment_info(
num_replicas=2, graceful_shutdown_timeout_s=20, version="1"
)
assert dsm.deploy(TEST_DEPLOYMENT_ID, b_info_1)
ds = dsm._deployment_states[TEST_DEPLOYMENT_ID]

dsm.update()
check_counts(ds, total=2, by_state=[(ReplicaState.STARTING, 2, v1)])

# Drain node-2 with deadline 60. Since the replicas are still
# starting and we don't know the actor node id yet nothing happens
cluster_node_info_cache.draining_node_ids = {"node-2": 60 * 1000}
dsm.update()
check_counts(ds, total=2, by_state=[(ReplicaState.STARTING, 2, v1)])

Expand All @@ -1342,9 +1411,8 @@ def test_stop_replicas_on_draining_nodes(mock_deployment_state_manager_full):
another_replica._actor.set_node_id("node-2")
another_replica._actor.set_ready()

# The replica running on node-2 will be drained.
# Simultaneously, a new replica will start to satisfy the target
# number of replicas.
# Try to start a new replica before initiating the graceful stop
# process for the replica on the draining node
dsm.update()
if RAY_SERVE_STOP_FULLY_THEN_START_REPLICAS:
check_counts(
Expand All @@ -1361,19 +1429,27 @@ def test_stop_replicas_on_draining_nodes(mock_deployment_state_manager_full):
total=3,
by_state=[
(ReplicaState.RUNNING, 1, v1),
(ReplicaState.STOPPING, 1, v1),
(ReplicaState.TO_BE_STOPPED, 1, v1),
(ReplicaState.STARTING, 1, v1),
],
)

# A new node is started.
cluster_node_info_cache.alive_node_ids = {
"node-1",
"node-2",
"node-3",
}
# Simulate the replica is not yet started after 40 seconds. The
# replica on node-2 should start graceful termination even though
# a new replica hasn't come up yet.
timer.advance(40)
dsm.update()
check_counts(
ds,
total=3,
by_state=[
(ReplicaState.RUNNING, 1, v1),
(ReplicaState.STOPPING, 1, v1),
(ReplicaState.STARTING, 1, v1),
],
)

# The draining replica is stopped.
# Mark replica as finished stopping.
another_replica._actor.set_done_stopping()
dsm.update()
check_counts(
Expand All @@ -1382,6 +1458,93 @@ def test_stop_replicas_on_draining_nodes(mock_deployment_state_manager_full):
by_state=[(ReplicaState.STARTING, 1, v1), (ReplicaState.RUNNING, 1, v1)],
)

# 5 seconds later, the replica finally starts.
timer.advance(5)
ds._replicas.get([ReplicaState.STARTING])[0]._actor.set_ready()
dsm.update()
check_counts(ds, total=2, by_state=[(ReplicaState.RUNNING, 2, v1)])
assert ds.curr_status_info.status == DeploymentStatus.HEALTHY


def test_stop_replicas_on_draining_nodes_3(mock_deployment_state_manager_full):
create_dsm, timer, cluster_node_info_cache = mock_deployment_state_manager_full
dsm: DeploymentStateManager = create_dsm()
timer.reset(0)

b_info_1, v1 = deployment_info(
num_replicas=2, graceful_shutdown_timeout_s=20, version="1"
)
assert dsm.deploy(TEST_DEPLOYMENT_ID, b_info_1)
ds = dsm._deployment_states[TEST_DEPLOYMENT_ID]

dsm.update()
check_counts(ds, total=2, by_state=[(ReplicaState.STARTING, 2, v1)])

# Drain node-2 with deadline 60. Since the replicas are still
# starting and we don't know the actor node id yet nothing happens
cluster_node_info_cache.draining_node_ids = {"node-2": 0}
dsm.update()
check_counts(ds, total=2, by_state=[(ReplicaState.STARTING, 2, v1)])

one_replica, another_replica = ds._replicas.get()

one_replica._actor.set_node_id("node-1")
one_replica._actor.set_ready()

another_replica._actor.set_node_id("node-2")
another_replica._actor.set_ready()

# Try to start a new replica before initiating the graceful stop
# process for the replica on the draining node
dsm.update()
check_counts(
ds,
total=3,
by_state=[
(ReplicaState.RUNNING, 1, v1),
(ReplicaState.TO_BE_STOPPED, 1, v1),
(ReplicaState.STARTING, 1, v1),
],
)

# Simulate the replica is not yet started after a gajillion seconds.
# The replica on node-2 should not start graceful termination.
timer.advance(1000000)
dsm.update()
check_counts(
ds,
total=3,
by_state=[
(ReplicaState.RUNNING, 1, v1),
(ReplicaState.TO_BE_STOPPED, 1, v1),
(ReplicaState.STARTING, 1, v1),
],
)

# Simulate it took 5 more seconds for the new replica to be started
timer.advance(5)
ds._replicas.get([ReplicaState.STARTING])[0]._actor.set_ready()
dsm.update()
check_counts(
ds,
total=3,
by_state=[
(ReplicaState.RUNNING, 2, v1),
(ReplicaState.STOPPING, 1, v1),
],
)

# After replica on draining node stops, deployment is healthy with 2
# running replicas.
another_replica._actor.set_done_stopping()
dsm.update()
check_counts(
ds,
total=2,
by_state=[(ReplicaState.RUNNING, 2, v1)],
)
assert ds.curr_status_info.status == DeploymentStatus.HEALTHY


def test_initial_deploy_no_throttling(mock_deployment_state_manager_full):
# All replicas should be started at once for a new deployment.
Expand Down

0 comments on commit 476cee1

Please sign in to comment.