Skip to content

Commit

Permalink
refactor: Move security definitions processing to a separate module
Browse files Browse the repository at this point in the history
  • Loading branch information
Stranger6667 committed May 8, 2020
1 parent 4def224 commit 129e897
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 66 deletions.
70 changes: 4 additions & 66 deletions src/schemathesis/schemas.py
Expand Up @@ -27,6 +27,7 @@
from .filters import should_skip_by_operation_id, should_skip_by_tag, should_skip_endpoint, should_skip_method
from .hooks import HookContext, HookDispatcher, HookLocation, HookScope, dispatch, warn_deprecated_hook
from .models import Endpoint, EndpointDefinition, empty_object
from .specs.openapi.security import OpenAPISecurityProcessor, SwaggerSecurityProcessor
from .types import Filter, GenericTest, Hook, NotSet
from .utils import NOT_SET, GenericResponse, StringDatesYAMLLoader, deprecated, traverse_schema

Expand Down Expand Up @@ -234,6 +235,7 @@ class SwaggerV20(BaseSchema): # pylint: disable=too-many-public-methods
nullable_name = "x-nullable"
example_field = "x-example"
operations: Tuple[str, ...] = ("get", "put", "post", "delete", "options", "head", "patch")
security = SwaggerSecurityProcessor()

def __repr__(self) -> str:
info = self.raw_schema["info"]
Expand Down Expand Up @@ -316,60 +318,9 @@ def make_endpoint( # pylint: disable=too-many-arguments
)
for parameter in parameters:
self.process_parameter(endpoint, parameter)
self.process_security_definitions(endpoint)
self.security.process_definitions(self.raw_schema, endpoint, self.resolver)
return endpoint

def get_security_definitions(self) -> Dict[str, Any]:
"""Extract security definitions from the schema."""
return self.raw_schema.get("securityDefinitions", {})

def get_security_requirements(self, endpoint: Endpoint) -> List[str]:
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/2.0.md#operation-object
# > This definition overrides any declared top-level security.
# > To remove a top-level security declaration, an empty array can be used.
global_requirements = self.raw_schema.get("security", [])
local_requirements = endpoint.definition.raw.get("security", None)
if local_requirements is not None:
requirements = local_requirements
else:
requirements = global_requirements
return [key for requirement in requirements for key in requirement]

def process_security_definitions(self, endpoint: Endpoint) -> None:
"""Add relevant security parameters to data generation."""
definitions = self.get_security_definitions()
requirements = self.get_security_requirements(endpoint)
for name, definition in definitions.items():
if name in requirements:
if definition["type"] == "apiKey":
self.process_api_key_security_definition(definition, endpoint)
self.process_http_security_definition(definition, endpoint)

def process_api_key_security_definition(self, definition: Dict[str, Any], endpoint: Endpoint) -> None:
if definition["in"] == "query":
endpoint.query = self.add_security_definition(endpoint.query, definition)
elif definition["in"] == "header":
endpoint.headers = self.add_security_definition(endpoint.headers, definition)

def process_http_security_definition(self, definition: Dict[str, Any], endpoint: Endpoint) -> None:
if definition["type"] == "basic":
endpoint.headers = self.add_http_auth_definition(endpoint.headers)

def add_security_definition(
self, container: Optional[Dict[str, Any]], definition: Dict[str, Any]
) -> Dict[str, Any]:
name = definition["name"]
container = container or empty_object()
container["properties"][name] = {"name": name, "type": "string"}
container["required"].append(name)
return container

def add_http_auth_definition(self, container: Optional[Dict[str, Any]], scheme: str = "basic") -> Dict[str, Any]:
container = container or empty_object()
container["properties"]["Authorization"] = {"type": "string", "format": f"_{scheme}_auth"}
container["required"].append("Authorization")
return container

def process_parameter(self, endpoint: Endpoint, parameter: Dict[str, Any]) -> None:
"""Convert each Parameter object to a JSON schema."""
parameter = deepcopy(parameter)
Expand Down Expand Up @@ -492,6 +443,7 @@ class OpenApi30(SwaggerV20): # pylint: disable=too-many-ancestors
nullable_name = "nullable"
example_field = "example"
operations = SwaggerV20.operations + ("trace",)
security = OpenAPISecurityProcessor()

