Skip to content

Commit

Permalink
refactor: Inner details of accessing API operations
Browse files Browse the repository at this point in the history
Signed-off-by: Dmitry Dygalo <dmitry@dygalo.dev>
  • Loading branch information
Stranger6667 committed May 16, 2024
1 parent db8f7ba commit bb266a1
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 72 deletions.
6 changes: 5 additions & 1 deletion docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ Changelog
:version:`Unreleased <v3.28.1...HEAD>` - TBD
--------------------------------------------

**Changed**:

- **INTERNAL**: Remove the ability to mutate components used in ``schema["/path"]["METHOD"]`` access patterns.

**Fixed**

- Not serializing shared parameters for an API operation.
Expand All @@ -13,7 +17,7 @@ Changelog

**Performance**

- Optimize ``schema["/path"]["methods"]`` access patterns and reduce memory usage.
- Optimize ``schema["/path"]["METHOD"]`` access patterns and reduce memory usage.
- Optimize ``get_operation_by_id`` method performance and reduce memory usage.
- Optimize ``get_operation_by_reference`` method performance.
- Less copying during schema traversal.
Expand Down
2 changes: 2 additions & 0 deletions src/schemathesis/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,8 @@ class OperationDefinition(Generic[D]):
resolved: D
scope: str

__slots__ = ("raw", "resolved", "scope")

def __contains__(self, item: str | int) -> bool:
return item in self.resolved

Expand Down
11 changes: 3 additions & 8 deletions src/schemathesis/specs/graphql/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
cast,
TYPE_CHECKING,
NoReturn,
MutableMapping,
Mapping,
Iterator,
)
from urllib.parse import urlsplit, urlunsplit
Expand Down Expand Up @@ -285,7 +285,7 @@ def get_tags(self, operation: APIOperation) -> list[str] | None:


