Skip to content

Commit

Permalink
Merge pull request #23 from joe-gordian-software/master
Browse files Browse the repository at this point in the history
using __name__ instead of the function itself
  • Loading branch information
shonenada committed Jan 2, 2019
2 parents c61621f + ed87a3f commit e085121
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 33 deletions.
44 changes: 14 additions & 30 deletions flask_rbac/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ def deny(self, role, method, resource, with_children=False):
if permission not in self._denied:
self._denied.append(permission)

def exempt(self, view_func):
def exempt(self, resource):
"""Exempt a view function from being checked permission
:param view_func: The view function exempt from checking.
:param resource: The view function exempt from checking.
"""
if view_func not in self._exempt:
self._exempt.append(view_func)
if resource not in self._exempt:
self._exempt.append(resource)

def is_allowed(self, role, method, resource):
"""Check whether role is allowed to access resource
Expand All @@ -106,12 +106,12 @@ def is_denied(self, role, method, resource):
"""
return (role, method, resource) in self._denied

def is_exempt(self, view_func):
"""Return whether view_func is exempted.
def is_exempt(self, resource):
"""Return whether resource is exempted.
:param view_func: View function to be checked.
:param resource: View function to be checked.
"""
return view_func in self._exempt
return resource in self._exempt


class _RBACState(object):
Expand Down Expand Up @@ -174,7 +174,7 @@ def init_app(self, app):
app.extensions = {}
app.extensions['rbac'] = _RBACState(self, app)

self.acl.allow(anonymous, 'GET', app.view_functions['static'])
self.acl.allow(anonymous, 'GET', 'static')
app.before_first_request(self._setup_acl)

app.before_request(self._authenticate)
Expand Down Expand Up @@ -254,22 +254,7 @@ def a_view_func():
roles = [anonymous]
else:
roles = _user.get_roles()
view_func = app.view_functions[endpoint]
return self._check_permission(roles, method, view_func)

def check_perm(self, role, method, callback=None):
def decorator(view_func):
if not self._check_permission([role], method, view_func):
if callable(callback):
callback()
else:
self._deny_hook()
return view_func
return decorator

def user_loader(self, loader):
self._user_loader = loader
return loader
return self._check_permission(roles, method, endpoint)

def allow(self, roles, methods, with_children=True):
"""This is a decorator function.
Expand All @@ -295,7 +280,7 @@ def website_setting():
"""
def decorator(view_func):
_methods = [m.upper() for m in methods]
for r, m, v in itertools.product(roles, _methods, [view_func]):
for r, m, v in itertools.product(roles, _methods, [view_func.__name__]):
self.before_acl['allow'].append((r, m, v, with_children))
return view_func
return decorator
Expand All @@ -320,7 +305,7 @@ def article_post():
"""
def decorator(view_func):
_methods = [m.upper() for m in methods]
for r, m, v in itertools.product(roles, _methods, [view_func]):
for r, m, v in itertools.product(roles, _methods, [view_func.__name__]):
self.before_acl['deny'].append((r, m, v, with_children))
return view_func
return decorator
Expand All @@ -338,7 +323,7 @@ def everyone_can_access():
:param view_func: The view function going to be exempted.
"""
self.acl.exempt(view_func)
self.acl.exempt(view_func.__name__)
return view_func

def get_app(self, reference_app=None):
Expand Down Expand Up @@ -375,8 +360,7 @@ def _authenticate(self):
"%s is not an instance of %s" %
(current_user, self._user_model.__class__))

endpoint = request.endpoint
resource = app.view_functions.get(endpoint, None)
resource = request.endpoint

if not resource:
abort(404)
Expand Down
31 changes: 28 additions & 3 deletions test_rbac.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from flask_rbac import RBAC, UserMixin, RoleMixin

import functools


class Role(RoleMixin):
def __repr__(self):
Expand All @@ -31,8 +33,13 @@ def __repr__(self):

current_user = anonymous

def rewrite_decorator(viewfunc):
@functools.wraps(viewfunc)
def newfunc(*args, **kwargs):
return viewfunc(*args, **kwargs)
return newfunc

def makeapp(with_factory=False, use_white=False):
def makeapp(with_factory, use_white, before_decorator, after_decorator):
global current_user
app = Flask(__name__)
app.debug = True
Expand All @@ -53,50 +60,68 @@ def makeapp(with_factory=False, use_white=False):
rbac.set_role_model(Role)

@app.route('/')
@after_decorator
@rbac.allow(roles=['everyone'], methods=['GET'])
@before_decorator
def index():
return Response('index')

@app.route('/a')
@after_decorator
@rbac.allow(roles=['special'], methods=['GET'])
@before_decorator
def a():
return Response('Hello')

@app.route('/b', methods=['GET', 'POST'])
@after_decorator
@rbac.allow(roles=['logged_role'], methods=['GET'])
@rbac.allow(roles=['staff_role', 'special'], methods=['POST'])
@before_decorator
def b():
return Response('Hello from /b')

@app.route('/c')
@after_decorator
@rbac.allow(roles=['everyone'], methods=['GET'])
@rbac.deny(roles=['logged_role'], methods=['GET'], with_children=False)
@rbac.allow(roles=['staff_role'], methods=['GET'])
@before_decorator
def c():
return Response('Hello from /c')

@app.route('/d')
@after_decorator
@rbac.deny(roles=['everyone'], methods=['GET'])
@before_decorator
def d():
return Response('Hello from /d')

@app.route('/e')
@after_decorator
@rbac.deny(roles=['everyone'], methods=['GET'], with_children=True)
@before_decorator
def e():
return Response('Hello from /e')

@app.route('/f', methods=['POST'])
@after_decorator
@rbac.deny(roles=['logged_role'], methods=['POST'])
@before_decorator
def f():
return Response('Hello from /f')

@app.route('/g', methods=['GET'])
@after_decorator
@rbac.exempt
@before_decorator
def g():
return Response('Hello from /g')

@app.route('/h', methods=['GET'])
@after_decorator
@rbac.allow(['anonymous'], methods=['GET'], with_children=False)
@before_decorator
def h():
return Response('Hello from /h')

Expand All @@ -106,7 +131,7 @@ def h():
class UseWhiteApplicationUnitTests(unittest.TestCase):

def setUp(self):
self.app = makeapp(use_white=True)
self.app = makeapp(with_factory=False, use_white=True, before_decorator=rewrite_decorator, after_decorator=rewrite_decorator)
self.client = self.app.test_client()
self.rbac = self.app.extensions['rbac'].rbac

Expand Down Expand Up @@ -221,7 +246,7 @@ def test_exempt(self):
class NoWhiteApplicationUnitTests(unittest.TestCase):

def setUp(self):
self.app = makeapp(use_white=False)
self.app = makeapp(with_factory=False, use_white=False, before_decorator=rewrite_decorator, after_decorator=rewrite_decorator)
self.client = self.app.test_client()
self.rbac = self.app.extensions['rbac'].rbac

Expand Down

0 comments on commit e085121

Please sign in to comment.