Skip to content

Commit

Permalink
Fix hash function generation for frozen models with unusual MRO (#7274)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu committed Aug 29, 2023
1 parent dad6665 commit 2b8bea3
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 22 deletions.
30 changes: 19 additions & 11 deletions pydantic/_internal/_decorators.py
Expand Up @@ -309,6 +309,11 @@ def mro(tp: type[Any]) -> tuple[type[Any], ...]:
# GenericAlias and some other cases
pass

bases = get_bases(tp)
return (tp,) + mro_for_bases(bases)


def mro_for_bases(bases: tuple[type[Any], ...]) -> tuple[type[Any], ...]:
def merge_seqs(seqs: list[deque[type[Any]]]) -> Iterable[type[Any]]:
while True:
non_empty = [seq for seq in seqs if seq]
Expand All @@ -332,14 +337,11 @@ def merge_seqs(seqs: list[deque[type[Any]]]) -> Iterable[type[Any]]:
if seq[0] == candidate:
seq.popleft()

bases = get_bases(tp)
seqs = [deque(mro(base)) for base in bases] + [deque(bases)]
res = tuple(merge_seqs(seqs))

return (tp,) + res
return tuple(merge_seqs(seqs))


def get_attribute_from_bases(tp: type[Any], name: str) -> Any:
def get_attribute_from_bases(tp: type[Any] | tuple[type[Any], ...], name: str) -> Any:
"""Get the attribute from the next class in the MRO that has it,
aiming to simulate calling the method on the actual class.
Expand All @@ -349,7 +351,7 @@ def get_attribute_from_bases(tp: type[Any], name: str) -> Any:
from its bases (as done here).
Args:
tp: The type or class to search for the attribute.
tp: The type or class to search for the attribute. If a tuple, this is treated as a set of base classes.
name: The name of the attribute to retrieve.
Returns:
Expand All @@ -358,13 +360,19 @@ def get_attribute_from_bases(tp: type[Any], name: str) -> Any:
Raises:
AttributeError: If the attribute is not found in any class in the MRO.
"""
try:
return getattr(tp, name)
except AttributeError as e:
for base in reversed(mro(tp)):
if isinstance(tp, tuple):
for base in mro_for_bases(tp):
if hasattr(base, name):
return getattr(base, name)
raise e
raise AttributeError(f'{name} not found in {tp}')
else:
try:
return getattr(tp, name)
except AttributeError as e:
for base in mro(tp):
if hasattr(base, name):
return getattr(base, name)
raise e


def get_attribute_from_base_dicts(tp: type[Any], name: str) -> Any:
Expand Down
22 changes: 11 additions & 11 deletions pydantic/_internal/_model_construction.py
Expand Up @@ -17,7 +17,12 @@
from ..warnings import PydanticDeprecatedSince20
from ._config import ConfigWrapper
from ._core_utils import collect_invalid_schemas, flatten_schema_defs, inline_schema_defs
from ._decorators import ComputedFieldInfo, DecoratorInfos, PydanticDescriptorProxy
from ._decorators import (
ComputedFieldInfo,
DecoratorInfos,
PydanticDescriptorProxy,
get_attribute_from_bases,
)
from ._discriminated_union import apply_discriminators
from ._fields import collect_model_fields, is_valid_field_name, is_valid_privateattr_name
from ._generate_schema import GenerateSchema
Expand Down Expand Up @@ -374,16 +379,11 @@ def set_default_hash_func(namespace: dict[str, Any], bases: tuple[type[Any], ...
if '__hash__' in namespace:
return

base_hash_func = None
for base in bases:
base_hash_func = getattr(base, '__hash__', PydanticUndefined)
if base_hash_func is not PydanticUndefined:
break

if base_hash_func is None:
# This will be the case for `BaseModel` since it defines `__eq__` but not `__hash__`.
# In this case, we generate a standard hash function, generally for use with frozen models.

base_hash_func = get_attribute_from_bases(bases, '__hash__')
if base_hash_func in {None, object.__hash__}:
# If `__hash__` is None _or_ `object.__hash__`, we generate a hash function.
# It will be `None` if not overridden from BaseModel, but may be `object.__hash__` if there is another
# parent class earlier in the bases which doesn't override `__hash__` (e.g. `typing.Generic`).
def hash_func(self: Any) -> int:
return hash(self.__class__) + hash(tuple(self.__dict__.values()))

Expand Down
12 changes: 12 additions & 0 deletions tests/test_generics.py
Expand Up @@ -2613,3 +2613,15 @@ class C(B[int]):
'type': 'int_parsing',
},
]


def test_reverse_order_generic_hashability():
T = TypeVar('T')

class Model(Generic[T], BaseModel):
x: T
model_config = dict(frozen=True)

m1 = Model[int](x=1)
m2 = Model[int](x=1)
assert len({m1, m2}) == 1
14 changes: 14 additions & 0 deletions tests/test_types_typeddict.py
Expand Up @@ -21,6 +21,7 @@
PydanticUserError,
ValidationError,
)
from pydantic._internal._decorators import get_attribute_from_bases
from pydantic.functional_serializers import field_serializer, model_serializer
from pydantic.functional_validators import field_validator, model_validator
from pydantic.type_adapter import TypeAdapter
Expand Down Expand Up @@ -905,3 +906,16 @@ class MySubTypedDict(MyMiddleTypedDict):

validated_data = TypeAdapter(MySubTypedDict).validate_python({'x': 'ABC', 'y': 'DEF', 'z': 'GHI'})
assert validated_data == {'x': 'abc', 'y': 'def', 'z': 'ghi'}


def test_typeddict_mro():
class A(TypedDict):
x = 1

class B(A):
x = 2

class C(B):
pass

assert get_attribute_from_bases(C, 'x') == 2

0 comments on commit 2b8bea3

Please sign in to comment.