@dataclass
class FieldMap(MutableMapping):
class FieldMap(Mapping):
"""Container for accessing API operations.
Provides a more specific error message if API operation is not found.
Expand All @@ -295,12 +295,7 @@ class FieldMap(MutableMapping):
_root_type: RootType
_operation_type: graphql.GraphQLObjectType

def __setitem__(self, key: str, value: APIOperation) -> None:
schema = cast(GraphQLSchema, self._parent._schema)
schema._operation_cache.insert_operation(key, value)

def __delitem__(self, key: str) -> None:
del self._operation_type.fields[key]
__slots__ = ("_parent", "_root_type", "_operation_type")

def __len__(self) -> int:
return len(self._operation_type.fields)
Expand Down
10 changes: 5 additions & 5 deletions src/schemathesis/specs/openapi/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ class OperationCacheEntry:
method: str
# The resolution scope of the operation
scope: str
# Parameters shared among all operations in the path
shared_parameters: list[dict[str, Any]]
# Parent path item
path_item: dict[str, Any]
# Unresolved operation definition
operation: dict[str, Any]
__slots__ = ("path", "method", "scope", "shared_parameters", "operation")
__slots__ = ("path", "method", "scope", "path_item", "operation")


# During traversal, we need to keep track of the scope, path, and method
Expand Down Expand Up @@ -66,12 +66,12 @@ def insert_definition_by_id(
path: str,
method: str,
scope: str,
shared_parameters: list[dict[str, Any]],
path_item: dict[str, Any],
operation: dict[str, Any],
) -> None:
"""Insert a new operation definition into cache."""
self._id_to_definition[operation_id] = OperationCacheEntry(
path=path, method=method, scope=scope, shared_parameters=shared_parameters, operation=operation
path=path, method=method, scope=scope, path_item=path_item, operation=operation
)

def get_definition_by_id(self, operation_id: str) -> OperationCacheEntry:
Expand Down
105 changes: 47 additions & 58 deletions src/schemathesis/specs/openapi/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,19 @@ def _add_override(test: GenericTest) -> GenericTest:

return _add_override

def _resolve_shared_parameters(self, path_item: Mapping[str, Any]) -> list[dict[str, Any]]:
return self.resolver.resolve_all(path_item.get("parameters", []), RECURSION_DEPTH_LIMIT - 8)

def _resolve_operation(self, operation: dict[str, Any]) -> dict[str, Any]:
return self.resolver.resolve_all(operation, RECURSION_DEPTH_LIMIT - 8)

def _collect_operation_parameters(
self, path_item: Mapping[str, Any], operation: dict[str, Any]
) -> list[OpenAPIParameter]:
shared_parameters = self._resolve_shared_parameters(path_item)
parameters = operation.get("parameters", ())
return self.collect_parameters(itertools.chain(parameters, shared_parameters), operation)

def get_all_operations(
self, hooks: HookDispatcher | None = None
) -> Generator[Result[APIOperation, OperationSchemaError], None, None]:
Expand Down Expand Up @@ -254,29 +267,17 @@ def get_all_operations(
continue
self.dispatch_hook("before_process_path", context, path, path_item)
scope, path_item = self._resolve_path_item(path_item)
shared_parameters = self.resolver.resolve_all(
path_item.get("parameters", []), RECURSION_DEPTH_LIMIT - 8
)
for method, definition in path_item.items():
shared_parameters = self._resolve_shared_parameters(path_item)
for method, entry in path_item.items():
if method not in HTTP_METHODS:
continue
try:
# Setting a low recursion limit doesn't solve the problem with recursive references & inlining
# too much but decreases the number of cases when Schemathesis stuck on this step.
self.resolver.push_scope(scope)
try:
resolved_definition = self.resolver.resolve_all(definition, RECURSION_DEPTH_LIMIT - 8)
finally:
self.resolver.pop_scope()
# Only method definitions are parsed
if self._should_skip(method, resolved_definition):
resolved = self._resolve_operation(entry)
if self._should_skip(method, resolved):
continue
parameters = self.collect_parameters(
itertools.chain(resolved_definition.get("parameters", ()), shared_parameters),
resolved_definition,
)
# To prevent recursion errors we need to pass not resolved schema as well
# It could be used for response validation
raw_definition = OperationDefinition(path_item[method], resolved_definition, scope)
operation = self.make_operation(path, method, parameters, raw_definition)
parameters = resolved.get("parameters", ())
parameters = self.collect_parameters(itertools.chain(parameters, shared_parameters), resolved)
operation = self.make_operation(path, method, parameters, entry, resolved, scope)
context = HookContext(operation=operation)
if (
should_skip_operation(GLOBAL_HOOK_DISPATCHER, context)
Expand Down Expand Up @@ -348,15 +349,17 @@ def make_operation(
path: str,
method: str,
parameters: list[OpenAPIParameter],
raw_definition: OperationDefinition,
raw: dict[str, Any],
resolved: dict[str, Any],
scope: str,
) -> APIOperation:
"""Create JSON schemas for the query, body, etc from Swagger parameters definitions."""
__tracebackhide__ = True
base_url = self.get_base_url()
operation: APIOperation[OpenAPIParameter, Case] = APIOperation(
path=path,
method=method,
definition=raw_definition,
definition=OperationDefinition(raw, resolved, scope),
base_url=base_url,
app=self.app,
schema=self,
Expand Down Expand Up @@ -407,16 +410,9 @@ def get_operation_by_id(self, operation_id: str) -> APIOperation:
instance = cache.get_operation_by_traversal_key(entry.scope, entry.path, entry.method)
if instance is not None:
return instance
shared_parameters = self.resolver.resolve_all(entry.shared_parameters, RECURSION_DEPTH_LIMIT - 8)
self.resolver.push_scope(entry.scope)
try:
resolved = self.resolver.resolve_all(entry.operation, RECURSION_DEPTH_LIMIT - 8)
finally:
self.resolver.pop_scope()
raw_parameters = itertools.chain(resolved.get("parameters", ()), shared_parameters)
parameters = self.collect_parameters(raw_parameters, resolved)
definition = OperationDefinition(entry.operation, resolved, entry.scope)
initialized = self.make_operation(entry.path, entry.method, parameters, definition)
resolved = self._resolve_operation(entry.operation)
parameters = self._collect_operation_parameters(entry.path_item, resolved)
initialized = self.make_operation(entry.path, entry.method, parameters, entry.operation, resolved, entry.scope)
cache.insert_operation_by_traversal_key(entry.scope, entry.path, entry.method, initialized)
cache.insert_operation_by_id(operation_id, initialized)
return initialized
Expand All @@ -439,7 +435,7 @@ def _populate_operation_id_cache(self, cache: OperationCache) -> None:
path=path,
method=key,
scope=scope,
shared_parameters=path_item.get("parameters", []),
path_item=path_item,
operation=entry,
)

Expand All @@ -449,28 +445,24 @@ def get_operation_by_reference(self, reference: str) -> APIOperation:
Reference example: #/paths/~1users~1{user_id}/patch
"""
cache = self._operation_cache
operation = cache.get_operation_by_reference(reference)
if operation is not None:
return operation
scope, data = self.resolver.resolve(reference)
cached = cache.get_operation_by_reference(reference)
if cached is not None:
return cached
scope, operation = self.resolver.resolve(reference)
path, method = scope.rsplit("/", maxsplit=2)[-2:]
path = path.replace("~1", "/").replace("~0", "~")
# Check the traversal cache as it could've been populated in other places
traversal_key = (self.resolver.resolution_scope, path, method)
operation = cache.get_operation_by_traversal_key(*traversal_key)
if operation is not None:
return operation
resolved_definition = self.resolver.resolve_all(data)
cached = cache.get_operation_by_traversal_key(*traversal_key)
if cached is not None:
return cached
resolved = self._resolve_operation(operation)
parent_ref, _ = reference.rsplit("/", maxsplit=1)
_, methods = self.resolver.resolve(parent_ref)
common_parameters = self.resolver.resolve_all(methods.get("parameters", []), RECURSION_DEPTH_LIMIT - 8)
parameters = self.collect_parameters(
itertools.chain(resolved_definition.get("parameters", ()), common_parameters), resolved_definition
)
raw_definition = OperationDefinition(data, resolved_definition, scope)
initialized = self.make_operation(path, method, parameters, raw_definition)
cache.insert_operation_by_reference(reference, initialized)
_, path_item = self.resolver.resolve(parent_ref)
parameters = self._collect_operation_parameters(path_item, resolved)
initialized = self.make_operation(path, method, parameters, operation, resolved, scope)
cache.insert_operation_by_traversal_key(*traversal_key, initialized)
cache.insert_operation_by_reference(reference, initialized)
return initialized

