Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[serve] safe draining #43228

Merged
merged 1 commit into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 6 additions & 6 deletions python/ray/serve/_private/cluster_node_info_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple

import ray
from ray._raylet import GcsClient
Expand Down Expand Up @@ -53,8 +53,8 @@ def get_alive_node_ids(self) -> Set[str]:
return {node_id for node_id, _ in self.get_alive_nodes()}

@abstractmethod
def get_draining_node_ids(self) -> Set[str]:
"""Get IDs of all draining nodes in the cluster."""
def get_draining_nodes(self) -> Dict[str, int]:
"""Get draining nodes in the cluster and their deadlines."""
raise NotImplementedError

@abstractmethod
Expand All @@ -67,15 +67,15 @@ def get_active_node_ids(self) -> Set[str]:

A node is active if it's schedulable for new tasks and actors.
"""
return self.get_alive_node_ids() - self.get_draining_node_ids()
return self.get_alive_node_ids() - set(self.get_draining_nodes())


class DefaultClusterNodeInfoCache(ClusterNodeInfoCache):
def __init__(self, gcs_client: GcsClient):
super().__init__(gcs_client)

def get_draining_node_ids(self) -> Set[str]:
return set()
def get_draining_nodes(self) -> Dict[str, int]:
return dict()

def get_node_az(self, node_id: str) -> Optional[str]:
"""Get availability zone of a node."""
Expand Down
1 change: 1 addition & 0 deletions python/ray/serve/_private/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class ReplicaState(str, Enum):
RECOVERING = "RECOVERING"
RUNNING = "RUNNING"
STOPPING = "STOPPING"
PENDING_MIGRATION = "PENDING_MIGRATION"


class ApplicationStatus(str, Enum):
Expand Down
4 changes: 2 additions & 2 deletions python/ray/serve/_private/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,8 @@ def _update_proxy_nodes(self):
(head node and nodes with deployment replicas).
"""
new_proxy_nodes = self.deployment_state_manager.get_active_node_ids()
new_proxy_nodes = (
new_proxy_nodes - self.cluster_node_info_cache.get_draining_node_ids()
new_proxy_nodes = new_proxy_nodes - set(
self.cluster_node_info_cache.get_draining_nodes()
)
new_proxy_nodes.add(self._controller_node_id)
self._proxy_nodes = new_proxy_nodes
Expand Down
130 changes: 111 additions & 19 deletions python/ray/serve/_private/deployment_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1377,7 +1377,9 @@ def app_name(self) -> str:
def get_running_replica_infos(self) -> List[RunningReplicaInfo]:
return [
replica.get_running_replica_info(self._cluster_node_info_cache)
for replica in self._replicas.get([ReplicaState.RUNNING])
for replica in self._replicas.get(
[ReplicaState.RUNNING, ReplicaState.PENDING_MIGRATION]
)
]

def get_active_node_ids(self) -> Set[str]:
Expand All @@ -1391,6 +1393,9 @@ def get_active_node_ids(self) -> Set[str]:
ReplicaState.UPDATING,
ReplicaState.RECOVERING,
ReplicaState.RUNNING,
# NOTE(zcin): We still want a proxy to run on a draining
# node before all the replicas are migrated.
ReplicaState.PENDING_MIGRATION,
]
return {
replica.actor_node_id
Expand Down Expand Up @@ -1629,7 +1634,9 @@ def get_total_num_requests(self) -> float:
"""

total_requests = 0
running_replicas = self._replicas.get([ReplicaState.RUNNING])
running_replicas = self._replicas.get(
[ReplicaState.RUNNING, ReplicaState.PENDING_MIGRATION]
)

if (
RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE
Expand Down Expand Up @@ -1745,7 +1752,11 @@ def _stop_or_update_outdated_version_replicas(self, max_to_stop=math.inf) -> boo
"""
replicas_to_update = self._replicas.pop(
exclude_version=self._target_state.version,
states=[ReplicaState.STARTING, ReplicaState.RUNNING],
states=[
ReplicaState.STARTING,
ReplicaState.PENDING_MIGRATION,
ReplicaState.RUNNING,
],
)
replicas_changed = False
code_version_changes = 0
Expand Down Expand Up @@ -1781,7 +1792,7 @@ def _stop_or_update_outdated_version_replicas(self, max_to_stop=math.inf) -> boo
f"{replica.replica_tag}, deployment_name: {self.deployment_name}, "
f"app_name: {self.app_name}"
)
# We don't allow going from STARTING to UPDATING.
# We don't allow going from STARTING, PENDING_MIGRATION to UPDATING.
else:
self._replicas.add(replica.actor_details.state, replica)

