Skip to content

Commit

Permalink
cluster init env
Browse files Browse the repository at this point in the history
  • Loading branch information
carolineechen committed Mar 28, 2024
1 parent e90ee86 commit b033875
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 61 deletions.
8 changes: 8 additions & 0 deletions runhouse/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
LOGS_DIR = ".rh/logs"
RH_LOGFILE_PATH = Path.home() / LOGS_DIR

ENVS_DIR = "~/.rh/envs"

MAX_MESSAGE_LENGTH = 1 * 1024 * 1024 * 1024 # 1 GB

CLI_RESTART_CMD = "runhouse restart"
Expand Down Expand Up @@ -43,3 +45,9 @@
# We need to use this instead of ray stop to make sure we don't stop the SkyPilot ray server,
# which runs on other ports but is required to preserve autostop and correct cluster status.
RAY_KILL_CMD = 'pkill -f ".*ray.*' + str(DEFAULT_RAY_PORT) + '.*"'

CONDA_INSTALL_CMDS = [
"wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh",
"bash ~/miniconda.sh -b -p ~/miniconda",
"source $HOME/miniconda3/bin/activate",
]
5 changes: 4 additions & 1 deletion runhouse/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
import webbrowser
from pathlib import Path
from typing import List, Optional
from typing import Dict, List, Optional

import ray

Expand Down Expand Up @@ -140,6 +140,9 @@ def _print_status(config):

first_info_to_print = ["den_auth", "server_connection_type", "server_port"]

if config["default_env"] and isinstance(config["default_env"], Dict):
config["defualt_env"] = config["default_env"]["name"]

for info in config:
if info in first_info_to_print:
console.print(f"\u2022 {info}: {config[info]}")
Expand Down
5 changes: 3 additions & 2 deletions runhouse/resources/envs/conda_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import Dict, List, Optional, Union

from runhouse.constants import ENVS_DIR
from runhouse.globals import obj_store

from runhouse.resources.packages import Package
Expand Down Expand Up @@ -63,12 +64,12 @@ def env_name(self):
return self.conda_yaml["name"]

def _create_conda_env(self, force=False):
path = "~/.rh/envs"
subprocess.run(f"mkdir -p {path}", shell=True)
path = ENVS_DIR

local_env_exists = f"\n{self.env_name} " in subprocess.check_output(
shlex.split("conda info --envs"), shell=False
).decode("utf-8")
subprocess.run(f"mkdir -p {path}", shell=True)
yaml_exists = (Path(path).expanduser() / f"{self.env_name}.yml").exists()

if force or not (yaml_exists and local_env_exists):
Expand Down
4 changes: 2 additions & 2 deletions runhouse/resources/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@ def _get_conda_yaml(conda_env=None):
for dep in conda_yaml["dependencies"]
if isinstance(dep, Dict) and "pip" in dep
]:
conda_yaml["dependencies"].append({"pip": ["ray<=2.4.0,>=2.2.0"]})
conda_yaml["dependencies"].append({"pip": ["ray >= 2.2.0, <= 2.6.3, != 2.6.0"]})
else:
for dep in conda_yaml["dependencies"]:
if (
isinstance(dep, Dict)
and "pip" in dep
and not [pip for pip in dep["pip"] if "ray" in pip]
):
dep["pip"].append("ray<=2.4.0,>=2.2.0")
dep["pip"].append("ray >= 2.2.0, <= 2.6.3, != 2.6.0")
continue
return conda_yaml

