Skip to content

Commit

Permalink
Centralize error handling in ExceptionMiddleware (#1754)
Browse files Browse the repository at this point in the history
I was writing the documentation on exception handling, and I noticed
that it was very hard to explain our current behavior.

Error handlers can be registered either on the internal Flask app (not
the Starlette one) or on the Connexion app, which leads to some
undefined (actually just really hard to explain) behavior. Eg.
- Registering error handlers on a status code would capture
`starlette.HTTPException` errors on the Connexion app, and
`werkzeug.HTTPException` errors on the Flask App, which means that
registering an error handler on a status code doesn't catch all the
errors with that status code.
- Flask does some default error handling which leads to some exceptions
never reaching the error handlers registered on the Connexion app.

So I made the following changes:
- Replaced the default error handlers we registered on the Flask app
with a default handler on the `ExceptionMiddleware` that takes into
account other handlers registered on status codes.
- Configured Flask to propagate exceptions instead of catching them.
- Abstracted away the Starlette `Request` and `Response` types, so users
can and must now use `ConnexionRequest`
  and `ConnexionResponse` types in error handlers.
- Renamed the `ASGIRequest` class to `ConnexionRequest` since it is the
only Request class part of the high level
  Connexion interface.

We could also rename `ConnexionRequest` and `ConnexionResponse` to just
`Request` and `Response`. Wdyt?
  • Loading branch information
RobbeSneyders committed Oct 29, 2023
1 parent 1b72019 commit b9ba13c
Show file tree
Hide file tree
Showing 17 changed files with 179 additions and 110 deletions.
10 changes: 8 additions & 2 deletions connexion/apps/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
from starlette.types import ASGIApp, Receive, Scope, Send

from connexion.jsonifier import Jsonifier
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
from connexion.middleware import ConnexionMiddleware, MiddlewarePosition, SpecMiddleware
from connexion.middleware.lifespan import Lifespan
from connexion.options import SwaggerUIOptions
from connexion.resolver import Resolver
from connexion.types import MaybeAwaitable
from connexion.uri_parsing import AbstractURIParser


Expand Down Expand Up @@ -250,14 +252,18 @@ def decorator(func: t.Callable) -> t.Callable:

@abc.abstractmethod
def add_error_handler(
self, code_or_exception: t.Union[int, t.Type[Exception]], function: t.Callable
self,
code_or_exception: t.Union[int, t.Type[Exception]],
function: t.Callable[
[ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse]
],
) -> None:
"""
Register a callable to handle application errors.
:param code_or_exception: An exception class or the status code of HTTP exceptions to
handle.
:param function: Callable that will handle exception.
:param function: Callable that will handle exception, may be async.
"""

def test_client(self, **kwargs):
Expand Down
12 changes: 9 additions & 3 deletions connexion/apps/asynchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
from connexion.apps.abstract import AbstractApp
from connexion.decorators import StarletteDecorator
from connexion.jsonifier import Jsonifier
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
from connexion.middleware.lifespan import Lifespan
from connexion.operations import AbstractOperation
from connexion.options import SwaggerUIOptions
from connexion.resolver import Resolver
from connexion.types import MaybeAwaitable
from connexion.uri_parsing import AbstractURIParser

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -88,7 +90,7 @@ def make_operation(self, operation: AbstractOperation) -> AsyncOperation:
)


class AsyncMiddlewareApp(RoutedMiddleware[AsyncApi]):
class AsyncASGIApp(RoutedMiddleware[AsyncApi]):

api_cls = AsyncApi

