From c6e2f89f36805aef03194c366afb30faa798d188 Mon Sep 17 00:00:00 2001 From: sydney-runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Wed, 13 Sep 2023 22:47:58 -0500 Subject: [PATCH] Fix Generic Dataclass Fields Mutation Bug (when using TypeAdapter) (#7435) Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com> --- pydantic/_internal/_core_utils.py | 13 ++++++++-- pydantic/_internal/_generate_schema.py | 4 +-- tests/test_dataclasses.py | 35 ++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 4 deletions(-) diff --git a/pydantic/_internal/_core_utils.py b/pydantic/_internal/_core_utils.py index 6d9687b427..a012086cc9 100644 --- a/pydantic/_internal/_core_utils.py +++ b/pydantic/_internal/_core_utils.py @@ -1,7 +1,16 @@ from __future__ import annotations from collections import defaultdict -from typing import Any, Callable, Hashable, Iterable, TypeVar, Union, cast +from typing import ( + Any, + Callable, + Hashable, + Iterable, + TypeVar, + Union, + _GenericAlias, # type: ignore + cast, +) from pydantic_core import CoreSchema, core_schema from typing_extensions import TypeAliasType, TypeGuard, get_args @@ -65,7 +74,7 @@ def get_type_ref(type_: type[Any], args_override: tuple[type[Any], ...] | None = when creating generic models without needing to create a concrete class. """ origin = type_ - args = args_override or () + args = get_args(type_) if isinstance(type_, _GenericAlias) else (args_override or ()) generic_metadata = getattr(type_, '__pydantic_generic_metadata__', None) if generic_metadata: origin = generic_metadata['origin'] or origin diff --git a/pydantic/_internal/_generate_schema.py b/pydantic/_internal/_generate_schema.py index c1aa6f4474..299b63de3c 100644 --- a/pydantic/_internal/_generate_schema.py +++ b/pydantic/_internal/_generate_schema.py @@ -9,7 +9,7 @@ import typing import warnings from contextlib import contextmanager -from copy import copy +from copy import copy, deepcopy from enum import Enum from functools import partial from inspect import Parameter, _ParameterKind, signature @@ -1275,7 +1275,7 @@ def _dataclass_schema( from ..dataclasses import is_pydantic_dataclass if is_pydantic_dataclass(dataclass): - fields = dataclass.__pydantic_fields__ + fields = deepcopy(dataclass.__pydantic_fields__) if typevars_map: for field in fields.values(): field.apply_typevars_map(typevars_map, self._types_namespace) diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index d7bb62efe0..e6f32fa0fb 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -2058,6 +2058,41 @@ class GenericDataclass(Generic[T]): assert exc_info.value.errors(include_url=False) == output_value +def test_multiple_parametrized_generic_dataclasses(): + T = TypeVar('T') + + @pydantic.dataclasses.dataclass + class GenericDataclass(Generic[T]): + x: T + + validator1 = pydantic.TypeAdapter(GenericDataclass[int]) + validator2 = pydantic.TypeAdapter(GenericDataclass[str]) + + # verify that generic parameters are showing up in the type ref for generic dataclasses + # this can probably be removed if the schema changes in some way that makes this part of the test fail + assert '[int:' in validator1.core_schema['schema']['schema_ref'] + assert '[str:' in validator2.core_schema['schema']['schema_ref'] + + assert validator1.validate_python({'x': 1}).x == 1 + assert validator2.validate_python({'x': 'hello world'}).x == 'hello world' + + with pytest.raises(ValidationError) as exc_info: + validator2.validate_python({'x': 1}) + assert exc_info.value.errors(include_url=False) == [ + {'input': 1, 'loc': ('x',), 'msg': 'Input should be a valid string', 'type': 'string_type'} + ] + with pytest.raises(ValidationError) as exc_info: + validator1.validate_python({'x': 'hello world'}) + assert exc_info.value.errors(include_url=False) == [ + { + 'input': 'hello world', + 'loc': ('x',), + 'msg': 'Input should be a valid integer, unable to parse string as an integer', + 'type': 'int_parsing', + } + ] + + @pytest.mark.parametrize('dataclass_decorator', **dataclass_decorators(include_identity=True)) def test_pydantic_dataclass_preserves_metadata(dataclass_decorator: Callable[[Any], Any]) -> None: @dataclass_decorator