Skip to content

Commit

Permalink
perf: Optimize get_by_operation_id
Browse files Browse the repository at this point in the history
Signed-off-by: Dmitry Dygalo <dmitry@dygalo.dev>
  • Loading branch information
Stranger6667 committed May 15, 2024
1 parent 0a075c6 commit 1af652e
Showing 1 changed file with 48 additions and 20 deletions.
68 changes: 48 additions & 20 deletions src/schemathesis/specs/openapi/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,26 @@
SCHEMA_PARSING_ERRORS = (KeyError, AttributeError, jsonschema.exceptions.RefResolutionError)


@dataclass
class OperationCacheEntry:
path: str
method: str
# The resolution scope of the operation
scope: str
# Parameters shared among all operations in the path
shared_parameters: list[dict[str, Any]]
# Unresolved operation definition
operation: dict[str, Any]
__slots__ = ("path", "method", "scope", "shared_parameters", "operation")


@dataclass(eq=False, repr=False)
class BaseOpenAPISchema(BaseSchema):
nullable_name: ClassVar[str] = ""
links_field: ClassVar[str] = ""
header_required_field: ClassVar[str] = ""
security: ClassVar[BaseSecurityProcessor] = None # type: ignore
_operations_by_id: dict[str, APIOperation] = field(init=False)
_operations_by_id: dict[str, OperationCacheEntry] = field(init=False)
_inline_reference_cache: dict[str, Any] = field(default_factory=dict)
# Inline references cache can be populated from multiple threads, therefore we need some synchronisation to avoid
# excessive resolving
Expand Down Expand Up @@ -375,30 +388,45 @@ def get_response_schema(self, definition: dict[str, Any], scope: str) -> tuple[l
def get_operation_by_id(self, operation_id: str) -> APIOperation:
"""Get an `APIOperation` instance by its `operationId`."""
if not hasattr(self, "_operations_by_id"):
self._operations_by_id = dict(self._group_operations_by_id())
self._operations_by_id = self._collect_operation_ids()
try:
return self._operations_by_id[operation_id]
entry = self._operations_by_id[operation_id]
except KeyError as exc:
matches = get_close_matches(operation_id, list(self._operations_by_id))
self._on_missing_operation(operation_id, exc, matches)

def _group_operations_by_id(self) -> Generator[tuple[str, APIOperation], None, None]:
for path, methods in self.raw_schema["paths"].items():
scope, raw_methods = self._resolve_methods(methods)
common_parameters = self.resolver.resolve_all(methods.get("parameters", []), RECURSION_DEPTH_LIMIT - 8)
for method, definition in methods.items():
if method not in HTTP_METHODS or "operationId" not in definition:
shared_parameters = self.resolver.resolve_all(entry.shared_parameters, RECURSION_DEPTH_LIMIT - 8)
self.resolver.push_scope(entry.scope)
try:
operation = self.resolver.resolve_all(entry.operation, RECURSION_DEPTH_LIMIT - 8)
finally:
self.resolver.pop_scope()
raw_parameters = itertools.chain(operation.get("parameters", ()), shared_parameters)
parameters = self.collect_parameters(raw_parameters, operation)
definition = OperationDefinition(entry.operation, operation, entry.scope)
return self.make_operation(entry.path, entry.method, parameters, definition)

def _collect_operation_ids(self) -> dict[str, OperationCacheEntry]:
"""Collect all operation IDs from the schema."""
ids = {}
for path, path_item in self.raw_schema.get("paths", {}).items():
# If the path is behind a reference we have to keep the scope
# The scope is used to resolve nested components later on
if "$ref" in path_item:
scope, path_item = self.resolver.resolve(path_item["$ref"])
else:
scope = self.resolver.resolution_scope
for key, entry in path_item.items():
if key not in HTTP_METHODS:
continue
self.resolver.push_scope(scope)
try:
resolved_definition = self.resolver.resolve_all(definition, RECURSION_DEPTH_LIMIT - 8)
finally:
self.resolver.pop_scope()
parameters = self.collect_parameters(
itertools.chain(resolved_definition.get("parameters", ()), common_parameters), resolved_definition
)
raw_definition = OperationDefinition(raw_methods[method], resolved_definition, scope)
yield resolved_definition["operationId"], self.make_operation(path, method, parameters, raw_definition)
if "operationId" in entry:
ids[entry["operationId"]] = OperationCacheEntry(
path=path,
method=key,
scope=scope,
shared_parameters=path_item.get("parameters", []),
operation=entry,
)
return ids

def get_operation_by_reference(self, reference: str) -> APIOperation:
"""Get local or external `APIOperation` instance by reference.
Expand Down

0 comments on commit 1af652e

Please sign in to comment.