Skip to content

Commit

Permalink
Only instantiate specification once (#1819)
Browse files Browse the repository at this point in the history
Fixes #1801 

I had to make quite a few additional changes to satisfy mypy.
  • Loading branch information
RobbeSneyders committed Nov 30, 2023
1 parent bbd085b commit 0857710
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 49 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ repos:
args: ["tests"]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.961
rev: v0.981
hooks:
- id: mypy
files: "^connexion/"
Expand Down
16 changes: 6 additions & 10 deletions connexion/middleware/abstract.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import abc
import logging
import pathlib
import typing as t
from collections import defaultdict

Expand All @@ -22,9 +21,7 @@ class SpecMiddleware(abc.ABC):
base class"""

@abc.abstractmethod
def add_api(
self, specification: t.Union[pathlib.Path, str, dict], **kwargs
) -> t.Any:
def add_api(self, specification: Specification, **kwargs) -> t.Any:
"""
Register an API represented by a single OpenAPI specification on this middleware.
Multiple APIs can be registered on a single middleware.
Expand All @@ -40,15 +37,14 @@ class AbstractSpecAPI:

def __init__(
self,
specification: t.Union[pathlib.Path, str, dict],
specification: Specification,
base_path: t.Optional[str] = None,
resolver: t.Optional[Resolver] = None,
arguments: t.Optional[dict] = None,
uri_parser_class=None,
*args,
**kwargs,
):
self.specification = Specification.load(specification, arguments=arguments)
self.specification = specification
self.uri_parser_class = uri_parser_class

self._set_base_path(base_path)
Expand Down Expand Up @@ -88,7 +84,7 @@ def add_paths(self, paths: t.Optional[dict] = None) -> None:
"""
Adds the paths defined in the specification as operations.
"""
paths = paths or self.specification.get("paths", dict())
paths = t.cast(dict, paths or self.specification.get("paths", dict()))
for path, methods in paths.items():
logger.debug("Adding %s%s...", self.base_path, path)

Expand Down Expand Up @@ -176,7 +172,7 @@ def _handle_add_operation_error(
class RoutedAPI(AbstractSpecAPI, t.Generic[OP]):
def __init__(
self,
specification: t.Union[pathlib.Path, str, dict],
specification: Specification,
*args,
next_app: ASGIApp,
**kwargs,
Expand Down Expand Up @@ -235,7 +231,7 @@ def __init__(self, app: ASGIApp) -> None:
self.app = app
self.apis: t.Dict[str, t.List[API]] = defaultdict(list)

def add_api(self, specification: t.Union[pathlib.Path, str, dict], **kwargs) -> API:
def add_api(self, specification: Specification, **kwargs) -> API:
api = self.api_cls(specification, next_app=self.app, **kwargs)
self.apis[api.base_path].append(api)
return api
Expand Down
9 changes: 5 additions & 4 deletions connexion/middleware/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from connexion.middleware.swagger_ui import SwaggerUIMiddleware
from connexion.options import SwaggerUIOptions
from connexion.resolver import Resolver
from connexion.spec import Specification
from connexion.types import MaybeAwaitable
from connexion.uri_parsing import AbstractURIParser
from connexion.utils import inspect_function_arguments
Expand Down Expand Up @@ -390,18 +391,18 @@ def add_api(
if self.middleware_stack is not None:
raise RuntimeError("Cannot add api after an application has started")

if isinstance(specification, dict):
specification = specification
else:
if isinstance(specification, (pathlib.Path, str)):
specification = t.cast(pathlib.Path, self.specification_dir / specification)

# Add specification as file to watch for reloading
if pathlib.Path.cwd() in specification.parents:
self.extra_files.append(
str(specification.relative_to(pathlib.Path.cwd()))
)

specification = Specification.load(specification, arguments=arguments)

options = self.options.replace(
arguments=arguments,
auth_all_paths=auth_all_paths,
jsonifier=jsonifier,
swagger_ui_options=swagger_ui_options,
Expand Down
8 changes: 4 additions & 4 deletions connexion/middleware/routing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pathlib
import typing as t
from contextvars import ContextVar

Expand All @@ -14,6 +13,7 @@
)
from connexion.operations import AbstractOperation
from connexion.resolver import Resolver
from connexion.spec import Specification

_scope: ContextVar[dict] = ContextVar("SCOPE")

Expand Down Expand Up @@ -50,7 +50,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
class RoutingAPI(AbstractRoutingAPI):
def __init__(
self,
specification: t.Union[pathlib.Path, str, dict],
specification: Specification,
*,
next_app: ASGIApp,
base_path: t.Optional[str] = None,
Expand Down Expand Up @@ -110,14 +110,14 @@ def __init__(self, app: ASGIApp) -> None:

def add_api(
self,
specification: t.Union[pathlib.Path, str, dict],
specification: Specification,
base_path: t.Optional[str] = None,
arguments: t.Optional[dict] = None,
**kwargs,
) -> None:
"""Add an API to the router based on a OpenAPI spec.
:param specification: OpenAPI spec as dict or path to file.
:param specification: OpenAPI spec.
:param base_path: Base path where to add this API.
:param arguments: Jinja arguments to replace in the spec.
"""
Expand Down
24 changes: 22 additions & 2 deletions connexion/middleware/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
from connexion.operations import AbstractOperation
from connexion.security import SecurityHandlerFactory
from connexion.spec import Specification

logger = logging.getLogger("connexion.middleware.security")

Expand All @@ -31,11 +32,21 @@ def __init__(
@classmethod
def from_operation(
cls,
operation: AbstractOperation,
operation: t.Union[AbstractOperation, Specification],
*,
next_app: ASGIApp,
security_handler_factory: SecurityHandlerFactory,
) -> "SecurityOperation":
"""Create a SecurityOperation from an Operation of Specification instance
:param operation: The operation can be both an Operation or Specification instance here
since security is defined at both levels in the OpenAPI spec. Creating a
SecurityOperation based on a Specification can be used to create a SecurityOperation
for routes not explicitly defined in the specification.
:param next_app: The next ASGI app to call.
:param security_handler_factory: The factory to be used to generate security handlers for
the different security schemes.
"""
return cls(
next_app=next_app,
security_handler_factory=security_handler_factory,
Expand Down Expand Up @@ -120,7 +131,16 @@ def add_auth_on_not_found(self) -> None:
default_operation = self.make_operation(self.specification)
self.operations = defaultdict(lambda: default_operation)

def make_operation(self, operation: AbstractOperation) -> SecurityOperation:
def make_operation(
self, operation: t.Union[AbstractOperation, Specification]
) -> SecurityOperation:
"""Create a SecurityOperation from an Operation of Specification instance
:param operation: The operation can be both an Operation or Specification instance here
since security is defined at both levels in the OpenAPI spec. Creating a
SecurityOperation based on a Specification can be used to create a SecurityOperation
for routes not explicitly defined in the specification.
"""
return SecurityOperation.from_operation(
operation,
next_app=self.next_app,
Expand Down
6 changes: 3 additions & 3 deletions connexion/middleware/swagger_ui.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import logging
import pathlib
import re
import typing as t
from contextvars import ContextVar
Expand All @@ -17,6 +16,7 @@
from connexion.middleware import SpecMiddleware
from connexion.middleware.abstract import AbstractSpecAPI
from connexion.options import SwaggerUIConfig, SwaggerUIOptions
from connexion.spec import Specification
from connexion.utils import yamldumper

logger = logging.getLogger("connexion.middleware.swagger_ui")
Expand Down Expand Up @@ -191,14 +191,14 @@ def __init__(self, app: ASGIApp) -> None:

def add_api(
self,
specification: t.Union[pathlib.Path, str, dict],
specification: Specification,
base_path: t.Optional[str] = None,
arguments: t.Optional[dict] = None,
**kwargs
) -> None:
"""Add an API to the router based on a OpenAPI spec.
:param specification: OpenAPI spec as dict or path to file.
:param specification: OpenAPI spec.
:param base_path: Base path where to add this API.
:param arguments: Jinja arguments to replace in the spec.
"""
Expand Down
5 changes: 5 additions & 0 deletions connexion/operations/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ def __init__(

self._responses = self._operation.get("responses", {})

@classmethod
@abc.abstractmethod
def from_spec(cls, spec, *args, path, method, resolver, **kwargs):
pass

@property
def method(self):
"""
Expand Down
16 changes: 15 additions & 1 deletion connexion/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import pathlib
import pkgutil
import typing as t
from collections.abc import Mapping
from urllib.parse import urlsplit

Expand All @@ -19,7 +20,7 @@

from .exceptions import InvalidSpecification
from .json_schema import NullableTypeValidator, resolve_refs
from .operations import OpenAPIOperation, Swagger2Operation
from .operations import AbstractOperation, OpenAPIOperation, Swagger2Operation
from .utils import deep_get

validate_properties = Draft4Validator.VALIDATORS["properties"]
Expand Down Expand Up @@ -72,6 +73,9 @@ def canonical_base_path(base_path):


class Specification(Mapping):

operation_cls: t.Type[AbstractOperation]

def __init__(self, raw_spec, *, base_uri=""):
self._raw_spec = copy.deepcopy(raw_spec)
self._set_defaults(raw_spec)
Expand Down Expand Up @@ -206,6 +210,16 @@ def with_base_path(self, base_path):
new_spec.base_path = base_path
return new_spec

@property
@abc.abstractmethod
def base_path(self):
pass

@base_path.setter
@abc.abstractmethod
def base_path(self, base_path):
pass


class Swagger2Specification(Specification):
"""Python interface for a Swagger 2 specification."""
Expand Down
Loading

0 comments on commit 0857710

Please sign in to comment.