Skip to content

Commit

Permalink
Create abstract validator classes (#1653)
Browse files Browse the repository at this point in the history
  • Loading branch information
RobbeSneyders committed Mar 2, 2023
2 parents 969c146 + 4e8e57e commit 50cfc83
Show file tree
Hide file tree
Showing 12 changed files with 354 additions and 288 deletions.
7 changes: 7 additions & 0 deletions connexion/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ def _do_resolve(node):
return res


def format_error_with_path(exception: ValidationError) -> str:
"""Format a `ValidationError` with path to error."""
error_path = ".".join(str(item) for item in exception.path)
error_path_msg = f" - '{error_path}'" if error_path else ""
return error_path_msg


def allow_nullable(validation_fn: t.Callable) -> t.Callable:
"""Extend an existing validation function, so it allows nullable values to be null."""

Expand Down
2 changes: 1 addition & 1 deletion connexion/middleware/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,5 @@ def common_error_handler(_request: StarletteRequest, exc: Exception) -> Response

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
# Needs to be set so starlette router throws exceptions instead of returning error responses
scope["app"] = self
scope["app"] = "connexion"
await super().__call__(scope, receive, send)
8 changes: 2 additions & 6 deletions connexion/middleware/request_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@ def validate_mime_type(self, mime_type: str) -> None:
)

async def __call__(self, scope: Scope, receive: Receive, send: Send):
receive_fn = receive

# Validate parameters & headers
uri_parser_class = self._operation._uri_parser_class
uri_parser = uri_parser_class(
Expand Down Expand Up @@ -100,8 +98,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send):
)
else:
validator = body_validator(
scope,
receive,
schema=schema,
required=self._operation.request_body.get("required", False),
nullable=utils.is_nullable(
Expand All @@ -113,9 +109,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send):
self._operation.parameters, self._operation.body_definition()
),
)
receive_fn = await validator.wrapped_receive()
receive = await validator.wrap_receive(receive, scope=scope)

await self.next_app(scope, receive_fn, send)
await self.next_app(scope, receive, send)


class RequestValidationAPI(RoutedAPI[RequestValidationOperation]):
Expand Down
10 changes: 3 additions & 7 deletions connexion/middleware/response_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,8 @@ def validate_required_headers(
raise NonConformingResponseHeaders(detail=msg)

async def __call__(self, scope: Scope, receive: Receive, send: Send):

send_fn = send

async def wrapped_send(message: t.MutableMapping[str, t.Any]) -> None:
nonlocal send_fn
nonlocal send

if message["type"] == "http.response.start":
status = str(message["status"])
Expand All @@ -107,16 +104,15 @@ async def wrapped_send(message: t.MutableMapping[str, t.Any]) -> None:
else:
validator = body_validator(
scope,
send,
schema=self._operation.response_schema(status, mime_type),
nullable=utils.is_nullable(
self._operation.response_definition(status, mime_type)
),
encoding=encoding,
)
send_fn = validator.send
send = validator.wrap_send(send)

return await send_fn(message)
return await send(message)

await self.next_app(scope, receive, wrapped_send)

Expand Down
5 changes: 3 additions & 2 deletions connexion/operations/swagger2.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def _transform_form(self, form_parameters: t.List[dict]) -> dict:

default = param.get("default")
if default is not None:
prop["default"] = default
defaults[param["name"]] = default

nullable = param.get("x-nullable")
Expand Down Expand Up @@ -320,11 +321,11 @@ def _transform_form(self, form_parameters: t.List[dict]) -> dict:
"schema": {
"type": "object",
"properties": properties,
"default": defaults,
"required": required,
}
}

if defaults:
definition["schema"]["default"] = defaults
if encoding:
definition["encoding"] = encoding

Expand Down
4 changes: 4 additions & 0 deletions connexion/validators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from connexion.datastructures import MediaTypeDict