@property
def spec_version(self) -> str:
Expand Down Expand Up @@ -530,20 +482,6 @@ def make_endpoint( # pylint: disable=too-many-arguments
self.process_body(endpoint, resolved_definition["requestBody"])
return endpoint

def get_security_definitions(self) -> Dict[str, Any]:
components = self.raw_schema.get("components", {})
security_schemes = components.get("securitySchemes", {})
return self.resolve(security_schemes, RECURSION_DEPTH_LIMIT)

def process_api_key_security_definition(self, definition: Dict[str, Any], endpoint: Endpoint) -> None:
if definition["in"] == "cookie":
endpoint.cookies = self.add_security_definition(endpoint.cookies, definition)
super().process_api_key_security_definition(definition, endpoint)

def process_http_security_definition(self, definition: Dict[str, Any], endpoint: Endpoint) -> None:
if definition["type"] == "http":
endpoint.headers = self.add_http_auth_definition(endpoint.headers, scheme=definition["scheme"].lower())

def process_by_type(self, endpoint: Endpoint, parameter: Dict[str, Any]) -> None:
if parameter["in"] == "cookie":
self.process_cookie(endpoint, parameter)
Expand Down
Empty file.
Empty file.
87 changes: 87 additions & 0 deletions src/schemathesis/specs/openapi/security.py
@@ -0,0 +1,87 @@
"""Processing of ``securityDefinitions`` or ``securitySchemes`` keywords."""
from typing import Any, Dict, List, Optional

import attr
from jsonschema import RefResolver

from ...models import Endpoint, empty_object


@attr.s(slots=True) # pragma: no mutate
class BaseSecurityProcessor:
def process_definitions(self, schema: Dict[str, Any], endpoint: Endpoint, resolver: RefResolver) -> None:
"""Add relevant security parameters to data generation."""
definitions = self.get_security_definitions(schema, resolver)
requirements = get_security_requirements(schema, endpoint)
for name, definition in definitions.items():
if name in requirements:
if definition["type"] == "apiKey":
self.process_api_key_security_definition(definition, endpoint)
self.process_http_security_definition(definition, endpoint)

def get_security_definitions(self, schema: Dict[str, Any], resolver: RefResolver) -> Dict[str, Any]:
return schema.get("securityDefinitions", {})

def process_api_key_security_definition(self, definition: Dict[str, Any], endpoint: Endpoint) -> None:
if definition["in"] == "query":
endpoint.query = add_security_definition(endpoint.query, definition)
elif definition["in"] == "header":
endpoint.headers = add_security_definition(endpoint.headers, definition)

def process_http_security_definition(self, definition: Dict[str, Any], endpoint: Endpoint) -> None:
if definition["type"] == "basic":
endpoint.headers = add_http_auth_definition(endpoint.headers)


SwaggerSecurityProcessor = BaseSecurityProcessor


@attr.s(slots=True) # pragma: no mutate
class OpenAPISecurityProcessor(BaseSecurityProcessor):
def get_security_definitions(self, schema: Dict[str, Any], resolver: RefResolver) -> Dict[str, Any]:
"""In Open API 3 security definitions are located in ``components`` and may have references inside."""
components = schema.get("components", {})
security_schemes = components.get("securitySchemes", {})
if "$ref" in security_schemes:
return resolver.resolve(security_schemes["$ref"])[1]
return security_schemes

def process_api_key_security_definition(self, definition: Dict[str, Any], endpoint: Endpoint) -> None:
if definition["in"] == "cookie":
endpoint.cookies = add_security_definition(endpoint.cookies, definition)
super().process_api_key_security_definition(definition, endpoint)

def process_http_security_definition(self, definition: Dict[str, Any], endpoint: Endpoint) -> None:
if definition["type"] == "http":
endpoint.headers = add_http_auth_definition(endpoint.headers, scheme=definition["scheme"].lower())


def add_security_definition(container: Optional[Dict[str, Any]], definition: Dict[str, Any]) -> Dict[str, Any]:
"""Create a JSON schema for the provided security definition."""
name = definition["name"]
container = container or empty_object()
container["properties"][name] = {"name": name, "type": "string"}
container["required"].append(name)
return container


def add_http_auth_definition(container: Optional[Dict[str, Any]], scheme: str = "basic") -> Dict[str, Any]:
"""HTTP auth is handled via a custom `format` that is registered by Schemathesis during the import time."""
container = container or empty_object()
container["properties"]["Authorization"] = {"type": "string", "format": f"_{scheme}_auth"}
container["required"].append("Authorization")
return container


def get_security_requirements(schema: Dict[str, Any], endpoint: Endpoint) -> List[str]:
"""Get applied security requirements for the given endpoint."""
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/2.0.md#operation-object
# > This definition overrides any declared top-level security.
# > To remove a top-level security declaration, an empty array can be used.
global_requirements = schema.get("security", [])
local_requirements = endpoint.definition.raw.get("security", None)
if local_requirements is not None:
requirements = local_requirements
else:
requirements = global_requirements
return [key for requirement in requirements for key in requirement]

0 comments on commit 129e897

Please sign in to comment.