Skip to content

Commit

Permalink
read from cls.__dict__ so init_subclass works
Browse files Browse the repository at this point in the history
Modified the :class:`.DeclarativeMeta` metaclass to pass ``cls.__dict__``
into the declarative scanning process to look for attributes, rather than
the separate dictionary passed to the type's ``__init__()`` method. This
allows user-defined base classes that add attributes within an
``__init_subclass__()`` to work as expected, as ``__init_subclass__()`` can
only affect the ``cls.__dict__`` itself and not the other dictionary. This
is technically a regression from 1.3 where ``__dict__`` was being used.

Additionally makes the reference between ClassManager and
the declarative configuration object a weak reference, so that it
can be discarded after mappers are set up.

Fixes: #7900
Change-Id: I3c2fd4e227cc1891aa4bb3d7d5b43d5686f9f27c
  • Loading branch information
zzzeek committed Apr 12, 2022
1 parent a45e228 commit 428ea01
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 7 deletions.
14 changes: 14 additions & 0 deletions doc/build/changelog/unreleased_14/7900.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
.. change::
:tags: bug, orm, declarative
:tickets: 7900

Modified the :class:`.DeclarativeMeta` metaclass to pass ``cls.__dict__``
into the declarative scanning process to look for attributes, rather than
the separate dictionary passed to the type's ``__init__()`` method. This
allows user-defined base classes that add attributes within an
``__init_subclass__()`` to work as expected, as ``__init_subclass__()`` can
only affect the ``cls.__dict__`` itself and not the other dictionary. This
is technically a regression from 1.3 where ``__dict__`` was being used.



7 changes: 6 additions & 1 deletion lib/sqlalchemy/orm/decl_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ class DeclarativeMeta(
def __init__(
cls, classname: Any, bases: Any, dict_: Any, **kw: Any
) -> None:
# use cls.__dict__, which can be modified by an
# __init_subclass__() method (#7900)
dict_ = cls.__dict__

# early-consume registry from the initial declarative base,
# assign privately to not conflict with subclass attributes named
# "registry"
Expand Down Expand Up @@ -293,7 +297,8 @@ def __get__(self, instance, owner) -> InstrumentedAttribute[_T]:

# here, we are inside of the declarative scan. use the registry
# that is tracking the values of these attributes.
declarative_scan = manager.declarative_scan
declarative_scan = manager.declarative_scan()
assert declarative_scan is not None
reg = declarative_scan.declared_attr_reg

if self in reg:
Expand Down
16 changes: 12 additions & 4 deletions lib/sqlalchemy/orm/decl_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,13 @@ def _check_declared_props_nocascade(obj, name, cls):


class _MapperConfig:
__slots__ = ("cls", "classname", "properties", "declared_attr_reg")
__slots__ = (
"cls",
"classname",
"properties",
"declared_attr_reg",
"__weakref__",
)

@classmethod
def setup_mapping(cls, registry, cls_, dict_, table, mapper_kw):
Expand Down Expand Up @@ -311,13 +317,15 @@ def __init__(
mapper_kw,
):

# grab class dict before the instrumentation manager has been added.
# reduces cycles
self.clsdict_view = (
util.immutabledict(dict_) if dict_ else util.EMPTY_DICT
)
super(_ClassScanMapperConfig, self).__init__(registry, cls_, mapper_kw)
self.registry = registry
self.persist_selectable = None

self.clsdict_view = (
util.immutabledict(dict_) if dict_ else util.EMPTY_DICT
)
self.collected_attributes = {}
self.collected_annotations: Dict[str, Tuple[Any, bool]] = {}
self.declared_columns = util.OrderedSet()
Expand Down
3 changes: 2 additions & 1 deletion lib/sqlalchemy/orm/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from typing import Set
from typing import TYPE_CHECKING
from typing import TypeVar
import weakref

from . import base
from . import collections
Expand Down Expand Up @@ -167,7 +168,7 @@ def _update_state(
if registry:
registry._add_manager(self)
if declarative_scan:
self.declarative_scan = declarative_scan
self.declarative_scan = weakref.ref(declarative_scan)
if expired_attribute_loader:
self.expired_attribute_loader = expired_attribute_loader

Expand Down
19 changes: 18 additions & 1 deletion test/orm/declarative/test_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@
mapper_registry = None


class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults):
class DeclarativeTestBase(
testing.AssertsCompiledSQL,
fixtures.TestBase,
testing.AssertsExecutionResults,
):
def setup_test(self):
global Base, mapper_registry

Expand All @@ -58,6 +62,19 @@ def teardown_test(self):


class DeclarativeMixinTest(DeclarativeTestBase):
def test_init_subclass_works(self, registry):
class Base:
def __init_subclass__(cls):
cls.id = Column(Integer, primary_key=True)

Base = registry.generate_base(cls=Base)

class Foo(Base):
__tablename__ = "foo"
name = Column(String)

self.assert_compile(select(Foo), "SELECT foo.name, foo.id FROM foo")

def test_simple_wbase(self):
class MyMixin:

Expand Down

0 comments on commit 428ea01

Please sign in to comment.