201 changes: 172 additions & 29 deletions src/recursion_guard.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use ahash::AHashSet;
use std::mem::MaybeUninit;

type RecursionKey = (
// Identifier for the input object, e.g. the id() of a Python dict
Expand All @@ -11,58 +12,200 @@ type RecursionKey = (

/// This is used to avoid cyclic references in input data causing recursive validation and a nasty segmentation fault.
/// It's used in `validators/definition` to detect when a reference is reused within itself.
pub(crate) struct RecursionGuard<'a, S: ContainsRecursionState> {
state: &'a mut S,
obj_id: usize,
node_id: usize,
}

pub(crate) enum RecursionError {
/// Cyclic reference detected
Cyclic,
/// Recursion limit exceeded
Depth,
}

impl<S: ContainsRecursionState> RecursionGuard<'_, S> {
/// Creates a recursion guard for the given object and node id.
///
/// When dropped, this will release the recursion for the given object and node id.
pub fn new(state: &'_ mut S, obj_id: usize, node_id: usize) -> Result<RecursionGuard<'_, S>, RecursionError> {
state.access_recursion_state(|state| {
if !state.insert(obj_id, node_id) {
return Err(RecursionError::Cyclic);
}
if state.incr_depth() {
return Err(RecursionError::Depth);
}
Ok(())
})?;
Ok(RecursionGuard { state, obj_id, node_id })
}

/// Retrieves the underlying state for further use.
pub fn state(&mut self) -> &mut S {
self.state
}
}

impl<S: ContainsRecursionState> Drop for RecursionGuard<'_, S> {
fn drop(&mut self) {
self.state.access_recursion_state(|state| {
state.decr_depth();
state.remove(self.obj_id, self.node_id);
});
}
}

/// This trait is used to retrieve the recursion state from some other type
pub(crate) trait ContainsRecursionState {
fn access_recursion_state<R>(&mut self, f: impl FnOnce(&mut RecursionState) -> R) -> R;
}