Expand Down Expand Up @@ -176,7 +178,7 @@ def __init__(
:param security_map: A dictionary of security handlers to use. Defaults to
:obj:`security.SECURITY_HANDLERS`
"""
self._middleware_app: AsyncMiddlewareApp = AsyncMiddlewareApp()
self._middleware_app: AsyncASGIApp = AsyncASGIApp()

super().__init__(
import_name,
Expand Down Expand Up @@ -205,6 +207,10 @@ def add_url_rule(
)

def add_error_handler(
self, code_or_exception: t.Union[int, t.Type[Exception]], function: t.Callable
self,
code_or_exception: t.Union[int, t.Type[Exception]],
function: t.Callable[
[ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse]
],
) -> None:
self.middleware.add_error_handler(code_or_exception, function)
51 changes: 15 additions & 36 deletions connexion/apps/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,22 @@
import typing as t

import flask
import werkzeug.exceptions
from a2wsgi import WSGIMiddleware
from flask import Response as FlaskResponse
from flask import signals
from starlette.types import Receive, Scope, Send

from connexion.apps.abstract import AbstractApp
from connexion.decorators import FlaskDecorator
from connexion.exceptions import InternalServerError, ProblemException, ResolverError
from connexion.exceptions import ResolverError
from connexion.frameworks import flask as flask_utils
from connexion.jsonifier import Jsonifier
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
from connexion.middleware.abstract import AbstractRoutingAPI, SpecMiddleware
from connexion.middleware.lifespan import Lifespan
from connexion.operations import AbstractOperation
from connexion.options import SwaggerUIOptions
from connexion.problem import problem
from connexion.resolver import Resolver
from connexion.types import MaybeAwaitable
from connexion.uri_parsing import AbstractURIParser


Expand Down Expand Up @@ -117,44 +116,20 @@ def add_url_rule(
return self.blueprint.add_url_rule(rule, endpoint, view_func, **options)


class FlaskMiddlewareApp(SpecMiddleware):
class FlaskASGIApp(SpecMiddleware):
def __init__(self, import_name, server_args: dict, **kwargs):
self.app = flask.Flask(import_name, **server_args)
self.app.json = flask_utils.FlaskJSONProvider(self.app)
self.app.url_map.converters["float"] = flask_utils.NumberConverter
self.app.url_map.converters["int"] = flask_utils.IntegerConverter

self.set_errors_handlers()
# Propagate Errors so we can handle them in the middleware
self.app.config["PROPAGATE_EXCEPTIONS"] = True
self.app.config["TRAP_BAD_REQUEST_ERRORS"] = True
self.app.config["TRAP_HTTP_EXCEPTIONS"] = True

self.asgi_app = WSGIMiddleware(self.app.wsgi_app)

def set_errors_handlers(self):
for error_code in werkzeug.exceptions.default_exceptions:
self.app.register_error_handler(error_code, self.common_error_handler)

self.app.register_error_handler(ProblemException, self.common_error_handler)

def common_error_handler(self, exception: Exception) -> FlaskResponse:
"""Default error handler."""
if isinstance(exception, ProblemException):
response = exception.to_problem()
else:
if not isinstance(exception, werkzeug.exceptions.HTTPException):
exception = InternalServerError()

response = problem(
title=exception.name,
detail=exception.description,
status=exception.code,
)

if response.status_code >= 500:
signals.got_request_exception.send(self.app, exception=exception)

return flask.make_response(
(response.body, response.status_code, response.headers)
)

def add_api(self, specification, *, name: str = None, **kwargs):
api = FlaskApi(specification, **kwargs)

Expand All @@ -177,7 +152,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
class FlaskApp(AbstractApp):
"""Connexion Application based on ConnexionMiddleware wrapping a Flask application."""

_middleware_app: FlaskMiddlewareApp
_middleware_app: FlaskASGIApp

def __init__(
self,
Expand Down Expand Up @@ -237,7 +212,7 @@ def __init__(
:param security_map: A dictionary of security handlers to use. Defaults to
:obj:`security.SECURITY_HANDLERS`
"""
self._middleware_app = FlaskMiddlewareApp(import_name, server_args or {})
self._middleware_app = FlaskASGIApp(import_name, server_args or {})
self.app = self._middleware_app.app
super().__init__(
import_name,
Expand Down Expand Up @@ -266,6 +241,10 @@ def add_url_rule(
)

def add_error_handler(
self, code_or_exception: t.Union[int, t.Type[Exception]], function: t.Callable
self,
code_or_exception: t.Union[int, t.Type[Exception]],
function: t.Callable[
[ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse]
],
) -> None:
self.app.register_error_handler(code_or_exception, function)
4 changes: 2 additions & 2 deletions connexion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from starlette.types import Receive, Scope
from werkzeug.local import LocalProxy

from connexion.lifecycle import ASGIRequest
from connexion.lifecycle import ConnexionRequest
from connexion.operations import AbstractOperation

UNBOUND_MESSAGE = (
Expand All @@ -25,5 +25,5 @@
scope = LocalProxy(_scope, unbound_message=UNBOUND_MESSAGE)

request = LocalProxy(
lambda: ASGIRequest(scope, receive), unbound_message=UNBOUND_MESSAGE
lambda: ConnexionRequest(scope, receive), unbound_message=UNBOUND_MESSAGE
)
8 changes: 4 additions & 4 deletions connexion/decorators/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from connexion.context import context, operation
from connexion.frameworks.abstract import Framework
from connexion.http_facts import FORM_CONTENT_TYPES
from connexion.lifecycle import ASGIRequest, WSGIRequest
from connexion.lifecycle import ConnexionRequest, WSGIRequest
from connexion.operations import AbstractOperation, Swagger2Operation
from connexion.utils import (
deep_merge,
Expand All @@ -43,7 +43,7 @@ def __init__(

def _maybe_get_body(
self,
request: t.Union[WSGIRequest, ASGIRequest],
request: t.Union[WSGIRequest, ConnexionRequest],
*,
arguments: t.List[str],
has_kwargs: bool,
Expand Down Expand Up @@ -95,7 +95,7 @@ def __call__(self, function: t.Callable) -> t.Callable:
arguments, has_kwargs = inspect_function_arguments(unwrapped_function)

@functools.wraps(function)
async def wrapper(request: ASGIRequest) -> t.Any:
async def wrapper(request: ConnexionRequest) -> t.Any:
request_body = self._maybe_get_body(
request, arguments=arguments, has_kwargs=has_kwargs
)
Expand All @@ -118,7 +118,7 @@ async def wrapper(request: ASGIRequest) -> t.Any:


def prep_kwargs(
request: t.Union[WSGIRequest, ASGIRequest],
request: t.Union[WSGIRequest, ConnexionRequest],
*,
request_body: t.Any,
files: t.Dict[str, t.Any],
Expand Down
6 changes: 3 additions & 3 deletions connexion/frameworks/starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from starlette.types import Receive, Scope

from connexion.frameworks.abstract import Framework
from connexion.lifecycle import ASGIRequest
from connexion.lifecycle import ConnexionRequest
from connexion.uri_parsing import AbstractURIParser


Expand Down Expand Up @@ -48,8 +48,8 @@ def build_response(
)

@staticmethod
def get_request(*, scope: Scope, receive: Receive, uri_parser: AbstractURIParser, **kwargs) -> ASGIRequest: # type: ignore
return ASGIRequest(scope, receive, uri_parser=uri_parser)
def get_request(*, scope: Scope, receive: Receive, uri_parser: AbstractURIParser, **kwargs) -> ConnexionRequest: # type: ignore
return ConnexionRequest(scope, receive, uri_parser=uri_parser)


PATH_PARAMETER = re.compile(r"\{([^}]*)\}")
Expand Down
19 changes: 16 additions & 3 deletions connexion/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __getattr__(self, item):
return getattr(self._werkzeug_request, item)


class ASGIRequest(_RequestInterface):
class ConnexionRequest(_RequestInterface):
"""
Implementation of the Connexion :code:`_RequestInterface` representing an ASGI request.
Expand All @@ -142,7 +142,9 @@ class ASGIRequest(_RequestInterface):
"""

def __init__(self, *args, uri_parser=None, **kwargs):
self._starlette_request = StarletteRequest(*args, **kwargs)
# Might be set in `from_starlette_request` class method
if not hasattr(self, "_starlette_request"):
self._starlette_request = StarletteRequest(*args, **kwargs)
self.uri_parser = uri_parser

self._context = None
Expand All @@ -152,6 +154,16 @@ def __init__(self, *args, uri_parser=None, **kwargs):
self._form = None
self._files = None

@classmethod
def from_starlette_request(
cls, request: StarletteRequest, uri_parser=None
) -> "ConnexionRequest":
# Instantiate the class, and set the `_starlette_request` property before initializing.
self = cls.__new__(cls)
self._starlette_request = request
self.__init__(uri_parser=uri_parser) # type: ignore
return self

@property
def context(self):
if self._context is None:
Expand Down Expand Up @@ -226,7 +238,8 @@ async def get_body(self):
return await self.body() or None

def __getattr__(self, item):
return getattr(self._starlette_request, item)
if self.__getattribute__("_starlette_request"):
return getattr(self._starlette_request, item)


class ConnexionResponse:
Expand Down

0 comments on commit b9ba13c

Please sign in to comment.