Skip to content

Commit

Permalink
chore: GraphQL schema loaders now accept single `DataGenerationMethod…
Browse files Browse the repository at this point in the history
…` instances for the `data_generation_methods` argument
  • Loading branch information
Stranger6667 committed Oct 17, 2021
1 parent f485bcf commit d31b017
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 29 deletions.
22 changes: 11 additions & 11 deletions src/schemathesis/specs/graphql/loaders.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pathlib
from typing import IO, Any, Callable, Dict, Iterable, Optional, Union, cast
from typing import IO, Any, Callable, Dict, Optional, Union, cast

import graphql
import requests
Expand All @@ -9,11 +9,11 @@
from werkzeug import Client
from yarl import URL

from ...constants import DEFAULT_DATA_GENERATION_METHODS, CodeSampleStyle, DataGenerationMethod
from ...constants import DEFAULT_DATA_GENERATION_METHODS, CodeSampleStyle
from ...exceptions import HTTPError
from ...hooks import HookContext, dispatch
from ...types import PathLike
from ...utils import WSGIResponse, require_relative_url, setup_headers
from ...types import DataGenerationMethodInput, PathLike
from ...utils import WSGIResponse, prepare_data_generation_methods, require_relative_url, setup_headers
from .schemas import GraphQLSchema

