Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix BaseHTTPMiddleware when used without Starlette #1

Merged
merged 2 commits into from
May 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 1 addition & 8 deletions starlette/applications.py
@@ -1,8 +1,5 @@
import typing

import anyio
from anyio.abc import TaskGroup

from starlette.datastructures import State, URLPath
from starlette.exceptions import ExceptionMiddleware
from starlette.middleware import Middleware
Expand Down Expand Up @@ -39,8 +36,6 @@ class Starlette:
standard functions, or async functions.
"""

task_group: TaskGroup

def __init__(
self,
debug: bool = False,
Expand Down Expand Up @@ -114,9 +109,7 @@ def url_path_for(self, name: str, **path_params: str) -> URLPath:

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
scope["app"] = self
self.task_group = anyio.create_task_group()
async with self.task_group:
await self.middleware_stack(scope, receive, send)
await self.middleware_stack(scope, receive, send)

# The following usages are now discouraged in favour of configuration
#  during Starlette.__init__(...)
Expand Down
64 changes: 32 additions & 32 deletions starlette/middleware/base.py
Expand Up @@ -22,38 +22,38 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
return

request = Request(scope, receive=receive)
response = await self.dispatch_func(request, self.call_next)
await response(scope, receive, send)

async def call_next(self, request: Request) -> Response:
send_stream, recv_stream = anyio.create_memory_object_stream()
scope = request.scope

async def coro() -> None:
async with send_stream:
await self.app(scope, request.receive, send_stream.send)

scope["app"].task_group.start_soon(coro)

try:
message = await recv_stream.receive()
except anyio.EndOfStream:
raise RuntimeError("No response returned.")

assert message["type"] == "http.response.start"

async def body_stream() -> typing.AsyncGenerator[bytes, None]:
async with recv_stream:
async for message in recv_stream:
assert message["type"] == "http.response.body"
yield message.get("body", b"")

response = StreamingResponse(
status_code=message["status"], content=body_stream()
)
response.raw_headers = message["headers"]
return response
async def call_next(request: Request) -> Response:
send_stream, recv_stream = anyio.create_memory_object_stream()

async def coro() -> None:
async with send_stream:
await self.app(scope, request.receive, send_stream.send)

task_group.start_soon(coro)

try:
message = await recv_stream.receive()
except anyio.EndOfStream:
raise RuntimeError("No response returned.")

assert message["type"] == "http.response.start"

async def body_stream() -> typing.AsyncGenerator[bytes, None]:
async with recv_stream:
async for message in recv_stream:
assert message["type"] == "http.response.body"
yield message.get("body", b"")

response = StreamingResponse(
status_code=message["status"], content=body_stream()
)
response.raw_headers = message["headers"]
return response

async with anyio.create_task_group() as task_group:
request = Request(scope, receive=receive)
response = await self.dispatch_func(request, call_next)
await response(scope, receive, send)

async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
Expand Down
5 changes: 4 additions & 1 deletion starlette/testclient.py
Expand Up @@ -8,6 +8,7 @@
import queue
import types
import typing
from concurrent.futures import Future
from urllib.parse import unquote, urljoin, urlsplit

import anyio
Expand Down Expand Up @@ -284,7 +285,7 @@ def __enter__(self) -> "WebSocketTestSession":
)

try:
self.portal.start_task_soon(self._run)
_: "Future[None]" = self.portal.start_task_soon(self._run)
self.send({"type": "websocket.connect"})
message = self.receive()
self._raise_on_close(message)
Expand Down Expand Up @@ -385,6 +386,8 @@ class TestClient(requests.Session):
"backend_options": {},
} # type: typing.Dict[str, typing.Any]

task: "Future[None]"

def __init__(
self,
app: typing.Union[ASGI2App, ASGI3App],
Expand Down