Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serve] Support manually terminating a replica #3179

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
46 changes: 33 additions & 13 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4632,9 +4632,15 @@ def serve_status(all: bool, endpoint: bool, service_names: List[str]):
default=False,
required=False,
help='Skip confirmation prompt.')
@click.option('--replica-id',
'-r',
default=None,
type=int,
help='Tear down a given replica')
# pylint: disable=redefined-builtin
def serve_down(service_names: List[str], all: bool, purge: bool, yes: bool):
"""Teardown service(s).
def serve_down(service_names: List[str], all: bool, purge: bool, yes: bool,
dtran24 marked this conversation as resolved.
Show resolved Hide resolved
replica_id: Optional[int]):
"""Teardown service(s) or a replica.

SERVICE_NAMES is the name of the service (or glob pattern) to tear down. If
dtran24 marked this conversation as resolved.
Show resolved Hide resolved
both SERVICE_NAMES and ``--all`` are supplied, the latter takes precedence.
Expand All @@ -4660,6 +4666,9 @@ def serve_down(service_names: List[str], all: bool, purge: bool, yes: bool):
\b
# Forcefully tear down a service in failed status.
sky serve down failed-service --purge
\b
# Tear down a specific replica
dtran24 marked this conversation as resolved.
Show resolved Hide resolved
sky serve down my-service --replica-id 1
"""
if sum([len(service_names) > 0, all]) != 1:
argument_str = f'SERVICE_NAMES={",".join(service_names)}' if len(
Expand All @@ -4668,6 +4677,11 @@ def serve_down(service_names: List[str], all: bool, purge: bool, yes: bool):
raise click.UsageError(
'Can only specify one of SERVICE_NAMES or --all. '
f'Provided {argument_str!r}.')
replica_id_is_defined = replica_id is not None
if replica_id_is_defined and len(service_names) != 1:
raise click.UsageError(
'Must specify one and only one service when replica ID is '
'specified.')

_, handle = backend_utils.is_controller_up(
controller_type=controller_utils.Controllers.SKY_SERVE_CONTROLLER,
Expand All @@ -4676,17 +4690,23 @@ def serve_down(service_names: List[str], all: bool, purge: bool, yes: bool):
# Hint messages already printed by the call above.
sys.exit(1)

if not yes:
quoted_service_names = [f'{name!r}' for name in service_names]
service_identity_str = f'service(s) {", ".join(quoted_service_names)}'
if all:
service_identity_str = 'all services'
click.confirm(f'Terminating {service_identity_str}. Proceed?',
default=True,
abort=True,
show_default=True)

serve_lib.down(service_names=service_names, all=all, purge=purge)
if replica_id_is_defined:
if not yes:
click.confirm(f'Terminating replica ID {replica_id} in '
f'{service_names[0]!r}. Proceed?')
serve_lib.terminate_replica(service_names[0], replica_id, purge)
else:
if not yes:
quoted_service_names = [f'{name!r}' for name in service_names]
service_identity_str = (f'service(s) '
f'{", ".join(quoted_service_names)}')
if all:
service_identity_str = 'all services'
click.confirm(f'Terminating {service_identity_str}. Proceed?',
default=True,
abort=True,
show_default=True)
serve_lib.down(service_names=service_names, all=all, purge=purge)


@serve.command('logs', cls=_DocumentedCodeCommand)
Expand Down
2 changes: 2 additions & 0 deletions sky/serve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sky.serve.core import down
from sky.serve.core import status
from sky.serve.core import tail_logs
from sky.serve.core import terminate_replica
from sky.serve.core import up
from sky.serve.core import update
from sky.serve.serve_state import ReplicaStatus
Expand Down Expand Up @@ -41,6 +42,7 @@
'SKY_SERVE_CONTROLLER_NAME',
'SKYSERVE_METADATA_DIR',
'status',
'terminate_replica',
'tail_logs',
'up',
'update',
Expand Down
54 changes: 54 additions & 0 deletions sky/serve/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
import threading
import time
import traceback
from typing import Dict

import colorama
import fastapi
import uvicorn

from sky import global_user_state
from sky import serve
from sky import sky_logging
from sky.serve import autoscalers
Expand Down Expand Up @@ -89,6 +92,37 @@ def _run_autoscaler(self):
logger.error(f' Traceback: {traceback.format_exc()}')
time.sleep(self._autoscaler.get_decision_interval())

def _purge_replica(self, replica_id: int) -> Dict[str, str]:
logger.info(f'Purging replica {replica_id}...')
replica_info = serve_state.get_replica_info_from_id(
self._service_name, replica_id)
assert replica_info is not None
replica_cluster_is_remaining = False
if replica_info.status in serve_state.ReplicaStatus.failed_statuses():
if global_user_state.get_cluster_from_name(
replica_info.cluster_name) is not None:
replica_cluster_is_remaining = True
serve_state.remove_replica(self._service_name, replica_id)
if replica_cluster_is_remaining:
return {
'message':
f'{colorama.Fore.YELLOW}Purged replica {replica_id} '
f'with failed status ({replica_info.status}). This may'
f' indicate a resource leak. Please check the following'
f' SkyPilot cluster on the controller: '
f'{replica_info.cluster_name}{colorama.Style.RESET_ALL}'
}
Comment on lines +108 to +114
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as said in the comment above, we should cleanup this cluster as well. how about adding an argument purge: bool in ReplicaManager.scale_down and remove the replica record here?

if info.status_property.is_scale_down_succeeded(
self._get_initial_delay_seconds(info.version)):
# This means the cluster is deleted due to
# a scale down or the cluster is recovering
# from preemption. Delete the replica info
# so it won't count as a replica.
if info.status_property.preempted:
removal_reason = 'for preemption recovery'
else:
removal_reason = 'normally'
# Don't keep failed record for version mismatch replicas,
# since user should fixed the error before update.
elif info.version != self.latest_version:
removal_reason = 'for version outdated'
else:
logger.info(f'Termination of replica {replica_id} '
'finished. Replica info is kept since some '
'failure detected.')
serve_state.add_or_update_replica(self._service_name,
replica_id, info)
if removal_reason is not None:
serve_state.remove_replica(self._service_name, replica_id)
logger.info(f'Replica {replica_id} removed from the '
f'replica table {removal_reason}.')

else:
return {
'message': f'Successfully purged replica '
f'{replica_id}'
}
else:
return {
'message': f'No purging for replica {replica_id} since '
f'the replica does not have a failed status.'
}

def run(self) -> None:

@self._app.post('/controller/load_balancer_sync')
Expand Down Expand Up @@ -126,6 +160,26 @@ async def update_service(request: fastapi.Request):
f'{common_utils.format_exception(e)}')
return {'message': 'Error'}

@self._app.post('/controller/terminate_replica')
async def terminate_replica(request: fastapi.Request):
request_data = await request.json()
try:
replica_id = request_data['replica_id']
dtran24 marked this conversation as resolved.
Show resolved Hide resolved
purge = request_data['purge']
if purge:
return self._purge_replica(replica_id)
else:
logger.info(f'Terminating replica {replica_id}...')
self._replica_manager.scale_down(replica_id)
return {
'message': f'Success terminating replica {replica_id}.'
}
except Exception as e: # pylint: disable=broad-except
error_message = (f'Error in terminate_replica: '
f'{common_utils.format_exception(e)}')
logger.error(error_message)
return {'message': error_message}

@self._app.on_event('startup')
def configure_logger():
uvicorn_access_logger = logging.getLogger('uvicorn.access')
Expand Down
55 changes: 53 additions & 2 deletions sky/serve/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
if typing.TYPE_CHECKING:
from sky import clouds

logger = sky_logging.init_logger(__name__)


@usage_lib.entrypoint
def up(
Expand Down Expand Up @@ -272,8 +274,7 @@ def update(task: 'sky.Task', service_name: str) -> None:
'Service controller is stopped. There is no service to update. '
f'To spin up a new service, use {backend_utils.BOLD}'
f'sky serve up{backend_utils.RESET_BOLD}',
non_existent_message='Service does not exist. '
'To spin up a new service, '
non_existent_message='To spin up a new service, '
f'use {backend_utils.BOLD}sky serve up{backend_utils.RESET_BOLD}',
)

Expand Down Expand Up @@ -466,6 +467,56 @@ def down(
sky_logging.print(stdout)


@usage_lib.entrypoint
def terminate_replica(service_name: str, replica_id: int, purge: bool) -> None:
"""Tear down a specific replica

