From 76ecbf563b575d3c2901660b9f0b6df8a3dcb933 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Tue, 18 Feb 2025 11:06:20 -0500 Subject: [PATCH 1/3] fix strict behavior for unions --- src/validators/union.rs | 22 ++-------------------- 1 file changed, 2 insertions(+), 20 deletions(-) diff --git a/src/validators/union.rs b/src/validators/union.rs index bfe744212..dccc09ce7 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -8,7 +8,7 @@ use pyo3::{intern, PyTraverseError, PyVisit}; use smallvec::SmallVec; use crate::build_tools::py_schema_err; -use crate::build_tools::{is_strict, schema_or_config}; +use crate::build_tools::schema_or_config; use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD}; use crate::errors::{ErrorType, ToErrorValue, ValError, ValLineError, ValResult}; use crate::input::{BorrowInput, Input, ValidatedDict}; @@ -43,7 +43,6 @@ pub struct UnionValidator { mode: UnionMode, choices: Vec<(CombinedValidator, Option)>, custom_error: Option, - strict: bool, name: String, } @@ -91,7 +90,6 @@ impl BuildValidator for UnionValidator { mode, choices, custom_error: CustomError::build(schema, config, definitions)?, - strict: is_strict(schema, config)?, name: format!("{}[{descr}]", Self::EXPECTED_TYPE), } .into()) @@ -110,17 +108,11 @@ impl UnionValidator { let old_exactness = state.exactness; let old_fields_set_count = state.fields_set_count; - let strict = state.strict_or(self.strict); let mut errors = MaybeErrors::new(self.custom_error.as_ref()); let mut best_match: Option<(Py, Exactness, Option)> = None; for (choice, label) in &self.choices { - let state = &mut state.rebind_extra(|extra| { - if strict { - extra.strict = Some(strict); - } - }); state.exactness = Some(Exactness::Exact); state.fields_set_count = None; let result = choice.validate(py, input, state); @@ -197,14 +189,6 @@ impl UnionValidator { ) -> ValResult { let mut errors = MaybeErrors::new(self.custom_error.as_ref()); - let mut rebound_state; - let state = if state.strict_or(self.strict) { - rebound_state = state.rebind_extra(|extra| extra.strict = Some(true)); - &mut rebound_state - } else { - state - }; - for (validator, label) in &self.choices { match validator.validate(py, input, state) { Err(ValError::LineErrors(lines)) => errors.push(validator, label.as_deref(), lines), @@ -300,7 +284,6 @@ pub struct TaggedUnionValidator { discriminator: Discriminator, lookup: LiteralLookup, from_attributes: bool, - strict: bool, custom_error: Option, tags_repr: String, discriminator_repr: String, @@ -349,7 +332,6 @@ impl BuildValidator for TaggedUnionValidator { discriminator, lookup, from_attributes, - strict: is_strict(schema, config)?, custom_error: CustomError::build(schema, config, definitions)?, tags_repr, discriminator_repr, @@ -371,7 +353,7 @@ impl Validator for TaggedUnionValidator { match &self.discriminator { Discriminator::LookupKey(lookup_key) => { let from_attributes = state.extra().from_attributes.unwrap_or(self.from_attributes); - let dict = input.validate_model_fields(self.strict, from_attributes)?; + let dict = input.validate_model_fields(state.strict_or(false), from_attributes)?; // note this methods returns PyResult>, the outer Err is just for // errors when getting attributes which should be "raised" let tag = match dict.get_item(lookup_key)? { From 044ff1c97c24e266061ed547b5391d464068af17 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Mon, 24 Feb 2025 08:50:07 -0500 Subject: [PATCH 2/3] fixing and adding tests --- python/pydantic_core/core_schema.py | 3 -- tests/benchmarks/test_micro_benchmarks.py | 8 ++-- tests/validators/test_bytes.py | 2 +- .../validators/test_definitions_recursive.py | 8 ++-- tests/validators/test_union.py | 48 ++++++++++++++++--- 5 files changed, 52 insertions(+), 17 deletions(-) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index e999bdcfc..9bf0d2349 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -2525,7 +2525,6 @@ def union_schema( custom_error_message: str | None = None, custom_error_context: dict[str, str | int] | None = None, mode: Literal['smart', 'left_to_right'] | None = None, - strict: bool | None = None, ref: str | None = None, metadata: dict[str, Any] | None = None, serialization: SerSchema | None = None, @@ -2551,7 +2550,6 @@ def union_schema( mode: How to select which choice to return * `smart` (default) will try to return the choice which is the closest match to the input value * `left_to_right` will return the first choice in `choices` which succeeds validation - strict: Whether the underlying schemas should be validated with strict mode ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core serialization: Custom serialization schema @@ -2564,7 +2562,6 @@ def union_schema( custom_error_message=custom_error_message, custom_error_context=custom_error_context, mode=mode, - strict=strict, ref=ref, metadata=metadata, serialization=serialization, diff --git a/tests/benchmarks/test_micro_benchmarks.py b/tests/benchmarks/test_micro_benchmarks.py index 3435174a5..5640431ae 100644 --- a/tests/benchmarks/test_micro_benchmarks.py +++ b/tests/benchmarks/test_micro_benchmarks.py @@ -686,8 +686,9 @@ def test_smart_union_coerce_core(self, benchmark): def test_strict_union_core(self, benchmark): v = SchemaValidator( schema=core_schema.union_schema( - strict=True, choices=[core_schema.bool_schema(), core_schema.int_schema(), core_schema.str_schema()] - ) + choices=[core_schema.bool_schema(), core_schema.int_schema(), core_schema.str_schema()] + ), + config=CoreConfig(strict=True), ) benchmark(v.validate_python, 1) @@ -695,7 +696,8 @@ def test_strict_union_core(self, benchmark): @pytest.mark.benchmark(group='strict-union-error') def test_strict_union_error_core(self, benchmark): v = SchemaValidator( - schema=core_schema.union_schema(strict=True, choices=[core_schema.bool_schema(), core_schema.str_schema()]) + schema=core_schema.union_schema(choices=[core_schema.bool_schema(), core_schema.str_schema()]), + config=CoreConfig(strict=True), ) def validate_with_expected_error(): diff --git a/tests/validators/test_bytes.py b/tests/validators/test_bytes.py index 1110a34d3..0506957ec 100644 --- a/tests/validators/test_bytes.py +++ b/tests/validators/test_bytes.py @@ -91,7 +91,7 @@ def test_constrained_bytes(py_and_json: PyAndJson, opts: dict[str, Any], input, def test_union(): - v = SchemaValidator(cs.union_schema(choices=[cs.str_schema(), cs.bytes_schema()], strict=True)) + v = SchemaValidator(cs.union_schema(choices=[cs.str_schema(strict=True), cs.bytes_schema(strict=True)])) assert v.validate_python('oh, a string') == 'oh, a string' assert v.validate_python(b'oh, bytes') == b'oh, bytes' diff --git a/tests/validators/test_definitions_recursive.py b/tests/validators/test_definitions_recursive.py index 4c2577408..34b8ca445 100644 --- a/tests/validators/test_definitions_recursive.py +++ b/tests/validators/test_definitions_recursive.py @@ -611,11 +611,11 @@ def test_union_cycle(strict: bool): 'foobar': core_schema.typed_dict_field( core_schema.list_schema(core_schema.definition_reference_schema('root-schema')) ) - } + }, + strict=strict, ) ], auto_collapse=False, - strict=strict, ref='root-schema', ) ], @@ -700,11 +700,11 @@ def f(input_value, info): ) ], auto_collapse=False, - strict=strict, ref='root-schema', ) ], - ) + ), + config=CoreConfig(strict=strict), ) with pytest.raises(ValidationError) as exc_info: diff --git a/tests/validators/test_union.py b/tests/validators/test_union.py index 70acd6cfd..f35ebfec0 100644 --- a/tests/validators/test_union.py +++ b/tests/validators/test_union.py @@ -8,7 +8,7 @@ import pytest from dirty_equals import IsFloat, IsInt -from pydantic_core import SchemaError, SchemaValidator, ValidationError, core_schema, validate_core_schema +from pydantic_core import CoreConfig, SchemaError, SchemaValidator, ValidationError, core_schema, validate_core_schema from ..conftest import plain_repr @@ -262,16 +262,47 @@ def test_one_choice(): assert v.validate_python('hello') == 'hello' -def test_strict_union(): +def test_strict_union_flag() -> None: + v = SchemaValidator(core_schema.union_schema(choices=[core_schema.bool_schema(), core_schema.int_schema()])) + assert v.validate_python(1, strict=True) == 1 + assert v.validate_python(123, strict=True) == 123 + + with pytest.raises(ValidationError) as exc_info: + v.validate_python('123', strict=True) + + assert exc_info.value.errors(include_url=False) == [ + {'type': 'bool_type', 'loc': ('bool',), 'msg': 'Input should be a valid boolean', 'input': '123'}, + {'type': 'int_type', 'loc': ('int',), 'msg': 'Input should be a valid integer', 'input': '123'}, + ] + + +def test_strict_union_config_level() -> None: v = SchemaValidator( - core_schema.union_schema(strict=True, choices=[core_schema.bool_schema(), core_schema.int_schema()]) + core_schema.union_schema(choices=[core_schema.bool_schema(), core_schema.int_schema()]), + config=CoreConfig(strict=True), ) + assert v.validate_python(1) == 1 assert v.validate_python(123) == 123 with pytest.raises(ValidationError) as exc_info: v.validate_python('123') + assert exc_info.value.errors(include_url=False) == [ + {'type': 'bool_type', 'loc': ('bool',), 'msg': 'Input should be a valid boolean', 'input': '123'}, + {'type': 'int_type', 'loc': ('int',), 'msg': 'Input should be a valid integer', 'input': '123'}, + ] + +def test_strict_union_member_level() -> None: + v = SchemaValidator( + core_schema.union_schema(choices=[core_schema.bool_schema(strict=True), core_schema.int_schema(strict=True)]) + ) + + assert v.validate_python(1) == 1 + assert v.validate_python(123) == 123 + + with pytest.raises(ValidationError) as exc_info: + v.validate_python('123') assert exc_info.value.errors(include_url=False) == [ {'type': 'bool_type', 'loc': ('bool',), 'msg': 'Input should be a valid boolean', 'input': '123'}, {'type': 'int_type', 'loc': ('int',), 'msg': 'Input should be a valid integer', 'input': '123'}, @@ -469,10 +500,10 @@ def test_left_to_right_union(): def test_left_to_right_union_strict(): - choices = [core_schema.int_schema(), core_schema.float_schema()] + choices = [core_schema.int_schema(strict=True), core_schema.float_schema(strict=True)] # left_to_right union will select not cast if int first (strict int will not accept float) - v = SchemaValidator(core_schema.union_schema(choices, mode='left_to_right', strict=True)) + v = SchemaValidator(core_schema.union_schema(choices, mode='left_to_right')) out = v.validate_python(1) assert out == 1 assert isinstance(out, int) @@ -482,7 +513,12 @@ def test_left_to_right_union_strict(): assert isinstance(out, float) # reversing union will select float always (as strict float will accept int) - v = SchemaValidator(core_schema.union_schema(list(reversed(choices)), mode='left_to_right', strict=True)) + v = SchemaValidator( + core_schema.union_schema( + list(reversed(choices)), + mode='left_to_right', + ) + ) out = v.validate_python(1.0) assert out == 1.0 assert isinstance(out, float) From 5bcaaf84946a759962d2863cc248731c02dfe4fc Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Mon, 24 Feb 2025 08:54:29 -0500 Subject: [PATCH 3/3] linting fix --- src/url.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/url.rs b/src/url.rs index 816227bd4..692ac80c2 100644 --- a/src/url.rs +++ b/src/url.rs @@ -400,7 +400,7 @@ impl PyMultiHostUrl { username: username.map(Into::into), password: password.map(Into::into), host: host.map(Into::into), - port: port.map(Into::into), + port, }; format!("{scheme}://{url_host}") } else {