From f589387e97a197f530a8f33ac466ea330c788230 Mon Sep 17 00:00:00 2001 From: Rohin Bhasin Date: Tue, 26 Mar 2024 17:29:58 -0400 Subject: [PATCH] Move uvicorn start into `ClusterServlet` to increase process locality. --- runhouse/servers/cluster_servlet.py | 19 ++++++++++++++++++ runhouse/servers/http/http_server.py | 24 +++++++++------------- runhouse/servers/obj_store.py | 30 ++++++++++++++++++++-------- 3 files changed, 50 insertions(+), 23 deletions(-) diff --git a/runhouse/servers/cluster_servlet.py b/runhouse/servers/cluster_servlet.py index 48167e4d2..7a414ebf6 100644 --- a/runhouse/servers/cluster_servlet.py +++ b/runhouse/servers/cluster_servlet.py @@ -1,11 +1,15 @@ import logging from typing import Any, Dict, List, Optional, Set, Union +import uvicorn + from runhouse.globals import configs, rns_client from runhouse.resources.hardware import load_cluster_config_from_file from runhouse.rns.utils.api import ResourceAccess from runhouse.servers.http.auth import AuthCache +from runhouse.servers.http.http_server import app + logger = logging.getLogger(__name__) @@ -27,6 +31,21 @@ async def __init__( self._key_to_env_servlet_name: Dict[Any, str] = {} self._auth_cache: AuthCache = AuthCache() + async def start_http_server(self, *args, **kwargs): + # Hack to set the cluster servlet in the obj_store, don't initialize it + from runhouse.globals import obj_store + from runhouse.servers.obj_store import ClusterServletSetupOption + + await obj_store.ainitialize( + "base", setup_cluster_servlet=ClusterServletSetupOption.NONE + ) + obj_store.cluster_servlet = self + + config = uvicorn.Config(app, *args, **kwargs) + self._uvicorn_server = uvicorn.Server(config) + + await self._uvicorn_server.serve() + ############################################## # Cluster config state storage methods ############################################## diff --git a/runhouse/servers/http/http_server.py b/runhouse/servers/http/http_server.py index 20d8ea901..bcde865f0 100644 --- a/runhouse/servers/http/http_server.py +++ b/runhouse/servers/http/http_server.py @@ -55,8 +55,8 @@ logger = logging.getLogger(__name__) app = FastAPI(docs_url=None, redoc_url=None) - -suspend_autostop = False +app.suspend_autostop = False +app.memory_exporter = None def validate_cluster_access(func): @@ -116,7 +116,6 @@ async def wrapper(*args, **kwargs): class HTTPServer: SKY_YAML = str(Path("~/.sky/sky_ray.yml").expanduser()) - memory_exporter = None @classmethod async def ainitialize( @@ -148,10 +147,9 @@ async def ainitialize( ) ) ) - global memory_exporter - memory_exporter = InMemorySpanExporter() + app.memory_exporter = InMemorySpanExporter() trace.get_tracer_provider().add_span_processor( - SimpleSpanProcessor(memory_exporter) + SimpleSpanProcessor(app.memory_exporter) ) # Instrument the app object FastAPIInstrumentor.instrument_app(app) @@ -164,7 +162,8 @@ async def ainitialize( def get_spans(request: Request): return { "spans": [ - span.to_json() for span in memory_exporter.get_finished_spans() + span.to_json() + for span in app.memory_exporter.get_finished_spans() ] } @@ -251,7 +250,7 @@ async def disable_den_auth(cls): @staticmethod def register_activity(): - if suspend_autostop: + if app.suspend_autostop: try: from sky.skylet.autostop_lib import set_last_active_time_to_now @@ -853,7 +852,6 @@ def _log_cluster_data(data: dict, labels: dict): async def main(): - import uvicorn parser = argparse.ArgumentParser() parser.add_argument( @@ -950,8 +948,7 @@ async def main(): ) else: logger.info("Loaded cluster config from Ray.") - global suspend_autostop - suspend_autostop = cluster_config.get("autostop_mins", -1) > 0 + app.suspend_autostop = cluster_config.get("autostop_mins", -1) > 0 ######################################## # Handling args that could be specified in the @@ -1186,15 +1183,12 @@ async def main(): uvicorn_cert = parsed_ssl_certfile if not use_caddy and use_https else None uvicorn_key = parsed_ssl_keyfile if not use_caddy and use_https else None - config = uvicorn.Config( - app, + await obj_store.astart_http_server_in_cluster_servlet( host=host, port=daemon_port, ssl_certfile=uvicorn_cert, ssl_keyfile=uvicorn_key, ) - server = uvicorn.Server(config) - await server.serve() if __name__ == "__main__": diff --git a/runhouse/servers/obj_store.py b/runhouse/servers/obj_store.py index 982c7bce9..96fc3b60b 100644 --- a/runhouse/servers/obj_store.py +++ b/runhouse/servers/obj_store.py @@ -28,6 +28,7 @@ class ClusterServletSetupOption(str, Enum): GET_OR_CREATE = "get_or_create" GET_OR_FAIL = "get_or_fail" FORCE_CREATE = "force_create" + NONE = "none" class ObjStoreError(Exception): @@ -170,15 +171,16 @@ async def ainitialize( ) # Now, we expect to be connected to an initialized Ray instance. - if setup_cluster_servlet == ClusterServletSetupOption.FORCE_CREATE: - kill_actors(namespace="runhouse", gracefully=False) + if setup_cluster_servlet != ClusterServletSetupOption.NONE: + if setup_cluster_servlet == ClusterServletSetupOption.FORCE_CREATE: + kill_actors(namespace="runhouse", gracefully=False) - create_if_not_exists = ( - setup_cluster_servlet != ClusterServletSetupOption.GET_OR_FAIL - ) - self.cluster_servlet = get_cluster_servlet( - create_if_not_exists=create_if_not_exists - ) + create_if_not_exists = ( + setup_cluster_servlet != ClusterServletSetupOption.GET_OR_FAIL + ) + self.cluster_servlet = get_cluster_servlet( + create_if_not_exists=create_if_not_exists + ) if self.cluster_servlet is None: # TODO: logger. is not printing correctly here when doing `runhouse start`. # Fix this and general logging. @@ -231,6 +233,14 @@ def initialize( setup_cluster_servlet, ) + def astart_http_server_in_cluster_servlet(self, *args, **kwargs): + if self.cluster_servlet is not None: + return self.call_actor_method( + self.cluster_servlet, "start_http_server", *args, **kwargs + ) + else: + raise ObjStoreError("Cluster servlet not initialized.") + ############################################## # Generic helpers ############################################## @@ -240,6 +250,10 @@ async def acall_actor_method( ): if actor is None: raise ObjStoreError("Attempting to call an actor method on a None actor.") + + if not isinstance(actor, ray.actor.ActorHandle): + # This is likely the cluster servlet set within itself, hack for now. + return await getattr(actor, method)(*args, **kwargs) return await getattr(actor, method).remote(*args, **kwargs) @staticmethod