From 38cd7f70e363123f474646eb7906d956ced79451 Mon Sep 17 00:00:00 2001 From: Cameron Hurst Date: Tue, 22 Sep 2020 13:48:31 -0400 Subject: [PATCH 1/2] feat: introduced filtered policy support --- casbin_sqlalchemy_adapter/adapter.py | 67 +++++++++++++++++++++++++--- 1 file changed, 60 insertions(+), 7 deletions(-) diff --git a/casbin_sqlalchemy_adapter/adapter.py b/casbin_sqlalchemy_adapter/adapter.py index 9260ce9..ab839d5 100644 --- a/casbin_sqlalchemy_adapter/adapter.py +++ b/casbin_sqlalchemy_adapter/adapter.py @@ -1,6 +1,6 @@ from casbin import persist from sqlalchemy import Column, Integer, String -from sqlalchemy import create_engine +from sqlalchemy import create_engine, and_, or_ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker @@ -8,7 +8,7 @@ class CasbinRule(Base): - __tablename__ = 'casbin_rule' + __tablename__ = "casbin_rule" id = Column(Integer, primary_key=True) ptype = Column(String(255)) @@ -25,16 +25,32 @@ def __str__(self): if v is None: break arr.append(v) - return ', '.join(arr) + return ", ".join(arr) def __repr__(self): return ''.format(self.id, str(self)) +class Filter: + def __init__(self, v0=None, v1=None, v2=None, v3=None, v4=None, v5=None): + self.v0 = v0 + self.v1 = v1 + self.v2 = v2 + self.v3 = v3 + self.v4 = v4 + self.v5 = v5 + + +class PolicyFilter: + def __init__(self, p=None, g=None): + self.P = p or () + self.G = g or () + + class Adapter(persist.Adapter): """the interface for Casbin adapters.""" - def __init__(self, engine): + def __init__(self, engine, filtered): if isinstance(engine, str): self._engine = create_engine(engine) else: @@ -44,6 +60,43 @@ def __init__(self, engine): self._session = session() Base.metadata.create_all(self._engine) + self._filtered = filtered + + def is_filtered(self): + return self._filtered + + def load_filtered_policy(self, model, filter) -> None: + """loads all policy rules from the storage.""" + + self._filtered = True + query = self._session.query(self._db_class) + filters = [] + for p in filter.P: + filters.append(and_(self._db_class.ptype == "p", *self.__build_rule_filter(p))) + for g in filter.G: + filters.append(and_(self._db_class.ptype == "g", *self.__build_rule_filter(g))) + + query = query.filter(or_(*filters)) + + for line in query.all(): + persist.load_policy_line(str(line), model) + + def __build_rule_filter(self, filter): + rules = [] + if filter.v0: + rules.append(self._db_class.v0 == filter.v0) + if filter.v1: + rules.append(self._db_class.v1 == filter.v1) + if filter.v2: + rules.append(self._db_class.v2 == filter.v2) + if filter.v3: + rules.append(self._db_class.v3 == filter.v3) + if filter.v4: + rules.append(self._db_class.v4 == filter.v4) + if filter.v5: + rules.append(self._db_class.v5 == filter.v5) + + return rules def load_policy(self, model): """loads all policy rules from the storage.""" @@ -54,7 +107,7 @@ def load_policy(self, model): def _save_policy_line(self, ptype, rule): line = CasbinRule(ptype=ptype) for i, v in enumerate(rule): - setattr(line, 'v{}'.format(i), v) + setattr(line, "v{}".format(i), v) self._session.add(line) def _commit(self): @@ -83,7 +136,7 @@ def remove_policy(self, sec, ptype, rule): query = self._session.query(CasbinRule) query = query.filter(CasbinRule.ptype == ptype) for i, v in enumerate(rule): - query = query.filter(getattr(CasbinRule, 'v{}'.format(i)) == v) + query = query.filter(getattr(CasbinRule, "v{}".format(i)) == v) r = query.delete() self._commit() @@ -100,7 +153,7 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values): if not (1 <= field_index + len(field_values) <= 6): return False for i, v in enumerate(field_values): - query = query.filter(getattr(CasbinRule, 'v{}'.format(field_index + i)) == v) + query = query.filter(getattr(CasbinRule, "v{}".format(field_index + i)) == v) r = query.delete() self._commit() From 78daab0805eec2cfc5196fe22b454184a7edbe88 Mon Sep 17 00:00:00 2001 From: Cameron Hurst Date: Tue, 22 Sep 2020 13:58:08 -0400 Subject: [PATCH 2/2] fix: cleanup and unit tests --- casbin_sqlalchemy_adapter/adapter.py | 2 +- tests/test_adapter.py | 89 ++++++++++++++++------------ 2 files changed, 51 insertions(+), 40 deletions(-) diff --git a/casbin_sqlalchemy_adapter/adapter.py b/casbin_sqlalchemy_adapter/adapter.py index ab839d5..524fbfc 100644 --- a/casbin_sqlalchemy_adapter/adapter.py +++ b/casbin_sqlalchemy_adapter/adapter.py @@ -50,7 +50,7 @@ def __init__(self, p=None, g=None): class Adapter(persist.Adapter): """the interface for Casbin adapters.""" - def __init__(self, engine, filtered): + def __init__(self, engine, filtered=False): if isinstance(engine, str): self._engine = create_engine(engine) else: diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 5984a7f..93fcd6f 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -7,6 +7,8 @@ import casbin import os +from casbin_sqlalchemy_adapter.adapter import PolicyFilter, Filter + def get_fixture(path): dir_path = os.path.split(os.path.realpath(__file__))[0] + "/" @@ -22,79 +24,78 @@ def get_enforcer(): Base.metadata.create_all(engine) s = session() s.query(CasbinRule).delete() - s.add(CasbinRule(ptype='p', v0='alice', v1='data1', v2='read')) - s.add(CasbinRule(ptype='p', v0='bob', v1='data2', v2='write')) - s.add(CasbinRule(ptype='p', v0='data2_admin', v1='data2', v2='read')) - s.add(CasbinRule(ptype='p', v0='data2_admin', v1='data2', v2='write')) - s.add(CasbinRule(ptype='g', v0='alice', v1='data2_admin')) + s.add(CasbinRule(ptype="p", v0="alice", v1="data1", v2="read")) + s.add(CasbinRule(ptype="p", v0="bob", v1="data2", v2="write")) + s.add(CasbinRule(ptype="p", v0="data2_admin", v1="data2", v2="read")) + s.add(CasbinRule(ptype="p", v0="data2_admin", v1="data2", v2="write")) + s.add(CasbinRule(ptype="g", v0="alice", v1="data2_admin")) s.commit() s.close() - return casbin.Enforcer(get_fixture('rbac_model.conf'), adapter) + return casbin.Enforcer(get_fixture("rbac_model.conf"), adapter) class TestConfig(TestCase): - def test_enforcer_basic(self): e = get_enforcer() - self.assertTrue(e.enforce('alice', 'data1', 'read')) - self.assertFalse(e.enforce('bob', 'data1', 'read')) - self.assertTrue(e.enforce('bob', 'data2', 'write')) - self.assertTrue(e.enforce('alice', 'data2', 'read')) - self.assertTrue(e.enforce('alice', 'data2', 'write')) + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertFalse(e.enforce("bob", "data1", "read")) + self.assertTrue(e.enforce("bob", "data2", "write")) + self.assertTrue(e.enforce("alice", "data2", "read")) + self.assertTrue(e.enforce("alice", "data2", "write")) def test_add_policy(self): e = get_enforcer() - self.assertFalse(e.enforce('eve', 'data3', 'read')) - res = e.add_permission_for_user('eve', 'data3', 'read') + self.assertFalse(e.enforce("eve", "data3", "read")) + res = e.add_permission_for_user("eve", "data3", "read") self.assertTrue(res) - self.assertTrue(e.enforce('eve', 'data3', 'read')) + self.assertTrue(e.enforce("eve", "data3", "read")) def test_save_policy(self): e = get_enforcer() - self.assertFalse(e.enforce('alice', 'data4', 'read')) + self.assertFalse(e.enforce("alice", "data4", "read")) model = e.get_model() model.clear_policy() - model.add_policy('p', 'p', ['alice', 'data4', 'read']) + model.add_policy("p", "p", ["alice", "data4", "read"]) adapter = e.get_adapter() adapter.save_policy(model) - self.assertTrue(e.enforce('alice', 'data4', 'read')) + self.assertTrue(e.enforce("alice", "data4", "read")) def test_remove_policy(self): e = get_enforcer() - self.assertFalse(e.enforce('alice', 'data5', 'read')) - e.add_permission_for_user('alice', 'data5', 'read') - self.assertTrue(e.enforce('alice', 'data5', 'read')) - e.delete_permission_for_user('alice', 'data5', 'read') - self.assertFalse(e.enforce('alice', 'data5', 'read')) + self.assertFalse(e.enforce("alice", "data5", "read")) + e.add_permission_for_user("alice", "data5", "read") + self.assertTrue(e.enforce("alice", "data5", "read")) + e.delete_permission_for_user("alice", "data5", "read") + self.assertFalse(e.enforce("alice", "data5", "read")) def test_remove_filtered_policy(self): e = get_enforcer() - self.assertTrue(e.enforce('alice', 'data1', 'read')) - e.remove_filtered_policy(1, 'data1') - self.assertFalse(e.enforce('alice', 'data1', 'read')) + self.assertTrue(e.enforce("alice", "data1", "read")) + e.remove_filtered_policy(1, "data1") + self.assertFalse(e.enforce("alice", "data1", "read")) - self.assertTrue(e.enforce('bob', 'data2', 'write')) - self.assertTrue(e.enforce('alice', 'data2', 'read')) - self.assertTrue(e.enforce('alice', 'data2', 'write')) + self.assertTrue(e.enforce("bob", "data2", "write")) + self.assertTrue(e.enforce("alice", "data2", "read")) + self.assertTrue(e.enforce("alice", "data2", "write")) - e.remove_filtered_policy(1, 'data2', 'read') + e.remove_filtered_policy(1, "data2", "read") - self.assertTrue(e.enforce('bob', 'data2', 'write')) - self.assertFalse(e.enforce('alice', 'data2', 'read')) - self.assertTrue(e.enforce('alice', 'data2', 'write')) + self.assertTrue(e.enforce("bob", "data2", "write")) + self.assertFalse(e.enforce("alice", "data2", "read")) + self.assertTrue(e.enforce("alice", "data2", "write")) - e.remove_filtered_policy(2, 'write') + e.remove_filtered_policy(2, "write") - self.assertFalse(e.enforce('bob', 'data2', 'write')) - self.assertFalse(e.enforce('alice', 'data2', 'write')) + self.assertFalse(e.enforce("bob", "data2", "write")) + self.assertFalse(e.enforce("alice", "data2", "write")) # e.add_permission_for_user('alice', 'data6', 'delete') # e.add_permission_for_user('bob', 'data6', 'delete') @@ -109,11 +110,11 @@ def test_remove_filtered_policy(self): # self.assertFalse(e.enforce('eve', 'data6', 'delete')) def test_str(self): - rule = CasbinRule(ptype='p', v0='alice', v1='data1', v2='read') - self.assertEqual(str(rule), 'p, alice, data1, read') + rule = CasbinRule(ptype="p", v0="alice", v1="data1", v2="read") + self.assertEqual(str(rule), "p, alice, data1, read") def test_repr(self): - rule = CasbinRule(ptype='p', v0='alice', v1='data1', v2='read') + rule = CasbinRule(ptype="p", v0="alice", v1="data1", v2="read") self.assertEqual(repr(rule), '') engine = create_engine("sqlite://") @@ -125,3 +126,13 @@ def test_repr(self): s.commit() self.assertRegex(repr(rule), r'') s.close() + + def test_loads_filtered_policy(self): + enforcer = get_enforcer() + + enforcer.get_model().clear_policy() + + enforcer.load_filtered_policy(PolicyFilter(p=(Filter(v0="alice")), g=(Filter(v0="alice")))) + self.assertFalse(enforcer.enforce("alice", "data1", "read")) + self.assertFalse(enforcer.enforce("bob", "data2", "write")) + self.assertTrue(enforcer.enforce("alice", "data2", "write"))