diff --git a/restfulpy/tests/test_validation.py b/restfulpy/tests/test_validation.py index 8fa930e..079b434 100644 --- a/restfulpy/tests/test_validation.py +++ b/restfulpy/tests/test_validation.py @@ -67,6 +67,31 @@ def test_exact(self): result.update(context.query_string) return result + @json + @validate_form( + type_={ + 'typedParam1': float, + 'typedParam2': float, + 'typedParam3': float, + }, + client={ + 'type_': { + 'typedParam1': int, + 'typedParam2': int + } + }, + admin={ + 'type_': { + 'typedParam1': complex, + 'typedParam4': complex + } + } + ) + def test_type(self): + result = copy.deepcopy(context.form) + result.update(context.query_string) + return result + class Root(RootController): validation = ValidationController() @@ -530,6 +555,83 @@ def test_validation_exact(self): params={'exactParamForAdmin': 'param'}, expected_status=400 ) + def test_validation_type_(self): + # Test `type` + # role -> All + self.wsgi_app.jwt_token = DummyIdentity().dump().decode() + result, ___ = self.request( + 'All', 'TEST_TYPE', '/validation', + doc=False, + params={ + 'typedParam1': '1', + 'typedParam2': '2', + 'typedParam3': '3', + 'typedParam4': '4' + } + ) + self.assertEqual(type(result['typedParam1']), float) + self.assertEqual(type(result['typedParam2']), float) + self.assertEqual(type(result['typedParam3']), float) + self.assertEqual(type(result['typedParam4']), str) + + self.request( + 'All', 'TEST_TYPE', '/validation', + doc=False, + params={'typedParam1': 'not_convertible'}, + expected_status=400 + ) + + # ----------------------------- + # role -> Client + self.wsgi_app.jwt_token = DummyIdentity('client').dump().decode() + result, ___ = self.request( + 'Client', 'TEST_TYPE', '/validation', + doc=False, + params={ + 'typedParam1': '1', + 'typedParam2': '2', + 'typedParam3': '3', + 'typedParam4': '4' + } + ) + self.assertEqual(type(result['typedParam1']), int) + self.assertEqual(type(result['typedParam2']), int) + self.assertEqual(type(result['typedParam3']), float) + self.assertEqual(type(result['typedParam4']), str) + + self.request( + 'Client', 'TEST_TYPE', '/validation', + doc=False, + params={'typedParam1': 'not_convertible'}, + expected_status=400 + ) + + # ----------------------------- + # role -> Admin + self.wsgi_app.jwt_token = DummyIdentity('admin').dump().decode() + result, ___ = self.request( + 'Admin', 'TEST_TYPE', '/validation', + doc=False, + params={ + 'typedParam1': '1', + 'typedParam2': '2', + 'typedParam3': '3', + 'typedParam4': '4' + } + ) + # type complex is dict + self.assertEqual(type(result['typedParam1']), dict) + self.assertEqual(type(result['typedParam2']), float) + self.assertEqual(type(result['typedParam3']), float) + self.assertEqual(type(result['typedParam4']), dict) + + self.request( + 'Admin', 'TEST_TYPE', '/validation', + doc=False, + params={'typedParam1': 'not_convertible'}, + expected_status=400 + ) + if __name__ == '__main__': # pragma: no cover unittest.main() diff --git a/restfulpy/validation.py b/restfulpy/validation.py index 33bc914..960f2b2 100644 --- a/restfulpy/validation.py +++ b/restfulpy/validation.py @@ -6,8 +6,17 @@ class FormValidator: - def __init__(self, blacklist=None, exclude=None, filter_=None, whitelist=None, requires=None, exact=None, - **rules_per_role): + def __init__( + self, + blacklist=None, + exclude=None, + filter_=None, + whitelist=None, + requires=None, + exact=None, + type_=None, + **rules_per_role + ): self.default_rules = {} if blacklist: self.default_rules['blacklist'] = set(blacklist) @@ -27,6 +36,9 @@ def __init__(self, blacklist=None, exclude=None, filter_=None, whitelist=None, r if exact: self.default_rules['exact'] = set(exact) + if type_: + self.default_rules['type_'] = type_ + self._rules_per_role = rules_per_role def extract_rule_fields(self, rule_name, user_rules): @@ -34,6 +46,12 @@ def extract_rule_fields(self, rule_name, user_rules): *[ruleset[rule_name] for ruleset in ([self.default_rules] + user_rules) if rule_name in ruleset] )) + def extract_rule_fields_with_values(self, rule_name, user_rules): + for user_rule in user_rules: + if rule_name in user_rule: + return {**self.default_rules.get(rule_name, {}), **user_rule[rule_name]} + return self.default_rules.get(rule_name, {}) + def __call__(self, *args, **kwargs): input_collections = [context.form, context.query_string] all_input_fields = set(chain(*input_collections)) @@ -74,10 +92,22 @@ def __call__(self, *args, **kwargs): if exact_fields != all_input_fields: raise HttpBadRequest('Exactly these fields are allowed: [%s]' % ', '.join(whitelist_fields)) + typed_fields = self.extract_rule_fields_with_values('type_', user_rules) + if typed_fields: + for collection in input_collections: + for field, callable_type in typed_fields.items(): + if field in collection: + try: + collection[field] = callable_type(collection[field]) + except ValueError: + raise HttpBadRequest( + 'Cant cast %s type to %s type' % (type(collection[field]), callable_type.__name__) + ) + return args, kwargs -def validate_form(blacklist=None, exclude=None, filter_=None, whitelist=None, requires=None, exact=None, +def validate_form(blacklist=None, exclude=None, filter_=None, whitelist=None, requires=None, exact=None, type_=None, **rules_per_role): """Creates a validation decorator based on given rules: @@ -90,6 +120,9 @@ def validate_form(blacklist=None, exclude=None, filter_=None, whitelist=None, re in the request payload. :param exact: A list of fields to raise :class:`nanohttp.exceptions.HttpBadRequest` if the given fields are not exact match. + :param type_: A dictionary of fields and their expected types. Fields will be casted to expected types if possible. + Otherwise :class:`nanohttp.exceptions.HttpBadRequest` will be raised. + :param rules_per_role: A dictionary ``{ role: { ... } }``, which you can apply above rules to single role. :return: A validation decorator. @@ -97,7 +130,7 @@ def validate_form(blacklist=None, exclude=None, filter_=None, whitelist=None, re def decorator(func): validator = FormValidator(blacklist=blacklist, exclude=exclude, filter_=filter_, whitelist=whitelist, - requires=requires, exact=exact, **rules_per_role) + requires=requires, exact=exact, type_=type_, **rules_per_role) @functools.wraps(func) def wrapper(*args, **kwargs):