Skip to content

Commit

Permalink
Enforce required RequestBody (#1652)
Browse files Browse the repository at this point in the history
Fixes #878 
Fixes #1317
  • Loading branch information
RobbeSneyders committed Feb 25, 2023
1 parent 90a734d commit 969c146
Show file tree
Hide file tree
Showing 12 changed files with 127 additions and 38 deletions.
14 changes: 12 additions & 2 deletions connexion/middleware/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

from starlette.exceptions import ExceptionMiddleware as StarletteExceptionMiddleware
from starlette.exceptions import HTTPException
from starlette.requests import Request as StarletteRequest
Expand All @@ -6,6 +8,8 @@

from connexion.exceptions import InternalServerError, ProblemException, problem

logger = logging.getLogger(__name__)


class ExceptionMiddleware(StarletteExceptionMiddleware):
"""Subclass of starlette ExceptionMiddleware to change handling of HTTP exceptions to
Expand All @@ -17,8 +21,10 @@ def __init__(self, *args, **kwargs):
self.add_exception_handler(Exception, self.common_error_handler)

@staticmethod
def problem_handler(_request: StarletteRequest, exception: ProblemException):
response = exception.to_problem()
def problem_handler(_request: StarletteRequest, exc: ProblemException):
logger.exception(exc)

response = exc.to_problem()

return Response(
content=response.body,
Expand All @@ -29,6 +35,8 @@ def problem_handler(_request: StarletteRequest, exception: ProblemException):

@staticmethod
def http_exception(_request: StarletteRequest, exc: HTTPException) -> Response:
logger.exception(exc)

headers = exc.headers

connexion_response = problem(
Expand All @@ -44,6 +52,8 @@ def http_exception(_request: StarletteRequest, exc: HTTPException) -> Response:

@staticmethod
def common_error_handler(_request: StarletteRequest, exc: Exception) -> Response:
logger.exception(exc)

response = InternalServerError().to_problem()

return Response(
Expand Down
1 change: 1 addition & 0 deletions connexion/middleware/request_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send):
scope,
receive,
schema=schema,
required=self._operation.request_body.get("required", False),
nullable=utils.is_nullable(
self._operation.body_definition(mime_type)
),
Expand Down
4 changes: 4 additions & 0 deletions connexion/operations/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ def method(self):
"""
return self._method

@property
def request_body(self):
"""The request body for this operation"""

@property
def path(self):
"""
Expand Down
17 changes: 7 additions & 10 deletions connexion/operations/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ def __init__(
uri_parser_class=uri_parser_class,
)

self._request_body = operation.get("requestBody", {})

self._parameters = operation.get("parameters", [])
if path_parameters:
self._parameters += path_parameters
Expand All @@ -97,9 +95,7 @@ def __init__(
for _, defn in self._responses.items():
response_content_types += defn.get("content", {}).keys()
self._produces = response_content_types or ["application/json"]

request_content = self._request_body.get("content", {})
self._consumes = list(request_content.keys()) or ["application/json"]
self._consumes = None

logger.debug("consumes: %s" % self.consumes)
logger.debug("produces: %s" % self.produces)
Expand All @@ -122,14 +118,17 @@ def from_spec(cls, spec, api, path, method, resolver, *args, **kwargs):

@property
def request_body(self):
return self._request_body
return self._operation.get("requestBody", {})

@property
def parameters(self):
return self._parameters

@property
def consumes(self):
if self._consumes is None:
request_content = self.request_body.get("content", {})
self._consumes = list(request_content.keys()) or ["application/json"]
return self._consumes

@property
Expand Down Expand Up @@ -247,10 +246,8 @@ def body_definition(self, content_type: str = None) -> dict:
The body complete definition for this operation.
**There can be one "body" parameter at most.**
:rtype: dict
"""
if self._request_body:
if self.request_body:
if content_type is None:
# TODO: make content type required
content_type = self.consumes[0]
Expand All @@ -259,7 +256,7 @@ def body_definition(self, content_type: str = None) -> dict:
"this operation accepts multiple content types, using %s",
content_type,
)
content_type_dict = MediaTypeDict(self._request_body.get("content", {}))
content_type_dict = MediaTypeDict(self.request_body.get("content", {}))
res = content_type_dict.get(content_type, {})
return self.with_definitions(res)
return {}
50 changes: 31 additions & 19 deletions connexion/operations/swagger2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import typing as t