/// State for the RecursionGuard. Can also be used directly to increase / decrease depth.
#[derive(Debug, Clone, Default)]
pub struct RecursionGuard {
ids: Option<AHashSet<RecursionKey>>,
pub struct RecursionState {
ids: RecursionStack,
// depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just
// use one number for all validators
depth: u16,
depth: u8,
}

// A hard limit to avoid stack overflows when rampant recursion occurs
pub const RECURSION_GUARD_LIMIT: u16 = if cfg!(any(target_family = "wasm", all(windows, PyPy))) {
pub const RECURSION_GUARD_LIMIT: u8 = if cfg!(any(target_family = "wasm", all(windows, PyPy))) {
// wasm and windows PyPy have very limited stack sizes
50
49
} else if cfg!(any(PyPy, windows)) {
// PyPy and Windows in general have more restricted stack space
100
99
} else {
255
};

impl RecursionGuard {
// insert a new id into the set, return whether the set already had the id in it
pub fn contains_or_insert(&mut self, obj_id: usize, node_id: usize) -> bool {
match self.ids {
// https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
// "If the set did not have this value present, `true` is returned."
Some(ref mut set) => !set.insert((obj_id, node_id)),
None => {
let mut set: AHashSet<RecursionKey> = AHashSet::with_capacity(10);
set.insert((obj_id, node_id));
self.ids = Some(set);
false
}
}
impl RecursionState {
// insert a new value
// * return `false` if the stack already had it in it
// * return `true` if the stack didn't have it in it and it was inserted
fn insert(&mut self, obj_id: usize, node_id: usize) -> bool {
self.ids.insert((obj_id, node_id))
}

// see #143 this is used as a backup in case the identity check recursion guard fails
#[must_use]
#[cfg(any(target_family = "wasm", windows, PyPy))]
pub fn incr_depth(&mut self) -> bool {
// use saturating_add as it's faster (since there's no error path)
// and the RECURSION_GUARD_LIMIT check will be hit before it overflows
debug_assert!(RECURSION_GUARD_LIMIT < 255);
self.depth = self.depth.saturating_add(1);
self.depth > RECURSION_GUARD_LIMIT
}

#[must_use]
#[cfg(not(any(target_family = "wasm", windows, PyPy)))]
pub fn incr_depth(&mut self) -> bool {
self.depth += 1;
self.depth >= RECURSION_GUARD_LIMIT
debug_assert_eq!(RECURSION_GUARD_LIMIT, 255);
// use checked_add to check if we've hit the limit
if let Some(depth) = self.depth.checked_add(1) {
self.depth = depth;
false
} else {
true
}
}

pub fn decr_depth(&mut self) {
self.depth -= 1;
// for the same reason as incr_depth, use saturating_sub
self.depth = self.depth.saturating_sub(1);
}

pub fn remove(&mut self, obj_id: usize, node_id: usize) {
match self.ids {
Some(ref mut set) => {
set.remove(&(obj_id, node_id));
fn remove(&mut self, obj_id: usize, node_id: usize) {
self.ids.remove(&(obj_id, node_id));
}
}

// trial and error suggests this is a good value, going higher causes array lookups to get significantly slower
const ARRAY_SIZE: usize = 16;

#[derive(Debug, Clone)]
enum RecursionStack {
Array {
data: [MaybeUninit<RecursionKey>; ARRAY_SIZE],
len: usize,
},
Set(AHashSet<RecursionKey>),
}

impl Default for RecursionStack {
fn default() -> Self {
Self::Array {
data: std::array::from_fn(|_| MaybeUninit::uninit()),
len: 0,
}
}
}

impl RecursionStack {
// insert a new value
// * return `false` if the stack already had it in it
// * return `true` if the stack didn't have it in it and it was inserted
fn insert(&mut self, v: RecursionKey) -> bool {
match self {
Self::Array { data, len } => {
if *len < ARRAY_SIZE {
for value in data.iter().take(*len) {
// Safety: reading values within bounds
if unsafe { value.assume_init() } == v {
return false;
}
}

data[*len].write(v);
*len += 1;
true
} else {
let mut set = AHashSet::with_capacity(ARRAY_SIZE + 1);
for existing in data.iter() {
// Safety: the array is fully initialized
set.insert(unsafe { existing.assume_init() });
}
let inserted = set.insert(v);
*self = Self::Set(set);
inserted
}
}
// https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
// "If the set did not have this value present, `true` is returned."
Self::Set(set) => set.insert(v),
}
}

fn remove(&mut self, v: &RecursionKey) {
match self {
Self::Array { data, len } => {
*len = len.checked_sub(1).expect("remove from empty recursion guard");
// Safety: this is reading what was the back of the initialized array
let removed = unsafe { data.get_unchecked_mut(*len) };
assert!(unsafe { removed.assume_init_ref() } == v, "remove did not match insert");
// this should compile away to a noop
unsafe { std::ptr::drop_in_place(removed.as_mut_ptr()) }
}
Self::Set(set) => {
set.remove(v);
}
None => unreachable!(),
};
}
}
}

impl Drop for RecursionStack {
fn drop(&mut self) {
// This should compile away to a noop as Recursion>Key doesn't implement Drop, but it seemed
// desirable to leave this in for safety in case that should change in the future
if let Self::Array { data, len } = self {
for value in data.iter_mut().take(*len) {
// Safety: reading values within bounds
unsafe { std::ptr::drop_in_place(value.as_mut_ptr()) };
}
}
}
}
4 changes: 4 additions & 0 deletions src/serializers/computed_fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ impl ComputedFields {
}
for computed_field in &self.0 {
let property_name_py = computed_field.property_name_py.as_ref(model.py());
let value = model.getattr(property_name_py).map_err(py_err_se_err)?;
if extra.exclude_none && value.is_none() {
return Ok(());
}
if let Some((next_include, next_exclude)) = filter
.key_filter(property_name_py, include, exclude)
.map_err(py_err_se_err)?
Expand Down
144 changes: 64 additions & 80 deletions src/serializers/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,60 +15,98 @@ use crate::tools::SchemaDict;
use super::errors::py_err_se_err;

#[derive(Debug, Clone)]
#[allow(clippy::struct_field_names)]
pub(crate) struct SerializationConfig {
pub timedelta_mode: TimedeltaMode,
pub bytes_mode: BytesMode,
pub inf_nan_mode: InfNanMode,
}

impl SerializationConfig {
pub fn from_config(config: Option<&PyDict>) -> PyResult<Self> {
let timedelta_mode = TimedeltaMode::from_config(config)?;
let bytes_mode = BytesMode::from_config(config)?;
let inf_nan_mode = InfNanMode::from_config(config)?;
Ok(Self {
timedelta_mode,
bytes_mode,
inf_nan_mode,
})
}

pub fn from_args(timedelta_mode: &str, bytes_mode: &str) -> PyResult<Self> {
pub fn from_args(timedelta_mode: &str, bytes_mode: &str, inf_nan_mode: &str) -> PyResult<Self> {
Ok(Self {
timedelta_mode: TimedeltaMode::from_str(timedelta_mode)?,
bytes_mode: BytesMode::from_str(bytes_mode)?,
inf_nan_mode: InfNanMode::from_str(inf_nan_mode)?,
})
}
}

#[derive(Default, Debug, Clone)]
pub(crate) enum TimedeltaMode {
#[default]
Iso8601,
Float,
pub trait FromConfig {
fn from_config(config: Option<&PyDict>) -> PyResult<Self>
where
Self: Sized;
}

impl FromStr for TimedeltaMode {
type Err = PyErr;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"iso8601" => Ok(Self::Iso8601),
"float" => Ok(Self::Float),
s => py_schema_err!(
"Invalid timedelta serialization mode: `{}`, expected `iso8601` or `float`",
s
),
macro_rules! serialization_mode {
($name:ident, $config_key:expr, $($variant:ident => $value:expr),* $(,)?) => {
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub(crate) enum $name {
#[default]
$($variant,)*
}
}

impl FromStr for $name {
type Err = PyErr;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
$($value => Ok(Self::$variant),)*
s => py_schema_err!(
concat!("Invalid ", stringify!($name), " serialization mode: `{}`, expected ", $($value, " or "),*),
s
),
}
}
}

impl FromConfig for $name {
fn from_config(config: Option<&PyDict>) -> PyResult<Self> {
let Some(config_dict) = config else {
return Ok(Self::default());
};
let raw_mode = config_dict.get_as::<&str>(intern!(config_dict.py(), $config_key))?;
raw_mode.map_or_else(|| Ok(Self::default()), Self::from_str)
}
}

};
}

impl TimedeltaMode {
pub fn from_config(config: Option<&PyDict>) -> PyResult<Self> {
let Some(config_dict) = config else {
return Ok(Self::default());
};
let raw_mode = config_dict.get_as::<&str>(intern!(config_dict.py(), "ser_json_timedelta"))?;
raw_mode.map_or_else(|| Ok(Self::default()), Self::from_str)
}
serialization_mode! {
TimedeltaMode,
"ser_json_timedelta",
Iso8601 => "iso8601",
Float => "float",
}

serialization_mode! {
BytesMode,
"ser_json_bytes",
Utf8 => "utf8",
Base64 => "base64",
Hex => "hex",
}

serialization_mode! {
InfNanMode,
"ser_json_inf_nan",
Null => "null",
Constants => "constants",
}

impl TimedeltaMode {
fn total_seconds(py_timedelta: &PyDelta) -> PyResult<&PyAny> {
py_timedelta.call_method0(intern!(py_timedelta.py(), "total_seconds"))
}
Expand Down Expand Up @@ -124,39 +162,7 @@ impl TimedeltaMode {
}
}

#[derive(Default, Debug, Clone)]
pub(crate) enum BytesMode {
#[default]
Utf8,
Base64,
Hex,
}

impl FromStr for BytesMode {
type Err = PyErr;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"utf8" => Ok(Self::Utf8),
"base64" => Ok(Self::Base64),
"hex" => Ok(Self::Hex),
s => py_schema_err!(
"Invalid bytes serialization mode: `{}`, expected `utf8`, `base64` or `hex`",
s
),
}
}
}

impl BytesMode {
pub fn from_config(config: Option<&PyDict>) -> PyResult<Self> {
let Some(config_dict) = config else {
return Ok(Self::default());
};
let raw_mode = config_dict.get_as::<&str>(intern!(config_dict.py(), "ser_json_bytes"))?;
raw_mode.map_or_else(|| Ok(Self::default()), Self::from_str)
}

pub fn bytes_to_string<'py>(&self, py: Python, bytes: &'py [u8]) -> PyResult<Cow<'py, str>> {
match self {
Self::Utf8 => from_utf8(bytes)
Expand Down Expand Up @@ -190,28 +196,6 @@ pub fn utf8_py_error(py: Python, err: Utf8Error, data: &[u8]) -> PyErr {
}
}

#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub(crate) enum InfNanMode {
#[default]
Null,
Constants,
}

impl FromStr for InfNanMode {
type Err = PyErr;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"null" => Ok(Self::Null),
"constants" => Ok(Self::Constants),
s => py_schema_err!(
"Invalid inf_nan serialization mode: `{}`, expected `null` or `constants`",
s
),
}
}
}

impl FromPyObject<'_> for InfNanMode {
fn extract(ob: &'_ PyAny) -> PyResult<Self> {
let s = ob.extract::<&str>()?;
Expand Down
60 changes: 31 additions & 29 deletions src/serializers/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,24 @@ use serde::ser::Error;
use super::config::SerializationConfig;
use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER_MARKER};
use super::ob_type::ObTypeLookup;
use crate::recursion_guard::ContainsRecursionState;
use crate::recursion_guard::RecursionError;
use crate::recursion_guard::RecursionGuard;
use crate::recursion_guard::RecursionState;

/// this is ugly, would be much better if extra could be stored in `SerializationState`
/// then `SerializationState` got a `serialize_infer` method, but I couldn't get it to work
pub(crate) struct SerializationState {
warnings: CollectWarnings,
rec_guard: SerRecursionGuard,
rec_guard: SerRecursionState,
config: SerializationConfig,
}

impl SerializationState {
pub fn new(timedelta_mode: &str, bytes_mode: &str) -> PyResult<Self> {
pub fn new(timedelta_mode: &str, bytes_mode: &str, inf_nan_mode: &str) -> PyResult<Self> {
let warnings = CollectWarnings::new(false);
let rec_guard = SerRecursionGuard::default();
let config = SerializationConfig::from_args(timedelta_mode, bytes_mode)?;
let rec_guard = SerRecursionState::default();
let config = SerializationConfig::from_args(timedelta_mode, bytes_mode, inf_nan_mode)?;
Ok(Self {
warnings,
rec_guard,
Expand Down Expand Up @@ -77,7 +80,7 @@ pub(crate) struct Extra<'a> {
pub exclude_none: bool,
pub round_trip: bool,
pub config: &'a SerializationConfig,
pub rec_guard: &'a SerRecursionGuard,
pub rec_guard: &'a SerRecursionState,
// the next two are used for union logic
pub check: SerCheck,
// data representing the current model field
Expand All @@ -101,7 +104,7 @@ impl<'a> Extra<'a> {
exclude_none: bool,
round_trip: bool,
config: &'a SerializationConfig,
rec_guard: &'a SerRecursionGuard,
rec_guard: &'a SerRecursionState,
serialize_unknown: bool,
fallback: Option<&'a PyAny>,
) -> Self {
Expand All @@ -124,6 +127,22 @@ impl<'a> Extra<'a> {
}
}

pub fn recursion_guard<'x, 'y>(
// TODO: this double reference is a bit if a hack, but it's necessary because the recursion
// guard is not passed around with &mut reference
//
// See how validation has &mut ValidationState passed around; we should aim to refactor
// to match that.
self: &'x mut &'y Self,
value: &PyAny,
def_ref_id: usize,
) -> PyResult<RecursionGuard<'x, &'y Self>> {
RecursionGuard::new(self, value.as_ptr() as usize, def_ref_id).map_err(|e| match e {
RecursionError::Depth => PyValueError::new_err("Circular reference detected (depth exceeded)"),
RecursionError::Cyclic => PyValueError::new_err("Circular reference detected (id repeated)"),
})
}

pub fn serialize_infer<'py>(&'py self, value: &'py PyAny) -> super::infer::SerializeInfer<'py> {
super::infer::SerializeInfer::new(value, None, None, self)
}
Expand Down Expand Up @@ -157,7 +176,7 @@ pub(crate) struct ExtraOwned {
exclude_none: bool,
round_trip: bool,
config: SerializationConfig,
rec_guard: SerRecursionGuard,
rec_guard: SerRecursionState,
check: SerCheck,
model: Option<PyObject>,
field_name: Option<String>,
Expand Down Expand Up @@ -340,29 +359,12 @@ impl CollectWarnings {

#[derive(Default, Clone)]
#[cfg_attr(debug_assertions, derive(Debug))]
pub struct SerRecursionGuard {
guard: RefCell<RecursionGuard>,
pub struct SerRecursionState {
guard: RefCell<RecursionState>,
}

impl SerRecursionGuard {
pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult<usize> {
// https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
// "If the set did not have this value present, `true` is returned."
let id = value.as_ptr() as usize;
let mut guard = self.guard.borrow_mut();

if guard.contains_or_insert(id, def_ref_id) {
Err(PyValueError::new_err("Circular reference detected (id repeated)"))
} else if guard.incr_depth() {
Err(PyValueError::new_err("Circular reference detected (depth exceeded)"))
} else {
Ok(id)
}
}

pub fn pop(&self, id: usize, def_ref_id: usize) {
let mut guard = self.guard.borrow_mut();
guard.decr_depth();
guard.remove(id, def_ref_id);
impl ContainsRecursionState for &'_ Extra<'_> {
fn access_recursion_state<R>(&mut self, f: impl FnOnce(&mut RecursionState) -> R) -> R {
f(&mut self.rec_guard.guard.borrow_mut())
}
}
257 changes: 162 additions & 95 deletions src/serializers/fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ pub struct GeneralFieldsSerializer {
required_fields: usize,
}

macro_rules! option_length {
($op_has_len:expr) => {
match $op_has_len {
Some(ref has_len) => has_len.len(),
None => 0,
}
};
}

impl GeneralFieldsSerializer {
pub(super) fn new(
fields: AHashMap<String, SerField>,
Expand Down Expand Up @@ -136,50 +145,21 @@ impl GeneralFieldsSerializer {
}
}
}
}

macro_rules! option_length {
($op_has_len:expr) => {
match $op_has_len {
Some(ref has_len) => has_len.len(),
None => 0,
}
};
}

impl_py_gc_traverse!(GeneralFieldsSerializer {
fields,
computed_fields
});

impl TypeSerializer for GeneralFieldsSerializer {
fn to_python(
pub fn main_to_python<'py>(
&self,
value: &PyAny,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
) -> PyResult<PyObject> {
let py = value.py();
// If there is already a model registered (from a dataclass, BaseModel)
// then do not touch it
// If there is no model, we (a TypedDict) are the model
let td_extra = Extra {
model: extra.model.map_or_else(|| Some(value), Some),
..*extra
};
let (main_dict, extra_dict) = if let Some(main_extra_dict) = self.extract_dicts(value) {
main_extra_dict
} else {
td_extra.warnings.on_fallback_py(self.get_name(), value, &td_extra)?;
return infer_to_python(value, include, exclude, &td_extra);
};

py: Python<'py>,
main_iter: impl Iterator<Item = PyResult<(&'py PyAny, &'py PyAny)>>,
include: Option<&'py PyAny>,
exclude: Option<&'py PyAny>,
extra: Extra,
) -> PyResult<&'py PyDict> {
let output_dict = PyDict::new(py);
let mut used_req_fields: usize = 0;

// NOTE! we maintain the order of the input dict assuming that's right
for (key, value) in main_dict {
for result in main_iter {
let (key, value) = result?;
let key_str = key_str(key)?;
let op_field = self.fields.get(key_str);
if extra.exclude_none && value.is_none() {
Expand All @@ -190,16 +170,16 @@ impl TypeSerializer for GeneralFieldsSerializer {
}
continue;
}
let extra = Extra {
let field_extra = Extra {
field_name: Some(key_str),
..td_extra
..extra
};
if let Some((next_include, next_exclude)) = self.filter.key_filter(key, include, exclude)? {
if let Some(field) = op_field {
if let Some(ref serializer) = field.serializer {
if !exclude_default(value, &extra, serializer)? {
let value = serializer.to_python(value, next_include, next_exclude, &extra)?;
let output_key = field.get_key_py(output_dict.py(), &extra);
if !exclude_default(value, &field_extra, serializer)? {
let value = serializer.to_python(value, next_include, next_exclude, &field_extra)?;
let output_key = field.get_key_py(output_dict.py(), &field_extra);
output_dict.set_item(output_key, value)?;
}
}
Expand All @@ -209,23 +189,140 @@ impl TypeSerializer for GeneralFieldsSerializer {
}
} else if self.mode == FieldsMode::TypedDictAllow {
let value = match &self.extra_serializer {
Some(serializer) => serializer.to_python(value, next_include, next_exclude, &extra)?,
None => infer_to_python(value, next_include, next_exclude, &extra)?,
Some(serializer) => serializer.to_python(value, next_include, next_exclude, &field_extra)?,
None => infer_to_python(value, next_include, next_exclude, &field_extra)?,
};
output_dict.set_item(key, value)?;
} else if extra.check == SerCheck::Strict {
} else if field_extra.check == SerCheck::Strict {
return Err(PydanticSerializationUnexpectedValue::new_err(None));
}
}
}
if td_extra.check.enabled()

if extra.check.enabled()
// If any of these are true we can't count fields
&& !(extra.exclude_defaults || extra.exclude_unset || extra.exclude_none)
// Check for missing fields, we can't have extra fields here
&& self.required_fields > used_req_fields
{
return Err(PydanticSerializationUnexpectedValue::new_err(None));
Err(PydanticSerializationUnexpectedValue::new_err(None))
} else {
Ok(output_dict)
}
}

pub fn main_serde_serialize<'py, S: serde::ser::Serializer>(
&self,
main_iter: impl Iterator<Item = PyResult<(&'py PyAny, &'py PyAny)>>,
expected_len: usize,
serializer: S,
include: Option<&'py PyAny>,
exclude: Option<&'py PyAny>,
extra: Extra,
) -> Result<S::SerializeMap, S::Error> {
// NOTE! As above, we maintain the order of the input dict assuming that's right
// we don't both with `used_fields` here because on unions, `to_python(..., mode='json')` is used
let mut map = serializer.serialize_map(Some(expected_len))?;

for result in main_iter {
let (key, value) = result.map_err(py_err_se_err)?;
if extra.exclude_none && value.is_none() {
continue;
}
let key_str = key_str(key).map_err(py_err_se_err)?;
let field_extra = Extra {
field_name: Some(key_str),
..extra
};

let filter = self.filter.key_filter(key, include, exclude).map_err(py_err_se_err)?;
if let Some((next_include, next_exclude)) = filter {
if let Some(field) = self.fields.get(key_str) {
if let Some(ref serializer) = field.serializer {
if !exclude_default(value, &field_extra, serializer).map_err(py_err_se_err)? {
let s =
PydanticSerializer::new(value, serializer, next_include, next_exclude, &field_extra);
let output_key = field.get_key_json(key_str, &field_extra);
map.serialize_entry(&output_key, &s)?;
}
}
} else if self.mode == FieldsMode::TypedDictAllow {
let output_key = infer_json_key(key, &field_extra).map_err(py_err_se_err)?;
let s = SerializeInfer::new(value, next_include, next_exclude, &field_extra);
map.serialize_entry(&output_key, &s)?;
}
// no error case here since unions (which need the error case) use `to_python(..., mode='json')`
}
}
Ok(map)
}

pub fn add_computed_fields_python(
&self,
model: Option<&PyAny>,
output_dict: &PyDict,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
) -> PyResult<()> {
if let Some(ref computed_fields) = self.computed_fields {
if let Some(model_value) = model {
let cf_extra = Extra { model, ..*extra };
computed_fields.to_python(model_value, output_dict, &self.filter, include, exclude, &cf_extra)?;
}
}
Ok(())
}

pub fn add_computed_fields_json<S: serde::ser::Serializer>(
&self,
model: Option<&PyAny>,
map: &mut S::SerializeMap,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
) -> Result<(), S::Error> {
if let Some(ref computed_fields) = self.computed_fields {
if let Some(model) = model {
computed_fields.serde_serialize::<S>(model, map, &self.filter, include, exclude, extra)?;
}
}
Ok(())
}

pub fn computed_field_count(&self) -> usize {
option_length!(self.computed_fields)
}
}

impl_py_gc_traverse!(GeneralFieldsSerializer {
fields,
computed_fields
});

impl TypeSerializer for GeneralFieldsSerializer {
fn to_python(
&self,
value: &PyAny,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
) -> PyResult<PyObject> {
let py = value.py();
// If there is already a model registered (from a dataclass, BaseModel)
// then do not touch it
// If there is no model, we (a TypedDict) are the model
let model = extra.model.map_or_else(|| Some(value), Some);
let td_extra = Extra { model, ..*extra };
let (main_dict, extra_dict) = if let Some(main_extra_dict) = self.extract_dicts(value) {
main_extra_dict
} else {
td_extra.warnings.on_fallback_py(self.get_name(), value, &td_extra)?;
return infer_to_python(value, include, exclude, &td_extra);
};

let output_dict = self.main_to_python(py, main_dict.iter().map(Ok), include, exclude, td_extra)?;

// this is used to include `__pydantic_extra__` in serialization on models
if let Some(extra_dict) = extra_dict {
for (key, value) in extra_dict {
Expand All @@ -241,11 +338,7 @@ impl TypeSerializer for GeneralFieldsSerializer {
}
}
}
if let Some(ref computed_fields) = self.computed_fields {
if let Some(model) = td_extra.model {
computed_fields.to_python(model, output_dict, &self.filter, include, exclude, &td_extra)?;
}
}
self.add_computed_fields_python(model, output_dict, include, exclude, extra)?;
Ok(output_dict.into_py(py))
}

Expand All @@ -271,46 +364,23 @@ impl TypeSerializer for GeneralFieldsSerializer {
// If there is already a model registered (from a dataclass, BaseModel)
// then do not touch it
// If there is no model, we (a TypedDict) are the model
let td_extra = Extra {
model: extra.model.map_or_else(|| Some(value), Some),
..*extra
};
let model = extra.model.map_or_else(|| Some(value), Some);
let td_extra = Extra { model, ..*extra };
let expected_len = match self.mode {
FieldsMode::TypedDictAllow => main_dict.len() + option_length!(self.computed_fields),
_ => self.fields.len() + option_length!(extra_dict) + option_length!(self.computed_fields),
FieldsMode::TypedDictAllow => main_dict.len() + self.computed_field_count(),
_ => self.fields.len() + option_length!(extra_dict) + self.computed_field_count(),
};
// NOTE! As above, we maintain the order of the input dict assuming that's right
// we don't both with `used_fields` here because on unions, `to_python(..., mode='json')` is used
let mut map = serializer.serialize_map(Some(expected_len))?;

for (key, value) in main_dict {
if extra.exclude_none && value.is_none() {
continue;
}
let key_str = key_str(key).map_err(py_err_se_err)?;
let extra = Extra {
field_name: Some(key_str),
..td_extra
};
let mut map = self.main_serde_serialize(
main_dict.iter().map(Ok),
expected_len,
serializer,
include,
exclude,
td_extra,
)?;

let filter = self.filter.key_filter(key, include, exclude).map_err(py_err_se_err)?;
if let Some((next_include, next_exclude)) = filter {
if let Some(field) = self.fields.get(key_str) {
if let Some(ref serializer) = field.serializer {
if !exclude_default(value, &extra, serializer).map_err(py_err_se_err)? {
let s = PydanticSerializer::new(value, serializer, next_include, next_exclude, &extra);
let output_key = field.get_key_json(key_str, &extra);
map.serialize_entry(&output_key, &s)?;
}
}
} else if self.mode == FieldsMode::TypedDictAllow {
let output_key = infer_json_key(key, &extra).map_err(py_err_se_err)?;
let s = SerializeInfer::new(value, next_include, next_exclude, &extra);
map.serialize_entry(&output_key, &s)?;
}
// no error case here since unions (which need the error case) use `to_python(..., mode='json')`
}
}
// this is used to include `__pydantic_extra__` in serialization on models
if let Some(extra_dict) = extra_dict {
for (key, value) in extra_dict {
Expand All @@ -319,17 +389,14 @@ impl TypeSerializer for GeneralFieldsSerializer {
}
let filter = self.filter.key_filter(key, include, exclude).map_err(py_err_se_err)?;
if let Some((next_include, next_exclude)) = filter {
let output_key = infer_json_key(key, &td_extra).map_err(py_err_se_err)?;
let s = SerializeInfer::new(value, next_include, next_exclude, &td_extra);
let output_key = infer_json_key(key, extra).map_err(py_err_se_err)?;
let s = SerializeInfer::new(value, next_include, next_exclude, extra);
map.serialize_entry(&output_key, &s)?;
}
}
}
if let Some(ref computed_fields) = self.computed_fields {
if let Some(model) = td_extra.model {
computed_fields.serde_serialize::<S>(model, &mut map, &self.filter, include, exclude, &td_extra)?;
}
}

self.add_computed_fields_json::<S>(model, &mut map, include, exclude, extra)?;
map.end()
}

Expand Down
173 changes: 100 additions & 73 deletions src/serializers/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,17 @@ use pyo3::types::{
use serde::ser::{Error, Serialize, SerializeMap, SerializeSeq, Serializer};

use crate::input::{EitherTimedelta, Int};
use crate::serializers::errors::SERIALIZATION_ERR_MARKER;
use crate::serializers::filter::SchemaFilter;
use crate::serializers::shared::{PydanticSerializer, TypeSerializer};
use crate::serializers::SchemaSerializer;
use crate::tools::{extract_i64, py_err, safe_repr};
use crate::url::{PyMultiHostUrl, PyUrl};

use super::config::InfNanMode;
use super::errors::SERIALIZATION_ERR_MARKER;
use super::errors::{py_err_se_err, PydanticSerializationError};
use super::extra::{Extra, SerMode};
use super::filter::AnyFilter;
use super::filter::{AnyFilter, SchemaFilter};
use super::ob_type::ObType;
use super::shared::dataclass_to_dict;
use super::shared::{any_dataclass_iter, PydanticSerializer, TypeSerializer};
use super::SchemaSerializer;

pub(crate) fn infer_to_python(
value: &PyAny,
Expand All @@ -41,19 +40,22 @@ pub(crate) fn infer_to_python_known(
value: &PyAny,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
mut extra: &Extra,
) -> PyResult<PyObject> {
let py = value.py();
let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID) {
Ok(id) => id,

let mode = extra.mode;
let mut guard = match extra.recursion_guard(value, INFER_DEF_REF_ID) {
Ok(v) => v,
Err(e) => {
return match extra.mode {
return match mode {
SerMode::Json => Err(e),
// if recursion is detected by we're serializing to python, we just return the value
_ => Ok(value.into_py(py)),
};
}
};
let extra = guard.state();

macro_rules! serialize_seq {
($t:ty) => {
Expand Down Expand Up @@ -82,22 +84,6 @@ pub(crate) fn infer_to_python_known(
}};
}

let serialize_dict = |dict: &PyDict| {
let new_dict = PyDict::new(py);
let filter = AnyFilter::new();

for (k, v) in dict {
let op_next = filter.key_filter(k, include, exclude)?;
if let Some((next_include, next_exclude)) = op_next {
let k_str = infer_json_key(k, extra)?;
let k = PyString::new(py, &k_str);
let v = infer_to_python(v, next_include, next_exclude, extra)?;
new_dict.set_item(k, v)?;
}
}
Ok::<PyObject, PyErr>(new_dict.into_py(py))
};

let serialize_with_serializer = || {
let py_serializer = value.getattr(intern!(py, "__pydantic_serializer__"))?;
let serializer: PyRef<SchemaSerializer> = py_serializer.extract()?;
Expand All @@ -120,10 +106,19 @@ pub(crate) fn infer_to_python_known(
let value = match extra.mode {
SerMode::Json => match ob_type {
// `bool` and `None` can't be subclasses, `ObType::Int`, `ObType::Float`, `ObType::Str` refer to exact types
ObType::None | ObType::Bool | ObType::Int | ObType::Float | ObType::Str => value.into_py(py),
ObType::None | ObType::Bool | ObType::Int | ObType::Str => value.into_py(py),
// have to do this to make sure subclasses of for example str are upcast to `str`
ObType::IntSubclass => extract_i64(value)?.into_py(py),
ObType::FloatSubclass => value.extract::<f64>()?.into_py(py),
ObType::IntSubclass => match extract_i64(value) {
Some(v) => v.into_py(py),
None => return py_err!(PyTypeError; "expected int, got {}", safe_repr(value)),
},
ObType::Float | ObType::FloatSubclass => {
let v = value.extract::<f64>()?;
if (v.is_nan() || v.is_infinite()) && extra.config.inf_nan_mode == InfNanMode::Null {
return Ok(py.None());
}
v.into_py(py)
}
ObType::Decimal => value.to_string().into_py(py),
ObType::StrSubclass => value.extract::<&str>()?.into_py(py),
ObType::Bytes => extra
Expand Down Expand Up @@ -158,7 +153,12 @@ pub(crate) fn infer_to_python_known(
let elements = serialize_seq!(PyFrozenSet);
PyList::new(py, elements).into_py(py)
}
ObType::Dict => serialize_dict(value.downcast()?)?,
ObType::Dict => {
let dict: &PyDict = value.downcast()?;
serialize_pairs_python(py, dict.iter().map(Ok), include, exclude, extra, |k| {
Ok(PyString::new(py, &infer_json_key(k, extra)?))
})?
}
ObType::Datetime => {
let py_dt: &PyDateTime = value.downcast()?;
let iso_dt = super::type_serializers::datetime_etc::datetime_to_string(py_dt)?;
Expand Down Expand Up @@ -195,7 +195,11 @@ pub(crate) fn infer_to_python_known(
uuid.into_py(py)
}
ObType::PydanticSerializable => serialize_with_serializer()?,
ObType::Dataclass => serialize_dict(dataclass_to_dict(value)?)?,
ObType::Dataclass => {
serialize_pairs_python(py, any_dataclass_iter(value)?.0, include, exclude, extra, |k| {
Ok(PyString::new(py, &infer_json_key(k, extra)?))
})?
}
ObType::Enum => {
let v = value.getattr(intern!(py, "value"))?;
infer_to_python(v, include, exclude, extra)?.into_py(py)
Expand All @@ -219,7 +223,6 @@ pub(crate) fn infer_to_python_known(
if let Some(fallback) = extra.fallback {
let next_value = fallback.call1((value,))?;
let next_result = infer_to_python(next_value, include, exclude, extra);
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
return next_result;
} else if extra.serialize_unknown {
serialize_unknown(value).into_py(py)
Expand All @@ -246,22 +249,11 @@ pub(crate) fn infer_to_python_known(
PyFrozenSet::new(py, &elements)?.into_py(py)
}
ObType::Dict => {
// different logic for keys from above
let dict: &PyDict = value.downcast()?;
let new_dict = PyDict::new(py);
let filter = AnyFilter::new();

for (k, v) in dict {
let op_next = filter.key_filter(k, include, exclude)?;
if let Some((next_include, next_exclude)) = op_next {
let v = infer_to_python(v, next_include, next_exclude, extra)?;
new_dict.set_item(k, v)?;
}
}
new_dict.into_py(py)
serialize_pairs_python(py, dict.iter().map(Ok), include, exclude, extra, Ok)?
}
ObType::PydanticSerializable => serialize_with_serializer()?,
ObType::Dataclass => serialize_dict(dataclass_to_dict(value)?)?,
ObType::Dataclass => serialize_pairs_python(py, any_dataclass_iter(value)?.0, include, exclude, extra, Ok)?,
ObType::Generator => {
let iter = super::type_serializers::generator::SerializationIterator::new(
value.downcast()?,
Expand All @@ -277,15 +269,13 @@ pub(crate) fn infer_to_python_known(
if let Some(fallback) = extra.fallback {
let next_value = fallback.call1((value,))?;
let next_result = infer_to_python(next_value, include, exclude, extra);
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
return next_result;
}
value.into_py(py)
}
_ => value.into_py(py),
},
};
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
Ok(value)
}

Expand Down Expand Up @@ -342,18 +332,21 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
serializer: S,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
mut extra: &Extra,
) -> Result<S::Ok, S::Error> {
let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID).map_err(py_err_se_err) {
let extra_serialize_unknown = extra.serialize_unknown;
let mut guard = match extra.recursion_guard(value, INFER_DEF_REF_ID) {
Ok(v) => v,
Err(e) => {
return if extra.serialize_unknown {
return if extra_serialize_unknown {
serializer.serialize_str("...")
} else {
Err(e)
}
Err(py_err_se_err(e))
};
}
};
let extra = guard.state();

macro_rules! serialize {
($t:ty) => {
match value.extract::<$t>() {
Expand Down Expand Up @@ -395,23 +388,6 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
}};
}

macro_rules! serialize_dict {
($py_dict:expr) => {{
let mut map = serializer.serialize_map(Some($py_dict.len()))?;
let filter = AnyFilter::new();

for (key, value) in $py_dict {
let op_next = filter.key_filter(key, include, exclude).map_err(py_err_se_err)?;
if let Some((next_include, next_exclude)) = op_next {
let key = infer_json_key(key, extra).map_err(py_err_se_err)?;
let value_serializer = SerializeInfer::new(value, next_include, next_exclude, extra);
map.serialize_entry(&key, &value_serializer)?;
}
}
map.end()
}};
}

let ser_result = match ob_type {
ObType::None => serializer.serialize_none(),
ObType::Int | ObType::IntSubclass => serialize!(Int),
Expand All @@ -435,7 +411,10 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
.bytes_mode
.serialize_bytes(unsafe { py_byte_array.as_bytes() }, serializer)
}
ObType::Dict => serialize_dict!(value.downcast::<PyDict>().map_err(py_err_se_err)?),
ObType::Dict => {
let dict = value.downcast::<PyDict>().map_err(py_err_se_err)?;
serialize_pairs_json(dict.iter().map(Ok), dict.len(), serializer, include, exclude, extra)
}
ObType::List => serialize_seq_filter!(PyList),
ObType::Tuple => serialize_seq_filter!(PyTuple),
ObType::Set => serialize_seq!(PySet),
Expand Down Expand Up @@ -493,7 +472,10 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
PydanticSerializer::new(value, &extracted_serializer.serializer, include, exclude, &extra);
pydantic_serializer.serialize(serializer)
}
ObType::Dataclass => serialize_dict!(dataclass_to_dict(value).map_err(py_err_se_err)?),
ObType::Dataclass => {
let (pairs_iter, fields_dict) = any_dataclass_iter(value).map_err(py_err_se_err)?;
serialize_pairs_json(pairs_iter, fields_dict.len(), serializer, include, exclude, extra)
}
ObType::Uuid => {
let py_uuid: &PyAny = value.downcast().map_err(py_err_se_err)?;
let uuid = super::type_serializers::uuid::uuid_to_string(py_uuid).map_err(py_err_se_err)?;
Expand Down Expand Up @@ -527,7 +509,6 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
if let Some(fallback) = extra.fallback {
let next_value = fallback.call1((value,)).map_err(py_err_se_err)?;
let next_result = infer_serialize(next_value, serializer, include, exclude, extra);
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
return next_result;
} else if extra.serialize_unknown {
serializer.serialize_str(&serialize_unknown(value))
Expand All @@ -541,7 +522,6 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
}
}
};
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
ser_result
}

Expand Down Expand Up @@ -662,3 +642,50 @@ pub(crate) fn infer_json_key_known<'py>(ob_type: ObType, key: &'py PyAny, extra:
}
}
}

fn serialize_pairs_python<'py>(
py: Python,
pairs_iter: impl Iterator<Item = PyResult<(&'py PyAny, &'py PyAny)>>,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
key_transform: impl Fn(&'py PyAny) -> PyResult<&'py PyAny>,
) -> PyResult<PyObject> {
let new_dict = PyDict::new(py);
let filter = AnyFilter::new();

for result in pairs_iter {
let (k, v) = result?;
let op_next = filter.key_filter(k, include, exclude)?;
if let Some((next_include, next_exclude)) = op_next {
let k = key_transform(k)?;
let v = infer_to_python(v, next_include, next_exclude, extra)?;
new_dict.set_item(k, v)?;
}
}
Ok(new_dict.into_py(py))
}

fn serialize_pairs_json<'py, S: Serializer>(
pairs_iter: impl Iterator<Item = PyResult<(&'py PyAny, &'py PyAny)>>,
iter_size: usize,
serializer: S,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
) -> Result<S::Ok, S::Error> {
let mut map = serializer.serialize_map(Some(iter_size))?;
let filter = AnyFilter::new();

for result in pairs_iter {
let (key, value) = result.map_err(py_err_se_err)?;

let op_next = filter.key_filter(key, include, exclude).map_err(py_err_se_err)?;
if let Some((next_include, next_exclude)) = op_next {
let key = infer_json_key(key, extra).map_err(py_err_se_err)?;
let value_serializer = SerializeInfer::new(value, next_include, next_exclude, extra);
map.serialize_entry(&key, &value_serializer)?;
}
}
map.end()
}
18 changes: 10 additions & 8 deletions src/serializers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::py_gc::PyGcTraverse;

use config::SerializationConfig;
pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue};
use extra::{CollectWarnings, SerRecursionGuard};
use extra::{CollectWarnings, SerRecursionState};
pub(crate) use extra::{Extra, SerMode, SerializationState};
pub use shared::CombinedSerializer;
use shared::{to_json_bytes, BuildSerializer, TypeSerializer};
Expand Down Expand Up @@ -52,7 +52,7 @@ impl SchemaSerializer {
exclude_defaults: bool,
exclude_none: bool,
round_trip: bool,
rec_guard: &'a SerRecursionGuard,
rec_guard: &'a SerRecursionState,
serialize_unknown: bool,
fallback: Option<&'a PyAny>,
) -> Extra<'b> {
Expand Down Expand Up @@ -113,7 +113,7 @@ impl SchemaSerializer {
) -> PyResult<PyObject> {
let mode: SerMode = mode.into();
let warnings = CollectWarnings::new(warnings);
let rec_guard = SerRecursionGuard::default();
let rec_guard = SerRecursionState::default();
let extra = self.build_extra(
py,
&mode,
Expand Down Expand Up @@ -152,7 +152,7 @@ impl SchemaSerializer {
fallback: Option<&PyAny>,
) -> PyResult<PyObject> {
let warnings = CollectWarnings::new(warnings);
let rec_guard = SerRecursionGuard::default();
let rec_guard = SerRecursionState::default();
let extra = self.build_extra(
py,
&SerMode::Json,
Expand Down Expand Up @@ -213,7 +213,7 @@ impl SchemaSerializer {
#[pyfunction]
#[pyo3(signature = (value, *, indent = None, include = None, exclude = None, by_alias = true,
exclude_none = false, round_trip = false, timedelta_mode = "iso8601", bytes_mode = "utf8",
serialize_unknown = false, fallback = None))]
inf_nan_mode = "constants", serialize_unknown = false, fallback = None))]
pub fn to_json(
py: Python,
value: &PyAny,
Expand All @@ -225,10 +225,11 @@ pub fn to_json(
round_trip: bool,
timedelta_mode: &str,
bytes_mode: &str,
inf_nan_mode: &str,
serialize_unknown: bool,
fallback: Option<&PyAny>,
) -> PyResult<PyObject> {
let state = SerializationState::new(timedelta_mode, bytes_mode)?;
let state = SerializationState::new(timedelta_mode, bytes_mode, inf_nan_mode)?;
let extra = state.extra(
py,
&SerMode::Json,
Expand All @@ -248,7 +249,7 @@ pub fn to_json(
#[allow(clippy::too_many_arguments)]
#[pyfunction]
#[pyo3(signature = (value, *, include = None, exclude = None, by_alias = true, exclude_none = false, round_trip = false,
timedelta_mode = "iso8601", bytes_mode = "utf8", serialize_unknown = false, fallback = None))]
timedelta_mode = "iso8601", bytes_mode = "utf8", inf_nan_mode = "constants", serialize_unknown = false, fallback = None))]
pub fn to_jsonable_python(
py: Python,
value: &PyAny,
Expand All @@ -259,10 +260,11 @@ pub fn to_jsonable_python(
round_trip: bool,
timedelta_mode: &str,
bytes_mode: &str,
inf_nan_mode: &str,
serialize_unknown: bool,
fallback: Option<&PyAny>,
) -> PyResult<PyObject> {
let state = SerializationState::new(timedelta_mode, bytes_mode)?;
let state = SerializationState::new(timedelta_mode, bytes_mode, inf_nan_mode)?;
let extra = state.extra(
py,
&SerMode::Json,
Expand Down
37 changes: 17 additions & 20 deletions src/serializers/ob_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,29 +59,26 @@ pub enum IsType {

impl ObTypeLookup {
fn new(py: Python) -> Self {
let lib_url = url::Url::parse("https://example.com").unwrap();
Self {
none: py.None().as_ref(py).get_type_ptr() as usize,
int: 0i32.into_py(py).as_ref(py).get_type_ptr() as usize,
bool: true.into_py(py).as_ref(py).get_type_ptr() as usize,
float: 0f32.into_py(py).as_ref(py).get_type_ptr() as usize,
list: PyList::empty(py).get_type_ptr() as usize,
dict: PyDict::new(py).get_type_ptr() as usize,
int: PyInt::type_object_raw(py) as usize,
bool: PyBool::type_object_raw(py) as usize,
float: PyFloat::type_object_raw(py) as usize,
list: PyList::type_object_raw(py) as usize,
dict: PyDict::type_object_raw(py) as usize,
decimal_object: py.import("decimal").unwrap().getattr("Decimal").unwrap().to_object(py),
string: PyString::new(py, "s").get_type_ptr() as usize,
bytes: PyBytes::new(py, b"s").get_type_ptr() as usize,
bytearray: PyByteArray::new(py, b"s").get_type_ptr() as usize,
tuple: PyTuple::empty(py).get_type_ptr() as usize,
set: PySet::empty(py).unwrap().get_type_ptr() as usize,
frozenset: PyFrozenSet::empty(py).unwrap().get_type_ptr() as usize,
datetime: PyDateTime::new(py, 2000, 1, 1, 0, 0, 0, 0, None)
.unwrap()
.get_type_ptr() as usize,
date: PyDate::new(py, 2000, 1, 1).unwrap().get_type_ptr() as usize,
time: PyTime::new(py, 0, 0, 0, 0, None).unwrap().get_type_ptr() as usize,
timedelta: PyDelta::new(py, 0, 0, 0, false).unwrap().get_type_ptr() as usize,
url: PyUrl::new(lib_url.clone()).into_py(py).as_ref(py).get_type_ptr() as usize,
multi_host_url: PyMultiHostUrl::new(lib_url, None).into_py(py).as_ref(py).get_type_ptr() as usize,
string: PyString::type_object_raw(py) as usize,
bytes: PyBytes::type_object_raw(py) as usize,
bytearray: PyByteArray::type_object_raw(py) as usize,
tuple: PyTuple::type_object_raw(py) as usize,
set: PySet::type_object_raw(py) as usize,
frozenset: PyFrozenSet::type_object_raw(py) as usize,
datetime: PyDateTime::type_object_raw(py) as usize,
date: PyDate::type_object_raw(py) as usize,
time: PyTime::type_object_raw(py) as usize,
timedelta: PyDelta::type_object_raw(py) as usize,
url: PyUrl::type_object_raw(py) as usize,
multi_host_url: PyMultiHostUrl::type_object_raw(py) as usize,
enum_object: py.import("enum").unwrap().getattr("Enum").unwrap().to_object(py),
generator_object: py
.import("types")
Expand Down
48 changes: 25 additions & 23 deletions src/serializers/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ combined_serializer! {
Union: super::type_serializers::union::UnionSerializer;
Literal: super::type_serializers::literal::LiteralSerializer;
Recursive: super::type_serializers::definitions::DefinitionRefSerializer;
TuplePositional: super::type_serializers::tuple::TuplePositionalSerializer;
TupleVariable: super::type_serializers::tuple::TupleVariableSerializer;
Tuple: super::type_serializers::tuple::TupleSerializer;
}
}

Expand Down Expand Up @@ -248,8 +247,7 @@ impl PyGcTraverse for CombinedSerializer {
CombinedSerializer::Union(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Literal(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Recursive(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::TuplePositional(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::TupleVariable(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Tuple(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Uuid(inner) => inner.py_gc_traverse(visit),
}
}
Expand Down Expand Up @@ -366,29 +364,33 @@ pub(crate) fn to_json_bytes(
Ok(bytes)
}

pub(super) fn any_dataclass_iter<'py>(
dataclass: &'py PyAny,
) -> PyResult<(impl Iterator<Item = PyResult<(&'py PyAny, &'py PyAny)>> + 'py, &PyDict)> {
let py = dataclass.py();
let fields: &PyDict = dataclass.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?;
let field_type_marker = get_field_marker(py)?;

let next = move |(field_name, field): (&'py PyAny, &'py PyAny)| -> PyResult<Option<(&'py PyAny, &'py PyAny)>> {
let field_type = field.getattr(intern!(py, "_field_type"))?;
if field_type.is(field_type_marker) {
let field_name: &PyString = field_name.downcast()?;
let value = dataclass.getattr(field_name)?;
Ok(Some((field_name, value)))
} else {
Ok(None)
}
};

Ok((fields.iter().filter_map(move |field| next(field).transpose()), fields))
}

static DC_FIELD_MARKER: GILOnceCell<PyObject> = GILOnceCell::new();

/// needed to match the logic from dataclasses.fields `tuple(f for f in fields.values() if f._field_type is _FIELD)`
pub(super) fn get_field_marker(py: Python<'_>) -> PyResult<&PyAny> {
fn get_field_marker(py: Python<'_>) -> PyResult<&PyAny> {
let field_type_marker_obj = DC_FIELD_MARKER.get_or_try_init(py, || {
let field_ = py.import("dataclasses")?.getattr("_FIELD")?;
Ok::<PyObject, PyErr>(field_.into_py(py))
py.import("dataclasses")?.getattr("_FIELD").map(|f| f.into_py(py))
})?;
Ok(field_type_marker_obj.as_ref(py))
}

pub(super) fn dataclass_to_dict(dc: &PyAny) -> PyResult<&PyDict> {
let py = dc.py();
let dc_fields: &PyDict = dc.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?;
let dict = PyDict::new(py);

let field_type_marker = get_field_marker(py)?;
for (field_name, field) in dc_fields {
let field_type = field.getattr(intern!(py, "_field_type"))?;
if field_type.is(field_type_marker) {
let field_name: &PyString = field_name.downcast()?;
dict.set_item(field_name, dc.getattr(field_name)?)?;
}
}
Ok(dict)
}
17 changes: 10 additions & 7 deletions src/serializers/type_serializers/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,28 @@ use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict};

use crate::definitions::DefinitionsBuilder;
use crate::serializers::config::{BytesMode, FromConfig};

use super::{
infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, SerMode,
TypeSerializer,
};

#[derive(Debug, Clone)]
pub struct BytesSerializer;
pub struct BytesSerializer {
bytes_mode: BytesMode,
}

impl BuildSerializer for BytesSerializer {
const EXPECTED_TYPE: &'static str = "bytes";

fn build(
_schema: &PyDict,
_config: Option<&PyDict>,
config: Option<&PyDict>,
_definitions: &mut DefinitionsBuilder<CombinedSerializer>,
) -> PyResult<CombinedSerializer> {
Ok(Self {}.into())
let bytes_mode = BytesMode::from_config(config)?;
Ok(Self { bytes_mode }.into())
}
}

Expand All @@ -38,8 +42,7 @@ impl TypeSerializer for BytesSerializer {
let py = value.py();
match value.downcast::<PyBytes>() {
Ok(py_bytes) => match extra.mode {
SerMode::Json => extra
.config
SerMode::Json => self
.bytes_mode
.bytes_to_string(py, py_bytes.as_bytes())
.map(|s| s.into_py(py)),
Expand All @@ -54,7 +57,7 @@ impl TypeSerializer for BytesSerializer {

fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult<Cow<'py, str>> {
match key.downcast::<PyBytes>() {
Ok(py_bytes) => extra.config.bytes_mode.bytes_to_string(key.py(), py_bytes.as_bytes()),
Ok(py_bytes) => self.bytes_mode.bytes_to_string(key.py(), py_bytes.as_bytes()),
Err(_) => {
extra.warnings.on_fallback_py(self.get_name(), key, extra)?;
infer_json_key(key, extra)
Expand All @@ -71,7 +74,7 @@ impl TypeSerializer for BytesSerializer {
extra: &Extra,
) -> Result<S::Ok, S::Error> {
match value.downcast::<PyBytes>() {
Ok(py_bytes) => extra.config.bytes_mode.serialize_bytes(py_bytes.as_bytes(), serializer),
Ok(py_bytes) => self.bytes_mode.serialize_bytes(py_bytes.as_bytes(), serializer),
Err(_) => {
extra.warnings.on_fallback_ser::<S>(self.get_name(), value, extra)?;
infer_serialize(value, serializer, include, exclude, extra)
Expand Down
76 changes: 59 additions & 17 deletions src/serializers/type_serializers/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use pyo3::types::{PyDict, PyList, PyString, PyType};
use std::borrow::Cow;

use ahash::AHashMap;
use serde::ser::SerializeMap;

use crate::build_tools::{py_schema_error_type, ExtraBehavior};
use crate::definitions::DefinitionsBuilder;
Expand Down Expand Up @@ -39,7 +40,7 @@ impl BuildSerializer for DataclassArgsBuilder {
let field_info: &PyDict = item.downcast()?;
let name: String = field_info.get_as_req(intern!(py, "name"))?;

let key_py: Py<PyString> = PyString::intern(py, &name).into_py(py);
let key_py: Py<PyString> = PyString::new(py, &name).into_py(py);

if field_info.get_as(intern!(py, "serialization_exclude"))? == Some(true) {
fields.insert(name, SerField::new(py, key_py, None, None, true));
Expand Down Expand Up @@ -131,16 +132,30 @@ impl TypeSerializer for DataclassSerializer {
exclude: Option<&PyAny>,
extra: &Extra,
) -> PyResult<PyObject> {
let extra = Extra {
let dc_extra = Extra {
model: Some(value),
..*extra
};
if self.allow_value(value, &extra)? {
let inner_value = self.get_inner_value(value)?;
self.serializer.to_python(inner_value, include, exclude, &extra)
if self.allow_value(value, extra)? {
let py = value.py();
if let CombinedSerializer::Fields(ref fields_serializer) = *self.serializer {
let output_dict = fields_serializer.main_to_python(
py,
known_dataclass_iter(&self.fields, value),
include,
exclude,
dc_extra,
)?;

fields_serializer.add_computed_fields_python(Some(value), output_dict, include, exclude, extra)?;
Ok(output_dict.into_py(py))
} else {
let inner_value = self.get_inner_value(value)?;
self.serializer.to_python(inner_value, include, exclude, &dc_extra)
}
} else {
extra.warnings.on_fallback_py(self.get_name(), value, &extra)?;
infer_to_python(value, include, exclude, &extra)
extra.warnings.on_fallback_py(self.get_name(), value, &dc_extra)?;
infer_to_python(value, include, exclude, &dc_extra)
}
}

Expand All @@ -161,17 +176,29 @@ impl TypeSerializer for DataclassSerializer {
exclude: Option<&PyAny>,
extra: &Extra,
) -> Result<S::Ok, S::Error> {
let extra = Extra {
model: Some(value),
..*extra
};
if self.allow_value(value, &extra).map_err(py_err_se_err)? {
let inner_value = self.get_inner_value(value).map_err(py_err_se_err)?;
self.serializer
.serde_serialize(inner_value, serializer, include, exclude, &extra)
let model = Some(value);
let dc_extra = Extra { model, ..*extra };
if self.allow_value(value, extra).map_err(py_err_se_err)? {
if let CombinedSerializer::Fields(ref fields_serializer) = *self.serializer {
let expected_len = self.fields.len() + fields_serializer.computed_field_count();
let mut map = fields_serializer.main_serde_serialize(
known_dataclass_iter(&self.fields, value),
expected_len,
serializer,
include,
exclude,
dc_extra,
)?;
fields_serializer.add_computed_fields_json::<S>(model, &mut map, include, exclude, extra)?;
map.end()
} else {
let inner_value = self.get_inner_value(value).map_err(py_err_se_err)?;
self.serializer
.serde_serialize(inner_value, serializer, include, exclude, extra)
}
} else {
extra.warnings.on_fallback_ser::<S>(self.get_name(), value, &extra)?;
infer_serialize(value, serializer, include, exclude, &extra)
extra.warnings.on_fallback_ser::<S>(self.get_name(), value, extra)?;
infer_serialize(value, serializer, include, exclude, extra)
}
}

Expand All @@ -183,3 +210,18 @@ impl TypeSerializer for DataclassSerializer {
true
}
}

fn known_dataclass_iter<'a, 'py>(
fields: &'a [Py<PyString>],
dataclass: &'py PyAny,
) -> impl Iterator<Item = PyResult<(&'py PyAny, &'py PyAny)>> + 'a
where
'py: 'a,
{
let py = dataclass.py();
fields.iter().map(move |field| {
let field_ref = field.clone_ref(py).into_ref(py);
let value = dataclass.getattr(field_ref)?;
Ok((field_ref as &PyAny, value))
})
}
33 changes: 16 additions & 17 deletions src/serializers/type_serializers/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,17 @@ impl TypeSerializer for DefinitionRefSerializer {
value: &PyAny,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
mut extra: &Extra,
) -> PyResult<PyObject> {
let comb_serializer = self.definition.get().unwrap();
let value_id = extra.rec_guard.add(value, self.definition.id())?;
let r = comb_serializer.to_python(value, include, exclude, extra);
extra.rec_guard.pop(value_id, self.definition.id());
r
self.definition.read(|comb_serializer| {
let comb_serializer = comb_serializer.unwrap();
let mut guard = extra.recursion_guard(value, self.definition.id())?;
comb_serializer.to_python(value, include, exclude, guard.state())
})
}

fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult<Cow<'py, str>> {
self.definition.get().unwrap().json_key(key, extra)
self.definition.read(|s| s.unwrap().json_key(key, extra))
}

fn serde_serialize<S: serde::ser::Serializer>(
Expand All @@ -85,23 +85,22 @@ impl TypeSerializer for DefinitionRefSerializer {
serializer: S,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
mut extra: &Extra,
) -> Result<S::Ok, S::Error> {
let comb_serializer = self.definition.get().unwrap();
let value_id = extra
.rec_guard
.add(value, self.definition.id())
.map_err(py_err_se_err)?;
let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra);
extra.rec_guard.pop(value_id, self.definition.id());
r
self.definition.read(|comb_serializer| {
let comb_serializer = comb_serializer.unwrap();
let mut guard = extra
.recursion_guard(value, self.definition.id())
.map_err(py_err_se_err)?;
comb_serializer.serde_serialize(value, serializer, include, exclude, guard.state())
})
}

fn get_name(&self) -> &str {
Self::EXPECTED_TYPE
}

fn retry_with_lax_check(&self) -> bool {
self.definition.get().unwrap().retry_with_lax_check()
self.definition.read(|s| s.unwrap().retry_with_lax_check())
}
}
4 changes: 2 additions & 2 deletions src/serializers/type_serializers/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl BuildSerializer for LiteralSerializer {
repr_args.push(item.repr()?.extract()?);
if let Ok(bool) = item.downcast::<PyBool>() {
expected_py.append(bool)?;
} else if let Ok(int) = extract_i64(item) {
} else if let Some(int) = extract_i64(item) {
expected_int.insert(int);
} else if let Ok(py_str) = item.downcast::<PyString>() {
expected_str.insert(py_str.to_str()?.to_string());
Expand Down Expand Up @@ -79,7 +79,7 @@ impl LiteralSerializer {
fn check<'a>(&self, value: &'a PyAny, extra: &Extra) -> PyResult<OutputValue<'a>> {
if extra.check.enabled() {
if !self.expected_int.is_empty() && !PyBool::is_type_of(value) {
if let Ok(int) = extract_i64(value) {
if let Some(int) = extract_i64(value) {
if self.expected_int.contains(&int) {
return Ok(OutputValue::OkInt(int));
}
Expand Down
28 changes: 18 additions & 10 deletions src/serializers/type_serializers/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ use std::borrow::Cow;

use serde::Serialize;

use crate::PydanticSerializationUnexpectedValue;
use crate::{definitions::DefinitionsBuilder, input::Int};

use super::{
infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, IsType, ObType,
SerMode, TypeSerializer,
SerCheck, SerMode, TypeSerializer,
};

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -85,7 +86,7 @@ impl TypeSerializer for NoneSerializer {
}

macro_rules! build_simple_serializer {
($struct_name:ident, $expected_type:literal, $rust_type:ty, $ob_type:expr, $key_method:ident) => {
($struct_name:ident, $expected_type:literal, $rust_type:ty, $ob_type:expr, $key_method:ident, $subtypes_allowed:expr) => {
#[derive(Debug, Clone)]
pub struct $struct_name;

Expand Down Expand Up @@ -114,12 +115,15 @@ macro_rules! build_simple_serializer {
let py = value.py();
match extra.ob_type_lookup.is_type(value, $ob_type) {
IsType::Exact => Ok(value.into_py(py)),
IsType::Subclass => match extra.mode {
SerMode::Json => {
let rust_value = value.extract::<$rust_type>()?;
Ok(rust_value.to_object(py))
}
_ => infer_to_python(value, include, exclude, extra),
IsType::Subclass => match extra.check {
SerCheck::Strict => Err(PydanticSerializationUnexpectedValue::new_err(None)),
SerCheck::Lax | SerCheck::None => match extra.mode {
SerMode::Json => {
let rust_value = value.extract::<$rust_type>()?;
Ok(rust_value.to_object(py))
}
_ => infer_to_python(value, include, exclude, extra),
},
},
IsType::False => {
extra.warnings.on_fallback_py(self.get_name(), value, extra)?;
Expand Down Expand Up @@ -160,6 +164,10 @@ macro_rules! build_simple_serializer {
fn get_name(&self) -> &str {
Self::EXPECTED_TYPE
}

fn retry_with_lax_check(&self) -> bool {
$subtypes_allowed
}
}
};
}
Expand All @@ -168,7 +176,7 @@ pub(crate) fn to_str_json_key(key: &PyAny) -> PyResult<Cow<str>> {
Ok(key.str()?.to_string_lossy())
}

build_simple_serializer!(IntSerializer, "int", Int, ObType::Int, to_str_json_key);
build_simple_serializer!(IntSerializer, "int", Int, ObType::Int, to_str_json_key, true);

pub(crate) fn bool_json_key(key: &PyAny) -> PyResult<Cow<str>> {
let v = if key.is_true().unwrap_or(false) {
Expand All @@ -179,4 +187,4 @@ pub(crate) fn bool_json_key(key: &PyAny) -> PyResult<Cow<str>> {
Ok(Cow::Borrowed(v))
}

build_simple_serializer!(BoolSerializer, "bool", bool, ObType::Bool, bool_json_key);
build_simple_serializer!(BoolSerializer, "bool", bool, ObType::Bool, bool_json_key, false);
26 changes: 12 additions & 14 deletions src/serializers/type_serializers/timedelta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,28 @@ use pyo3::types::PyDict;

use crate::definitions::DefinitionsBuilder;
use crate::input::EitherTimedelta;
use crate::serializers::config::{FromConfig, TimedeltaMode};

use super::{
infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, SerMode,
TypeSerializer,
};

#[derive(Debug, Clone)]
pub struct TimeDeltaSerializer;
pub struct TimeDeltaSerializer {
timedelta_mode: TimedeltaMode,
}

impl BuildSerializer for TimeDeltaSerializer {
const EXPECTED_TYPE: &'static str = "timedelta";

fn build(
_schema: &PyDict,
_config: Option<&PyDict>,
config: Option<&PyDict>,
_definitions: &mut DefinitionsBuilder<CombinedSerializer>,
) -> PyResult<CombinedSerializer> {
Ok(Self {}.into())
let timedelta_mode = TimedeltaMode::from_config(config)?;
Ok(Self { timedelta_mode }.into())
}
}

Expand All @@ -38,10 +42,7 @@ impl TypeSerializer for TimeDeltaSerializer {
) -> PyResult<PyObject> {
match extra.mode {
SerMode::Json => match EitherTimedelta::try_from(value) {
Ok(either_timedelta) => extra
.config
.timedelta_mode
.either_delta_to_json(value.py(), &either_timedelta),
Ok(either_timedelta) => self.timedelta_mode.either_delta_to_json(value.py(), &either_timedelta),
Err(_) => {
extra.warnings.on_fallback_py(self.get_name(), value, extra)?;
infer_to_python(value, include, exclude, extra)
Expand All @@ -53,7 +54,7 @@ impl TypeSerializer for TimeDeltaSerializer {

fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult<Cow<'py, str>> {
match EitherTimedelta::try_from(key) {
Ok(either_timedelta) => extra.config.timedelta_mode.json_key(key.py(), &either_timedelta),
Ok(either_timedelta) => self.timedelta_mode.json_key(key.py(), &either_timedelta),
Err(_) => {
extra.warnings.on_fallback_py(self.get_name(), key, extra)?;
infer_json_key(key, extra)
Expand All @@ -70,12 +71,9 @@ impl TypeSerializer for TimeDeltaSerializer {
extra: &Extra,
) -> Result<S::Ok, S::Error> {
match EitherTimedelta::try_from(value) {
Ok(either_timedelta) => {
extra
.config
.timedelta_mode
.timedelta_serialize(value.py(), &either_timedelta, serializer)
}
Ok(either_timedelta) => self
.timedelta_mode
.timedelta_serialize(value.py(), &either_timedelta, serializer),
Err(_) => {
extra.warnings.on_fallback_ser::<S>(self.get_name(), value, extra)?;
infer_serialize(value, serializer, include, exclude, extra)
Expand Down
349 changes: 141 additions & 208 deletions src/serializers/type_serializers/tuple.rs

Large diffs are not rendered by default.

28 changes: 21 additions & 7 deletions src/tools.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::borrow::Cow;

use pyo3::exceptions::{PyKeyError, PyTypeError};
use pyo3::exceptions::PyKeyError;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyInt, PyString};
use pyo3::{intern, FromPyObject, PyTypeInfo};
use pyo3::types::{PyDict, PyString};
use pyo3::{ffi, intern, FromPyObject};

pub trait SchemaDict<'py> {
fn get_as<T>(&'py self, key: &PyString) -> PyResult<Option<T>>
Expand Down Expand Up @@ -99,10 +99,24 @@ pub fn safe_repr(v: &PyAny) -> Cow<str> {
}
}

pub fn extract_i64(v: &PyAny) -> PyResult<i64> {
if PyInt::is_type_of(v) {
v.extract()
/// Extract an i64 from a python object more quickly, see
/// https://github.com/PyO3/pyo3/pull/3742#discussion_r1451763928
#[cfg(not(any(target_pointer_width = "32", windows, PyPy)))]
pub fn extract_i64(obj: &PyAny) -> Option<i64> {
let val = unsafe { ffi::PyLong_AsLong(obj.as_ptr()) };
if val == -1 && PyErr::occurred(obj.py()) {
unsafe { ffi::PyErr_Clear() };
None
} else {
py_err!(PyTypeError; "expected int, got {}", safe_repr(v))
Some(val)
}
}

#[cfg(any(target_pointer_width = "32", windows, PyPy))]
pub fn extract_i64(v: &PyAny) -> Option<i64> {
if v.is_instance_of::<pyo3::types::PyInt>() {
v.extract().ok()
} else {
None
}
}
2 changes: 1 addition & 1 deletion src/validators/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl Validator for AnyValidator {
py: Python<'data>,
input: &'data impl Input<'data>,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
) -> ValResult<PyObject> {
// in a union, Any should be preferred to doing lax coercions
state.floor_exactness(Exactness::Strict);
Ok(input.to_object(py))
Expand Down
23 changes: 13 additions & 10 deletions src/validators/arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use pyo3::types::{PyDict, PyList, PyString, PyTuple};
use ahash::AHashSet;

use crate::build_tools::py_schema_err;
use crate::build_tools::schema_or_config_same;
use crate::build_tools::{schema_or_config_same, ExtraBehavior};
use crate::errors::{AsLocItem, ErrorTypeDefaults, ValError, ValLineError, ValResult};
use crate::input::{GenericArguments, Input, ValidationMatch};
use crate::lookup_key::LookupKey;
Expand All @@ -31,6 +31,7 @@ pub struct ArgumentsValidator {
var_args_validator: Option<Box<CombinedValidator>>,
var_kwargs_validator: Option<Box<CombinedValidator>>,
loc_by_alias: bool,
extra: ExtraBehavior,
}

impl BuildValidator for ArgumentsValidator {
Expand Down Expand Up @@ -73,7 +74,7 @@ impl BuildValidator for ArgumentsValidator {
}
None => Some(LookupKey::from_string(py, &name)),
};
kwarg_key = Some(PyString::intern(py, &name).into());
kwarg_key = Some(PyString::new(py, &name).into());
}

let schema: &PyAny = arg.get_as_req(intern!(py, "schema"))?;
Expand Down Expand Up @@ -119,6 +120,7 @@ impl BuildValidator for ArgumentsValidator {
None => None,
},
loc_by_alias: config.get_as(intern!(py, "loc_by_alias"))?.unwrap_or(true),
extra: ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Forbid)?,
}
.into())
}
Expand Down Expand Up @@ -166,7 +168,7 @@ impl Validator for ArgumentsValidator {
py: Python<'data>,
input: &'data impl Input<'data>,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
) -> ValResult<PyObject> {
let args = input.validate_args()?;

let mut output_args: Vec<PyObject> = Vec::with_capacity(self.positional_params_count);
Expand Down Expand Up @@ -307,15 +309,16 @@ impl Validator for ArgumentsValidator {
Err(err) => return Err(err),
},
None => {
errors.push(ValLineError::new_with_loc(
ErrorTypeDefaults::UnexpectedKeywordArgument,
value,
raw_key.as_loc_item(),
));
if let ExtraBehavior::Forbid = self.extra {
errors.push(ValLineError::new_with_loc(
ErrorTypeDefaults::UnexpectedKeywordArgument,
value,
raw_key.as_loc_item(),
));
}
}
}
}
}
}}
}
}
}};
Expand Down
2 changes: 1 addition & 1 deletion src/validators/bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl Validator for BoolValidator {
py: Python<'data>,
input: &'data impl Input<'data>,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
) -> ValResult<PyObject> {
// TODO in theory this could be quicker if we used PyBool rather than going to a bool
// and back again, might be worth profiling?
input
Expand Down
4 changes: 2 additions & 2 deletions src/validators/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl Validator for BytesValidator {
py: Python<'data>,
input: &'data impl Input<'data>,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
) -> ValResult<PyObject> {
input
.validate_bytes(state.strict_or(self.strict))
.map(|m| m.unpack(state).into_py(py))
Expand All @@ -71,7 +71,7 @@ impl Validator for BytesConstrainedValidator {
py: Python<'data>,
input: &'data impl Input<'data>,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
) -> ValResult<PyObject> {
let either_bytes = input.validate_bytes(state.strict_or(self.strict))?.unpack(state);
let len = either_bytes.len()?;

Expand Down
3 changes: 2 additions & 1 deletion src/validators/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ impl BuildValidator for CallValidator {
}

impl_py_gc_traverse!(CallValidator {
function,
arguments_validator,
return_validator
});
Expand All @@ -77,7 +78,7 @@ impl Validator for CallValidator {
py: Python<'data>,
input: &'data impl Input<'data>,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
) -> ValResult<PyObject> {
let args = self.arguments_validator.validate(py, input, state)?;

let return_value = if let Ok((args, kwargs)) = args.extract::<(&PyTuple, &PyDict)>(py) {
Expand Down
2 changes: 1 addition & 1 deletion src/validators/callable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl Validator for CallableValidator {
py: Python<'data>,
input: &'data impl Input<'data>,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
) -> ValResult<PyObject> {
state.floor_exactness(Exactness::Lax);
match input.callable() {
true => Ok(input.to_object(py)),
Expand Down
2 changes: 1 addition & 1 deletion src/validators/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl Validator for ChainValidator {
py: Python<'data>,
input: &'data impl Input<'data>,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
) -> ValResult<PyObject> {
let mut steps_iter = self.steps.iter();
let first_step = steps_iter.next().unwrap();
let value = first_step.validate(py, input, state)?;
Expand Down
5 changes: 3 additions & 2 deletions src/validators/custom_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use pyo3::prelude::*;
use pyo3::types::PyDict;

use crate::build_tools::py_schema_err;
use crate::errors::AsErrorValue;
use crate::errors::{ErrorType, PydanticCustomError, PydanticKnownError, ValError, ValResult};
use crate::input::Input;
use crate::tools::SchemaDict;
Expand Down Expand Up @@ -49,7 +50,7 @@ impl CustomError {
}
}

pub fn as_val_error<'a>(&self, input: &'a impl Input<'a>) -> ValError<'a> {
pub fn as_val_error(&self, input: &impl AsErrorValue) -> ValError {
match self {
CustomError::KnownError(ref known_error) => known_error.clone().into_val_error(input),
CustomError::Custom(ref custom_error) => custom_error.clone().into_val_error(input),
Expand Down Expand Up @@ -93,7 +94,7 @@ impl Validator for CustomErrorValidator {
py: Python<'data>,
input: &'data impl Input<'data>,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
) -> ValResult<PyObject> {
self.validator
.validate(py, input, state)
.map_err(|_| self.custom_error.as_val_error(input))
Expand Down
45 changes: 31 additions & 14 deletions src/validators/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use ahash::AHashSet;
use crate::build_tools::py_schema_err;
use crate::build_tools::{is_strict, schema_or_config_same, ExtraBehavior};
use crate::errors::{AsLocItem, ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult};
use crate::input::InputType;
use crate::input::{BorrowInput, GenericArguments, Input, ValidationMatch};
use crate::lookup_key::LookupKey;
use crate::tools::SchemaDict;
Expand All @@ -25,6 +26,7 @@ struct Field {
kw_only: bool,
name: String,
py_name: Py<PyString>,
init: bool,
init_only: bool,
lookup_key: LookupKey,
validator: CombinedValidator,
Expand Down Expand Up @@ -106,6 +108,7 @@ impl BuildValidator for DataclassArgsValidator {
py_name: py_name.into(),
lookup_key,
validator,
init: field.get_as(intern!(py, "init"))?.unwrap_or(true),
init_only: field.get_as(intern!(py, "init_only"))?.unwrap_or(false),
frozen: field.get_as::<bool>(intern!(py, "frozen"))?.unwrap_or(false),
});
Expand Down Expand Up @@ -143,7 +146,7 @@ impl Validator for DataclassArgsValidator {
py: Python<'data>,
input: &'data impl Input<'data>,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
) -> ValResult<PyObject> {
let args = input.validate_dataclass_args(&self.dataclass_name)?;

let output_dict = PyDict::new(py);
Expand Down Expand Up @@ -175,6 +178,23 @@ impl Validator for DataclassArgsValidator {
($args:ident, $get_method:ident, $get_macro:ident, $slice_macro:ident) => {{
// go through fields getting the value from args or kwargs and validating it
for (index, field) in self.fields.iter().enumerate() {
if (!field.init) {
match field.validator.default_value(py, Some(field.name.as_str()), state) {
Ok(Some(value)) => {
// Default value exists, and passed validation if required
set_item!(field, value);
},
Ok(None) | Err(ValError::Omit) => continue,
// Note: this will always use the field name even if there is an alias
// However, we don't mind so much because this error can only happen if the
// default value fails validation, which is arguably a developer error.
// We could try to "fix" this in the future if desired.
Err(ValError::LineErrors(line_errors)) => errors.extend(line_errors),
Err(err) => return Err(err),
};
continue;
};

let mut pos_value = None;
if let Some(args) = $args.args {
if !field.kw_only {
Expand All @@ -201,8 +221,7 @@ impl Validator for DataclassArgsValidator {
ErrorTypeDefaults::MultipleArgumentValues,
kw_value,
field.name.clone(),
)
.into_owned(py),
),
);
}
// found a positional argument, validate it
Expand All @@ -225,10 +244,9 @@ impl Validator for DataclassArgsValidator {
errors.extend(line_errors.into_iter().map(|err| {
lookup_path
.apply_error_loc(err, self.loc_by_alias, &field.name)
.into_owned(py)
}));
}
Err(err) => return Err(err.into_owned(py)),
Err(err) => return Err(err),
}
}
// found neither, check if there is a default value, otherwise error
Expand Down Expand Up @@ -293,8 +311,7 @@ impl Validator for DataclassArgsValidator {
ErrorTypeDefaults::UnexpectedKeywordArgument,
value,
raw_key.as_loc_item(),
)
.into_owned(py),
),
);
}
ExtraBehavior::Ignore => {}
Expand Down Expand Up @@ -374,7 +391,7 @@ impl Validator for DataclassArgsValidator {
field_name: &'data str,
field_value: &'data PyAny,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
) -> ValResult<PyObject> {
let dict: &PyDict = obj.downcast()?;

let ok = |output: PyObject| {
Expand Down Expand Up @@ -479,7 +496,7 @@ impl BuildValidator for DataclassValidator {
let validator = build_validator(sub_schema, config, definitions)?;

let post_init = if schema.get_as::<bool>(intern!(py, "post_init"))?.unwrap_or(false) {
Some(PyString::intern(py, "__post_init__").into_py(py))
Some(PyString::new(py, "__post_init__").into_py(py))
} else {
None
};
Expand Down Expand Up @@ -517,7 +534,7 @@ impl Validator for DataclassValidator {
py: Python<'data>,
input: &'data impl Input<'data>,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
) -> ValResult<PyObject> {
if let Some(self_instance) = state.extra().self_instance {
// in the case that self_instance is Some, we're calling validation from within `BaseModel.__init__`
return self.validate_init(py, self_instance, input, state);
Expand All @@ -535,7 +552,7 @@ impl Validator for DataclassValidator {
} else {
Ok(input.to_object(py))
}
} else if state.strict_or(self.strict) && input.is_python() {
} else if state.strict_or(self.strict) && state.extra().input_type == InputType::Python {
Err(ValError::new(
ErrorType::DataclassExactType {
class_name: self.get_name().to_string(),
Expand All @@ -559,7 +576,7 @@ impl Validator for DataclassValidator {
field_name: &'data str,
field_value: &'data PyAny,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
) -> ValResult<PyObject> {
if self.frozen {
return Err(ValError::new(ErrorTypeDefaults::FrozenInstance, field_value));
}
Expand Down Expand Up @@ -599,7 +616,7 @@ impl DataclassValidator {
self_instance: &'s PyAny,
input: &'data impl Input<'data>,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
) -> ValResult<PyObject> {
// we need to set `self_instance` to None for nested validators as we don't want to operate on the self_instance
// instance anymore
let state = &mut state.rebind_extra(|extra| extra.self_instance = None);
Expand All @@ -626,7 +643,7 @@ impl DataclassValidator {
dc: &PyAny,
val_output: PyObject,
input: &'data impl Input<'data>,
) -> ValResult<'data, ()> {
) -> ValResult<()> {
let (dc_dict, post_init_kwargs): (&PyAny, &PyAny) = val_output.extract(py)?;
if self.slots {
let dc_dict: &PyDict = dc_dict.downcast()?;
Expand Down
4 changes: 2 additions & 2 deletions src/validators/date.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ impl Validator for DateValidator {
py: Python<'data>,
input: &'data impl Input<'data>,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
) -> ValResult<PyObject> {
let strict = state.strict_or(self.strict);
let date = match input.validate_date(strict) {
Ok(val_match) => val_match.unpack(state),
Expand Down Expand Up @@ -109,7 +109,7 @@ impl Validator for DateValidator {
/// "exact date", e.g. has a zero time component.
///
/// Ok(None) means that this is not relevant to dates (the input was not a datetime nor a string)
fn date_from_datetime<'data>(input: &'data impl Input<'data>) -> Result<Option<EitherDate<'data>>, ValError<'data>> {
fn date_from_datetime<'data>(input: &'data impl Input<'data>) -> Result<Option<EitherDate<'data>>, ValError> {
let either_dt = match input.validate_datetime(false, speedate::MicrosecondsPrecisionOverflowBehavior::Truncate) {
Ok(val_match) => val_match.into_inner(),
// if the error was a parsing error, update the error type from DatetimeParsing to DateFromDatetimeParsing
Expand Down
61 changes: 55 additions & 6 deletions src/validators/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use pyo3::intern;
use pyo3::once_cell::GILOnceCell;
use pyo3::prelude::*;
use pyo3::types::{PyDateTime, PyDict, PyString};
use speedate::DateTime;
use speedate::{DateTime, Time};
use std::cmp::Ordering;
use strum::EnumMessage;

Expand All @@ -13,6 +13,7 @@ use crate::input::{EitherDateTime, Input};

use crate::tools::SchemaDict;

use super::Exactness;
use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -63,11 +64,17 @@ impl Validator for DateTimeValidator {
py: Python<'data>,
input: &'data impl Input<'data>,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
) -> ValResult<PyObject> {
let strict = state.strict_or(self.strict);
let datetime = input
.validate_datetime(strict, self.microseconds_precision)?
.unpack(state);
let datetime = match input.validate_datetime(strict, self.microseconds_precision) {
Ok(val_match) => val_match.unpack(state),
// if the error was a parsing error, in lax mode we allow dates and add the time 00:00:00
Err(line_errors @ ValError::LineErrors(..)) if !strict => {
state.floor_exactness(Exactness::Lax);
datetime_from_date(input)?.ok_or(line_errors)?
}
Err(otherwise) => return Err(otherwise),
};
if let Some(constraints) = &self.constraints {
// if we get an error from as_speedate, it's probably because the input datetime was invalid
// specifically had an invalid tzinfo, hence here we return a validation error
Expand Down Expand Up @@ -132,6 +139,48 @@ impl Validator for DateTimeValidator {
}
}

/// In lax mode, if the input is not a datetime, we try parsing the input as a date and add the "00:00:00" time.
///
/// Ok(None) means that this is not relevant to datetimes (the input was not a date nor a string)
fn datetime_from_date<'data>(input: &'data impl Input<'data>) -> Result<Option<EitherDateTime<'data>>, ValError> {
let either_date = match input.validate_date(false) {
Ok(val_match) => val_match.into_inner(),
// if the error was a parsing error, update the error type from DateParsing to DatetimeFromDateParsing
Err(ValError::LineErrors(mut line_errors)) => {
if line_errors.iter_mut().fold(false, |has_parsing_error, line_error| {
if let ErrorType::DateParsing { error, .. } = &mut line_error.error_type {
line_error.error_type = ErrorType::DatetimeFromDateParsing {
error: std::mem::take(error),
context: None,
};
true
} else {
has_parsing_error
}
}) {
return Err(ValError::LineErrors(line_errors));
}
return Ok(None);
}
// for any other error, don't return it
Err(_) => return Ok(None),
};

let zero_time = Time {
hour: 0,
minute: 0,
second: 0,
microsecond: 0,
tz_offset: Some(0),
};

let datetime = DateTime {
date: either_date.as_raw()?,
time: zero_time,
};
Ok(Some(EitherDateTime::Raw(datetime)))
}

#[derive(Debug, Clone)]
struct DateTimeConstraints {
le: Option<DateTime>,
Expand Down Expand Up @@ -263,7 +312,7 @@ impl TZConstraint {
}
}

pub(super) fn tz_check<'d>(&self, tz_offset: Option<i32>, input: &'d impl Input<'d>) -> ValResult<'d, ()> {
pub(super) fn tz_check<'d>(&self, tz_offset: Option<i32>, input: &'d impl Input<'d>) -> ValResult<()> {
match (self, tz_offset) {
(TZConstraint::Aware(_), None) => return Err(ValError::new(ErrorTypeDefaults::TimezoneAware, input)),
(TZConstraint::Aware(Some(tz_expected)), Some(tz_actual)) => {
Expand Down
18 changes: 5 additions & 13 deletions src/validators/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,7 @@ impl_py_gc_traverse!(DecimalValidator {
gt
});

fn extract_decimal_digits_info<'data>(
decimal: &PyAny,
normalized: bool,
py: Python<'data>,
) -> ValResult<'data, (u64, u64)> {
fn extract_decimal_digits_info(decimal: &PyAny, normalized: bool, py: Python<'_>) -> ValResult<(u64, u64)> {
let mut normalized_decimal: Option<&PyAny> = None;
if normalized {
normalized_decimal = Some(decimal.call_method0(intern!(py, "normalize")).unwrap_or(decimal));
Expand Down Expand Up @@ -124,7 +120,7 @@ impl Validator for DecimalValidator {
py: Python<'data>,
input: &'data impl Input<'data>,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
) -> ValResult<PyObject> {
let decimal = input.validate_decimal(state.strict_or(self.strict), py)?;

if !self.allow_inf_nan || self.check_digits {
Expand Down Expand Up @@ -269,11 +265,7 @@ impl Validator for DecimalValidator {
}
}

pub(crate) fn create_decimal<'a>(
arg: &'a PyAny,
input: &'a impl Input<'a>,
py: Python<'a>,
) -> ValResult<'a, &'a PyAny> {
pub(crate) fn create_decimal<'a>(arg: &'a PyAny, input: &'a impl Input<'a>, py: Python<'a>) -> ValResult<&'a PyAny> {
let decimal_type_obj: Py<PyType> = get_decimal_type(py);
decimal_type_obj
.call1(py, (arg,))
Expand All @@ -293,10 +285,10 @@ pub(crate) fn create_decimal<'a>(

fn handle_decimal_new_error<'a>(
py: Python<'a>,
input: InputValue<'a>,
input: InputValue,
error: PyErr,
decimal_exception: &'a PyAny,
) -> ValError<'a> {
) -> ValError {
if error.matches(py, decimal_exception) {
ValError::new_custom_input(ErrorTypeDefaults::DecimalParsing, input)
} else if error.matches(py, PyTypeError::type_object(py)) {
Expand Down
53 changes: 21 additions & 32 deletions src/validators/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::definitions::DefinitionRef;
use crate::errors::{ErrorTypeDefaults, ValError, ValResult};
use crate::input::Input;

use crate::recursion_guard::RecursionGuard;
use crate::tools::SchemaDict;

use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};
Expand Down Expand Up @@ -72,24 +73,18 @@ impl Validator for DefinitionRefValidator {
py: Python<'data>,
input: &'data impl Input<'data>,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
let validator = self.definition.get().unwrap();
if let Some(id) = input.identity() {
if state.recursion_guard.contains_or_insert(id, self.definition.id()) {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input))
} else {
if state.recursion_guard.incr_depth() {
) -> ValResult<PyObject> {
self.definition.read(|validator| {
let validator = validator.unwrap();
if let Some(id) = input.identity() {
let Ok(mut guard) = RecursionGuard::new(state, id, self.definition.id()) else {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input));
}
let output = validator.validate(py, input, state);
state.recursion_guard.remove(id, self.definition.id());
state.recursion_guard.decr_depth();
output
};
validator.validate(py, input, guard.state())
} else {
validator.validate(py, input, state)
}
} else {
validator.validate(py, input, state)
}
})
}

fn validate_assignment<'data>(
Expand All @@ -99,24 +94,18 @@ impl Validator for DefinitionRefValidator {
field_name: &'data str,
field_value: &'data PyAny,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
let validator = self.definition.get().unwrap();
if let Some(id) = obj.identity() {
if state.recursion_guard.contains_or_insert(id, self.definition.id()) {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj))
} else {
if state.recursion_guard.incr_depth() {
) -> ValResult<PyObject> {
self.definition.read(|validator| {
let validator = validator.unwrap();
if let Some(id) = obj.identity() {
let Ok(mut guard) = RecursionGuard::new(state, id, self.definition.id()) else {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj));
}
let output = validator.validate_assignment(py, obj, field_name, field_value, state);
state.recursion_guard.remove(id, self.definition.id());
state.recursion_guard.decr_depth();
output
};
validator.validate_assignment(py, obj, field_name, field_value, guard.state())
} else {
validator.validate_assignment(py, obj, field_name, field_value, state)
}
} else {
validator.validate_assignment(py, obj, field_name, field_value, state)
}
})
}

fn get_name(&self) -> &str {
Expand Down
Loading