Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 60 additions & 7 deletions casbin_sqlalchemy_adapter/adapter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from casbin import persist
from sqlalchemy import Column, Integer, String
from sqlalchemy import create_engine
from sqlalchemy import create_engine, and_, or_
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

Base = declarative_base()


class CasbinRule(Base):
__tablename__ = 'casbin_rule'
__tablename__ = "casbin_rule"

id = Column(Integer, primary_key=True)
ptype = Column(String(255))
Expand All @@ -25,16 +25,32 @@ def __str__(self):
if v is None:
break
arr.append(v)
return ', '.join(arr)
return ", ".join(arr)

def __repr__(self):
return '<CasbinRule {}: "{}">'.format(self.id, str(self))


class Filter:
def __init__(self, v0=None, v1=None, v2=None, v3=None, v4=None, v5=None):
self.v0 = v0
self.v1 = v1
self.v2 = v2
self.v3 = v3
self.v4 = v4
self.v5 = v5


class PolicyFilter:
def __init__(self, p=None, g=None):
self.P = p or ()
self.G = g or ()


class Adapter(persist.Adapter):
"""the interface for Casbin adapters."""

def __init__(self, engine):
def __init__(self, engine, filtered=False):
if isinstance(engine, str):
self._engine = create_engine(engine)
else:
Expand All @@ -44,6 +60,43 @@ def __init__(self, engine):
self._session = session()

Base.metadata.create_all(self._engine)
self._filtered = filtered

def is_filtered(self):
return self._filtered

def load_filtered_policy(self, model, filter) -> None:
"""loads all policy rules from the storage."""

self._filtered = True
query = self._session.query(self._db_class)
filters = []
for p in filter.P:
filters.append(and_(self._db_class.ptype == "p", *self.__build_rule_filter(p)))
for g in filter.G:
filters.append(and_(self._db_class.ptype == "g", *self.__build_rule_filter(g)))

query = query.filter(or_(*filters))

for line in query.all():
persist.load_policy_line(str(line), model)

def __build_rule_filter(self, filter):
rules = []
if filter.v0:
rules.append(self._db_class.v0 == filter.v0)
if filter.v1:
rules.append(self._db_class.v1 == filter.v1)
if filter.v2:
rules.append(self._db_class.v2 == filter.v2)
if filter.v3:
rules.append(self._db_class.v3 == filter.v3)
if filter.v4:
rules.append(self._db_class.v4 == filter.v4)
if filter.v5:
rules.append(self._db_class.v5 == filter.v5)

return rules

def load_policy(self, model):
"""loads all policy rules from the storage."""
Expand All @@ -54,7 +107,7 @@ def load_policy(self, model):
def _save_policy_line(self, ptype, rule):
line = CasbinRule(ptype=ptype)
for i, v in enumerate(rule):
setattr(line, 'v{}'.format(i), v)
setattr(line, "v{}".format(i), v)
self._session.add(line)

def _commit(self):
Expand Down Expand Up @@ -83,7 +136,7 @@ def remove_policy(self, sec, ptype, rule):
query = self._session.query(CasbinRule)
query = query.filter(CasbinRule.ptype == ptype)
for i, v in enumerate(rule):
query = query.filter(getattr(CasbinRule, 'v{}'.format(i)) == v)
query = query.filter(getattr(CasbinRule, "v{}".format(i)) == v)
r = query.delete()
self._commit()

Expand All @@ -100,7 +153,7 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
if not (1 <= field_index + len(field_values) <= 6):
return False
for i, v in enumerate(field_values):
query = query.filter(getattr(CasbinRule, 'v{}'.format(field_index + i)) == v)
query = query.filter(getattr(CasbinRule, "v{}".format(field_index + i)) == v)
r = query.delete()
self._commit()

Expand Down
89 changes: 50 additions & 39 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import casbin
import os

from casbin_sqlalchemy_adapter.adapter import PolicyFilter, Filter


def get_fixture(path):
dir_path = os.path.split(os.path.realpath(__file__))[0] + "/"
Expand All @@ -22,79 +24,78 @@ def get_enforcer():
Base.metadata.create_all(engine)
s = session()
s.query(CasbinRule).delete()
s.add(CasbinRule(ptype='p', v0='alice', v1='data1', v2='read'))
s.add(CasbinRule(ptype='p', v0='bob', v1='data2', v2='write'))
s.add(CasbinRule(ptype='p', v0='data2_admin', v1='data2', v2='read'))
s.add(CasbinRule(ptype='p', v0='data2_admin', v1='data2', v2='write'))
s.add(CasbinRule(ptype='g', v0='alice', v1='data2_admin'))
s.add(CasbinRule(ptype="p", v0="alice", v1="data1", v2="read"))
s.add(CasbinRule(ptype="p", v0="bob", v1="data2", v2="write"))
s.add(CasbinRule(ptype="p", v0="data2_admin", v1="data2", v2="read"))
s.add(CasbinRule(ptype="p", v0="data2_admin", v1="data2", v2="write"))
s.add(CasbinRule(ptype="g", v0="alice", v1="data2_admin"))
s.commit()
s.close()

return casbin.Enforcer(get_fixture('rbac_model.conf'), adapter)
return casbin.Enforcer(get_fixture("rbac_model.conf"), adapter)


class TestConfig(TestCase):

def test_enforcer_basic(self):
e = get_enforcer()