from connexion.exceptions import InvalidSpecification
from connexion.http_facts import FORM_CONTENT_TYPES
from connexion.operations.abstract import AbstractOperation
from connexion.uri_parsing import Swagger2URIParser
from connexion.utils import deep_get
Expand Down Expand Up @@ -107,8 +106,6 @@ def __init__(

self._responses = operation.get("responses", {})

self._body_definitions = {}

@classmethod
def from_spec(cls, spec, api, path, method, resolver, *args, **kwargs):
return cls(
Expand All @@ -127,6 +124,36 @@ def from_spec(cls, spec, api, path, method, resolver, *args, **kwargs):
**kwargs,
)

@property
def request_body(self) -> dict:
if not hasattr(self, "_request_body"):
body_params = []
form_params = []
for parameter in self.parameters:
if parameter["in"] == "body":
body_params.append(parameter)
elif parameter["in"] == "formData":
form_params.append(parameter)

if len(body_params) > 1:
raise InvalidSpecification(
f"{self.method} {self.path}: There can be one 'body' parameter at most"
)

if body_params and form_params:
raise InvalidSpecification(
f"{self.method} {self.path}: 'body' and 'formData' parameters are mutually exclusive"
)

if body_params:
self._request_body = self._transform_json(body_params[0])
elif form_params:
self._request_body = self._transform_form(form_params)
else:
self._request_body = {}

return self._request_body

@property
def parameters(self):
return self._parameters
Expand Down Expand Up @@ -229,22 +256,7 @@ def body_definition(self, content_type: str = None) -> dict:
**There can be one "body" parameter at most.**
"""
if self._body_definitions.get(content_type) is None:
if content_type in FORM_CONTENT_TYPES:
form_parameters = [p for p in self.parameters if p["in"] == "formData"]
_body_definition = self._transform_form(form_parameters)
else:
body_parameters = [p for p in self.parameters if p["in"] == "body"]
if len(body_parameters) > 1:
raise InvalidSpecification(
"{method} {path} There can be one 'body' parameter at most".format(
method=self.method, path=self.path
)
)
body_parameter = body_parameters[0] if body_parameters else {}
_body_definition = self._transform_json(body_parameter)
self._body_definitions[content_type] = _body_definition
return self._body_definitions[content_type]
return self.request_body

def _transform_json(self, body_parameter: dict) -> dict:
"""Translate Swagger2 json parameters into OpenAPI 3 jsonschema spec."""
Expand Down
16 changes: 11 additions & 5 deletions connexion/validators/form_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,18 @@ def __init__(
*,
schema: dict,
validator: t.Type[Draft4Validator] = None,
uri_parser: t.Optional[AbstractURIParser] = None,
required=False,
nullable=False,
encoding: str,
uri_parser: t.Optional[AbstractURIParser] = None,
strict_validation: bool,
) -> None:
self._scope = scope
self._receive = receive
self.schema = schema
self.has_default = schema.get("default", False)
self.nullable = nullable
self.required = required
validator_cls = validator or Draft4RequestValidator
self.validator = validator_cls(schema, format_checker=draft4_format_checker)
self.uri_parser = uri_parser
Expand All @@ -50,10 +52,14 @@ def form_parser_cls(self):
def check_empty(self):
"""`receive` is never called if body is empty, so we need to check this case at
initialization."""
if not int(self.headers.get("content-length", 0)) and self.schema.get(
"required", []
):
self._validate({})
if not int(self.headers.get("content-length", 0)):
# TODO: default should be passed along and content-length updated
if self.schema.get("default"):
self.validate(self.schema.get("default"))
elif self.required: # RequestBody itself is required
raise BadRequestProblem("RequestBody is required")
elif self.schema.get("required", []): # Required top level properties
self._validate({})

@classmethod
def _error_path_message(cls, exception):
Expand Down
17 changes: 17 additions & 0 deletions connexion/validators/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import jsonschema
from jsonschema import Draft4Validator, ValidationError, draft4_format_checker
from starlette.datastructures import Headers
from starlette.types import Receive, Scope, Send

from connexion.exceptions import BadRequestProblem, NonConformingResponseBody
Expand All @@ -23,6 +24,7 @@ def __init__(
*,
schema: dict,
validator: t.Type[Draft4Validator] = Draft4RequestValidator,
required=False,
nullable=False,
encoding: str,
**kwargs,
Expand All @@ -32,8 +34,23 @@ def __init__(
self.schema = schema
self.has_default = schema.get("default", False)
self.nullable = nullable
self.required = required
self.validator = validator(schema, format_checker=draft4_format_checker)
self.encoding = encoding
self.headers = Headers(scope=scope)
self.check_empty()

def check_empty(self):
"""receive` is never called if body is empty, so we need to check this case at
initialization."""
if not int(self.headers.get("content-length", 0)):
# TODO: default should be passed along and content-length updated
if self.schema.get("default"):
self.validate(self.schema.get("default"))
elif self.required: # RequestBody itself is required
raise BadRequestProblem("RequestBody is required")
elif self.schema.get("required", []): # Required top level properties
self.validate({})

@classmethod
def _error_path_message(cls, exception):
Expand Down
11 changes: 11 additions & 0 deletions tests/api/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,17 @@ def test_default_object_body(simple_app):
assert response == 1


def test_required_body(simple_app):
app_client = simple_app.test_client()
resp = app_client.post(
"/v1.0/test-required-body", headers={"content-type": "application/json"}
)
assert resp.status_code == 400

resp = app_client.post("/v1.0/test-required-body", json={"foo": "bar"})
assert resp.status_code == 200


def test_empty_object_body(simple_app):
app_client = simple_app.test_client()
resp = app_client.post(
Expand Down
4 changes: 4 additions & 0 deletions tests/fakeapi/hello/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,10 @@ def test_default_object_body(stack):
return {"stack": stack}


def test_required_body(body):
return body


def test_nested_additional_properties(body):
return body

Expand Down
13 changes: 13 additions & 0 deletions tests/fixtures/simple/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,19 @@ paths:
$ref: '#/components/schemas/new_stack'
default:
image_version: default_image
/test-required-body:
post:
summary: Test if a required RequestBody is enforced.
operationId: fakeapi.hello.test_required_body
responses:
'200':
description: OK
requestBody:
required: true
content:
application/json:
schema:
type: object
/test-nested-additional-properties:
post:
summary: Test if nested additionalProperties are cast
Expand Down
14 changes: 14 additions & 0 deletions tests/fixtures/simple/swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,20 @@ paths:
200:
description: OK

/test-required-body:
post:
summary: Test if a required RequestBody is enforced.
operationId: fakeapi.hello.test_required_body
parameters:
- name: body
in: body
required: true
schema:
type: object
responses:
200:
description: OK

/test-default-integer-body:
post:
summary: Test if default integer body param is passed to handler.
Expand Down
4 changes: 2 additions & 2 deletions tests/test_operation2.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,10 +546,10 @@ def test_multi_body(api):
operation.body_schema()

exception = exc_info.value
assert str(exception) == "GET endpoint There can be one 'body' parameter at most"
assert str(exception) == "GET endpoint: There can be one 'body' parameter at most"
assert (
repr(exception)
== """<InvalidSpecification: "GET endpoint There can be one 'body' parameter at most">"""
== """<InvalidSpecification: "GET endpoint: There can be one 'body' parameter at most">"""
)


Expand Down

0 comments on commit 969c146

Please sign in to comment.