Skip to content

Commit

Permalink
[serve] safe draining
Browse files Browse the repository at this point in the history
Implement safe draining.

Signed-off-by: Cindy Zhang <cindyzyx9@gmail.com>
  • Loading branch information
zcin committed Feb 21, 2024
1 parent 38106e3 commit 483d90f
Show file tree
Hide file tree
Showing 5 changed files with 414 additions and 63 deletions.
10 changes: 5 additions & 5 deletions python/ray/serve/_private/cluster_node_info_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple

import ray
from ray._raylet import GcsClient
Expand Down Expand Up @@ -53,7 +53,7 @@ def get_alive_node_ids(self) -> Set[str]:
return {node_id for node_id, _ in self.get_alive_nodes()}

@abstractmethod
def get_draining_node_ids(self) -> Set[str]:
def get_draining_node_ids(self) -> Dict[str, int]:
"""Get IDs of all draining nodes in the cluster."""
raise NotImplementedError

Expand All @@ -67,15 +67,15 @@ def get_active_node_ids(self) -> Set[str]:
A node is active if it's schedulable for new tasks and actors.
"""
return self.get_alive_node_ids() - self.get_draining_node_ids()
return self.get_alive_node_ids() - set(self.get_draining_node_ids())


class DefaultClusterNodeInfoCache(ClusterNodeInfoCache):
def __init__(self, gcs_client: GcsClient):
super().__init__(gcs_client)

def get_draining_node_ids(self) -> Set[str]:
return set()
def get_draining_node_ids(self) -> Dict[str, int]:
return dict()

def get_node_az(self, node_id: str) -> Optional[str]:
"""Get availability zone of a node."""
Expand Down
1 change: 1 addition & 0 deletions python/ray/serve/_private/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class ReplicaState(str, Enum):
RECOVERING = "RECOVERING"
RUNNING = "RUNNING"
STOPPING = "STOPPING"
DRAINING = "DRAINING"


class ApplicationStatus(str, Enum):
Expand Down
4 changes: 2 additions & 2 deletions python/ray/serve/_private/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,8 @@ def _update_proxy_nodes(self):
(head node and nodes with deployment replicas).
"""
new_proxy_nodes = self.deployment_state_manager.get_active_node_ids()
new_proxy_nodes = (
new_proxy_nodes - self.cluster_node_info_cache.get_draining_node_ids()
new_proxy_nodes = new_proxy_nodes - set(
self.cluster_node_info_cache.get_draining_node_ids()
)
new_proxy_nodes.add(self._controller_node_id)
self._proxy_nodes = new_proxy_nodes
Expand Down
56 changes: 50 additions & 6 deletions python/ray/serve/_private/deployment_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2248,20 +2248,64 @@ def _check_and_update_replicas(self):
del self.replica_average_ongoing_requests[replica.replica_tag]

def _stop_replicas_on_draining_nodes(self):
draining_nodes = self._cluster_node_info_cache.get_draining_node_ids()
draining_node_deadlines = self._cluster_node_info_cache.get_draining_node_ids()
for replica in self._replicas.pop(
states=[ReplicaState.UPDATING, ReplicaState.RUNNING]
states=[ReplicaState.UPDATING, ReplicaState.RUNNING, ReplicaState.STARTING]
):
if replica.actor_node_id in draining_nodes:
state = replica._actor_details.state
if replica.actor_node_id in draining_node_deadlines:
self._replicas.add(ReplicaState.DRAINING, replica)
else:
self._replicas.add(replica.actor_details.state, replica)

# Stop replicas whose deadline is up
for replica in self._replicas.pop(states=[ReplicaState.DRAINING]):
current_timestamp_ms = time.time() * 1000
timeout_ms = replica._actor.graceful_shutdown_timeout_s * 1000
if (
replica.actor_node_id in draining_node_deadlines
and draining_node_deadlines[replica.actor_node_id] > 0
and current_timestamp_ms
>= draining_node_deadlines[replica.actor_node_id] - timeout_ms
):
logger.info(
f"Stopping replica {replica.replica_tag} (currently {state}) "
f"Stopping replica {replica.replica_tag} "
f"of deployment '{self.deployment_name}' in application "
f"'{self.app_name}' on draining node {replica.actor_node_id}."
)
self._stop_replica(replica, graceful_stop=True)
else:
self._replicas.add(replica.actor_details.state, replica)
self._replicas.add(ReplicaState.DRAINING, replica)

# Stop excess DRAINING replicas when new "replacement" replicas
# have transitioned to RUNNING.
num_running = self._replicas.count(states=[ReplicaState.RUNNING])
num_draining = self._replicas.count(states=[ReplicaState.DRAINING])
num_excess = num_running + num_draining - self._target_state.target_num_replicas

draining_replicas = self._replicas.pop(states=[ReplicaState.DRAINING])

# Greedily choose replicas that have the earliest deadline
def order(deadline: int):
if deadline:
return deadline
else:
return float("inf")

draining_replicas.sort(
key=lambda r: order(draining_node_deadlines[r.actor_node_id])
)
for replica in draining_replicas:
if num_excess <= 0:
self._replicas.add(ReplicaState.DRAINING, replica)
continue

logger.info(
f"Stopping replica {replica.replica_tag} "
f"of deployment '{self.deployment_name}' in application "
f"'{self.app_name}' on draining node {replica.actor_node_id}."
)
self._stop_replica(replica, graceful_stop=True)
num_excess -= 1

def update(self) -> DeploymentStateUpdateResult:
"""Attempts to reconcile this deployment to match its goal state.
Expand Down

0 comments on commit 483d90f

Please sign in to comment.