From cf7478db39b9e71f23451325b06ed32e0e9b9972 Mon Sep 17 00:00:00 2001 From: Rohin Bhasin Date: Fri, 12 Apr 2024 10:37:02 -0400 Subject: [PATCH] Add animated spinners to long running operations. --- runhouse/resources/hardware/cluster.py | 13 +- runhouse/resources/hardware/sky_ssh_runner.py | 72 ++--- runhouse/resources/module.py | 14 +- runhouse/resources/packages/package.py | 16 +- runhouse/servers/http/http_client.py | 289 +++++++++--------- runhouse/utils.py | 28 ++ setup.py | 1 + 7 files changed, 240 insertions(+), 193 deletions(-) diff --git a/runhouse/resources/hardware/cluster.py b/runhouse/resources/hardware/cluster.py index 7a2d37bde..66aba66ad 100644 --- a/runhouse/resources/hardware/cluster.py +++ b/runhouse/resources/hardware/cluster.py @@ -44,6 +44,7 @@ from runhouse.resources.resource import Resource from runhouse.servers.http import HTTPClient +from runhouse.utils import alive_bar_spinner_only, success_emoji logger = logging.getLogger(__name__) @@ -651,9 +652,15 @@ def check_server(self, restart_server=True): if not self.client: try: self.connect_server_client() - logger.debug(f"Checking server {self.name}") - self.client.check_server() - logger.info(f"Server {self.name} is up.") + with alive_bar_spinner_only( + title=f"Checking Runhouse server on cluster {self.name} is up..." + ) as bar: + self.client.check_server() + bar.title( + success_emoji( + f"Confirmed Runhouse server on cluster {self.name} is up" + ) + ) except ( requests.exceptions.ConnectionError, requests.exceptions.ReadTimeout, diff --git a/runhouse/resources/hardware/sky_ssh_runner.py b/runhouse/resources/hardware/sky_ssh_runner.py index 36b1b17b9..0c8c7a6e7 100644 --- a/runhouse/resources/hardware/sky_ssh_runner.py +++ b/runhouse/resources/hardware/sky_ssh_runner.py @@ -22,7 +22,7 @@ SSHCommandRunner, SshMode, ) - +from runhouse.utils import alive_bar_spinner_only, success_emoji logger = logging.getLogger(__name__) @@ -90,7 +90,7 @@ def _ssh_base_command( local, remote = fwd, fwd else: local, remote = fwd - logger.info(f"Forwarding port {local} to port {remote} on localhost.") + logger.debug(f"Forwarding port {local} to port {remote} on localhost.") ssh += ["-L", f"{local}:localhost:{remote}"] if self._docker_ssh_proxy_command is not None: docker_ssh_proxy_command = self._docker_ssh_proxy_command(ssh) @@ -472,39 +472,43 @@ def ssh_tunnel( ) return tunnel - while is_port_in_use(local_port): - if num_ports_to_try < 0: - raise Exception( - f"Failed to create find open port after {num_ports_to_try} attempts" - ) + with alive_bar_spinner_only(title="Creating ssh tunnel to remote host...") as bar: + while is_port_in_use(local_port): + if num_ports_to_try < 0: + raise Exception( + f"Failed to create find open port after {num_ports_to_try} attempts" + ) - logger.info(f"Port {local_port} is already in use. Trying next port.") - local_port += 1 - num_ports_to_try -= 1 - - # Start a tunnel using self.run in a thread, instead of ssh_tunnel - ssh_credentials = copy.copy(ssh_creds) - - # Host could be a proxy specified in credentials or is the provided address - host = ssh_credentials.pop("ssh_host", address) - ssh_control_name = ssh_credentials.pop("ssh_control_name", f"{address}:{ssh_port}") - - runner = SkySSHRunner( - ip=host, - ssh_user=ssh_creds.get("ssh_user"), - ssh_private_key=ssh_creds.get("ssh_private_key"), - ssh_proxy_command=ssh_creds.get("ssh_proxy_command"), - ssh_control_name=ssh_control_name, - port=ssh_port, - ) - runner.tunnel(local_port, remote_port) - - logger.debug( - f"Successfully bound " - f"{LOCALHOST}:{remote_port} via ssh port {ssh_port} " - f"on remote server {address} " - f"to {LOCALHOST}:{local_port} on local machine." - ) + logger.info(f"Port {local_port} is already in use. Trying next port.") + local_port += 1 + num_ports_to_try -= 1 + + # Start a tunnel using self.run in a thread, instead of ssh_tunnel + ssh_credentials = copy.copy(ssh_creds) + + # Host could be a proxy specified in credentials or is the provided address + host = ssh_credentials.pop("ssh_host", address) + ssh_control_name = ssh_credentials.pop( + "ssh_control_name", f"{address}:{ssh_port}" + ) + + runner = SkySSHRunner( + ip=host, + ssh_user=ssh_creds.get("ssh_user"), + ssh_private_key=ssh_creds.get("ssh_private_key"), + ssh_proxy_command=ssh_creds.get("ssh_proxy_command"), + ssh_control_name=ssh_control_name, + port=ssh_port, + ) + runner.tunnel(local_port, remote_port) + + logger.info( + f"Successfully bound " + f"{LOCALHOST}:{remote_port} via ssh port {ssh_port} " + f"on remote server {address} " + f"to {LOCALHOST}:{local_port} on local machine." + ) + bar.title(success_emoji(f"SSH tunnel to {address} created successfully ")) cache_existing_sky_ssh_runner(address, ssh_port, runner) return runner diff --git a/runhouse/resources/module.py b/runhouse/resources/module.py index 1dc184aa6..5ec556591 100644 --- a/runhouse/resources/module.py +++ b/runhouse/resources/module.py @@ -28,6 +28,7 @@ from runhouse.rns.utils.names import _generate_default_name from runhouse.servers.http import HTTPClient from runhouse.servers.http.http_utils import CallParams +from runhouse.utils import alive_bar_spinner_only, success_emoji logger = logging.getLogger(__name__) @@ -497,10 +498,15 @@ def to( for attr, val in self.__dict__.items() if attr not in excluded_state_keys } - logger.info( - f"Sending module {new_module.name} of type {type(new_module)} to {system.name or 'local Runhouse daemon'}" - ) - system.put_resource(new_module, state, dryrun=True) + with alive_bar_spinner_only( + title=f"Sending module {new_module.name} of type {type(new_module)} to {system.name or 'local Runhouse daemon'}" + ) as bar: + system.put_resource(new_module, state, dryrun=True) + bar.title( + success_emoji( + f"Sending module {new_module.name} of type {type(new_module)} to {system.name or 'local Runhouse daemon'}" + ) + ) return new_module diff --git a/runhouse/resources/packages/package.py b/runhouse/resources/packages/package.py index 19c5d7e24..d90600837 100644 --- a/runhouse/resources/packages/package.py +++ b/runhouse/resources/packages/package.py @@ -10,6 +10,7 @@ from runhouse.resources.folders import Folder, folder from runhouse.resources.hardware.utils import _get_cluster_from from runhouse.resources.resource import Resource +from runhouse.utils import alive_bar_spinner_only, success_emoji INSTALL_METHODS = {"local", "reqs", "pip", "conda"} @@ -359,10 +360,17 @@ def to( # If we're on the target system, just make sure the package is in the Python path sys.path.append(self.install_target.local_path) return self - logger.info( - f"Copying package from {self.install_target.fsspec_url} to: {getattr(system, 'name', system)}" - ) - new_folder = self.install_target._to_cluster(system, path=path, mount=mount) + with alive_bar_spinner_only( + title=f"Copying package from {self.install_target.fsspec_url} to: {getattr(system, 'name', system)}" + ) as bar: + new_folder = self.install_target._to_cluster( + system, path=path, mount=mount + ) + bar.title( + success_emoji( + f"Package copied to: {getattr(system, 'name', system)}" + ) + ) else: # to fs new_folder = self.install_target.to(system, path=path) new_folder.system = system diff --git a/runhouse/servers/http/http_client.py b/runhouse/servers/http/http_client.py index 5ccf23b67..00c77cc82 100644 --- a/runhouse/servers/http/http_client.py +++ b/runhouse/servers/http/http_client.py @@ -28,6 +28,7 @@ RenameObjectParams, serialize_data, ) +from runhouse.utils import alive_bar_spinner_only, success_emoji logger = logging.getLogger(__name__) @@ -327,91 +328,87 @@ def call_module_method( Client function to call the rpc for call_module_method """ # Measure the time it takes to send the message - start = time.time() - logger.info( - f"{'Calling' if method_name else 'Getting'} {key}" - + (f".{method_name}" if method_name else "") - ) - serialization = serialization or "pickle" - res = retry_with_exponential_backoff(session.post)( - self._formatted_url(f"{key}/{method_name}"), - json=CallParams( - data=serialize_data(data, serialization), - serialization=serialization, - run_name=run_name, - stream_logs=stream_logs, - save=save, - remote=remote, - ).dict(), - stream=True, - headers=rns_client.request_headers(resource_address), - auth=self.auth, - verify=self.verify, - ) - - if res.status_code != 200: - raise ValueError( - f"Error calling {method_name} on server: {res.content.decode()}" + if method_name: + log_str = f"Calling {key}.{method_name}" + else: + log_str = f"Getting {key}" + with alive_bar_spinner_only(title=log_str) as bar: + serialization = serialization or "pickle" + res = retry_with_exponential_backoff(session.post)( + self._formatted_url(f"{key}/{method_name}"), + json=CallParams( + data=serialize_data(data, serialization), + serialization=serialization, + run_name=run_name, + stream_logs=stream_logs, + save=save, + remote=remote, + ).dict(), + stream=True, + headers=rns_client.request_headers(resource_address), + auth=self.auth, + verify=self.verify, ) - error_str = f"Error calling {method_name} on {key} on server" - # We get back a stream of intermingled log outputs and results (maybe None, maybe error, maybe single result, - # maybe a stream of results), so we need to separate these out. - result = None - res_iter = res.iter_lines(chunk_size=None) - # We need to manually iterate through res_iter so we can try/except to bypass a ChunkedEncodingError bug - while True: - try: - responses_json = next(res_iter) - except requests.exceptions.ChunkedEncodingError: - # Some silly bug in urllib3, see https://github.com/psf/requests/issues/4248 - continue - except StopIteration: - break - except StopAsyncIteration: - break - - resp = json.loads(responses_json) - output_type = resp["output_type"] - result = handle_response( - resp, output_type, error_str, log_formatter=self.log_formatter - ) - # If this was a `.remote` call, we don't need to recreate the system and connection, which can be - # slow, we can just set it explicitly. - from runhouse.resources.module import Module + if res.status_code != 200: + raise ValueError( + f"Error calling {method_name} on server: {res.content.decode()}" + ) + error_str = f"Error calling {method_name} on {key} on server" - if isinstance(result, Module): - if ( - system - and result.system - and system.rns_address == result.system.rns_address - ): - result.system = system - elif output_type == OutputType.CONFIG: - if ( - system - and "system" in result - and system.rns_address == result["system"] - ): - result["system"] = system - result = Resource.from_config(result, dryrun=True) + # We get back a stream of intermingled log outputs and results (maybe None, maybe error, maybe single result, + # maybe a stream of results), so we need to separate these out. + result = None + res_iter = res.iter_lines(chunk_size=None) + # We need to manually iterate through res_iter so we can try/except to bypass a ChunkedEncodingError bug + while True: + try: + responses_json = next(res_iter) + except requests.exceptions.ChunkedEncodingError: + # Some silly bug in urllib3, see https://github.com/psf/requests/issues/4248 + continue + except StopIteration: + break + except StopAsyncIteration: + break + + resp = json.loads(responses_json) + output_type = resp["output_type"] + result = handle_response( + resp, output_type, error_str, log_formatter=self.log_formatter + ) + # If this was a `.remote` call, we don't need to recreate the system and connection, which can be + # slow, we can just set it explicitly. + from runhouse.resources.module import Module - end = time.time() + if isinstance(result, Module): + if ( + system + and result.system + and system.rns_address == result.system.rns_address + ): + result.system = system + elif output_type == OutputType.CONFIG: + if ( + system + and "system" in result + and system.rns_address == result["system"] + ): + result["system"] = system + result = Resource.from_config(result, dryrun=True) - if ( - hasattr(result, "system") - and system is not None - and result.system.rns_address == system.rns_address - ): - result.system = system + if ( + hasattr(result, "system") + and system is not None + and result.system.rns_address == system.rns_address + ): + result.system = system - if method_name: - log_str = ( - f"Time to call {key}.{method_name}: {round(end - start, 2)} seconds" - ) - else: - log_str = f"Time to get {key}: {round(end - start, 2)} seconds" - logging.info(log_str) + if method_name: + log_str = f"Completed call {key}.{method_name}" + else: + log_str = f"Completed get {key}" + bar.title(success_emoji(log_str)) return result async def acall( @@ -460,77 +457,73 @@ async def acall_module_method( Client function to call the rpc for call_module_method """ # Measure the time it takes to send the message - start = time.time() - logger.info( - f"{'Calling' if method_name else 'Getting'} {key}" - + (f".{method_name}" if method_name else "") - ) - serialization = serialization or "pickle" - async with self.async_session.stream( - "POST", - self._formatted_url(f"{key}/{method_name}"), - json=CallParams( - data=serialize_data(data, serialization), - serialization=serialization, - run_name=run_name, - stream_logs=stream_logs, - save=save, - remote=remote, - run_async=run_async, - ).dict(), - headers=rns_client.request_headers(resource_address), - ) as res: - if res.status_code != 200: - raise ValueError( - f"Error calling {method_name} on server: {res.content.decode()}" - ) - error_str = f"Error calling {method_name} on {key} on server" - - # We get back a stream of intermingled log outputs and results (maybe None, maybe error, maybe single result, - # maybe a stream of results), so we need to separate these out. - result = None - async for response_json in res.aiter_lines(): - resp = json.loads(response_json) - output_type = resp["output_type"] - result = handle_response( - resp, output_type, error_str, log_formatter=self.log_formatter - ) - # If this was a `.remote` call, we don't need to recreate the system and connection, which can be - # slow, we can just set it explicitly. - from runhouse.resources.module import Module - - if isinstance(result, Module): - if ( - system - and result.system - and system.rns_address == result.system.rns_address - ): - result.system = system - elif output_type == OutputType.CONFIG: - if ( - system - and "system" in result - and system.rns_address == result["system"] - ): - result["system"] = system - result = Resource.from_config(result, dryrun=True) - - end = time.time() + if method_name: + log_str = f"Calling {key}.{method_name}" + else: + log_str = f"Getting {key}" + with alive_bar_spinner_only(title=log_str) as bar: + serialization = serialization or "pickle" + async with self.async_session.stream( + "POST", + self._formatted_url(f"{key}/{method_name}"), + json=CallParams( + data=serialize_data(data, serialization), + serialization=serialization, + run_name=run_name, + stream_logs=stream_logs, + save=save, + remote=remote, + run_async=run_async, + ).dict(), + headers=rns_client.request_headers(resource_address), + ) as res: + if res.status_code != 200: + raise ValueError( + f"Error calling {method_name} on server: {res.content.decode()}" + ) + error_str = f"Error calling {method_name} on {key} on server" + + # We get back a stream of intermingled log outputs and results (maybe None, maybe error, maybe single result, + # maybe a stream of results), so we need to separate these out. + result = None + async for response_json in res.aiter_lines(): + resp = json.loads(response_json) + output_type = resp["output_type"] + result = handle_response( + resp, output_type, error_str, log_formatter=self.log_formatter + ) + # If this was a `.remote` call, we don't need to recreate the system and connection, which can be + # slow, we can just set it explicitly. + from runhouse.resources.module import Module + + if isinstance(result, Module): + if ( + system + and result.system + and system.rns_address == result.system.rns_address + ): + result.system = system + elif output_type == OutputType.CONFIG: + if ( + system + and "system" in result + and system.rns_address == result["system"] + ): + result["system"] = system + result = Resource.from_config(result, dryrun=True) - if ( - hasattr(result, "system") - and system is not None - and result.system.rns_address == system.rns_address - ): - result.system = system + if ( + hasattr(result, "system") + and system is not None + and result.system.rns_address == system.rns_address + ): + result.system = system - if method_name: - log_str = ( - f"Time to call {key}.{method_name}: {round(end - start, 2)} seconds" - ) - else: - log_str = f"Time to get {key}: {round(end - start, 2)} seconds" - logging.info(log_str) + if method_name: + log_str = f"Completed call {key}.{method_name}" + else: + log_str = f"Completed get {key}" + bar.title(success_emoji(log_str)) return result def put_object(self, key: str, value: Any, env=None): diff --git a/runhouse/utils.py b/runhouse/utils.py index c7a022c62..83fe8bbca 100644 --- a/runhouse/utils.py +++ b/runhouse/utils.py @@ -3,6 +3,9 @@ from concurrent.futures import ThreadPoolExecutor from functools import wraps +from alive_progress import alive_bar +from rich.emoji import Emoji + def _thread_coroutine(coroutine, context): # Copy contextvars from the parent thread to the new thread @@ -44,3 +47,28 @@ def string_to_dict(dict_as_string): key = parts[0].strip() value = parts[1].strip() return key, value + + +#################################################################################################### +# Styling utils +#################################################################################################### +def success_emoji(text: str) -> str: + return f"{Emoji('white_check_mark')} {text}" + + +def failure_emoji(text: str) -> str: + return f"{Emoji('cross_mark')} {text}" + + +def alive_bar_spinner_only(*args, **kwargs): + return alive_bar( + bar=None, # No actual bar + enrich_print=False, # Print statements while the bar is running are unmodified + monitor=False, + stats=False, + monitor_end=True, + stats_end=False, + title_length=0, + *args, + **kwargs, + ) diff --git a/setup.py b/setup.py index 913c579ed..7934e19c5 100644 --- a/setup.py +++ b/setup.py @@ -82,6 +82,7 @@ def parse_readme(readme: str) -> str: "wheel", "apispec", "httpx", + "alive_progress", ] # NOTE: Change the templates/spot-controller.yaml.j2 file if any of the following