Skip to content
This repository has been archived by the owner on Sep 28, 2022. It is now read-only.

Commit

Permalink
Merge branch 'develop' of github.com:brandicted/nefertari into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
chartpath committed Apr 17, 2015
2 parents cb0ed91 + 68fb7db commit a0ea2d7
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 193 deletions.
99 changes: 0 additions & 99 deletions nefertari/tests/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,109 +3,10 @@
import unittest
import mock
from nefertari import wrappers
from nefertari.view import BaseView
from nefertari.utils import dictset


class WrappersTest(unittest.TestCase):

def test_validator_decorator(self):
params = dictset(a=10, b='bbb', c=20, mixed=lambda: {})

req = mock.MagicMock(params=params)
res = mock.MagicMock(actions=['create', 'update', 'index'])

class MyView(BaseView):

__validation_schema__ = dict(
a=dict(type=int, required=True),
b=dict(type=str, required=False)
)

def __init__(self):
BaseView.__init__(self, res, req)

@wrappers.validator(c=dict(type=int, required=True),
a=dict(type=float, required=False))
def create(self):
pass

@wrappers.validator()
def update(self):
pass

@wrappers.validator(a=dict(type=int, required=False))
def index(self):
[]

def convert_ids2objects(self, *args, **kwargs):
pass

view = MyView()
self.assertEqual([wrappers.validate_types(),
wrappers.validate_required()],
view.create._before_calls)
self.assertIn('c', view.create._before_calls[0].kwargs)

self.assertEqual(dict(type=float, required=False),
view.create._before_calls[0].kwargs['a'])

def test_validate_types(self):
import datetime as dt

request = mock.MagicMock()
wrappers.validate_types()(request=request)

schema = dict(a=dict(type=int), b=dict(type=str),
c=dict(type=dt.datetime), d=dict(type=dt.date),
e=dict(type=None), f=dict(type='BadType'))

request.params = dict(a=1, b=2)
wrappers.validate_types(**schema)(request=request)

request.params = dict(c='2000-01-01T01:01:01')
wrappers.validate_types(**schema)(request=request)

request.params = dict(d='2000-01-01')
wrappers.validate_types(**schema)(request=request)

request.params = dict(c='bad_date')
with self.assertRaises(wrappers.ValidationError):
wrappers.validate_types(**schema)(request=request)

request.params = dict(d='bad_date')
with self.assertRaises(wrappers.ValidationError):
wrappers.validate_types(**schema)(request=request)

request.params = dict(e='unknown_type')
with mock.patch('nefertari.wrappers.log') as log:
wrappers.validate_types(**schema)(request=request)
self.assertTrue(log.debug.called)

request.params = dict(f='bad_type')
with self.assertRaises(wrappers.ValidationError):
wrappers.validate_types(**schema)(request=request)

def test_validate_required(self):
request = mock.MagicMock()
wrappers.validate_types()(request=request)

schema = dict(a=dict(type=int, required=True), b=dict(type=str,
required=False), c=dict(type=int))

request.params = dict(a=1, b=2, c=3)
wrappers.validate_required(**schema)(request=request)

request.params = dict(a=1, b=2)
wrappers.validate_required(**schema)(request=request)

request.params = dict(a=1, c=3)
wrappers.validate_required(**schema)(request=request)

request.params = dict(b=2, c=3)
with self.assertRaises(wrappers.ValidationError):
wrappers.validate_required(**schema)(request=request)

def test_obj2dict(self):
result = mock.MagicMock()
result.to_dict.return_value = dict(a=1)
Expand Down
1 change: 0 additions & 1 deletion nefertari/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def __init__(self, context, request, _params={}):
elif 'text/plain' in request.accept:
request.override_renderer = 'string'


self.setup_default_wrappers()
self.convert_ids2objects()

Expand Down
95 changes: 2 additions & 93 deletions nefertari/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import urllib
from datetime import datetime, date
from hashlib import md5

import logging
Expand Down Expand Up @@ -50,26 +49,6 @@ def __call__(self, meth):
return meth


class validator(wrap_me):
"""Decorator that validates the type and required fields in request params
against the supplied kwargs.
::
class MyView():
@validator(first_name={'type':int, 'required':True})
def index(self):
return response
"""

def __init__(self, **kwargs):
wrap_me.__init__(
self, before=[
validate_types(**kwargs),
validate_required(**kwargs)
])


class callable_base(object):
"""Base class for all before and after calls.
``__eq__`` method is overloaded in order to prevent duplicate callables
Expand All @@ -87,77 +66,6 @@ def __eq__(self, other):
return type(self) == type(other)


# Before calls

class validate_base(callable_base):
"""Base class for validation callables.
"""

def __call__(self, **kwargs):
self.request = kwargs['request']
self.params = self.request.params.copy()
# Tunneling internal param, no need to check.
self.params.pop('_method', None)


class validate_types(validate_base):
"""
Validates the field types in ``request.params`` match the types declared
in ``kwargs``. Raises ValidationError if there is mismatch.
"""

def __call__(self, **kwargs):
validate_base.__call__(self, **kwargs)
# checking the types
for name, value in self.params.items():
if value == 'None': # fix this properly.
continue
_type = self.kwargs.get(name, {}).get('type')
try:
if _type == datetime:
# must be in iso format
value = datetime.strptime(value, '%Y-%m-%dT%H:%M:%S')
elif _type == date:
# must be in iso format
value = datetime.strptime(value, '%Y-%m-%d')
elif _type == None:
log.debug('Incorrect or unsupported type for %s(%s)',
name, value)
continue
elif type(_type) is type:
_type(value)
else:
raise ValueError
except ValueError, e:
raise ValidationError(
'Bad type %s for %s=%s. Suppose to be %s' % (
type(value), name, value, _type))


class validate_required(validate_base):
"""Validates that fields in ``request.params`` are present
according to ``kwargs`` argument passed to ``__call__.
Raises ValidationError in case of the mismatch
"""

def __call__(self, **kwargs):
validate_base.__call__(self, **kwargs)
# Get parent resources' ids from matchdict, so there is no need
# to pass in the request.params
self.params.update(self.request.matchdict)

self.kwargs.pop('id', None)

required_fields = set([n for n in self.kwargs.keys()
if self.kwargs[n].get('required', False)])

if not required_fields.issubset(set(self.params.keys())):
raise ValidationError('Required fields: %s. Received: %s'
% (list(required_fields),
self.params.keys()))


# After calls.

class obj2dict(object):
Expand All @@ -182,7 +90,8 @@ def __call__(self, **kwargs):
#make sure its mutable, i.e list
result = list(result)
for ix, each in enumerate(result):
result[ix] = obj2dict(self.request)(_fields=_fields, result=each)
result[ix] = obj2dict(self.request)(
_fields=_fields, result=each)

return result

Expand Down

0 comments on commit a0ea2d7

Please sign in to comment.