Skip to content

Commit

Permalink
Support Quart 0.19
Browse files Browse the repository at this point in the history
This has required mostly small changes throughout.
  • Loading branch information
pgjones committed Oct 7, 2023
1 parent 1ab5125 commit 4c59acf
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 28 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ ignore_missing_imports = true
pydata_sphinx_theme = { version = "*", optional = true }
python = ">=3.8"
hypercorn = { version = ">=0.12.0", extras = ["trio"] }
quart = ">=0.18"
quart = ">=0.19"
trio = ">=0.19.0"
exceptiongroup = ">=1.0.0"

Expand Down
71 changes: 52 additions & 19 deletions src/quart_trio/app.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import sys
import warnings
from typing import Any, Awaitable, Callable, Coroutine, Optional, Union
from typing import Any, Awaitable, Callable, Coroutine, Optional, TypeVar, Union

import trio
from exceptiongroup import BaseExceptionGroup
from hypercorn.config import Config as HyperConfig
from hypercorn.trio import serve
from quart import Quart, request_started, websocket_started
from quart.ctx import copy_current_app_context, RequestContext, WebsocketContext
from quart.ctx import RequestContext, WebsocketContext
from quart.signals import got_serving_exception
from quart.typing import FilePath, ResponseReturnValue
from quart.utils import file_path_to_path
from quart.wrappers import Request, Response, Websocket
Expand All @@ -18,6 +20,14 @@
from .utils import run_sync
from .wrappers import TrioRequest, TrioResponse, TrioWebsocket

try:
from typing import ParamSpec
except ImportError:
from typing_extensions import ParamSpec # type: ignore

T = TypeVar("T")
P = ParamSpec("P")


class QuartTrio(Quart):
nursery: trio.Nursery
Expand All @@ -28,7 +38,7 @@ class QuartTrio(Quart):
request_class = TrioRequest
response_class = TrioResponse
test_app_class = TrioTestApp
test_client_class = TrioClient
test_client_class = TrioClient # type: ignore
websocket_class = TrioWebsocket

