Skip to content

Commit

Permalink
[serve] safe draining
Browse files Browse the repository at this point in the history
Implement safe draining.

When we receive notification that a node is draining, or will be terminated soon, try to start a new replica first before gracefully terminating replicas running on the draining node.

1. If a new replacement replica gets started before deadline - `graceful_shutdown_timeout_s`, then start graceful termination of the old replica after the new replacement replica starts.
2. If it takes longer for the replacement replica to start, then at the latest start graceful termination of the old replica at deadline - `graceful_shutdown_timeout_s`.
3. If there is no deadline, wait indefinitely for the new replacement replica to start before gracefully terminating the old replica.

Signed-off-by: Cindy Zhang <cindyzyx9@gmail.com>
  • Loading branch information
zcin committed Feb 22, 2024
1 parent 38106e3 commit 9ff9002
Show file tree
Hide file tree
Showing 5 changed files with 691 additions and 211 deletions.
12 changes: 6 additions & 6 deletions python/ray/serve/_private/cluster_node_info_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple

import ray
from ray._raylet import GcsClient
Expand Down Expand Up @@ -53,8 +53,8 @@ def get_alive_node_ids(self) -> Set[str]:
return {node_id for node_id, _ in self.get_alive_nodes()}

@abstractmethod
def get_draining_node_ids(self) -> Set[str]:
"""Get IDs of all draining nodes in the cluster."""
def get_draining_nodes(self) -> Dict[str, int]:
"""Get draining nodes in the cluster and their deadlines."""
raise NotImplementedError

@abstractmethod
Expand All @@ -67,15 +67,15 @@ def get_active_node_ids(self) -> Set[str]:
A node is active if it's schedulable for new tasks and actors.
"""
return self.get_alive_node_ids() - self.get_draining_node_ids()
return self.get_alive_node_ids() - set(self.get_draining_nodes())


class DefaultClusterNodeInfoCache(ClusterNodeInfoCache):
def __init__(self, gcs_client: GcsClient):
super().__init__(gcs_client)

def get_draining_node_ids(self) -> Set[str]:
return set()
def get_draining_nodes(self) -> Dict[str, int]:
return dict()

def get_node_az(self, node_id: str) -> Optional[str]:
"""Get availability zone of a node."""
Expand Down
1 change: 1 addition & 0 deletions python/ray/serve/_private/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class ReplicaState(str, Enum):
RECOVERING = "RECOVERING"
RUNNING = "RUNNING"
STOPPING = "STOPPING"
PENDING_MIGRATION = "PENDING_MIGRATION"


class ApplicationStatus(str, Enum):
Expand Down
4 changes: 2 additions & 2 deletions python/ray/serve/_private/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,8 @@ def _update_proxy_nodes(self):
(head node and nodes with deployment replicas).
"""
new_proxy_nodes = self.deployment_state_manager.get_active_node_ids()
new_proxy_nodes = (
new_proxy_nodes - self.cluster_node_info_cache.get_draining_node_ids()
new_proxy_nodes = new_proxy_nodes - set(
self.cluster_node_info_cache.get_draining_nodes()
)
new_proxy_nodes.add(self._controller_node_id)
self._proxy_nodes = new_proxy_nodes
Expand Down
130 changes: 111 additions & 19 deletions python/ray/serve/_private/deployment_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1377,7 +1377,9 @@ def app_name(self) -> str:
def get_running_replica_infos(self) -> List[RunningReplicaInfo]:
return [
replica.get_running_replica_info(self._cluster_node_info_cache)
for replica in self._replicas.get([ReplicaState.RUNNING])
for replica in self._replicas.get(
[ReplicaState.RUNNING, ReplicaState.PENDING_MIGRATION]
)
]

def get_active_node_ids(self) -> Set[str]:
Expand All @@ -1391,6 +1393,9 @@ def get_active_node_ids(self) -> Set[str]:
ReplicaState.UPDATING,
ReplicaState.RECOVERING,
ReplicaState.RUNNING,
# NOTE(zcin): We still want a proxy to run on a draining
# node before all the replicas are migrated.
ReplicaState.PENDING_MIGRATION,
]
return {
replica.actor_node_id
Expand Down Expand Up @@ -1629,7 +1634,9 @@ def get_total_num_requests(self) -> float:
"""

total_requests = 0
running_replicas = self._replicas.get([ReplicaState.RUNNING])
running_replicas = self._replicas.get(
[ReplicaState.RUNNING, ReplicaState.PENDING_MIGRATION]
)

if (
RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE
Expand Down Expand Up @@ -1745,7 +1752,11 @@ def _stop_or_update_outdated_version_replicas(self, max_to_stop=math.inf) -> boo
"""
replicas_to_update = self._replicas.pop(
exclude_version=self._target_state.version,
states=[ReplicaState.STARTING, ReplicaState.RUNNING],
states=[
ReplicaState.STARTING,
ReplicaState.PENDING_MIGRATION,
ReplicaState.RUNNING,
],
)
replicas_changed = False
code_version_changes = 0
Expand Down Expand Up @@ -1781,7 +1792,7 @@ def _stop_or_update_outdated_version_replicas(self, max_to_stop=math.inf) -> boo
f"{replica.replica_tag}, deployment_name: {self.deployment_name}, "
f"app_name: {self.app_name}"
)
# We don't allow going from STARTING to UPDATING.
# We don't allow going from STARTING, PENDING_MIGRATION to UPDATING.
else:
self._replicas.add(replica.actor_details.state, replica)

