Skip to content

Commit

Permalink
perf: Optimize direct access to Open API schemas
Browse files Browse the repository at this point in the history
Signed-off-by: Dmitry Dygalo <dmitry@dygalo.dev>
  • Loading branch information
Stranger6667 committed May 16, 2024
1 parent 3400a10 commit db8f7ba
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 147 deletions.
2 changes: 0 additions & 2 deletions benches/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,10 @@ def test_length(raw_schema, loader):
BBCI_OPERATION_KEY = ("/categories", "get")
BBCI_SCHEMA_WITH_OPERATIONS_CACHE = schemathesis.from_dict(BBCI)
BBCI_SCHEMA_WITH_OPERATIONS_CACHE.get_operation_by_id(BBCI_OPERATION_ID)
_ = BBCI_SCHEMA_WITH_OPERATIONS_CACHE.operations
VMWARE_OPERATION_ID = "listProblemEvents"
VMWARE_OPERATION_KEY = ("/entities/problems", "get")
VMWARE_SCHEMA_WITH_OPERATIONS_CACHE = schemathesis.from_dict(VMWARE)
VMWARE_SCHEMA_WITH_OPERATIONS_CACHE.get_operation_by_id(VMWARE_OPERATION_ID)
_ = VMWARE_SCHEMA_WITH_OPERATIONS_CACHE.operations
UNIVERSE_OPERATION_KEY = ("Query", "manageTickets")
UNIVERSE_SCHEMA_WITH_OPERATIONS_CACHE = schemathesis.graphql.from_dict(UNIVERSE)
UNIVERSE_SCHEMA_WITH_OPERATIONS_CACHE[UNIVERSE_OPERATION_KEY[0]][UNIVERSE_OPERATION_KEY[1]]
Expand Down
32 changes: 7 additions & 25 deletions src/schemathesis/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from __future__ import annotations

from collections.abc import Mapping, MutableMapping
from collections.abc import Mapping
from contextlib import nullcontext
from dataclasses import dataclass, field
from functools import lru_cache
Expand Down Expand Up @@ -99,7 +99,7 @@ class BaseSchema(Mapping):
sanitize_output: bool = True

def __iter__(self) -> Iterator[str]:
return iter(self.operations)
raise NotImplementedError

def __getitem__(self, item: str) -> APIOperationMap:
__tracebackhide__ = True
Expand Down Expand Up @@ -158,18 +158,6 @@ def get_base_url(self) -> str:
def validate(self) -> None:
raise NotImplementedError

@property
def operations(self) -> dict[str, APIOperationMap]:
if not hasattr(self, "_operations"):
operations = self.get_all_operations()
self._operations = self._store_operations(operations)
return self._operations

def _store_operations(
self, operations: Generator[Result[APIOperation, OperationSchemaError], None, None]
) -> dict[str, APIOperationMap]:
raise NotImplementedError

