Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix broken parametrized bases with GenericModels #5052

Merged
merged 18 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,42 @@ jobs:
with:
name: docs
path: site

test-memray:
name: test memray
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3

- name: set up python
uses: actions/setup-python@v4
with:
python-version: '3.10'

- uses: actions/cache@v3
id: cache
with:
path: ${{ env.pythonLocation }}
key: >
test-memray
${{ runner.os }}
${{ env.pythonLocation }}
${{ hashFiles('setup.py') }}
${{ hashFiles('requirements.txt') }}
${{ hashFiles('tests/requirements-testing.txt') }}

- name: install
run: |
make install-testing
pip install pytest-memray==1.4.0

- name: compile
run: |
make build-trace
python -c "import sys, pydantic; print('compiled:', pydantic.compiled); sys.exit(0 if pydantic.compiled else 1)"

- name: test
run: pytest --ignore=tests/mypy/ --memray

test-linux-compiled:
name: test py${{ matrix.python-version }} on linux compiled
Expand Down
1 change: 1 addition & 0 deletions changes/5052-MarkusSintonen.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix broken parametrized bases handling with `GenericModel`s with complex sets of models.
22 changes: 18 additions & 4 deletions pydantic/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Union,
cast,
)
from weakref import WeakKeyDictionary, WeakValueDictionary

from typing_extensions import Annotated

Expand All @@ -25,23 +26,36 @@
from .main import BaseModel, create_model
from .types import JsonWrapper
from .typing import display_as_type, get_all_type_hints, get_args, get_origin, typing_base
from .utils import LimitedDict, all_identical, lenient_issubclass
from .utils import all_identical, lenient_issubclass

if sys.version_info >= (3, 10):
from typing import _UnionGenericAlias

GenericModelT = TypeVar('GenericModelT', bound='GenericModel')
TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type

CacheKey = Tuple[Type[Any], Any, Tuple[Any, ...]]
Parametrization = Mapping[TypeVarType, Type[Any]]

_generic_types_cache: LimitedDict[Tuple[Type[Any], Union[Any, Tuple[Any, ...]]], Type[BaseModel]] = LimitedDict()
# weak dictionaries allow the dynamically created parametrized versions of generic models to get collected
# once they are no longer referenced by the caller.
if sys.version_info >= (3, 9): # Typing for weak dictionaries available at 3.9
GenericTypesCache = WeakValueDictionary[CacheKey, Type[BaseModel]]
AssignedParameters = WeakKeyDictionary[Type[BaseModel], Parametrization]
else:
GenericTypesCache = WeakValueDictionary
AssignedParameters = WeakKeyDictionary

# _generic_types_cache is a Mapping from __class_getitem__ arguments to the parametrized version of generic models.
# This ensures multiple calls of e.g. A[B] return always the same class.
_generic_types_cache = GenericTypesCache()
samuelcolvin marked this conversation as resolved.
Show resolved Hide resolved

# _assigned_parameters is a Mapping from parametrized version of generic models to assigned types of parametrizations
# as captured during construction of the class (not instances).
# E.g., for generic model `Model[A, B]`, when parametrized model `Model[int, str]` is created,
# `Model[int, str]`: {A: int, B: str}` will be stored in `_assigned_parameters`.
# (This information is only otherwise available after creation from the class name string).
_assigned_parameters: LimitedDict[Type[Any], Parametrization] = LimitedDict()
_assigned_parameters = AssignedParameters()


