diff --git a/src/serializers/type_serializers/tuple.rs b/src/serializers/type_serializers/tuple.rs index e5f225c92..5890288d6 100644 --- a/src/serializers/type_serializers/tuple.rs +++ b/src/serializers/type_serializers/tuple.rs @@ -7,8 +7,10 @@ use std::iter; use serde::ser::SerializeSeq; use crate::definitions::DefinitionsBuilder; +use crate::serializers::extra::SerCheck; use crate::serializers::type_serializers::any::AnySerializer; use crate::tools::SchemaDict; +use crate::PydanticSerializationUnexpectedValue; use super::{ infer_json_key, infer_serialize, infer_to_python, py_err_se_err, BuildSerializer, CombinedSerializer, Extra, @@ -70,52 +72,14 @@ impl TypeSerializer for TupleSerializer { let py = value.py(); let n_items = py_tuple.len(); - let mut py_tuple_iter = py_tuple.iter(); let mut items = Vec::with_capacity(n_items); - macro_rules! use_serializers { - ($serializers_iter:expr) => { - for (index, serializer) in $serializers_iter.enumerate() { - let element = match py_tuple_iter.next() { - Some(value) => value, - None => break, - }; - let op_next = self - .filter - .index_filter(index, include, exclude, Some(n_items))?; - if let Some((next_include, next_exclude)) = op_next { - items.push(serializer.to_python(element, next_include, next_exclude, extra)?); - } - } - }; - } - - if let Some(variadic_item_index) = self.variadic_item_index { - // Need `saturating_sub` to handle items with too few elements without panicking - let n_variadic_items = (n_items + 1).saturating_sub(self.serializers.len()); - let serializers_iter = self.serializers[..variadic_item_index] - .iter() - .chain(iter::repeat(&self.serializers[variadic_item_index]).take(n_variadic_items)) - .chain(self.serializers[variadic_item_index + 1..].iter()); - use_serializers!(serializers_iter); - } else { - use_serializers!(self.serializers.iter()); - let mut warned = false; - for (i, element) in py_tuple_iter.enumerate() { - if !warned { - extra - .warnings - .custom_warning("Unexpected extra items present in tuple".to_string()); - warned = true; - } - let op_next = - self.filter - .index_filter(i + self.serializers.len(), include, exclude, Some(n_items))?; - if let Some((next_include, next_exclude)) = op_next { - items.push(AnySerializer.to_python(element, next_include, next_exclude, extra)?); - } - } - }; + self.for_each_tuple_item_and_serializer(py_tuple, include, exclude, extra, |entry| { + entry + .serializer + .to_python(entry.item, entry.include, entry.exclude, extra) + .map(|item| items.push(item)) + })??; match extra.mode { SerMode::Json => Ok(PyList::new(py, items).into_py(py)), @@ -132,35 +96,14 @@ impl TypeSerializer for TupleSerializer { fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult> { match key.downcast::() { Ok(py_tuple) => { - let mut py_tuple_iter = py_tuple.iter(); - let mut key_builder = KeyBuilder::new(); - let n_items = py_tuple.len(); - - macro_rules! use_serializers { - ($serializers_iter:expr) => { - for serializer in $serializers_iter { - let element = match py_tuple_iter.next() { - Some(value) => value, - None => break, - }; - key_builder.push(&serializer.json_key(element, extra)?); - } - }; - } - - if let Some(variadic_item_index) = self.variadic_item_index { - // Need `saturating_sub` to handle items with too few elements without panicking - let n_variadic_items = (n_items + 1).saturating_sub(self.serializers.len()); - let serializers_iter = self.serializers[..variadic_item_index] - .iter() - .chain(iter::repeat(&self.serializers[variadic_item_index]).take(n_variadic_items)) - .chain(self.serializers[variadic_item_index + 1..].iter()); - use_serializers!(serializers_iter); - } else { - use_serializers!(self.serializers.iter()); - }; + self.for_each_tuple_item_and_serializer(py_tuple, None, None, extra, |entry| { + entry + .serializer + .json_key(entry.item, extra) + .map(|key| key_builder.push(&key)) + })??; Ok(Cow::Owned(key_builder.finish())) } @@ -184,63 +127,18 @@ impl TypeSerializer for TupleSerializer { let py_tuple: &PyTuple = py_tuple.downcast().map_err(py_err_se_err)?; let n_items = py_tuple.len(); - let mut py_tuple_iter = py_tuple.iter(); let mut seq = serializer.serialize_seq(Some(n_items))?; - macro_rules! use_serializers { - ($serializers_iter:expr) => { - for (index, serializer) in $serializers_iter.enumerate() { - let element = match py_tuple_iter.next() { - Some(value) => value, - None => break, - }; - let op_next = self - .filter - .index_filter(index, include, exclude, Some(n_items)) - .map_err(py_err_se_err)?; - if let Some((next_include, next_exclude)) = op_next { - let item_serialize = - PydanticSerializer::new(element, serializer, next_include, next_exclude, extra); - seq.serialize_element(&item_serialize)?; - } - } - }; - } - - if let Some(variadic_item_index) = self.variadic_item_index { - // Need `saturating_sub` to handle items with too few elements without panicking - let n_variadic_items = (n_items + 1).saturating_sub(self.serializers.len()); - let serializers_iter = self.serializers[..variadic_item_index] - .iter() - .chain(iter::repeat(&self.serializers[variadic_item_index]).take(n_variadic_items)) - .chain(self.serializers[variadic_item_index + 1..].iter()); - use_serializers!(serializers_iter); - } else { - use_serializers!(self.serializers.iter()); - let mut warned = false; - for (i, element) in py_tuple_iter.enumerate() { - if !warned { - extra - .warnings - .custom_warning("Unexpected extra items present in tuple".to_string()); - warned = true; - } - let op_next = self - .filter - .index_filter(i + self.serializers.len(), include, exclude, Some(n_items)) - .map_err(py_err_se_err)?; - if let Some((next_include, next_exclude)) = op_next { - let item_serialize = PydanticSerializer::new( - element, - &CombinedSerializer::Any(AnySerializer), - next_include, - next_exclude, - extra, - ); - seq.serialize_element(&item_serialize)?; - } - } - }; + self.for_each_tuple_item_and_serializer(py_tuple, include, exclude, extra, |entry| { + seq.serialize_element(&PydanticSerializer::new( + entry.item, + entry.serializer, + entry.include, + entry.exclude, + extra, + )) + }) + .map_err(py_err_se_err)??; seq.end() } @@ -254,6 +152,100 @@ impl TypeSerializer for TupleSerializer { fn get_name(&self) -> &str { &self.name } + + fn retry_with_lax_check(&self) -> bool { + true + } +} + +struct TupleSerializerEntry<'a, 'py> { + item: &'py PyAny, + include: Option<&'py PyAny>, + exclude: Option<&'py PyAny>, + serializer: &'a CombinedSerializer, +} + +impl TupleSerializer { + /// Try to serialize each item in the tuple with the corresponding serializer. + /// + /// If the tuple doesn't match the length of the serializer, in strict mode, an error is returned. + /// + /// The error type E is the type of the error returned by the closure, which is why there are two + /// levels of `Result`. + fn for_each_tuple_item_and_serializer( + &self, + tuple: &PyTuple, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, + mut f: impl for<'a, 'py> FnMut(TupleSerializerEntry<'a, 'py>) -> Result<(), E>, + ) -> PyResult> { + let n_items = tuple.len(); + let mut py_tuple_iter = tuple.iter(); + + macro_rules! use_serializers { + ($serializers_iter:expr) => { + for (index, serializer) in $serializers_iter.enumerate() { + let element = match py_tuple_iter.next() { + Some(value) => value, + None => break, + }; + let op_next = self.filter.index_filter(index, include, exclude, Some(n_items))?; + if let Some((next_include, next_exclude)) = op_next { + if let Err(e) = f(TupleSerializerEntry { + item: element, + include: next_include, + exclude: next_exclude, + serializer, + }) { + return Ok(Err(e)); + }; + } + } + }; + } + + if let Some(variadic_item_index) = self.variadic_item_index { + // Need `saturating_sub` to handle items with too few elements without panicking + let n_variadic_items = (n_items + 1).saturating_sub(self.serializers.len()); + let serializers_iter = self.serializers[..variadic_item_index] + .iter() + .chain(iter::repeat(&self.serializers[variadic_item_index]).take(n_variadic_items)) + .chain(self.serializers[variadic_item_index + 1..].iter()); + use_serializers!(serializers_iter); + } else if extra.check == SerCheck::Strict && n_items != self.serializers.len() { + return Err(PydanticSerializationUnexpectedValue::new_err(Some(format!( + "Expected {} items, but got {}", + self.serializers.len(), + n_items + )))); + } else { + use_serializers!(self.serializers.iter()); + let mut warned = false; + for (i, element) in py_tuple_iter.enumerate() { + if !warned { + extra + .warnings + .custom_warning("Unexpected extra items present in tuple".to_string()); + warned = true; + } + let op_next = self + .filter + .index_filter(i + self.serializers.len(), include, exclude, Some(n_items))?; + if let Some((next_include, next_exclude)) = op_next { + if let Err(e) = f(TupleSerializerEntry { + item: element, + include: next_include, + exclude: next_exclude, + serializer: &CombinedSerializer::Any(AnySerializer), + }) { + return Ok(Err(e)); + }; + } + } + }; + Ok(Ok(())) + } } pub(crate) struct KeyBuilder { diff --git a/tests/serializers/test_list_tuple.py b/tests/serializers/test_list_tuple.py index df6aabd8a..256a342fe 100644 --- a/tests/serializers/test_list_tuple.py +++ b/tests/serializers/test_list_tuple.py @@ -411,3 +411,29 @@ def test_tuple_pos_dict_key(): assert s.to_python({(1, 'a', 2): 1}, mode='json') == {'1,a,2': 1} assert s.to_json({(1, 'a'): 1}) == b'{"1,a":1}' assert s.to_json({(1, 'a', 2): 1}) == b'{"1,a,2":1}' + + +def test_tuple_wrong_size_union(): + # See https://github.com/pydantic/pydantic/issues/8677 + + f = core_schema.float_schema() + s = SchemaSerializer( + core_schema.union_schema([core_schema.tuple_schema([f, f]), core_schema.tuple_schema([f, f, f])]) + ) + assert s.to_python((1.0, 2.0)) == (1.0, 2.0) + assert s.to_python((1.0, 2.0, 3.0)) == (1.0, 2.0, 3.0) + + with pytest.warns(UserWarning, match='Unexpected extra items present in tuple'): + s.to_python((1.0, 2.0, 3.0, 4.0)) + + assert s.to_python((1.0, 2.0), mode='json') == [1.0, 2.0] + assert s.to_python((1.0, 2.0, 3.0), mode='json') == [1.0, 2.0, 3.0] + + with pytest.warns(UserWarning, match='Unexpected extra items present in tuple'): + s.to_python((1.0, 2.0, 3.0, 4.0), mode='json') + + assert s.to_json((1.0, 2.0)) == b'[1.0,2.0]' + assert s.to_json((1.0, 2.0, 3.0)) == b'[1.0,2.0,3.0]' + + with pytest.warns(UserWarning, match='Unexpected extra items present in tuple'): + s.to_json((1.0, 2.0, 3.0, 4.0))