diff --git a/src/validators/enum_.rs b/src/validators/enum_.rs index 2c9ae4b28..087d0ca5f 100644 --- a/src/validators/enum_.rs +++ b/src/validators/enum_.rs @@ -4,7 +4,7 @@ use std::marker::PhantomData; use pyo3::exceptions::PyTypeError; use pyo3::intern; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyList, PyType}; +use pyo3::types::{PyDict, PyFloat, PyInt, PyList, PyString, PyType}; use crate::build_tools::{is_strict, py_schema_err}; use crate::errors::{ErrorType, ValError, ValResult}; @@ -167,9 +167,27 @@ impl EnumValidateValue for PlainEnumValidator { py: Python<'py>, input: &I, lookup: &LiteralLookup, - _strict: bool, + strict: bool, ) -> ValResult> { - Ok(lookup.validate(py, input)?.map(|(_, v)| v.clone_ref(py))) + match lookup.validate(py, input)? { + Some((_, v)) => Ok(Some(v.clone_ref(py))), + None => { + if !strict { + if let Some(py_input) = input.as_python() { + // necessary for compatibility with 2.6, where str and int subclasses are allowed + if py_input.is_instance_of::() { + return Ok(lookup.validate_str(input, false)?.map(|v| v.clone_ref(py))); + } else if py_input.is_instance_of::() { + return Ok(lookup.validate_int(py, input, false)?.map(|v| v.clone_ref(py))); + // necessary for compatibility with 2.6, where float values are allowed for int enums in lax mode + } else if py_input.is_instance_of::() { + return Ok(lookup.validate_int(py, input, false)?.map(|v| v.clone_ref(py))); + } + } + } + Ok(None) + } + } } } diff --git a/tests/validators/test_enums.py b/tests/validators/test_enums.py index 49abc415b..91701ce74 100644 --- a/tests/validators/test_enums.py +++ b/tests/validators/test_enums.py @@ -269,6 +269,52 @@ class MyEnum(Enum): SchemaValidator(core_schema.enum_schema(MyEnum, [])) +def test_enum_with_str_subclass() -> None: + class MyEnum(Enum): + a = 'a' + b = 'b' + + v = SchemaValidator(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values()))) + + assert v.validate_python(MyEnum.a) is MyEnum.a + assert v.validate_python('a') is MyEnum.a + + class MyStr(str): + pass + + assert v.validate_python(MyStr('a')) is MyEnum.a + with pytest.raises(ValidationError): + v.validate_python(MyStr('a'), strict=True) + + +def test_enum_with_int_subclass() -> None: + class MyEnum(Enum): + a = 1 + b = 2 + + v = SchemaValidator(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values()))) + + assert v.validate_python(MyEnum.a) is MyEnum.a + assert v.validate_python(1) is MyEnum.a + + class MyInt(int): + pass + + assert v.validate_python(MyInt(1)) is MyEnum.a + with pytest.raises(ValidationError): + v.validate_python(MyInt(1), strict=True) + + +def test_validate_float_for_int_enum() -> None: + class MyEnum(int, Enum): + a = 1 + b = 2 + + v = SchemaValidator(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values()))) + + assert v.validate_python(1.0) is MyEnum.a + + def test_missing_error_converted_to_val_error() -> None: class MyFlags(IntFlag): OFF = 0