Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions restfulpy/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,31 @@ def test_exact(self):
result.update(context.query_string)
return result

@json
@validate_form(
type_={
'typedParam1': float,
'typedParam2': float,
'typedParam3': float,
},
client={
'type_': {
'typedParam1': int,
'typedParam2': int
}
},
admin={
'type_': {
'typedParam1': complex,
'typedParam4': complex
}
}
)
def test_type(self):
result = copy.deepcopy(context.form)
result.update(context.query_string)
return result


class Root(RootController):
validation = ValidationController()
Expand Down Expand Up @@ -530,6 +555,83 @@ def test_validation_exact(self):
params={'exactParamForAdmin': 'param'}, expected_status=400
)

def test_validation_type_(self):
# Test `type`
# role -> All
self.wsgi_app.jwt_token = DummyIdentity().dump().decode()
result, ___ = self.request(
'All', 'TEST_TYPE', '/validation',
doc=False,
params={
'typedParam1': '1',
'typedParam2': '2',
'typedParam3': '3',
'typedParam4': '4'
}
)
self.assertEqual(type(result['typedParam1']), float)
self.assertEqual(type(result['typedParam2']), float)
self.assertEqual(type(result['typedParam3']), float)
self.assertEqual(type(result['typedParam4']), str)

self.request(
'All', 'TEST_TYPE', '/validation',
doc=False,
params={'typedParam1': 'not_convertible'},
expected_status=400
)

# -----------------------------
# role -> Client
self.wsgi_app.jwt_token = DummyIdentity('client').dump().decode()
result, ___ = self.request(
'Client', 'TEST_TYPE', '/validation',
doc=False,
params={
'typedParam1': '1',
'typedParam2': '2',
'typedParam3': '3',
'typedParam4': '4'
}
)
self.assertEqual(type(result['typedParam1']), int)
self.assertEqual(type(result['typedParam2']), int)
self.assertEqual(type(result['typedParam3']), float)
self.assertEqual(type(result['typedParam4']), str)

self.request(
'Client', 'TEST_TYPE', '/validation',
doc=False,
params={'typedParam1': 'not_convertible'},
expected_status=400
)

# -----------------------------
# role -> Admin
self.wsgi_app.jwt_token = DummyIdentity('admin').dump().decode()
result, ___ = self.request(
'Admin', 'TEST_TYPE', '/validation',
doc=False,
params={
'typedParam1': '1',
'typedParam2': '2',
'typedParam3': '3',
'typedParam4': '4'
}
)
# type complex is dict
self.assertEqual(type(result['typedParam1']), dict)
self.assertEqual(type(result['typedParam2']), float)
self.assertEqual(type(result['typedParam3']), float)
self.assertEqual(type(result['typedParam4']), dict)

self.request(
'Admin', 'TEST_TYPE', '/validation',
doc=False,
params={'typedParam1': 'not_convertible'},
expected_status=400
)


if __name__ == '__main__': # pragma: no cover
unittest.main()
41 changes: 37 additions & 4 deletions restfulpy/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,17 @@


class FormValidator:
def __init__(self, blacklist=None, exclude=None, filter_=None, whitelist=None, requires=None, exact=None,
**rules_per_role):
def __init__(
self,
blacklist=None,
exclude=None,
filter_=None,
whitelist=None,
requires=None,
exact=None,
type_=None,
**rules_per_role
):
self.default_rules = {}
if blacklist:
self.default_rules['blacklist'] = set(blacklist)
Expand All @@ -27,13 +36,22 @@ def __init__(self, blacklist=None, exclude=None, filter_=None, whitelist=None, r
if exact:
self.default_rules['exact'] = set(exact)

if type_:
self.default_rules['type_'] = type_

self._rules_per_role = rules_per_role

def extract_rule_fields(self, rule_name, user_rules):
return set(chain(
*[ruleset[rule_name] for ruleset in ([self.default_rules] + user_rules) if rule_name in ruleset]
))

def extract_rule_fields_with_values(self, rule_name, user_rules):
for user_rule in user_rules:
if rule_name in user_rule:
return {**self.default_rules.get(rule_name, {}), **user_rule[rule_name]}
return self.default_rules.get(rule_name, {})

def __call__(self, *args, **kwargs):
input_collections = [context.form, context.query_string]
all_input_fields = set(chain(*input_collections))
Expand Down Expand Up @@ -74,10 +92,22 @@ def __call__(self, *args, **kwargs):
if exact_fields != all_input_fields:
raise HttpBadRequest('Exactly these fields are allowed: [%s]' % ', '.join(whitelist_fields))

typed_fields = self.extract_rule_fields_with_values('type_', user_rules)
if typed_fields:
for collection in input_collections:
for field, callable_type in typed_fields.items():
if field in collection:
try:
collection[field] = callable_type(collection[field])
except ValueError:
raise HttpBadRequest(
'Cant cast %s type to %s type' % (type(collection[field]), callable_type.__name__)
)

return args, kwargs


def validate_form(blacklist=None, exclude=None, filter_=None, whitelist=None, requires=None, exact=None,
def validate_form(blacklist=None, exclude=None, filter_=None, whitelist=None, requires=None, exact=None, type_=None,
**rules_per_role):
"""Creates a validation decorator based on given rules:

Expand All @@ -90,14 +120,17 @@ def validate_form(blacklist=None, exclude=None, filter_=None, whitelist=None, re
in the request payload.
:param exact: A list of fields to raise :class:`nanohttp.exceptions.HttpBadRequest` if the given fields are not
exact match.
:param type_: A dictionary of fields and their expected types. Fields will be casted to expected types if possible.
Otherwise :class:`nanohttp.exceptions.HttpBadRequest` will be raised.

:param rules_per_role: A dictionary ``{ role: { ... } }``, which you can apply above rules to single role.

:return: A validation decorator.
"""

def decorator(func):
validator = FormValidator(blacklist=blacklist, exclude=exclude, filter_=filter_, whitelist=whitelist,
requires=requires, exact=exact, **rules_per_role)
requires=requires, exact=exact, type_=type_, **rules_per_role)

@functools.wraps(func)
def wrapper(*args, **kwargs):
Expand Down