diff --git a/openapi/rest.py b/openapi/rest.py index b2381e8..12f74f6 100644 --- a/openapi/rest.py +++ b/openapi/rest.py @@ -1,5 +1,5 @@ -import typing as t from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Sequence from aiohttp.web import Application @@ -11,13 +11,14 @@ def rest( - openapi: t.Dict = None, - setup_app: t.Callable[[Application], None] = None, + openapi: Optional[Dict] = None, + setup_app: Callable[[Application], None] = None, base_path: str = "", - commands: t.Optional[t.List] = None, - allowed_tags: t.Optional[t.Set[str]] = None, + commands: Optional[List] = None, + allowed_tags: Sequence[str] = (), validate_docs: bool = False, - servers: t.Optional[t.List[str]] = None, + servers: Optional[List[str]] = None, + security: Optional[Dict[str, Dict]] = None, OpenApiSpecClass: type = OpenApiSpec, **kwargs, ) -> OpenApiClient: @@ -29,6 +30,7 @@ def rest( allowed_tags=allowed_tags, validate_docs=validate_docs, servers=servers, + security=security, ), base_path=base_path, commands=commands, diff --git a/openapi/spec/spec.py b/openapi/spec/spec.py index 4c913ee..1055437 100644 --- a/openapi/spec/spec.py +++ b/openapi/spec/spec.py @@ -1,5 +1,5 @@ from collections import OrderedDict -from dataclasses import Field, asdict, dataclass, field +from dataclasses import Field, asdict, dataclass from dataclasses import fields as get_fields from dataclasses import is_dataclass from datetime import date, datetime @@ -43,7 +43,6 @@ class OpenApi: description: str = "" version: str = "0.1.0" termsOfService: str = "" - security: Dict[str, Dict] = field(default_factory=dict) contact: Contact = Contact() license: License = License() @@ -188,6 +187,7 @@ def __init__( allowed_tags: Iterable = None, validate_docs: bool = False, servers: Optional[List] = None, + security: Optional[Dict[str, Dict]] = None, ) -> None: super().__init__(validate_docs=validate_docs) self.parameters: Dict = {} @@ -197,6 +197,7 @@ def __init__( self.servers: List = servers or [] self.default_content_type = default_content_type or "application/json" self.default_responses = default_responses or {} + self.security = security self.doc = dict( openapi=OPENAPI, info=asdict(info or OpenApi()), paths=OrderedDict() ) @@ -224,9 +225,10 @@ def build( self.add_schema_to_parse(ValidationErrors) self.add_schema_to_parse(ErrorMessage) self.add_schema_to_parse(FieldError) - security = self.doc["info"].pop("security", None) or {} - if security: - self.doc["info"]["security"] = list(security) + security = self.security or {} + self.doc["security"] = [ + {name: value.pop("scopes", [])} for name, value in security.items() + ] # Build paths self._build_paths(app, public, private) s = self.parsed_schemas() diff --git a/tests/example/main.py b/tests/example/main.py index bda473e..9ce0887 100644 --- a/tests/example/main.py +++ b/tests/example/main.py @@ -17,7 +17,20 @@ def create_app(): - return rest(setup_app=setup_app) + return rest( + security=dict( + auth_key={ + "type": "apiKey", + "name": "X-Meta-Api-Key", + "description": ( + "The authentication key is required to access most " + "endpoints of the API" + ), + "in": "header", + } + ), + setup_app=setup_app, + ) def setup_app(app: web.Application) -> None: diff --git a/tests/spec/test_spec.py b/tests/spec/test_spec.py index 85597fe..a0972f0 100644 --- a/tests/spec/test_spec.py +++ b/tests/spec/test_spec.py @@ -1,13 +1,10 @@ import pytest +from openapi.exc import InvalidSpecException from openapi.rest import rest from openapi.spec import OpenApi, OpenApiSpec -from openapi.exc import InvalidSpecException - from tests.example import endpoints, endpoints_additional -# from openapi_spec_validator import validate_spec - def create_spec_app(routes): def setup_app(app): @@ -29,23 +26,6 @@ async def test_spec_validation(test_app): # validate_spec(spec.doc) -async def test_spec_security(test_app): - open_api = OpenApi( - security=dict( - auth_key={ - "type": "apiKey", - "name": "X-Api-Key", - "description": "The authentication key", - "in": "header", - } - ) - ) - spec = OpenApiSpec(open_api) - spec.build(test_app) - assert spec.doc["info"]["security"] == ["auth_key"] - assert spec.doc["components"]["securitySchemes"] - - async def test_spec_422(test_app): spec = OpenApiSpec() spec.build(test_app)