self.assertTrue(e.enforce('alice', 'data1', 'read'))
self.assertFalse(e.enforce('bob', 'data1', 'read'))
self.assertTrue(e.enforce('bob', 'data2', 'write'))
self.assertTrue(e.enforce('alice', 'data2', 'read'))
self.assertTrue(e.enforce('alice', 'data2', 'write'))
self.assertTrue(e.enforce("alice", "data1", "read"))
self.assertFalse(e.enforce("bob", "data1", "read"))
self.assertTrue(e.enforce("bob", "data2", "write"))
self.assertTrue(e.enforce("alice", "data2", "read"))
self.assertTrue(e.enforce("alice", "data2", "write"))

def test_add_policy(self):
e = get_enforcer()

self.assertFalse(e.enforce('eve', 'data3', 'read'))
res = e.add_permission_for_user('eve', 'data3', 'read')
self.assertFalse(e.enforce("eve", "data3", "read"))
res = e.add_permission_for_user("eve", "data3", "read")
self.assertTrue(res)
self.assertTrue(e.enforce('eve', 'data3', 'read'))
self.assertTrue(e.enforce("eve", "data3", "read"))

def test_save_policy(self):
e = get_enforcer()
self.assertFalse(e.enforce('alice', 'data4', 'read'))
self.assertFalse(e.enforce("alice", "data4", "read"))

model = e.get_model()
model.clear_policy()

model.add_policy('p', 'p', ['alice', 'data4', 'read'])
model.add_policy("p", "p", ["alice", "data4", "read"])

adapter = e.get_adapter()
adapter.save_policy(model)
self.assertTrue(e.enforce('alice', 'data4', 'read'))
self.assertTrue(e.enforce("alice", "data4", "read"))

def test_remove_policy(self):
e = get_enforcer()

self.assertFalse(e.enforce('alice', 'data5', 'read'))
e.add_permission_for_user('alice', 'data5', 'read')
self.assertTrue(e.enforce('alice', 'data5', 'read'))
e.delete_permission_for_user('alice', 'data5', 'read')
self.assertFalse(e.enforce('alice', 'data5', 'read'))
self.assertFalse(e.enforce("alice", "data5", "read"))
e.add_permission_for_user("alice", "data5", "read")
self.assertTrue(e.enforce("alice", "data5", "read"))
e.delete_permission_for_user("alice", "data5", "read")
self.assertFalse(e.enforce("alice", "data5", "read"))

def test_remove_filtered_policy(self):
e = get_enforcer()

self.assertTrue(e.enforce('alice', 'data1', 'read'))
e.remove_filtered_policy(1, 'data1')
self.assertFalse(e.enforce('alice', 'data1', 'read'))
self.assertTrue(e.enforce("alice", "data1", "read"))
e.remove_filtered_policy(1, "data1")
self.assertFalse(e.enforce("alice", "data1", "read"))

self.assertTrue(e.enforce('bob', 'data2', 'write'))
self.assertTrue(e.enforce('alice', 'data2', 'read'))
self.assertTrue(e.enforce('alice', 'data2', 'write'))
self.assertTrue(e.enforce("bob", "data2", "write"))
self.assertTrue(e.enforce("alice", "data2", "read"))
self.assertTrue(e.enforce("alice", "data2", "write"))

e.remove_filtered_policy(1, 'data2', 'read')
e.remove_filtered_policy(1, "data2", "read")

self.assertTrue(e.enforce('bob', 'data2', 'write'))
self.assertFalse(e.enforce('alice', 'data2', 'read'))
self.assertTrue(e.enforce('alice', 'data2', 'write'))
self.assertTrue(e.enforce("bob", "data2", "write"))
self.assertFalse(e.enforce("alice", "data2", "read"))
self.assertTrue(e.enforce("alice", "data2", "write"))

e.remove_filtered_policy(2, 'write')
e.remove_filtered_policy(2, "write")

self.assertFalse(e.enforce('bob', 'data2', 'write'))
self.assertFalse(e.enforce('alice', 'data2', 'write'))
self.assertFalse(e.enforce("bob", "data2", "write"))
self.assertFalse(e.enforce("alice", "data2", "write"))

# e.add_permission_for_user('alice', 'data6', 'delete')
# e.add_permission_for_user('bob', 'data6', 'delete')
Expand All @@ -109,11 +110,11 @@ def test_remove_filtered_policy(self):
# self.assertFalse(e.enforce('eve', 'data6', 'delete'))

def test_str(self):
rule = CasbinRule(ptype='p', v0='alice', v1='data1', v2='read')
self.assertEqual(str(rule), 'p, alice, data1, read')
rule = CasbinRule(ptype="p", v0="alice", v1="data1", v2="read")
self.assertEqual(str(rule), "p, alice, data1, read")

def test_repr(self):
rule = CasbinRule(ptype='p', v0='alice', v1='data1', v2='read')
rule = CasbinRule(ptype="p", v0="alice", v1="data1", v2="read")
self.assertEqual(repr(rule), '<CasbinRule None: "p, alice, data1, read">')
engine = create_engine("sqlite://")

Expand All @@ -125,3 +126,13 @@ def test_repr(self):
s.commit()
self.assertRegex(repr(rule), r'<CasbinRule \d+: "p, alice, data1, read">')
s.close()

def test_loads_filtered_policy(self):
enforcer = get_enforcer()

enforcer.get_model().clear_policy()

enforcer.load_filtered_policy(PolicyFilter(p=(Filter(v0="alice")), g=(Filter(v0="alice"))))
self.assertFalse(enforcer.enforce("alice", "data1", "read"))
self.assertFalse(enforcer.enforce("bob", "data2", "write"))
self.assertTrue(enforcer.enforce("alice", "data2", "write"))