Skip to content

Commit

Permalink
Add LifespanMiddleware
Browse files Browse the repository at this point in the history
  • Loading branch information
RobbeSneyders committed Mar 13, 2023
1 parent 21766f1 commit eb0b8a9
Show file tree
Hide file tree
Showing 11 changed files with 93 additions and 28 deletions.
3 changes: 3 additions & 0 deletions connexion/apps/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from connexion.jsonifier import Jsonifier
from connexion.middleware import ConnexionMiddleware, SpecMiddleware
from connexion.middleware.lifespan import Lifespan
from connexion.resolver import Resolver
from connexion.uri_parsing import AbstractURIParser

Expand All @@ -32,6 +33,7 @@ def __init__(
self,
import_name: str,
*,
lifespan: t.Optional[Lifespan] = None,
middlewares: t.Optional[list] = None,
specification_dir: t.Union[pathlib.Path, str] = "",
arguments: t.Optional[dict] = None,
Expand Down Expand Up @@ -79,6 +81,7 @@ def __init__(
self.middleware = ConnexionMiddleware(
self.middleware_app,
import_name=import_name,
lifespan=lifespan,
middlewares=middlewares,
specification_dir=specification_dir,
arguments=arguments,
Expand Down
5 changes: 4 additions & 1 deletion connexion/apps/asynchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from connexion.decorators import StarletteDecorator
from connexion.jsonifier import Jsonifier
from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
from connexion.middleware.lifespan import Lifespan
from connexion.operations import AbstractOperation
from connexion.resolver import Resolver
from connexion.uri_parsing import AbstractURIParser
Expand Down Expand Up @@ -120,6 +121,7 @@ def __init__(
self,
import_name: str,
*,
lifespan: t.Optional[Lifespan] = None,
middlewares: t.Optional[list] = None,
specification_dir: t.Union[pathlib.Path, str] = "",
arguments: t.Optional[dict] = None,
Expand Down Expand Up @@ -168,8 +170,9 @@ def __init__(

super().__init__(
import_name,
specification_dir=specification_dir,
lifespan=lifespan,
middlewares=middlewares,
specification_dir=specification_dir,
arguments=arguments,
auth_all_paths=auth_all_paths,
jsonifier=jsonifier,
Expand Down
5 changes: 5 additions & 0 deletions connexion/apps/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from connexion.frameworks import flask as flask_utils
from connexion.jsonifier import Jsonifier
from connexion.middleware.abstract import AbstractRoutingAPI, SpecMiddleware
from connexion.middleware.lifespan import Lifespan
from connexion.operations import AbstractOperation
from connexion.problem import problem
from connexion.resolver import Resolver
Expand Down Expand Up @@ -176,6 +177,7 @@ def __init__(
self,
import_name: str,
*,
lifespan: t.Optional[Lifespan] = None,
middlewares: t.Optional[list] = None,
server_args: t.Optional[dict] = None,
specification_dir: t.Union[pathlib.Path, str] = "",
Expand All @@ -195,6 +197,8 @@ def __init__(
:param import_name: The name of the package or module that this object belongs to. If you
are using a single module, __name__ is always the correct value. If you however are
using a package, it’s usually recommended to hardcode the name of your package there.
:param lifespan: A lifespan context function, which can be used to perform startup and
shutdown tasks.
:param middlewares: The list of middlewares to wrap around the application. Defaults to
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
:param server_args: Arguments to pass to the Flask application.
Expand Down Expand Up @@ -226,6 +230,7 @@ def __init__(
self.app = self.middleware_app.app
super().__init__(
import_name,
lifespan=lifespan,
middlewares=middlewares,
specification_dir=specification_dir,
arguments=arguments,
Expand Down
2 changes: 1 addition & 1 deletion connexion/middleware/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ class RoutedMiddleware(SpecMiddleware, t.Generic[API]):
api_cls: t.Type[API]
"""The subclass of RoutedAPI this middleware uses."""

def __init__(self, app: ASGIApp) -> None:
def __init__(self, app: ASGIApp, **kwargs) -> None:
self.app = app
self.apis: t.Dict[str, API] = {}

Expand Down
6 changes: 3 additions & 3 deletions connexion/middleware/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from starlette.exceptions import HTTPException
from starlette.requests import Request as StarletteRequest
from starlette.responses import Response
from starlette.types import Receive, Scope, Send
from starlette.types import ASGIApp, Receive, Scope, Send

from connexion.exceptions import InternalServerError, ProblemException, problem

Expand All @@ -15,8 +15,8 @@ class ExceptionMiddleware(StarletteExceptionMiddleware):
"""Subclass of starlette ExceptionMiddleware to change handling of HTTP exceptions to
existing connexion behavior."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, next_app: ASGIApp, *args, **kwargs):
super().__init__(next_app)
self.add_exception_handler(ProblemException, self.problem_handler)
self.add_exception_handler(Exception, self.common_error_handler)

Expand Down
28 changes: 28 additions & 0 deletions connexion/middleware/lifespan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import typing as t

from starlette.routing import Router
from starlette.types import ASGIApp, Receive, Scope, Send

Lifespan = t.Callable[[t.Any], t.AsyncContextManager]


class LifespanMiddleware:
"""
Middleware that adds support for Starlette lifespan handlers
(https://www.starlette.io/lifespan/).
"""

def __init__(
self, next_app: ASGIApp, *, lifespan: t.Optional[Lifespan], **kwargs
) -> None:
self.next_app = next_app
self._lifespan = lifespan
# Leverage a Starlette Router for lifespan handling only
self.router = Router(lifespan=lifespan)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
# If no lifespan is registered, pass to next app so it can be handled downstream.
if scope["type"] == "lifespan" and self._lifespan:
await self.router(scope, receive, send)
else:
await self.next_app(scope, receive, send)
12 changes: 8 additions & 4 deletions connexion/middleware/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from connexion.middleware.abstract import SpecMiddleware
from connexion.middleware.context import ContextMiddleware
from connexion.middleware.exceptions import ExceptionMiddleware
from connexion.middleware.lifespan import Lifespan, LifespanMiddleware
from connexion.middleware.request_validation import RequestValidationMiddleware
from connexion.middleware.response_validation import ResponseValidationMiddleware
from connexion.middleware.routing import RoutingMiddleware
Expand Down Expand Up @@ -92,13 +93,15 @@ class ConnexionMiddleware:
RequestValidationMiddleware,
ResponseValidationMiddleware,
ContextMiddleware,
LifespanMiddleware,
]

def __init__(
self,
app: ASGIApp,
*,
import_name: t.Optional[str] = None,
lifespan: t.Optional[Lifespan] = None,
middlewares: t.Optional[list] = None,
specification_dir: t.Union[pathlib.Path, str] = "",
arguments: t.Optional[dict] = None,
Expand Down Expand Up @@ -150,7 +153,9 @@ def __init__(

if middlewares is None:
middlewares = self.default_middlewares
self.app, self.apps = self._apply_middlewares(app, middlewares)
self.app, self.apps = self._apply_middlewares(
app, middlewares, lifespan=lifespan
)

self.options = _Options(
arguments=arguments,
Expand All @@ -177,9 +182,8 @@ def ensure_absolute(self, path: t.Union[str, pathlib.Path]) -> pathlib.Path:
else:
return self.root_path / path

@staticmethod
def _apply_middlewares(
app: ASGIApp, middlewares: t.List[t.Type[ASGIApp]]
self, app: ASGIApp, middlewares: t.List[t.Type[ASGIApp]], **kwargs
) -> t.Tuple[ASGIApp, t.Iterable[ASGIApp]]:
"""Apply all middlewares to the provided app.
Expand All @@ -193,7 +197,7 @@ def _apply_middlewares(
# Include the wrapped application in the returned list.
apps = [app]
for middleware in reversed(middlewares):
app = middleware(app) # type: ignore
app = middleware(app, **kwargs) # type: ignore
apps.append(app)
return app, list(reversed(apps))

Expand Down
2 changes: 1 addition & 1 deletion connexion/middleware/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _add_operation_internal(


class RoutingMiddleware(SpecMiddleware):
def __init__(self, app: ASGIApp) -> None:
def __init__(self, app: ASGIApp, **kwargs) -> None:
"""Middleware that resolves the Operation for an incoming request and attaches it to the
scope.
Expand Down
2 changes: 1 addition & 1 deletion connexion/middleware/swagger_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ async def _get_swagger_ui_config(self, request):


class SwaggerUIMiddleware(SpecMiddleware):
def __init__(self, app: ASGIApp) -> None:
def __init__(self, app: ASGIApp, **kwargs) -> None:
"""Middleware that hosts a swagger UI.
:param app: app to wrap in middleware.
Expand Down
38 changes: 38 additions & 0 deletions tests/test_lifespan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import contextlib
import sys
from unittest import mock

import pytest
from connexion import AsyncApp, ConnexionMiddleware


def test_lifespan_handler(app_class):
m = mock.MagicMock()

@contextlib.asynccontextmanager
async def lifespan(app):
m.startup()
yield
m.shutdown()

app = AsyncApp(__name__, lifespan=lifespan)
with app.test_client():
m.startup.assert_called()
m.shutdown.assert_not_called()
m.shutdown.assert_called()


@pytest.mark.skipif(
sys.version_info < (3, 8), reason="AsyncMock only available from 3.8."
)
async def test_lifespan():
"""Test that lifespan events are passed through if no handler is registered."""
lifecycle_handler = mock.Mock()

async def check_lifecycle(scope, receive, send):
if scope["type"] == "lifespan":
lifecycle_handler.handle()

test_app = ConnexionMiddleware(check_lifecycle)
await test_app({"type": "lifespan"}, mock.AsyncMock(), mock.AsyncMock())
lifecycle_handler.handle.assert_called()
18 changes: 1 addition & 17 deletions tests/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class TestMiddleware:

__test__ = False

def __init__(self, app):
def __init__(self, app, **kwargs):
self.app = app

async def __call__(self, scope, receive, send):
Expand Down Expand Up @@ -49,19 +49,3 @@ def test_routing_middleware(middleware_app):
assert (
response.headers.get("operation_id") == "fakeapi.hello.post_greeting"
), response.status_code


@pytest.mark.skipif(
sys.version_info < (3, 8), reason="AsyncMock only available from 3.8."
)
async def test_lifecycle():
"""Test that lifecycle events are passed correctly."""
lifecycle_handler = mock.Mock()

async def check_lifecycle(scope, receive, send):
if scope["type"] == "lifecycle":
lifecycle_handler.handle()

test_app = ConnexionMiddleware(check_lifecycle)
await test_app({"type": "lifecycle"}, mock.AsyncMock, mock.AsyncMock)
lifecycle_handler.handle.assert_called()

0 comments on commit eb0b8a9

Please sign in to comment.