from .abstract import ( # NOQA
AbstractRequestBodyValidator,
AbstractResponseBodyValidator,
)
from .form_data import FormDataValidator, MultiPartFormDataValidator
from .json import DefaultsJSONRequestBodyValidator # NOQA
from .json import (
Expand Down
205 changes: 205 additions & 0 deletions connexion/validators/abstract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
"""
This module defines a Validator interface with base functionality that can be subclassed
for custom validators provided to the RequestValidationMiddleware.
"""
import copy
import json
import typing as t

from starlette.datastructures import Headers, MutableHeaders
from starlette.types import Receive, Scope, Send

from connexion.exceptions import BadRequestProblem


class AbstractRequestBodyValidator:
"""
Validator interface with base functionality that can be subclassed for custom validators.
.. note: Validators load the whole body into memory, which can be a problem for large payloads.
"""

MUTABLE_VALIDATION = False
"""
Whether mutations to the body during validation should be transmitted via the receive channel.
Note that this does not apply to the substitution of a missing body with the default body, which always
updates the receive channel.
"""
MAX_MESSAGE_LENGTH = 256000
"""Maximum message length that will be sent via the receive channel for mutated bodies."""

def __init__(
self,
*,
schema: dict,
required: bool = False,
nullable: bool = False,
encoding: str,
strict_validation: bool,
**kwargs,
):
"""
:param schema: Schema of operation to validate
:param required: Whether RequestBody is required
:param nullable: Whether RequestBody is nullable
:param encoding: Encoding of body (passed via Content-Type header)
:param kwargs: Additional arguments for subclasses
:param strict_validation: Whether to allow parameters not defined in the spec
"""
self._schema = schema
self._nullable = nullable
self._required = required
self._encoding = encoding
self._strict_validation = strict_validation

async def _parse(
self, stream: t.AsyncGenerator[bytes, None], scope: Scope
) -> t.Any:
"""Parse the incoming stream."""

def _validate(self, body: t.Any) -> t.Optional[dict]:
"""
Validate the parsed body.
:raises: :class:`connexion.exceptions.BadRequestProblem`
"""

def _insert_body(self, receive: Receive, *, body: t.Any, scope: Scope) -> Receive:
"""
Insert messages transmitting the body at the start of the `receive` channel.
This method updates the provided `scope` in place with the right `Content-Length` header.
"""
if body is None:
return receive

bytes_body = json.dumps(body).encode(self._encoding)

# Update the content-length header
new_scope = copy.deepcopy(scope)
headers = MutableHeaders(scope=new_scope)
headers["content-length"] = str(len(bytes_body))

# Wrap in new receive channel
messages = (
{
"type": "http.request",
"body": bytes_body[i : i + self.MAX_MESSAGE_LENGTH],
"more_body": i + self.MAX_MESSAGE_LENGTH < len(bytes_body),
}
for i in range(0, len(bytes_body), self.MAX_MESSAGE_LENGTH)
)

receive = self._insert_messages(receive, messages=messages)

return receive

@staticmethod
def _insert_messages(
receive: Receive, *, messages: t.Iterable[t.MutableMapping[str, t.Any]]
) -> Receive:
"""Insert messages at the start of the `receive` channel."""

async def receive_() -> t.MutableMapping[str, t.Any]:
for message in messages:
return message
return await receive()

return receive_

async def wrap_receive(self, receive: Receive, *, scope: Scope) -> Receive:
"""
Wrap the provided `receive` channel with request body validation.
This method updates the provided `scope` in place with the right `Content-Length` header.
"""
# Handle missing bodies
headers = Headers(scope=scope)
if not int(headers.get("content-length", 0)):
body = self._schema.get("default")
if body is None and self._required:
raise BadRequestProblem("RequestBody is required")
# The default body is encoded as a `receive` channel to mimic an incoming body
receive = self._insert_body(receive, body=body, scope=scope)

# The receive channel is converted to a stream for convenient access
messages = []

async def stream() -> t.AsyncGenerator[bytes, None]:
more_body = True
while more_body:
message = await receive()
messages.append(message)
more_body = message.get("more_body", False)
yield message.get("body", b"")
yield b""

# The body is parsed and validated
body = await self._parse(stream(), scope=scope)
if not (body is None and self._nullable):
self._validate(body)

# If MUTABLE_VALIDATION is enabled, include any changes made during validation in the messages to send
if self.MUTABLE_VALIDATION:
# Include changes made during validation
receive = self._insert_body(receive, body=body, scope=scope)
else:
# Serialize original messages
receive = self._insert_messages(receive, messages=messages)

return receive


class AbstractResponseBodyValidator:
"""
Validator interface with base functionality that can be subclassed for custom validators.
.. note: Validators load the whole body into memory, which can be a problem for large payloads.
"""

def __init__(
self,
scope: Scope,
*,
schema: dict,
nullable: bool = False,
encoding: str,
) -> None:
self._scope = scope
self._schema = schema
self._nullable = nullable
self._encoding = encoding

def _parse(self, stream: t.Generator[bytes, None, None]) -> t.Any:
"""Parse the incoming stream."""

def _validate(self, body: t.Any) -> t.Optional[dict]:
"""
Validate the body.
:raises: :class:`connexion.exceptions.NonConformingResponse`
"""

def wrap_send(self, send: Send) -> Send:
"""Wrap the provided send channel with response body validation"""

messages = []

async def send_(message: t.MutableMapping[str, t.Any]) -> None:
messages.append(message)

if message["type"] == "http.response.start" or message.get(
"more_body", False
):
return

stream = (message.get("body", b"") for message in messages)
body = self._parse(stream)

if not (body is None and self._nullable):
self._validate(body)

while messages:
await send(messages.pop(0))

return send_

0 comments on commit 50cfc83

Please sign in to comment.