Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Begin middleware revamp #2550

Merged
merged 13 commits into from
Sep 21, 2022
163 changes: 83 additions & 80 deletions sanic/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,10 @@ def url_for(self, view_name: str, **kwargs):
# -------------------------------------------------------------------- #

async def handle_exception(
self, request: Request, exception: BaseException
self,
request: Request,
exception: BaseException,
run_middleware: bool = True,
): # no cov
"""
A handler that catches specific exceptions and outputs a response.
Expand All @@ -718,6 +721,7 @@ async def handle_exception(
:param exception: The exception that was raised
:raises ServerError: response 500
"""
response = None
await self.dispatch(
"http.lifecycle.exception",
inline=True,
Expand Down Expand Up @@ -758,9 +762,11 @@ async def handle_exception(
# -------------------------------------------- #
# Request Middleware
# -------------------------------------------- #
response = await self._run_request_middleware(
request, request_name=None
)
if run_middleware:
middleware = (
request.route and request.route.extra.request_middleware
) or self.request_middleware
response = await self._run_request_middleware(request, middleware)
# No middleware results
if not response:
try:
Expand Down Expand Up @@ -840,7 +846,13 @@ async def handle_request(self, request: Request): # no cov

# Define `response` var here to remove warnings about
# allocation before assignment below.
response = None
response: Optional[
Union[
BaseHTTPResponse,
Coroutine[Any, Any, Optional[BaseHTTPResponse]],
]
] = None
run_middleware = True
prryplatypus marked this conversation as resolved.
Show resolved Hide resolved
try:

await self.dispatch(
Expand Down Expand Up @@ -885,9 +897,11 @@ async def handle_request(self, request: Request): # no cov
# -------------------------------------------- #
# Request Middleware
# -------------------------------------------- #
response = await self._run_request_middleware(
request, request_name=route.name
)
run_middleware = False
if request.route.extra.request_middleware:
response = await self._run_request_middleware(
request, request.route.extra.request_middleware
)

# No middleware results
if not response:
Expand Down Expand Up @@ -928,7 +942,7 @@ async def handle_request(self, request: Request): # no cov
if request.stream is not None:
response = request.stream.response
elif response is not None:
response = await request.respond(response)
response = await request.respond(response) # type: ignore
elif not hasattr(handler, "is_websocket"):
response = request.stream.response # type: ignore

Expand All @@ -946,7 +960,7 @@ async def handle_request(self, request: Request): # no cov
...
await response.send(end_stream=True)
elif isinstance(response, ResponseStream):
resp = await response(request)
resp = await response(request) # type: ignore
ahopkins marked this conversation as resolved.
Show resolved Hide resolved
await self.dispatch(
"http.lifecycle.response",
inline=True,
Expand All @@ -955,7 +969,7 @@ async def handle_request(self, request: Request): # no cov
"response": resp,
},
)
await response.eof()
await response.eof() # type: ignore
else:
if not hasattr(handler, "is_websocket"):
raise ServerError(
Expand All @@ -967,7 +981,9 @@ async def handle_request(self, request: Request): # no cov
raise
except Exception as e:
# Response Generation Failed
await self.handle_exception(request, e)
await self.handle_exception(
request, e, run_middleware=run_middleware
)

async def _websocket_handler(
self, handler, request, *args, subprotocols=None, **kwargs
Expand Down Expand Up @@ -1036,86 +1052,72 @@ def asgi_client(self): # noqa
# -------------------------------------------------------------------- #

async def _run_request_middleware(
self, request, request_name=None
self, request, middleware_collection
): # no cov
# The if improves speed. I don't know why
named_middleware = self.named_request_middleware.get(
request_name, deque()
)
applicable_middleware = self.request_middleware + named_middleware

# request.request_middleware_started is meant as a stop-gap solution
# until RFC 1630 is adopted
if applicable_middleware and not request.request_middleware_started:
request.request_middleware_started = True
request._request_middleware_started = True

for middleware in applicable_middleware:
await self.dispatch(
"http.middleware.before",
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
for middleware in middleware_collection:
await self.dispatch(
"http.middleware.before",
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)

response = middleware(request)
if isawaitable(response):
response = await response
response = middleware(request)
if isawaitable(response):
response = await response

await self.dispatch(
"http.middleware.after",
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
await self.dispatch(
"http.middleware.after",
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)

if response:
return response
if response:
return response
return None

async def _run_response_middleware(
self, request, response, request_name=None
self, request, response, middleware_collection
): # no cov
named_middleware = self.named_response_middleware.get(
request_name, deque()
)
applicable_middleware = self.response_middleware + named_middleware
if applicable_middleware:
for middleware in applicable_middleware:
await self.dispatch(
"http.middleware.before",
inline=True,
context={
"request": request,
"response": response,
},
condition={"attach_to": "response"},
)
for middleware in middleware_collection:
await self.dispatch(
"http.middleware.before",
inline=True,
context={
"request": request,
"response": response,
},
condition={"attach_to": "response"},
)

_response = middleware(request, response)
if isawaitable(_response):
_response = await _response
_response = middleware(request, response)
if isawaitable(_response):
_response = await _response

await self.dispatch(
"http.middleware.after",
inline=True,
context={
"request": request,
"response": _response if _response else response,
},
condition={"attach_to": "response"},
)
await self.dispatch(
"http.middleware.after",
inline=True,
context={
"request": request,
"response": _response if _response else response,
},
condition={"attach_to": "response"},
)

if _response:
response = _response
if isinstance(response, BaseHTTPResponse):
response = request.stream.respond(response)
break
if _response:
response = _response
if isinstance(response, BaseHTTPResponse):
response = request.stream.respond(response)
break
return response

def _build_endpoint_name(self, *parts):
Expand Down Expand Up @@ -1528,6 +1530,7 @@ def finalize(self):
except FinalizationError as e:
if not Sanic.test_mode:
raise e
self.finalize_middleware()

def signalize(self, allow_fail_builtin=True):
self.signal_router.allow_fail_builtin = allow_fail_builtin
Expand Down
5 changes: 4 additions & 1 deletion sanic/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,4 +234,7 @@ async def __call__(self) -> None:
self.stage = Stage.HANDLER
await self.sanic_app.handle_request(self.request)
except Exception as e:
await self.sanic_app.handle_exception(self.request, e)
try:
await self.sanic_app.handle_exception(self.request, e)
except Exception as exc:
await self.sanic_app.handle_exception(self.request, exc, False)
5 changes: 4 additions & 1 deletion sanic/http/http1.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,10 @@ async def error_response(self, exception: Exception) -> None:
if self.request is None:
self.create_empty_request()

await app.handle_exception(self.request, exception)
try:
await app.handle_exception(self.request, exception)
except Exception as e:
await app.handle_exception(self.request, e, False)

def create_empty_request(self) -> None:
"""
Expand Down
66 changes: 66 additions & 0 deletions sanic/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from __future__ import annotations

from collections import deque
from enum import IntEnum, auto
from itertools import count
from typing import Deque, Sequence, Union

from sanic.models.handler_types import MiddlewareType


class MiddlewareLocation(IntEnum):
REQUEST = auto()
RESPONSE = auto()


class Middleware:
_counter = count()

__slots__ = ("func", "priority", "location", "definition")

def __init__(
self,
func: MiddlewareType,
location: MiddlewareLocation,
priority: int = 0,
) -> None:
self.func = func
self.priority = priority
self.location = location
self.definition = next(Middleware._counter)

def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"func=<function {self.func.__name__}>, "
f"priority={self.priority}, "
f"location={self.location.name})"
)

@property
def order(self):
return (self.priority, -self.definition)

@classmethod
def convert(
cls,
*middleware_collections: Sequence[Union[Middleware, MiddlewareType]],
location: MiddlewareLocation,
) -> Deque[Middleware]:
return deque(
[
middleware
if isinstance(middleware, Middleware)
else Middleware(middleware, location)
for collection in middleware_collections
for middleware in collection
]
)

@classmethod
def reset_count(cls):
cls._counter = count()
cls.count = next(cls._counter)