diff --git a/sanic/app.py b/sanic/app.py index 85d5fbcedd..26f1b9d91f 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -709,7 +709,10 @@ def url_for(self, view_name: str, **kwargs): # -------------------------------------------------------------------- # async def handle_exception( - self, request: Request, exception: BaseException + self, + request: Request, + exception: BaseException, + run_middleware: bool = True, ): # no cov """ A handler that catches specific exceptions and outputs a response. @@ -718,6 +721,7 @@ async def handle_exception( :param exception: The exception that was raised :raises ServerError: response 500 """ + response = None await self.dispatch( "http.lifecycle.exception", inline=True, @@ -758,9 +762,11 @@ async def handle_exception( # -------------------------------------------- # # Request Middleware # -------------------------------------------- # - response = await self._run_request_middleware( - request, request_name=None - ) + if run_middleware: + middleware = ( + request.route and request.route.extra.request_middleware + ) or self.request_middleware + response = await self._run_request_middleware(request, middleware) # No middleware results if not response: try: @@ -840,7 +846,13 @@ async def handle_request(self, request: Request): # no cov # Define `response` var here to remove warnings about # allocation before assignment below. - response = None + response: Optional[ + Union[ + BaseHTTPResponse, + Coroutine[Any, Any, Optional[BaseHTTPResponse]], + ] + ] = None + run_middleware = True try: await self.dispatch( @@ -885,9 +897,11 @@ async def handle_request(self, request: Request): # no cov # -------------------------------------------- # # Request Middleware # -------------------------------------------- # - response = await self._run_request_middleware( - request, request_name=route.name - ) + run_middleware = False + if request.route.extra.request_middleware: + response = await self._run_request_middleware( + request, request.route.extra.request_middleware + ) # No middleware results if not response: @@ -928,7 +942,7 @@ async def handle_request(self, request: Request): # no cov if request.stream is not None: response = request.stream.response elif response is not None: - response = await request.respond(response) + response = await request.respond(response) # type: ignore elif not hasattr(handler, "is_websocket"): response = request.stream.response # type: ignore @@ -946,7 +960,7 @@ async def handle_request(self, request: Request): # no cov ... await response.send(end_stream=True) elif isinstance(response, ResponseStream): - resp = await response(request) + resp = await response(request) # type: ignore await self.dispatch( "http.lifecycle.response", inline=True, @@ -955,7 +969,7 @@ async def handle_request(self, request: Request): # no cov "response": resp, }, ) - await response.eof() + await response.eof() # type: ignore else: if not hasattr(handler, "is_websocket"): raise ServerError( @@ -967,7 +981,9 @@ async def handle_request(self, request: Request): # no cov raise except Exception as e: # Response Generation Failed - await self.handle_exception(request, e) + await self.handle_exception( + request, e, run_middleware=run_middleware + ) async def _websocket_handler( self, handler, request, *args, subprotocols=None, **kwargs @@ -1036,86 +1052,72 @@ def asgi_client(self): # noqa # -------------------------------------------------------------------- # async def _run_request_middleware( - self, request, request_name=None + self, request, middleware_collection ): # no cov - # The if improves speed. I don't know why - named_middleware = self.named_request_middleware.get( - request_name, deque() - ) - applicable_middleware = self.request_middleware + named_middleware - - # request.request_middleware_started is meant as a stop-gap solution - # until RFC 1630 is adopted - if applicable_middleware and not request.request_middleware_started: - request.request_middleware_started = True + request._request_middleware_started = True - for middleware in applicable_middleware: - await self.dispatch( - "http.middleware.before", - inline=True, - context={ - "request": request, - "response": None, - }, - condition={"attach_to": "request"}, - ) + for middleware in middleware_collection: + await self.dispatch( + "http.middleware.before", + inline=True, + context={ + "request": request, + "response": None, + }, + condition={"attach_to": "request"}, + ) - response = middleware(request) - if isawaitable(response): - response = await response + response = middleware(request) + if isawaitable(response): + response = await response - await self.dispatch( - "http.middleware.after", - inline=True, - context={ - "request": request, - "response": None, - }, - condition={"attach_to": "request"}, - ) + await self.dispatch( + "http.middleware.after", + inline=True, + context={ + "request": request, + "response": None, + }, + condition={"attach_to": "request"}, + ) - if response: - return response + if response: + return response return None async def _run_response_middleware( - self, request, response, request_name=None + self, request, response, middleware_collection ): # no cov - named_middleware = self.named_response_middleware.get( - request_name, deque() - ) - applicable_middleware = self.response_middleware + named_middleware - if applicable_middleware: - for middleware in applicable_middleware: - await self.dispatch( - "http.middleware.before", - inline=True, - context={ - "request": request, - "response": response, - }, - condition={"attach_to": "response"}, - ) + for middleware in middleware_collection: + await self.dispatch( + "http.middleware.before", + inline=True, + context={ + "request": request, + "response": response, + }, + condition={"attach_to": "response"}, + ) - _response = middleware(request, response) - if isawaitable(_response): - _response = await _response + _response = middleware(request, response) + if isawaitable(_response): + _response = await _response - await self.dispatch( - "http.middleware.after", - inline=True, - context={ - "request": request, - "response": _response if _response else response, - }, - condition={"attach_to": "response"}, - ) + await self.dispatch( + "http.middleware.after", + inline=True, + context={ + "request": request, + "response": _response if _response else response, + }, + condition={"attach_to": "response"}, + ) - if _response: - response = _response - if isinstance(response, BaseHTTPResponse): - response = request.stream.respond(response) - break + if _response: + response = _response + if isinstance(response, BaseHTTPResponse): + response = request.stream.respond(response) + break return response def _build_endpoint_name(self, *parts): @@ -1528,6 +1530,7 @@ def finalize(self): except FinalizationError as e: if not Sanic.test_mode: raise e + self.finalize_middleware() def signalize(self, allow_fail_builtin=True): self.signal_router.allow_fail_builtin = allow_fail_builtin diff --git a/sanic/asgi.py b/sanic/asgi.py index 10357ae876..de61ed68a5 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -234,4 +234,7 @@ async def __call__(self) -> None: self.stage = Stage.HANDLER await self.sanic_app.handle_request(self.request) except Exception as e: - await self.sanic_app.handle_exception(self.request, e) + try: + await self.sanic_app.handle_exception(self.request, e) + except Exception as exc: + await self.sanic_app.handle_exception(self.request, exc, False) diff --git a/sanic/http/http1.py b/sanic/http/http1.py index 1f2423ad07..ccfae75de6 100644 --- a/sanic/http/http1.py +++ b/sanic/http/http1.py @@ -428,7 +428,10 @@ async def error_response(self, exception: Exception) -> None: if self.request is None: self.create_empty_request() - await app.handle_exception(self.request, exception) + try: + await app.handle_exception(self.request, exception) + except Exception as e: + await app.handle_exception(self.request, e, False) def create_empty_request(self) -> None: """ diff --git a/sanic/middleware.py b/sanic/middleware.py new file mode 100644 index 0000000000..5bbd777b6c --- /dev/null +++ b/sanic/middleware.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from collections import deque +from enum import IntEnum, auto +from itertools import count +from typing import Deque, Sequence, Union + +from sanic.models.handler_types import MiddlewareType + + +class MiddlewareLocation(IntEnum): + REQUEST = auto() + RESPONSE = auto() + + +class Middleware: + _counter = count() + + __slots__ = ("func", "priority", "location", "definition") + + def __init__( + self, + func: MiddlewareType, + location: MiddlewareLocation, + priority: int = 0, + ) -> None: + self.func = func + self.priority = priority + self.location = location + self.definition = next(Middleware._counter) + + def __call__(self, *args, **kwargs): + return self.func(*args, **kwargs) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"func=, " + f"priority={self.priority}, " + f"location={self.location.name})" + ) + + @property + def order(self): + return (self.priority, -self.definition) + + @classmethod + def convert( + cls, + *middleware_collections: Sequence[Union[Middleware, MiddlewareType]], + location: MiddlewareLocation, + ) -> Deque[Middleware]: + return deque( + [ + middleware + if isinstance(middleware, Middleware) + else Middleware(middleware, location) + for collection in middleware_collections + for middleware in collection + ] + ) + + @classmethod + def reset_count(cls): + cls._counter = count() + cls.count = next(cls._counter) diff --git a/sanic/mixins/middleware.py b/sanic/mixins/middleware.py index 5ef9dc77a1..bea3976f26 100644 --- a/sanic/mixins/middleware.py +++ b/sanic/mixins/middleware.py @@ -1,11 +1,17 @@ +from collections import deque from functools import partial +from operator import attrgetter from typing import List from sanic.base.meta import SanicMeta +from sanic.middleware import Middleware, MiddlewareLocation from sanic.models.futures import FutureMiddleware +from sanic.router import Router class MiddlewareMixin(metaclass=SanicMeta): + router: Router + def __init__(self, *args, **kwargs) -> None: self._future_middleware: List[FutureMiddleware] = [] @@ -13,7 +19,12 @@ def _apply_middleware(self, middleware: FutureMiddleware): raise NotImplementedError # noqa def middleware( - self, middleware_or_request, attach_to="request", apply=True + self, + middleware_or_request, + attach_to="request", + apply=True, + *, + priority=0 ): """ Decorate and register middleware to be called before a request @@ -30,6 +41,12 @@ def middleware( def register_middleware(middleware, attach_to="request"): nonlocal apply + location = ( + MiddlewareLocation.REQUEST + if attach_to == "request" + else MiddlewareLocation.RESPONSE + ) + middleware = Middleware(middleware, location, priority=priority) future_middleware = FutureMiddleware(middleware, attach_to) self._future_middleware.append(future_middleware) if apply: @@ -46,7 +63,7 @@ def register_middleware(middleware, attach_to="request"): register_middleware, attach_to=middleware_or_request ) - def on_request(self, middleware=None): + def on_request(self, middleware=None, *, priority=0): """Register a middleware to be called before a request is handled. This is the same as *@app.middleware('request')*. @@ -54,11 +71,13 @@ def on_request(self, middleware=None): :param: middleware: A callable that takes in request. """ if callable(middleware): - return self.middleware(middleware, "request") + return self.middleware(middleware, "request", priority=priority) else: - return partial(self.middleware, attach_to="request") + return partial( + self.middleware, attach_to="request", priority=priority + ) - def on_response(self, middleware=None): + def on_response(self, middleware=None, *, priority=0): """Register a middleware to be called after a response is created. This is the same as *@app.middleware('response')*. @@ -67,6 +86,57 @@ def on_response(self, middleware=None): A callable that takes in a request and its response. """ if callable(middleware): - return self.middleware(middleware, "response") + return self.middleware(middleware, "response", priority=priority) else: - return partial(self.middleware, attach_to="response") + return partial( + self.middleware, attach_to="response", priority=priority + ) + + def finalize_middleware(self): + for route in self.router.routes: + request_middleware = Middleware.convert( + self.request_middleware, + self.named_request_middleware.get(route.name, deque()), + location=MiddlewareLocation.REQUEST, + ) + response_middleware = Middleware.convert( + self.response_middleware, + self.named_response_middleware.get(route.name, deque()), + location=MiddlewareLocation.RESPONSE, + ) + route.extra.request_middleware = deque( + sorted( + request_middleware, + key=attrgetter("order"), + reverse=True, + ) + ) + route.extra.response_middleware = deque( + sorted( + response_middleware, + key=attrgetter("order"), + reverse=True, + )[::-1] + ) + request_middleware = Middleware.convert( + self.request_middleware, + location=MiddlewareLocation.REQUEST, + ) + response_middleware = Middleware.convert( + self.response_middleware, + location=MiddlewareLocation.RESPONSE, + ) + self.request_middleware = deque( + sorted( + request_middleware, + key=attrgetter("order"), + reverse=True, + ) + ) + self.response_middleware = deque( + sorted( + response_middleware, + key=attrgetter("order"), + reverse=True, + )[::-1] + ) diff --git a/sanic/request.py b/sanic/request.py index 0a8ca59eab..f7b3b9993e 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -56,7 +56,7 @@ parse_xforwarded, ) from sanic.http import Stage -from sanic.log import error_logger, logger +from sanic.log import deprecation, error_logger, logger from sanic.models.protocol_types import TransportProtocol from sanic.response import BaseHTTPResponse, HTTPResponse @@ -103,6 +103,7 @@ class Request: "_port", "_protocol", "_remote_addr", + "_request_middleware_started", "_scheme", "_socket", "_stream_id", @@ -126,7 +127,6 @@ class Request: "parsed_token", "raw_url", "responded", - "request_middleware_started", "route", "stream", "transport", @@ -178,7 +178,7 @@ def __init__( self.parsed_not_grouped_args: DefaultDict[ Tuple[bool, bool, str, str], List[Tuple[str, str]] ] = defaultdict(list) - self.request_middleware_started = False + self._request_middleware_started = False self.responded: bool = False self.route: Optional[Route] = None self.stream: Optional[Stream] = None @@ -219,6 +219,16 @@ def get_current(cls) -> Request: def generate_id(*_): return uuid.uuid4() + @property + def request_middleware_started(self): + deprecation( + "Request.request_middleware_started has been deprecated and will" + "be removed. You should set a flag on the request context using" + "either middleware or signals if you need this feature.", + 23.3, + ) + return self._request_middleware_started + @property def stream_id(self): """ @@ -324,9 +334,13 @@ async def add_header(_, response: HTTPResponse): response = await response # type: ignore # Run response middleware try: - response = await self.app._run_response_middleware( - self, response, request_name=self.name - ) + middleware = ( + self.route and self.route.extra.response_middleware + ) or self.app.response_middleware + if middleware: + response = await self.app._run_response_middleware( + self, response, middleware + ) except CancelledErrors: raise except Exception: diff --git a/sanic/worker/inspector.py b/sanic/worker/inspector.py index 6c9869ee18..e49b6851f4 100644 --- a/sanic/worker/inspector.py +++ b/sanic/worker/inspector.py @@ -73,7 +73,7 @@ def stop(self, *_): def state_to_json(self): output = {"info": self.app_info} - output["workers"] = self._make_safe(dict(self.worker_state)) + output["workers"] = self.make_safe(dict(self.worker_state)) return output def reload(self): @@ -84,10 +84,11 @@ def shutdown(self): message = "__TERMINATE__" self._publisher.send(message) - def _make_safe(self, obj: Dict[str, Any]) -> Dict[str, Any]: + @staticmethod + def make_safe(obj: Dict[str, Any]) -> Dict[str, Any]: for key, value in obj.items(): if isinstance(value, dict): - obj[key] = self._make_safe(value) + obj[key] = Inspector.make_safe(value) elif isinstance(value, datetime): obj[key] = value.isoformat() return obj diff --git a/setup.py b/setup.py index 3bc5ea4cc4..c7d2c0e267 100644 --- a/setup.py +++ b/setup.py @@ -84,7 +84,7 @@ def open_local(paths, mode="r", encoding="utf8"): uvloop = "uvloop>=0.5.3" + env_dependency types_ujson = "types-ujson" + env_dependency requirements = [ - "sanic-routing>=22.3.0,<22.6.0", + "sanic-routing>=22.8.0", "httptools>=0.0.10", uvloop, ujson, @@ -94,7 +94,7 @@ def open_local(paths, mode="r", encoding="utf8"): ] tests_require = [ - "sanic-testing>=22.9.0b1", + "sanic-testing>=22.9.0b2", "pytest", "coverage", "beautifulsoup4", diff --git a/tests/test_app.py b/tests/test_app.py index df7f238ff6..333ea51a3b 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -18,6 +18,7 @@ from sanic.helpers import _default from sanic.log import LOGGING_CONFIG_DEFAULTS from sanic.response import text +from sanic.router import Route @pytest.fixture(autouse=True) @@ -152,8 +153,13 @@ async def handler(): def test_app_handle_request_handler_is_none(app: Sanic, monkeypatch): + app.config.TOUCHUP = False + route = Mock(spec=Route) + route.extra.request_middleware = [] + route.extra.response_middleware = [] + def mockreturn(*args, **kwargs): - return Mock(), None, {} + return route, None, {} monkeypatch.setattr(app.router, "get", mockreturn) diff --git a/tests/test_middleware_priority.py b/tests/test_middleware_priority.py new file mode 100644 index 0000000000..9646f6d0f9 --- /dev/null +++ b/tests/test_middleware_priority.py @@ -0,0 +1,90 @@ +from functools import partial + +import pytest + +from sanic import Sanic +from sanic.middleware import Middleware +from sanic.response import json + + +PRIORITY_TEST_CASES = ( + ([0, 1, 2], [1, 1, 1]), + ([0, 1, 2], [1, 1, None]), + ([0, 1, 2], [1, None, None]), + ([0, 1, 2], [2, 1, None]), + ([0, 1, 2], [2, 2, None]), + ([0, 1, 2], [3, 2, 1]), + ([0, 1, 2], [None, None, None]), + ([0, 2, 1], [1, None, 1]), + ([0, 2, 1], [2, None, 1]), + ([0, 2, 1], [2, None, 2]), + ([0, 2, 1], [3, 1, 2]), + ([1, 0, 2], [1, 2, None]), + ([1, 0, 2], [2, 3, 1]), + ([1, 0, 2], [None, 1, None]), + ([1, 2, 0], [1, 3, 2]), + ([1, 2, 0], [None, 1, 1]), + ([1, 2, 0], [None, 2, 1]), + ([1, 2, 0], [None, 2, 2]), + ([2, 0, 1], [1, None, 2]), + ([2, 0, 1], [2, 1, 3]), + ([2, 0, 1], [None, None, 1]), + ([2, 1, 0], [1, 2, 3]), + ([2, 1, 0], [None, 1, 2]), +) + + +@pytest.fixture(autouse=True) +def reset_middleware(): + yield + Middleware.reset_count() + + +@pytest.mark.parametrize( + "expected,priorities", + PRIORITY_TEST_CASES, +) +def test_request_middleware_order_priority(app: Sanic, expected, priorities): + order = [] + + def add_ident(request, ident): + order.append(ident) + + @app.get("/") + def handler(request): + return json(None) + + for ident, priority in enumerate(priorities): + kwargs = {} + if priority is not None: + kwargs["priority"] = priority + app.on_request(partial(add_ident, ident=ident), **kwargs) + + app.test_client.get("/") + + assert order == expected + + +@pytest.mark.parametrize( + "expected,priorities", + PRIORITY_TEST_CASES, +) +def test_response_middleware_order_priority(app: Sanic, expected, priorities): + order = [] + + def add_ident(request, response, ident): + order.append(ident) + + @app.get("/") + def handler(request): + return json(None) + + for ident, priority in enumerate(priorities): + kwargs = {} + if priority is not None: + kwargs["priority"] = priority + app.on_response(partial(add_ident, ident=ident), **kwargs) + + app.test_client.get("/") + + assert order[::-1] == expected