@property
def operations_count(self) -> int:
raise NotImplementedError
Expand Down Expand Up @@ -451,33 +439,27 @@ def as_strategy(
"""Build a strategy for generating test cases for all defined API operations."""
assert len(self) > 0, "No API operations found"
strategies = [
operation.as_strategy(
operation.ok().as_strategy(
hooks=hooks,
auth_storage=auth_storage,
data_generation_method=data_generation_method,
generation_config=generation_config,
**kwargs,
)
for operations in self.operations.values()
for operation in operations.values()
for operation in self.get_all_operations(hooks=hooks)
if isinstance(operation, Ok)
]
return combine_strategies(strategies)


@dataclass
class APIOperationMap(MutableMapping):
class APIOperationMap(Mapping):
_schema: BaseSchema
_data: MutableMapping

def __setitem__(self, key: str, value: APIOperation) -> None:
self._data[key] = value
_data: Mapping

def __getitem__(self, item: str) -> APIOperation:
return self._data[item]

def __delitem__(self, key: str) -> None:
del self._data[key]

def __len__(self) -> int:
return len(self._data)

Expand Down
20 changes: 0 additions & 20 deletions src/schemathesis/specs/graphql/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,26 +143,6 @@ def on_missing_operation(self, item: str, exc: KeyError) -> NoReturn:
message += f". Did you mean `{matches[0]}`?"
raise OperationNotFound(message=message, item=item) from exc

def _store_operations(
self, operations: Generator[Result[APIOperation, OperationSchemaError], None, None]
) -> dict[str, APIOperationMap]:
output: dict[str, APIOperationMap] = {}
for result in operations:
if isinstance(result, Ok):
operation = result.ok()
definition = cast(GraphQLOperationDefinition, operation.definition)
type_name = (
definition.type_.name if isinstance(definition.type_, graphql.GraphQLNamedType) else "Unknown"
)
if type_name not in output:
map = APIOperationMap(self, {})
map._data = FieldMap(map, definition.root_type, definition.type_)
output[type_name] = map
else:
map = output[type_name]
map[definition.field_name] = operation
return output

def get_full_path(self, path: str) -> str:
return self.base_path

Expand Down
9 changes: 9 additions & 0 deletions src/schemathesis/specs/openapi/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

if TYPE_CHECKING:
from ...models import APIOperation
from ...schemas import APIOperationMap


@dataclass
Expand Down Expand Up @@ -43,6 +44,8 @@ class OperationCache:
_reference_to_operation: dict[Reference, int] = field(default_factory=dict)
# The actual operations
_operations: list[APIOperation] = field(default_factory=list)
# Cache for operation maps
_maps: dict[str, APIOperationMap] = field(default_factory=dict)

@property
def known_operation_ids(self) -> list[str]:
Expand Down Expand Up @@ -108,3 +111,9 @@ def get_operation_by_traversal_key(self, scope: str, path: str, method: str) ->
if idx is not None:
return self._operations[idx]
return None

def get_map(self, key: str) -> APIOperationMap | None:
return self._maps.get(key)

def insert_map(self, key: str, value: APIOperationMap) -> None:
self._maps[key] = value
163 changes: 92 additions & 71 deletions src/schemathesis/specs/openapi/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
ClassVar,
Generator,
Iterable,
Iterator,
Mapping,
NoReturn,
Sequence,
TypeVar,
cast,
)
from urllib.parse import urlsplit

Expand All @@ -28,33 +31,34 @@

from ... import experimental, failures
from ..._compat import MultipleFailures
from ..._override import CaseOverride, set_override_mark, check_no_override_mark
from ..._override import CaseOverride, check_no_override_mark, set_override_mark
from ...auths import AuthStorage
from ...generation import DataGenerationMethod, GenerationConfig
from ...constants import HTTP_METHODS, NOT_SET
from ...exceptions import (
InternalError,
OperationNotFound,
OperationSchemaError,
SchemaError,
SchemaErrorType,
UsageError,
get_missing_content_type_error,
get_response_parsing_error,
get_schema_validation_error,
SchemaError,
SchemaErrorType,
OperationNotFound,
)
from ...generation import DataGenerationMethod, GenerationConfig
from ...hooks import GLOBAL_HOOK_DISPATCHER, HookContext, HookDispatcher, should_skip_operation
from ...internal.copy import fast_deepcopy
from ...internal.jsonschema import traverse_schema
from ...internal.result import Err, Ok, Result
from ...models import APIOperation, Case, OperationDefinition
from ...schemas import BaseSchema, APIOperationMap
from ...schemas import APIOperationMap, BaseSchema
from ...stateful import Stateful, StatefulTest
from ...stateful.state_machine import APIStateMachine
from ...transports.content_types import is_json_media_type, parse_content_type
from ...transports.responses import get_json
from ...types import Body, Cookies, FormData, Headers, NotSet, PathParameters, Query, GenericTest
from ...types import Body, Cookies, FormData, GenericTest, Headers, NotSet, PathParameters, Query
from . import links, serialization
from ._cache import OperationCache
from ._hypothesis import get_case_strategy
from .converter import to_json_schema, to_json_schema_recursive
from .definitions import OPENAPI_30_VALIDATOR, OPENAPI_31_VALIDATOR, SWAGGER_20_VALIDATOR
Expand All @@ -74,8 +78,7 @@
OpenAPI30Parameter,
OpenAPIParameter,
)
from ._cache import OperationCache
from .references import RECURSION_DEPTH_LIMIT, ConvertingResolver, InliningResolver, resolve_pointer, UNRESOLVABLE
from .references import RECURSION_DEPTH_LIMIT, UNRESOLVABLE, ConvertingResolver, InliningResolver, resolve_pointer
from .security import BaseSecurityProcessor, OpenAPISecurityProcessor, SwaggerSecurityProcessor
from .stateful import create_state_machine

