Skip to content

Commit

Permalink
fix: nan inf float (#1062)
Browse files Browse the repository at this point in the history
Co-authored-by: JeanArhancet <jean.arhancetebehere@gmail.com>
  • Loading branch information
davidhewitt and JeanArhancet committed Nov 6, 2023
1 parent 5de6b75 commit d80c454
Show file tree
Hide file tree
Showing 13 changed files with 1,496 additions and 8 deletions.
3 changes: 3 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class CoreConfig(TypedDict, total=False):
allow_inf_nan: Whether to allow infinity and NaN values for float fields. Default is `True`.
ser_json_timedelta: The serialization option for `timedelta` values. Default is 'iso8601'.
ser_json_bytes: The serialization option for `bytes` values. Default is 'utf8'.
ser_json_inf_nan: The serialization option for infinity and NaN values
in float fields. Default is 'null'.
hide_input_in_errors: Whether to hide input data from `ValidationError` representation.
validation_error_cause: Whether to add user-python excs to the __cause__ of a ValidationError.
Requires exceptiongroup backport pre Python 3.11.
Expand Down Expand Up @@ -102,6 +104,7 @@ class CoreConfig(TypedDict, total=False):
# the config options are used to customise serialization to JSON
ser_json_timedelta: Literal['iso8601', 'float'] # default: 'iso8601'
ser_json_bytes: Literal['utf8', 'base64', 'hex'] # default: 'utf8'
ser_json_inf_nan: Literal['null', 'constants'] # default: 'null'
# used to hide input data from ValidationError repr
hide_input_in_errors: bool
validation_error_cause: bool # default: False
Expand Down
4 changes: 2 additions & 2 deletions src/errors/validation_exception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,12 +324,12 @@ impl ValidationError {
Some(indent) => {
let indent = vec![b' '; indent];
let formatter = PrettyFormatter::with_indent(&indent);
let mut ser = serde_json::Serializer::with_formatter(writer, formatter);
let mut ser = crate::serializers::ser::PythonSerializer::with_formatter(writer, formatter);
serializer.serialize(&mut ser).map_err(json_py_err)?;
ser.into_inner()
}
None => {
let mut ser = serde_json::Serializer::new(writer);
let mut ser = crate::serializers::ser::PythonSerializer::new(writer);
serializer.serialize(&mut ser).map_err(json_py_err)?;
ser.into_inner()
}
Expand Down
29 changes: 29 additions & 0 deletions src/serializers/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,32 @@ pub fn utf8_py_error(py: Python, err: Utf8Error, data: &[u8]) -> PyErr {
Err(err) => err,
}
}

#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub(crate) enum InfNanMode {
#[default]
Null,
Constants,
}

impl FromStr for InfNanMode {
type Err = PyErr;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"null" => Ok(Self::Null),
"constants" => Ok(Self::Constants),
s => py_schema_err!(
"Invalid inf_nan serialization mode: `{}`, expected `null` or `constants`",
s
),
}
}
}

impl FromPyObject<'_> for InfNanMode {
fn extract(ob: &'_ PyAny) -> PyResult<Self> {
let s = ob.extract::<&str>()?;
Self::from_str(s)
}
}
27 changes: 26 additions & 1 deletion src/serializers/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,33 @@ pub(super) fn py_err_se_err<T: ser::Error, E: fmt::Display>(py_error: E) -> T {
T::custom(py_error.to_string())
}

#[pyclass(extends=PyValueError, module="pydantic_core._pydantic_core")]
#[derive(Debug, Clone)]
pub struct PythonSerializerError {
pub message: String,
}

impl fmt::Display for PythonSerializerError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.message)
}
}

impl std::error::Error for PythonSerializerError {}

impl serde::ser::Error for PythonSerializerError {
fn custom<T>(msg: T) -> Self
where
T: fmt::Display,
{
PythonSerializerError {
message: format!("{msg}"),
}
}
}

/// convert a serde serialization error into a `PyErr`
pub(super) fn se_err_py_err(error: serde_json::Error) -> PyErr {
pub(super) fn se_err_py_err(error: PythonSerializerError) -> PyErr {
let s = error.to_string();
if let Some(msg) = s.strip_prefix(UNEXPECTED_TYPE_SER_MARKER) {
if msg.is_empty() {
Expand Down
1 change: 1 addition & 0 deletions src/serializers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ mod fields;
mod filter;
mod infer;
mod ob_type;
pub mod ser;
mod shared;
mod type_serializers;

Expand Down
Loading

0 comments on commit d80c454

Please sign in to comment.