diff --git a/casbin_sqlalchemy_adapter/adapter.py b/casbin_sqlalchemy_adapter/adapter.py index c9551d1..62df9a5 100644 --- a/casbin_sqlalchemy_adapter/adapter.py +++ b/casbin_sqlalchemy_adapter/adapter.py @@ -127,12 +127,15 @@ def filter_query(self, querydb, filter): ) return querydb.order_by(self._db_class.id) - def _save_policy_line(self, ptype, rule): - with self._session_scope() as session: - line = self._db_class(ptype=ptype) - for i, v in enumerate(rule): - setattr(line, "v{}".format(i), v) + def _save_policy_line(self, ptype, rule, session=None): + line = self._db_class(ptype=ptype) + for i, v in enumerate(rule): + setattr(line, "v{}".format(i), v) + if session: session.add(line) + else: + with self._session_scope() as session: + session.add(line) def save_policy(self, model): """saves all policy rules to the storage.""" @@ -144,7 +147,7 @@ def save_policy(self, model): continue for ptype, ast in model.model[sec].items(): for rule in ast.policy: - self._save_policy_line(ptype, rule) + self._save_policy_line(ptype, rule, session=session) return True def add_policy(self, sec, ptype, rule):