Expand Down Expand Up @@ -115,16 +118,24 @@ def __repr__(self) -> str:
info = self.raw_schema["info"]
return f"<{self.__class__.__name__} for {info['title']} {info['version']}>"

def _get_operation_map(self, key: str) -> APIOperationMap:
return self.operations[key]
def __iter__(self) -> Iterator[str]:
return iter(self.raw_schema.get("paths", {}))

def _store_operations(
self, operations: Generator[Result[APIOperation, OperationSchemaError], None, None]
) -> dict[str, APIOperationMap]:
return operations_to_dict(operations)
def _get_operation_map(self, path: str) -> APIOperationMap:
cache = self._operation_cache
map = cache.get_map(path)
if map is not None:
return map
path_item = self.raw_schema.get("paths", {})[path]
scope, path_item = self._resolve_path_item(path_item)
self.dispatch_hook("before_process_path", HookContext(), path, path_item)
map = APIOperationMap(self, {})
map._data = MethodMap(map, scope, path, CaseInsensitiveDict(path_item))
cache.insert_map(path, map)
return map

def on_missing_operation(self, item: str, exc: KeyError) -> NoReturn:
matches = get_close_matches(item, list(self.operations))
matches = get_close_matches(item, list(self))
self._on_missing_operation(item, exc, matches)

