diff --git a/gino/declarative.py b/gino/declarative.py index ecc9bc4c..080d1361 100644 --- a/gino/declarative.py +++ b/gino/declarative.py @@ -152,11 +152,12 @@ def _init_table(cls, sub_cls): for each_cls in sub_cls.__mro__[::-1]: for k, v in getattr(each_cls, '__namespace__', each_cls.__dict__).items(): - if callable(v) and getattr(v, '__declared_attr__', False): - if k == '__tablename__': - table_name = v(sub_cls) - continue + declared_callable_attr = callable(v) and \ + getattr(v, '__declared_attr__', False) + if k != '__tablename__' and declared_callable_attr: v = updates[k] = v(sub_cls) + elif k == '__tablename__': + table_name = v(sub_cls) if declared_callable_attr else v if isinstance(v, sa.Column): v = v.copy() if not v.name: @@ -166,8 +167,6 @@ def _init_table(cls, sub_cls): updates[k] = sub_cls.__attr_factory__(k, v) elif isinstance(v, (sa.Index, sa.Constraint)): inspected_args.append(v) - if table_name is None: - table_name = getattr(sub_cls, '__tablename__', None) if table_name is None: return sub_cls._column_name_map = column_name_map diff --git a/tests/test_declarative.py b/tests/test_declarative.py index 7403bb2a..eb70a9ec 100644 --- a/tests/test_declarative.py +++ b/tests/test_declarative.py @@ -260,3 +260,39 @@ class Model(db.Model): select_col = db.Column(name=db.quoted_name('select', False)) assert select_col.name == 'select' assert not select_col.name.quote + + +async def test_overwrite_declared_table_name(): + class MyTableNameMixin: + @db.declared_attr + def __tablename__(cls): + return cls.__name__.lower() + + class MyTableWithoutName(MyTableNameMixin, db.Model): + id = db.Column(db.Integer, primary_key=True) + + class MyTableWithName(MyTableNameMixin, db.Model): + __tablename__ = 'manually_overwritten_name' + id = db.Column(db.Integer, primary_key=True) + + assert MyTableWithoutName.__table__.name == 'mytablewithoutname' + assert MyTableWithName.__table__.name == 'manually_overwritten_name' + + +async def test_multiple_inheritance_overwrite_declared_table_name(): + class MyTableNameMixin: + @db.declared_attr + def __tablename__(cls): + return cls.__name__.lower() + + class AnotherTableNameMixin: + __tablename__ = "static_table_name" + + class MyTableWithoutName(AnotherTableNameMixin, MyTableNameMixin, db.Model): + id = db.Column(db.Integer, primary_key=True) + + class MyOtherTableWithoutName(MyTableNameMixin, AnotherTableNameMixin, db.Model): + id = db.Column(db.Integer, primary_key=True) + + assert MyTableWithoutName.__table__.name == 'static_table_name' + assert MyOtherTableWithoutName.__table__.name == 'myothertablewithoutname'