Expand Down
124 changes: 111 additions & 13 deletions runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
CLI_RESTART_CMD,
CLI_STOP_CMD,
CLUSTER_CONFIG_PATH,
CONDA_INSTALL_CMDS,
DEFAULT_HTTP_PORT,
DEFAULT_HTTPS_PORT,
DEFAULT_RAY_PORT,
DEFAULT_SERVER_PORT,
ENVS_DIR,
LOCALHOST,
RESERVED_SYSTEM_NAMES,
)
Expand All @@ -53,6 +55,7 @@ def __init__(
name: Optional[str] = None,
ips: List[str] = None,
creds: "Secret" = None,
default_env: "Env" = None,
server_host: str = None,
server_port: int = None,
ssh_port: int = None,
Expand Down Expand Up @@ -95,6 +98,8 @@ def __init__(
self.domain = domain
self.use_local_telemetry = use_local_telemetry

self.default_env = _get_env_from(default_env)

@property
def address(self):
return self.ips[0] if isinstance(self.ips, List) else None
Expand Down Expand Up @@ -147,11 +152,15 @@ def save(
return self

def _save_sub_resources(self):
from runhouse.resources.envs import Env
from runhouse.resources.secrets import Secret

if self._creds and isinstance(self._creds, Secret):
self._creds.save()

if self.default_env and isinstance(self.default_env, Env):
self.default_env.save()

@classmethod
def from_config(cls, config: dict, dryrun=False):
resource_subtype = config.get("resource_subtype")
Expand Down Expand Up @@ -196,6 +205,12 @@ def config(self, condensed=True):

config["creds"] = creds

if self.default_env:
default_env = self._resource_string_for_subconfig(
self.default_env, condensed
)
config["default_env"] = default_env

if self._use_custom_certs:
config["ssl_certfile"] = self.cert_config.cert_path
config["ssl_keyfile"] = self.cert_config.key_path
Expand Down Expand Up @@ -309,13 +324,78 @@ def keep_warm(self):
)
return self

def _sync_default_env_to_cluster(self):
if not self.default_env:
return

from runhouse.resources.packages.package import Package

env = _get_env_from(self.default_env)
logging.info(f"Syncing default env {self.default_env.name} to cluster.")

if hasattr(env, "conda_yaml"):
status_codes = self._run_commands_with_ssh(["conda --version"])
if status_codes[0][0] != 0:
logging.info("Conda is not installed. Installing...")
for cmd in CONDA_INSTALL_CMDS:
self._run_commands_with_ssh([cmd])
if self._run_commands_with_ssh(["conda --version"])[0][0] != 0:
raise RuntimeError("Could not install Conda")

local_env_exists = (
f"\n{self.env_name} "
in self._run_commands_with_ssh(["conda info --envs"])[0][0]
)

self._run_commands_with_ssh([f"mkdir -p {ENVS_DIR}"])
yaml_path = Path(ENVS_DIR / f"{env.env_name}.yml")
yaml_exists = (
self._run_commands_with_ssh([f"mkdir -p {yaml_path}"])[0][0] == 0
)

if not (yaml_exists and local_env_exists):
import yaml

contents = yaml.dump(env.conda_yaml)
try:
self._run_commands_with_ssh(
[
f"echo $'{contents}' > {yaml_path}",
f"conda env create -f {yaml_path}.yml",
]
)
local_env_exists = (
f"\n{self.env_name} "
in self._run_commands_with_ssh(["conda info --envs"])[0][0]
)
except:
local_env_exists = False
if not local_env_exists:
raise RuntimeError(
f"conda env {self.env_name} not created properly."
)

cmd_prefix = env._run_cmd
for package in env.reqs:
package = (
Package.from_string(package) if isinstance(package, str) else package
)
self._run_commands_with_ssh(
[package._install_cmd()], cmd_prefix=cmd_prefix, stream_logs=True
)

if env.setup_cmds:
self._run_commands_with_ssh(env.setup_cmds, cmd_prefix=cmd_prefix)

def _sync_runhouse_to_cluster(self, _install_url=None, env=None):
if self.on_this_cluster():
return

if not self.address:
raise ValueError(f"No address set for cluster <{self.name}>. Is it up?")

env = env or self.default_env

local_rh_package_path = Path(importlib.util.find_spec("runhouse").origin).parent

# Check if runhouse is installed from source and has setup.py
Expand Down Expand Up @@ -351,7 +431,10 @@ def _sync_runhouse_to_cluster(self, _install_url=None, env=None):

for node in self.ips:
status_codes = self.run(
[rh_install_cmd], node=node, env=env, stream_logs=True
[rh_install_cmd],
node=node,
stream_logs=True,
env=env,
)

if status_codes[0][0] != 0:
Expand All @@ -376,7 +459,9 @@ def install_packages(
from runhouse.resources.envs.env import Env

self.check_server()
env = _get_env_from(env) or Env(name=env or Env.DEFAULT_NAME)
env = _get_env_from(env or self.default_env) or Env(
name=env or Env.DEFAULT_NAME
)
env.reqs = env._reqs + reqs
env.to(self)

Expand Down Expand Up @@ -410,11 +495,12 @@ def add_secrets(
):
"""Copy secrets from current environment onto the cluster"""
self.check_server()
self.sync_secrets(provider_secrets, env=env)
self.sync_secrets(provider_secrets, env=env or self.default_env)

def put(self, key: str, obj: Any, env=None):
"""Put the given object on the cluster's object store at the given key."""
self.check_server()
env = env or self.default_env
if self.on_this_cluster():
return obj_store.put(key, obj, env=env)
return self.client.put_object(key, obj, env=env)
Expand All @@ -431,7 +517,7 @@ def put_resource(
if hasattr(resource, "env")
else resource.name or resource.env_name
if resource.RESOURCE_TYPE == "env"
else None
else self.default_env
)

if env and not isinstance(env, str):
Expand Down Expand Up @@ -460,7 +546,7 @@ def keys(self, env=None):
self.check_server()
if self.on_this_cluster():
return obj_store.keys()
res = self.client.keys(env=env)
res = self.client.keys(env=env or self.default_env)
return res

def delete(self, keys: Union[None, str, List[str]]):
Expand Down Expand Up @@ -676,7 +762,6 @@ def restart_server(
_rh_install_url: str = None,
resync_rh: bool = True,
restart_ray: bool = False,
env: Union[str, "Env"] = None,
restart_proxy: bool = False,
):
"""Restart the RPC server.
Expand All @@ -692,8 +777,14 @@ def restart_server(
"""
logger.info(f"Restarting Runhouse API server on {self.name}.")

default_env = _get_env_from(self.default_env) if self.default_env else None
if default_env:
self._sync_default_env_to_cluster()

if resync_rh:
self._sync_runhouse_to_cluster(_install_url=_rh_install_url)
self._sync_runhouse_to_cluster(
_install_url=_rh_install_url, env=default_env
)
logger.debug("Finished syncing Runhouse to cluster.")

https_flag = self._use_https
Expand Down Expand Up @@ -749,7 +840,7 @@ def restart_server(
+ f" --port {self.server_port}"
)

status_codes = self.run(commands=[cmd], env=env)
status_codes = self.run(commands=[cmd], env=self.default_env, node=self.address)
if not status_codes[0][0] == 0:
raise ValueError(f"Failed to restart server {self.name}.")

Expand All @@ -768,6 +859,9 @@ def restart_server(
if restart_ray and len(self.ips) > 1:
self._start_ray_workers(DEFAULT_RAY_PORT)

if default_env:
self.put_resource(default_env)

return status_codes

def stop_server(self, stop_ray: bool = True, env: Union[str, "Env"] = None):
Expand All @@ -779,7 +873,7 @@ def stop_server(self, stop_ray: bool = True, env: Union[str, "Env"] = None):
"""
cmd = CLI_STOP_CMD if stop_ray else f"{CLI_STOP_CMD} --no-stop-ray"

status_codes = self.run([cmd], env=env, stream_logs=False)
status_codes = self.run([cmd], env=env or self.default_env, stream_logs=False)
assert status_codes[0][0] == 1

@contextlib.contextmanager
Expand Down Expand Up @@ -1077,12 +1171,15 @@ def run(
>>> cpu.run(["python script.py"], run_name="my_exp")
>>> cpu.run(["python script.py"], node="3.89.174.234")
"""
if isinstance(commands, str):
commands = [commands]

if node == "all":
res_list = []
for node in self.ips:
res = self.run(
commands=commands,
env=env,
env=env or self.default_env,
stream_logs=stream_logs,
port_forward=port_forward,
require_outputs=require_outputs,
Expand Down Expand Up @@ -1167,8 +1264,8 @@ def run(
def _run_commands_with_ssh(
self,
commands: list,
cmd_prefix: str,
stream_logs: bool,
cmd_prefix: str = None,
stream_logs: bool = True,
node: str = None,
port_forward: int = None,
require_outputs: bool = True,
Expand Down Expand Up @@ -1271,7 +1368,7 @@ def run_python(
# If invoking a run as part of the python commands also return the Run object
return_codes = self.run(
[formatted_command],
env=env,
env=env or self.default_env,
stream_logs=stream_logs,
node=node,
port_forward=port_forward,
Expand All @@ -1297,6 +1394,7 @@ def sync_secrets(
self.check_server()
from runhouse.resources.secrets import Secret

env = env or self.default_env
if isinstance(env, str):
from runhouse.resources.envs import Env

Expand Down
Loading

0 comments on commit b033875

Please sign in to comment.