Skip to content

Commit

Permalink
Disallow config specification when initializing a TypeAdapter whe…
Browse files Browse the repository at this point in the history
…n the annotated type has config already (#8365)
  • Loading branch information
sydney-runkle committed Dec 14, 2023
1 parent 2155e4d commit 17faa91
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 9 deletions.
23 changes: 15 additions & 8 deletions pydantic/type_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Set, TypeVar, Union, cast, overload

from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator, Some
from typing_extensions import Literal, is_typeddict
from typing_extensions import Literal, get_args, is_typeddict

from pydantic.errors import PydanticUserError
from pydantic.main import BaseModel
Expand Down Expand Up @@ -98,6 +98,15 @@ def _getattr_no_parents(obj: Any, attribute: str) -> Any:
raise AttributeError(attribute)


def _type_has_config(type_: Any) -> bool:
"""Returns whether the type has config."""
try:
return issubclass(type_, BaseModel) or is_dataclass(type_) or is_typeddict(type_)
except TypeError:
# type is not a class
return False


class TypeAdapter(Generic[T]):
"""Type adapters provide a flexible way to perform validation and serialization based on a Python type.
Expand Down Expand Up @@ -168,13 +177,9 @@ def __init__(
Returns:
A type adapter configured for the specified `type`.
"""
config_wrapper = _config.ConfigWrapper(config)

try:
type_has_config = issubclass(type, BaseModel) or is_dataclass(type) or is_typeddict(type)
except TypeError:
# type is not a class
type_has_config = False
type_is_annotated: bool = _typing_extra.is_annotated(type)
annotated_type: Any = get_args(type)[0] if type_is_annotated else None
type_has_config: bool = _type_has_config(annotated_type if type_is_annotated else type)

if type_has_config and config is not None:
raise PydanticUserError(
Expand All @@ -185,6 +190,8 @@ def __init__(
code='type-adapter-config-unused',
)

config_wrapper = _config.ConfigWrapper(config)

core_schema: CoreSchema
try:
core_schema = _getattr_no_parents(type, '__pydantic_core_schema__')
Expand Down
37 changes: 36 additions & 1 deletion tests/test_type_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

import pytest
from pydantic_core import ValidationError
from typing_extensions import TypeAlias, TypedDict
from typing_extensions import Annotated, TypeAlias, TypedDict

from pydantic import BaseModel, TypeAdapter, ValidationInfo, field_validator
from pydantic.config import ConfigDict
from pydantic.errors import PydanticUserError

ItemType = TypeVar('ItemType')

Expand Down Expand Up @@ -309,3 +310,37 @@ def test_validate_strings_dict(strict):
1: date(2017, 1, 1),
2: date(2017, 1, 2),
}


def test_annotated_type_disallows_config() -> None:
class Model(BaseModel):
x: int

with pytest.raises(PydanticUserError, match='Cannot use `config`'):
TypeAdapter(Annotated[Model, ...], config=ConfigDict(strict=False))


@pytest.mark.xfail(reason='waiting for fix in core for ser_json_bytes application')
def test_ta_config_with_annotated_type() -> None:
class TestValidator(BaseModel):
x: str

model_config = ConfigDict(str_to_lower=True)

assert TestValidator(x='ABC').x == 'abc'
assert TypeAdapter(TestValidator).validate_python({'x': 'ABC'}).x == 'abc'
assert TypeAdapter(Annotated[TestValidator, ...]).validate_python({'x': 'ABC'}).x == 'abc'

class TestSerializer(BaseModel):
some_bytes: bytes
model_config = ConfigDict(ser_json_bytes='base64')

result = TestSerializer(some_bytes=b'\xaa')
assert result.model_dump(mode='json') == {'some_bytes': 'qg=='}
assert TypeAdapter(TestSerializer).dump_python(result, mode='json') == {'some_bytes': 'qg=='}

# cases where SchemaSerializer is constructed within TypeAdapter's __init__
assert TypeAdapter(Annotated[TestSerializer, ...]).dump_python(result, mode='json') == {'some_bytes': 'qg=='}
assert TypeAdapter(Annotated[list[TestSerializer], ...]).dump_python([result], mode='json') == [
{'some_bytes': 'qg=='}
]

0 comments on commit 17faa91

Please sign in to comment.