def run( # type: ignore
Expand Down Expand Up @@ -114,7 +124,7 @@ def run_task(

return serve(self, config, shutdown_trigger=shutdown_trigger)

def sync_to_async(self, func: Callable[..., Any]) -> Callable[..., Awaitable[Any]]:
def sync_to_async(self, func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
"""Return a async function that will run the synchronous function *func*.
This can be used as so,::
Expand All @@ -140,6 +150,9 @@ async def handle_request(self, request: Request) -> Union[Response, WerkzeugResp
return await self.handle_exception(error) # type: ignore
except Exception as error:
return await self.handle_exception(error)
finally:
if request.scope.get("_quart._preserve_context", False):
self._preserved_context = request_context.copy()

async def full_dispatch_request(
self, request_context: Optional[RequestContext] = None
Expand All @@ -150,9 +163,10 @@ async def full_dispatch_request(
request_context: The request context, optional as Flask
omits this argument.
"""
await self.try_trigger_before_first_request_functions()
await request_started.send(self)
try:
await request_started.send_async(self, _sync_wrapper=self.ensure_async)

result: ResponseReturnValue | HTTPException | None
result = await self.preprocess_request(request_context)
if result is None:
result = await self.dispatch_request(request_context)
Expand Down Expand Up @@ -190,6 +204,9 @@ async def handle_websocket(
return await self.handle_websocket_exception(error) # type: ignore
except Exception as error:
return await self.handle_websocket_exception(error)
finally:
if websocket.scope.get("_quart._preserve_context", False):
self._preserved_context = websocket_context.copy()

async def full_dispatch_websocket(
self, websocket_context: Optional[WebsocketContext] = None
Expand All @@ -200,9 +217,10 @@ async def full_dispatch_websocket(
websocket_context: The websocket context, optional to match
the Flask convention.
"""
await self.try_trigger_before_first_request_functions()
await websocket_started.send(self)
try:
await websocket_started.send_async(self, _sync_wrapper=self.ensure_async)

result: ResponseReturnValue | HTTPException | None
result = await self.preprocess_websocket(websocket_context)
if result is None:
result = await self.dispatch_websocket(websocket_context)
Expand Down Expand Up @@ -244,20 +262,35 @@ async def open_resource(
def add_background_task(self, func: Callable, *args: Any, **kwargs: Any) -> None:
async def _wrapper() -> None:
try:
await copy_current_app_context(func)(*args, **kwargs)
async with self.app_context():
await self.ensure_async(func)(*args, **kwargs)
except (BaseExceptionGroup, Exception) as error:
await self.handle_background_exception(error) # type: ignore

self.nursery.start_soon(_wrapper)

async def shutdown(self) -> None:
async with self.app_context():
for func in self.after_serving_funcs:
await self.ensure_async(func)()
for gen in self.while_serving_gens:
try:
await gen.__anext__()
except StopAsyncIteration:
pass
else:
raise RuntimeError("While serving generator didn't terminate")
if self.config["BACKGROUND_TASK_SHUTDOWN_TIMEOUT"] is not None:
self.nursery.cancel_scope.deadline = (
trio.current_time() + self.config["BACKGROUND_TASK_SHUTDOWN_TIMEOUT"]
)
else:
self.nursery.cancel_scope.cancel()

try:
async with self.app_context():
for func in self.after_serving_funcs:
await self.ensure_async(func)()
for gen in self.while_serving_gens:
try:
await gen.__anext__()
except StopAsyncIteration:
pass
else:
raise RuntimeError("While serving generator didn't terminate")
except Exception as error:
await got_serving_exception.send_async(
self, _sync_wrapper=self.ensure_async, exception=error
)
self.log_exception(sys.exc_info())
raise
25 changes: 18 additions & 7 deletions src/quart_trio/asgi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from functools import partial
from typing import cast, Optional, TYPE_CHECKING, Union
from urllib.parse import urlparse

import trio
from exceptiongroup import BaseExceptionGroup
Expand Down Expand Up @@ -73,7 +74,7 @@ async def handle_messages( # type: ignore
event = await receive()
if event["type"] == "websocket.receive":
message = event.get("bytes") or event["text"]
await websocket_received.send(message)
await websocket_received.send_async(message, _sync_wrapper=self.app.ensure_async)
await self.send_channel.send(message)
elif event["type"] == "websocket.disconnect":
break
Expand All @@ -83,10 +84,20 @@ def _create_websocket_from_scope(self, send: ASGISendCallable) -> Websocket:
headers = Headers()
headers["Remote-Addr"] = (self.scope.get("client") or ["<local>"])[0]
for name, value in self.scope["headers"]:
headers.add(name.decode().title(), value.decode())
headers.add(name.decode("latin1").title(), value.decode("latin1"))

path = self.scope["path"]
path = path if path[0] == "/" else urlparse(path).path
root_path = self.scope.get("root_path", "")
if root_path != "":
try:
path = path.split(root_path, 1)[1]
path = " " if path == "" else path
except IndexError:
path = " " # Invalid in paths, hence will result in 404

return self.app.websocket_class(
self.scope["path"],
path,
self.scope["query_string"],
self.scope["scheme"],
headers,
Expand All @@ -97,7 +108,7 @@ def _create_websocket_from_scope(self, send: ASGISendCallable) -> Websocket:
partial(self.send_data, send),
partial(self.accept_connection, send),
partial(self.close_connection, send),
self.scope,
scope=self.scope,
)

async def handle_websocket( # type: ignore
Expand All @@ -124,7 +135,7 @@ async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -
cast(
LifespanStartupFailedEvent,
{"type": "lifespan.startup.failed", "message": str(error)},
)
),
)
else:
await send(
Expand All @@ -140,13 +151,13 @@ async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -
cast(
LifespanShutdownFailedEvent,
{"type": "lifespan.shutdown.failed", "message": str(error)},
)
),
)
else:
await send(
cast(
LifespanShutdownCompleteEvent,
{"type": "lifespan.shutdown.complete"},
)
),
)
break
2 changes: 1 addition & 1 deletion src/quart_trio/wrappers/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ async def get_data(
self.body.clear()

if as_text:
return raw_data.decode(self.charset, self.encoding_errors)
return raw_data.decode()
else:
return raw_data

Expand Down

0 comments on commit 4c59acf

Please sign in to comment.