Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 55 additions & 5 deletions casbin_sqlalchemy_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class Filter:
v5 = []


class Adapter(persist.Adapter):
class Adapter(persist.Adapter, persist.adapters.UpdateAdapter):
"""the interface for Casbin adapters."""

def __init__(self, engine, db_class=None, filtered=False):
Expand Down Expand Up @@ -152,7 +152,7 @@ def remove_policy(self, sec, ptype, rule):
return True if r > 0 else False

def remove_policies(self, sec, ptype, rules):
"""removes a policy rules from the storage."""
"""remove policy rules from the storage."""
if not rules:
return
with self._session_scope() as session:
Expand All @@ -168,15 +168,65 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
This is part of the Auto-Save feature.
"""
with self._session_scope() as session:
query = session.query(self._db_class)
query = query.filter(self._db_class.ptype == ptype)
query = (session
.query(self._db_class)
.filter(self._db_class.ptype == ptype))

if not (0 <= field_index <= 5):
return False
if not (1 <= field_index + len(field_values) <= 6):
return False
for i, v in enumerate(field_values):
if v != '':
query = query.filter(getattr(self._db_class, "v{}".format(field_index + i)) == v)
v_value = getattr(self._db_class, "v{}".format(field_index + i))
query = query.filter(v_value == v)
r = query.delete()

return True if r > 0 else False

def update_policy(self, sec: str, ptype: str, old_rule: [str], new_rule: [str]) -> None:
"""
Update the old_rule with the new_rule in the database (storage).

:param sec: section type
:param ptype: policy type
:param old_rule: the old rule that needs to be modified
:param new_rule: the new rule to replace the old rule

:return: None
"""

with self._session_scope() as session:
query = (session
.query(self._db_class)
.filter(self._db_class.ptype == ptype))

# locate the old rule
for index, value in enumerate(old_rule):
v_value = getattr(self._db_class, "v{}".format(index))
query = query.filter(v_value == value)

# need the length of the longest_rule to perform overwrite
longest_rule = old_rule if len(old_rule) > len(new_rule) else new_rule
old_rule_line = query.one()

# overwrite the old rule with the new rule
for index in range(len(longest_rule)):
if index < len(new_rule):
exec(f"old_rule_line.v{index} = new_rule[{index}]")
else:
exec(f"old_rule_line.v{index} = None")

def update_policies(self, sec: str, ptype: str, old_rules: [[str], ], new_rules: [[str], ]) -> None:
"""
Update the old_rules with the new_rules in the database (storage).

:param sec: section type
:param ptype: policy type
:param old_rules: the old rules that need to be modified
:param new_rules: the new rules to replace the old rules

:return: None
"""
for i in range(len(old_rules)):
self.update_policy(sec, ptype, old_rules[i], new_rules[i])
59 changes: 59 additions & 0 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,62 @@ def test_filtered_policy(self):
self.assertTrue(e.enforce('bob', 'data2', 'write'))
self.assertFalse(e.enforce('data2_admin', 'data2', 'read'))
self.assertTrue(e.enforce('data2_admin', 'data2', 'write'))

def test_update_policy(self):
e = get_enforcer()
example_p = ['mike', 'cookie', 'eat']

self.assertTrue(e.enforce('alice', 'data1', 'read'))
e.update_policy(['alice', 'data1', 'read'], ['alice', 'data1', 'no_read'])
self.assertFalse(e.enforce('alice', 'data1', 'read'))

self.assertFalse(e.enforce('bob', 'data1', 'read'))
e.add_policy(example_p)
e.update_policy(example_p, ['bob', 'data1', 'read'])
self.assertTrue(e.enforce('bob', 'data1', 'read'))

self.assertFalse(e.enforce('bob', 'data1', 'write'))
e.update_policy(['bob', 'data1', 'read'], ['bob', 'data1', 'write'])
self.assertTrue(e.enforce('bob', 'data1', 'write'))

self.assertTrue(e.enforce('bob', 'data2', 'write'))
e.update_policy(['bob', 'data2', 'write'], ['bob', 'data2', 'read'])
self.assertFalse(e.enforce('bob', 'data2', 'write'))

self.assertTrue(e.enforce('bob', 'data2', 'read'))
e.update_policy(['bob', 'data2', 'read'], ['carl', 'data2', 'write'])
self.assertFalse(e.enforce('bob', 'data2', 'write'))

self.assertTrue(e.enforce('carl', 'data2', 'write'))
e.update_policy(['carl', 'data2', 'write'], ['carl', 'data2', 'no_write'])
self.assertFalse(e.enforce('bob', 'data2', 'write'))

def test_update_policies(self):
e = get_enforcer()

old_rule_0 = ['alice', 'data1', 'read']
old_rule_1 = ['bob', 'data2', 'write']
old_rule_2 = ['data2_admin', 'data2', 'read']
old_rule_3 = ['data2_admin', 'data2', 'write']

new_rule_0 = ['alice', 'data_test', 'read']
new_rule_1 = ['bob', 'data_test', 'write']
new_rule_2 = ['data2_admin', 'data_test', 'read']
new_rule_3 = ['data2_admin', 'data_test', 'write']

old_rules = [old_rule_0, old_rule_1, old_rule_2, old_rule_3]
new_rules = [new_rule_0, new_rule_1, new_rule_2, new_rule_3]

e.update_policies(old_rules, new_rules)

self.assertFalse(e.enforce('alice', 'data1', 'read'))
self.assertTrue(e.enforce('alice', 'data_test', 'read'))

self.assertFalse(e.enforce('bob', 'data2', 'write'))
self.assertTrue(e.enforce('bob', 'data_test', 'write'))

self.assertFalse(e.enforce('data2_admin', 'data2', 'read'))
self.assertTrue(e.enforce('data2_admin', 'data_test', 'read'))

self.assertFalse(e.enforce('data2_admin', 'data2', 'write'))
self.assertTrue(e.enforce('data2_admin', 'data_test', 'write'))