Skip to content

Commit

Permalink
add flag to allow only finite float values (#228)
Browse files Browse the repository at this point in the history
* add flag to allow only finite float values

* switch to `allow_inf_nan` with default to True

* tweak error

* move check directly in float

* add test for constrained float

* use dirty_equals

* allow flag at config level

* add field at config level + test
  • Loading branch information
PrettyWood committed Aug 22, 2022
1 parent 59babff commit 72b9cad
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 11 deletions.
3 changes: 3 additions & 0 deletions pydantic_core/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class Config(TypedDict, total=False):
str_strip_whitespace: bool
str_to_lower: bool
str_to_upper: bool
# fields related to float fields only
allow_inf_nan: bool # default: True


class DictSchema(TypedDict, total=False):
Expand All @@ -59,6 +61,7 @@ class DictSchema(TypedDict, total=False):

class FloatSchema(TypedDict, total=False):
type: Required[Literal['float']]
allow_inf_nan: bool # whether 'NaN', '+inf', '-inf' should be forbidden. default: True
multiple_of: float
le: float
ge: float
Expand Down
2 changes: 2 additions & 0 deletions src/errors/kinds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ pub enum ErrorKind {
FloatType,
#[strum(message = "Input should be a valid number, unable to parse string as an number")]
FloatParsing,
#[strum(message = "Input should be a finite number")]
FloatFiniteNumber,
#[strum(serialize = "multiple_of", message = "Input should be a multiple of {multiple_of}")]
FloatMultipleOf {
multiple_of: f64,
Expand Down
2 changes: 1 addition & 1 deletion src/input/input_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ impl<'a> Input<'a> for JsonInput {
},
JsonInput::Float(f) => Ok(*f),
JsonInput::Int(i) => Ok(*i as f64),
JsonInput::String(str) => match str.parse() {
JsonInput::String(str) => match str.parse::<f64>() {
Ok(i) => Ok(i),
Err(_) => Err(ValError::new(ErrorKind::FloatParsing, self)),
},
Expand Down
2 changes: 1 addition & 1 deletion src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ impl<'a> Input<'a> for PyAny {
if let Ok(float) = self.extract::<f64>() {
Ok(float)
} else if let Some(cow_str) = maybe_as_string(self, ErrorKind::FloatParsing)? {
match cow_str.as_ref().parse() {
match cow_str.as_ref().parse::<f64>() {
Ok(i) => Ok(i),
Err(_) => Err(ValError::new(ErrorKind::FloatParsing, self)),
}
Expand Down
15 changes: 13 additions & 2 deletions src/validators/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::PyDict;

use crate::build_tools::{is_strict, SchemaDict};
use crate::build_tools::{is_strict, schema_or_config_same, SchemaDict};
use crate::errors::{ErrorKind, ValError, ValResult};
use crate::input::Input;
use crate::recursion_guard::RecursionGuard;
Expand All @@ -12,6 +12,7 @@ use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator};
#[derive(Debug, Clone)]
pub struct FloatValidator {
strict: bool,
allow_inf_nan: bool,
}

impl BuildValidator for FloatValidator {
Expand All @@ -33,6 +34,7 @@ impl BuildValidator for FloatValidator {
} else {
Ok(Self {
strict: is_strict(schema, config)?,
allow_inf_nan: schema_or_config_same(schema, config, intern!(py, "allow_inf_nan"))?.unwrap_or(true),
}
.into())
}
Expand All @@ -48,7 +50,11 @@ impl Validator for FloatValidator {
_slots: &'data [CombinedValidator],
_recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
Ok(input.validate_float(extra.strict.unwrap_or(self.strict))?.into_py(py))
let float = input.validate_float(extra.strict.unwrap_or(self.strict))?;
if !self.allow_inf_nan && !float.is_finite() {
return Err(ValError::new(ErrorKind::FloatFiniteNumber, input));
}
Ok(float.into_py(py))
}

fn get_name(&self) -> &str {
Expand All @@ -59,6 +65,7 @@ impl Validator for FloatValidator {
#[derive(Debug, Clone)]
pub struct ConstrainedFloatValidator {
strict: bool,
allow_inf_nan: bool,
multiple_of: Option<f64>,
le: Option<f64>,
lt: Option<f64>,
Expand All @@ -76,6 +83,9 @@ impl Validator for ConstrainedFloatValidator {
_recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
let float = input.validate_float(extra.strict.unwrap_or(self.strict))?;
if !self.allow_inf_nan && !float.is_finite() {
return Err(ValError::new(ErrorKind::FloatFiniteNumber, input));
}
if let Some(multiple_of) = self.multiple_of {
if float % multiple_of != 0.0 {
return Err(ValError::new(ErrorKind::FloatMultipleOf { multiple_of }, input));
Expand Down Expand Up @@ -113,6 +123,7 @@ impl ConstrainedFloatValidator {
let py = schema.py();
Ok(Self {
strict: is_strict(schema, config)?,
allow_inf_nan: schema_or_config_same(schema, config, intern!(py, "allow_inf_nan"))?.unwrap_or(true),
multiple_of: schema.get_as(intern!(py, "multiple_of"))?,
le: schema.get_as(intern!(py, "le"))?,
lt: schema.get_as(intern!(py, "lt"))?,
Expand Down
58 changes: 55 additions & 3 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import math
import re

import pytest
from dirty_equals import HasAttributes, IsInstance
from dirty_equals import FunctionCheck, HasAttributes, IsInstance

from pydantic_core import SchemaValidator, ValidationError
from pydantic_core import Config, SchemaValidator, ValidationError

from .conftest import plain_repr
from .conftest import Err, plain_repr


def test_on_field():
Expand Down Expand Up @@ -190,3 +193,52 @@ def test_sub_model_merge():
'context': {'min_length': 1},
},
]


@pytest.mark.parametrize(
'config,float_field_schema,input_value,expected',
[
({}, {'type': 'float'}, {'x': 'nan'}, IsInstance(MyModel) & HasAttributes(x=FunctionCheck(math.isnan))),
(
{'allow_inf_nan': True},
{'type': 'float'},
{'x': 'nan'},
IsInstance(MyModel) & HasAttributes(x=FunctionCheck(math.isnan)),
),
(
{'allow_inf_nan': False},
{'type': 'float'},
{'x': 'nan'},
Err('Input should be a finite number [kind=float_finite_number,'),
),
# field `allow_inf_nan` (if set) should have priority over global config
(
{'allow_inf_nan': True},
{'type': 'float', 'allow_inf_nan': False},
{'x': 'nan'},
Err('Input should be a finite number [kind=float_finite_number,'),
),
(
{'allow_inf_nan': False},
{'type': 'float', 'allow_inf_nan': True},
{'x': 'nan'},
IsInstance(MyModel) & HasAttributes(x=FunctionCheck(math.isnan)),
),
],
ids=repr,
)
def test_allow_inf_nan(config: Config, float_field_schema, input_value, expected):
v = SchemaValidator(
{
'type': 'new-class',
'class_type': MyModel,
'schema': {'type': 'typed-dict', 'fields': {'x': {'schema': float_field_schema}}},
'config': config,
}
)
if isinstance(expected, Err):
with pytest.raises(ValidationError, match=re.escape(expected.message)):
v.validate_python(input_value)
else:
output_dict = v.validate_python(input_value)
assert output_dict == expected
123 changes: 121 additions & 2 deletions tests/validators/test_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Dict

import pytest
from dirty_equals import FunctionCheck

from pydantic_core import SchemaValidator, ValidationError

Expand Down Expand Up @@ -141,9 +142,14 @@ def test_union_float_simple(py_and_json: PyAndJson):

def test_float_repr():
v = SchemaValidator({'type': 'float'})
assert plain_repr(v) == 'SchemaValidator(name="float",validator=Float(FloatValidator{strict:false}))'
assert (
plain_repr(v)
== 'SchemaValidator(name="float",validator=Float(FloatValidator{strict:false,allow_inf_nan:true}))'
)
v = SchemaValidator({'type': 'float', 'strict': True})
assert plain_repr(v) == 'SchemaValidator(name="float",validator=Float(FloatValidator{strict:true}))'
assert (
plain_repr(v) == 'SchemaValidator(name="float",validator=Float(FloatValidator{strict:true,allow_inf_nan:true}))'
)
v = SchemaValidator({'type': 'float', 'multiple_of': 7})
assert plain_repr(v).startswith('SchemaValidator(name="constrained-float",validator=ConstrainedFloat(')

Expand Down Expand Up @@ -174,3 +180,116 @@ def test_float_key(py_and_json: PyAndJson):
assert v.validate_test({'1.5': 1, '2.4': 2}) == {1.5: 1, 2.4: 2}
with pytest.raises(ValidationError, match='Input should be a valid number'):
v.validate_test({'1.5': 1, '2.5': 2}, strict=True)


@pytest.mark.parametrize(
'input_value,allow_inf_nan,expected',
[
('NaN', True, FunctionCheck(math.isnan)),
(
'NaN',
False,
Err("Input should be a finite number [kind=float_finite_number, input_value='NaN', input_type=str]"),
),
('+inf', True, FunctionCheck(lambda x: math.isinf(x) and x > 0)),
(
'+inf',
False,
Err("Input should be a finite number [kind=float_finite_number, input_value='+inf', input_type=str]"),
),
('+infinity', True, FunctionCheck(lambda x: math.isinf(x) and x > 0)),
(
'+infinity',
False,
Err("Input should be a finite number [kind=float_finite_number, input_value='+infinity', input_type=str]"),
),
('-inf', True, FunctionCheck(lambda x: math.isinf(x) and x < 0)),
(
'-inf',
False,
Err("Input should be a finite number [kind=float_finite_number, input_value='-inf', input_type=str]"),
),
('-infinity', True, FunctionCheck(lambda x: math.isinf(x) and x < 0)),
(
'-infinity',
False,
Err("Input should be a finite number [kind=float_finite_number, input_value='-infinity', input_type=str]"),
),
('0.7', True, 0.7),
('0.7', False, 0.7),
(
'pika',
True,
Err(
'Input should be a valid number, unable to parse string as an number '
"[kind=float_parsing, input_value='pika', input_type=str]"
),
),
(
'pika',
False,
Err(
'Input should be a valid number, unable to parse string as an number '
"[kind=float_parsing, input_value='pika', input_type=str]"
),
),
],
)
def test_non_finite_json_values(py_and_json: PyAndJson, input_value, allow_inf_nan, expected):
v = py_and_json({'type': 'float', 'allow_inf_nan': allow_inf_nan})
if isinstance(expected, Err):
with pytest.raises(ValidationError, match=re.escape(expected.message)):
v.validate_test(input_value)
else:
assert v.validate_test(input_value) == expected


@pytest.mark.parametrize('strict', (True, False))
@pytest.mark.parametrize(
'input_value,allow_inf_nan,expected',
[
(float('nan'), True, FunctionCheck(math.isnan)),
(
float('nan'),
False,
Err('Input should be a finite number [kind=float_finite_number, input_value=nan, input_type=float]'),
),
],
)
def test_non_finite_float_values(strict, input_value, allow_inf_nan, expected):
v = SchemaValidator({'type': 'float', 'allow_inf_nan': allow_inf_nan, 'strict': strict})
if isinstance(expected, Err):
with pytest.raises(ValidationError, match=re.escape(expected.message)):
v.validate_python(input_value)
else:
assert v.validate_python(input_value) == expected


@pytest.mark.parametrize(
'input_value,allow_inf_nan,expected',
[
(float('+inf'), True, FunctionCheck(lambda x: math.isinf(x) and x > 0)),
(
float('+inf'),
False,
Err('Input should be a finite number [kind=float_finite_number, input_value=inf, input_type=float]'),
),
(
float('-inf'),
True,
Err('Input should be greater than 0 [kind=greater_than, input_value=-inf, input_type=float]'),
),
(
float('-inf'),
False,
Err('Input should be a finite number [kind=float_finite_number, input_value=-inf, input_type=float]'),
),
],
)
def test_non_finite_constrained_float_values(input_value, allow_inf_nan, expected):
v = SchemaValidator({'type': 'float', 'allow_inf_nan': allow_inf_nan, 'gt': 0})
if isinstance(expected, Err):
with pytest.raises(ValidationError, match=re.escape(expected.message)):
v.validate_python(input_value)
else:
assert v.validate_python(input_value) == expected
12 changes: 10 additions & 2 deletions tests/validators/test_typed_dict.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import math
import re
import sys
from dataclasses import dataclass
from datetime import datetime
from typing import Mapping

import pytest
from dirty_equals import HasRepr, IsStr
from dirty_equals import FunctionCheck, HasRepr, IsStr

from pydantic_core import Config, SchemaError, SchemaValidator, ValidationError

Expand Down Expand Up @@ -116,12 +117,19 @@ def test_missing_error():
Err('Keys should be strings [kind=invalid_key,'),
),
({'strict': True}, Map(a=123), Err('Input should be a valid dictionary [kind=dict_type,')),
({}, {'a': '123', 'b': '4.7'}, {'a': 123, 'b': 4.7}),
({}, {'a': '123', 'b': 'nan'}, {'a': 123, 'b': FunctionCheck(math.isnan)}),
(
{'allow_inf_nan': False},
{'a': '123', 'b': 'nan'},
Err('Input should be a finite number [kind=float_finite_number,'),
),
],
ids=repr,
)
def test_config(config: Config, input_value, expected):
v = SchemaValidator(
{'type': 'typed-dict', 'fields': {'a': {'schema': 'int'}, 'b': {'schema': 'int', 'required': False}}}, config
{'type': 'typed-dict', 'fields': {'a': {'schema': 'int'}, 'b': {'schema': 'float', 'required': False}}}, config
)
if isinstance(expected, Err):
with pytest.raises(ValidationError, match=re.escape(expected.message)):
Expand Down

0 comments on commit 72b9cad

Please sign in to comment.