Skip to content

Commit

Permalink
Move uvicorn start into ClusterServlet to increase process locality.
Browse files Browse the repository at this point in the history
  • Loading branch information
rohinb2 committed Mar 27, 2024
1 parent e81fd2f commit f589387
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 23 deletions.
19 changes: 19 additions & 0 deletions runhouse/servers/cluster_servlet.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import logging
from typing import Any, Dict, List, Optional, Set, Union

import uvicorn

from runhouse.globals import configs, rns_client
from runhouse.resources.hardware import load_cluster_config_from_file
from runhouse.rns.utils.api import ResourceAccess
from runhouse.servers.http.auth import AuthCache

from runhouse.servers.http.http_server import app

logger = logging.getLogger(__name__)


Expand All @@ -27,6 +31,21 @@ async def __init__(
self._key_to_env_servlet_name: Dict[Any, str] = {}
self._auth_cache: AuthCache = AuthCache()

async def start_http_server(self, *args, **kwargs):
# Hack to set the cluster servlet in the obj_store, don't initialize it
from runhouse.globals import obj_store
from runhouse.servers.obj_store import ClusterServletSetupOption

await obj_store.ainitialize(
"base", setup_cluster_servlet=ClusterServletSetupOption.NONE
)
obj_store.cluster_servlet = self

config = uvicorn.Config(app, *args, **kwargs)
self._uvicorn_server = uvicorn.Server(config)

await self._uvicorn_server.serve()

##############################################
# Cluster config state storage methods
##############################################
Expand Down
24 changes: 9 additions & 15 deletions runhouse/servers/http/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@
logger = logging.getLogger(__name__)

app = FastAPI(docs_url=None, redoc_url=None)

suspend_autostop = False
app.suspend_autostop = False
app.memory_exporter = None


def validate_cluster_access(func):
Expand Down Expand Up @@ -116,7 +116,6 @@ async def wrapper(*args, **kwargs):

class HTTPServer:
SKY_YAML = str(Path("~/.sky/sky_ray.yml").expanduser())
memory_exporter = None

@classmethod
async def ainitialize(
Expand Down Expand Up @@ -148,10 +147,9 @@ async def ainitialize(
)
)
)
global memory_exporter
memory_exporter = InMemorySpanExporter()
app.memory_exporter = InMemorySpanExporter()
trace.get_tracer_provider().add_span_processor(
SimpleSpanProcessor(memory_exporter)
SimpleSpanProcessor(app.memory_exporter)
)
# Instrument the app object
FastAPIInstrumentor.instrument_app(app)
Expand All @@ -164,7 +162,8 @@ async def ainitialize(
def get_spans(request: Request):
return {
"spans": [
span.to_json() for span in memory_exporter.get_finished_spans()
span.to_json()
for span in app.memory_exporter.get_finished_spans()
]
}

Expand Down Expand Up @@ -251,7 +250,7 @@ async def disable_den_auth(cls):

@staticmethod
def register_activity():
if suspend_autostop:
if app.suspend_autostop:
try:
from sky.skylet.autostop_lib import set_last_active_time_to_now

Expand Down Expand Up @@ -853,7 +852,6 @@ def _log_cluster_data(data: dict, labels: dict):


async def main():
import uvicorn

parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -950,8 +948,7 @@ async def main():
)
else:
logger.info("Loaded cluster config from Ray.")
global suspend_autostop
suspend_autostop = cluster_config.get("autostop_mins", -1) > 0
app.suspend_autostop = cluster_config.get("autostop_mins", -1) > 0

########################################
# Handling args that could be specified in the
Expand Down Expand Up @@ -1186,15 +1183,12 @@ async def main():
uvicorn_cert = parsed_ssl_certfile if not use_caddy and use_https else None
uvicorn_key = parsed_ssl_keyfile if not use_caddy and use_https else None

config = uvicorn.Config(
app,
await obj_store.astart_http_server_in_cluster_servlet(
host=host,
port=daemon_port,
ssl_certfile=uvicorn_cert,
ssl_keyfile=uvicorn_key,
)
server = uvicorn.Server(config)
await server.serve()


if __name__ == "__main__":
Expand Down
30 changes: 22 additions & 8 deletions runhouse/servers/obj_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class ClusterServletSetupOption(str, Enum):
GET_OR_CREATE = "get_or_create"
GET_OR_FAIL = "get_or_fail"
FORCE_CREATE = "force_create"
NONE = "none"


class ObjStoreError(Exception):
Expand Down Expand Up @@ -170,15 +171,16 @@ async def ainitialize(
)

# Now, we expect to be connected to an initialized Ray instance.
if setup_cluster_servlet == ClusterServletSetupOption.FORCE_CREATE:
kill_actors(namespace="runhouse", gracefully=False)
if setup_cluster_servlet != ClusterServletSetupOption.NONE:
if setup_cluster_servlet == ClusterServletSetupOption.FORCE_CREATE:
kill_actors(namespace="runhouse", gracefully=False)

create_if_not_exists = (
setup_cluster_servlet != ClusterServletSetupOption.GET_OR_FAIL
)
self.cluster_servlet = get_cluster_servlet(
create_if_not_exists=create_if_not_exists
)
create_if_not_exists = (
setup_cluster_servlet != ClusterServletSetupOption.GET_OR_FAIL
)
self.cluster_servlet = get_cluster_servlet(
create_if_not_exists=create_if_not_exists
)
if self.cluster_servlet is None:
# TODO: logger.<method> is not printing correctly here when doing `runhouse start`.
# Fix this and general logging.
Expand Down Expand Up @@ -231,6 +233,14 @@ def initialize(
setup_cluster_servlet,
)

def astart_http_server_in_cluster_servlet(self, *args, **kwargs):
if self.cluster_servlet is not None:
return self.call_actor_method(
self.cluster_servlet, "start_http_server", *args, **kwargs
)
else:
raise ObjStoreError("Cluster servlet not initialized.")

##############################################
# Generic helpers
##############################################
Expand All @@ -240,6 +250,10 @@ async def acall_actor_method(
):
if actor is None:
raise ObjStoreError("Attempting to call an actor method on a None actor.")

if not isinstance(actor, ray.actor.ActorHandle):
# This is likely the cluster servlet set within itself, hack for now.
return await getattr(actor, method)(*args, **kwargs)
return await getattr(actor, method).remote(*args, **kwargs)

@staticmethod
Expand Down

0 comments on commit f589387

Please sign in to comment.