Skip to content

Commit

Permalink
Merge pull request #289 from 31z4/consumes
Browse files Browse the repository at this point in the history
Respect what the APIs consume
  • Loading branch information
hjacobs committed Sep 29, 2016
2 parents 6089f3d + 0d014d0 commit 92e4645
Show file tree
Hide file tree
Showing 12 changed files with 107 additions and 39 deletions.
5 changes: 5 additions & 0 deletions connexion/api.py
Expand Up @@ -113,6 +113,10 @@ def __init__(self, swagger_yaml_path, base_url=None, arguments=None,
# API calls.
self.produces = self.specification.get('produces', list()) # type: List[str]

# A list of MIME types the APIs can consume. This is global to all APIs but can be overridden on specific
# API calls.
self.consumes = self.specification.get('consumes', ['application/json']) # type: List[str]

self.security = self.specification.get('security')
self.security_definitions = self.specification.get('securityDefinitions', dict())
logger.debug('Security Definitions: %s', self.security_definitions)
Expand Down Expand Up @@ -167,6 +171,7 @@ def add_operation(self, method, path, swagger_operation, path_parameters):
path_parameters=path_parameters,
operation=swagger_operation,
app_produces=self.produces,
app_consumes=self.consumes,
app_security=self.security,
security_definitions=self.security_definitions,
definitions=self.definitions,
Expand Down
1 change: 1 addition & 0 deletions connexion/app.py
Expand Up @@ -155,6 +155,7 @@ def _resolver_error_handler(self, *args, **kwargs):
kwargs['operation'] = {
'operationId': 'connexion.handlers.ResolverErrorHandler',
}
kwargs.setdefault('app_consumes', ['application/json'])
return ResolverErrorHandler(self.resolver_error, *args, **kwargs)

def add_error_handler(self, error_code, function):
Expand Down
17 changes: 11 additions & 6 deletions connexion/decorators/parameter.py
Expand Up @@ -7,7 +7,7 @@
import six
import werkzeug.exceptions as exceptions

from ..utils import boolean, is_null, is_nullable
from ..utils import all_json, boolean, is_null, is_nullable

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -58,13 +58,15 @@ def get_val_from_param(value, query_param):
return make_type(value, query_param["type"])


def parameter_to_arg(parameters, function):
def parameter_to_arg(parameters, consumes, function):
"""
Pass query and body parameters as keyword arguments to handler function.
See (https://github.com/zalando/connexion/issues/59)
:param parameters: All the parameters of the handler functions
:type parameters: dict|None
:param consumes: The list of content types the operation consumes
:type consumes: list
:param function: The handler function for the REST endpoint.
:type function: function|None
"""
Expand All @@ -87,10 +89,13 @@ def parameter_to_arg(parameters, function):
def wrapper(*args, **kwargs):
logger.debug('Function Arguments: %s', arguments)

try:
request_body = flask.request.json
except exceptions.BadRequest:
request_body = None
if all_json(consumes):
try:
request_body = flask.request.json
except exceptions.BadRequest:
request_body = None
else:
request_body = flask.request.data

if default_body and not request_body:
request_body = default_body
Expand Down
4 changes: 2 additions & 2 deletions connexion/decorators/response.py
Expand Up @@ -21,7 +21,7 @@
from ..exceptions import (NonConformingResponseBody,
NonConformingResponseHeaders)
from ..problem import problem
from ..utils import produces_json
from ..utils import all_json
from .decorator import BaseDecorator
from .validation import ResponseBodyValidator

Expand Down Expand Up @@ -91,7 +91,7 @@ def is_json_schema_compatible(self, response_definition):
if not response_definition:
return False
return ('schema' in response_definition and
(produces_json([self.mimetype]) or self.mimetype == 'text/plain'))
(all_json([self.mimetype]) or self.mimetype == 'text/plain'))

def __call__(self, function):
"""
Expand Down
17 changes: 10 additions & 7 deletions connexion/decorators/validation.py
Expand Up @@ -24,7 +24,7 @@
from werkzeug import FileStorage

from ..problem import problem
from ..utils import boolean, is_null, is_nullable
from ..utils import all_json, boolean, is_null, is_nullable

logger = logging.getLogger('connexion.decorators.validation')

Expand Down Expand Up @@ -97,12 +97,14 @@ def validate_parameter_list(parameter_type, request_params, spec_params):


class RequestBodyValidator(object):
def __init__(self, schema, is_null_value_valid=False):
def __init__(self, schema, consumes, is_null_value_valid=False):
"""
:param schema: The schema of the request body
:param consumes: The list of content types the operation consumes
:param is_nullable: Flag to indicate if null is accepted as valid value.
"""
self.schema = schema
self.consumes = consumes
self.has_default = schema.get('default', False)
self.is_null_value_valid = is_null_value_valid

Expand All @@ -114,12 +116,13 @@ def __call__(self, function):

@functools.wraps(function)
def wrapper(*args, **kwargs):
data = flask.request.json
if all_json(self.consumes):
data = flask.request.json

logger.debug("%s validating schema...", flask.request.url)
error = self.validate_schema(data)
if error and not self.has_default:
return error
logger.debug("%s validating schema...", flask.request.url)
error = self.validate_schema(data)
if error and not self.has_default:
return error

response = function(*args, **kwargs)
return response
Expand Down
15 changes: 9 additions & 6 deletions connexion/operation.py
Expand Up @@ -27,7 +27,7 @@
from .decorators.validation import (ParameterValidator, RequestBodyValidator,
TypeValidationError)
from .exceptions import InvalidSpecification
from .utils import flaskify_endpoint, is_nullable, produces_json
from .utils import all_json, flaskify_endpoint, is_nullable

logger = logging.getLogger('connexion.operation')

Expand Down Expand Up @@ -105,7 +105,7 @@ class Operation(SecureOperation):
A single API operation on a path.
"""

