Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Int extraction #1155

Merged
merged 8 commits into from Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Expand Up @@ -36,3 +36,6 @@ node_modules/
/foobar.py
/python/pydantic_core/*.so
/src/self_schema.py

# samply
/profile.json
2 changes: 1 addition & 1 deletion src/errors/types.rs
Expand Up @@ -786,7 +786,7 @@ impl From<Int> for Number {

impl FromPyObject<'_> for Number {
fn extract(obj: &PyAny) -> PyResult<Self> {
if let Ok(int) = extract_i64(obj) {
if let Some(int) = extract_i64(obj) {
Ok(Number::Int(int))
} else if let Ok(float) = obj.extract::<f64>() {
Ok(Number::Float(float))
Expand Down
2 changes: 1 addition & 1 deletion src/errors/value_exception.rs
Expand Up @@ -122,7 +122,7 @@ impl PydanticCustomError {
let key: &PyString = key.downcast()?;
if let Ok(py_str) = value.downcast::<PyString>() {
message = message.replace(&format!("{{{}}}", key.to_str()?), py_str.to_str()?);
} else if let Ok(value_int) = extract_i64(value) {
} else if let Some(value_int) = extract_i64(value) {
message = message.replace(&format!("{{{}}}", key.to_str()?), &value_int.to_string());
} else {
// fallback for anything else just in case
Expand Down
10 changes: 5 additions & 5 deletions src/input/input_python.rs
Expand Up @@ -96,7 +96,7 @@ impl AsLocItem for PyAny {
fn as_loc_item(&self) -> LocItem {
if let Ok(py_str) = self.downcast::<PyString>() {
py_str.to_string_lossy().as_ref().into()
} else if let Ok(key_int) = extract_i64(self) {
} else if let Some(key_int) = extract_i64(self) {
key_int.into()
} else {
safe_repr(self).to_string().into()
Expand Down Expand Up @@ -292,7 +292,7 @@ impl<'a> Input<'a> for PyAny {
if !strict {
if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::BoolParsing)? {
return str_as_bool(self, &cow_str).map(ValidationMatch::lax);
} else if let Ok(int) = extract_i64(self) {
} else if let Some(int) = extract_i64(self) {
return int_as_bool(self, int).map(ValidationMatch::lax);
} else if let Ok(float) = self.extract::<f64>() {
if let Ok(int) = float_as_int(self, float) {
Expand Down Expand Up @@ -635,7 +635,7 @@ impl<'a> Input<'a> for PyAny {
bytes_as_time(self, py_bytes.as_bytes(), microseconds_overflow_behavior)
} else if PyBool::is_exact_type_of(self) {
Err(ValError::new(ErrorTypeDefaults::TimeType, self))
} else if let Ok(int) = extract_i64(self) {
} else if let Some(int) = extract_i64(self) {
int_as_time(self, int, 0)
} else if let Ok(float) = self.extract::<f64>() {
float_as_time(self, float)
Expand Down Expand Up @@ -669,7 +669,7 @@ impl<'a> Input<'a> for PyAny {
bytes_as_datetime(self, py_bytes.as_bytes(), microseconds_overflow_behavior)
} else if PyBool::is_exact_type_of(self) {
Err(ValError::new(ErrorTypeDefaults::DatetimeType, self))
} else if let Ok(int) = extract_i64(self) {
} else if let Some(int) = extract_i64(self) {
int_as_datetime(self, int, 0)
} else if let Ok(float) = self.extract::<f64>() {
float_as_datetime(self, float)
Expand Down Expand Up @@ -706,7 +706,7 @@ impl<'a> Input<'a> for PyAny {
bytes_as_timedelta(self, str.as_bytes(), microseconds_overflow_behavior)
} else if let Ok(py_bytes) = self.downcast::<PyBytes>() {
bytes_as_timedelta(self, py_bytes.as_bytes(), microseconds_overflow_behavior)
} else if let Ok(int) = extract_i64(self) {
} else if let Some(int) = extract_i64(self) {
Ok(int_as_duration(self, int)?.into())
} else if let Ok(float) = self.extract::<f64>() {
Ok(float_as_duration(self, float)?.into())
Expand Down
6 changes: 3 additions & 3 deletions src/input/return_enums.rs
Expand Up @@ -23,7 +23,7 @@ use pyo3::PyTypeInfo;
use serde::{ser::Error, Serialize, Serializer};

use crate::errors::{py_err_string, ErrorType, ErrorTypeDefaults, InputValue, ValError, ValLineError, ValResult};
use crate::tools::py_err;
use crate::tools::{extract_i64, py_err};
use crate::validators::{CombinedValidator, Exactness, ValidationState, Validator};

use super::input_string::StringMapping;
Expand Down Expand Up @@ -863,7 +863,7 @@ pub enum EitherInt<'a> {
impl<'a> EitherInt<'a> {
pub fn upcast(py_any: &'a PyAny) -> ValResult<Self> {
// Safety: we know that py_any is a python int
if let Ok(int_64) = py_any.extract::<i64>() {
if let Some(int_64) = extract_i64(py_any) {
Ok(Self::I64(int_64))
} else {
let big_int: BigInt = py_any.extract()?;
Expand Down Expand Up @@ -1021,7 +1021,7 @@ impl<'a> Rem for &'a Int {

impl<'a> FromPyObject<'a> for Int {
fn extract(obj: &'a PyAny) -> PyResult<Self> {
if let Ok(i) = obj.extract::<i64>() {
if let Some(i) = extract_i64(obj) {
Ok(Int::I64(i))
} else if let Ok(b) = obj.extract::<BigInt>() {
Ok(Int::Big(b))
Expand Down
2 changes: 1 addition & 1 deletion src/lookup_key.rs
Expand Up @@ -429,7 +429,7 @@ impl PathItem {
} else {
Ok(Self::Pos(usize_key))
}
} else if let Ok(int_key) = extract_i64(obj) {
} else if let Some(int_key) = extract_i64(obj) {
if index == 0 {
py_err!(PyTypeError; "The first item in an alias path should be a string")
} else {
Expand Down
5 changes: 4 additions & 1 deletion src/serializers/infer.rs
Expand Up @@ -123,7 +123,10 @@ pub(crate) fn infer_to_python_known(
// `bool` and `None` can't be subclasses, `ObType::Int`, `ObType::Float`, `ObType::Str` refer to exact types
ObType::None | ObType::Bool | ObType::Int | ObType::Str => value.into_py(py),
// have to do this to make sure subclasses of for example str are upcast to `str`
ObType::IntSubclass => extract_i64(value)?.into_py(py),
ObType::IntSubclass => match extract_i64(value) {
Some(v) => v.into_py(py),
None => return py_err!(PyTypeError; "expected int, got {}", safe_repr(value)),
},
ObType::Float | ObType::FloatSubclass => {
let v = value.extract::<f64>()?;
if (v.is_nan() || v.is_infinite()) && extra.config.inf_nan_mode == InfNanMode::Null {
Expand Down
4 changes: 2 additions & 2 deletions src/serializers/type_serializers/literal.rs
Expand Up @@ -46,7 +46,7 @@ impl BuildSerializer for LiteralSerializer {
repr_args.push(item.repr()?.extract()?);
if let Ok(bool) = item.downcast::<PyBool>() {
expected_py.append(bool)?;
} else if let Ok(int) = extract_i64(item) {
} else if let Some(int) = extract_i64(item) {
expected_int.insert(int);
} else if let Ok(py_str) = item.downcast::<PyString>() {
expected_str.insert(py_str.to_str()?.to_string());
Expand Down Expand Up @@ -79,7 +79,7 @@ impl LiteralSerializer {
fn check<'a>(&self, value: &'a PyAny, extra: &Extra) -> PyResult<OutputValue<'a>> {
if extra.check.enabled() {
if !self.expected_int.is_empty() && !PyBool::is_type_of(value) {
if let Ok(int) = extract_i64(value) {
if let Some(int) = extract_i64(value) {
if self.expected_int.contains(&int) {
return Ok(OutputValue::OkInt(int));
}
Expand Down
28 changes: 21 additions & 7 deletions src/tools.rs
@@ -1,9 +1,9 @@
use std::borrow::Cow;

use pyo3::exceptions::{PyKeyError, PyTypeError};
use pyo3::exceptions::PyKeyError;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyInt, PyString};
use pyo3::{intern, FromPyObject, PyTypeInfo};
use pyo3::types::{PyDict, PyString};
use pyo3::{ffi, intern, FromPyObject};

pub trait SchemaDict<'py> {
fn get_as<T>(&'py self, key: &PyString) -> PyResult<Option<T>>
Expand Down Expand Up @@ -99,10 +99,24 @@ pub fn safe_repr(v: &PyAny) -> Cow<str> {
}
}

pub fn extract_i64(v: &PyAny) -> PyResult<i64> {
if PyInt::is_type_of(v) {
v.extract()
/// Extract an i64 from a python object more quickly, see
/// https://github.com/PyO3/pyo3/pull/3742#discussion_r1451763928
#[cfg(not(any(target_pointer_width = "32", windows, PyPy)))]
pub fn extract_i64(obj: &PyAny) -> Option<i64> {
let val = unsafe { ffi::PyLong_AsLong(obj.as_ptr()) };
if val == -1 && PyErr::occurred(obj.py()) {
unsafe { ffi::PyErr_Clear() };
None
} else {
py_err!(PyTypeError; "expected int, got {}", safe_repr(v))
Some(val)
}
}

#[cfg(any(target_pointer_width = "32", windows, PyPy))]
pub fn extract_i64(v: &PyAny) -> Option<i64> {
if v.is_instance_of::<pyo3::types::PyInt>() {
v.extract().ok()
} else {
None
}
}
12 changes: 12 additions & 0 deletions tests/benchmarks/test_micro_benchmarks.py
Expand Up @@ -1232,6 +1232,18 @@ def test_strict_int(benchmark):
benchmark(v.validate_python, 42)


@pytest.mark.benchmark(group='strict_int')
def test_strict_int_fails(benchmark):
v = SchemaValidator(core_schema.int_schema(strict=True))

@benchmark
def t():
try:
v.validate_python(())
except ValidationError:
pass


@pytest.mark.benchmark(group='int_range')
def test_int_range(benchmark):
v = SchemaValidator(core_schema.int_schema(gt=0, lt=100))
Expand Down
15 changes: 11 additions & 4 deletions tests/validators/test_int.py
Expand Up @@ -29,6 +29,8 @@
('123456789123456.00001', Err('Input should be a valid integer, unable to parse string as an integer')),
(int(1e10), int(1e10)),
(i64_max, i64_max),
(i64_max + 1, i64_max + 1),
(i64_max * 2, i64_max * 2),
pytest.param(
12.5,
Err('Input should be a valid integer, got a number with a fractional part [type=int_from_float'),
Expand Down Expand Up @@ -106,10 +108,15 @@ def test_int(input_value, expected):
@pytest.mark.parametrize(
'input_value,expected',
[
(Decimal('1'), 1),
(Decimal('1.0'), 1),
(i64_max, i64_max),
(i64_max + 1, i64_max + 1),
pytest.param(Decimal('1'), 1),
pytest.param(Decimal('1.0'), 1),
pytest.param(i64_max, i64_max, id='i64_max'),
pytest.param(i64_max + 1, i64_max + 1, id='i64_max+1'),
pytest.param(
-1,
Err('Input should be greater than 0 [type=greater_than, input_value=-1, input_type=int]'),
id='-1',
),
(
-i64_max + 1,
Err('Input should be greater than 0 [type=greater_than, input_value=-9223372036854775806, input_type=int]'),
Expand Down