Skip to content

Commit

Permalink
reloading ondemand cluster from sky db + cluster config
Browse files Browse the repository at this point in the history
  • Loading branch information
jlewitt1 committed Aug 12, 2024
1 parent c1f6563 commit 3f71c85
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 3 deletions.
Empty file added -H
Empty file.
37 changes: 37 additions & 0 deletions runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import importlib
import json
import logging
import os
import re
import subprocess
import threading
Expand Down Expand Up @@ -182,6 +183,7 @@ def from_name(
):
cluster = super().from_name(
name=name,
load_from_den=load_from_den,
dryrun=dryrun,
alt_options=alt_options,
_resolve_children=_resolve_children,
Expand All @@ -193,6 +195,41 @@ def from_name(
pass
return cluster

@classmethod
def _from_sky_db(cls, name, dryrun=False):
import sky

from runhouse import OnDemandCluster

sky_cluster_name = os.path.basename(name)
sky_data: dict = sky.global_user_state.get_cluster_from_name(sky_cluster_name)
if not sky_data:
raise ValueError(f"Cluster {name} not found locally")

cluster_ip = sky_data["handle"].head_ip
if not cluster_ip:
raise ValueError(
f"Failed to load cluster {name} locally, IP address not found"
)

try:
resp = requests.get(
f"http://{cluster_ip}:{DEFAULT_SERVER_PORT}/config",
headers=rns_client.request_headers(),
timeout=5,
)
if resp.status_code != 200:
raise ConnectionError(
f"Failed to load cluster {name} from IP address: {cluster_ip}"
)

except requests.exceptions.Timeout:
# Server might not be up, port may not be open
raise TimeoutError(f"Request to cluster {name} failed")

cluster_config = resp.json()
return OnDemandCluster(**cluster_config, dryrun=dryrun)

def save_config_to_cluster(
self,
node: str = None,
Expand Down
13 changes: 10 additions & 3 deletions runhouse/resources/hardware/cluster_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,11 @@ def cluster(
if den_auth:
c.save()
return c
except ValueError as e:
except (ValueError, ConnectionError, TimeoutError) as e:
if not alt_options:
if not load_from_den:
# Cluster might be of type on-demand with metadata saved in the local Sky DB
return OnDemandCluster._from_sky_db(name, dryrun)
raise e

if ssh_creds:
Expand Down Expand Up @@ -477,13 +480,17 @@ def ondemand_cluster(
if den_auth:
c.save()
return c
except ValueError as e:
except (ValueError, ConnectionError, TimeoutError) as e:
import sky

state = sky.status(cluster_names=[name], refresh=False)
if len(state) == 0 and not alt_options:
raise e

if not load_from_den:
# Try loading the cluster using data from the local Sky DB
return OnDemandCluster._from_sky_db(name, dryrun)

c = OnDemandCluster(
instance_type=instance_type,
provider=provider,
Expand Down Expand Up @@ -684,7 +691,7 @@ def sagemaker_cluster(
if c:
c.set_connection_defaults()
return c
except ValueError as e:
except (ValueError, ConnectionError, TimeoutError) as e:
if not alt_options:
raise e

Expand Down
4 changes: 4 additions & 0 deletions runhouse/servers/http/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,10 @@ def status(self, resource_address: str):
# Note: Resource address must be specified in order to construct the cluster subtoken
return self.request("status", req_type="get", resource_address=resource_address)

def config(self):
"""Load the remote cluster's config."""
return self.request("config", req_type="get")

def get_certificate(self):
cert: bytes = self.request(
"cert",
Expand Down
16 changes: 16 additions & 0 deletions runhouse/servers/http/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,22 @@ def check_server():
serialization=serialization,
)

@staticmethod
@app.get("/config")
@validate_cluster_access
def cluster_config(request: Request):
try:
cluster_config = obj_store.get_cluster_config()
return Response(
data=cluster_config,
output_type=OutputType.RESULT_SERIALIZED,
serialization=None,
)
except (AttributeError, ObjStoreError) as e:
return handle_exception_response(
e, traceback.format_exc(), from_http_server=True
)

@staticmethod
@app.post("/settings")
@validate_cluster_access
Expand Down
7 changes: 7 additions & 0 deletions tests/test_resources/test_clusters/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,3 +896,10 @@ def test_switch_default_env(self, cluster):
# set it back
cluster.default_env = test_env
cluster.delete(new_env.name)

@pytest.mark.level("release")
@pytest.mark.clustertest
def test_load_cluster_from_sky_db(self, cluster):
# Reload based on the IP stored in the Sky DB and config stored on the cluster
reloaded_cluster = rh.cluster(name=cluster.rns_address, load_from_den=False)
assert reloaded_cluster.ips == cluster.ips
9 changes: 9 additions & 0 deletions tests/test_servers/test_http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ def test_check_server(self, http_client):
response = http_client.get("/check")
assert response.status_code == 200

@pytest.mark.level("local")
def test_check_cluster_config(self, http_client, cluster):
response = http_client.get(
"/config", headers=rns_client.request_headers(cluster.rns_address)
)
assert response.status_code == 200
cluster_servlet_config = response.json().get("data")
assert cluster_servlet_config["ips"] == cluster.ips

@pytest.mark.level("local")
def test_put_resource(self, http_client, blob_data, cluster):
state = None
Expand Down

0 comments on commit 3f71c85

Please sign in to comment.