Skip to content

Commit

Permalink
refactor tornado + misc fixes to sanic/fastapi
Browse files Browse the repository at this point in the history
  • Loading branch information
rmorshea committed May 17, 2021
1 parent 04ae50a commit 16c9209
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 113 deletions.
13 changes: 8 additions & 5 deletions src/idom/server/fastapi.py
Expand Up @@ -8,7 +8,6 @@
import json
import logging
import sys
import time
from asyncio import Future
from threading import Event, Thread, current_thread
from typing import Any, Dict, Optional, Tuple, Union
Expand All @@ -35,6 +34,8 @@
)
from idom.core.layout import Layout, LayoutEvent, LayoutUpdate

from .utils import poll


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -152,10 +153,12 @@ def run_in_thread(self, host: str, port: int, *args: Any, **kwargs: Any) -> None
thread.start()

def wait_until_started(self, timeout: Optional[float] = 3.0) -> None:
while self._current_thread.is_alive() and (
not hasattr(self, "_server") or not self._server.started
):
time.sleep(0.01)
poll(
f"start {self.app}",
0.01,
timeout,
lambda: hasattr(self, "_server") and self._server.started,
)

def stop(self, timeout: Optional[float] = 3.0) -> None:
self._server.should_exit = True
Expand Down
33 changes: 12 additions & 21 deletions src/idom/server/sanic.py
Expand Up @@ -9,7 +9,7 @@
import logging
from asyncio import Future
from asyncio.events import AbstractEventLoop
from threading import Event, Thread
from threading import Event
from typing import Any, Dict, Optional, Tuple, Union

from mypy_extensions import TypedDict
Expand All @@ -28,6 +28,8 @@
)
from idom.core.layout import Layout, LayoutEvent, LayoutUpdate

from .utils import threaded, wait_on_event


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -103,27 +105,9 @@ def __init__(self, app: Sanic) -> None:
def run(self, host: str, port: int, *args: Any, **kwargs: Any) -> None:
self.app.run(host, port, *args, **kwargs) # pragma: no cover

@threaded
def run_in_thread(self, host: str, port: int, *args: Any, **kwargs: Any) -> None:
thread = Thread(
target=lambda: self._run_in_thread(host, port, *args, *kwargs), daemon=True
)
thread.start()

def wait_until_started(self, timeout: Optional[float] = 3.0) -> None:
self._did_start.wait(timeout)

def stop(self, timeout: Optional[float] = 3.0) -> None:
self._loop.call_soon_threadsafe(self.app.stop)
self._did_stop.wait(timeout)

def _run_in_thread(self, host: str, port: int, *args: Any, **kwargs: Any) -> None:
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
finally:
self._loop = loop
loop = asyncio.get_event_loop()

# what follows was copied from:
# https://github.com/sanic-org/sanic/blob/7028eae083b0da72d09111b9892ddcc00bce7df4/examples/run_async_advanced.py
Expand All @@ -150,6 +134,13 @@ def _run_in_thread(self, host: str, port: int, *args: Any, **kwargs: Any) -> Non
connection.close_if_idle()
server.after_stop()

def wait_until_started(self, timeout: Optional[float] = 3.0) -> None:
wait_on_event(f"start {self.app}", self._did_start, timeout)

def stop(self, timeout: Optional[float] = 3.0) -> None:
self._loop.call_soon_threadsafe(self.app.stop)
wait_on_event(f"stop {self.app}", self._did_stop, timeout)

async def _server_did_start(self, app: Sanic, loop: AbstractEventLoop) -> None:
self._loop = loop
self._did_start.set()
Expand Down
184 changes: 98 additions & 86 deletions src/idom/server/tornado.py
Expand Up @@ -10,7 +10,7 @@
from asyncio import Queue as AsyncQueue
from asyncio.futures import Future
from threading import Event as ThreadEvent
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, List, Optional, Tuple, Type, Union
from urllib.parse import urljoin

from tornado.platform.asyncio import AsyncIOMainLoop
Expand All @@ -23,7 +23,7 @@
from idom.core.dispatcher import dispatch_single_view
from idom.core.layout import Layout, LayoutEvent, LayoutUpdate

from .base import AbstractRenderServer
from .utils import threaded, wait_on_event


_RouteHandlerSpecs = List[Tuple[str, Type[RequestHandler], Any]]
Expand All @@ -32,17 +32,61 @@
class Config(TypedDict, total=False):
"""Render server config for :class:`TornadoRenderServer` subclasses"""

base_url: str
url_prefix: str
serve_static_files: bool
redirect_root_to_index: bool


class TornadoRenderServer(AbstractRenderServer[Application, Config]):
"""A base class for all Tornado render servers"""
def PerClientStateServer(
constructor: ComponentConstructor,
config: Optional[Config] = None,
app: Optional[Application] = None,
) -> TornadoServer:
"""Return a :class:`FastApiServer` where each client has its own state.
_model_stream_handler_type: Type[WebSocketHandler]
Implements the :class:`~idom.server.proto.ServerFactory` protocol
def stop(self, timeout: Optional[float] = None) -> None:
Parameters:
constructor: A component constructor
config: Options for configuring server behavior
app: An application instance (otherwise a default instance is created)
"""
config, app = _setup_config_and_app(config, app)
_add_handler(
app,
config,
_setup_common_routes(config) + _setup_single_view_dispatcher_route(constructor),
)
return TornadoServer(app)


class TornadoServer:
"""A thin wrapper for running a Tornado application
See :class:`idom.server.proto.Server` for more info
"""

_loop: asyncio.AbstractEventLoop

def __init__(self, app: Application) -> None:
self.app = app
self._did_start = ThreadEvent()

