Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix form response file saving. Closes #7 #8

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
import os
import sys
from importlib.metadata import version as imp_version
sys.path.insert(0, os.path.abspath('.'))

sys.path.insert(0, os.path.abspath("."))


# -- Project information -----------------------------------------------------

project = 'Quart-Trio'
copyright = '2020, Philip Jones'
author = 'Philip Jones'
project = "Quart-Trio"
copyright = "2020, Philip Jones"
author = "Philip Jones"
version = imp_version("quart-trio")
release = version

Expand All @@ -29,14 +30,14 @@
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = ['sphinx.ext.autodoc', 'sphinx.ext.napoleon']
extensions = ["sphinx.ext.autodoc", "sphinx.ext.napoleon"]

source_suffix = '.rst'
source_suffix = ".rst"

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]


# -- Options for HTML output -------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions src/quart_trio/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def run( # type: ignore
self.run_task, host, port, debug, use_reloader, ca_certs, certfile, keyfile
)

def run_task(
def run_task( # type: ignore[override] # Has extra `use_reloader` param
self,
host: str = "127.0.0.1",
port: int = 5000,
Expand Down Expand Up @@ -168,7 +168,7 @@ async def handle_user_exception(
if isinstance(error, BaseExceptionGroup):
for exception in error.exceptions:
try:
return await self.handle_user_exception(exception) # type: ignore
return await self.handle_user_exception(exception)
except Exception:
pass # No handler for this error
# Not found a single handler, re-raise the error
Expand Down
6 changes: 4 additions & 2 deletions src/quart_trio/asgi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import AnyStr, cast, Optional, TYPE_CHECKING
from typing import cast, Optional, TYPE_CHECKING

import trio
from exceptiongroup import BaseExceptionGroup
Expand Down Expand Up @@ -58,7 +58,9 @@ def __init__(self, app: "QuartTrio", scope: WebsocketScope) -> None:
self.scope = scope
self._accepted = False
self._closed = False
self.send_channel, self.receive_channel = trio.open_memory_channel[AnyStr](10)
self.send_channel: trio.MemorySendChannel
self.receive_channel: trio.MemoryReceiveChannel
self.send_channel, self.receive_channel = trio.open_memory_channel(10)

async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
websocket = self._create_websocket_from_scope(send)
Expand Down
26 changes: 26 additions & 0 deletions src/quart_trio/datastructures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from __future__ import annotations

from os import PathLike

from quart.datastructures import FileStorage
from trio import open_file, Path, wrap_file


class TrioFileStorage(FileStorage):
async def save(self, destination: PathLike, buffer_size: int = 16384) -> None: # type: ignore
wrapped_stream = wrap_file(self.stream)
async with await open_file(destination, "wb") as file_:
data = await wrapped_stream.read(buffer_size)
while data != b"":
await file_.write(data)
data = await wrapped_stream.read(buffer_size)

async def load(self, source: PathLike, buffer_size: int = 16384) -> None:
path = Path(source)
self.filename = path.name
wrapped_stream = wrap_file(self.stream)
async with await open_file(path, "rb") as file_:
data = await file_.read(buffer_size)
while data != b"":
await wrapped_stream.write(data)
data = await file_.read(buffer_size)
7 changes: 7 additions & 0 deletions src/quart_trio/formparser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from quart.formparser import FormDataParser

from quart_trio.datastructures import TrioFileStorage


class TrioFormDataParser(FormDataParser):
file_storage_class = TrioFileStorage
6 changes: 3 additions & 3 deletions src/quart_trio/testing/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ def __init__(self, app: Quart, scope: WebsocketScope) -> None:
self.scope = scope
self.status_code: Optional[int] = None
self._server_send, self._server_receive = trio.open_memory_channel[dict](10)
self._client_send, self._client_receive = trio.open_memory_channel[
Union[AnyStr, Exception]
](10)
self._client_send: trio.MemorySendChannel
self._client_receive: trio.MemoryReceiveChannel
self._client_send, self._client_receive = trio.open_memory_channel(10)
self._task: Awaitable[None] = None

async def __aenter__(self) -> TestWebsocketConnectionProtocol:
Expand Down
3 changes: 3 additions & 0 deletions src/quart_trio/wrappers/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from quart.wrappers.request import Body, Request
from werkzeug.exceptions import RequestEntityTooLarge, RequestTimeout

from ..formparser import TrioFormDataParser


class EventWrapper:
def __init__(self) -> None:
Expand Down Expand Up @@ -44,6 +46,7 @@ def __init__(

class TrioRequest(Request):
body_class = TrioBody
form_data_parser_class = TrioFormDataParser
lock_class = trio.Lock # type: ignore

async def get_data(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pathlib import Path

import py
import py # type: ignore[import] # Stubs do not currently exist
import pytest
from quart import abort, Quart, ResponseReturnValue, send_file, websocket
from quart.testing import WebsocketResponseError
Expand Down