Skip to content

Commit 16c9209

Browse files
committed
refactor tornado + misc fixes to sanic/fastapi
1 parent 04ae50a commit 16c9209

File tree

4 files changed

+162
-113
lines changed

4 files changed

+162
-113
lines changed

src/idom/server/fastapi.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import json
99
import logging
1010
import sys
11-
import time
1211
from asyncio import Future
1312
from threading import Event, Thread, current_thread
1413
from typing import Any, Dict, Optional, Tuple, Union
@@ -35,6 +34,8 @@
3534
)
3635
from idom.core.layout import Layout, LayoutEvent, LayoutUpdate
3736

37+
from .utils import poll
38+
3839

3940
logger = logging.getLogger(__name__)
4041

@@ -152,10 +153,12 @@ def run_in_thread(self, host: str, port: int, *args: Any, **kwargs: Any) -> None
152153
thread.start()
153154

154155
def wait_until_started(self, timeout: Optional[float] = 3.0) -> None:
155-
while self._current_thread.is_alive() and (
156-
not hasattr(self, "_server") or not self._server.started
157-
):
158-
time.sleep(0.01)
156+
poll(
157+
f"start {self.app}",
158+
0.01,
159+
timeout,
160+
lambda: hasattr(self, "_server") and self._server.started,
161+
)
159162

160163
def stop(self, timeout: Optional[float] = 3.0) -> None:
161164
self._server.should_exit = True

src/idom/server/sanic.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import logging
1010
from asyncio import Future
1111
from asyncio.events import AbstractEventLoop
12-
from threading import Event, Thread
12+
from threading import Event
1313
from typing import Any, Dict, Optional, Tuple, Union
1414

1515
from mypy_extensions import TypedDict
@@ -28,6 +28,8 @@
2828
)
2929
from idom.core.layout import Layout, LayoutEvent, LayoutUpdate
3030

31+
from .utils import threaded, wait_on_event
32+
3133

3234
logger = logging.getLogger(__name__)
3335

@@ -103,27 +105,9 @@ def __init__(self, app: Sanic) -> None:
103105
def run(self, host: str, port: int, *args: Any, **kwargs: Any) -> None:
104106
self.app.run(host, port, *args, **kwargs) # pragma: no cover
105107

108+
@threaded
106109
def run_in_thread(self, host: str, port: int, *args: Any, **kwargs: Any) -> None:
107-
thread = Thread(
108-
target=lambda: self._run_in_thread(host, port, *args, *kwargs), daemon=True
109-
)
110-
thread.start()
111-
112-
def wait_until_started(self, timeout: Optional[float] = 3.0) -> None:
113-
self._did_start.wait(timeout)
114-
115-
def stop(self, timeout: Optional[float] = 3.0) -> None:
116-
self._loop.call_soon_threadsafe(self.app.stop)
117-
self._did_stop.wait(timeout)
118-
119-
def _run_in_thread(self, host: str, port: int, *args: Any, **kwargs: Any) -> None:
120-
try:
121-
loop = asyncio.get_event_loop()
122-
except RuntimeError:
123-
loop = asyncio.new_event_loop()
124-
asyncio.set_event_loop(loop)
125-
finally:
126-
self._loop = loop
110+
loop = asyncio.get_event_loop()
127111

128112
# what follows was copied from:
129113
# https://github.com/sanic-org/sanic/blob/7028eae083b0da72d09111b9892ddcc00bce7df4/examples/run_async_advanced.py
@@ -150,6 +134,13 @@ def _run_in_thread(self, host: str, port: int, *args: Any, **kwargs: Any) -> Non
150134
connection.close_if_idle()
151135
server.after_stop()
152136

137+
def wait_until_started(self, timeout: Optional[float] = 3.0) -> None:
138+
wait_on_event(f"start {self.app}", self._did_start, timeout)
139+
140+
def stop(self, timeout: Optional[float] = 3.0) -> None:
141+
self._loop.call_soon_threadsafe(self.app.stop)
142+
wait_on_event(f"stop {self.app}", self._did_stop, timeout)
143+
153144
async def _server_did_start(self, app: Sanic, loop: AbstractEventLoop) -> None:
154145
self._loop = loop
155146
self._did_start.set()

src/idom/server/tornado.py

Lines changed: 98 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from asyncio import Queue as AsyncQueue
1111
from asyncio.futures import Future
1212
from threading import Event as ThreadEvent
13-
from typing import Any, Dict, List, Optional, Tuple, Type, Union
13+
from typing import Any, List, Optional, Tuple, Type, Union
1414
from urllib.parse import urljoin
1515

1616
from tornado.platform.asyncio import AsyncIOMainLoop
@@ -23,7 +23,7 @@
2323
from idom.core.dispatcher import dispatch_single_view
2424
from idom.core.layout import Layout, LayoutEvent, LayoutUpdate
2525