def run(self, host: str, port: int, *args: Any, **kwargs: Any) -> None:
self._loop = asyncio.get_event_loop()
AsyncIOMainLoop().install()
self.app.listen(port, host, *args, **kwargs)
self._did_start.set()
asyncio.get_event_loop().run_forever()

@threaded
def run_in_thread(self, host: str, port: int, *args: Any, **kwargs: Any) -> None:
self.run(host, port, *args, **kwargs)

def wait_until_started(self, timeout: Optional[float] = 3.0) -> None:
self._did_start.wait(timeout)

def stop(self, timeout: Optional[float] = 3.0) -> None:
try:
loop = self._loop
except AttributeError: # pragma: no cover
Expand All @@ -57,87 +101,61 @@ def stop() -> None:
did_stop.set()

loop.call_soon_threadsafe(stop)
did_stop.wait(timeout)

def _create_config(self, config: Optional[Config]) -> Config:
new_config: Config = {
"base_url": "",
wait_on_event(f"stop {self.app}", did_stop, timeout)


def _setup_config_and_app(
config: Optional[Config], app: Optional[Application]
) -> Tuple[Config, Application]:
return (
{
"url_prefix": "",
"serve_static_files": True,
"redirect_root_to_index": True,
**(config or {}), # type: ignore
}
return new_config

def _default_application(self, config: Config) -> Application:
return Application()

def _setup_application(
self,
config: Config,
app: Application,
) -> None:
base_url = config["base_url"]
app.add_handlers(
r".*",
[
(urljoin(base_url, route_pattern),) + tuple(handler_info) # type: ignore
for route_pattern, *handler_info in self._create_route_handlers(config)
],
)
},
app or Application(),
)

def _setup_application_did_start_event(
self, config: Config, app: Application, event: ThreadEvent
) -> None:
pass

def _create_route_handlers(self, config: Config) -> _RouteHandlerSpecs:
handlers: _RouteHandlerSpecs = [
def _setup_common_routes(config: Config) -> _RouteHandlerSpecs:
handlers: _RouteHandlerSpecs = []
if config["serve_static_files"]:
handlers.append(
(
"/stream",
self._model_stream_handler_type,
{"component_constructor": self._root_component_constructor},
)
]

if config["serve_static_files"]:
handlers.append(
(
r"/client/(.*)",
StaticFileHandler,
{"path": str(IDOM_CLIENT_BUILD_DIR.current)},
)
r"/client/(.*)",
StaticFileHandler,
{"path": str(IDOM_CLIENT_BUILD_DIR.current)},
)
if config["redirect_root_to_index"]:
handlers.append(("/", RedirectHandler, {"url": "./client/index.html"}))

return handlers

def _run_application(
self,
config: Config,
app: Application,
host: str,
port: int,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> None:
self._loop = asyncio.get_event_loop()
AsyncIOMainLoop().install()
app.listen(port, host, *args, **kwargs)
self._server_did_start.set()
asyncio.get_event_loop().run_forever()

def _run_application_in_thread(
self,
config: Config,
app: Application,
host: str,
port: int,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> None:
asyncio.set_event_loop(asyncio.new_event_loop())
self._run_application(config, app, host, port, args, kwargs)
)
if config["redirect_root_to_index"]:
handlers.append(("/", RedirectHandler, {"url": "./client/index.html"}))
return handlers


def _add_handler(
app: Application, config: Config, handlers: _RouteHandlerSpecs
) -> None:
app.add_handlers(
r".*",
[
(urljoin(config["url_prefix"], route_pattern),) + tuple(handler_info)
for route_pattern, *handler_info in handlers
],
)


def _setup_single_view_dispatcher_route(
constructor: ComponentConstructor,
) -> _RouteHandlerSpecs:
return [
(
"/stream",
PerClientStateModelStreamHandler,
{"component_constructor": constructor},
)
]


class PerClientStateModelStreamHandler(WebSocketHandler):
Expand Down Expand Up @@ -176,9 +194,3 @@ async def on_message(self, message: Union[str, bytes]) -> None:
def on_close(self) -> None:
if not self._dispatch_future.done():
self._dispatch_future.cancel()


class PerClientStateServer(TornadoRenderServer):
"""Each client view will have its own state."""

_model_stream_handler_type = PerClientStateModelStreamHandler
45 changes: 44 additions & 1 deletion src/idom/server/utils.py
@@ -1,7 +1,11 @@
import asyncio
import time
from contextlib import closing
from functools import wraps
from importlib import import_module
from socket import socket
from typing import Any, List
from threading import Event, Thread
from typing import Any, Callable, List, Optional, TypeVar, cast

from .proto import ServerFactory

Expand All @@ -14,6 +18,45 @@
]


_Func = TypeVar("_Func", bound=Callable[..., None])


def threaded(function: _Func) -> _Func:
@wraps(function)
def wrapper(*args: Any, **kwargs: Any) -> None:
def target() -> None:
asyncio.set_event_loop(asyncio.new_event_loop())
function(*args, **kwargs)

Thread(target=target, daemon=True).start()

return None

return cast(_Func, wrapper)


def wait_on_event(description: str, event: Event, timeout: float) -> None:
if not event.wait(timeout):
raise TimeoutError(f"Did not {description} within {timeout} seconds")


def poll(
description: str,
frequency: float,
timeout: Optional[float],
function: Callable[[], bool],
) -> None:
if timeout is not None:
expiry = time.time() + timeout
while not function():
if time.time() > expiry:
raise TimeoutError(f"Did not {description} within {timeout} seconds")
time.sleep(frequency)
else:
while not function():
time.sleep(frequency)


def find_builtin_server_type(type_name: str) -> ServerFactory[Any, Any]:
"""Find first installed server implementation"""
installed_builtins: List[str] = []
Expand Down

0 comments on commit 16c9209

Please sign in to comment.