Skip to content
87 changes: 73 additions & 14 deletions casbin_sqlalchemy_adapter/adapter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from casbin import persist
from sqlalchemy import Column, Integer, String
from sqlalchemy import create_engine
from sqlalchemy import create_engine, and_, or_
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

Base = declarative_base()


class CasbinRule(Base):
__tablename__ = 'casbin_rule'
__tablename__ = "casbin_rule"

id = Column(Integer, primary_key=True)
ptype = Column(String(255))
Expand All @@ -25,44 +25,103 @@ def __str__(self):
if v is None:
break
arr.append(v)
return ', '.join(arr)
return ", ".join(arr)

def __repr__(self):
return '<CasbinRule {}: "{}">'.format(self.id, str(self))


class Filter:
def __init__(self, v0=None, v1=None, v2=None, v3=None, v4=None, v5=None):
self.v0 = v0
self.v1 = v1
self.v2 = v2
self.v3 = v3
self.v4 = v4
self.v5 = v5


class PolicyFilter:
def __init__(self, p=None, g=None):
self.P = p or ()
self.G = g or ()


class Adapter(persist.Adapter):
"""the interface for Casbin adapters."""

def __init__(self, engine):
def __init__(self, engine, db_class=None, filtered=False):
if isinstance(engine, str):
self._engine = create_engine(engine)
else:
self._engine = engine

if db_class is None:
db_class = CasbinRule
self._db_class = db_class
session = sessionmaker(bind=self._engine)
self._session = session()

Base.metadata.create_all(self._engine)
self._filtered = filtered

def is_filtered(self):
return self._filtered

def load_filtered_policy(self, model, filter) -> None:
"""loads all policy rules from the storage."""
self._commit() # Commit transaction, so you can see the insert/update/delete from other transaction when use multi processes(eg. Nginx reverse proxy)
self._filtered = True
query = self._session.query(self._db_class)
filters = []
for p in filter.P:
filters.append(and_(self._db_class.ptype == "p", *self.__build_rule_filter(p)))
for g in filter.G:
filters.append(and_(self._db_class.ptype == "g", *self.__build_rule_filter(g)))

query = query.filter(or_(*filters))

for line in query.all():
persist.load_policy_line(str(line), model)
self._commit()

def __build_rule_filter(self, filter):
rules = []
if filter.v0:
rules.append(self._db_class.v0 == filter.v0)
if filter.v1:
rules.append(self._db_class.v1 == filter.v1)
if filter.v2:
rules.append(self._db_class.v2 == filter.v2)
if filter.v3:
rules.append(self._db_class.v3 == filter.v3)
if filter.v4:
rules.append(self._db_class.v4 == filter.v4)
if filter.v5:
rules.append(self._db_class.v5 == filter.v5)

return rules

def load_policy(self, model):
"""loads all policy rules from the storage."""
lines = self._session.query(CasbinRule).all()
self._commit() # Commit transaction, so you can see the insert/update/delete from other transaction when use multi processes(eg. Nginx reverse proxy)
lines = self._session.query(self._db_class).all()
for line in lines:
persist.load_policy_line(str(line), model)
self._commit()

def _save_policy_line(self, ptype, rule):
line = CasbinRule(ptype=ptype)
line = self._db_class(ptype=ptype)
for i, v in enumerate(rule):
setattr(line, 'v{}'.format(i), v)
setattr(line, "v{}".format(i), v)
self._session.add(line)

def _commit(self):
self._session.commit()

def save_policy(self, model):
"""saves all policy rules to the storage."""
query = self._session.query(CasbinRule)
query = self._session.query(self._db_class)
query.delete()
for sec in ["p", "g"]:
if sec not in model.model.keys():
Expand All @@ -80,10 +139,10 @@ def add_policy(self, sec, ptype, rule):

def remove_policy(self, sec, ptype, rule):
"""removes a policy rule from the storage."""
query = self._session.query(CasbinRule)
query = query.filter(CasbinRule.ptype == ptype)
query = self._session.query(self._db_class)
query = query.filter(self._db_class.ptype == ptype)
for i, v in enumerate(rule):
query = query.filter(getattr(CasbinRule, 'v{}'.format(i)) == v)
query = query.filter(getattr(self._db_class, "v{}".format(i)) == v)
r = query.delete()
self._commit()