26-
from .base import AbstractRenderServer
26+
from .utils import threaded, wait_on_event
2727

2828

2929
_RouteHandlerSpecs = List[Tuple[str, Type[RequestHandler], Any]]
@@ -32,17 +32,61 @@
3232
class Config(TypedDict, total=False):
3333
"""Render server config for :class:`TornadoRenderServer` subclasses"""
3434

35-
base_url: str
35+
url_prefix: str
3636
serve_static_files: bool
3737
redirect_root_to_index: bool
3838

3939

40-
class TornadoRenderServer(AbstractRenderServer[Application, Config]):
41-
"""A base class for all Tornado render servers"""
40+
def PerClientStateServer(
41+
constructor: ComponentConstructor,
42+
config: Optional[Config] = None,
43+
app: Optional[Application] = None,
44+
) -> TornadoServer:
45+
"""Return a :class:`FastApiServer` where each client has its own state.
4246
43-
_model_stream_handler_type: Type[WebSocketHandler]
47+
Implements the :class:`~idom.server.proto.ServerFactory` protocol
4448
45-
def stop(self, timeout: Optional[float] = None) -> None:
49+
Parameters:
50+
constructor: A component constructor
51+
config: Options for configuring server behavior
52+
app: An application instance (otherwise a default instance is created)
53+
"""
54+
config, app = _setup_config_and_app(config, app)
55+
_add_handler(
56+
app,
57+
config,
58+
_setup_common_routes(config) + _setup_single_view_dispatcher_route(constructor),
59+
)
60+
return TornadoServer(app)
61+
62+
63+
class TornadoServer:
64+
"""A thin wrapper for running a Tornado application
65+
66+
See :class:`idom.server.proto.Server` for more info
67+
"""
68+
69+
_loop: asyncio.AbstractEventLoop
70+
71+
def __init__(self, app: Application) -> None:
72+
self.app = app
73+
self._did_start = ThreadEvent()
74+
75+
def run(self, host: str, port: int, *args: Any, **kwargs: Any) -> None:
76+
self._loop = asyncio.get_event_loop()
77+
AsyncIOMainLoop().install()
78+
self.app.listen(port, host, *args, **kwargs)
79+
self._did_start.set()
80+
asyncio.get_event_loop().run_forever()
81+
82+
@threaded
83+
def run_in_thread(self, host: str, port: int, *args: Any, **kwargs: Any) -> None:
84+
self.run(host, port, *args, **kwargs)
85+
86+
def wait_until_started(self, timeout: Optional[float] = 3.0) -> None:
87+
self._did_start.wait(timeout)
88+
89+
def stop(self, timeout: Optional[float] = 3.0) -> None:
4690
try:
4791
loop = self._loop
4892
except AttributeError: # pragma: no cover
@@ -57,87 +101,61 @@ def stop() -> None:
57101
did_stop.set()
58102

59103
loop.call_soon_threadsafe(stop)
60-
did_stop.wait(timeout)
61104

62-
def _create_config(self, config: Optional[Config]) -> Config:
63-
new_config: Config = {
64-
"base_url": "",
105+
wait_on_event(f"stop {self.app}", did_stop, timeout)
106+
107+
108+
def _setup_config_and_app(
109+
config: Optional[Config], app: Optional[Application]
110+
) -> Tuple[Config, Application]:
111+
return (
112+
{
113+
"url_prefix": "",
65114
"serve_static_files": True,
66115
"redirect_root_to_index": True,
67116
**(config or {}), # type: ignore
68-
}
69-
return new_config
70-
71-
def _default_application(self, config: Config) -> Application:
72-
return Application()
73-
74-
def _setup_application(
75-
self,
76-
config: Config,
77-
app: Application,
78-
) -> None:
79-
base_url = config["base_url"]
80-
app.add_handlers(
81-
r".*",
82-
[
83-
(urljoin(base_url, route_pattern),) + tuple(handler_info) # type: ignore
84-
for route_pattern, *handler_info in self._create_route_handlers(config)
85-
],
86-
)
117+
},
118+
app or Application(),
119+
)
87120

88-
def _setup_application_did_start_event(
89-
self, config: Config, app: Application, event: ThreadEvent
90-
) -> None:
91-
pass
92121

