Skip to content

Commit

Permalink
refactor: Make get_case_strategy dependent on the used spec
Browse files Browse the repository at this point in the history
It will unify interfaces for creating Hypothesis strategies
  • Loading branch information
Stranger6667 committed Sep 22, 2020
1 parent e30dfc8 commit 3a177df
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 156 deletions.
148 changes: 1 addition & 147 deletions src/schemathesis/_hypothesis.py
@@ -1,31 +1,14 @@
"""Provide strategies for given endpoint(s) definition."""
import asyncio
import re
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union
from urllib.parse import quote_plus
from typing import Any, Callable, List, Optional, Union

import hypothesis
from hypothesis import strategies as st
from hypothesis_jsonschema import from_schema

from . import utils
from .constants import DEFAULT_DEADLINE
from .exceptions import InvalidSchema
from .hooks import GLOBAL_HOOK_DISPATCHER, HookContext, HookDispatcher
from .models import Case, Endpoint
from .specs.openapi._hypothesis import STRING_FORMATS

PARAMETERS = frozenset(("path_parameters", "headers", "cookies", "query", "body", "form_data"))
LOCATION_TO_CONTAINER = {
"path": "path_parameters",
"query": "query",
"header": "headers",
"cookie": "cookies",
"body": "body",
"formData": "form_data",
}
SLASH = "/"


