Skip to content
This repository has been archived by the owner on Sep 28, 2022. It is now read-only.

Commit

Permalink
Merge pull request #71 from brandicted/acl-refactor
Browse files Browse the repository at this point in the history
Acl refactor
  • Loading branch information
postatum committed Aug 28, 2015
2 parents a58092c + 4fe11fb commit a87f379
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 83 deletions.
73 changes: 33 additions & 40 deletions ramses/acl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
Allow, Deny,
Everyone, Authenticated,
ALL_PERMISSIONS)
from nefertari.acl import SelfParamMixin
from nefertari.acl import CollectionACL
from nefertari.resource import PERMISSIONS
from nefertari.elasticsearch import ES

from .views import collection_methods, item_methods
from .utils import resolve_to_callable, is_callable_tag
Expand All @@ -30,8 +32,7 @@ def methods_to_perms(perms, methods_map):
the keyword 'all' into a set of valid Pyramid permissions.
:param perms: List or comma-separated string of HTTP methods, or 'all'
:param methods_map: Map of HTTP methods to permission names (nefertari view
methods)
:param methods_map: Map of HTTP methods to nefertari view methods
"""
if isinstance(perms, six.string_types):
perms = perms.split(',')
Expand All @@ -40,7 +41,7 @@ def methods_to_perms(perms, methods_map):
return ALL_PERMISSIONS
else:
try:
return [methods_map[p] for p in perms]
return [PERMISSIONS[methods_map[p]] for p in perms]
except KeyError:
raise ValueError(
'Unknown method name in permissions: {}. Valid methods: '
Expand Down Expand Up @@ -96,15 +97,12 @@ def parse_acl(acl_string, methods_map):
return result_acl


class BaseACL(SelfParamMixin):
class BaseACL(CollectionACL):
""" ACL Base class. """
__context_class__ = None
collection_acl = None
item_acl = None

def __init__(self, request):
super(BaseACL, self).__init__()
self.request = request
es_based = False
_collection_acl = (ALLOW_ALL, )
_item_acl = (ALLOW_ALL, )

def _apply_callables(self, acl, methods_map, obj=None):
""" Iterate over ACEs from :acl: and apply callable principals if any.
Expand Down Expand Up @@ -138,42 +136,37 @@ def _apply_callables(self, acl, methods_map, obj=None):
return new_acl

def __acl__(self):
""" Apply callables to `self.collection_acl` and return result. """
""" Apply callables to `self._collection_acl` and return result. """
return self._apply_callables(
acl=self.collection_acl,
acl=self._collection_acl,
methods_map=collection_methods)

def context_acl(self, obj):
""" Apply callables to `self.item_acl` and return result. """
def item_acl(self, item):
""" Apply callables to `self._item_acl` and return result. """
return self._apply_callables(
acl=self.item_acl,
acl=self._item_acl,
methods_map=item_methods,
obj=obj)
obj=item)

def item_db_id(self, key):
# ``self`` can be used for current authenticated user key
if key != 'self':
return key
user = getattr(self.request, 'user', None)
if user is None or not isinstance(user, self.item_model):
return key
return getattr(user, user.pk_field())

def __getitem__(self, key):
""" Get item using method depending on value of `self.es_based` """
key = self.resolve_self_key(key)
if self.es_based:
return self.getitem_es(key=key)
else:
return self.getitem_db(key=key)

def getitem_db(self, key):
""" Get item with ID of :key: from database """
pk_field = self.__context_class__.pk_field()
obj = self.__context_class__.get_resource(
**{pk_field: key})
obj.__acl__ = self.context_acl(obj)
obj.__parent__ = self
obj.__name__ = key
return obj
if not self.es_based:
return super(BaseACL, self).__getitem__(key)
return self.getitem_es(self.item_db_id(key))

