diff --git a/docs/changelog.rst b/docs/changelog.rst index 57d6ac94bd0..d1e4d847b44 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -7,6 +7,8 @@ Changelog **Added** - ``DataGenerationMethod.all`` shortcut to get all possible enum variants. +- A flexible way to select API operations for testing. It is now possible to exclude or include them by arbitrary + predicates. `#703`_, `#819`_, `#1006`_ `3.10.0`_ - 2021-09-13 ---------------------- @@ -2188,6 +2190,7 @@ Deprecated .. _#1013: https://github.com/schemathesis/schemathesis/issues/1013 .. _#1010: https://github.com/schemathesis/schemathesis/issues/1010 .. _#1007: https://github.com/schemathesis/schemathesis/issues/1007 +.. _#1006: https://github.com/schemathesis/schemathesis/issues/1006 .. _#1003: https://github.com/schemathesis/schemathesis/issues/1003 .. _#999: https://github.com/schemathesis/schemathesis/issues/999 .. _#994: https://github.com/schemathesis/schemathesis/issues/994 @@ -2243,6 +2246,7 @@ Deprecated .. _#830: https://github.com/schemathesis/schemathesis/issues/830 .. _#824: https://github.com/schemathesis/schemathesis/issues/824 .. _#822: https://github.com/schemathesis/schemathesis/issues/822 +.. _#819: https://github.com/schemathesis/schemathesis/issues/819 .. _#816: https://github.com/schemathesis/schemathesis/issues/816 .. _#814: https://github.com/schemathesis/schemathesis/issues/814 .. _#812: https://github.com/schemathesis/schemathesis/issues/812 @@ -2269,6 +2273,7 @@ Deprecated .. _#708: https://github.com/schemathesis/schemathesis/issues/708 .. _#706: https://github.com/schemathesis/schemathesis/issues/706 .. _#705: https://github.com/schemathesis/schemathesis/issues/705 +.. _#703: https://github.com/schemathesis/schemathesis/issues/703 .. _#702: https://github.com/schemathesis/schemathesis/issues/702 .. _#700: https://github.com/schemathesis/schemathesis/issues/700 .. _#695: https://github.com/schemathesis/schemathesis/issues/695 diff --git a/src/schemathesis/filters.py b/src/schemathesis/filters.py new file mode 100644 index 00000000000..890a87eccaf --- /dev/null +++ b/src/schemathesis/filters.py @@ -0,0 +1,70 @@ +import enum +from typing import Any, Callable, List, Optional + +import attr + +# Indicates the most common place where filters are applied. +# Other scopes are spec-specific and their filters may be applied earlier to avoid expensive computations +DEFAULT_SCOPE = None + + +class FilterResult(enum.Enum): + """The result of a single filter call. + + This functionality is implemented as a separate enum and not a simple boolean to provide a more descriptive API. + """ + + INCLUDED = enum.auto() + EXCLUDED = enum.auto() + + @property + def is_included(self) -> bool: + return self == FilterResult.INCLUDED + + @property + def is_excluded(self) -> bool: + return self == FilterResult.EXCLUDED + + def __bool__(self) -> bool: + return self.is_included + + def __and__(self, other: "FilterResult") -> "FilterResult": + if self.is_excluded or other.is_excluded: + return FilterResult.EXCLUDED + return self + + +@attr.s(slots=True) +class BaseFilter: + func: Callable[..., bool] = attr.ib() + scope: Optional[str] = attr.ib(default=DEFAULT_SCOPE) + + def apply(self, item: Any) -> FilterResult: + raise NotImplementedError + + +@attr.s(slots=True) +class Include(BaseFilter): + def apply(self, item: Any) -> FilterResult: + if self.func(item): + return FilterResult.INCLUDED + return FilterResult.EXCLUDED + + +@attr.s(slots=True) +class Exclude(BaseFilter): + def apply(self, item: Any) -> FilterResult: + if self.func(item): + return FilterResult.EXCLUDED + return FilterResult.INCLUDED + + +def evaluate_filters(filters: List[BaseFilter], item: Any, scope: Optional[str] = DEFAULT_SCOPE) -> FilterResult: + """Decide whether the given item passes the filters.""" + # Lazily apply filters that match the given scope + matching_filters = filter(lambda f: f.scope == scope, filters) + outcomes = map(lambda f: f.apply(item), matching_filters) + # If any filter will exclude the item, then the process short-circuits without evaluating all filters + if all(outcomes): + return FilterResult.INCLUDED + return FilterResult.EXCLUDED diff --git a/src/schemathesis/hooks.py b/src/schemathesis/hooks.py index 2325840b9e7..211a30a38c9 100644 --- a/src/schemathesis/hooks.py +++ b/src/schemathesis/hooks.py @@ -36,7 +36,7 @@ class HookContext: operation: Optional["APIOperation"] = attr.ib(default=None) # pragma: no mutate - @deprecated_property(removed_in="4.0", replacement="operation") + @deprecated_property(removed_in="4.0", replacement="`operation`") def endpoint(self) -> Optional["APIOperation"]: return self.operation diff --git a/src/schemathesis/lazy.py b/src/schemathesis/lazy.py index aa4732df113..31b824e07b1 100644 --- a/src/schemathesis/lazy.py +++ b/src/schemathesis/lazy.py @@ -22,6 +22,7 @@ is_given_applied, merge_given_args, validate_given_args, + warn_filtration_arguments, ) @@ -51,6 +52,11 @@ def parametrize( data_generation_methods: Union[Iterable[DataGenerationMethod], NotSet] = NOT_SET, code_sample_style: Union[str, NotSet] = NOT_SET, ) -> Callable: + # pylint: disable=too-many-statements + for name in ("method", "endpoint", "tag", "operation_id", "skip_deprecated_operations"): + value = locals()[name] + if value is not NOT_SET: + warn_filtration_arguments(name) if method is NOT_SET: method = self.method if endpoint is NOT_SET: diff --git a/src/schemathesis/models.py b/src/schemathesis/models.py index 28e92500883..7ca16b3993b 100644 --- a/src/schemathesis/models.py +++ b/src/schemathesis/models.py @@ -118,7 +118,7 @@ def __repr__(self) -> str: parts.extend((name, "=", repr(value))) return "".join(parts) + ")" - @deprecated_property(removed_in="4.0", replacement="operation") + @deprecated_property(removed_in="4.0", replacement="`operation`") def endpoint(self) -> "APIOperation": return self.operation diff --git a/src/schemathesis/runner/__init__.py b/src/schemathesis/runner/__init__.py index ed77a59c74c..d4616247b2a 100644 --- a/src/schemathesis/runner/__init__.py +++ b/src/schemathesis/runner/__init__.py @@ -31,7 +31,7 @@ ) -@deprecated(removed_in="4.0", replacement="schemathesis.runner.from_schema") +@deprecated(removed_in="4.0", replacement="`schemathesis.runner.from_schema`") def prepare( schema_uri: Union[str, Dict[str, Any]], *, diff --git a/src/schemathesis/schemas.py b/src/schemathesis/schemas.py index c7aea9e86c6..2e310436fc6 100644 --- a/src/schemathesis/schemas.py +++ b/src/schemathesis/schemas.py @@ -23,6 +23,7 @@ Type, TypeVar, Union, + cast, ) from urllib.parse import quote, unquote, urljoin, urlsplit, urlunsplit @@ -34,11 +35,22 @@ from ._hypothesis import create_test from .constants import DEFAULT_DATA_GENERATION_METHODS, CodeSampleStyle, DataGenerationMethod from .exceptions import InvalidSchema, UsageError +from .filters import BaseFilter, Exclude, Include from .hooks import HookContext, HookDispatcher, HookScope, dispatch from .models import APIOperation, Case from .stateful import APIStateMachine, Stateful, StatefulTest from .types import Body, Cookies, Filter, FormData, GenericTest, Headers, NotSet, PathParameters, Query -from .utils import NOT_SET, PARAMETRIZE_MARKER, Err, GenericResponse, GivenInput, Ok, Result, given_proxy +from .utils import ( + NOT_SET, + PARAMETRIZE_MARKER, + Err, + GenericResponse, + GivenInput, + Ok, + Result, + given_proxy, + warn_filtration_arguments, +) class MethodsDict(CaseInsensitiveDict): @@ -57,6 +69,7 @@ def __getitem__(self, item: Any) -> Any: C = TypeVar("C", bound=Case) +S = TypeVar("S", bound="BaseSchema") @attr.s(eq=False) # pragma: no mutate @@ -77,6 +90,7 @@ class BaseSchema(Mapping): default=DEFAULT_DATA_GENERATION_METHODS ) # pragma: no mutate code_sample_style: CodeSampleStyle = attr.ib(default=CodeSampleStyle.default()) # pragma: no mutate + filters: List[BaseFilter] = attr.ib(factory=list) def __iter__(self) -> Iterator[str]: return iter(self.operations) @@ -203,6 +217,10 @@ def parametrize( _code_sample_style = ( CodeSampleStyle.from_str(code_sample_style) if isinstance(code_sample_style, str) else code_sample_style ) + for name in ("method", "endpoint", "tag", "operation_id", "skip_deprecated_operations"): + value = locals()[name] + if value is not NOT_SET: + warn_filtration_arguments(name) def wrapper(func: GenericTest) -> GenericTest: if hasattr(func, PARAMETRIZE_MARKER): @@ -216,7 +234,7 @@ def wrapped_test(*_: Any, **__: Any) -> NoReturn: return wrapped_test HookDispatcher.add_dispatcher(func) - cloned = self.clone( + cloned: BaseSchema = self.clone( test_function=func, method=method, endpoint=endpoint, @@ -251,7 +269,8 @@ def clone( skip_deprecated_operations: Union[bool, NotSet] = NOT_SET, data_generation_methods: Union[Iterable[DataGenerationMethod], NotSet] = NOT_SET, code_sample_style: Union[CodeSampleStyle, NotSet] = NOT_SET, - ) -> "BaseSchema": + filters: Union[List[BaseFilter], NotSet] = NOT_SET, + ) -> S: if base_url is NOT_SET: base_url = self.base_url if method is NOT_SET: @@ -274,8 +293,13 @@ def clone( data_generation_methods = self.data_generation_methods if code_sample_style is NOT_SET: code_sample_style = self.code_sample_style + new_filters = self._construct_filters( + endpoint, method, tag, operation_id, cast(bool, skip_deprecated_operations), Include + ) + if filters is not NOT_SET: + new_filters += cast(List[BaseFilter], filters) - return self.__class__( + return self.__class__( # type: ignore self.raw_schema, location=self.location, base_url=base_url, # type: ignore @@ -290,8 +314,20 @@ def clone( skip_deprecated_operations=skip_deprecated_operations, # type: ignore data_generation_methods=data_generation_methods, # type: ignore code_sample_style=code_sample_style, # type: ignore + filters=new_filters, ) + def _construct_filters( + self, + path: Optional[Filter], + method: Optional[Filter], + tag: Optional[Filter], + operation_id: Optional[Filter], + skip_deprecated_operations: Optional[bool], + cls: Type[BaseFilter], + ) -> List[BaseFilter]: + return [] + def get_local_hook_dispatcher(self) -> Optional[HookDispatcher]: """Get a HookDispatcher instance bound to the test if present.""" # It might be not present when it is used without pytest via `APIOperation.as_strategy()` @@ -358,6 +394,18 @@ def validate_response(self, operation: APIOperation, response: GenericResponse) def prepare_schema(self, schema: Any) -> Any: raise NotImplementedError + def include_by(self, predicate: Callable) -> "BaseSchema": + """Get a new schema that includes API operations that pass the given predicate.""" + return self._filter_by(Include(predicate)) + + def exclude_by(self, predicate: Callable) -> "BaseSchema": + """Get a new schema that excludes API operations that pass the given predicate.""" + return self._filter_by(Exclude(predicate)) + + def _filter_by(self, *predicates: BaseFilter) -> S: + filters = self.filters + list(predicates) + return self.clone(filters=filters) + def operations_to_dict( operations: Generator[Result[APIOperation, InvalidSchema], None, None] diff --git a/src/schemathesis/specs/openapi/loaders.py b/src/schemathesis/specs/openapi/loaders.py index 336b64da06d..a970200fa30 100644 --- a/src/schemathesis/specs/openapi/loaders.py +++ b/src/schemathesis/specs/openapi/loaders.py @@ -16,7 +16,14 @@ from ...hooks import HookContext, dispatch from ...lazy import LazySchema from ...types import Filter, NotSet, PathLike -from ...utils import NOT_SET, StringDatesYAMLLoader, WSGIResponse, require_relative_url, setup_headers +from ...utils import ( + NOT_SET, + StringDatesYAMLLoader, + WSGIResponse, + require_relative_url, + setup_headers, + warn_filtration_arguments, +) from . import definitions from .schemas import BaseOpenAPISchema, OpenApi30, SwaggerV20 @@ -179,12 +186,18 @@ def from_dict( :param dict raw_schema: A schema to load. """ + for name in ("method", "endpoint", "tag", "operation_id"): + value = locals()[name] + if value is not None: + warn_filtration_arguments(name) + if skip_deprecated_operations is True: + warn_filtration_arguments("skip_deprecated_operations") _code_sample_style = CodeSampleStyle.from_str(code_sample_style) dispatch("before_load_schema", HookContext(), raw_schema) def init_openapi_2() -> SwaggerV20: _maybe_validate_schema(raw_schema, definitions.SWAGGER_20_VALIDATOR, validate_schema) - return SwaggerV20( + schema = SwaggerV20( raw_schema, app=app, base_url=base_url, @@ -198,10 +211,11 @@ def init_openapi_2() -> SwaggerV20: code_sample_style=_code_sample_style, location=location, ) + return schema.include(endpoint, method) def init_openapi_3() -> OpenApi30: _maybe_validate_schema(raw_schema, definitions.OPENAPI_30_VALIDATOR, validate_schema) - return OpenApi30( + schema = OpenApi30( raw_schema, app=app, base_url=base_url, @@ -215,6 +229,7 @@ def init_openapi_3() -> OpenApi30: code_sample_style=_code_sample_style, location=location, ) + return schema.include(endpoint, method) if force_schema_version == "20": return init_openapi_2() diff --git a/src/schemathesis/specs/openapi/schemas.py b/src/schemathesis/specs/openapi/schemas.py index a22a4525044..724c051b0fe 100644 --- a/src/schemathesis/specs/openapi/schemas.py +++ b/src/schemathesis/specs/openapi/schemas.py @@ -21,6 +21,7 @@ Type, TypeVar, Union, + cast, ) from urllib.parse import urlsplit @@ -38,11 +39,12 @@ get_response_parsing_error, get_schema_validation_error, ) +from ...filters import BaseFilter, Exclude, Include, evaluate_filters from ...hooks import HookContext, HookDispatcher from ...models import APIOperation, Case, OperationDefinition from ...schemas import BaseSchema from ...stateful import APIStateMachine, Stateful, StatefulTest -from ...types import Body, Cookies, FormData, Headers, NotSet, PathParameters, Query +from ...types import Body, Cookies, Filter, FormData, Headers, NotSet, PathParameters, Query from ...utils import ( NOT_SET, Err, @@ -78,6 +80,7 @@ SCHEMA_ERROR_MESSAGE = "Schema parsing failed. Please check your schema." SCHEMA_PARSING_ERRORS = (KeyError, AttributeError, jsonschema.exceptions.RefResolutionError) +S = TypeVar("S", bound="BaseOpenAPISchema") @attr.s(eq=False, repr=False) @@ -112,6 +115,54 @@ def __repr__(self) -> str: info = self.raw_schema["info"] return f"{self.__class__.__name__} for {info['title']} ({info['version']})" + def include( + self, + path: Optional[Filter] = None, + method: Optional[Filter] = None, + tag: Optional[Filter] = None, + operation_id: Optional[Filter] = None, + ) -> S: + predicates = self._construct_filters(path, method, tag, operation_id, None, Include) + return self._filter_by(*predicates) + + def exclude( + self, + path: Optional[Filter] = None, + method: Optional[Filter] = None, + tag: Optional[Filter] = None, + operation_id: Optional[Filter] = None, + ) -> S: + predicates = self._construct_filters(path, method, tag, operation_id, None, Exclude) + return self._filter_by(*predicates) + + def _construct_filters( + self, + path: Optional[Filter], + method: Optional[Filter], + tag: Optional[Filter], + operation_id: Optional[Filter], + skip_deprecated_operations: Optional[bool], + cls: Type[BaseFilter], + ) -> List[BaseFilter]: + predicates: List[BaseFilter] = [] + if path is not None: + predicates.append(cls(lambda i: not should_skip_endpoint(i[0], path), "path")) + if method is not None: + predicates.append(cls(lambda i: not should_skip_method(i[1], method))) + if tag is not None: + predicates.append(cls(lambda i: not should_skip_by_tag(i[2].get("tags"), tag))) + if operation_id is not None: + predicates.append(cls(lambda i: not should_skip_by_operation_id(i[2].get("operationId"), operation_id))) + if skip_deprecated_operations is not None: + predicates.append( + cls( + lambda i: not should_skip_deprecated( + i[2].get("deprecated", False), cast(bool, skip_deprecated_operations) + ) + ) + ) + return predicates + def get_all_operations(self) -> Generator[Result[APIOperation, InvalidSchema], None, None]: """Iterate over all operations defined in the API. @@ -139,7 +190,7 @@ def get_all_operations(self) -> Generator[Result[APIOperation, InvalidSchema], N method = None try: full_path = self.get_full_path(path) # Should be available for later use - if should_skip_endpoint(full_path, self.endpoint): + if evaluate_filters(self.filters, (full_path, None, None), "path").is_excluded: continue self.dispatch_hook("before_process_path", context, path, methods) scope, raw_methods = self._resolve_methods(methods) @@ -150,15 +201,9 @@ def get_all_operations(self) -> Generator[Result[APIOperation, InvalidSchema], N # too much but decreases the number of cases when Schemathesis stuck on this step. with self.resolver.in_scope(scope): resolved_definition = self.resolver.resolve_all(definition, RECURSION_DEPTH_LIMIT - 5) - # Only method definitions are parsed if ( method not in self.allowed_http_methods - or should_skip_method(method, self.method) - or should_skip_deprecated( - resolved_definition.get("deprecated", False), self.skip_deprecated_operations - ) - or should_skip_by_tag(resolved_definition.get("tags"), self.tag) - or should_skip_by_operation_id(resolved_definition.get("operationId"), self.operation_id) + or evaluate_filters(self.filters, (full_path, method, resolved_definition)).is_excluded ): continue parameters = self.collect_parameters( diff --git a/src/schemathesis/utils.py b/src/schemathesis/utils.py index b6a0237062d..bf491a2da10 100644 --- a/src/schemathesis/utils.py +++ b/src/schemathesis/utils.py @@ -284,10 +284,13 @@ def traverse_schema(schema: Schema, callback: Callable[..., Dict[str, Any]], *ar return schema -def _warn_deprecation(*, thing: str, removed_in: str, replacement: str) -> None: +def warn_filtration_arguments(name: str) -> None: + warn_deprecation(thing=f"Argument `{name}`", removed_in="4.0", replacement="`include` or `exclude` methods") + + +def warn_deprecation(*, thing: str, removed_in: str, replacement: str) -> None: warnings.warn( - f"Property `{thing}` is deprecated and will be removed in Schemathesis {removed_in}. " - f"Use `{replacement}` instead.", + f"{thing} is deprecated and will be removed in Schemathesis {removed_in}. " f"Use {replacement} instead.", DeprecationWarning, ) @@ -296,7 +299,7 @@ def deprecated_property(*, removed_in: str, replacement: str) -> Callable: def wrapper(prop: Callable) -> Callable: @property # type: ignore def inner(self: Any) -> Any: - _warn_deprecation(thing=prop.__name__, removed_in=removed_in, replacement=replacement) + warn_deprecation(thing=f"Property `{prop.__name__}`", removed_in=removed_in, replacement=replacement) return prop(self) return inner @@ -307,7 +310,7 @@ def inner(self: Any) -> Any: def deprecated(*, removed_in: str, replacement: str) -> Callable: def wrapper(func: Callable) -> Callable: def inner(*args: Any, **kwargs: Any) -> Any: - _warn_deprecation(thing=func.__name__, removed_in=removed_in, replacement=replacement) + warn_deprecation(thing=f"Function `{func.__name__}`", removed_in=removed_in, replacement=replacement) return func(*args, **kwargs) return inner diff --git a/test/test_filters.py b/test/test_filters.py index 87612c4947a..eec62b18288 100644 --- a/test/test_filters.py +++ b/test/test_filters.py @@ -1,5 +1,7 @@ import pytest +from schemathesis.filters import Exclude, FilterResult, Include, evaluate_filters + from .utils import integer @@ -190,3 +192,31 @@ def test_(request, case): result.stdout.re_match_lines([r"test_operation_id_list_filter.py::test_[GET /v1/foo][P] PASSED"]) result.stdout.re_match_lines([r"test_operation_id_list_filter.py::test_[POST /v1/foo][P] PASSED"]) + + +def predicate(x): + return "foo" in x + + +@pytest.mark.parametrize( + "filters, expected", + ( + ([], FilterResult.INCLUDED), + ([Include(predicate)], FilterResult.INCLUDED), + ([Include(predicate), Include(predicate)], FilterResult.INCLUDED), + ([Exclude(predicate)], FilterResult.EXCLUDED), + ([Exclude(predicate), Include(predicate)], FilterResult.EXCLUDED), + ([Include(predicate), Exclude(predicate)], FilterResult.EXCLUDED), + ), +) +def test_evaluate_filters(filters, expected): + assert evaluate_filters(filters, {"foo": 42}) == expected + + +def test_evaluate_filters_scoped(): + # When filters are evaluated in some particular scope + # Then filters not matching the scope should be ignored + assert ( + evaluate_filters([Exclude(predicate), Include(lambda x: x.startswith("f"), scope="path")], "foo", "path") + == FilterResult.INCLUDED + )