93-
def _create_route_handlers(self, config: Config) -> _RouteHandlerSpecs:
94-
handlers: _RouteHandlerSpecs = [
122+
def _setup_common_routes(config: Config) -> _RouteHandlerSpecs:
123+
handlers: _RouteHandlerSpecs = []
124+
if config["serve_static_files"]:
125+
handlers.append(
95126
(
96-
"/stream",
97-
self._model_stream_handler_type,
98-
{"component_constructor": self._root_component_constructor},
99-
)
100-
]
101-
102-
if config["serve_static_files"]:
103-
handlers.append(
104-
(
105-
r"/client/(.*)",
106-
StaticFileHandler,
107-
{"path": str(IDOM_CLIENT_BUILD_DIR.current)},
108-
)
127+
r"/client/(.*)",
128+
StaticFileHandler,
129+
{"path": str(IDOM_CLIENT_BUILD_DIR.current)},
109130
)
110-
if config["redirect_root_to_index"]:
111-
handlers.append(("/", RedirectHandler, {"url": "./client/index.html"}))
112-
113-
return handlers
114-
115-
def _run_application(
116-
self,
117-
config: Config,
118-
app: Application,
119-
host: str,
120-
port: int,
121-
args: Tuple[Any, ...],
122-
kwargs: Dict[str, Any],
123-
) -> None:
124-
self._loop = asyncio.get_event_loop()
125-
AsyncIOMainLoop().install()
126-
app.listen(port, host, *args, **kwargs)
127-
self._server_did_start.set()
128-
asyncio.get_event_loop().run_forever()
129-
130-
def _run_application_in_thread(
131-
self,
132-
config: Config,
133-
app: Application,
134-
host: str,
135-
port: int,
136-
args: Tuple[Any, ...],
137-
kwargs: Dict[str, Any],
138-
) -> None:
139-
asyncio.set_event_loop(asyncio.new_event_loop())
140-
self._run_application(config, app, host, port, args, kwargs)
131+
)
132+
if config["redirect_root_to_index"]:
133+
handlers.append(("/", RedirectHandler, {"url": "./client/index.html"}))
134+
return handlers
135+
136+
137+
def _add_handler(
138+
app: Application, config: Config, handlers: _RouteHandlerSpecs
139+
) -> None:
140+
app.add_handlers(
141+
r".*",
142+
[
143+
(urljoin(config["url_prefix"], route_pattern),) + tuple(handler_info)
144+
for route_pattern, *handler_info in handlers
145+
],
146+
)
147+
148+
149+
def _setup_single_view_dispatcher_route(
150+
constructor: ComponentConstructor,
151+
) -> _RouteHandlerSpecs:
152+
return [
153+
(
154+
"/stream",
155+
PerClientStateModelStreamHandler,
156+
{"component_constructor": constructor},
157+
)
158+
]
141159

142160

143161
class PerClientStateModelStreamHandler(WebSocketHandler):
@@ -176,9 +194,3 @@ async def on_message(self, message: Union[str, bytes]) -> None:
176194
def on_close(self) -> None:
177195
if not self._dispatch_future.done():
178196
self._dispatch_future.cancel()
179-
180-
181-
class PerClientStateServer(TornadoRenderServer):
182-
"""Each client view will have its own state."""
183-
184-
_model_stream_handler_type = PerClientStateModelStreamHandler

src/idom/server/utils.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
import asyncio
2+
import time
13
from contextlib import closing
4+
from functools import wraps
25
from importlib import import_module
36
from socket import socket
4-
from typing import Any, List
7+
from threading import Event, Thread
8+
from typing import Any, Callable, List, Optional, TypeVar, cast
59

610
from .proto import ServerFactory
711

@@ -14,6 +18,45 @@
1418
]
1519

1620

21+
_Func = TypeVar("_Func", bound=Callable[..., None])
22+
23+
24+
def threaded(function: _Func) -> _Func:
25+
@wraps(function)
26+
def wrapper(*args: Any, **kwargs: Any) -> None:
27+
def target() -> None:
28+
asyncio.set_event_loop(asyncio.new_event_loop())
29+
function(*args, **kwargs)
30+
31+
Thread(target=target, daemon=True).start()
32+
33+
return None
34+
35+
return cast(_Func, wrapper)
36+
37+
38+
def wait_on_event(description: str, event: Event, timeout: float) -> None:
39+
if not event.wait(timeout):
40+
raise TimeoutError(f"Did not {description} within {timeout} seconds")
41+
42+
43+
def poll(
44+
description: str,
45+
frequency: float,
46+
timeout: Optional[float],
47+
function: Callable[[], bool],
48+
) -> None:
49+
if timeout is not None:
50+
expiry = time.time() + timeout
51+
while not function():
52+
if time.time() > expiry:
53+
raise TimeoutError(f"Did not {description} within {timeout} seconds")
54+
time.sleep(frequency)
55+
else:
56+
while not function():
57+
time.sleep(frequency)
58+
59+
1760
def find_builtin_server_type(type_name: str) -> ServerFactory[Any, Any]:
1861
"""Find first installed server implementation"""
1962
installed_builtins: List[str] = []

0 commit comments

Comments
 (0)