Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lifespan middleware #1676

Merged
merged 2 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions connexion/apps/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from connexion.jsonifier import Jsonifier
from connexion.middleware import ConnexionMiddleware, SpecMiddleware
from connexion.middleware.lifespan import Lifespan
from connexion.resolver import Resolver
from connexion.uri_parsing import AbstractURIParser

Expand All @@ -32,8 +33,9 @@ def __init__(
self,
import_name: str,
*,
specification_dir: t.Union[pathlib.Path, str] = "",
lifespan: t.Optional[Lifespan] = None,
middlewares: t.Optional[list] = None,
specification_dir: t.Union[pathlib.Path, str] = "",
arguments: t.Optional[dict] = None,
auth_all_paths: t.Optional[bool] = None,
jsonifier: t.Optional[Jsonifier] = None,
Expand All @@ -50,11 +52,11 @@ def __init__(
:param import_name: The name of the package or module that this object belongs to. If you
are using a single module, __name__ is always the correct value. If you however are
using a package, it’s usually recommended to hardcode the name of your package there.
:param middlewares: The list of middlewares to wrap around the application. Defaults to
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
:param specification_dir: The directory holding the specification(s). The provided path
should either be absolute or relative to the root path of the application. Defaults to
the root path.
:param middlewares: The list of middlewares to wrap around the application. Defaults to
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
:param arguments: Arguments to substitute the specification using Jinja.
:param auth_all_paths: whether to authenticate not paths not defined in the specification.
Defaults to False.
Expand All @@ -79,8 +81,9 @@ def __init__(
self.middleware = ConnexionMiddleware(
self.middleware_app,
import_name=import_name,
specification_dir=specification_dir,
lifespan=lifespan,
middlewares=middlewares,
specification_dir=specification_dir,
arguments=arguments,
auth_all_paths=auth_all_paths,
jsonifier=jsonifier,
Expand Down
11 changes: 7 additions & 4 deletions connexion/apps/asynchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from connexion.decorators import StarletteDecorator
from connexion.jsonifier import Jsonifier
from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
from connexion.middleware.lifespan import Lifespan
from connexion.operations import AbstractOperation
from connexion.resolver import Resolver
from connexion.uri_parsing import AbstractURIParser
Expand Down Expand Up @@ -120,8 +121,9 @@ def __init__(
self,
import_name: str,
*,
specification_dir: t.Union[pathlib.Path, str] = "",
lifespan: t.Optional[Lifespan] = None,
middlewares: t.Optional[list] = None,
specification_dir: t.Union[pathlib.Path, str] = "",
arguments: t.Optional[dict] = None,
auth_all_paths: t.Optional[bool] = None,
jsonifier: t.Optional[Jsonifier] = None,
Expand All @@ -138,11 +140,11 @@ def __init__(
:param import_name: The name of the package or module that this object belongs to. If you
are using a single module, __name__ is always the correct value. If you however are
using a package, it’s usually recommended to hardcode the name of your package there.
:param middlewares: The list of middlewares to wrap around the application. Defaults to
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
:param specification_dir: The directory holding the specification(s). The provided path
should either be absolute or relative to the root path of the application. Defaults to
the root path.
:param middlewares: The list of middlewares to wrap around the application. Defaults to
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
:param arguments: Arguments to substitute the specification using Jinja.
:param auth_all_paths: whether to authenticate not paths not defined in the specification.
Defaults to False.
Expand All @@ -168,8 +170,9 @@ def __init__(

super().__init__(
import_name,
specification_dir=specification_dir,
lifespan=lifespan,
middlewares=middlewares,
specification_dir=specification_dir,
arguments=arguments,
auth_all_paths=auth_all_paths,
jsonifier=jsonifier,
Expand Down
13 changes: 9 additions & 4 deletions connexion/apps/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from connexion.frameworks import flask as flask_utils
from connexion.jsonifier import Jsonifier
from connexion.middleware.abstract import AbstractRoutingAPI, SpecMiddleware
from connexion.middleware.lifespan import Lifespan
from connexion.operations import AbstractOperation
from connexion.problem import problem
from connexion.resolver import Resolver
Expand Down Expand Up @@ -176,9 +177,10 @@ def __init__(
self,
import_name: str,
*,
lifespan: t.Optional[Lifespan] = None,
middlewares: t.Optional[list] = None,
server_args: t.Optional[dict] = None,
specification_dir: t.Union[pathlib.Path, str] = "",
middlewares: t.Optional[list] = None,
arguments: t.Optional[dict] = None,
auth_all_paths: t.Optional[bool] = None,
jsonifier: t.Optional[Jsonifier] = None,
Expand All @@ -195,12 +197,14 @@ def __init__(
:param import_name: The name of the package or module that this object belongs to. If you
are using a single module, __name__ is always the correct value. If you however are
using a package, it’s usually recommended to hardcode the name of your package there.
:param lifespan: A lifespan context function, which can be used to perform startup and
shutdown tasks.
:param middlewares: The list of middlewares to wrap around the application. Defaults to
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
:param server_args: Arguments to pass to the Flask application.
:param specification_dir: The directory holding the specification(s). The provided path
should either be absolute or relative to the root path of the application. Defaults to
the root path.
:param middlewares: The list of middlewares to wrap around the application. Defaults to
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
:param arguments: Arguments to substitute the specification using Jinja.
:param auth_all_paths: whether to authenticate not paths not defined in the specification.
Defaults to False.
Expand All @@ -226,8 +230,9 @@ def __init__(
self.app = self.middleware_app.app
super().__init__(
import_name,
specification_dir=specification_dir,
lifespan=lifespan,
middlewares=middlewares,
specification_dir=specification_dir,
arguments=arguments,
auth_all_paths=auth_all_paths,
jsonifier=jsonifier,
Expand Down
2 changes: 1 addition & 1 deletion connexion/middleware/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ class RoutedMiddleware(SpecMiddleware, t.Generic[API]):
api_cls: t.Type[API]
"""The subclass of RoutedAPI this middleware uses."""

def __init__(self, app: ASGIApp) -> None:
def __init__(self, app: ASGIApp, **kwargs) -> None:
self.app = app
self.apis: t.Dict[str, API] = {}

Expand Down
6 changes: 3 additions & 3 deletions connexion/middleware/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from starlette.exceptions import HTTPException
from starlette.requests import Request as StarletteRequest
from starlette.responses import Response
from starlette.types import Receive, Scope, Send
from starlette.types import ASGIApp, Receive, Scope, Send

from connexion.exceptions import InternalServerError, ProblemException, problem

Expand All @@ -15,8 +15,8 @@ class ExceptionMiddleware(StarletteExceptionMiddleware):
"""Subclass of starlette ExceptionMiddleware to change handling of HTTP exceptions to
existing connexion behavior."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, next_app: ASGIApp, *args, **kwargs):
super().__init__(next_app)
self.add_exception_handler(ProblemException, self.problem_handler)
self.add_exception_handler(Exception, self.common_error_handler)

Expand Down
28 changes: 28 additions & 0 deletions connexion/middleware/lifespan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import typing as t

from starlette.routing import Router
from starlette.types import ASGIApp, Receive, Scope, Send

Lifespan = t.Callable[[t.Any], t.AsyncContextManager]


class LifespanMiddleware:
"""
Middleware that adds support for Starlette lifespan handlers
(https://www.starlette.io/lifespan/).
"""

def __init__(
self, next_app: ASGIApp, *, lifespan: t.Optional[Lifespan], **kwargs
) -> None:
self.next_app = next_app
self._lifespan = lifespan
# Leverage a Starlette Router for lifespan handling only
self.router = Router(lifespan=lifespan)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
# If no lifespan is registered, pass to next app so it can be handled downstream.
if scope["type"] == "lifespan" and self._lifespan:
await self.router(scope, receive, send)
else:
await self.next_app(scope, receive, send)
19 changes: 12 additions & 7 deletions connexion/middleware/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from connexion.middleware.abstract import SpecMiddleware
from connexion.middleware.context import ContextMiddleware
from connexion.middleware.exceptions import ExceptionMiddleware
from connexion.middleware.lifespan import Lifespan, LifespanMiddleware
from connexion.middleware.request_validation import RequestValidationMiddleware
from connexion.middleware.response_validation import ResponseValidationMiddleware
from connexion.middleware.routing import RoutingMiddleware
Expand Down Expand Up @@ -92,15 +93,17 @@ class ConnexionMiddleware:
RequestValidationMiddleware,
ResponseValidationMiddleware,
ContextMiddleware,
LifespanMiddleware,
]

def __init__(
self,
app: ASGIApp,
*,
import_name: t.Optional[str] = None,
specification_dir: t.Union[pathlib.Path, str] = "",
lifespan: t.Optional[Lifespan] = None,
middlewares: t.Optional[list] = None,
specification_dir: t.Union[pathlib.Path, str] = "",
arguments: t.Optional[dict] = None,
auth_all_paths: t.Optional[bool] = None,
jsonifier: t.Optional[Jsonifier] = None,
Expand All @@ -117,11 +120,11 @@ def __init__(
:param import_name: The name of the package or module that this object belongs to. If you
are using a single module, __name__ is always the correct value. If you however are
using a package, it’s usually recommended to hardcode the name of your package there.
:param middlewares: The list of middlewares to wrap around the application. Defaults to
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
:param specification_dir: The directory holding the specification(s). The provided path
should either be absolute or relative to the root path of the application. Defaults to
the root path.
:param middlewares: The list of middlewares to wrap around the application. Defaults to
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
:param arguments: Arguments to substitute the specification using Jinja.
:param auth_all_paths: whether to authenticate not paths not defined in the specification.
Defaults to False.
Expand Down Expand Up @@ -150,7 +153,9 @@ def __init__(

if middlewares is None:
middlewares = self.default_middlewares
self.app, self.apps = self._apply_middlewares(app, middlewares)
self.app, self.apps = self._apply_middlewares(
app, middlewares, lifespan=lifespan
)

self.options = _Options(
arguments=arguments,
Expand All @@ -177,9 +182,8 @@ def ensure_absolute(self, path: t.Union[str, pathlib.Path]) -> pathlib.Path:
else:
return self.root_path / path

@staticmethod
def _apply_middlewares(
app: ASGIApp, middlewares: t.List[t.Type[ASGIApp]]
self, app: ASGIApp, middlewares: t.List[t.Type[ASGIApp]], **kwargs
) -> t.Tuple[ASGIApp, t.Iterable[ASGIApp]]:
"""Apply all middlewares to the provided app.

Expand All @@ -193,13 +197,14 @@ def _apply_middlewares(
# Include the wrapped application in the returned list.
apps = [app]
for middleware in reversed(middlewares):
app = middleware(app) # type: ignore
app = middleware(app, **kwargs) # type: ignore
apps.append(app)
return app, list(reversed(apps))

def add_api(
self,
specification: t.Union[pathlib.Path, str, dict],
*,
base_path: t.Optional[str] = None,
arguments: t.Optional[dict] = None,
auth_all_paths: t.Optional[bool] = None,
Expand Down
2 changes: 1 addition & 1 deletion connexion/middleware/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _add_operation_internal(


class RoutingMiddleware(SpecMiddleware):
def __init__(self, app: ASGIApp) -> None:
def __init__(self, app: ASGIApp, **kwargs) -> None:
"""Middleware that resolves the Operation for an incoming request and attaches it to the
scope.

Expand Down
2 changes: 1 addition & 1 deletion connexion/middleware/swagger_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ async def _get_swagger_ui_config(self, request):


class SwaggerUIMiddleware(SpecMiddleware):
def __init__(self, app: ASGIApp) -> None:
def __init__(self, app: ASGIApp, **kwargs) -> None:
"""Middleware that hosts a swagger UI.

:param app: app to wrap in middleware.
Expand Down
38 changes: 38 additions & 0 deletions tests/test_lifespan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import contextlib
import sys
from unittest import mock

import pytest
from connexion import AsyncApp, ConnexionMiddleware


def test_lifespan_handler(app_class):
m = mock.MagicMock()

@contextlib.asynccontextmanager
async def lifespan(app):
m.startup()
yield
m.shutdown()

app = AsyncApp(__name__, lifespan=lifespan)
with app.test_client():
m.startup.assert_called()
m.shutdown.assert_not_called()
m.shutdown.assert_called()


@pytest.mark.skipif(
sys.version_info < (3, 8), reason="AsyncMock only available from 3.8."
)
async def test_lifespan():
"""Test that lifespan events are passed through if no handler is registered."""
lifecycle_handler = mock.Mock()

async def check_lifecycle(scope, receive, send):
if scope["type"] == "lifespan":
lifecycle_handler.handle()

test_app = ConnexionMiddleware(check_lifecycle)
await test_app({"type": "lifespan"}, mock.AsyncMock(), mock.AsyncMock())
lifecycle_handler.handle.assert_called()
18 changes: 1 addition & 17 deletions tests/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class TestMiddleware:

__test__ = False

def __init__(self, app):
def __init__(self, app, **kwargs):
self.app = app

async def __call__(self, scope, receive, send):
Expand Down Expand Up @@ -49,19 +49,3 @@ def test_routing_middleware(middleware_app):
assert (
response.headers.get("operation_id") == "fakeapi.hello.post_greeting"
), response.status_code


@pytest.mark.skipif(
sys.version_info < (3, 8), reason="AsyncMock only available from 3.8."
)
async def test_lifecycle():
"""Test that lifecycle events are passed correctly."""
lifecycle_handler = mock.Mock()

async def check_lifecycle(scope, receive, send):
if scope["type"] == "lifecycle":
lifecycle_handler.handle()

test_app = ConnexionMiddleware(check_lifecycle)
await test_app({"type": "lifecycle"}, mock.AsyncMock, mock.AsyncMock)
lifecycle_handler.handle.assert_called()