INTROSPECTION_QUERY = graphql.get_introspection_query()
Expand All @@ -25,7 +25,7 @@ def from_path(
*,
app: Any = None,
base_url: Optional[str] = None,
data_generation_methods: Iterable[DataGenerationMethod] = DEFAULT_DATA_GENERATION_METHODS,
data_generation_methods: DataGenerationMethodInput = DEFAULT_DATA_GENERATION_METHODS,
code_sample_style: str = CodeSampleStyle.default().name,
encoding: str = "utf8",
) -> GraphQLSchema:
Expand All @@ -51,7 +51,7 @@ def from_url(
app: Any = None,
base_url: Optional[str] = None,
port: Optional[int] = None,
data_generation_methods: Iterable[DataGenerationMethod] = DEFAULT_DATA_GENERATION_METHODS,
data_generation_methods: DataGenerationMethodInput = DEFAULT_DATA_GENERATION_METHODS,
code_sample_style: str = CodeSampleStyle.default().name,
**kwargs: Any,
) -> GraphQLSchema:
Expand Down Expand Up @@ -86,7 +86,7 @@ def from_file(
*,
app: Any = None,
base_url: Optional[str] = None,
data_generation_methods: Iterable[DataGenerationMethod] = DEFAULT_DATA_GENERATION_METHODS,
data_generation_methods: DataGenerationMethodInput = DEFAULT_DATA_GENERATION_METHODS,
code_sample_style: str = CodeSampleStyle.default().name,
location: Optional[str] = None,
) -> GraphQLSchema:
Expand Down Expand Up @@ -123,7 +123,7 @@ def from_dict(
app: Any = None,
base_url: Optional[str] = None,
location: Optional[str] = None,
data_generation_methods: Iterable[DataGenerationMethod] = DEFAULT_DATA_GENERATION_METHODS,
data_generation_methods: DataGenerationMethodInput = DEFAULT_DATA_GENERATION_METHODS,
code_sample_style: str = CodeSampleStyle.default().name,
) -> GraphQLSchema:
"""Load GraphQL schema from a Python dictionary.
Expand All @@ -141,7 +141,7 @@ def from_dict(
location=location,
base_url=base_url,
app=app,
data_generation_methods=data_generation_methods,
data_generation_methods=prepare_data_generation_methods(data_generation_methods),
code_sample_style=_code_sample_style,
) # type: ignore

Expand All @@ -151,7 +151,7 @@ def from_wsgi(
app: Any,
*,
base_url: Optional[str] = None,
data_generation_methods: Iterable[DataGenerationMethod] = DEFAULT_DATA_GENERATION_METHODS,
data_generation_methods: DataGenerationMethodInput = DEFAULT_DATA_GENERATION_METHODS,
code_sample_style: str = CodeSampleStyle.default().name,
**kwargs: Any,
) -> GraphQLSchema:
Expand Down Expand Up @@ -183,7 +183,7 @@ def from_asgi(
app: Any,
*,
base_url: Optional[str] = None,
data_generation_methods: Iterable[DataGenerationMethod] = DEFAULT_DATA_GENERATION_METHODS,
data_generation_methods: DataGenerationMethodInput = DEFAULT_DATA_GENERATION_METHODS,
code_sample_style: str = CodeSampleStyle.default().name,
**kwargs: Any,
) -> GraphQLSchema:
Expand Down
29 changes: 14 additions & 15 deletions src/schemathesis/specs/openapi/loaders.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import io
import pathlib
from contextlib import suppress
from typing import IO, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import IO, Any, Callable, Dict, List, Optional, Tuple, Union
from urllib.parse import urljoin

import jsonschema
Expand All @@ -13,17 +13,22 @@
from werkzeug.test import Client
from yarl import URL

from ...constants import DEFAULT_DATA_GENERATION_METHODS, CodeSampleStyle, DataGenerationMethod
from ...constants import DEFAULT_DATA_GENERATION_METHODS, CodeSampleStyle
from ...exceptions import HTTPError, SchemaLoadingError
from ...hooks import HookContext, dispatch
from ...lazy import LazySchema
from ...types import Filter, NotSet, PathLike
from ...utils import NOT_SET, StringDatesYAMLLoader, WSGIResponse, require_relative_url, setup_headers
from ...types import DataGenerationMethodInput, Filter, NotSet, PathLike
from ...utils import (
NOT_SET,
StringDatesYAMLLoader,
WSGIResponse,
prepare_data_generation_methods,
require_relative_url,
setup_headers,
)
from . import definitions, validation
from .schemas import BaseOpenAPISchema, OpenApi30, SwaggerV20

DataGenerationMethodInput = Union[DataGenerationMethod, Iterable[DataGenerationMethod]]


def from_path(
path: PathLike,
Expand Down Expand Up @@ -196,7 +201,7 @@ def init_openapi_2() -> SwaggerV20:
operation_id=operation_id,
skip_deprecated_operations=skip_deprecated_operations,
validate_schema=validate_schema,
data_generation_methods=_prepare_data_generation_methods(data_generation_methods),
data_generation_methods=prepare_data_generation_methods(data_generation_methods),
code_sample_style=_code_sample_style,
location=location,
)
Expand All @@ -213,7 +218,7 @@ def init_openapi_3() -> OpenApi30:
operation_id=operation_id,
skip_deprecated_operations=skip_deprecated_operations,
validate_schema=validate_schema,
data_generation_methods=_prepare_data_generation_methods(data_generation_methods),
data_generation_methods=prepare_data_generation_methods(data_generation_methods),
code_sample_style=_code_sample_style,
location=location,
)
Expand All @@ -229,12 +234,6 @@ def init_openapi_3() -> OpenApi30:
raise SchemaLoadingError("Unsupported schema type")


def _prepare_data_generation_methods(data_generation_methods: DataGenerationMethodInput) -> List[DataGenerationMethod]:
if isinstance(data_generation_methods, DataGenerationMethod):
return [data_generation_methods]
return list(data_generation_methods)


# It is a common case when API schemas are stored in the YAML format and HTTP status codes are numbers
# The Open API spec requires HTTP status codes as strings
DOC_ENTRY = "https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.3.md#patterned-fields-1"
Expand Down Expand Up @@ -310,7 +309,7 @@ def from_pytest_fixture(
operation_id=operation_id,
skip_deprecated_operations=skip_deprecated_operations,
validate_schema=validate_schema,
data_generation_methods=_prepare_data_generation_methods(data_generation_methods),
data_generation_methods=prepare_data_generation_methods(data_generation_methods),
code_sample_style=_code_sample_style,
)

Expand Down
4 changes: 3 additions & 1 deletion src/schemathesis/types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Set, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Set, Tuple, Union

from hypothesis.strategies import SearchStrategy

if TYPE_CHECKING:
from . import DataGenerationMethod
from .hooks import HookContext

PathLike = Union[Path, str] # pragma: no mutate
Expand Down Expand Up @@ -31,3 +32,4 @@ class NotSet:
RawAuth = Tuple[str, str] # pragma: no mutate
# Generic test with any arguments and no return
GenericTest = Callable[..., None] # pragma: no mutate
DataGenerationMethodInput = Union["DataGenerationMethod", Iterable["DataGenerationMethod"]]
10 changes: 8 additions & 2 deletions src/schemathesis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@
from werkzeug.wrappers import Response as BaseResponse
from werkzeug.wrappers.json import JSONMixin

from .constants import USER_AGENT
from .constants import USER_AGENT, DataGenerationMethod
from .exceptions import UsageError
from .types import Filter, GenericTest, NotSet, RawAuth
from .types import DataGenerationMethodInput, Filter, GenericTest, NotSet, RawAuth

try:
from yaml import CSafeLoader as SafeLoader
Expand Down Expand Up @@ -423,3 +423,9 @@ def maybe_set_assertion_message(exc: AssertionError, check_name: str) -> str:
message = f"Check '{check_name}' failed"
exc.args = (message,)
return message


def prepare_data_generation_methods(data_generation_methods: DataGenerationMethodInput) -> List[DataGenerationMethod]:
if isinstance(data_generation_methods, DataGenerationMethod):
return [data_generation_methods]
return list(data_generation_methods)

0 comments on commit d31b017

Please sign in to comment.