Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add flag to allow only finite float values #228

Merged
merged 8 commits into from
Aug 22, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions pydantic_core/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class DictSchema(TypedDict, total=False):

class FloatSchema(TypedDict, total=False):
type: Required[Literal['float']]
allow_inf_nan: bool # whether 'NaN', '+inf', '-inf' should be forbidden. default: True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to add allow_inf_nan to Config to otherwise it can't be set, might be good to add a test to confirm setting allow_inf_nan via config works.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes I added the implementation but not the interfaces. I'll do that soon(ish)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

multiple_of: float
le: float
ge: float
Expand Down
2 changes: 2 additions & 0 deletions src/errors/kinds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ pub enum ErrorKind {
FloatType,
#[strum(message = "Input should be a valid number, unable to parse string as an number")]
FloatParsing,
#[strum(message = "Input should be a finite number")]
FloatFiniteNumber,
#[strum(serialize = "multiple_of", message = "Input should be a multiple of {multiple_of}")]
FloatMultipleOf {
multiple_of: f64,
Expand Down
2 changes: 1 addition & 1 deletion src/input/input_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ impl<'a> Input<'a> for JsonInput {
},
JsonInput::Float(f) => Ok(*f),
JsonInput::Int(i) => Ok(*i as f64),
JsonInput::String(str) => match str.parse() {
JsonInput::String(str) => match str.parse::<f64>() {
Ok(i) => Ok(i),
Err(_) => Err(ValError::new(ErrorKind::FloatParsing, self)),
},
Expand Down
2 changes: 1 addition & 1 deletion src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ impl<'a> Input<'a> for PyAny {
if let Ok(float) = self.extract::<f64>() {
Ok(float)
} else if let Some(cow_str) = maybe_as_string(self, ErrorKind::FloatParsing)? {
match cow_str.as_ref().parse() {
match cow_str.as_ref().parse::<f64>() {
Ok(i) => Ok(i),
Err(_) => Err(ValError::new(ErrorKind::FloatParsing, self)),
}
Expand Down
15 changes: 13 additions & 2 deletions src/validators/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::PyDict;

use crate::build_tools::{is_strict, SchemaDict};
use crate::build_tools::{is_strict, schema_or_config_same, SchemaDict};
use crate::errors::{ErrorKind, ValError, ValResult};
use crate::input::Input;
use crate::recursion_guard::RecursionGuard;
Expand All @@ -12,6 +12,7 @@ use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator};
#[derive(Debug, Clone)]
pub struct FloatValidator {
strict: bool,
allow_inf_nan: bool,
}

impl BuildValidator for FloatValidator {
Expand All @@ -33,6 +34,7 @@ impl BuildValidator for FloatValidator {
} else {
Ok(Self {
strict: is_strict(schema, config)?,
allow_inf_nan: schema_or_config_same(schema, config, intern!(py, "allow_inf_nan"))?.unwrap_or(true),
}
.into())
}
Expand All @@ -48,7 +50,11 @@ impl Validator for FloatValidator {
_slots: &'data [CombinedValidator],
_recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
Ok(input.validate_float(extra.strict.unwrap_or(self.strict))?.into_py(py))
let float = input.validate_float(extra.strict.unwrap_or(self.strict))?;
if !self.allow_inf_nan && !float.is_finite() {
return Err(ValError::new(ErrorKind::FloatFiniteNumber, input));
}
Ok(float.into_py(py))
}

fn get_name(&self) -> &str {
Expand All @@ -59,6 +65,7 @@ impl Validator for FloatValidator {
#[derive(Debug, Clone)]
pub struct ConstrainedFloatValidator {
strict: bool,
allow_inf_nan: bool,
multiple_of: Option<f64>,
le: Option<f64>,
lt: Option<f64>,
Expand All @@ -76,6 +83,9 @@ impl Validator for ConstrainedFloatValidator {
_recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
let float = input.validate_float(extra.strict.unwrap_or(self.strict))?;
if !self.allow_inf_nan && !float.is_finite() {
return Err(ValError::new(ErrorKind::FloatFiniteNumber, input));
}
if let Some(multiple_of) = self.multiple_of {
if float % multiple_of != 0.0 {
return Err(ValError::new(ErrorKind::FloatMultipleOf { multiple_of }, input));
Expand Down Expand Up @@ -113,6 +123,7 @@ impl ConstrainedFloatValidator {
let py = schema.py();
Ok(Self {
strict: is_strict(schema, config)?,
allow_inf_nan: schema_or_config_same(schema, config, intern!(py, "allow_inf_nan"))?.unwrap_or(true),
multiple_of: schema.get_as(intern!(py, "multiple_of"))?,
le: schema.get_as(intern!(py, "le"))?,
lt: schema.get_as(intern!(py, "lt"))?,
Expand Down
123 changes: 121 additions & 2 deletions tests/validators/test_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Dict

import pytest
from dirty_equals import FunctionCheck

from pydantic_core import SchemaValidator, ValidationError

Expand Down Expand Up @@ -141,9 +142,14 @@ def test_union_float_simple(py_and_json: PyAndJson):

def test_float_repr():
v = SchemaValidator({'type': 'float'})
assert plain_repr(v) == 'SchemaValidator(name="float",validator=Float(FloatValidator{strict:false}))'
assert (
plain_repr(v)
== 'SchemaValidator(name="float",validator=Float(FloatValidator{strict:false,allow_inf_nan:true}))'
)
v = SchemaValidator({'type': 'float', 'strict': True})
assert plain_repr(v) == 'SchemaValidator(name="float",validator=Float(FloatValidator{strict:true}))'
assert (
plain_repr(v) == 'SchemaValidator(name="float",validator=Float(FloatValidator{strict:true,allow_inf_nan:true}))'
)
v = SchemaValidator({'type': 'float', 'multiple_of': 7})
assert plain_repr(v).startswith('SchemaValidator(name="constrained-float",validator=ConstrainedFloat(')

Expand Down Expand Up @@ -174,3 +180,116 @@ def test_float_key(py_and_json: PyAndJson):
assert v.validate_test({'1.5': 1, '2.4': 2}) == {1.5: 1, 2.4: 2}
with pytest.raises(ValidationError, match='Input should be a valid number'):
v.validate_test({'1.5': 1, '2.5': 2}, strict=True)


@pytest.mark.parametrize(
'input_value,allow_inf_nan,expected',
[
('NaN', True, FunctionCheck(math.isnan)),
(
'NaN',
False,
Err("Input should be a finite number [kind=float_finite_number, input_value='NaN', input_type=str]"),
),
('+inf', True, FunctionCheck(lambda x: math.isinf(x) and x > 0)),
(
'+inf',
False,
Err("Input should be a finite number [kind=float_finite_number, input_value='+inf', input_type=str]"),
),
('+infinity', True, FunctionCheck(lambda x: math.isinf(x) and x > 0)),
(
'+infinity',
False,
Err("Input should be a finite number [kind=float_finite_number, input_value='+infinity', input_type=str]"),
),
('-inf', True, FunctionCheck(lambda x: math.isinf(x) and x < 0)),
(
'-inf',
False,
Err("Input should be a finite number [kind=float_finite_number, input_value='-inf', input_type=str]"),
),
('-infinity', True, FunctionCheck(lambda x: math.isinf(x) and x < 0)),
(
'-infinity',
False,
Err("Input should be a finite number [kind=float_finite_number, input_value='-infinity', input_type=str]"),
),
('0.7', True, 0.7),
('0.7', False, 0.7),
(
'pika',
True,
Err(
'Input should be a valid number, unable to parse string as an number '
"[kind=float_parsing, input_value='pika', input_type=str]"
),
),
(
'pika',
False,
Err(
'Input should be a valid number, unable to parse string as an number '
"[kind=float_parsing, input_value='pika', input_type=str]"
),
),
],
)
def test_non_finite_json_values(py_and_json: PyAndJson, input_value, allow_inf_nan, expected):
v = py_and_json({'type': 'float', 'allow_inf_nan': allow_inf_nan})
if isinstance(expected, Err):
with pytest.raises(ValidationError, match=re.escape(expected.message)):
v.validate_test(input_value)
else:
assert v.validate_test(input_value) == expected


@pytest.mark.parametrize('strict', (True, False))
@pytest.mark.parametrize(
'input_value,allow_inf_nan,expected',
[
(float('nan'), True, FunctionCheck(math.isnan)),
(
float('nan'),
False,
Err('Input should be a finite number [kind=float_finite_number, input_value=nan, input_type=float]'),
),
],
)
def test_non_finite_float_values(strict, input_value, allow_inf_nan, expected):
v = SchemaValidator({'type': 'float', 'allow_inf_nan': allow_inf_nan, 'strict': strict})
if isinstance(expected, Err):
with pytest.raises(ValidationError, match=re.escape(expected.message)):
v.validate_python(input_value)
else:
assert v.validate_python(input_value) == expected


@pytest.mark.parametrize(
'input_value,allow_inf_nan,expected',
[
(float('+inf'), True, FunctionCheck(lambda x: math.isinf(x) and x > 0)),
(
float('+inf'),
False,
Err('Input should be a finite number [kind=float_finite_number, input_value=inf, input_type=float]'),
),
(
float('-inf'),
True,
Err('Input should be greater than 0 [kind=greater_than, input_value=-inf, input_type=float]'),
),
(
float('-inf'),
False,
Err('Input should be a finite number [kind=float_finite_number, input_value=-inf, input_type=float]'),
),
],
)
def test_non_finite_constrained_float_values(input_value, allow_inf_nan, expected):
v = SchemaValidator({'type': 'float', 'allow_inf_nan': allow_inf_nan, 'gt': 0})
if isinstance(expected, Err):
with pytest.raises(ValidationError, match=re.escape(expected.message)):
v.validate_python(input_value)
else:
assert v.validate_python(input_value) == expected
12 changes: 10 additions & 2 deletions tests/validators/test_typed_dict.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import math
import re
import sys
from dataclasses import dataclass
from datetime import datetime
from typing import Mapping

import pytest
from dirty_equals import HasRepr, IsStr
from dirty_equals import FunctionCheck, HasRepr, IsStr

from pydantic_core import Config, SchemaError, SchemaValidator, ValidationError

Expand Down Expand Up @@ -116,12 +117,19 @@ def test_missing_error():
Err('Keys should be strings [kind=invalid_key,'),
),
({'strict': True}, Map(a=123), Err('Input should be a valid dictionary [kind=dict_type,')),
({}, {'a': '123', 'b': '4.7'}, {'a': 123, 'b': 4.7}),
({}, {'a': '123', 'b': 'nan'}, {'a': 123, 'b': FunctionCheck(math.isnan)}),
(
{'allow_inf_nan': False},
{'a': '123', 'b': 'nan'},
Err('Input should be a finite number [kind=float_finite_number,'),
),
],
ids=repr,
)
def test_config(config: Config, input_value, expected):
v = SchemaValidator(
{'type': 'typed-dict', 'fields': {'a': {'schema': 'int'}, 'b': {'schema': 'int', 'required': False}}}, config
{'type': 'typed-dict', 'fields': {'a': {'schema': 'int'}, 'b': {'schema': 'float', 'required': False}}}, config
)
if isinstance(expected, Err):
with pytest.raises(ValidationError, match=re.escape(expected.message)):
Expand Down