def getitem_es(self, key):
""" Get item with ID of :key: from elasticsearch """
from nefertari.elasticsearch import ES
es = ES(self.__context_class__.__name__)
es = ES(self.item_model.__name__)
obj = es.get_resource(id=key)
obj.__acl__ = self.context_acl(obj)
obj.__acl__ = self.item_acl(obj)
obj.__parent__ = self
obj.__name__ = key
return obj
Expand All @@ -182,7 +175,7 @@ def getitem_es(self, key):
def generate_acl(model_cls, raml_resource, es_based=True):
""" Generate an ACL.
Generated ACL class has a `__context_class__` attribute set to
Generated ACL class has a `item_model` attribute set to
:model_cls:.
ACLs used for collection and item access control are generated from a
Expand Down Expand Up @@ -216,12 +209,12 @@ def generate_acl(model_cls, raml_resource, es_based=True):
methods_map=item_methods)

class GeneratedACL(BaseACL):
__context_class__ = model_cls
item_model = model_cls

def __init__(self, request, es_based=es_based):
super(GeneratedACL, self).__init__(request=request)
self.es_based = es_based
self.collection_acl = collection_acl
self.item_acl = item_acl
self._collection_acl = collection_acl
self._item_acl = item_acl

return GeneratedACL
4 changes: 2 additions & 2 deletions ramses/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def reload_context(self, es_based, **kwargs):
kwargs['es_based'] = es_based

acl = self._factory(**kwargs)
if acl.__context_class__ is None:
acl.__context_class__ = self.Model
if acl.item_model is None:
acl.item_model = self.Model

self.context = acl[key]

Expand Down
80 changes: 40 additions & 40 deletions tests/test_acl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def test_methods_to_perms_invalid_perm_name(self):

def test_methods_to_perms(self):
perms = acl.methods_to_perms('get', self.methods_map)
assert perms == ['index']
assert perms == ['view']
perms = acl.methods_to_perms('get,post', self.methods_map)
assert sorted(perms) == ['create', 'index']
assert sorted(perms) == ['create', 'view']

def test_parse_acl_no_string(self):
perms = acl.parse_acl('', self.methods_map)
Expand Down Expand Up @@ -85,12 +85,12 @@ def test_no_security(self, mock_parse):
model_cls='Foo',
raml_resource=Mock(security_schemes=[]),
es_based=True)
assert acl_cls.__context_class__ == 'Foo'
assert acl_cls.item_model == 'Foo'
assert issubclass(acl_cls, acl.BaseACL)
instance = acl_cls(request=None)
assert instance.es_based
assert instance.collection_acl == [acl.ALLOW_ALL]
assert instance.item_acl == [acl.ALLOW_ALL]
assert instance._collection_acl == [acl.ALLOW_ALL]
assert instance._item_acl == [acl.ALLOW_ALL]
assert not mock_parse.called

def test_wrong_security_scheme_type(self, mock_parse):
Expand All @@ -102,12 +102,12 @@ def test_wrong_security_scheme_type(self, mock_parse):
raml_resource=raml_resource,
es_based=False)
assert not mock_parse.called
assert acl_cls.__context_class__ == 'Foo'
assert acl_cls.item_model == 'Foo'
assert issubclass(acl_cls, acl.BaseACL)
instance = acl_cls(request=None)
assert not instance.es_based
assert instance.collection_acl == [acl.ALLOW_ALL]
assert instance.item_acl == [acl.ALLOW_ALL]
assert instance._collection_acl == [acl.ALLOW_ALL]
assert instance._item_acl == [acl.ALLOW_ALL]

def test_correct_security_scheme(self, mock_parse):
raml_resource = Mock(security_schemes=[
Expand All @@ -122,18 +122,18 @@ def test_correct_security_scheme(self, mock_parse):
call(acl_string=7, methods_map=acl.item_methods),
])
instance = acl_cls(request=None)
assert instance.collection_acl == mock_parse()
assert instance.item_acl == mock_parse()
assert instance._collection_acl == mock_parse()
assert instance._item_acl == mock_parse()
assert not instance.es_based


class TestBaseACL(object):

def test_init(self):
obj = acl.BaseACL(request='Foo')
assert obj.__context_class__ is None
assert obj.collection_acl is None
assert obj.item_acl is None
assert obj.item_model is None
assert obj._collection_acl == (acl.ALLOW_ALL,)
assert obj._item_acl == (acl.ALLOW_ALL,)
assert obj.request == 'Foo'

def test_apply_callables_no_callables(self):
Expand Down Expand Up @@ -199,11 +199,11 @@ def test_apply_callables_functional(self):
acl=[(Deny, principal, ALL_PERMISSIONS)],
methods_map=acl.item_methods,
)
assert new_acl == [(Allow, Everyone, ['show'])]
assert new_acl == [(Allow, Everyone, ['view'])]

def test_magic_acl(self):
obj = acl.BaseACL('req')
obj.collection_acl = [(1, 2, 3)]
obj._collection_acl = [(1, 2, 3)]
obj._apply_callables = Mock()
result = obj.__acl__()
obj._apply_callables.assert_called_once_with(
Expand All @@ -212,11 +212,11 @@ def test_magic_acl(self):
)
assert result == obj._apply_callables()

def test_context_acl(self):
def test_item_acl(self):
obj = acl.BaseACL('req')
obj.item_acl = [(1, 2, 3)]
obj._item_acl = [(1, 2, 3)]
obj._apply_callables = Mock()
result = obj.context_acl(obj='foobar')
result = obj.item_acl('foobar')
obj._apply_callables.assert_called_once_with(
acl=[(1, 2, 3)],
methods_map=acl.item_methods,
Expand All @@ -226,50 +226,50 @@ def test_context_acl(self):

def test_magic_getitem_es_based(self):
obj = acl.BaseACL('req')
obj.resolve_self_key = Mock()
obj.item_db_id = Mock(return_value=42)
obj.getitem_es = Mock()
obj.es_based = True
obj.__getitem__(1)
obj.resolve_self_key.assert_called_once_with(1)
obj.getitem_es.assert_called_once_with(key=obj.resolve_self_key())
obj.item_db_id.assert_called_once_with(1)
obj.getitem_es.assert_called_once_with(42)

def test_magic_getitem_db_based(self):
obj = acl.BaseACL('req')
obj.resolve_self_key = Mock()
obj.getitem_db = Mock()
obj.item_db_id = Mock(return_value = 42)
obj.item_model = Mock()
obj.item_model.pk_field.return_value = 'id'
obj.es_based = False
obj.__getitem__(1)
obj.resolve_self_key.assert_called_once_with(1)
obj.getitem_db.assert_called_once_with(key=obj.resolve_self_key())
obj.item_db_id.assert_called_once_with(1)

def test_getitem_db(self):
obj = acl.BaseACL('req')
obj.__context_class__ = Mock()
obj.__context_class__.pk_field.return_value = 'myname'
obj.context_acl = Mock()
value = obj.getitem_db(key='varvar')
obj.__context_class__.get_resource.assert_called_once_with(
myname='varvar')
obj.context_acl.assert_called_once_with(
obj.__context_class__.get_resource())
assert value.__acl__ == obj.context_acl()
obj.item_model = Mock()
obj.item_model.pk_field.return_value = 'myname'
obj.item_acl = Mock()
value = obj['varvar']
obj.item_model.get.assert_called_once_with(
__raise=True, myname='varvar')
obj.item_acl.assert_called_once_with(
obj.item_model.get())
assert value.__acl__ == obj.item_acl()
assert value.__parent__ is obj
assert value.__name__ == 'varvar'

@patch('nefertari.elasticsearch.ES')
@patch('ramses.acl.ES')
def test_getitem_es(self, mock_es):
found_obj = Mock()
es_obj = Mock()
es_obj.get_resource.return_value = found_obj
mock_es.return_value = es_obj
obj = acl.BaseACL('req')
obj.__context_class__ = Mock(__name__='Foo')
obj.__context_class__.pk_field.return_value = 'myname'
obj.context_acl = Mock()
obj.item_model = Mock(__name__='Foo')
obj.item_model.pk_field.return_value = 'myname'
obj.item_acl = Mock()
value = obj.getitem_es(key='varvar')
mock_es.assert_called_with('Foo')
es_obj.get_resource.assert_called_once_with(id='varvar')
obj.context_acl.assert_called_once_with(found_obj)
assert value.__acl__ == obj.context_acl()
obj.item_acl.assert_called_once_with(found_obj)
assert value.__acl__ == obj.item_acl()
assert value.__parent__ is obj
assert value.__name__ == 'varvar'
2 changes: 1 addition & 1 deletion tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def test_parent_queryset(self):

def test_reload_context(self):
class Factory(dict):
__context_class__ = None
item_model = None

def __getitem__(self, key):
return key
Expand Down

0 comments on commit a87f379

Please sign in to comment.