Expand All @@ -93,14 +152,14 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
"""removes policy rules that match the filter from the storage.
This is part of the Auto-Save feature.
"""
query = self._session.query(CasbinRule)
query = query.filter(CasbinRule.ptype == ptype)
query = self._session.query(self._db_class)
query = query.filter(self._db_class.ptype == ptype)
if not (0 <= field_index <= 5):
return False
if not (1 <= field_index + len(field_values) <= 6):
return False
for i, v in enumerate(field_values):
query = query.filter(getattr(CasbinRule, 'v{}'.format(field_index + i)) == v)
query = query.filter(getattr(self._db_class, "v{}".format(field_index + i)) == v)
r = query.delete()
self._commit()

Expand Down
93 changes: 54 additions & 39 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import casbin
import os

from casbin_sqlalchemy_adapter.adapter import PolicyFilter, Filter


def get_fixture(path):
dir_path = os.path.split(os.path.realpath(__file__))[0] + "/"
Expand All @@ -22,79 +24,78 @@ def get_enforcer():
Base.metadata.create_all(engine)
s = session()
s.query(CasbinRule).delete()
s.add(CasbinRule(ptype='p', v0='alice', v1='data1', v2='read'))
s.add(CasbinRule(ptype='p', v0='bob', v1='data2', v2='write'))
s.add(CasbinRule(ptype='p', v0='data2_admin', v1='data2', v2='read'))
s.add(CasbinRule(ptype='p', v0='data2_admin', v1='data2', v2='write'))
s.add(CasbinRule(ptype='g', v0='alice', v1='data2_admin'))
s.add(CasbinRule(ptype="p", v0="alice", v1="data1", v2="read"))
s.add(CasbinRule(ptype="p", v0="bob", v1="data2", v2="write"))
s.add(CasbinRule(ptype="p", v0="data2_admin", v1="data2", v2="read"))
s.add(CasbinRule(ptype="p", v0="data2_admin", v1="data2", v2="write"))
s.add(CasbinRule(ptype="g", v0="alice", v1="data2_admin"))
s.commit()
s.close()

return casbin.Enforcer(get_fixture('rbac_model.conf'), adapter)
return casbin.Enforcer(get_fixture("rbac_model.conf"), adapter)


class TestConfig(TestCase):

def test_enforcer_basic(self):
e = get_enforcer()

self.assertTrue(e.enforce('alice', 'data1', 'read'))
self.assertFalse(e.enforce('bob', 'data1', 'read'))
self.assertTrue(e.enforce('bob', 'data2', 'write'))
self.assertTrue(e.enforce('alice', 'data2', 'read'))
self.assertTrue(e.enforce('alice', 'data2', 'write'))
self.assertTrue(e.enforce("alice", "data1", "read"))
self.assertFalse(e.enforce("bob", "data1", "read"))
self.assertTrue(e.enforce("bob", "data2", "write"))
self.assertTrue(e.enforce("alice", "data2", "read"))
self.assertTrue(e.enforce("alice", "data2", "write"))

def test_add_policy(self):
e = get_enforcer()

self.assertFalse(e.enforce('eve', 'data3', 'read'))
res = e.add_permission_for_user('eve', 'data3', 'read')
self.assertFalse(e.enforce("eve", "data3", "read"))
res = e.add_permission_for_user("eve", "data3", "read")
self.assertTrue(res)
self.assertTrue(e.enforce('eve', 'data3', 'read'))
self.assertTrue(e.enforce("eve", "data3", "read"))

def test_save_policy(self):
e = get_enforcer()
self.assertFalse(e.enforce('alice', 'data4', 'read'))
self.assertFalse(e.enforce("alice", "data4", "read"))

model = e.get_model()
model.clear_policy()

model.add_policy('p', 'p', ['alice', 'data4', 'read'])
model.add_policy("p", "p", ["alice", "data4", "read"])

adapter = e.get_adapter()
adapter.save_policy(model)
self.assertTrue(e.enforce('alice', 'data4', 'read'))
self.assertTrue(e.enforce("alice", "data4", "read"))

