From 8d80d55f1940cf63b23584d6ef1cf54c1027f1a9 Mon Sep 17 00:00:00 2001 From: Alexandra Belousov Date: Mon, 15 Apr 2024 11:41:30 +0300 Subject: [PATCH] minor change to docker cluster fixture --- runhouse/resources/hardware/cluster.py | 6 ++- .../resources/hardware/cluster_factory.py | 2 + tests/fixtures/docker_cluster_fixtures.py | 48 +++++++++++++++++-- 3 files changed, 50 insertions(+), 6 deletions(-) diff --git a/runhouse/resources/hardware/cluster.py b/runhouse/resources/hardware/cluster.py index 1275018cb..cb72c2aec 100644 --- a/runhouse/resources/hardware/cluster.py +++ b/runhouse/resources/hardware/cluster.py @@ -65,6 +65,7 @@ def __init__( den_auth: bool = False, use_local_telemetry: bool = False, dryrun=False, + api_server_url=None, **kwargs, # We have this here to ignore extra arguments when calling from from_config ): """ @@ -95,6 +96,7 @@ def __init__( self.server_host = server_host self.domain = domain self.use_local_telemetry = use_local_telemetry + self.api_server_url = api_server_url @property def address(self): @@ -186,6 +188,7 @@ def config(self, condensed=True): "use_local_telemetry", "ssh_port", "client_port", + "api_server_url", ], ) creds = self._resource_string_for_subconfig(self._creds, condensed) @@ -197,7 +200,8 @@ def config(self, condensed=True): creds = creds.replace("loaded_secret_", "") config["creds"] = creds - config["api_server_url"] = rns_client.api_server_url + if config.get("api_server_url") is None: + config["api_server_url"] = rns_client.api_server_url if self._use_custom_certs: config["ssl_certfile"] = self.cert_config.cert_path diff --git a/runhouse/resources/hardware/cluster_factory.py b/runhouse/resources/hardware/cluster_factory.py index 5236c9ff2..e2b3dd36b 100644 --- a/runhouse/resources/hardware/cluster_factory.py +++ b/runhouse/resources/hardware/cluster_factory.py @@ -29,6 +29,7 @@ def cluster( domain: str = None, den_auth: bool = False, dryrun: bool = False, + api_server_url=None, **kwargs, ) -> Union[Cluster, OnDemandCluster, SageMakerCluster]: """ @@ -177,6 +178,7 @@ def cluster( domain=domain, den_auth=den_auth, dryrun=dryrun, + api_server_url=api_server_url, **kwargs, ) c.set_connection_defaults(**kwargs) diff --git a/tests/fixtures/docker_cluster_fixtures.py b/tests/fixtures/docker_cluster_fixtures.py index 472be0b4f..22d247b7f 100644 --- a/tests/fixtures/docker_cluster_fixtures.py +++ b/tests/fixtures/docker_cluster_fixtures.py @@ -23,6 +23,8 @@ LOCAL_HTTPS_SERVER_PORT = 8443 LOCAL_HTTP_SERVER_PORT = 8080 DEFAULT_KEYPAIR_KEYPATH = "~/.ssh/sky-key" +DEN_DOCKER_API_SERVER_URL = "http://den_rns" +DEN_DOCKER_NETWORK_NAME = "rns_server_den_oss_network" def get_rh_parent_path(): @@ -80,6 +82,7 @@ def build_and_run_image( container_name: str, reuse_existing_container: bool, dir_name: str, + api_server_url: str, pwd_file=None, keypath=None, force_rebuild=False, @@ -167,11 +170,36 @@ def build_and_run_image( port_fwds = ( "".join([f"-p {port_fwd} " for port_fwd in port_fwds]).strip().split(" ") ) - run_cmd = ( - ["docker", "run", "--name", container_name, "-d", "--rm", "--shm-size=5.04gb"] - + port_fwds - + [f"runhouse:{image_name}"] - ) + if api_server_url == DEN_DOCKER_API_SERVER_URL: + run_cmd = ( + [ + "docker", + "run", + "--name", + container_name, + f"--network={DEN_DOCKER_NETWORK_NAME}", + "-d", + "--rm", + "--shm-size=5.04gb", + ] + + port_fwds + + [f"runhouse:{image_name}"] + ) + else: + run_cmd = ( + [ + "docker", + "run", + "--name", + container_name, + "-d", + "--rm", + "--shm-size=5.04gb", + ] + + port_fwds + + [f"runhouse:{image_name}"] + ) + print(shlex.join(run_cmd)) res = popen_shell_command(subprocess, run_cmd, cwd=str(rh_parent_path.parent)) stdout, stderr = res.communicate() @@ -227,6 +255,7 @@ def set_up_local_cluster( port_fwds: List[str], local_ssh_port: int, additional_cluster_init_args: Dict[str, Any], + api_server_url: str, logged_in: bool = False, keypath: str = None, pwd_file: str = None, @@ -240,6 +269,7 @@ def set_up_local_cluster( pwd_file=pwd_file, force_rebuild=force_rebuild, port_fwds=port_fwds, + api_server_url=api_server_url, ) cluster_init_args = dict( @@ -250,6 +280,7 @@ def set_up_local_cluster( "ssh_user": SSH_USER, "ssh_private_key": keypath, }, + api_server_url=api_server_url, ) for k, v in additional_cluster_init_args.items(): @@ -270,6 +301,8 @@ def set_up_local_cluster( config["token"] = rh.configs.token config["username"] = rh.configs.username + # if restarting a server when running the tests in a den docker env (aka api_server_url = http://den_rns:8000), + # putting resources in clusters with caddy causing errors. if rh_cluster._use_https: # If re-using fixtures make sure the crt file gets copied on to the cluster rh_cluster.restart_server() @@ -343,6 +376,7 @@ def docker_cluster_pk_tls_exposed(request, test_rns_folder): "den_auth": True, "use_local_telemetry": True, }, + api_server_url=api_server_url, ) # Yield the cluster @@ -392,6 +426,7 @@ def docker_cluster_pk_ssh(request, test_rns_folder): "server_connection_type": "ssh", "use_local_telemetry": True, }, + api_server_url=api_server_url, ) # Yield the cluster @@ -489,6 +524,7 @@ def docker_cluster_pk_http_exposed(request, test_rns_folder): "den_auth": False, "use_local_telemetry": True, }, + api_server_url=api_server_url, ) # Yield the cluster yield local_cluster @@ -537,6 +573,7 @@ def docker_cluster_pwd_ssh_no_auth(request, test_rns_folder): "server_connection_type": "ssh", "ssh_creds": {"ssh_user": SSH_USER, "password": pwd}, }, + api_server_url=api_server_url, ) # Yield the cluster yield local_cluster @@ -585,6 +622,7 @@ def friend_account_logged_in_docker_cluster_pk_ssh(request, test_rns_folder): "den_auth": "den_auth" in request.keywords, }, logged_in=True, + api_server_url=api_server_url, ) yield local_cluster