Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions casbin_sqlalchemy_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ def __init__(self, engine, db_class=None, filtered=False):

if db_class is None:
db_class = CasbinRule
else:
for attr in ("ptype", "v0", "v1", "v2", "v3", "v4", "v5"):
if not hasattr(db_class, attr):
raise Exception(f"{attr} not found in custom DatabaseClass.")
Base.metadata = db_class.metadata

self._db_class = db_class
self.session_local = sessionmaker(bind=self._engine)

Expand Down
36 changes: 31 additions & 5 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os
from unittest import TestCase

import casbin
from sqlalchemy import create_engine, Column, Integer, String
from sqlalchemy.orm import sessionmaker

from casbin_sqlalchemy_adapter import Adapter
from casbin_sqlalchemy_adapter import Base
from casbin_sqlalchemy_adapter import CasbinRule
from casbin_sqlalchemy_adapter.adapter import Filter
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from unittest import TestCase
import casbin
import os


def get_fixture(path):
Expand Down Expand Up @@ -35,6 +37,30 @@ def get_enforcer():


class TestConfig(TestCase):
def test_custom_db_class(self):
class CustomRule(Base):
__tablename__ = "casbin_rule2"

id = Column(Integer, primary_key=True)
ptype = Column(String(255))
v0 = Column(String(255))
v1 = Column(String(255))
v2 = Column(String(255))
v3 = Column(String(255))
v4 = Column(String(255))
v5 = Column(String(255))
not_exist = Column(String(255))

engine = create_engine("sqlite://")
adapter = Adapter(engine, CustomRule)

session = sessionmaker(bind=engine)
Base.metadata.create_all(engine)
s = session()
s.add(CustomRule(not_exist="NotNone"))
s.commit()
self.assertEqual(s.query(CustomRule).all()[0].not_exist, "NotNone")

def test_enforcer_basic(self):
e = get_enforcer()

Expand Down