Skip to content

Commit

Permalink
perf: Minor performance improvements
Browse files Browse the repository at this point in the history
Signed-off-by: Dmitry Dygalo <dmitry@dygalo.dev>
  • Loading branch information
Stranger6667 committed May 17, 2024
1 parent 54a705f commit e9a3e89
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 32 deletions.
73 changes: 46 additions & 27 deletions src/schemathesis/specs/openapi/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,19 +158,20 @@ def _operation_iter(self) -> Generator[dict[str, Any], None, None]:
paths = self.raw_schema["paths"]
except KeyError:
return
get_full_path = self.get_full_path
endpoint = self.endpoint
resolve = self.resolver.resolve
for path, methods in paths.items():
full_path = self.get_full_path(path)
if should_skip_endpoint(full_path, self.endpoint):
should_skip = self._should_skip
for path, path_item in paths.items():
full_path = get_full_path(path)
if should_skip_endpoint(full_path, endpoint):
continue
try:
if "$ref" in methods:
_, resolved_methods = resolve(methods["$ref"])
else:
resolved_methods = methods
if "$ref" in path_item:
_, path_item = resolve(path_item["$ref"])
# Straightforward iteration is faster than converting to a set & calculating length.
for method, definition in resolved_methods.items():
if self._should_skip(method, definition):
for method, definition in path_item.items():
if should_skip(method, definition):
continue
yield definition
except SCHEMA_PARSING_ERRORS:
Expand All @@ -188,11 +189,13 @@ def operations_count(self) -> int:
@property
def links_count(self) -> int:
total = 0
resolve = self.resolver.resolve
links_field = self.links_field
for definition in self._operation_iter():
for response in definition.get("responses", {}).values():
if "$ref" in response:
_, response = self.resolver.resolve(response["$ref"])
defined_links = response.get(self.links_field)
_, response = resolve(response["$ref"])
defined_links = response.get(links_field)
if defined_links is not None:
total += len(defined_links)
return total
Expand All @@ -217,6 +220,11 @@ def _add_override(test: GenericTest) -> GenericTest:

return _add_override

def _resolve_until_no_references(self, value: dict[str, Any]) -> dict[str, Any]:
while "$ref" in value:
_, value = self.resolver.resolve(value["$ref"])
return value

def _resolve_shared_parameters(self, path_item: Mapping[str, Any]) -> list[dict[str, Any]]:
return self.resolver.resolve_all(path_item.get("parameters", []), RECURSION_DEPTH_LIMIT - 8)

Expand Down Expand Up @@ -259,29 +267,40 @@ def get_all_operations(
self._raise_invalid_schema(exc)

context = HookContext()
# Optimization: local variables are faster than attribute access
get_full_path = self.get_full_path
endpoint = self.endpoint
dispatch_hook = self.dispatch_hook
resolve_path_item = self._resolve_path_item
resolve_shared_parameters = self._resolve_shared_parameters
resolve_operation = self._resolve_operation
should_skip = self._should_skip
collect_parameters = self.collect_parameters
make_operation = self.make_operation
hooks = self.hooks
for path, path_item in paths.items():
method = None
try:
full_path = self.get_full_path(path) # Should be available for later use
if should_skip_endpoint(full_path, self.endpoint):
full_path = get_full_path(path) # Should be available for later use
if should_skip_endpoint(full_path, endpoint):
continue
self.dispatch_hook("before_process_path", context, path, path_item)
scope, path_item = self._resolve_path_item(path_item)
shared_parameters = self._resolve_shared_parameters(path_item)
dispatch_hook("before_process_path", context, path, path_item)
scope, path_item = resolve_path_item(path_item)
shared_parameters = resolve_shared_parameters(path_item)
for method, entry in path_item.items():
if method not in HTTP_METHODS:
continue
try:
resolved = self._resolve_operation(entry)
if self._should_skip(method, resolved):
resolved = resolve_operation(entry)
if should_skip(method, resolved):
continue
parameters = resolved.get("parameters", ())
parameters = self.collect_parameters(itertools.chain(parameters, shared_parameters), resolved)
operation = self.make_operation(path, method, parameters, entry, resolved, scope)
parameters = collect_parameters(itertools.chain(parameters, shared_parameters), resolved)
operation = make_operation(path, method, parameters, entry, resolved, scope)
context = HookContext(operation=operation)
if (
should_skip_operation(GLOBAL_HOOK_DISPATCHER, context)
or should_skip_operation(self.hooks, context)
or should_skip_operation(hooks, context)
or (hooks and should_skip_operation(hooks, context))
):
continue
Expand Down Expand Up @@ -419,13 +438,15 @@ def get_operation_by_id(self, operation_id: str) -> APIOperation:

def _populate_operation_id_cache(self, cache: OperationCache) -> None:
"""Collect all operation IDs from the schema."""
resolve = self.resolver.resolve
default_scope = self.resolver.resolution_scope
for path, path_item in self.raw_schema.get("paths", {}).items():
# If the path is behind a reference we have to keep the scope
# The scope is used to resolve nested components later on
if "$ref" in path_item:
scope, path_item = self.resolver.resolve(path_item["$ref"])
scope, path_item = resolve(path_item["$ref"])
else:
scope = self.resolver.resolution_scope
scope = default_scope
for key, entry in path_item.items():
if key not in HTTP_METHODS:
continue
Expand Down Expand Up @@ -1100,10 +1121,8 @@ def _get_parameter_serializer(self, definitions: list[dict[str, Any]]) -> Callab
return serialization.serialize_openapi3_parameters(definitions)

def get_request_payload_content_types(self, operation: APIOperation) -> list[str]:
request_body = operation.definition.raw["requestBody"]
while "$ref" in request_body:
_, request_body = self.resolver.resolve(request_body["$ref"])
return list(request_body["content"].keys())
request_body = self._resolve_until_no_references(operation.definition.raw["requestBody"])
return list(request_body["content"])

def prepare_multipart(
self, form_data: FormData, operation: APIOperation
Expand Down
8 changes: 3 additions & 5 deletions src/schemathesis/specs/openapi/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,10 @@ def get_security_definitions(self, schema: dict[str, Any], resolver: RefResolver
"""In Open API 3 security definitions are located in ``components`` and may have references inside."""
components = schema.get("components", {})
security_schemes = components.get("securitySchemes", {})
resolve = resolver.resolve
if "$ref" in security_schemes:
return resolver.resolve(security_schemes["$ref"])[1]
return {
key: resolver.resolve(value["$ref"])[1] if "$ref" in value else value
for key, value in security_schemes.items()
}
return resolve(security_schemes["$ref"])[1]
return {key: resolve(value["$ref"])[1] if "$ref" in value else value for key, value in security_schemes.items()}

def _make_http_auth_parameter(self, definition: dict[str, Any]) -> dict[str, Any]:
schema = make_auth_header_schema(definition)
Expand Down

0 comments on commit e9a3e89

Please sign in to comment.