def test_remove_policy(self):
e = get_enforcer()

self.assertFalse(e.enforce('alice', 'data5', 'read'))
e.add_permission_for_user('alice', 'data5', 'read')
self.assertTrue(e.enforce('alice', 'data5', 'read'))
e.delete_permission_for_user('alice', 'data5', 'read')
self.assertFalse(e.enforce('alice', 'data5', 'read'))
self.assertFalse(e.enforce("alice", "data5", "read"))
e.add_permission_for_user("alice", "data5", "read")
self.assertTrue(e.enforce("alice", "data5", "read"))
e.delete_permission_for_user("alice", "data5", "read")
self.assertFalse(e.enforce("alice", "data5", "read"))

def test_remove_filtered_policy(self):
e = get_enforcer()

self.assertTrue(e.enforce('alice', 'data1', 'read'))
e.remove_filtered_policy(1, 'data1')
self.assertFalse(e.enforce('alice', 'data1', 'read'))
self.assertTrue(e.enforce("alice", "data1", "read"))
e.remove_filtered_policy(1, "data1")
self.assertFalse(e.enforce("alice", "data1", "read"))

self.assertTrue(e.enforce('bob', 'data2', 'write'))
self.assertTrue(e.enforce('alice', 'data2', 'read'))
self.assertTrue(e.enforce('alice', 'data2', 'write'))
self.assertTrue(e.enforce("bob", "data2", "write"))
self.assertTrue(e.enforce("alice", "data2", "read"))
self.assertTrue(e.enforce("alice", "data2", "write"))

e.remove_filtered_policy(1, 'data2', 'read')
e.remove_filtered_policy(1, "data2", "read")

self.assertTrue(e.enforce('bob', 'data2', 'write'))
self.assertFalse(e.enforce('alice', 'data2', 'read'))
self.assertTrue(e.enforce('alice', 'data2', 'write'))
self.assertTrue(e.enforce("bob", "data2", "write"))
self.assertFalse(e.enforce("alice", "data2", "read"))
self.assertTrue(e.enforce("alice", "data2", "write"))

e.remove_filtered_policy(2, 'write')
e.remove_filtered_policy(2, "write")

self.assertFalse(e.enforce('bob', 'data2', 'write'))
self.assertFalse(e.enforce('alice', 'data2', 'write'))
self.assertFalse(e.enforce("bob", "data2", "write"))
self.assertFalse(e.enforce("alice", "data2", "write"))

# e.add_permission_for_user('alice', 'data6', 'delete')
# e.add_permission_for_user('bob', 'data6', 'delete')
Expand All @@ -109,11 +110,11 @@ def test_remove_filtered_policy(self):
# self.assertFalse(e.enforce('eve', 'data6', 'delete'))

def test_str(self):
rule = CasbinRule(ptype='p', v0='alice', v1='data1', v2='read')
self.assertEqual(str(rule), 'p, alice, data1, read')
rule = CasbinRule(ptype="p", v0="alice", v1="data1", v2="read")
self.assertEqual(str(rule), "p, alice, data1, read")

def test_repr(self):
rule = CasbinRule(ptype='p', v0='alice', v1='data1', v2='read')
rule = CasbinRule(ptype="p", v0="alice", v1="data1", v2="read")
self.assertEqual(repr(rule), '<CasbinRule None: "p, alice, data1, read">')
engine = create_engine("sqlite://")

Expand All @@ -125,3 +126,17 @@ def test_repr(self):
s.commit()
self.assertRegex(repr(rule), r'<CasbinRule \d+: "p, alice, data1, read">')
s.close()

def test_loads_filtered_policy(self):
enforcer = get_enforcer()

enforcer.get_model().clear_policy()

enforcer.load_filtered_policy(PolicyFilter(p=(Filter(v0="alice"),), g=(Filter(v0="alice"),))) # PolicyFilter.p and PolicyFilter.g need to be iterable(eg. set(), Notice the comma at the end)
"""
p, alice, data1, read
g, alice, data2_admin
"""
self.assertTrue(enforcer.enforce("alice", "data1", "read"))
self.assertFalse(enforcer.enforce("bob", "data2", "write"))
self.assertFalse(enforcer.enforce("alice", "data2", "write"))