Args:
service_name: Name of the service.
replica_id: ID of replica to terminate.
purge: Whether to terminate replicas in a failed status. These replicas
may lead to resource leaks.

Raises:
sky.exceptions.ClusterNotUpError: if the sky sere controller is not up.
RuntimeError: if failed to terminate the replica.
"""
cluster_status, handle = backend_utils.is_controller_up(
controller_type=controller_utils.Controllers.SKY_SERVE_CONTROLLER,
stopped_message=
'No service is running now. Please spin up a service first.',
non_existent_message='To spin up a new service, '
f'use {backend_utils.BOLD}sky serve up{backend_utils.RESET_BOLD}',
)
if handle is None or handle.head_ip is None:
# The error message is already printed in
# backend_utils.is_controller_up
# TODO(zhwu): Move the error message into the exception.
with ux_utils.print_exception_no_traceback():
raise exceptions.ClusterNotUpError(message='',
cluster_status=cluster_status)

backend = backend_utils.get_backend_from_handle(handle)
assert isinstance(backend, backends.CloudVmRayBackend)

code = serve_utils.ServeCodeGen.terminate_replica(service_name, replica_id,
purge)
returncode = backend.run_on_head(handle,
code,
require_outputs=False,
stream_logs=False)
try:
subprocess_utils.handle_returncode(
returncode, code, f'Failed to terminate replica {replica_id}')
except exceptions.CommandError as e:
raise RuntimeError(e.error_msg) from e
logger.info(
f'{colorama.Fore.GREEN}Termination of replica {replica_id} for '
f'{service_name!r} has been scheduled.{colorama.Style.RESET_ALL}\n'
f'Please use {backend_utils.BOLD}sky serve status {service_name} '
f'{backend_utils.RESET_BOLD}to check the latest status.')


@usage_lib.entrypoint
def status(
service_names: Optional[Union[str,
Expand Down
40 changes: 34 additions & 6 deletions sky/serve/serve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,25 @@ def update_service_encoded(service_name: str, version: int) -> str:
return common_utils.encode_payload(service_msg)


def terminate_replica(service_name: str, replica_id: int, purge: bool) -> str:
service_status = _get_service_status(service_name)
if service_status is None:
raise ValueError(f'Service {service_name!r} does not exist.')
controller_port = service_status['controller_port']
resp = requests.post(
_CONTROLLER_URL.format(CONTROLLER_PORT=controller_port) +
'/controller/terminate_replica',
json={
'replica_id': replica_id,
'purge': purge,
})
if resp.status_code != 200:
raise ValueError(f'Failed to terminate replica {replica_id} '
f'in {service_name}')
service_msg = resp.json()['message']
return common_utils.encode_payload(service_msg)


def _get_service_status(
service_name: str,
with_replica_info: bool = True) -> Optional[Dict[str, Any]]:
Expand Down Expand Up @@ -844,6 +863,15 @@ def terminate_services(cls, service_names: Optional[List[str]],
]
return cls._build(code)

@classmethod
def terminate_replica(cls, service_name: str, replica_id: int,
purge: bool) -> str:
code = [
f'msg = serve_utils.terminate_replica({service_name!r}, '
f'{replica_id}, {purge})', 'print(msg, end="", flush=True)'
]
return cls._build(code)

@classmethod
def wait_service_registration(cls, service_name: str, job_id: int) -> str:
code = [
Expand Down Expand Up @@ -871,16 +899,16 @@ def stream_serve_process_logs(cls, service_name: str,
]
return cls._build(code)

@classmethod
def _build(cls, code: List[str]) -> str:
code = cls._PREFIX + code
generated_code = '; '.join(code)
return f'python3 -u -c {shlex.quote(generated_code)}'

@classmethod
def update_service(cls, service_name: str, version: int) -> str:
code = [
f'msg = serve_utils.update_service_encoded({service_name!r}, '
f'{version})', 'print(msg, end="", flush=True)'
]
return cls._build(code)

@classmethod
def _build(cls, code: List[str]) -> str:
code = cls._PREFIX + code
generated_code = '; '.join(code)
return f'python3 -u -c {shlex.quote(generated_code)}'
Loading