def __init__(self, method, path, operation, resolver, app_produces,
def __init__(self, method, path, operation, resolver, app_produces, app_consumes,
path_parameters=None, app_security=None, security_definitions=None,
definitions=None, parameter_definitions=None, response_definitions=None,
validate_responses=False, strict_validation=False, randomize_endpoint=None):
Expand All @@ -128,6 +128,8 @@ def __init__(self, method, path, operation, resolver, app_produces,
:param resolver: Callable that maps operationID to a function
:param app_produces: list of content types the application can return by default
:type app_produces: list
:param app_consumes: list of content types the application consumes by default
:type app_consumes: list
:param path_parameters: Parameters defined in the path level
:type path_parameters: list
:param app_security: list of security rules the application uses by default
Expand Down Expand Up @@ -172,6 +174,7 @@ def __init__(self, method, path, operation, resolver, app_produces,

self.security = operation.get('security', app_security)
self.produces = operation.get('produces', app_produces)
self.consumes = operation.get('consumes', app_consumes)

resolution = resolver.resolve(self)
self.operation_id = resolution.operation_id
Expand Down Expand Up @@ -276,7 +279,7 @@ def get_mimetype(self):
:rtype str
"""
if produces_json(self.produces):
if all_json(self.produces):
try:
return self.produces[0]
except IndexError:
Expand Down Expand Up @@ -325,7 +328,7 @@ def function(self):
:rtype: types.FunctionType
"""

function = parameter_to_arg(self.parameters, self.__undecorated_function)
function = parameter_to_arg(self.parameters, self.consumes, self.__undecorated_function)

if self.validate_responses:
logger.debug('... Response validation enabled.')
Expand Down Expand Up @@ -371,7 +374,7 @@ def __content_type_decorator(self):
logger.debug('... Produces: %s', self.produces, extra=vars(self))

mimetype = self.get_mimetype()
if produces_json(self.produces): # endpoint will return json
if all_json(self.produces): # endpoint will return json
logger.debug('... Produces json', extra=vars(self))
jsonify = Jsonifier(mimetype)
return jsonify
Expand All @@ -390,7 +393,7 @@ def __validation_decorators(self):
if self.parameters:
yield ParameterValidator(self.parameters, strict_validation=self.strict_validation)
if self.body_schema:
yield RequestBodyValidator(self.body_schema,
yield RequestBodyValidator(self.body_schema, self.consumes,
is_nullable(self.body_definition))

@property
Expand Down
22 changes: 11 additions & 11 deletions connexion/utils.py
Expand Up @@ -143,29 +143,29 @@ def is_json_mimetype(mimetype):
return maintype == 'application' and (subtype == 'json' or subtype.endswith('+json'))


def produces_json(produces):
def all_json(mimetypes):
"""
Returns True if all mimetypes in produces are serialized with json
Returns True if all mimetypes are serialized with json
:type produces: list
:type mimetypes: list
:rtype: bool
>>> produces_json(['application/json'])
>>> all_json(['application/json'])
True
>>> produces_json(['application/x.custom+json'])
>>> all_json(['application/x.custom+json'])
True
>>> produces_json([])
>>> all_json([])
True
>>> produces_json(['application/xml'])
>>> all_json(['application/xml'])
False
>>> produces_json(['text/json'])
>>> all_json(['text/json'])
False
>>> produces_json(['application/json', 'other/type'])
>>> all_json(['application/json', 'other/type'])
False
>>> produces_json(['application/json', 'application/x.custom+json'])
>>> all_json(['application/json', 'application/x.custom+json'])
True
"""
return all(is_json_mimetype(mimetype) for mimetype in produces)
return all(is_json_mimetype(mimetype) for mimetype in mimetypes)


def boolean(s):
Expand Down
7 changes: 7 additions & 0 deletions tests/api/test_responses.py
Expand Up @@ -166,3 +166,10 @@ def test_bad_operations(bad_operations_app):

resp = app_client.post('/v1.0/welcome')
assert resp.status_code == 501


def test_text_request(simple_app):
app_client = simple_app.app.test_client()

resp = app_client.post('/v1.0/text-request', data='text')
assert resp.status_code == 200
4 changes: 4 additions & 0 deletions tests/fakeapi/hello.py
Expand Up @@ -330,6 +330,10 @@ def get_data_as_binary():
return get_blob_data(), 200, {'Content-Type': 'application/octet-stream'}


def get_data_as_text(post_param):
return ''


def get_invalid_response():
return {"simple": object()}

Expand Down
16 changes: 16 additions & 0 deletions tests/fixtures/simple/swagger.yaml
Expand Up @@ -630,6 +630,22 @@ paths:
schema:
type: object

/text-request:
post:
operationId: fakeapi.hello.get_data_as_text
consumes:
- "text/plain"
parameters:
- name: post_param
description: Just a testing parameter.
in: body
required: true
schema:
type: string
responses:
200:
description: OK


definitions:
new_stack:
Expand Down

0 comments on commit 92e4645

Please sign in to comment.