From 69240a509f314f12b91379527c971a785d559297 Mon Sep 17 00:00:00 2001 From: Hasan Ramezani Date: Thu, 21 Sep 2023 18:00:25 +0200 Subject: [PATCH] Add `BaseModel.model_validate_strings` and `TypeAdapter.validate_strings` --- pydantic/main.py | 25 ++++++++++++++++++++ pydantic/type_adapter.py | 13 +++++++++++ tests/test_dataclasses.py | 13 ++++++++++- tests/test_main.py | 48 ++++++++++++++++++++++++++++++++++++++ tests/test_type_adapter.py | 43 ++++++++++++++++++++++++++++++++++ 5 files changed, 141 insertions(+), 1 deletion(-) diff --git a/pydantic/main.py b/pydantic/main.py index 36cae5cfd01..9e7e2a97d8f 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -530,6 +530,31 @@ def model_validate_json( __tracebackhide__ = True return cls.__pydantic_validator__.validate_json(json_data, strict=strict, context=context) + @classmethod + def model_validate_strings( + cls: type[Model], + obj: Any, + *, + strict: bool | None = None, + context: dict[str, Any] | None = None, + ) -> Model: + """Validate the given object contains string data against the Pydantic model. + + Args: + obj: The object contains string data to validate. + strict: Whether to enforce types strictly. + context: Extra variables to pass to the validator. + + Returns: + The validated Pydantic model. + + Raises: + ValueError: If `json_data` is not a JSON string. + """ + # `__tracebackhide__` tells pytest and some other tools to omit this function from tracebacks + __tracebackhide__ = True + return cls.__pydantic_validator__.validate_strings(obj, strict=strict, context=context) + @classmethod def __get_pydantic_core_schema__( cls, __source: type[BaseModel], __handler: _annotated_handlers.GetCoreSchemaHandler diff --git a/pydantic/type_adapter.py b/pydantic/type_adapter.py index 85ee6d028d0..af6aa04bba1 100644 --- a/pydantic/type_adapter.py +++ b/pydantic/type_adapter.py @@ -221,6 +221,19 @@ def validate_json( """ return self.validator.validate_json(__data, strict=strict, context=context) + def validate_strings(self, __obj: Any, *, strict: bool | None = None, context: dict[str, Any] | None = None) -> T: + """Validate object contains string data against the model. + + Args: + __obj: The object contains string data to validate. + strict: Whether to strictly check types. + context: Additional context to use during validation. + + Returns: + The validated object. + """ + return self.validator.validate_strings(__obj, strict=strict, context=context) + def get_default_value(self, *, strict: bool | None = None, context: dict[str, Any] | None = None) -> Some[T] | None: """Get the default value for the wrapped type. diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index 02a3cb22a6b..b7e9d91e5d7 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -6,7 +6,7 @@ import traceback from collections.abc import Hashable from dataclasses import InitVar -from datetime import datetime +from datetime import date, datetime from pathlib import Path from typing import Any, Callable, ClassVar, Dict, FrozenSet, Generic, List, Optional, Set, TypeVar, Union @@ -2593,3 +2593,14 @@ class Foo: obj = Foo(**{'some-var': 'some_value'}) assert obj.some_var == 'some_value' + + +def test_validate_strings(): + @pydantic.dataclasses.dataclass + class Nested: + d: date + + class Model(BaseModel): + n: Nested + + assert Model.model_validate_strings({'n': {'d': '2017-01-01'}}).n.d == date(2017, 1, 1) diff --git a/tests/test_main.py b/tests/test_main.py index 79ece3171ba..63932ec05a3 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -5,6 +5,7 @@ from collections import defaultdict from copy import deepcopy from dataclasses import dataclass +from datetime import date, datetime from enum import Enum from typing import ( Any, @@ -2543,6 +2544,53 @@ class UnrelatedClass: assert res == ModelFromAttributesFalse(x=1) +@pytest.mark.parametrize( + 'field_type,input_value,expected,raises_match,strict', + [ + (bool, 'true', True, None, False), + (bool, 'true', True, None, True), + (bool, 'false', False, None, False), + (bool, 'e', ValidationError, 'type=bool_parsing', False), + (int, '1', 1, None, False), + (int, '1', 1, None, True), + (int, 'xxx', ValidationError, 'type=int_parsing', True), + (float, '1.1', 1.1, None, False), + (float, '1.10', 1.1, None, False), + (float, '1.1', 1.1, None, True), + (float, '1.10', 1.1, None, True), + (date, '2017-01-01', date(2017, 1, 1), None, False), + (date, '2017-01-01', date(2017, 1, 1), None, True), + (date, '2017-01-01T12:13:14.567', ValidationError, 'type=date_from_datetime_inexact', False), + (date, '2017-01-01T12:13:14.567', ValidationError, 'type=date_parsing', True), + (date, '2017-01-01T00:00:00', date(2017, 1, 1), None, False), + (date, '2017-01-01T00:00:00', ValidationError, 'type=date_parsing', True), + (datetime, '2017-01-01T12:13:14.567', datetime(2017, 1, 1, 12, 13, 14, 567_000), None, False), + (datetime, '2017-01-01T12:13:14.567', datetime(2017, 1, 1, 12, 13, 14, 567_000), None, True), + ], + ids=repr, +) +def test_model_validate_strings(field_type, input_value, expected, raises_match, strict): + class Model(BaseModel): + x: field_type + + if raises_match is not None: + with pytest.raises(expected, match=raises_match): + Model.model_validate_strings({'x': input_value}, strict=strict) + else: + Model.model_validate_strings({'x': input_value}, strict=strict).x == expected + + +@pytest.mark.parametrize('strict', [True, False]) +def test_model_validate_strings_dict(strict): + class Model(BaseModel): + x: Dict[int, date] + + assert Model.model_validate_strings({'x': {'1': '2017-01-01', '2': '2017-01-02'}}, strict=strict).x == { + 1: date(2017, 1, 1), + 2: date(2017, 1, 2), + } + + def test_model_signature_annotated() -> None: class Model(BaseModel): x: Annotated[int, 123] diff --git a/tests/test_type_adapter.py b/tests/test_type_adapter.py index 680f6763239..0667c666f7f 100644 --- a/tests/test_type_adapter.py +++ b/tests/test_type_adapter.py @@ -1,6 +1,7 @@ import json import sys from dataclasses import dataclass +from datetime import date, datetime from typing import Any, Dict, ForwardRef, Generic, List, NamedTuple, Tuple, TypeVar, Union import pytest @@ -266,3 +267,45 @@ class UnrelatedClass: res = ta.validate_python(UnrelatedClass(), from_attributes=True) assert res == ModelFromAttributesFalse(x=1) + + +@pytest.mark.parametrize( + 'field_type,input_value,expected,raises_match,strict', + [ + (bool, 'true', True, None, False), + (bool, 'true', True, None, True), + (bool, 'false', False, None, False), + (bool, 'e', ValidationError, 'type=bool_parsing', False), + (int, '1', 1, None, False), + (int, '1', 1, None, True), + (int, 'xxx', ValidationError, 'type=int_parsing', True), + (float, '1.1', 1.1, None, False), + (float, '1.10', 1.1, None, False), + (float, '1.1', 1.1, None, True), + (float, '1.10', 1.1, None, True), + (date, '2017-01-01', date(2017, 1, 1), None, False), + (date, '2017-01-01', date(2017, 1, 1), None, True), + (date, '2017-01-01T12:13:14.567', ValidationError, 'type=date_from_datetime_inexact', False), + (date, '2017-01-01T12:13:14.567', ValidationError, 'type=date_parsing', True), + (date, '2017-01-01T00:00:00', date(2017, 1, 1), None, False), + (date, '2017-01-01T00:00:00', ValidationError, 'type=date_parsing', True), + (datetime, '2017-01-01T12:13:14.567', datetime(2017, 1, 1, 12, 13, 14, 567_000), None, False), + (datetime, '2017-01-01T12:13:14.567', datetime(2017, 1, 1, 12, 13, 14, 567_000), None, True), + ], + ids=repr, +) +def test_validate_strings(field_type, input_value, expected, raises_match, strict): + if raises_match is not None: + print(TypeAdapter(field_type).core_schema) + with pytest.raises(expected, match=raises_match): + TypeAdapter(field_type).validate_strings(input_value, strict=strict) + else: + TypeAdapter(field_type).validate_strings(input_value, strict=strict) == expected + + +@pytest.mark.parametrize('strict', [True, False]) +def test_validate_strings_dict(strict): + assert TypeAdapter(Dict[int, date]).validate_strings({'1': '2017-01-01', '2': '2017-01-02'}, strict=strict) == { + 1: date(2017, 1, 1), + 2: date(2017, 1, 2), + }