def create_test(
Expand Down Expand Up @@ -103,132 +86,3 @@ def example_generating_inner_function(ex: Case) -> None:
examples: List[Case] = []
example_generating_inner_function()
return examples[0]


def is_valid_header(headers: Dict[str, Any]) -> bool:
"""Verify if the generated headers are valid."""
for name, value in headers.items():
if not isinstance(value, str):
return False
if not utils.is_latin_1_encodable(value):
return False
if utils.has_invalid_characters(name, value):
return False
return True


def is_illegal_surrogate(item: Any) -> bool:
return isinstance(item, str) and bool(re.search(r"[\ud800-\udfff]", item))


def is_valid_query(query: Dict[str, Any]) -> bool:
"""Surrogates are not allowed in a query string.
`requests` and `werkzeug` will fail to send it to the application.
"""
for name, value in query.items():
if is_illegal_surrogate(name) or is_illegal_surrogate(value):
return False
return True


def get_case_strategy(endpoint: Endpoint, hooks: Optional[HookDispatcher] = None) -> st.SearchStrategy:
"""Create a strategy for a complete test case.
Path & endpoint are static, the others are JSON schemas.
"""
strategies = {}
static_kwargs: Dict[str, Any] = {}
for parameter in PARAMETERS:
value = getattr(endpoint, parameter)
if value is not None:
location = {"headers": "header", "cookies": "cookie", "path_parameters": "path"}.get(parameter, parameter)
strategies[parameter] = prepare_strategy(parameter, value, endpoint.get_hypothesis_conversions(location))
else:
static_kwargs[parameter] = None
return _get_case_strategy(endpoint, static_kwargs, strategies, hooks)


def to_bytes(value: Union[str, bytes, int, bool, float]) -> bytes:
return str(value).encode(errors="ignore")


def is_valid_form_data(form_data: Any) -> bool:
return isinstance(form_data, dict)


def prepare_form_data(form_data: Dict[str, Any]) -> Dict[str, Any]:
for name, value in form_data.items():
if isinstance(value, list):
form_data[name] = [to_bytes(item) if not isinstance(item, (bytes, str, int)) else item for item in value]
elif not isinstance(value, (bytes, str, int)):
form_data[name] = to_bytes(value)
return form_data


def prepare_strategy(parameter: str, value: Dict[str, Any], map_func: Optional[Callable]) -> st.SearchStrategy:
"""Create a strategy for a schema and add location-specific filters & maps."""
strategy = from_schema(value, custom_formats=STRING_FORMATS)
if map_func is not None:
strategy = strategy.map(map_func)
if parameter == "path_parameters":
strategy = strategy.filter(filter_path_parameters).map(quote_all) # type: ignore
elif parameter in ("headers", "cookies"):
strategy = strategy.filter(is_valid_header) # type: ignore
elif parameter == "query":
strategy = strategy.filter(is_valid_query) # type: ignore
elif parameter == "form_data":
strategy = strategy.filter(is_valid_form_data).map(prepare_form_data) # type: ignore
return strategy


def filter_path_parameters(parameters: Dict[str, Any]) -> bool:
"""Single "." chars and empty strings "" are excluded from path by urllib3.
A path containing to "/" or "%2F" will lead to ambiguous path resolution in
many frameworks and libraries, such behaviour have been observed in both
WSGI and ASGI applications.
In this case one variable in the path template will be empty, which will lead to 404 in most of the cases.
Because of it this case doesn't bring much value and might lead to false positives results of Schemathesis runs.
"""

path_parameter_blacklist = (".", SLASH, "")

return not any(
(value in path_parameter_blacklist or is_illegal_surrogate(value) or isinstance(value, str) and SLASH in value)
for value in parameters.values()
)


def quote_all(parameters: Dict[str, Any]) -> Dict[str, Any]:
"""Apply URL quotation for all values in a dictionary."""
return {key: quote_plus(value) if isinstance(value, str) else value for key, value in parameters.items()}


def _get_case_strategy(
endpoint: Endpoint,
extra_static_parameters: Dict[str, Any],
strategies: Dict[str, st.SearchStrategy],
hook_dispatcher: Optional[HookDispatcher] = None,
) -> st.SearchStrategy[Case]:
static_parameters: Dict[str, Any] = {"endpoint": endpoint, **extra_static_parameters}
if endpoint.schema.validate_schema and endpoint.method == "GET":
if endpoint.body is not None:
raise InvalidSchema("Body parameters are defined for GET request.")
static_parameters["body"] = None
strategies.pop("body", None)
context = HookContext(endpoint)
_apply_hooks(strategies, GLOBAL_HOOK_DISPATCHER, context)
_apply_hooks(strategies, endpoint.schema.hooks, context)
if hook_dispatcher is not None:
_apply_hooks(strategies, hook_dispatcher, context)
return st.builds(partial(Case, **static_parameters), **strategies)


def _apply_hooks(strategies: Dict[str, st.SearchStrategy], dispatcher: HookDispatcher, context: HookContext) -> None:
for key in strategies:
for hook in dispatcher.get_all_by_name(f"before_generate_{key}"):
# Get the strategy on each hook to pass the first hook output as an input to the next one
strategy = strategies[key]
strategies[key] = hook(context, strategy)
4 changes: 1 addition & 3 deletions src/schemathesis/models.py
Expand Up @@ -321,9 +321,7 @@ def full_path(self) -> str:
return self.schema.get_full_path(self.path)

def as_strategy(self, hooks: Optional["HookDispatcher"] = None) -> SearchStrategy:
from ._hypothesis import get_case_strategy # pylint: disable=import-outside-toplevel

return get_case_strategy(self, hooks)
return self.schema.get_case_strategy(self, hooks)

def get_strategies_from_examples(self) -> List[SearchStrategy[Case]]:
"""Get examples from endpoint."""
Expand Down
3 changes: 3 additions & 0 deletions src/schemathesis/schemas.py
Expand Up @@ -205,3 +205,6 @@ def prepare_multipart(

def get_request_payload_content_types(self, endpoint: Endpoint) -> List[str]:
raise NotImplementedError

def get_case_strategy(self, endpoint: Endpoint, hooks: Optional[HookDispatcher] = None) -> SearchStrategy:
raise NotImplementedError
142 changes: 141 additions & 1 deletion src/schemathesis/specs/openapi/_hypothesis.py
@@ -1,9 +1,20 @@
import re
from base64 import b64encode
from typing import Tuple
from functools import partial
from typing import Any, Callable, Dict, Optional, Tuple, Union
from urllib.parse import quote_plus

from hypothesis import strategies as st
from hypothesis_jsonschema import from_schema
from requests.auth import _basic_auth_str

from ... import utils
from ...exceptions import InvalidSchema
from ...hooks import GLOBAL_HOOK_DISPATCHER, HookContext, HookDispatcher
from ...models import Case, Endpoint

PARAMETERS = frozenset(("path_parameters", "headers", "cookies", "query", "body", "form_data"))
SLASH = "/"
STRING_FORMATS = {}


Expand All @@ -29,3 +40,132 @@ def make_basic_auth_str(item: Tuple[str, str]) -> str:

register_string_format("_basic_auth", st.tuples(latin1_text, latin1_text).map(make_basic_auth_str)) # type: ignore
register_string_format("_bearer_auth", st.text().map("Bearer {}".format))


def is_valid_header(headers: Dict[str, Any]) -> bool:
"""Verify if the generated headers are valid."""
for name, value in headers.items():
if not isinstance(value, str):
return False
if not utils.is_latin_1_encodable(value):
return False
if utils.has_invalid_characters(name, value):
return False
return True


def is_illegal_surrogate(item: Any) -> bool:
return isinstance(item, str) and bool(re.search(r"[\ud800-\udfff]", item))


def is_valid_query(query: Dict[str, Any]) -> bool:
"""Surrogates are not allowed in a query string.
`requests` and `werkzeug` will fail to send it to the application.
"""
for name, value in query.items():
if is_illegal_surrogate(name) or is_illegal_surrogate(value):
return False
return True


def get_case_strategy(endpoint: Endpoint, hooks: Optional[HookDispatcher] = None) -> st.SearchStrategy:
"""Create a strategy for a complete test case.
Path & endpoint are static, the others are JSON schemas.
"""
strategies = {}
static_kwargs: Dict[str, Any] = {}
for parameter in PARAMETERS:
value = getattr(endpoint, parameter)
if value is not None:
location = {"headers": "header", "cookies": "cookie", "path_parameters": "path"}.get(parameter, parameter)
strategies[parameter] = prepare_strategy(parameter, value, endpoint.get_hypothesis_conversions(location))
else:
static_kwargs[parameter] = None
return _get_case_strategy(endpoint, static_kwargs, strategies, hooks)


def to_bytes(value: Union[str, bytes, int, bool, float]) -> bytes:
return str(value).encode(errors="ignore")


def is_valid_form_data(form_data: Any) -> bool:
return isinstance(form_data, dict)


def prepare_form_data(form_data: Dict[str, Any]) -> Dict[str, Any]:
for name, value in form_data.items():
if isinstance(value, list):
form_data[name] = [to_bytes(item) if not isinstance(item, (bytes, str, int)) else item for item in value]
elif not isinstance(value, (bytes, str, int)):
form_data[name] = to_bytes(value)
return form_data


def prepare_strategy(parameter: str, value: Dict[str, Any], map_func: Optional[Callable]) -> st.SearchStrategy:
"""Create a strategy for a schema and add location-specific filters & maps."""
strategy = from_schema(value, custom_formats=STRING_FORMATS)
if map_func is not None:
strategy = strategy.map(map_func)
if parameter == "path_parameters":
strategy = strategy.filter(filter_path_parameters).map(quote_all) # type: ignore
elif parameter in ("headers", "cookies"):
strategy = strategy.filter(is_valid_header) # type: ignore
elif parameter == "query":
strategy = strategy.filter(is_valid_query) # type: ignore
elif parameter == "form_data":
strategy = strategy.filter(is_valid_form_data).map(prepare_form_data) # type: ignore
return strategy


def filter_path_parameters(parameters: Dict[str, Any]) -> bool:
"""Single "." chars and empty strings "" are excluded from path by urllib3.
A path containing to "/" or "%2F" will lead to ambiguous path resolution in
many frameworks and libraries, such behaviour have been observed in both
WSGI and ASGI applications.
In this case one variable in the path template will be empty, which will lead to 404 in most of the cases.
Because of it this case doesn't bring much value and might lead to false positives results of Schemathesis runs.
"""

path_parameter_blacklist = (".", SLASH, "")

return not any(
(value in path_parameter_blacklist or is_illegal_surrogate(value) or isinstance(value, str) and SLASH in value)
for value in parameters.values()
)


def quote_all(parameters: Dict[str, Any]) -> Dict[str, Any]:
"""Apply URL quotation for all values in a dictionary."""
return {key: quote_plus(value) if isinstance(value, str) else value for key, value in parameters.items()}


def _get_case_strategy(
endpoint: Endpoint,
extra_static_parameters: Dict[str, Any],
strategies: Dict[str, st.SearchStrategy],
hook_dispatcher: Optional[HookDispatcher] = None,
) -> st.SearchStrategy[Case]:
static_parameters: Dict[str, Any] = {"endpoint": endpoint, **extra_static_parameters}
if endpoint.schema.validate_schema and endpoint.method == "GET":
if endpoint.body is not None:
raise InvalidSchema("Body parameters are defined for GET request.")
static_parameters["body"] = None
strategies.pop("body", None)
context = HookContext(endpoint)
_apply_hooks(strategies, GLOBAL_HOOK_DISPATCHER, context)
_apply_hooks(strategies, endpoint.schema.hooks, context)
if hook_dispatcher is not None:
_apply_hooks(strategies, hook_dispatcher, context)
return st.builds(partial(Case, **static_parameters), **strategies)


def _apply_hooks(strategies: Dict[str, st.SearchStrategy], dispatcher: HookDispatcher, context: HookContext) -> None:
for key in strategies:
for hook in dispatcher.get_all_by_name(f"before_generate_{key}"):
# Get the strategy on each hook to pass the first hook output as an input to the next one
strategy = strategies[key]
strategies[key] = hook(context, strategy)
8 changes: 8 additions & 0 deletions src/schemathesis/specs/openapi/constants.py
@@ -0,0 +1,8 @@
LOCATION_TO_CONTAINER = {
"path": "path_parameters",
"query": "query",
"header": "headers",
"cookie": "cookies",
"body": "body",
"formData": "form_data",
}
3 changes: 2 additions & 1 deletion src/schemathesis/specs/openapi/examples.py
Expand Up @@ -2,8 +2,9 @@

from hypothesis.strategies import SearchStrategy

from ..._hypothesis import LOCATION_TO_CONTAINER, PARAMETERS, _get_case_strategy, prepare_strategy
from ...models import Case, Endpoint
from ._hypothesis import PARAMETERS, _get_case_strategy, prepare_strategy
from .constants import LOCATION_TO_CONTAINER


def get_object_example_from_properties(object_schema: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
2 changes: 1 addition & 1 deletion src/schemathesis/specs/openapi/links.py
Expand Up @@ -7,11 +7,11 @@

import attr

from ..._hypothesis import LOCATION_TO_CONTAINER
from ...models import Case, Endpoint
from ...stateful import ParsedData, StatefulTest
from ...utils import NOT_SET, GenericResponse
from . import expressions
from .constants import LOCATION_TO_CONTAINER


@attr.s(slots=True) # pragma: no mutate
Expand Down
6 changes: 5 additions & 1 deletion src/schemathesis/specs/openapi/schemas.py
Expand Up @@ -9,13 +9,14 @@
from requests.structures import CaseInsensitiveDict

from ...exceptions import InvalidSchema
from ...hooks import HookContext
from ...hooks import HookContext, HookDispatcher
from ...models import Case, Endpoint, EndpointDefinition, empty_object
from ...schemas import BaseSchema
from ...stateful import StatefulTest
from ...types import FormData
from ...utils import GenericResponse
from . import links, serialization
from ._hypothesis import get_case_strategy
from .converter import to_json_schema_recursive
from .examples import get_strategies_from_examples
from .filters import should_skip_by_operation_id, should_skip_by_tag, should_skip_endpoint, should_skip_method
Expand Down Expand Up @@ -177,6 +178,9 @@ def get_endpoint_by_reference(self, reference: str) -> Endpoint:
raw_definition = EndpointDefinition(data, resolved_definition, scope)
return self.make_endpoint(path, method, parameters, resolved_definition, raw_definition)

def get_case_strategy(self, endpoint: Endpoint, hooks: Optional[HookDispatcher] = None) -> SearchStrategy:
return get_case_strategy(endpoint, hooks)


class SwaggerV20(BaseOpenAPISchema):
nullable_name = "x-nullable"
Expand Down

0 comments on commit 3a177df

Please sign in to comment.