diff --git a/README.rst b/README.rst index 07b3ca6057..673896d182 100644 --- a/README.rst +++ b/README.rst @@ -396,59 +396,6 @@ For convenience you can explore the schemas and strategies manually: Schema instances implement ``Mapping`` protocol. -Changing data generation behavior -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -If you want to customize how data is generated, then you can use hooks of three types: - -- Global, which are applied to all schemas; -- Schema-local, which are applied only for specific schema instance; -- Test function specific, they are applied only for a specific test function; - -Each hook accepts a Hypothesis strategy and a hook context. Hook context provides additional info that might be helpful to -construct a new strategy, for example ``context.endpoint`` attribute is a reference to the currently tested endpoint. -For more information look at ``schemathesis.hooks.HookContext`` class. - -Hooks should return a Hypothesis strategy: - -.. code:: python - - import schemathesis - - def global_hook(strategy, context): - return strategy.filter(lambda x: x["id"].isdigit()) - - schemathesis.hooks.register("query", hook) - - schema = schemathesis.from_uri("http://0.0.0.0:8080/swagger.json") - - def schema_hook(strategy, context): - return strategy.filter(lambda x: int(x["id"]) % 2 == 0) - - schema.register_hook("query", schema_hook) - - def function_hook(strategy, context): - return strategy.filter(lambda x: len(x["id"]) > 5) - - @schema.with_hook("query", function_hook) - @schema.parametrize() - def test_api(case): - ... - -There are 6 places, where hooks can be applied and you need to pass it as the first argument to ``schemathesis.hooks.register`` or ``schema.register_hook``: - -- path_parameters -- headers -- cookies -- query -- body -- form_data - -It might be useful if you want to exclude certain cases that you don't want to test, or modify the generated data, so it -will be more meaningful for the application - add existing IDs from the database, custom auth header, etc. - -**NOTE**. Global hooks are applied first. - Lazy loading ~~~~~~~~~~~~ diff --git a/docs/changelog.rst b/docs/changelog.rst index e12a4d4455..b6a00665fb 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -11,11 +11,14 @@ Added - ``context`` argument for hook functions to provide an additional context for hooks. A deprecation warning is emitted for hook functions that do not accept this argument. +- A new hook system that allows generic hook dispatching. It comes with new hook locations. For more details see "Customization" section in our documentation. Deprecated ~~~~~~~~~~ -- Hook functions that do not accept ``context`` as their second argument. They will become not supported in Schemathesis 2.0. +- Hook functions that do not accept ``context`` as their first argument. They will become not supported in Schemathesis 2.0. +- Registering hooks by name and function. Use ``register`` decorators instead, for more details see "Customization" section in our documentation. +- ``BaseSchema.with_hook`` and ``BaseSchema.register_hook``. Use ``BaseSchema.hooks.apply`` and ``BaseSchema.hooks.register`` instead. Fixed ~~~~~ diff --git a/docs/customization.rst b/docs/customization.rst new file mode 100644 index 0000000000..82192706ec --- /dev/null +++ b/docs/customization.rst @@ -0,0 +1,81 @@ +.. customization: + +Customization +============= + +Often you need to modify certain aspects of Schemathesis behavior, adjust data generation, modify requests before +sending, and so on. Schemathesis offers a hook mechanism which is similar to the pytest's one. + +Basing on the scope of the changes there are three levels of hooks: + +- Global. These hooks applied to all schemas in the test run; +- Schema-local. Applied only for specific schema instance; +- Test function specific. Applied only for a specific test function; + +To register a new hook function you need to use special decorators - ``register`` for global and schema-local hooks and ``apply`` for test-specific ones: + +.. code:: python + + import schemathesis + + @schemathesis.hooks.register + def before_generate_query(context, strategy): + return strategy.filter(lambda x: x["id"].isdigit()) + + schema = schemathesis.from_uri("http://0.0.0.0:8080/swagger.json") + + @schema.hooks.register("before_generate_query") + def schema_hook(context, strategy): + return strategy.filter(lambda x: int(x["id"]) % 2 == 0) + + def function_hook(context, strategy): + return strategy.filter(lambda x: len(x["id"]) > 5) + + @schema.hooks.apply("before_generate_query", function_hook) + @schema.parametrize() + def test_api(case): + ... + +By default ``register`` functions will check the registered hook name to determine when to run it +(see all hook specifications in the section below), but to avoid name collisions you can provide a hook name as an argument to ``register``. + +Also, these decorators will check the signature of your hook function to match the specification. +Each hook should accept ``context`` as the first argument, that provides additional context for hook execution. + +Hooks registered on the same level will be applied in the order of registration. When there are multiple hooks in the same hook location, then the global ones will be applied first. + +Common hooks +------------ + +These hooks can be applied both in CLI and in-code use cases. + +``before_generate_*`` +~~~~~~~~~~~~~~~~~~~~~ + +This is a group of six hooks that share the same purpose - adjust data generation for specific request's part. + +- ``before_generate_path_parameters`` +- ``before_generate_headers`` +- ``before_generate_cookies`` +- ``before_generate_query`` +- ``before_generate_body`` +- ``before_generate_form_data`` + +They have the same signature that looks like this: + +.. code:: python + + def before_generate_query( + context: schemathesis.hooks.HookContext, + strategy: hypothesis.strategies.SearchStrategy, + ) -> hypothesis.strategies.SearchStrategy: + pass + +``strategy`` is a Hypothesis strategy that will generate a certain request part. For example, your endpoint under test +expects ``id`` query parameter that is a number and you'd like to have only values that have at least three occurrences of "1". +Then your hook might look like this: + +.. code:: python + + def before_generate_query(context, strategy): + return strategy.filter(lambda x: str(x["id"]).count("1") >= 3) diff --git a/docs/index.rst b/docs/index.rst index 5d18642240..e267d8e831 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,8 +1,3 @@ -.. schemathesis documentation master file, created by - sphinx-quickstart on Fri Jan 13 12:59:16 2017. - You can adapt this file completely to your liking, but it should at least - contain the root `toctree` directive. - Welcome to schemathesis's documentation! ======================================== @@ -14,6 +9,7 @@ Welcome to schemathesis's documentation! :caption: Contents: usage + customization targeted faq changelog diff --git a/src/schemathesis/_hypothesis.py b/src/schemathesis/_hypothesis.py index 41ac1ca5c5..664a1461de 100644 --- a/src/schemathesis/_hypothesis.py +++ b/src/schemathesis/_hypothesis.py @@ -14,7 +14,7 @@ from . import utils from ._compat import handle_warnings from .exceptions import InvalidSchema -from .hooks import HookContext, get_hook +from .hooks import GLOBAL_HOOK_DISPATCHER, HookContext, HookDispatcher from .models import Case, Endpoint from .types import Hook @@ -23,11 +23,11 @@ def create_test( - endpoint: Endpoint, test: Callable, settings: Optional[hypothesis.settings] = None, seed: Optional[int] = None, + endpoint: Endpoint, test: Callable, settings: Optional[hypothesis.settings] = None, seed: Optional[int] = None ) -> Callable: """Create a Hypothesis test.""" - hooks = getattr(test, "_schemathesis_hooks", None) - strategy = endpoint.as_strategy(hooks=hooks) + hook_dispatcher = getattr(test, "_schemathesis_hooks", None) + strategy = endpoint.as_strategy(hooks=hook_dispatcher) wrapped_test = hypothesis.given(case=strategy)(test) if seed is not None: wrapped_test = hypothesis.seed(seed)(wrapped_test) @@ -121,7 +121,7 @@ def is_valid_query(query: Dict[str, Any]) -> bool: return True -def get_case_strategy(endpoint: Endpoint, hooks: Optional[Dict[str, Hook]] = None) -> st.SearchStrategy: +def get_case_strategy(endpoint: Endpoint, hooks: Optional[HookDispatcher] = None) -> st.SearchStrategy: """Create a strategy for a complete test case. Path & endpoint are static, the others are JSON schemas. @@ -160,11 +160,7 @@ def filter_path_parameters(parameters: Dict[str, Any]) -> bool: Because of it this case doesn't bring much value and might lead to false positives results of Schemathesis runs. """ - path_parameter_blacklist = ( - ".", - SLASH, - "", - ) + path_parameter_blacklist = (".", SLASH, "") return not any( (value in path_parameter_blacklist or isinstance(value, str) and SLASH in value) @@ -180,7 +176,7 @@ def _get_case_strategy( endpoint: Endpoint, extra_static_parameters: Dict[str, Any], strategies: Dict[str, st.SearchStrategy], - hooks: Optional[Dict[str, Hook]] = None, + hook_dispatcher: Optional[HookDispatcher] = None, ) -> st.SearchStrategy: static_parameters: Dict[str, Any] = {"endpoint": endpoint, **extra_static_parameters} if endpoint.schema.validate_schema and endpoint.method == "GET": @@ -189,29 +185,28 @@ def _get_case_strategy( static_parameters["body"] = None strategies.pop("body", None) context = HookContext(endpoint) - _apply_hooks(strategies, get_hook, context) - _apply_hooks(strategies, endpoint.schema.get_hook, context) - if hooks is not None: - _apply_hooks(strategies, hooks.get, context) + _apply_hooks(strategies, GLOBAL_HOOK_DISPATCHER, context) + _apply_hooks(strategies, endpoint.schema.hooks, context) + if hook_dispatcher is not None: + _apply_hooks(strategies, hook_dispatcher, context) return st.builds(partial(Case, **static_parameters), **strategies) -def _apply_hooks( - strategies: Dict[str, st.SearchStrategy], getter: Callable[[str], Optional[Hook]], context: HookContext -) -> None: - for key, strategy in strategies.items(): - hook = getter(key) - if hook is not None: - args: Union[Tuple[st.SearchStrategy], Tuple[st.SearchStrategy, HookContext]] +def _apply_hooks(strategies: Dict[str, st.SearchStrategy], dispatcher: HookDispatcher, context: HookContext) -> None: + for key in strategies: + for hook in dispatcher.get_hooks(f"before_generate_{key}"): + # Get the strategy on each hook to pass the first hook output as an input to the next one + strategy = strategies[key] + args: Union[Tuple[st.SearchStrategy], Tuple[HookContext, st.SearchStrategy]] if _accepts_context(hook): - args = (strategy, context) + args = (context, strategy) else: args = (strategy,) strategies[key] = hook(*args) def _accepts_context(hook: Hook) -> bool: - # There are no restrictions on the second argument's name and we don't check its name here. + # There are no restrictions on the first argument's name and we don't check its name here. return len(inspect.signature(hook).parameters) == 2 diff --git a/src/schemathesis/hooks.py b/src/schemathesis/hooks.py index 0a9a92eb18..56ff856380 100644 --- a/src/schemathesis/hooks.py +++ b/src/schemathesis/hooks.py @@ -1,21 +1,14 @@ import inspect import warnings -from typing import Optional +from collections import defaultdict +from typing import Callable, DefaultDict, Dict, List, Union, cast import attr +from hypothesis import strategies as st from .constants import HookLocation from .models import Endpoint -from .types import Hook - -GLOBAL_HOOKS = {} - - -@attr.s(slots=True) -class HookContext: - """A context that is passed to hook functions.""" - - endpoint: Endpoint = attr.ib() +from .types import GenericTest, Hook def warn_deprecated_hook(hook: Hook) -> None: @@ -28,16 +21,169 @@ def warn_deprecated_hook(hook: Hook) -> None: ) -def register(place: str, hook: Hook) -> None: - warn_deprecated_hook(hook) - key = HookLocation[place] - GLOBAL_HOOKS[key] = hook +@attr.s(slots=True) # pragma: no mutate +class HookContext: + """A context that is passed to some hook functions.""" + + endpoint: Endpoint = attr.ib() # pragma: no mutate + + +@attr.s(slots=True) # pragma: no mutate +class HookDispatcher: + """Generic hook dispatcher. + + Provides a mechanism to extend Schemathesis in registered hook points. + """ + + _hooks: DefaultDict[str, List[Callable]] = attr.ib(factory=lambda: defaultdict(list)) # pragma: no mutate + _specs: Dict[str, inspect.Signature] = {} # pragma: no mutate + + def register(self, hook: Union[str, Callable]) -> Callable: + """Register a new hook. + + Can be used as a decorator in two forms. + Without arguments for registering hooks and autodetecting their names: + + @schema.hooks.register + def before_generate_query(strategy, context): + ... + + With a hook name as the first argument: + + @schema.hooks.register("before_generate_query") + def hook(strategy, context): + ... + """ + if isinstance(hook, str): + + def decorator(func: Callable) -> Callable: + hook_name = cast(str, hook) + return self.register_hook_with_name(hook_name, func) + + return decorator + return self.register_hook_with_name(hook.__name__, hook) + + def apply(self, name: str, hook: Callable, skip_validation: bool = False) -> Callable[[Callable], Callable]: + """Register hook to run only on one test function. + + Example: + def hook(strategy, context): + ... + + @schema.hooks.apply("before_generate_query", hook) + @schema.parametrize() + def test_api(case): + ... + + """ + + def decorator(func: GenericTest) -> GenericTest: + dispatcher = self.add_dispatcher(func) + dispatcher.register_hook_with_name(name, hook, skip_validation) + return func + + return decorator + @classmethod + def add_dispatcher(cls, func: GenericTest) -> "HookDispatcher": + """Attach a new dispatcher instance to the test if it is not already present.""" + if not hasattr(func, "_schemathesis_hooks"): + func._schemathesis_hooks = cls() # type: ignore + return func._schemathesis_hooks # type: ignore -def get_hook(place: str) -> Optional[Hook]: - key = HookLocation[place] - return GLOBAL_HOOKS.get(key) + def register_hook_with_name(self, name: str, hook: Callable, skip_validation: bool = False) -> Callable: + """A helper for hooks registration. + Besides its use in this class internally it is used to keep backward compatibility with the old hooks system. + """ + # Validation is skipped only for backward compatibility with the old hooks system + if not skip_validation: + self._validate_hook(name, hook) + self._hooks[name].append(hook) + return hook -def unregister_all() -> None: - GLOBAL_HOOKS.clear() + @classmethod + def register_spec(cls, spec: Callable) -> Callable: + """Register hook specification. + + All hooks, registered with `register` should comply with corresponding registered specs. + """ + cls._specs[spec.__name__] = inspect.signature(spec) + return spec + + def _validate_hook(self, name: str, hook: Callable) -> None: + """Basic validation for hooks being registered.""" + spec = self._specs.get(name) + if spec is None: + raise TypeError(f"There is no hook with name '{name}'") + signature = inspect.signature(hook) + if len(signature.parameters) != len(spec.parameters): + raise TypeError( + f"Hook '{name}' takes {len(spec.parameters)} arguments but {len(signature.parameters)} is defined" + ) + + def get_hooks(self, name: str) -> List[Callable]: + """Get a list of hooks registered for name.""" + return self._hooks.get(name, []) + + def unregister_all(self) -> None: + """Remove all registered hooks. + + Useful in tests. + """ + self._hooks = defaultdict(list) + + +@HookDispatcher.register_spec +def before_generate_path_parameters(strategy: st.SearchStrategy, context: HookContext) -> st.SearchStrategy: + pass + + +@HookDispatcher.register_spec +def before_generate_headers(strategy: st.SearchStrategy, context: HookContext) -> st.SearchStrategy: + pass + + +@HookDispatcher.register_spec +def before_generate_cookies(strategy: st.SearchStrategy, context: HookContext) -> st.SearchStrategy: + pass + + +@HookDispatcher.register_spec +def before_generate_query(strategy: st.SearchStrategy, context: HookContext) -> st.SearchStrategy: + pass + + +@HookDispatcher.register_spec +def before_generate_body(strategy: st.SearchStrategy, context: HookContext) -> st.SearchStrategy: + pass + + +@HookDispatcher.register_spec +def before_generate_form_data(strategy: st.SearchStrategy, context: HookContext) -> st.SearchStrategy: + pass + + +GLOBAL_HOOK_DISPATCHER = HookDispatcher() +get_hooks = GLOBAL_HOOK_DISPATCHER.get_hooks +unregister_all = GLOBAL_HOOK_DISPATCHER.unregister_all + + +def register(*args: Union[str, Callable]) -> Callable: + # This code suppose to support backward compatibility with the old hook system. + # In Schemathesis 2.0 this function can be replaced with `register = GLOBAL_HOOK_DISPATCHER.register` + if len(args) == 1: + return GLOBAL_HOOK_DISPATCHER.register(args[0]) + if len(args) == 2: + warnings.warn( + "Calling `schemathesis.register` with two arguments is deprecated, use it as a decorator instead.", + DeprecationWarning, + ) + place, hook = args + hook = cast(Callable, hook) + warn_deprecated_hook(hook) + if place not in HookLocation.__members__: + raise KeyError(place) + return GLOBAL_HOOK_DISPATCHER.register_hook_with_name(f"before_generate_{place}", hook, skip_validation=True) + # This approach is quite naive, but it should be enough for the common use case + raise TypeError("Invalid number of arguments. Please, use `register` as a decorator.") diff --git a/src/schemathesis/models.py b/src/schemathesis/models.py index 7b4a0bf8ce..6efa936a88 100644 --- a/src/schemathesis/models.py +++ b/src/schemathesis/models.py @@ -16,11 +16,12 @@ from .checks import ALL_CHECKS from .exceptions import InvalidSchema -from .types import Body, Cookies, FormData, Headers, Hook, PathParameters, Query +from .types import Body, Cookies, FormData, Headers, PathParameters, Query from .utils import GenericResponse, WSGIResponse if TYPE_CHECKING: from .schemas import BaseSchema + from .hooks import HookDispatcher @attr.s(slots=True) # pragma: no mutate @@ -233,7 +234,7 @@ class Endpoint: body: Optional[Body] = attr.ib(default=None) # pragma: no mutate form_data: Optional[FormData] = attr.ib(default=None) # pragma: no mutate - def as_strategy(self, hooks: Optional[Dict[str, Hook]] = None) -> SearchStrategy: + def as_strategy(self, hooks: Optional["HookDispatcher"] = None) -> SearchStrategy: from ._hypothesis import get_case_strategy # pylint: disable=import-outside-toplevel return get_case_strategy(self, hooks) diff --git a/src/schemathesis/schemas.py b/src/schemathesis/schemas.py index a9ed6d91b6..2348ef41a8 100644 --- a/src/schemathesis/schemas.py +++ b/src/schemathesis/schemas.py @@ -26,15 +26,13 @@ from .converter import to_json_schema, to_json_schema_recursive from .exceptions import InvalidSchema from .filters import should_skip_by_tag, should_skip_endpoint, should_skip_method -from .hooks import warn_deprecated_hook +from .hooks import HookDispatcher, warn_deprecated_hook from .models import Endpoint, empty_object -from .types import Filter, Hook, NotSet -from .utils import NOT_SET, GenericResponse, StringDatesYAMLLoader +from .types import Filter, GenericTest, Hook, NotSet +from .utils import NOT_SET, GenericResponse, StringDatesYAMLLoader, deprecated # Reference resolving will stop after this depth RECURSION_DEPTH_LIMIT = 100 -# Generic test with any arguments and no return -GenericTest = Callable[..., None] # pragma: no mutate def load_file_impl(location: str, opener: Callable) -> Dict[str, Any]: @@ -64,7 +62,8 @@ class BaseSchema(Mapping): endpoint: Optional[Filter] = attr.ib(default=None) # pragma: no mutate tag: Optional[Filter] = attr.ib(default=None) # pragma: no mutate app: Any = attr.ib(default=None) # pragma: no mutate - hooks: Dict[HookLocation, Hook] = attr.ib(factory=dict) # pragma: no mutate + hooks: HookDispatcher = attr.ib(factory=HookDispatcher) # pragma: no mutate + test_function: Optional[GenericTest] = attr.ib(default=None) # pragma: no mutate validate_schema: bool = attr.ib(default=True) # pragma: no mutate def __iter__(self) -> Iterator[str]: @@ -126,14 +125,16 @@ def parametrize( ) -> Callable: """Mark a test function as a parametrized one.""" - def wrapper(func: Callable) -> Callable: - func._schemathesis_test = self.clone(method, endpoint, tag, validate_schema) # type: ignore + def wrapper(func: GenericTest) -> GenericTest: + HookDispatcher.add_dispatcher(func) + func._schemathesis_test = self.clone(func, method, endpoint, tag, validate_schema) # type: ignore return func return wrapper - def clone( + def clone( # pylint: disable=too-many-arguments self, + test_function: Optional[GenericTest] = None, method: Optional[Filter] = NOT_SET, endpoint: Optional[Filter] = NOT_SET, tag: Optional[Filter] = NOT_SET, @@ -157,6 +158,7 @@ def clone( tag=tag, app=self.app, hooks=self.hooks, + test_function=test_function, validate_schema=validate_schema, # type: ignore ) @@ -164,29 +166,21 @@ def _get_response_schema(self, definition: Dict[str, Any]) -> Optional[Dict[str, """Extract response schema from `responses`.""" raise NotImplementedError + @deprecated("'register_hook` is deprecated, use `hooks.register' instead") def register_hook(self, place: str, hook: Hook) -> None: warn_deprecated_hook(hook) - key = HookLocation[place] - self.hooks[key] = hook + if place not in HookLocation.__members__: + raise KeyError(place) + self.hooks.register_hook_with_name(f"before_generate_{place}", hook, skip_validation=True) + @deprecated("'with_hook` is deprecated, use `hooks.apply' instead") def with_hook(self, place: str, hook: Hook) -> Callable[[GenericTest], GenericTest]: """Register a hook for a specific test.""" warn_deprecated_hook(hook) if place not in HookLocation.__members__: raise KeyError(place) - def wrapper(func: GenericTest) -> GenericTest: - if not hasattr(func, "_schemathesis_hooks"): - func._schemathesis_hooks = {} # type: ignore - # a string key is simpler to use later - func._schemathesis_hooks[place] = hook # type: ignore - return func - - return wrapper - - def get_hook(self, place: str) -> Optional[Hook]: - key = HookLocation[place] - return self.hooks.get(key) + return self.hooks.apply(f"before_generate_{place}", hook, skip_validation=True) def get_content_types(self, endpoint: Endpoint, response: GenericResponse) -> List[str]: """Content types available for this endpoint.""" diff --git a/src/schemathesis/types.py b/src/schemathesis/types.py index 09d09d880b..4eda7a10e2 100644 --- a/src/schemathesis/types.py +++ b/src/schemathesis/types.py @@ -30,3 +30,5 @@ class NotSet: ] # pragma: no mutate RawAuth = Tuple[str, str] # pragma: no mutate +# Generic test with any arguments and no return +GenericTest = Callable[..., None] # pragma: no mutate diff --git a/src/schemathesis/utils.py b/src/schemathesis/utils.py index ae37df6407..85cf12ff01 100644 --- a/src/schemathesis/utils.py +++ b/src/schemathesis/utils.py @@ -3,7 +3,9 @@ import re import sys import traceback +import warnings from contextlib import contextmanager +from functools import wraps from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Type, Union from urllib.parse import urlsplit, urlunsplit @@ -28,6 +30,20 @@ NOT_SET = NotSet() +def deprecated(message: str) -> Callable: + """Emit a warning if the given function is used.""" + + def wrapper(func: Callable) -> Callable: + @wraps(func) # pragma: no mutate + def inner(*args: Any, **kwargs: Any) -> Any: + warnings.warn(message, DeprecationWarning) + return func(*args, **kwargs) + + return inner + + return wrapper + + def file_exists(path: str) -> bool: try: return pathlib.Path(path).is_file() diff --git a/test/conftest.py b/test/conftest.py index cbb8b66fb0..5fc5d0a95e 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -124,6 +124,25 @@ def simple_schema(): } +@pytest.fixture(scope="session") +def simple_openapi(): + return { + "openapi": "3.0.2", + "info": {"title": "Test", "description": "Test", "version": "0.1.0"}, + "paths": { + "/query": { + "get": { + "parameters": [ + {"name": "id", "in": "query", "required": True, "schema": {"type": "string", "minLength": 1}}, + {"name": "value", "in": "header", "required": True, "schema": {"type": "string"}}, + ], + "responses": {"200": {"description": "OK"}}, + } + } + }, + } + + @pytest.fixture(scope="session") def swagger_20(simple_schema): return schemathesis.from_dict(simple_schema) diff --git a/test/test_hooks.py b/test/hooks/test_deprecated.py similarity index 76% rename from test/test_hooks.py rename to test/hooks/test_deprecated.py index 36fb4c57da..52de57cb4a 100644 --- a/test/test_hooks.py +++ b/test/hooks/test_deprecated.py @@ -1,10 +1,11 @@ +"""These tests ensure backward compatibility with the old hooks system.""" import pytest from hypothesis import given, settings import schemathesis -def hook(strategy, context): +def hook(context, strategy): return strategy.filter(lambda x: x["id"].isdigit()) @@ -52,7 +53,7 @@ def test(case): @pytest.mark.usefixtures("query_hook") @pytest.mark.endpoints("custom_format") def test_hooks_combination(schema, schema_url): - def extra(st, context): + def extra(context, st): assert context.endpoint == schema.endpoints["/api/custom_format"]["GET"] return st.filter(lambda x: int(x["id"]) % 2 == 0) @@ -68,29 +69,12 @@ def test(case): test() -SIMPLE_SCHEMA = { - "openapi": "3.0.2", - "info": {"title": "Test", "description": "Test", "version": "0.1.0"}, - "paths": { - "/query": { - "get": { - "parameters": [ - {"name": "id", "in": "query", "required": True, "schema": {"type": "string", "minLength": 1}}, - {"name": "value", "in": "header", "required": True, "schema": {"type": "string"}}, - ], - "responses": {"200": {"description": "OK"}}, - } - } - }, -} - - -def test_per_test_hooks(testdir): +def test_per_test_hooks(testdir, simple_openapi): testdir.make_test( """ from hypothesis import strategies as st -def replacement(strategy, context): +def replacement(context, strategy): return st.just({"id": "foobar"}) @schema.with_hook("query", replacement) @@ -105,10 +89,10 @@ def test_a(case): def test_b(case): assert case.query["id"] == "foobar" -def another_replacement(strategy, context): +def another_replacement(context, strategy): return st.just({"id": "foobaz"}) -def third_replacement(strategy, context): +def third_replacement(context, strategy): return st.just({"value": "spam"}) @schema.parametrize() @@ -125,14 +109,24 @@ def test_c(case): def test_d(case): assert case.query["id"] != "foobar" """, - schema=SIMPLE_SCHEMA, + schema=simple_openapi, ) result = testdir.runpytest() result.assert_outcomes(passed=4) -def test_invalid_hook(schema): - def foo(strategy, context): +def test_invalid_global_hook(): + with pytest.raises(KeyError, match="wrong"): + schemathesis.hooks.register("wrong", lambda x: x) + + +def test_invalid_schema_hook(schema): + with pytest.raises(KeyError, match="wrong"): + schema.register_hook("wrong", lambda x: x) + + +def test_invalid_local_hook(schema): + def foo(context, strategy): pass with pytest.raises(KeyError, match="wrong"): @@ -142,10 +136,10 @@ def test(case): pass -def test_hooks_via_parametrize(testdir): +def test_hooks_via_parametrize(testdir, simple_openapi): testdir.make_test( """ -def extra(st, context): +def extra(context, st): return st.filter(lambda x: x["id"].isdigit() and int(x["id"]) % 2 == 0) schema.register_hook("query", extra) @@ -153,10 +147,10 @@ def extra(st, context): @schema.parametrize() @settings(max_examples=1) def test(case): - assert case.endpoint.schema.get_hook("query") is extra + assert case.endpoint.schema.hooks.get_hooks("before_generate_query")[0] is extra assert int(case.query["id"]) % 2 == 0 """, - schema=SIMPLE_SCHEMA, + schema=simple_openapi, ) result = testdir.runpytest() result.assert_outcomes(passed=1) @@ -170,7 +164,7 @@ def deprecated_hook(strategy): schema.register_hook("query", deprecated_hook) assert ( - str(recwarn.list[0].message) == "Hook functions that do not accept `context` argument are deprecated and " + str(recwarn.list[1].message) == "Hook functions that do not accept `context` argument are deprecated and " "support will be removed in Schemathesis 2.0." ) @@ -180,6 +174,10 @@ def deprecated_hook(strategy): @settings(max_examples=3) def test(case): assert case.query["id"].isdigit() - assert int(case.query["id"]) % 2 == 0 test() + + +def test_register_wrong_number_of_argument(): + with pytest.raises(TypeError, match="Invalid number of arguments. Please, use `register` as a decorator."): + schemathesis.hooks.register("a", "b", "c") diff --git a/test/hooks/test_hooks.py b/test/hooks/test_hooks.py new file mode 100644 index 0000000000..5ab311789a --- /dev/null +++ b/test/hooks/test_hooks.py @@ -0,0 +1,232 @@ +import pytest +from hypothesis import given, settings + +import schemathesis +from schemathesis.hooks import HookDispatcher + + +@pytest.fixture(params=["direct", "named"]) +def global_hook(request): + if request.param == "direct": + + @schemathesis.hooks.register + def before_generate_query(context, strategy): + return strategy.filter(lambda x: x["id"].isdigit()) + + if request.param == "named": + + @schemathesis.hooks.register("before_generate_query") + def hook(context, strategy): + return strategy.filter(lambda x: x["id"].isdigit()) + + yield + schemathesis.hooks.unregister_all() + + +@pytest.fixture +def schema(flask_app): + return schemathesis.from_wsgi("/swagger.yaml", flask_app) + + +@pytest.fixture() +def dispatcher(): + return HookDispatcher() + + +@pytest.mark.hypothesis_nested +@pytest.mark.endpoints("custom_format") +@pytest.mark.usefixtures("global_hook") +def test_global_query_hook(schema, schema_url): + strategy = schema.endpoints["/api/custom_format"]["GET"].as_strategy() + + @given(case=strategy) + @settings(max_examples=3) + def test(case): + assert case.query["id"].isdigit() + + test() + + +@pytest.mark.hypothesis_nested +@pytest.mark.endpoints("custom_format") +def test_schema_query_hook(schema, schema_url): + @schema.hooks.register + def before_generate_query(context, strategy): + return strategy.filter(lambda x: x["id"].isdigit()) + + strategy = schema.endpoints["/api/custom_format"]["GET"].as_strategy() + + @given(case=strategy) + @settings(max_examples=3) + def test(case): + assert case.query["id"].isdigit() + + test() + + +@pytest.mark.hypothesis_nested +@pytest.mark.usefixtures("global_hook") +@pytest.mark.endpoints("custom_format") +def test_hooks_combination(schema, schema_url): + @schema.hooks.register("before_generate_query") + def extra(context, st): + assert context.endpoint == schema.endpoints["/api/custom_format"]["GET"] + return st.filter(lambda x: int(x["id"]) % 2 == 0) + + strategy = schema.endpoints["/api/custom_format"]["GET"].as_strategy() + + @given(case=strategy) + @settings(max_examples=3) + def test(case): + assert case.query["id"].isdigit() + assert int(case.query["id"]) % 2 == 0 + + test() + + +def test_per_test_hooks(testdir, simple_openapi): + testdir.make_test( + """ +from hypothesis import strategies as st + +def replacement(context, strategy): + return st.just({"id": "foobar"}) + +@schema.hooks.apply("before_generate_query", replacement) +@schema.parametrize() +@settings(max_examples=1) +def test_a(case): + assert case.query["id"] == "foobar" + +@schema.parametrize() +@schema.hooks.apply("before_generate_query", replacement) +@settings(max_examples=1) +def test_b(case): + assert case.query["id"] == "foobar" + +def another_replacement(context, strategy): + return st.just({"id": "foobaz"}) + +def third_replacement(context, strategy): + return st.just({"value": "spam"}) + +@schema.parametrize() +@schema.hooks.apply("before_generate_query", another_replacement) # Higher priority +@schema.hooks.apply("before_generate_query", replacement) +@schema.hooks.apply("before_generate_headers", third_replacement) +@settings(max_examples=1) +def test_c(case): + assert case.query["id"] == "foobaz" + assert case.headers["value"] == "spam" + +@schema.parametrize() +@settings(max_examples=1) +def test_d(case): + assert case.query["id"] != "foobar" + """, + schema=simple_openapi, + ) + result = testdir.runpytest() + result.assert_outcomes(passed=4) + + +def test_hooks_via_parametrize(testdir, simple_openapi): + testdir.make_test( + """ +@schema.hooks.register("before_generate_query") +def extra(context, st): + return st.filter(lambda x: x["id"].isdigit() and int(x["id"]) % 2 == 0) + +@schema.parametrize() +@settings(max_examples=1) +def test(case): + assert case.endpoint.schema.hooks.get_hooks("before_generate_query")[0] is extra + assert int(case.query["id"]) % 2 == 0 + """, + schema=simple_openapi, + ) + result = testdir.runpytest() + result.assert_outcomes(passed=1) + + +def test_register_invalid_hook_name(dispatcher): + with pytest.raises(TypeError, match="There is no hook with name 'hook'"): + + @dispatcher.register + def hook(): + pass + + +def test_register_invalid_hook_spec(dispatcher): + with pytest.raises(TypeError, match="Hook 'before_generate_query' takes 2 arguments but 3 is defined"): + + @dispatcher.register + def before_generate_query(a, b, c): + pass + + +def test_save_test_function(schema): + assert schema.test_function is None + + @schema.parametrize() + def test(case): + pass + + assert test._schemathesis_test.test_function is test + + +@pytest.mark.parametrize("apply_first", (True, False)) +def test_local_dispatcher(schema, apply_first): + # When there are schema-level hooks + @schema.hooks.register("before_generate_query") + def schema_hook(context, strategy): + return strategy + + # And per-test hooks are applied + def local_hook(context, strategy): + return strategy + + # And order of decorators is any + apply = schema.hooks.apply("before_generate_cookies", local_hook) + parametrize = schema.parametrize() + if apply_first: + wrap = lambda x: parametrize(apply(x)) + else: + wrap = lambda x: apply(parametrize(x)) + + @wrap + def test(case): + pass + + # Then a hook dispatcher instance is attached to the test function + assert isinstance(test._schemathesis_hooks, HookDispatcher) + # And this dispatcher contains only local hooks + assert test._schemathesis_hooks.get_hooks("before_generate_cookies") == [local_hook] + assert test._schemathesis_hooks.get_hooks("before_generate_query") == [] + # And the schema-level dispatcher still contains only schema-level hooks + assert test._schemathesis_test.hooks.get_hooks("before_generate_query") == [schema_hook] + assert test._schemathesis_test.hooks.get_hooks("before_generate_cookies") == [] + + +@pytest.mark.hypothesis_nested +@pytest.mark.endpoints("custom_format") +def test_multiple_hooks_per_spec(schema): + @schema.hooks.register("before_generate_query") + def first_hook(context, strategy): + return strategy.filter(lambda x: x["id"].isdigit()) + + @schema.hooks.register("before_generate_query") + def second_hook(context, strategy): + return strategy.filter(lambda x: int(x["id"]) % 2 == 0) + + assert schema.hooks.get_hooks("before_generate_query") == [first_hook, second_hook] + + strategy = schema.endpoints["/api/custom_format"]["GET"].as_strategy() + + @given(case=strategy) + @settings(max_examples=3) + def test(case): + assert case.query["id"].isdigit() + assert int(case.query["id"]) % 2 == 0 + + test()