Skip to content

Commit

Permalink
Injection
Browse files Browse the repository at this point in the history
  • Loading branch information
timmartin19 committed Jun 8, 2015
1 parent 11e2cfc commit 54f56e5
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 36 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.rst
@@ -1,7 +1,7 @@
0.1.5 (unreleased)
==================

- Nothing changed yet.
- Injected argument getter. No longer attempts to automatically load json from the data.


0.1.4 (2015-05-08)
Expand Down
58 changes: 32 additions & 26 deletions flask_ripozo/dispatcher.py
Expand Up @@ -14,7 +14,6 @@

from werkzeug.routing import Map

import json
import six


Expand All @@ -37,13 +36,28 @@ def exception_handler(dispatcher, accepted_mimetypes, exc):
raise exc


def get_request_query_body_args(request_obj):
"""
Gets the request query args and the
body arguments.
:param Request request_obj: A Flask request object.
:return: A tuple of the appropriately formatted query args and body args
:rtype: dict, dict
"""
query_args = dict(request_obj.args)
body = request_obj.get_json() or request_obj.form or {}
return query_args, body


class FlaskDispatcher(DispatcherBase):
"""
This is the actual dispatcher responsible for integrating
ripozo with flask. Pretty simple right?
"""

def __init__(self, app, url_prefix='', error_handler=exception_handler):
def __init__(self, app, url_prefix='', error_handler=exception_handler,
argument_getter=get_request_query_body_args):
"""
Eventually these will be able to be registed to a blueprint.
But for now it will probably break the routing by the adapters.
Expand All @@ -56,12 +70,16 @@ def __init__(self, app, url_prefix='', error_handler=exception_handler):
on the '/api' path.
:param function error_handler: A function that takes a dispatcher,
accepted_mimetypes, and exception that handles error responses.
:param function argument_getter: The function responsible for
getting the query/body arguments from the Flask Request as a
tuple.
"""
self.app = app
self.url_map = Map()
self.function_for_endpoint = {}
self.url_prefix = url_prefix
self.error_handler = error_handler
self.argument_getter = argument_getter

@property
def base_url(self):
Expand Down Expand Up @@ -106,14 +124,22 @@ def register_route(self, endpoint, endpoint_func=None, route=None, methods=None,
if key not in valid_flask_options:
options.pop(key, None)
self.app.add_url_rule(route, endpoint=endpoint,
view_func=flask_dispatch_wrapper(self, endpoint_func),
view_func=flask_dispatch_wrapper(self, endpoint_func, self.argument_getter),
methods=methods, **options)


def flask_dispatch_wrapper(dispatcher, f):

def flask_dispatch_wrapper(dispatcher, f, argument_getter=get_request_query_body_args):
"""
A decorator for wrapping the apimethods provided to the
dispatcher.
:param FlaskDispatcher dispatcher: The dispatcher that is
created this.
:param function f: The apimethod to wrap.
:param function argument_getter: The function that takes a flask
Request object and uses it to get the query arguments and the
body arguments as a tuple.
"""

@wraps(f)
Expand All @@ -136,7 +162,7 @@ def flask_dispatch(**urlparams):
:return: A response that the flask application can return.
:rtype: flask.Response
"""
request_args, body_args = _get_request_query_body_args(request)
request_args, body_args = argument_getter(request)
r = RequestContainer(url_params=urlparams, query_args=request_args, body_args=body_args,
headers=request.headers)
accepted_mimetypes = request.accept_mimetypes
Expand All @@ -147,24 +173,4 @@ def flask_dispatch(**urlparams):

return Response(response=adapter.formatted_body, headers=adapter.extra_headers,
content_type=adapter.extra_headers['Content-Type'], status=adapter.status_code)
return flask_dispatch


def _get_request_query_body_args(request_obj):
"""
Gets the request query args and the
body arguments.
:param Request request_obj: A Flask request object.
:return: A tuple of the appropriately formatted query args and body args
:rtype: dict, dict
"""
query_args = dict(request_obj.args)
# TODO What the fuck.
if request_obj.form:
return query_args, dict(request_obj.form)
elif request_obj.json:
return query_args, dict(request_obj.json)
elif request_obj.data:
return query_args, json.loads(request_obj.data)
return query_args, {}
return flask_dispatch
39 changes: 39 additions & 0 deletions flask_ripozo_tests/integration/dispatcher.py
@@ -0,0 +1,39 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from flask import Flask, request

from flask_ripozo.dispatcher import get_request_query_body_args

import json
import unittest


class TestDispatcherFlaskIntegration(unittest.TestCase):
def test_get_request_body_args(self):
"""
Tests getting the request body args
from a flask request object.
"""
app = Flask('myapp')
body = dict(x=1)
with app.test_request_context('/', data=json.dumps(body), content_type='application/json'):
q, b = get_request_query_body_args(request)
self.assertDictEqual(b, body)

with app.test_request_context('/', data=body): # Form encoded
q, b = get_request_query_body_args(request)
self.assertDictEqual(b, dict(x=['1']))

def test_get_request_body_args_nested(self):
"""
Tests getting nested body args which seems to
be handled slightly differnetly.
"""
app = Flask('myapp')
body = dict(x=1, y=dict(x=1))
with app.test_request_context('/', data=json.dumps(body), content_type='application/json'):
q, b = get_request_query_body_args(request)
self.assertDictEqual(b, body)
18 changes: 9 additions & 9 deletions flask_ripozo_tests/unit/dispatcher.py
Expand Up @@ -5,7 +5,7 @@

from flask import Flask, Blueprint

from flask_ripozo.dispatcher import FlaskDispatcher, flask_dispatch_wrapper, _get_request_query_body_args
from flask_ripozo.dispatcher import FlaskDispatcher, flask_dispatch_wrapper, get_request_query_body_args

from ripozo.exceptions import RestException
from ripozo.tests.python2base import TestBase
Expand Down Expand Up @@ -141,25 +141,25 @@ def test_blueprint_base_url(self):

def test_get_request_query_body_args(self):
"""
Tests the private _get_request_query_body_args
Tests the private get_request_query_body_args
method.
"""
query_args = dict(x=1)
form = dict(x=2)
mck = mock.Mock(args=query_args, form=form)
q, b = _get_request_query_body_args(mck)
mck = mock.Mock(args=query_args, form=form, get_json=mock.Mock(return_value=None))
q, b = get_request_query_body_args(mck)
self.assertDictEqual(query_args, q)
self.assertDictEqual(form, b)

mck = mock.MagicMock(args=query_args, form=None, json=form)
q, b = _get_request_query_body_args(mck)
mck = mock.MagicMock(args=query_args, form=None, get_json=mock.Mock(return_value=form))
q, b = get_request_query_body_args(mck)
self.assertDictEqual(query_args, q)
self.assertDictEqual(form, b)

mck = mock.MagicMock(args=query_args, form=None, json=None, data=json.dumps(form))
q, b = _get_request_query_body_args(mck)
mck = mock.MagicMock(args=query_args, form=None, get_json=mock.Mock(return_value=None), data=json.dumps(form))
q, b = get_request_query_body_args(mck)
self.assertDictEqual(query_args, q)
self.assertDictEqual(form, b)
self.assertDictEqual({}, b)

def test_register_route_invalid_options(self):
"""
Expand Down

0 comments on commit 54f56e5

Please sign in to comment.