Expand Down Expand Up @@ -1821,7 +1832,11 @@ def _check_and_stop_outdated_version_replicas(self) -> bool:
# terminate them and start new version replicas instead.
old_running_replicas = self._replicas.count(
exclude_version=self._target_state.version,
states=[ReplicaState.STARTING, ReplicaState.UPDATING, ReplicaState.RUNNING],
states=[
ReplicaState.STARTING,
ReplicaState.UPDATING,
ReplicaState.RUNNING,
],
)
old_stopping_replicas = self._replicas.count(
exclude_version=self._target_state.version, states=[ReplicaState.STOPPING]
Expand Down Expand Up @@ -2133,9 +2148,11 @@ def _check_and_update_replicas(self):
transition happened.
"""

for replica in self._replicas.pop(states=[ReplicaState.RUNNING]):
for replica in self._replicas.pop(
states=[ReplicaState.RUNNING, ReplicaState.PENDING_MIGRATION]
):
if replica.check_health():
self._replicas.add(ReplicaState.RUNNING, replica)
self._replicas.add(replica.actor_details.state, replica)
self.health_check_gauge.set(
1,
tags={
Expand Down Expand Up @@ -2247,22 +2264,97 @@ def _check_and_update_replicas(self):
if replica.replica_tag in self.replica_average_ongoing_requests:
del self.replica_average_ongoing_requests[replica.replica_tag]

def _stop_replicas_on_draining_nodes(self):
draining_nodes = self._cluster_node_info_cache.get_draining_node_ids()
def _choose_pending_migration_replicas_to_stop(
self,
replicas: List[DeploymentReplica],
deadlines: Dict[str, int],
min_replicas_to_stop: int,
) -> Tuple[List[DeploymentReplica], List[DeploymentReplica]]:
"""Returns a partition of replicas to stop and to keep.

Args:
replicas: The current list of replicas pending migration.
deadlines: The current draining node deadlines.
min_replicas_to_stop: The minimum number of replicas to stop.
"""
to_stop = list()
remaining = list()

# Stop replicas whose deadline is up
for replica in replicas:
curr_timestamp_ms = time.time() * 1000
timeout_ms = replica._actor.graceful_shutdown_timeout_s * 1000
if (
replica.actor_node_id in deadlines
and curr_timestamp_ms >= deadlines[replica.actor_node_id] - timeout_ms
):
to_stop.append(replica)
else:
remaining.append(replica)

# Stop excess PENDING_MIGRATION replicas when new "replacement"
# replicas have transitioned to RUNNING. The replicas with the
# earliest deadlines should be chosen greedily.
def order(deadline: int):
if deadline:
return deadline
else:
return float("inf")

# remaining.sort(key=lambda r: order(deadlines[r.actor_node_id]))
remaining.sort(key=lambda r: deadlines[r.actor_node_id])
num_excess = min_replicas_to_stop - len(to_stop)

if num_excess > 0:
to_stop.extend(remaining[:num_excess])
remaining = remaining[num_excess:]

return to_stop, remaining

def _migrate_replicas_on_draining_nodes(self):
draining_node_deadlines = self._cluster_node_info_cache.get_draining_nodes()
for replica in self._replicas.pop(
states=[ReplicaState.UPDATING, ReplicaState.RUNNING]
states=[ReplicaState.UPDATING, ReplicaState.RUNNING, ReplicaState.STARTING]
):
if replica.actor_node_id in draining_nodes:
state = replica._actor_details.state
logger.info(
f"Stopping replica {replica.replica_tag} (currently {state}) "
f"of deployment '{self.deployment_name}' in application "
f"'{self.app_name}' on draining node {replica.actor_node_id}."
)
self._stop_replica(replica, graceful_stop=True)
if replica.actor_node_id in draining_node_deadlines:
# For RUNNING replicas, migrate them safely by starting
# a replacement replica first.
if replica.actor_details.state == ReplicaState.RUNNING:
self._replicas.add(ReplicaState.PENDING_MIGRATION, replica)
# For replicas that are STARTING or UPDATING, might as
# well terminate them immediately to allow replacement
# replicas to start. Otherwise we need to wait for them
# to transition to RUNNING before starting migration.
else:
self._stop_replica(replica, graceful_stop=True)
else:
self._replicas.add(replica.actor_details.state, replica)

num_running = self._replicas.count(states=[ReplicaState.RUNNING])
num_draining = self._replicas.count(states=[ReplicaState.PENDING_MIGRATION])
num_pending_migration_replicas_to_stop = (
num_running + num_draining - self._target_state.target_num_replicas
)

(
replicas_to_stop,
replicas_to_keep,
) = self._choose_pending_migration_replicas_to_stop(
self._replicas.pop(states=[ReplicaState.PENDING_MIGRATION]),
draining_node_deadlines,
num_pending_migration_replicas_to_stop,
)
for replica in replicas_to_stop:
logger.info(
f"Stopping replica {replica.replica_tag} "
f"of deployment '{self.deployment_name}' in application "
f"'{self.app_name}' on draining node {replica.actor_node_id}."
)
self._stop_replica(replica, graceful_stop=True)

for replica in replicas_to_keep:
self._replicas.add(ReplicaState.PENDING_MIGRATION, replica)

def update(self) -> DeploymentStateUpdateResult:
"""Attempts to reconcile this deployment to match its goal state.

Expand All @@ -2283,7 +2375,7 @@ def update(self) -> DeploymentStateUpdateResult:
# Check the state of existing replicas and transition if necessary.
self._check_and_update_replicas()

self._stop_replicas_on_draining_nodes()
self._migrate_replicas_on_draining_nodes()

upscale, downscale = self._scale_deployment_replicas()

Expand Down