diff --git a/casbin_sqlalchemy_adapter/adapter.py b/casbin_sqlalchemy_adapter/adapter.py index 718fca1..6dc3200 100644 --- a/casbin_sqlalchemy_adapter/adapter.py +++ b/casbin_sqlalchemy_adapter/adapter.py @@ -246,3 +246,45 @@ def update_policies( """ for i in range(len(old_rules)): self.update_policy(sec, ptype, old_rules[i], new_rules[i]) + + def update_filtered_policies( + self, sec, ptype, new_rules: [[str]], field_index, *field_values + ) -> [[str]]: + """update_filtered_policies updates all the policies on the basis of the filter.""" + + filter = Filter() + filter.ptype = ptype + + # Creating Filter from the field_index & field_values provided + for i in range(len(field_values)): + if field_index <= i and i < field_index + len(field_values): + setattr(filter, f"v{i}", field_values[i - field_index]) + else: + break + + self._update_filtered_policies(new_rules, filter) + + def _update_filtered_policies(self, new_rules, filter) -> [[str]]: + """_update_filtered_policies updates all the policies on the basis of the filter.""" + + with self._session_scope() as session: + + # Load old policies + + query = session.query(self._db_class).filter( + self._db_class.ptype == filter.ptype + ) + filtered_query = self.filter_query(query, filter) + old_rules = filtered_query.all() + + # Delete old policies + + self.remove_policies("p", filter.ptype, old_rules) + + # Insert new policies + + self.add_policies("p", filter.ptype, new_rules) + + # return deleted rules + + return old_rules diff --git a/tests/test_adapter.py b/tests/test_adapter.py index f3819f8..726932e 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -364,3 +364,23 @@ def test_update_policies(self): self.assertFalse(e.enforce("data2_admin", "data2", "write")) self.assertTrue(e.enforce("data2_admin", "data_test", "write")) + + def test_update_filtered_policies(self): + e = get_enforcer() + + e.update_filtered_policies( + [ + ["data2_admin", "data3", "read"], + ["data2_admin", "data3", "write"], + ], + 0, + "data2_admin", + ) + self.assertTrue(e.enforce("data2_admin", "data3", "write")) + self.assertTrue(e.enforce("data2_admin", "data3", "read")) + + e.update_filtered_policies([["alice", "data1", "write"]], 0, "alice") + self.assertTrue(e.enforce("alice", "data1", "write")) + + e.update_filtered_policies([["bob", "data2", "read"]], 0, "bob") + self.assertTrue(e.enforce("bob", "data2", "read"))