def get_case_strategy(
Expand Down Expand Up @@ -814,6 +806,8 @@ class MethodMap(Mapping):
# Storage for definitions
_path_item: CaseInsensitiveDict

__slots__ = ("_parent", "_scope", "_path", "_path_item")

def __len__(self) -> int:
return len(self._path_item)

Expand All @@ -822,7 +816,6 @@ def __iter__(self) -> Iterator[str]:

def _init_operation(self, method: str) -> APIOperation:
method = method.lower()
# TODO: Prevent accessing something that is not a method
operation = self._path_item[method]
schema = cast(BaseOpenAPISchema, self._parent._schema)
cache = schema._operation_cache
Expand All @@ -831,17 +824,13 @@ def _init_operation(self, method: str) -> APIOperation:
instance = cache.get_operation_by_traversal_key(scope, path, method)
if instance is not None:
return instance
shared_parameters = self._path_item.get("parameters", [])
shared_parameters = schema.resolver.resolve_all(shared_parameters, RECURSION_DEPTH_LIMIT - 8)
schema.resolver.push_scope(scope)
try:
resolved = schema.resolver.resolve_all(operation, RECURSION_DEPTH_LIMIT - 8)
resolved = schema._resolve_operation(operation)
finally:
schema.resolver.pop_scope()
raw_parameters = itertools.chain(resolved.get("parameters", ()), shared_parameters)
parameters = schema.collect_parameters(raw_parameters, resolved)
definition = OperationDefinition(operation, resolved, scope)
initialized = schema.make_operation(path, method, parameters, definition)
parameters = schema._collect_operation_parameters(self._path_item, resolved)
initialized = schema.make_operation(path, method, parameters, operation, resolved, scope)
cache.insert_operation_by_traversal_key(scope, path, method, initialized)
return initialized

Expand Down

0 comments on commit bb266a1

Please sign in to comment.