Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto stop for cluster #653

Merged
merged 40 commits into from
Mar 31, 2022
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
f9d43c9
refactorize skylet
Michaelvll Mar 25, 2022
af060a4
implement autostop event without cluster stopping
Michaelvll Mar 25, 2022
dfce233
wip
Michaelvll Mar 26, 2022
ba3af68
Remove autostop from yaml file
Michaelvll Mar 26, 2022
fbdd296
fix naming
Michaelvll Mar 26, 2022
b33ba85
fix config
Michaelvll Mar 26, 2022
9b6cde2
fix skylet
Michaelvll Mar 27, 2022
8793f57
add autostop to status
Michaelvll Mar 27, 2022
d483c76
Merge branch 'master' of github.com:concretevitamin/sky-experiments i…
Michaelvll Mar 27, 2022
dc22d74
fix state and name match
Michaelvll Mar 27, 2022
a8d32b4
Replace min_workers/max_workers for gcp
Michaelvll Mar 27, 2022
5f883be
using ray up / ray down process
Michaelvll Mar 27, 2022
ddf0ba1
fix stopping
Michaelvll Mar 27, 2022
3524f04
set autostop in globle user state
Michaelvll Mar 27, 2022
b481ebe
update sky status
Michaelvll Mar 28, 2022
a66913c
format
Michaelvll Mar 28, 2022
f968bf2
Add refresh to sky status
Michaelvll Mar 28, 2022
dec91cd
Merge branch 'master' of github.com:concretevitamin/sky-experiments i…
Michaelvll Mar 28, 2022
42844d9
Merge branch 'master' of github.com:concretevitamin/sky-experiments i…
Michaelvll Mar 29, 2022
492c39d
Merge branch 'master' of github.com:concretevitamin/sky-experiments i…
Michaelvll Mar 29, 2022
9a207ae
address comments
Michaelvll Mar 29, 2022
6d3afba
comment
Michaelvll Mar 29, 2022
642f0a8
address comments
Michaelvll Mar 29, 2022
70a3aeb
Merge branch 'master' of github.com:concretevitamin/sky-experiments i…
Michaelvll Mar 29, 2022
905ba07
Fix logging
Michaelvll Mar 29, 2022
a34a810
update help
Michaelvll Mar 29, 2022
13f757a
remove ssh config and bring cursor back
Michaelvll Mar 30, 2022
af40dc4
Fix exec on stopped instance
Michaelvll Mar 30, 2022
15466be
address comment
Michaelvll Mar 30, 2022
5d20eef
format
Michaelvll Mar 30, 2022
af404a6
fix
Michaelvll Mar 30, 2022
0331aac
Add test for autostop
Michaelvll Mar 30, 2022
15b434b
Fix cancel
Michaelvll Mar 30, 2022
dfaea0f
Merge branch 'master' of github.com:concretevitamin/sky-experiments i…
Michaelvll Mar 31, 2022
c653006
address comment
Michaelvll Mar 31, 2022
e5227b6
address comment
Michaelvll Mar 31, 2022
0fb5a4d
Fix sky launch will change autostop to -1
Michaelvll Mar 31, 2022
1d1f9bd
format
Michaelvll Mar 31, 2022
d70bfbc
Add docs
Michaelvll Mar 31, 2022
b62c7dc
update
Michaelvll Mar 31, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 69 additions & 87 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from sky import backends
from sky import check as sky_check
from sky import clouds
from sky import global_user_state
from sky import exceptions
from sky import sky_logging
from sky.adaptors import azure
Expand All @@ -40,6 +41,7 @@
# NOTE: keep in sync with the cluster template 'file_mounts'.
SKY_REMOTE_WORKDIR = log_lib.SKY_REMOTE_WORKDIR
SKY_REMOTE_APP_DIR = '~/.sky/sky_app'
SKY_RAY_YAML_REMOTE_PATH = '~/.sky/sky_ray.yml'
IP_ADDR_REGEX = r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}'
SKY_REMOTE_RAY_VERSION = '1.10.0'
SKY_REMOTE_PATH = '~/.sky/sky_wheels'
Expand Down Expand Up @@ -71,8 +73,6 @@ def _fill_template(template_name: str,
raise FileNotFoundError(f'Template "{template_name}" does not exist.')
with open(template_path) as fin:
template = fin.read()
template = jinja2.Template(template)
content = template.render(**variables)
if output_path is None:
assert 'cluster_name' in variables, 'cluster_name is required.'
cluster_name = variables['cluster_name']
Expand All @@ -81,6 +81,12 @@ def _fill_template(template_name: str,
os.makedirs(output_path.parents[0], exist_ok=True)
output_path = str(output_path)
output_path = os.path.abspath(output_path)

# Add yaml file path to the template variables.
variables['sky_ray_yaml_remote_path'] = SKY_RAY_YAML_REMOTE_PATH
variables['sky_ray_yaml_local_path'] = output_path
template = jinja2.Template(template)
content = template.render(**variables)
with open(output_path, 'w') as fout:
fout.write(content)
return output_path
Expand Down Expand Up @@ -1041,17 +1047,30 @@ def get_node_ips(
handle is not None and handle.head_ip is not None):
return [handle.head_ip]

out = run(f'ray get-head-ip {yaml_handle}',
stdout=subprocess.PIPE).stdout.decode().strip()
head_ip = re.findall(IP_ADDR_REGEX, out)
assert 1 == len(head_ip), out
try:
proc = run(f'ray get-head-ip {yaml_handle}',
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
out = proc.stdout.decode().strip()
head_ip = re.findall(IP_ADDR_REGEX, out)
except subprocess.CalledProcessError as e:
raise exceptions.FetchIPError(
exceptions.FetchIPError.Reason.HEAD) from e
if len(head_ip) != 1:
raise exceptions.FetchIPError(exceptions.FetchIPError.Reason.HEAD)
concretevitamin marked this conversation as resolved.
Show resolved Hide resolved

if expected_num_nodes > 1:
out = run(f'ray get-worker-ips {yaml_handle}',
stdout=subprocess.PIPE).stdout.decode()
worker_ips = re.findall(IP_ADDR_REGEX, out)
assert expected_num_nodes - 1 == len(worker_ips), (expected_num_nodes -
1, out)
try:
proc = run(f'ray get-worker-ips {yaml_handle}',
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
out = proc.stdout.decode()
worker_ips = re.findall(IP_ADDR_REGEX, out)
except subprocess.CalledProcessError as e:
raise exceptions.FetchIPError(
exceptions.FetchIPError.Reason.WORKER) from e
if len(worker_ips) != expected_num_nodes - 1:
raise exceptions.FetchIPError(exceptions.FetchIPError.Reason.WORKER)
else:
worker_ips = []
if return_private_ips:
Expand Down Expand Up @@ -1080,6 +1099,45 @@ def get_head_ip(
return head_ip


def _ping_cluster_or_set_to_stopped(
record: Dict[str, Any]) -> global_user_state.ClusterStatus:
handle = record['handle']
if not isinstance(handle, backends.CloudVmRayBackend.ResourceHandle):
return record
# Autostop is disabled for the cluster
if record['autostop'] < 0:
return record
cluster_name = handle.cluster_name
try:
get_node_ips(handle.cluster_yaml, handle.launched_nodes)
return record
except exceptions.FetchIPError as e:
# Set the cluster status to STOPPED, even the head node is still alive,
# since it will be stopped as soon as the workers are stopped.
logger.debug(f'Failed to get IPs from cluster {cluster_name}: {e}, '
'set to STOPPED')
global_user_state.remove_cluster(cluster_name, terminate=False)
auth_config = read_yaml(handle.cluster_yaml)['auth']
SSHConfigHelper.remove_cluster(cluster_name, handle.head_ip, auth_config)
return global_user_state.get_cluster_from_name(cluster_name)


def get_status_from_cluster_name(
cluster_name: str) -> Optional[global_user_state.ClusterStatus]:
record = global_user_state.get_cluster_from_name(cluster_name)
if record is None:
return None
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
record = _ping_cluster_or_set_to_stopped(record)
return record['status']


def get_clusters(refresh: bool) -> List[Dict[str, Any]]:
records = global_user_state.get_clusters()
if not refresh:
return records
return [_ping_cluster_or_set_to_stopped(record) for record in records]


def query_head_ip_with_retries(cluster_yaml: str, retry_count: int = 1) -> str:
"""Returns the ip of the head node from yaml file."""
for i in range(retry_count):
Expand Down Expand Up @@ -1115,82 +1173,6 @@ def get_backend_from_handle(
return backend


class JobLibCodeGen(object):
"""Code generator for job utility functions.

Usage:

>> codegen = JobLibCodeGen.add_job(...)
"""

_PREFIX = ['from sky.skylet import job_lib, log_lib']

@classmethod
def add_job(cls, job_name: str, username: str, run_timestamp: str) -> str:
if job_name is None:
job_name = '-'
code = [
'job_id = job_lib.add_job('
f'{job_name!r}, {username!r}, {run_timestamp!r})',
'print(job_id, flush=True)',
]
return cls._build(code)

@classmethod
def update_status(cls) -> str:
code = [
'job_lib.update_status()',
]
return cls._build(code)

@classmethod
def show_jobs(cls, username: Optional[str], all_jobs: bool) -> str:
code = [f'job_lib.show_jobs({username!r}, {all_jobs})']
return cls._build(code)

@classmethod
def cancel_jobs(cls, job_ids: Optional[List[int]]) -> str:
code = [f'job_lib.cancel_jobs({job_ids!r})']
return cls._build(code)

@classmethod
def fail_all_jobs_in_progress(cls) -> str:
# Used only for restarting a cluster.
code = ['job_lib.fail_all_jobs_in_progress()']
return cls._build(code)

@classmethod
def tail_logs(cls, job_id: int) -> str:
code = [
f'log_dir = job_lib.log_dir({job_id})',
f'log_lib.tail_logs({job_id}, log_dir)',
]
return cls._build(code)

@classmethod
def get_job_status(cls, job_id: str) -> str:
# Prints "Job <id> <status>" for UX; caller should parse the last token.
code = [
f'job_status = job_lib.get_status({job_id})',
f'print("Job", {job_id}, job_status.value, flush=True)',
]
return cls._build(code)

@classmethod
def get_log_path(cls, job_id: int) -> str:
code = [
f'log_dir = job_lib.log_dir({job_id})',
'print(log_dir, flush=True)',
]
return cls._build(code)

@classmethod
def _build(cls, code: List[str]) -> str:
code = cls._PREFIX + code
code = ';'.join(code)
return f'python3 -u -c {code!r}'


class NoOpConsole:
"""An empty class for multi-threaded console.status."""

Expand Down
65 changes: 36 additions & 29 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from sky import task as task_lib
from sky.backends import backend_utils
from sky.backends import wheel_utils
from sky.skylet import job_lib, log_lib
from sky.skylet import autostop_lib, job_lib, log_lib

if typing.TYPE_CHECKING:
from sky import dag
Expand Down Expand Up @@ -172,7 +172,7 @@ def __init__(self):
# It is a int automatically generated by the DB on the cluster
# and monotonically increasing starting from 1.
# To generate the job ID, we use the following logic:
# code = backend_utils.JobLibCodeGen.add_job(username,
# code = job_lib.JobLibCodeGen.add_job(username,
# run_timestamp)
# job_id = get_output(run_on_cluster(code))
self.job_id = None
Expand Down Expand Up @@ -635,7 +635,7 @@ def _yield_region_zones(self, to_provision: 'resources_lib.Resources',
zones = [clouds.Zone(name=zone) for zone in zones.split(',')]
region.set_zones(zones)
# Get the *previous* cluster status.
cluster_status = global_user_state.get_status_from_cluster_name(
cluster_status = backend_utils.get_status_from_cluster_name(
cluster_name)
if cluster_status != global_user_state.ClusterStatus.UP:
logger.info(
Expand Down Expand Up @@ -787,7 +787,7 @@ def _retry_region_zones(self,
f'{style.BRIGHT}{tail_cmd}{style.RESET_ALL}')

# Get previous cluster status
prev_cluster_status = global_user_state.get_status_from_cluster_name(
prev_cluster_status = backend_utils.get_status_from_cluster_name(
cluster_name)

self._clear_blocklist()
Expand Down Expand Up @@ -951,13 +951,6 @@ def ray_up(start_streaming_at):
require_outputs=True)
return returncode, stdout, stderr

config = backend_utils.read_yaml(cluster_config_file)
file_mounts = config['file_mounts']
if 'ssh_public_key' in config['auth']:
# For Azure, we need to add ssh public key to VM by filemounts.
public_key_path = config['auth']['ssh_public_key']
file_mounts[public_key_path] = public_key_path
concretevitamin marked this conversation as resolved.
Show resolved Hide resolved

region_name = logging_info['region_name']
zone_str = logging_info['zone_str']

Expand Down Expand Up @@ -1257,8 +1250,7 @@ def provision(self,
to_provision_config = self._check_existing_cluster(
task, to_provision, cluster_name)
prev_cluster_status = (
global_user_state.get_status_from_cluster_name(cluster_name)
)
backend_utils.get_status_from_cluster_name(cluster_name))
assert to_provision_config.resources is not None, (
'to_provision should not be None', to_provision_config)
# TODO(suquark): once we have sky on PYPI, we should directly
Expand Down Expand Up @@ -1315,7 +1307,7 @@ def provision(self,
# update_status will query the ray job status for all INIT /
# PENDING / RUNNING jobs for the real status, since we do not
# know the actual previous status of the cluster.
cmd = backend_utils.JobLibCodeGen.update_status()
cmd = job_lib.JobLibCodeGen.update_status()
with backend_utils.safe_console_status(
'[bold cyan]Preparing Job Queue'):
returncode, _, stderr = self.run_on_head(
Expand All @@ -1330,7 +1322,7 @@ def provision(self,
# 1. A job finishes RUNNING, but right before it update itself
# to SUCCEEDED, the cluster is STOPPED by `sky stop`.
# 2. On next `sky start`, it gets reset to FAILED.
cmd = backend_utils.JobLibCodeGen.fail_all_jobs_in_progress()
cmd = job_lib.JobLibCodeGen.fail_all_jobs_in_progress()
returncode, _, stderr = self.run_on_head(handle,
cmd,
require_outputs=True)
Expand All @@ -1349,6 +1341,21 @@ def provision(self,
os.remove(lock_path)
return handle

def set_autostop(self, handle: ResourceHandle,
idle_minutes_to_autostop: Optional[int]) -> None:
if idle_minutes_to_autostop is not None:
code = autostop_lib.AutostopCodeGen.set_autostop(
idle_minutes_to_autostop, self.NAME)
returncode, _, stderr = self.run_on_head(handle,
code,
require_outputs=True)
backend_utils.handle_returncode(returncode,
code,
'Failed to set autostop',
stderr=stderr)
global_user_state.set_cluster_autostop_value(
handle.cluster_name, idle_minutes_to_autostop)

def sync_workdir(self, handle: ResourceHandle, workdir: Path) -> None:
# Even though provision() takes care of it, there may be cases where
# this function is called in isolation, without calling provision(),
Expand Down Expand Up @@ -1642,7 +1649,7 @@ def _setup_node(ip: int) -> int:

def get_job_status(self, handle: ResourceHandle,
job_id: int) -> Optional[job_lib.JobStatus]:
code = backend_utils.JobLibCodeGen.get_job_status(job_id)
code = job_lib.JobLibCodeGen.get_job_status(job_id)
returncode, stdout, stderr = self.run_on_head(handle,
code,
stream_logs=True,
Expand All @@ -1655,7 +1662,7 @@ def get_job_status(self, handle: ResourceHandle,
return job_lib.JobStatus(result.split(' ')[-1])

def sync_down_logs(self, handle: ResourceHandle, job_id: int) -> None:
code = backend_utils.JobLibCodeGen.get_log_path(job_id)
code = job_lib.JobLibCodeGen.get_log_path(job_id)
returncode, log_dir, stderr = self.run_on_head(handle,
code,
stream_logs=False,
Expand Down Expand Up @@ -1770,7 +1777,7 @@ def _exec_code_on_head(
f'{backend_utils.RESET_BOLD}')

def tail_logs(self, handle: ResourceHandle, job_id: int) -> None:
code = backend_utils.JobLibCodeGen.tail_logs(job_id)
code = job_lib.JobLibCodeGen.tail_logs(job_id)
logger.info(f'{colorama.Fore.YELLOW}Start streaming logs...'
f'{colorama.Style.RESET_ALL}')

Expand All @@ -1795,8 +1802,8 @@ def tail_logs(self, handle: ResourceHandle, job_id: int) -> None:

def _add_job(self, handle: ResourceHandle, job_name: str) -> int:
username = getpass.getuser()
code = backend_utils.JobLibCodeGen.add_job(job_name, username,
self.run_timestamp)
code = job_lib.JobLibCodeGen.add_job(job_name, username,
self.run_timestamp)
returncode, job_id_str, stderr = self.run_on_head(handle,
code,
stream_logs=False,
Expand Down Expand Up @@ -1830,13 +1837,13 @@ def execute(

job_id = self._add_job(handle, task.name)

# Case: Task(run, num_nodes=1)
# Case: task_lib.Task(run, num_nodes=1)
if task.num_nodes == 1:
return self._execute_task_one_node(handle, task, job_id, detach_run)

# Case: Task(run, num_nodes=N)
assert task.num_nodes > 1, task.num_nodes
return self._execute_task_n_nodes(handle, task, job_id, detach_run)
self._execute_task_one_node(handle, task, job_id, detach_run)
else:
# Case: task_lib.Task(run, num_nodes=N)
assert task.num_nodes > 1, task.num_nodes
self._execute_task_n_nodes(handle, task, job_id, detach_run)

def _execute_task_one_node(self, handle: ResourceHandle,
task: task_lib.Task, job_id: int,
Expand Down Expand Up @@ -1982,7 +1989,7 @@ def teardown_no_lock(self,
log_abs_path = os.path.abspath(log_path)
cloud = handle.launched_resources.cloud
config = backend_utils.read_yaml(handle.cluster_yaml)
prev_status = global_user_state.get_status_from_cluster_name(
prev_status = backend_utils.get_status_from_cluster_name(
handle.cluster_name)
cluster_name = handle.cluster_name
if terminate and isinstance(cloud, clouds.Azure):
Expand Down Expand Up @@ -2103,8 +2110,8 @@ def teardown_no_lock(self,
backend_utils.SSHConfigHelper.remove_cluster(cluster_name,
handle.head_ip,
auth_config)
name = handle.cluster_name
global_user_state.remove_cluster(name, terminate=terminate)
global_user_state.remove_cluster(handle.cluster_name,
terminate=terminate)

if terminate:
# Clean up generated config
Expand Down
Loading