Expand Down Expand Up @@ -1821,7 +1832,11 @@ def _check_and_stop_outdated_version_replicas(self) -> bool:
# terminate them and start new version replicas instead.
old_running_replicas = self._replicas.count(
exclude_version=self._target_state.version,
states=[ReplicaState.STARTING, ReplicaState.UPDATING, ReplicaState.RUNNING],
states=[
ReplicaState.STARTING,
ReplicaState.UPDATING,
ReplicaState.RUNNING,
],
)
old_stopping_replicas = self._replicas.count(
exclude_version=self._target_state.version, states=[ReplicaState.STOPPING]
Expand Down Expand Up @@ -2133,9 +2148,11 @@ def _check_and_update_replicas(self):
transition happened.
"""

for replica in self._replicas.pop(states=[ReplicaState.RUNNING]):
for replica in self._replicas.pop(
states=[ReplicaState.RUNNING, ReplicaState.PENDING_MIGRATION]
):
if replica.check_health():
self._replicas.add(ReplicaState.RUNNING, replica)
self._replicas.add(replica.actor_details.state, replica)
self.health_check_gauge.set(
1,
tags={
Expand Down Expand Up @@ -2247,22 +2264,97 @@ def _check_and_update_replicas(self):
if replica.replica_tag in self.replica_average_ongoing_requests:
del self.replica_average_ongoing_requests[replica.replica_tag]

def _stop_replicas_on_draining_nodes(self):
draining_nodes = self._cluster_node_info_cache.get_draining_node_ids()
def _choose_pending_migration_replicas_to_stop(
self,
replicas: List[DeploymentReplica],
deadlines: Dict[str, int],
min_replicas_to_stop: int,
) -> Tuple[List[DeploymentReplica], List[DeploymentReplica]]:
"""Returns a partition of replicas to stop and to keep.
Args:
replicas: The current list of replicas pending migration.
deadlines: The current draining node deadlines.
min_replicas_to_stop: The minimum number of replicas to stop.
"""
to_stop = list()
remaining = list()

# Stop replicas whose deadline is up
for replica in replicas:
curr_timestamp_ms = time.time() * 1000
timeout_ms = replica._actor.graceful_shutdown_timeout_s * 1000
if (
replica.actor_node_id in deadlines
and curr_timestamp_ms >= deadlines[replica.actor_node_id] - timeout_ms
):
to_stop.append(replica)
else:
remaining.append(replica)

# Stop excess PENDING_MIGRATION replicas when new "replacement"
# replicas have transitioned to RUNNING. The replicas with the
# earliest deadlines should be chosen greedily.
def order(deadline: int):
if deadline:
return deadline
else:
return float("inf")

# remaining.sort(key=lambda r: order(deadlines[r.actor_node_id]))
remaining.sort(key=lambda r: deadlines[r.actor_node_id])
num_excess = min_replicas_to_stop - len(to_stop)

if num_excess > 0:
to_stop.extend(remaining[:num_excess])
remaining = remaining[num_excess:]

return to_stop, remaining

def _migrate_replicas_on_draining_nodes(self):
draining_node_deadlines = self._cluster_node_info_cache.get_draining_nodes()
for replica in self._replicas.pop(
states=[ReplicaState.UPDATING, ReplicaState.RUNNING]
states=[ReplicaState.UPDATING, ReplicaState.RUNNING, ReplicaState.STARTING]
):
if replica.actor_node_id in draining_nodes:
state = replica._actor_details.state
logger.info(
f"Stopping replica {replica.replica_tag} (currently {state}) "
f"of deployment '{self.deployment_name}' in application "
f"'{self.app_name}' on draining node {replica.actor_node_id}."
)
self._stop_replica(replica, graceful_stop=True)
if replica.actor_node_id in draining_node_deadlines:
# For RUNNING replicas, migrate them safely by starting
# a replacement replica first.
if replica.actor_details.state == ReplicaState.RUNNING:
self._replicas.add(ReplicaState.PENDING_MIGRATION, replica)
# For replicas that are STARTING or UPDATING, might as
# well terminate them immediately to allow replacement
# replicas to start. Otherwise we need to wait for them
# to transition to RUNNING before starting migration.
else:
self._stop_replica(replica, graceful_stop=True)
else:
self._replicas.add(replica.actor_details.state, replica)

num_running = self._replicas.count(states=[ReplicaState.RUNNING])
num_draining = self._replicas.count(states=[ReplicaState.PENDING_MIGRATION])
num_pending_migration_replicas_to_stop = (
num_running + num_draining - self._target_state.target_num_replicas
)

(
replicas_to_stop,
replicas_to_keep,
) = self._choose_pending_migration_replicas_to_stop(
self._replicas.pop(states=[ReplicaState.PENDING_MIGRATION]),
draining_node_deadlines,
num_pending_migration_replicas_to_stop,
)
for replica in replicas_to_stop:
logger.info(
f"Stopping replica {replica.replica_tag} "
f"of deployment '{self.deployment_name}' in application "
f"'{self.app_name}' on draining node {replica.actor_node_id}."
)
self._stop_replica(replica, graceful_stop=True)

for replica in replicas_to_keep:
self._replicas.add(ReplicaState.PENDING_MIGRATION, replica)

def update(self) -> DeploymentStateUpdateResult:
"""Attempts to reconcile this deployment to match its goal state.
Expand All @@ -2283,7 +2375,7 @@ def update(self) -> DeploymentStateUpdateResult:
# Check the state of existing replicas and transition if necessary.
self._check_and_update_replicas()

self._stop_replicas_on_draining_nodes()
self._migrate_replicas_on_draining_nodes()

upscale, downscale = self._scale_deployment_replicas()

Expand Down

0 comments on commit 9ff9002

Please sign in to comment.