Skip to content

Commit

Permalink
add default env
Browse files Browse the repository at this point in the history
  • Loading branch information
carolineechen committed Apr 3, 2024
1 parent 7d10f72 commit b950de1
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 22 deletions.
5 changes: 4 additions & 1 deletion runhouse/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
import webbrowser
from pathlib import Path
from typing import List, Optional
from typing import Dict, List, Optional

import ray

Expand Down Expand Up @@ -140,6 +140,9 @@ def _print_status(config):

first_info_to_print = ["den_auth", "server_connection_type", "server_port"]

if config["default_env"] and isinstance(config["default_env"], Dict):
config["default_env"] = config["default_env"]["name"]

for info in config:
if info in first_info_to_print:
console.print(f"\u2022 {info}: {config[info]}")
Expand Down
5 changes: 3 additions & 2 deletions runhouse/resources/envs/conda_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,10 @@ def install(self, force=False, cluster=None):
on the cluster using SSH. (default: ``None``)
"""
if not any(["python" in dep for dep in self.conda_yaml["dependencies"]]):
base_python_version = run_setup_command(
status_codes = run_setup_command(
"python --version", cluster=cluster, stream_logs=False
)[1].split()[1]
)
base_python_version = status_codes[1].split()[1] if status_codes[0] == 0 else "3.10.9"
self.conda_yaml["dependencies"].append(f"python=={base_python_version}")
install_conda(cluster=cluster)
local_env_exists = (
Expand Down
85 changes: 71 additions & 14 deletions runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
DEFAULT_HTTPS_PORT,
DEFAULT_RAY_PORT,
DEFAULT_SERVER_PORT,
ENVS_DIR,
LOCALHOST,
RESERVED_SYSTEM_NAMES,
)
Expand All @@ -53,6 +54,7 @@ def __init__(
name: Optional[str] = None,
ips: List[str] = None,
creds: "Secret" = None,
default_env: "Env" = None,
server_host: str = None,
server_port: int = None,
ssh_port: int = None,
Expand Down Expand Up @@ -95,6 +97,8 @@ def __init__(
self.domain = domain
self.use_local_telemetry = use_local_telemetry

self._default_env = _get_env_from(default_env)

@property
def address(self):
return self.ips[0] if isinstance(self.ips, List) else None
Expand All @@ -110,6 +114,19 @@ def creds_values(self) -> Dict:
return {}

return self._creds.values

@property
def default_env(self):
return self._default_env

@default_env.setter
def default_env(self, env):
self._default_env = _get_env_from(env)
if self.is_up():
self.check_server()
self._sync_default_env_to_cluster()
self.put_resource(self._default_env)
self.save_config_to_cluster()

def save_config_to_cluster(self, node: str = None):
config = self.config(condensed=False)
Expand Down Expand Up @@ -147,11 +164,15 @@ def save(
return self

def _save_sub_resources(self):
from runhouse.resources.envs import Env
from runhouse.resources.secrets import Secret

if self._creds and isinstance(self._creds, Secret):
self._creds.save()

if self._default_env and isinstance(self._default_env, Env):
self._default_env.save()

@classmethod
def from_config(cls, config: dict, dryrun=False):
resource_subtype = config.get("resource_subtype")
Expand Down Expand Up @@ -196,6 +217,12 @@ def config(self, condensed=True):

config["creds"] = creds

if self._default_env:
default_env = self._resource_string_for_subconfig(
self._default_env, condensed
)
config["default_env"] = default_env

if self._use_custom_certs:
config["ssl_certfile"] = self.cert_config.cert_path
config["ssl_keyfile"] = self.cert_config.key_path
Expand Down Expand Up @@ -309,13 +336,22 @@ def keep_warm(self):
)
return self

def _sync_default_env_to_cluster(self):
if not self._default_env:
return

logging.info(f"Syncing default env {self._default_env.name} to cluster")
self._default_env.install(cluster=self)

def _sync_runhouse_to_cluster(self, _install_url=None, env=None):
if self.on_this_cluster():
return

if not self.address:
raise ValueError(f"No address set for cluster <{self.name}>. Is it up?")

env = env or self._default_env

local_rh_package_path = Path(importlib.util.find_spec("runhouse").origin).parent

# Check if runhouse is installed from source and has setup.py
Expand Down Expand Up @@ -351,7 +387,10 @@ def _sync_runhouse_to_cluster(self, _install_url=None, env=None):

for node in self.ips:
status_codes = self.run(
[rh_install_cmd], node=node, env=env, stream_logs=True
[rh_install_cmd],
node=node,
env=env,
stream_logs=True,
)

if status_codes[0][0] != 0:
Expand All @@ -376,7 +415,9 @@ def install_packages(
from runhouse.resources.envs.env import Env

self.check_server()
env = _get_env_from(env) or Env(name=env or Env.DEFAULT_NAME)
env = _get_env_from(env or self._default_env) or Env(
name=env or Env.DEFAULT_NAME
)
env.reqs = env._reqs + reqs
env.to(self)

Expand Down Expand Up @@ -410,11 +451,12 @@ def add_secrets(
):
"""Copy secrets from current environment onto the cluster"""
self.check_server()
self.sync_secrets(provider_secrets, env=env)
self.sync_secrets(provider_secrets, env=env or self._default_env)

def put(self, key: str, obj: Any, env=None):
"""Put the given object on the cluster's object store at the given key."""
self.check_server()
env = env or self._default_env
if self.on_this_cluster():
return obj_store.put(key, obj, env=env)
return self.client.put_object(key, obj, env=env)
Expand All @@ -431,7 +473,7 @@ def put_resource(
if hasattr(resource, "env")
else resource.name or resource.env_name
if resource.RESOURCE_TYPE == "env"
else None
else self._default_env
)

