Skip to content

Commit

Permalink
Merge ca4dfb4 into f3649c5
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgalvez-tiendeo committed Jan 6, 2021
2 parents f3649c5 + ca4dfb4 commit 91bcce6
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 27 deletions.
29 changes: 24 additions & 5 deletions flask_rbac/__init__.py
Expand Up @@ -7,6 +7,7 @@
"""

import itertools
from collections import defaultdict

from flask import request, abort, _request_ctx_stack

Expand Down Expand Up @@ -256,7 +257,7 @@ def a_view_func():
roles = _user.get_roles()
return self._check_permission(roles, method, endpoint)

def allow(self, roles, methods, with_children=True):
def allow(self, roles, methods, with_children=True, endpoint=None):
"""This is a decorator function.
You can allow roles to access the view func with it.
Expand All @@ -280,12 +281,13 @@ def website_setting():
"""
def decorator(view_func):
_methods = [m.upper() for m in methods]
for r, m, v in itertools.product(roles, _methods, [view_func.__name__]):
resource = [endpoint or view_func.__name__]
for r, m, v in itertools.product(roles, _methods, resource):
self.before_acl['allow'].append((r, m, v, with_children))
return view_func
return decorator

def deny(self, roles, methods, with_children=False):
def deny(self, roles, methods, with_children=False, endpoint=None):
"""This is a decorator function.
You can deny roles to access the view func with it.
Expand All @@ -305,7 +307,8 @@ def article_post():
"""
def decorator(view_func):
_methods = [m.upper() for m in methods]
for r, m, v in itertools.product(roles, _methods, [view_func.__name__]):
resource = [endpoint or view_func.__name__]
for r, m, v in itertools.product(roles, _methods, resource):
self.before_acl['deny'].append((r, m, v, with_children))
return view_func
return decorator
Expand Down Expand Up @@ -361,7 +364,6 @@ def _authenticate(self):
(current_user, self._user_model.__class__))

resource = request.endpoint

if not resource:
abort(404)

Expand All @@ -378,6 +380,7 @@ def _authenticate(self):
return self._deny_hook()

def _check_permission(self, roles, method, resource):

if self.acl.is_exempt(resource):
return True

Expand All @@ -400,6 +403,7 @@ def _check_permission(self, roles, method, resource):

if is_allowed != True and self.acl.is_allowed(r.get_name(), m, res):
is_allowed = True
break

if self.use_white:
permit = (is_allowed is True)
Expand All @@ -422,6 +426,21 @@ def _setup_acl(self):
else:
role = self._role_model.get_by_name(rn)
self.acl.allow(role, method, resource, with_children)

if not self.use_white:
to_deny_map = defaultdict(list)
all_roles = {x.get_name() if not isinstance(x, str)
else x for x in self._role_model.get_all()}

for role, method, resource, with_children in self.before_acl['allow']:
to_deny_map[(resource, role, with_children)].append(method)
for k, methods in to_deny_map.items():
v, role, with_children, = k
for r in all_roles - {role}:
for m in methods:
if (r, m, v, with_children) not in self.before_acl['allow']:
self.before_acl['deny'].append((r, m, v, with_children))

for rn, method, resource, with_children in self.before_acl['deny']:
role = self._role_model.get_by_name(rn)
self.acl.deny(role, method, resource, with_children)
Expand Down
4 changes: 4 additions & 0 deletions flask_rbac/model.py
Expand Up @@ -69,6 +69,10 @@ def get_by_name(name):
"""
return RoleMixin.roles[name]

@classmethod
def get_all(cls):
return cls.roles


