diff --git a/casbin_sqlalchemy_adapter/adapter.py b/casbin_sqlalchemy_adapter/adapter.py index ab58a79..e40135d 100644 --- a/casbin_sqlalchemy_adapter/adapter.py +++ b/casbin_sqlalchemy_adapter/adapter.py @@ -43,7 +43,7 @@ class Filter: v5 = [] -class Adapter(persist.Adapter): +class Adapter(persist.Adapter, persist.adapters.UpdateAdapter): """the interface for Casbin adapters.""" def __init__(self, engine, db_class=None, filtered=False): @@ -152,7 +152,7 @@ def remove_policy(self, sec, ptype, rule): return True if r > 0 else False def remove_policies(self, sec, ptype, rules): - """removes a policy rules from the storage.""" + """remove policy rules from the storage.""" if not rules: return with self._session_scope() as session: @@ -168,15 +168,65 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values): This is part of the Auto-Save feature. """ with self._session_scope() as session: - query = session.query(self._db_class) - query = query.filter(self._db_class.ptype == ptype) + query = (session + .query(self._db_class) + .filter(self._db_class.ptype == ptype)) + if not (0 <= field_index <= 5): return False if not (1 <= field_index + len(field_values) <= 6): return False for i, v in enumerate(field_values): if v != '': - query = query.filter(getattr(self._db_class, "v{}".format(field_index + i)) == v) + v_value = getattr(self._db_class, "v{}".format(field_index + i)) + query = query.filter(v_value == v) r = query.delete() return True if r > 0 else False + + def update_policy(self, sec: str, ptype: str, old_rule: [str], new_rule: [str]) -> None: + """ + Update the old_rule with the new_rule in the database (storage). + + :param sec: section type + :param ptype: policy type + :param old_rule: the old rule that needs to be modified + :param new_rule: the new rule to replace the old rule + + :return: None + """ + + with self._session_scope() as session: + query = (session + .query(self._db_class) + .filter(self._db_class.ptype == ptype)) + + # locate the old rule + for index, value in enumerate(old_rule): + v_value = getattr(self._db_class, "v{}".format(index)) + query = query.filter(v_value == value) + + # need the length of the longest_rule to perform overwrite + longest_rule = old_rule if len(old_rule) > len(new_rule) else new_rule + old_rule_line = query.one() + + # overwrite the old rule with the new rule + for index in range(len(longest_rule)): + if index < len(new_rule): + exec(f"old_rule_line.v{index} = new_rule[{index}]") + else: + exec(f"old_rule_line.v{index} = None") + + def update_policies(self, sec: str, ptype: str, old_rules: [[str], ], new_rules: [[str], ]) -> None: + """ + Update the old_rules with the new_rules in the database (storage). + + :param sec: section type + :param ptype: policy type + :param old_rules: the old rules that need to be modified + :param new_rules: the new rules to replace the old rules + + :return: None + """ + for i in range(len(old_rules)): + self.update_policy(sec, ptype, old_rules[i], new_rules[i]) diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 6dfc2aa..e46cfe6 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -279,3 +279,62 @@ def test_filtered_policy(self): self.assertTrue(e.enforce('bob', 'data2', 'write')) self.assertFalse(e.enforce('data2_admin', 'data2', 'read')) self.assertTrue(e.enforce('data2_admin', 'data2', 'write')) + + def test_update_policy(self): + e = get_enforcer() + example_p = ['mike', 'cookie', 'eat'] + + self.assertTrue(e.enforce('alice', 'data1', 'read')) + e.update_policy(['alice', 'data1', 'read'], ['alice', 'data1', 'no_read']) + self.assertFalse(e.enforce('alice', 'data1', 'read')) + + self.assertFalse(e.enforce('bob', 'data1', 'read')) + e.add_policy(example_p) + e.update_policy(example_p, ['bob', 'data1', 'read']) + self.assertTrue(e.enforce('bob', 'data1', 'read')) + + self.assertFalse(e.enforce('bob', 'data1', 'write')) + e.update_policy(['bob', 'data1', 'read'], ['bob', 'data1', 'write']) + self.assertTrue(e.enforce('bob', 'data1', 'write')) + + self.assertTrue(e.enforce('bob', 'data2', 'write')) + e.update_policy(['bob', 'data2', 'write'], ['bob', 'data2', 'read']) + self.assertFalse(e.enforce('bob', 'data2', 'write')) + + self.assertTrue(e.enforce('bob', 'data2', 'read')) + e.update_policy(['bob', 'data2', 'read'], ['carl', 'data2', 'write']) + self.assertFalse(e.enforce('bob', 'data2', 'write')) + + self.assertTrue(e.enforce('carl', 'data2', 'write')) + e.update_policy(['carl', 'data2', 'write'], ['carl', 'data2', 'no_write']) + self.assertFalse(e.enforce('bob', 'data2', 'write')) + + def test_update_policies(self): + e = get_enforcer() + + old_rule_0 = ['alice', 'data1', 'read'] + old_rule_1 = ['bob', 'data2', 'write'] + old_rule_2 = ['data2_admin', 'data2', 'read'] + old_rule_3 = ['data2_admin', 'data2', 'write'] + + new_rule_0 = ['alice', 'data_test', 'read'] + new_rule_1 = ['bob', 'data_test', 'write'] + new_rule_2 = ['data2_admin', 'data_test', 'read'] + new_rule_3 = ['data2_admin', 'data_test', 'write'] + + old_rules = [old_rule_0, old_rule_1, old_rule_2, old_rule_3] + new_rules = [new_rule_0, new_rule_1, new_rule_2, new_rule_3] + + e.update_policies(old_rules, new_rules) + + self.assertFalse(e.enforce('alice', 'data1', 'read')) + self.assertTrue(e.enforce('alice', 'data_test', 'read')) + + self.assertFalse(e.enforce('bob', 'data2', 'write')) + self.assertTrue(e.enforce('bob', 'data_test', 'write')) + + self.assertFalse(e.enforce('data2_admin', 'data2', 'read')) + self.assertTrue(e.enforce('data2_admin', 'data_test', 'read')) + + self.assertFalse(e.enforce('data2_admin', 'data2', 'write')) + self.assertTrue(e.enforce('data2_admin', 'data_test', 'write'))