Skip to content

Commit

Permalink
chg: Accept bare class for schema arguments (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rick Riensche committed Sep 17, 2019
1 parent 8b30e5a commit 367bfa8
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 5 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,5 @@ man/
.pytest_cache

#vscode
.vscode/
.vscode/
pip-wheel-metadata
21 changes: 19 additions & 2 deletions flask_rebar/rebar.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,16 @@
from flask_rebar.utils.request_utils import get_header_params_or_400
from flask_rebar.utils.request_utils import get_json_body_params_or_400
from flask_rebar.utils.request_utils import get_query_string_params_or_400
from flask_rebar.utils.request_utils import normalize_schema
from flask_rebar.utils.deprecation import deprecated, deprecated_parameters
from flask_rebar.swagger_generation import SwaggerV2Generator
from flask_rebar.swagger_ui import create_swagger_ui_blueprint

# Deal with maintaining (for now at least) support for 2.7+:
try:
from collections.abc import Mapping # 3.3+
except ImportError:
from collections import Mapping # 2.7+

# To catch redirection exceptions, app.errorhandler expects 301 in versions
# below 0.11.0 but the exception itself in versions greater than 0.11.0.
Expand Down Expand Up @@ -459,8 +465,19 @@ def add_handler(
:param Type[USE_DEFAULT]|None|str mimetype:
Content-Type header to add to the response schema
"""
if isinstance(response_body_schema, marshmallow.Schema):
response_body_schema = {200: response_body_schema}
# Fix #115: if we were passed bare classes we'll go ahead and instantiate
headers_schema = normalize_schema(headers_schema)
request_body_schema = normalize_schema(request_body_schema)
query_string_schema = normalize_schema(query_string_schema)
if response_body_schema:
# Ensure we wrap in appropriate default (200) dict if we were passed a single Schema or class:
if not isinstance(response_body_schema, Mapping):
response_body_schema = {200: response_body_schema}
# use normalize_schema to convert any class reference(s) to instantiated schema(s):
response_body_schema = {
code: normalize_schema(schema)
for (code, schema) in response_body_schema.items()
}

# authenticators can be a list of Authenticators, a single Authenticator, USE_DEFAULT, or None
if isinstance(authenticators, Authenticator) or authenticators is USE_DEFAULT:
Expand Down
3 changes: 2 additions & 1 deletion flask_rebar/utils/request_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from flask_rebar import compat
from flask_rebar import errors
from flask_rebar import messages
from flask_rebar.utils.defaults import USE_DEFAULT


class HeadersProxy(compat.Mapping):
Expand Down Expand Up @@ -92,7 +93,7 @@ def normalize_schema(schema):
This allows for either an instance of a marshmallow.Schema or the class
itself to be passed to functions.
"""
if not isinstance(schema, marshmallow.Schema):
if schema not in (None, USE_DEFAULT) and not isinstance(schema, marshmallow.Schema):
schema = schema()
return schema

Expand Down
78 changes: 77 additions & 1 deletion tests/test_rebar.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import marshmallow as m
from flask import Flask
from werkzeug.routing import RequestRedirect

from flask_rebar import messages
from flask_rebar import HeaderApiKeyAuthenticator, SwaggerV3Generator
Expand Down Expand Up @@ -677,3 +676,80 @@ def test_redirects_for_missing_trailing_slash(self):
resp = app.test_client().get(path="/with_trailing_slash")
self.assertIn(resp.status_code, (301, 308))
self.assertTrue(resp.headers["Location"].endswith("/with_trailing_slash/"))

def test_bare_class_schemas_handled(self):
rebar = Rebar()
registry = rebar.create_handler_registry()

expected_foo = FooSchema().load({"uid": "some_uid", "name": "Namey McNamerton"})
expected_headers = {"x-name": "Header Name"}

def get_foo(*args, **kwargs):
return expected_foo

def post_foo(*args, **kwargs):
return expected_foo

register_endpoint(
registry=registry,
method="GET",
path="/my_get_endpoint",
headers_schema=HeadersSchema,
response_body_schema={200: FooSchema},
query_string_schema=FooListSchema,
func=get_foo,
)

register_endpoint(
registry=registry,
method="POST",
path="/my_post_endpoint",
request_body_schema=FooListSchema,
response_body_schema=FooSchema,
func=post_foo,
)

app = create_rebar_app(rebar)
# violate headers schema:
resp = app.test_client().get(path="/my_get_endpoint?name=QuerystringName")
self.assertEqual(resp.status_code, 400)
self.assertEqual(
get_json_from_resp(resp)["message"], messages.header_validation_failed
)
# violate querystring schema:
resp = app.test_client().get(path="/my_get_endpoint", headers=expected_headers)
self.assertEqual(resp.status_code, 400)
self.assertEqual(
get_json_from_resp(resp)["message"], messages.query_string_validation_failed
)
# valid request:
resp = app.test_client().get(
path="/my_get_endpoint?name=QuerystringName", headers=expected_headers
)
self.assertEqual(resp.status_code, 200)
self.assertEqual(get_json_from_resp(resp), expected_foo.data)

resp = app.test_client().post(
path="/my_post_endpoint",
data='{"wrong": "Posted Name"}',
content_type="application/json",
)
self.assertEqual(resp.status_code, 400)
self.assertEqual(
get_json_from_resp(resp)["message"], messages.body_validation_failed
)

resp = app.test_client().post(
path="/my_post_endpoint",
data='{"name": "Posted Name"}',
content_type="application/json",
)
self.assertEqual(resp.status_code, 200)

# ensure Swagger generation doesn't break (Issue #115)
from flask_rebar import SwaggerV2Generator, SwaggerV3Generator

swagger = SwaggerV2Generator().generate(registry)
self.assertIsNotNone(swagger) # really only care that it didn't barf
swagger = SwaggerV3Generator().generate(registry)
self.assertIsNotNone(swagger)

0 comments on commit 367bfa8

Please sign in to comment.