From ca4dfb449df224e89cef0aeea37c1dc70a603d86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20G=C3=A1lvez=20Mart=C3=ADnez?= Date: Wed, 6 Jan 2021 00:41:28 +0100 Subject: [PATCH] Add endpoint argument to decorators in order to allow to use it inside blueprints Improve performance of _check_permission method Automatically add deny rules when an allow decorator is added in order to deny access for all roles that not are in the decorator Fix tests --- flask_rbac/__init__.py | 29 +++++++++++++++++++++++----- flask_rbac/model.py | 4 ++++ test_rbac.py | 44 +++++++++++++++++++++--------------------- 3 files changed, 50 insertions(+), 27 deletions(-) diff --git a/flask_rbac/__init__.py b/flask_rbac/__init__.py index 9eafa96..e2f4a38 100644 --- a/flask_rbac/__init__.py +++ b/flask_rbac/__init__.py @@ -7,6 +7,7 @@ """ import itertools +from collections import defaultdict from flask import request, abort, _request_ctx_stack @@ -256,7 +257,7 @@ def a_view_func(): roles = _user.get_roles() return self._check_permission(roles, method, endpoint) - def allow(self, roles, methods, with_children=True): + def allow(self, roles, methods, with_children=True, endpoint=None): """This is a decorator function. You can allow roles to access the view func with it. @@ -280,12 +281,13 @@ def website_setting(): """ def decorator(view_func): _methods = [m.upper() for m in methods] - for r, m, v in itertools.product(roles, _methods, [view_func.__name__]): + resource = [endpoint or view_func.__name__] + for r, m, v in itertools.product(roles, _methods, resource): self.before_acl['allow'].append((r, m, v, with_children)) return view_func return decorator - def deny(self, roles, methods, with_children=False): + def deny(self, roles, methods, with_children=False, endpoint=None): """This is a decorator function. You can deny roles to access the view func with it. @@ -305,7 +307,8 @@ def article_post(): """ def decorator(view_func): _methods = [m.upper() for m in methods] - for r, m, v in itertools.product(roles, _methods, [view_func.__name__]): + resource = [endpoint or view_func.__name__] + for r, m, v in itertools.product(roles, _methods, resource): self.before_acl['deny'].append((r, m, v, with_children)) return view_func return decorator @@ -361,7 +364,6 @@ def _authenticate(self): (current_user, self._user_model.__class__)) resource = request.endpoint - if not resource: abort(404) @@ -378,6 +380,7 @@ def _authenticate(self): return self._deny_hook() def _check_permission(self, roles, method, resource): + if self.acl.is_exempt(resource): return True @@ -400,6 +403,7 @@ def _check_permission(self, roles, method, resource): if is_allowed != True and self.acl.is_allowed(r.get_name(), m, res): is_allowed = True + break if self.use_white: permit = (is_allowed is True) @@ -422,6 +426,21 @@ def _setup_acl(self): else: role = self._role_model.get_by_name(rn) self.acl.allow(role, method, resource, with_children) + + if not self.use_white: + to_deny_map = defaultdict(list) + all_roles = {x.get_name() if not isinstance(x, str) + else x for x in self._role_model.get_all()} + + for role, method, resource, with_children in self.before_acl['allow']: + to_deny_map[(resource, role, with_children)].append(method) + for k, methods in to_deny_map.items(): + v, role, with_children, = k + for r in all_roles - {role}: + for m in methods: + if (r, m, v, with_children) not in self.before_acl['allow']: + self.before_acl['deny'].append((r, m, v, with_children)) + for rn, method, resource, with_children in self.before_acl['deny']: role = self._role_model.get_by_name(rn) self.acl.deny(role, method, resource, with_children) diff --git a/flask_rbac/model.py b/flask_rbac/model.py index ba5d808..a8953d9 100644 --- a/flask_rbac/model.py +++ b/flask_rbac/model.py @@ -69,6 +69,10 @@ def get_by_name(name): """ return RoleMixin.roles[name] + @classmethod + def get_all(cls): + return cls.roles + class UserMixin(object): """This provides implementations for the methods that Flask-RBAC wants diff --git a/test_rbac.py b/test_rbac.py index 408dc54..6578d91 100644 --- a/test_rbac.py +++ b/test_rbac.py @@ -1,7 +1,7 @@ import unittest from flask import Flask, Response, make_response -from flask.ext.login import current_user as login_user +from flask_login import current_user as login_user from flask_rbac import RBAC, UserMixin, RoleMixin @@ -110,7 +110,7 @@ def e(): @before_decorator def f(): return Response('Hello from /f') - + @app.route('/g', methods=['GET']) @after_decorator @rbac.exempt @@ -146,18 +146,18 @@ def test_set_user_loader(self): def test_allow_get_view(self): global current_user current_user = anonymous - self.assertEqual(self.client.open('/').data, 'index') + self.assertEqual(self.client.open('/').data.decode('utf-8'), 'index') current_user = normal_user - self.assertEqual(self.client.open('/').data, 'index') - self.assertEqual(self.client.open('/b').data, 'Hello from /b') + self.assertEqual(self.client.open('/').data.decode('utf-8'), 'index') + self.assertEqual(self.client.open('/b').data.decode('utf-8'), 'Hello from /b') current_user = staff_role_user - self.assertEqual(self.client.open('/').data, 'index') - self.assertEqual(self.client.open('/b').data, 'Hello from /b') + self.assertEqual(self.client.open('/').data.decode('utf-8'), 'index') + self.assertEqual(self.client.open('/b').data.decode('utf-8'), 'Hello from /b') current_user = special_user - self.assertEqual(self.client.open('/a').data, 'Hello') + self.assertEqual(self.client.open('/a').data.decode('utf-8'), 'Hello') def test_deny_get_view(self): global current_user @@ -177,10 +177,10 @@ def test_deny_get_view(self): def test_allow_post_view(self): global current_user current_user = staff_role_user - self.assertEqual(self.client.post('/b').data, 'Hello from /b') + self.assertEqual(self.client.post('/b').data.decode('utf-8'), 'Hello from /b') current_user = special_user - self.assertEqual(self.client.post('/b').data, 'Hello from /b') + self.assertEqual(self.client.post('/b').data.decode('utf-8'), 'Hello from /b') def test_deny_post_view(self): global current_user @@ -193,20 +193,20 @@ def test_deny_post_view(self): def test_complicate_get_view(self): global current_user current_user = anonymous - self.assertEqual(self.client.open('/c').data, 'Hello from /c') + self.assertEqual(self.client.open('/c').data.decode('utf-8'), 'Hello from /c') current_user = normal_user self.assertEqual(self.client.open('/c').status_code, 403) current_user = staff_role_user - self.assertEqual(self.client.open('/c').data, 'Hello from /c') + self.assertEqual(self.client.open('/c').data.decode('utf-8'), 'Hello from /c') def test_hook(self): global current_user current_user = special_user self.rbac.set_hook(lambda: make_response('Permission Denied', 403)) self.assertEqual(self.client.open('/').status_code, 403) - self.assertEqual(self.client.open('/').data, 'Permission Denied') + self.assertEqual(self.client.open('/').data.decode('utf-8'), 'Permission Denied') def test_has_permission(self): global current_user @@ -227,20 +227,20 @@ def test_has_permission(self): current_user = None self.assertTrue(self.rbac.has_permission('GET', 'h')) - self.assertEqual(self.client.open('/h').data, 'Hello from /h') + self.assertEqual(self.client.open('/h').data.decode('utf-8'), 'Hello from /h') def test_exempt(self): global current_user current_user = anonymous - self.assertEqual(self.client.open('/g').data, 'Hello from /g') + self.assertEqual(self.client.open('/g').data.decode('utf-8'), 'Hello from /g') current_user = special_user - self.assertEqual(self.client.open('/g').data, 'Hello from /g') + self.assertEqual(self.client.open('/g').data.decode('utf-8'), 'Hello from /g') current_user = normal_user - self.assertEqual(self.client.open('/g').data, 'Hello from /g') + self.assertEqual(self.client.open('/g').data.decode('utf-8'), 'Hello from /g') class NoWhiteApplicationUnitTests(unittest.TestCase): @@ -253,10 +253,10 @@ def setUp(self): def test_allow_get_view(self): global current_user current_user = normal_user - self.assertEqual(self.client.open('/d').data, 'Hello from /d') + self.assertEqual(self.client.open('/d').data.decode('utf-8'), 'Hello from /d') current_user = staff_role_user - self.assertEqual(self.client.open('/d').data, 'Hello from /d') + self.assertEqual(self.client.open('/d').data.decode('utf-8'), 'Hello from /d') def test_deny_get_view(self): global current_user @@ -273,10 +273,10 @@ def test_deny_get_view(self): def test_allow_post_view(self): global current_user current_user = anonymous - self.assertEqual(self.client.post('/f').data, 'Hello from /f') + self.assertEqual(self.client.post('/f').data.decode('utf-8'), 'Hello from /f') current_user = staff_role_user - self.assertEqual(self.client.post('/f').data, 'Hello from /f') + self.assertEqual(self.client.post('/f').data.decode('utf-8'), 'Hello from /f') def test_deny_post_view(self): global current_user @@ -386,7 +386,7 @@ class DecoratorUnitTests(unittest.TestCase): def setUp(self): self.rbac = RBAC() - + @self.rbac.as_role_model class RoleModel(RoleMixin): pass