class GenericModel(BaseModel):
Expand All @@ -67,7 +81,7 @@ def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[T

"""

def _cache_key(_params: Any) -> Tuple[Type[GenericModelT], Any, Tuple[Any, ...]]:
def _cache_key(_params: Any) -> CacheKey:
args = get_args(_params)
# python returns a list for Callables, which is not hashable
if len(args) == 2 and isinstance(args[0], list):
Expand Down
38 changes: 0 additions & 38 deletions pydantic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
Iterator,
List,
Mapping,
MutableMapping,
NoReturn,
Optional,
Set,
Expand Down Expand Up @@ -80,7 +79,6 @@
'get_unique_discriminator_alias',
'get_discriminator_alias_and_values',
'DUNDER_ATTRIBUTES',
'LimitedDict',
)

ROOT_KEY = '__root__'
Expand Down Expand Up @@ -803,39 +801,3 @@ def _get_union_alias_and_all_values(
# unzip: [('alias_a',('v1', 'v2)), ('alias_b', ('v3',))] => [('alias_a', 'alias_b'), (('v1', 'v2'), ('v3',))]
all_aliases, all_values = zip(*zipped_aliases_values)
return get_unique_discriminator_alias(all_aliases, discriminator_key), all_values


KT = TypeVar('KT')
VT = TypeVar('VT')
if TYPE_CHECKING:
# Annoying inheriting from `MutableMapping` and `dict` breaks cython, hence this work around
class LimitedDict(dict, MutableMapping[KT, VT]): # type: ignore[type-arg]
def __init__(self, size_limit: int = 1000):
...

else:

class LimitedDict(dict):
"""
Limit the size/length of a dict used for caching to avoid unlimited increase in memory usage.

Since the dict is ordered, and we always remove elements from the beginning, this is effectively a FIFO cache.

Annoying inheriting from `MutableMapping` breaks cython.
"""

def __init__(self, size_limit: int = 1000):
self.size_limit = size_limit
super().__init__()

def __setitem__(self, __key: Any, __value: Any) -> None:
super().__setitem__(__key, __value)
if len(self) > self.size_limit:
excess = len(self) - self.size_limit + self.size_limit // 10
to_remove = list(self.keys())[:excess]
for key in to_remove:
del self[key]

def __class_getitem__(cls, *args: Any) -> Any:
# to avoid errors with 3.7
pass
8 changes: 4 additions & 4 deletions tests/requirements-testing.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ hypothesis==6.54.4
# pin importlib-metadata as upper versions need typing-extensions to work if on Python < 3.8
importlib-metadata==3.1.0;python_version<"3.8"
mypy==0.971
pytest==7.1.2
pytest-cov==3.0.0
pytest-mock==3.8.2
pytest-sugar==0.9.5
pytest==7.2.1
pytest-cov==4.0.0
pytest-mock==3.10.0
pytest-sugar==0.9.6
# pin typing-extensions to minimum requirement - see #4885
typing-extensions==4.2.0
2 changes: 1 addition & 1 deletion tests/test_edge_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -1894,7 +1894,7 @@ def get_double_a(self) -> float:
model = Model(a=10.2)
assert model.a == 10.2
assert model.b == 10
return model.get_double_a() == 20.2
assert model.get_double_a() == 20.2


def test_resolve_annotations_module_missing(tmp_path):
Expand Down
133 changes: 126 additions & 7 deletions tests/test_generics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import gc
import itertools
import json
import sys
from enum import Enum
Expand All @@ -6,12 +8,14 @@
Callable,
ClassVar,
Dict,
FrozenSet,
Generic,
Iterable,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Expand All @@ -21,8 +25,19 @@
import pytest
from typing_extensions import Annotated, Literal

from pydantic import BaseModel, Field, Json, ValidationError, root_validator, validator
from pydantic.generics import GenericModel, _generic_types_cache, iter_contained_typevars, replace_types
from pydantic import BaseModel, Field, Json, ValidationError, create_model, root_validator, validator
from pydantic.generics import (
GenericModel,
_assigned_parameters,
_generic_types_cache,
iter_contained_typevars,
replace_types,
)


@pytest.fixture(autouse=True)
def clean_cache():
gc.collect() # cleans up _generic_types_cache for checking item counts in the cache


def test_generic_name():
Expand Down Expand Up @@ -229,10 +244,13 @@ def test_cover_cache():
class Model(GenericModel, Generic[T]):
x: T

Model[int] # adds both with-tuple and without-tuple version to cache
models = [] # keep references to models to get cache size

models.append(Model[int]) # adds both with-tuple and without-tuple version to cache
assert len(_generic_types_cache) == cache_size + 2
Model[int] # uses the cache
models.append(Model[int]) # uses the cache
assert len(_generic_types_cache) == cache_size + 2
del models


def test_cache_keys_are_hashable():
Expand All @@ -246,19 +264,66 @@ class MyGenericModel(GenericModel, Generic[T]):
# Callable's first params get converted to a list, which is not hashable.
# Make sure we can handle that special case
Simple = MyGenericModel[Callable[[int], str]]
models = [] # keep references to models to get cache size
models.append(Simple)
assert len(_generic_types_cache) == cache_size + 2
# Nested Callables
MyGenericModel[Callable[[C], Iterable[str]]]
models.append(MyGenericModel[Callable[[C], Iterable[str]]])
assert len(_generic_types_cache) == cache_size + 4
MyGenericModel[Callable[[Simple], Iterable[int]]]
models.append(MyGenericModel[Callable[[Simple], Iterable[int]]])
assert len(_generic_types_cache) == cache_size + 6
MyGenericModel[Callable[[MyGenericModel[C]], Iterable[int]]]
models.append(MyGenericModel[Callable[[MyGenericModel[C]], Iterable[int]]])
assert len(_generic_types_cache) == cache_size + 10

class Model(BaseModel):
x: MyGenericModel[Callable[[C], Iterable[str]]] = Field(...)

models.append(Model)
assert len(_generic_types_cache) == cache_size + 10
del models


def test_caches_get_cleaned_up():
types_cache_size = len(_generic_types_cache)
params_cache_size = len(_assigned_parameters)
T = TypeVar('T')

class MyGenericModel(GenericModel, Generic[T]):
x: T

Model = MyGenericModel[int]
assert len(_generic_types_cache) == types_cache_size + 2
assert len(_assigned_parameters) == params_cache_size + 1
del Model
gc.collect()
assert len(_generic_types_cache) == types_cache_size
assert len(_assigned_parameters) == params_cache_size


def test_generics_work_with_many_parametrized_base_models():
cache_size = len(_generic_types_cache)
params_size = len(_assigned_parameters)
count_create_models = 1000
T = TypeVar('T')
C = TypeVar('C')

class A(GenericModel, Generic[T, C]):
x: T
y: C

class B(A[int, C], GenericModel, Generic[C]):
pass

models = [create_model(f'M{i}') for i in range(count_create_models)]
generics = []
for m in models:
Working = B[m]
generics.append(Working)

assert len(_generic_types_cache) == cache_size + count_create_models * 5 + 1
assert len(_assigned_parameters) == params_size + count_create_models * 3 + 1
del models
del generics


def test_generic_config():
Expand Down Expand Up @@ -1379,3 +1444,57 @@ class Payload(BaseModel):
'properties': {'payload_field': {'title': 'Payload Field', 'type': 'string'}},
'required': ['payload_field'],
}


def memray_limit_memory(limit):
if '--memray' in sys.argv:
return pytest.mark.limit_memory(limit)
else:
return pytest.mark.skip(reason='memray not enabled')


@memray_limit_memory('100 MB')
def test_generics_memory_use():
"""See:
- https://github.com/pydantic/pydantic/issues/3829
- https://github.com/pydantic/pydantic/pull/4083
- https://github.com/pydantic/pydantic/pull/5052
"""

T = TypeVar('T')
U = TypeVar('U')
V = TypeVar('V')

class MyModel(GenericModel, Generic[T, U, V]):
message: Json[T]
field: Dict[U, V]

class Outer(GenericModel, Generic[T]):
inner: T

types = [
int,
str,
float,
bool,
bytes,
]

containers = [
List,
Tuple,
Set,
FrozenSet,
]

all = [*types, *[container[tp] for container in containers for tp in types]]

total = list(itertools.product(all, all, all))

for t1, t2, t3 in total:

class Foo(MyModel[t1, t2, t3]):
pass

class _(Outer[Foo]):
pass
Loading