def _on_missing_operation(self, item: str, exc: KeyError, matches: list[str]) -> NoReturn:
Expand Down Expand Up @@ -560,46 +571,26 @@ def add_link(
"""
if parameters is None and request_body is None:
raise ValueError("You need to provide `parameters` or `request_body`.")
if hasattr(self, "_operations"):
delattr(self, "_operations")
for operation, methods in self.raw_schema["paths"].items():
if operation == source.path:
# Methods should be completely resolved now, otherwise they might miss a resolving scope when
# they will be fully resolved later
methods = self.resolver.resolve_all(methods, RECURSION_DEPTH_LIMIT - 8)
found = False
for method, definition in methods.items():
if method.upper() == source.method.upper():
found = True
links.add_link(
responses=definition["responses"],
links_field=self.links_field,
parameters=parameters,
request_body=request_body,
status_code=status_code,
target=target,
name=name,
)
# If methods are behind a reference, then on the next resolving they will miss the new link
# Therefore we need to modify it this way
self.raw_schema["paths"][operation][method] = definition
# The reference should be removed completely, otherwise new keys in this dictionary will be ignored
# due to the `$ref` keyword behavior
self.raw_schema["paths"][operation].pop("$ref", None)
if found:
return
name = f"{source.method.upper()} {source.path}"
# Use a name without basePath, as the user doesn't use it.
# E.g. `source=schema["/users/"]["POST"]` without a prefix
message = f"No such API operation: `{name}`."
possibilities = [
f"{op.ok().method.upper()} {op.ok().path}" for op in self.get_all_operations() if isinstance(op, Ok)
]
matches = get_close_matches(name, possibilities)
if matches:
message += f" Did you mean `{matches[0]}`?"
message += " Check if the requested API operation passes the filters in the schema."
raise ValueError(message)
# TODO: Avoid adding it twice
definition = self[source.path][source.method].definition
links.add_link(
responses=definition.resolved["responses"],
links_field=self.links_field,
parameters=parameters,
request_body=request_body,
status_code=status_code,
target=target,
name=name,
)
links.add_link(
responses=definition.raw["responses"],
links_field=self.links_field,
parameters=parameters,
request_body=request_body,
status_code=status_code,
target=target,
name=name,
)

def get_links(self, operation: APIOperation) -> dict[str, dict[str, Any]]:
result: dict[str, dict[str, Any]] = defaultdict(dict)
Expand Down Expand Up @@ -808,30 +799,60 @@ def in_scopes(resolver: jsonschema.RefResolver, scopes: list[str]) -> Generator[
yield


def operations_to_dict(
operations: Generator[Result[APIOperation, OperationSchemaError], None, None],
) -> dict[str, APIOperationMap]:
output: dict[str, APIOperationMap] = {}
for result in operations:
if isinstance(result, Ok):
operation = result.ok()
output.setdefault(operation.path, APIOperationMap(operation.schema, MethodMap()))
output[operation.path][operation.method] = operation
return output


class MethodMap(CaseInsensitiveDict):
@dataclass
class MethodMap(Mapping):
"""Container for accessing API operations.
Provides a more specific error message if API operation is not found.
"""

_parent: APIOperationMap
# Reference resolution scope
_scope: str
# Methods are stored for this path
_path: str
# Storage for definitions
_path_item: CaseInsensitiveDict

def __len__(self) -> int:
return len(self._path_item)

def __iter__(self) -> Iterator[str]:
return iter(self._path_item)

def _init_operation(self, method: str) -> APIOperation:
method = method.lower()
# TODO: Prevent accessing something that is not a method
operation = self._path_item[method]
schema = cast(BaseOpenAPISchema, self._parent._schema)
cache = schema._operation_cache
path = self._path
scope = self._scope
instance = cache.get_operation_by_traversal_key(scope, path, method)
if instance is not None:
return instance
shared_parameters = self._path_item.get("parameters", [])
shared_parameters = schema.resolver.resolve_all(shared_parameters, RECURSION_DEPTH_LIMIT - 8)
schema.resolver.push_scope(scope)
try:
resolved = schema.resolver.resolve_all(operation, RECURSION_DEPTH_LIMIT - 8)
finally:
schema.resolver.pop_scope()
raw_parameters = itertools.chain(resolved.get("parameters", ()), shared_parameters)
parameters = schema.collect_parameters(raw_parameters, resolved)
definition = OperationDefinition(operation, resolved, scope)
initialized = schema.make_operation(path, method, parameters, definition)
cache.insert_operation_by_traversal_key(scope, path, method, initialized)
return initialized

def __getitem__(self, item: str) -> APIOperation:
try:
return super().__getitem__(item)
return self._init_operation(item)
except KeyError as exc:
available_methods = ", ".join(map(str.upper, self))
message = f"Method `{item}` not found. Available methods: {available_methods}"
message = f"Method `{item.upper()}` not found."
if available_methods:
message += f" Available methods: {available_methods}"
raise KeyError(message) from exc


Expand Down
2 changes: 0 additions & 2 deletions test/specs/graphql/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,6 @@ def test_unknown_field_name(graphql_schema, name, expected):
def test_field_map_operations(graphql_schema):
assert len(graphql_schema["Query"]) == 2
assert list(iter(graphql_schema["Query"])) == ["getBooks", "getAuthors"]
del graphql_schema["Query"]["getBooks"]
assert len(graphql_schema["Query"]) == 1


def test_repr(graphql_schema):
Expand Down

0 comments on commit db8f7ba

Please sign in to comment.