class UserMixin(object):
"""This provides implementations for the methods that Flask-RBAC wants
Expand Down
44 changes: 22 additions & 22 deletions test_rbac.py
@@ -1,7 +1,7 @@
import unittest

from flask import Flask, Response, make_response
from flask.ext.login import current_user as login_user
from flask_login import current_user as login_user

from flask_rbac import RBAC, UserMixin, RoleMixin

Expand Down Expand Up @@ -110,7 +110,7 @@ def e():
@before_decorator
def f():
return Response('Hello from /f')

@app.route('/g', methods=['GET'])
@after_decorator
@rbac.exempt
Expand Down Expand Up @@ -146,18 +146,18 @@ def test_set_user_loader(self):
def test_allow_get_view(self):
global current_user
current_user = anonymous
self.assertEqual(self.client.open('/').data, 'index')
self.assertEqual(self.client.open('/').data.decode('utf-8'), 'index')

current_user = normal_user
self.assertEqual(self.client.open('/').data, 'index')
self.assertEqual(self.client.open('/b').data, 'Hello from /b')
self.assertEqual(self.client.open('/').data.decode('utf-8'), 'index')
self.assertEqual(self.client.open('/b').data.decode('utf-8'), 'Hello from /b')

current_user = staff_role_user
self.assertEqual(self.client.open('/').data, 'index')
self.assertEqual(self.client.open('/b').data, 'Hello from /b')
self.assertEqual(self.client.open('/').data.decode('utf-8'), 'index')
self.assertEqual(self.client.open('/b').data.decode('utf-8'), 'Hello from /b')

current_user = special_user
self.assertEqual(self.client.open('/a').data, 'Hello')
self.assertEqual(self.client.open('/a').data.decode('utf-8'), 'Hello')

def test_deny_get_view(self):
global current_user
Expand All @@ -177,10 +177,10 @@ def test_deny_get_view(self):
def test_allow_post_view(self):
global current_user
current_user = staff_role_user
self.assertEqual(self.client.post('/b').data, 'Hello from /b')
self.assertEqual(self.client.post('/b').data.decode('utf-8'), 'Hello from /b')

current_user = special_user
self.assertEqual(self.client.post('/b').data, 'Hello from /b')
self.assertEqual(self.client.post('/b').data.decode('utf-8'), 'Hello from /b')

def test_deny_post_view(self):
global current_user
Expand All @@ -193,20 +193,20 @@ def test_deny_post_view(self):
def test_complicate_get_view(self):
global current_user
current_user = anonymous
self.assertEqual(self.client.open('/c').data, 'Hello from /c')
self.assertEqual(self.client.open('/c').data.decode('utf-8'), 'Hello from /c')

current_user = normal_user
self.assertEqual(self.client.open('/c').status_code, 403)

current_user = staff_role_user
self.assertEqual(self.client.open('/c').data, 'Hello from /c')
self.assertEqual(self.client.open('/c').data.decode('utf-8'), 'Hello from /c')

def test_hook(self):
global current_user
current_user = special_user
self.rbac.set_hook(lambda: make_response('Permission Denied', 403))
self.assertEqual(self.client.open('/').status_code, 403)
self.assertEqual(self.client.open('/').data, 'Permission Denied')
self.assertEqual(self.client.open('/').data.decode('utf-8'), 'Permission Denied')

def test_has_permission(self):
global current_user
Expand All @@ -227,20 +227,20 @@ def test_has_permission(self):

current_user = None
self.assertTrue(self.rbac.has_permission('GET', 'h'))
self.assertEqual(self.client.open('/h').data, 'Hello from /h')
self.assertEqual(self.client.open('/h').data.decode('utf-8'), 'Hello from /h')


def test_exempt(self):
global current_user

current_user = anonymous
self.assertEqual(self.client.open('/g').data, 'Hello from /g')
self.assertEqual(self.client.open('/g').data.decode('utf-8'), 'Hello from /g')

current_user = special_user
self.assertEqual(self.client.open('/g').data, 'Hello from /g')
self.assertEqual(self.client.open('/g').data.decode('utf-8'), 'Hello from /g')

current_user = normal_user
self.assertEqual(self.client.open('/g').data, 'Hello from /g')
self.assertEqual(self.client.open('/g').data.decode('utf-8'), 'Hello from /g')


class NoWhiteApplicationUnitTests(unittest.TestCase):
Expand All @@ -253,10 +253,10 @@ def setUp(self):
def test_allow_get_view(self):
global current_user
current_user = normal_user
self.assertEqual(self.client.open('/d').data, 'Hello from /d')
self.assertEqual(self.client.open('/d').data.decode('utf-8'), 'Hello from /d')

current_user = staff_role_user
self.assertEqual(self.client.open('/d').data, 'Hello from /d')
self.assertEqual(self.client.open('/d').data.decode('utf-8'), 'Hello from /d')

def test_deny_get_view(self):
global current_user
Expand All @@ -273,10 +273,10 @@ def test_deny_get_view(self):
def test_allow_post_view(self):
global current_user
current_user = anonymous
self.assertEqual(self.client.post('/f').data, 'Hello from /f')
self.assertEqual(self.client.post('/f').data.decode('utf-8'), 'Hello from /f')

current_user = staff_role_user
self.assertEqual(self.client.post('/f').data, 'Hello from /f')
self.assertEqual(self.client.post('/f').data.decode('utf-8'), 'Hello from /f')

def test_deny_post_view(self):
global current_user
Expand Down Expand Up @@ -386,7 +386,7 @@ class DecoratorUnitTests(unittest.TestCase):

def setUp(self):
self.rbac = RBAC()

@self.rbac.as_role_model
class RoleModel(RoleMixin):
pass
Expand Down

0 comments on commit 91bcce6

Please sign in to comment.