if env and not isinstance(env, str):
Expand Down Expand Up @@ -460,7 +502,7 @@ def keys(self, env=None):
self.check_server()
if self.on_this_cluster():
return obj_store.keys()
res = self.client.keys(env=env)
res = self.client.keys(env=env or self._default_env)
return res

def delete(self, keys: Union[None, str, List[str]]):
Expand Down Expand Up @@ -656,7 +698,7 @@ def _use_custom_certs(self):
and a domain is provided."""
return self._use_https and not (self._use_caddy and self.domain is not None)

def _start_ray_workers(self, ray_port):
def _start_ray_workers(self, ray_port, env):
for host in self.ips:
if host == self.address:
# This is the master node, skip
Expand All @@ -669,14 +711,14 @@ def _start_ray_workers(self, ray_port):
f"ray start --address={self.address}:{ray_port}",
],
node=host,
env=env,
)

def restart_server(
self,
_rh_install_url: str = None,
resync_rh: bool = True,
restart_ray: bool = False,
env: Union[str, "Env"] = None,
restart_proxy: bool = False,
):
"""Restart the RPC server.
Expand All @@ -692,8 +734,14 @@ def restart_server(
"""
logger.info(f"Restarting Runhouse API server on {self.name}.")

default_env = _get_env_from(self._default_env) if self._default_env else None
if default_env:
self._sync_default_env_to_cluster()

if resync_rh:
self._sync_runhouse_to_cluster(_install_url=_rh_install_url)
self._sync_runhouse_to_cluster(
_install_url=_rh_install_url, env=default_env
)
logger.debug("Finished syncing Runhouse to cluster.")

https_flag = self._use_https
Expand Down Expand Up @@ -749,7 +797,7 @@ def restart_server(
+ f" --port {self.server_port}"
)

status_codes = self.run(commands=[cmd], env=env)
status_codes = self.run(commands=[cmd], env=self._default_env, node=self.address)
if not status_codes[0][0] == 0:
raise ValueError(f"Failed to restart server {self.name}.")

Expand All @@ -766,7 +814,10 @@ def restart_server(
self.client.use_https = https_flag

if restart_ray and len(self.ips) > 1:
self._start_ray_workers(DEFAULT_RAY_PORT)
self._start_ray_workers(DEFAULT_RAY_PORT, env=self._default_env)

if default_env:
self.put_resource(default_env)

return status_codes

Expand All @@ -779,7 +830,7 @@ def stop_server(self, stop_ray: bool = True, env: Union[str, "Env"] = None):
"""
cmd = CLI_STOP_CMD if stop_ray else f"{CLI_STOP_CMD} --no-stop-ray"

status_codes = self.run([cmd], env=env, stream_logs=False)
status_codes = self.run([cmd], env=env or self._default_env, stream_logs=False)
assert status_codes[0][0] == 1

@contextlib.contextmanager
Expand Down Expand Up @@ -1077,6 +1128,11 @@ def run(
>>> cpu.run(["python script.py"], run_name="my_exp")
>>> cpu.run(["python script.py"], node="3.89.174.234")
"""
if isinstance(commands, str):
commands = [commands]

env = env or self._default_env

if node == "all":
res_list = []
for node in self.ips:
Expand Down Expand Up @@ -1167,8 +1223,8 @@ def run(
def _run_commands_with_ssh(
self,
commands: list,
cmd_prefix: str,
stream_logs: bool,
cmd_prefix: str = None,
stream_logs: bool = True,
node: str = None,
port_forward: int = None,
require_outputs: bool = True,
Expand Down Expand Up @@ -1271,7 +1327,7 @@ def run_python(
# If invoking a run as part of the python commands also return the Run object
return_codes = self.run(
[formatted_command],
env=env,
env=env or self._default_env,
stream_logs=stream_logs,
node=node,
port_forward=port_forward,
Expand All @@ -1297,6 +1353,7 @@ def sync_secrets(
self.check_server()
from runhouse.resources.secrets import Secret

env = env or self._default_env
if isinstance(env, str):
from runhouse.resources.envs import Env

Expand Down
10 changes: 10 additions & 0 deletions runhouse/resources/hardware/cluster_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def cluster(
ssl_certfile: str = None,
domain: str = None,
den_auth: bool = False,
default_env: Union["Env", str] = None,
dryrun: bool = False,
**kwargs,
) -> Union[Cluster, OnDemandCluster, SageMakerCluster]:
Expand Down Expand Up @@ -55,6 +56,7 @@ def cluster(
den_auth (bool, optional): Whether to use Den authorization on the server. If ``True``, will validate incoming
requests with a Runhouse token provided in the auth headers of the request with the format:
``{"Authorization": "Bearer <token>"}``. (Default: ``False``).
# TODO [CC]: Add default_env docstrings
dryrun (bool): Whether to create the Cluster if it doesn't exist, or load a Cluster object as a dryrun.
(Default: ``False``)
Expand Down Expand Up @@ -95,6 +97,7 @@ def cluster(
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
domain=domain,
default_env=default_env,
kwargs=kwargs if len(kwargs) > 0 else None,
)
# Filter out None/default values
Expand Down Expand Up @@ -125,6 +128,7 @@ def cluster(
ssl_certfile=ssl_certfile,
domain=domain,
den_auth=den_auth,
default_env=default_env,
dryrun=dryrun,
**kwargs,
)
Expand Down Expand Up @@ -153,6 +157,7 @@ def cluster(
ssl_certfile=ssl_certfile,
domain=domain,
den_auth=den_auth,
default_env=default_env,
dryrun=dryrun,
**kwargs,
)
Expand All @@ -176,6 +181,7 @@ def cluster(
ssl_certfile=ssl_certfile,
domain=domain,
den_auth=den_auth,
default_env=default_env,
dryrun=dryrun,
**kwargs,
)
Expand Down Expand Up @@ -308,6 +314,7 @@ def ondemand_cluster(
ssl_certfile: str = None,
domain: str = None,
den_auth: bool = None,
default_env: Union["Env", str] = None,
dryrun: bool = False,
**kwargs,
) -> OnDemandCluster:
Expand Down Expand Up @@ -467,6 +474,7 @@ def sagemaker_cluster(
ssl_certfile: str = None,
domain: str = None,
den_auth: bool = False,
default_env: Union["Env", str] = None,
dryrun: bool = False,
**kwargs,
) -> SageMakerCluster:
Expand Down Expand Up @@ -595,6 +603,7 @@ def sagemaker_cluster(
ssl_certfile=ssl_certfile,
domain=domain,
den_auth=den_auth,
default_env=default_env,
)
# Filter out None/default values
alt_options = {k: v for k, v in alt_options.items() if v is not None}
Expand Down Expand Up @@ -632,6 +641,7 @@ def sagemaker_cluster(
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
domain=domain,
default_env=default_env,
dryrun=dryrun,
**kwargs,
)
Expand Down
2 changes: 2 additions & 0 deletions runhouse/resources/hardware/on_demand_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
instance_type: str = None,
num_instances: int = None,
provider: str = None,
default_env: "Env" = None,
dryrun=False,
autostop_mins=None,
use_spot=False,
Expand All @@ -66,6 +67,7 @@ def __init__(
"""
super().__init__(
name=name,
default_env=default_env,
server_host=server_host,
server_port=server_port,
server_connection_type=server_connection_type,
Expand Down
Loading

0 comments on commit b950de1

Please sign in to comment.