Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Computed field serialization for TypedDict #1018

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
75 changes: 68 additions & 7 deletions src/serializers/computed_fields.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use pyo3::exceptions::PyAttributeError;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyString};
use pyo3::{intern, PyTraverseError, PyVisit};
Expand All @@ -12,6 +13,8 @@ use crate::serializers::shared::{BuildSerializer, CombinedSerializer, PydanticSe
use crate::tools::SchemaDict;

use super::errors::py_err_se_err;
use super::errors::PydanticSerializationError;
use super::ob_type::{ObType, ObTypeLookup};
use super::Extra;

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -52,9 +55,10 @@ impl ComputedFields {
// Do not serialize computed fields
return Ok(());
}
for computed_fields in &self.0 {
computed_fields.to_python(model, output_dict, filter, include, exclude, extra)?;
for computed_field in &self.0 {
computed_field.to_python(model, output_dict, filter, include, exclude, extra)?;
}

Ok(())
}

Expand Down Expand Up @@ -102,6 +106,7 @@ struct ComputedField {
serializer: CombinedSerializer,
alias: String,
alias_py: Py<PyString>,
has_ser_func: bool,
}

impl ComputedField {
Expand All @@ -123,6 +128,7 @@ impl ComputedField {
serializer,
alias: alias_py.extract()?,
alias_py: alias_py.into_py(py),
has_ser_func: has_ser_function(return_schema),
})
}

Expand All @@ -139,8 +145,7 @@ impl ComputedField {
let property_name_py = self.property_name_py.as_ref(py);

if let Some((next_include, next_exclude)) = filter.key_filter(property_name_py, include, exclude)? {
let next_value = model.getattr(property_name_py)?;

let next_value = get_next_value(self, model, extra.ob_type_lookup)?;
let value = self
.serializer
.to_python(next_value, next_include, next_exclude, extra)?;
Expand Down Expand Up @@ -177,9 +182,8 @@ impl_py_gc_traverse!(ComputedFieldSerializer<'_> { computed_field });

impl<'py> Serialize for ComputedFieldSerializer<'py> {
fn serialize<S: serde::ser::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let py = self.model.py();
let property_name_py = self.computed_field.property_name_py.as_ref(py);
let next_value = self.model.getattr(property_name_py).map_err(py_err_se_err)?;
let next_value =
get_next_value(self.computed_field, self.model, self.extra.ob_type_lookup).map_err(py_err_se_err)?;
let s = PydanticSerializer::new(
next_value,
&self.computed_field.serializer,
Expand All @@ -190,3 +194,60 @@ impl<'py> Serialize for ComputedFieldSerializer<'py> {
s.serialize(serializer)
}
}

fn has_ser_function(schema: &PyDict) -> bool {
let py = schema.py();
let ser_schema = schema
.get_as::<&PyDict>(intern!(py, "serialization"))
.unwrap_or_default();
ser_schema.is_some_and(|s| s.contains(intern!(py, "function")).unwrap_or_default())
}

fn get_next_value<'a>(
field: &'a ComputedField,
input_value: &'a PyAny,
ob_type_lookup: &'a ObTypeLookup,
) -> PyResult<&'a PyAny> {
let py = input_value.py();
// Backwards compatiability.
let mut legacy_attr_error: Option<PyErr> = None;
let legacy_result = match ob_type_lookup.get_type(input_value) {
Comment on lines +212 to +214
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's fair to call this "backwards compatibility" when this is still expected to be the main code path for models and dataclasses.

I wonder if there might be a more unified way. For a model or dataclass A with computed field b, the analogous functionality really seems to be A.b.__get__(instance). For a TypedDict it looks like that also works:

>>> class Bar(TypedDict):
...      @property
...      def y(self):
...          return 434
...
>>> Bar.y.__get__({})
434

So maybe what we really want, in all cases, is

let property_value = input_value.get_type().getattr(field.property_name_py.as_ref(py))?.call_method1("__get__", (input_value,))?;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking was that a serialization function should be provided to a computed field, doesn't matter if the input type is a Model, TypedDict, Dataclass, etc.

The default behavior is to compute the computed value from the function provided in the serialization schema and then it gets set in the output_dict:

let value = self
.serializer
.to_python(next_value, next_include, next_exclude, extra)?;
if extra.exclude_none && value.is_none(py) {
return Ok(());
}
let key = match extra.by_alias {
true => self.alias_py.as_ref(py),
false => property_name_py,
};
output_dict.set_item(key, value)?;

This seems more generalizable to all computed fields instead of relying on the computed field defined as an attribute on the input value.

However, I am probably missing some context on how computed fields are used by https://github.com/pydantic/pydantic.

Copy link
Member

@adriangb adriangb Oct 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not following all that closely but my 2c is that ideally we extract the function from the thing in pydantic and not in pydantic-core so that:

  1. We have more flexibility. It's easier to hack things (like rebuild __mro__ based on __orig_bases__ which we do for TypedDict)
  2. It ensures that we do this at schema build time and not runtime

The con of that last one is that in theory someone could want us to use the method on a subclass they pass in as a value, which doesn't apply to TypedDict but also is not what we do for BaseModel and no one has complained 😄

ObType::Dataclass | ObType::PydanticSerializable => {
match input_value.getattr(field.property_name_py.as_ref(py)) {
Ok(attr) => Ok(Some(attr)),
Err(err) => {
if err.get_type(py).is_subclass_of::<PyAttributeError>()? {
legacy_attr_error = Some(err);
Ok(None)
} else {
Err(err)
}
}
}
}
_ => Ok(None),
};
match legacy_result {
Ok(opt) => {
if let Some(legacy_next_value) = opt {
return Ok(legacy_next_value);
}
}
Err(err) => return Err(err),
};

// Default behavior: If custom serialization function provided, compute value based on input.
if field.has_ser_func {
return Ok(input_value);
}
// Fallback behavior: Check if computed field is a property of input object
// (i.e. in some cases input_value can be ObType::Unknown)
if let Ok(next_value_from_input) = input_value.getattr(field.property_name_py.as_ref(py)) {
return Ok(next_value_from_input);
}

Err(legacy_attr_error.unwrap_or(PydanticSerializationError::new_err(format!(
"No serialization function found for '{}'",
field.property_name
))))
}
110 changes: 108 additions & 2 deletions tests/serializers/test_typed_dict.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import json
from typing import Any, Dict
from typing import Any, Dict, List

import pytest
from dirty_equals import IsStrictDict
from typing_extensions import TypedDict

from pydantic_core import SchemaSerializer, core_schema
from pydantic_core import PydanticSerializationError, SchemaSerializer, core_schema


@pytest.mark.parametrize('extra_behavior_kw', [{}, {'extra_behavior': 'ignore'}, {'extra_behavior': None}])
Expand Down Expand Up @@ -333,3 +333,109 @@ def test_extra_custom_serializer():
m = {'extra': 'extra'}

assert s.to_python(m) == {'extra': 'extra bam!'}


def test_computed_fields_with_plain_serializer_function():
def ser_x(v: dict):
two = v['0'] + v['1'] + 1
return two

schema = core_schema.typed_dict_schema(
{
'0': core_schema.typed_dict_field(core_schema.int_schema()),
'1': core_schema.typed_dict_field(core_schema.int_schema()),
},
computed_fields=[
core_schema.computed_field(
'2', core_schema.int_schema(serialization=core_schema.plain_serializer_function_ser_schema(ser_x))
)
],
)
s = SchemaSerializer(schema)
value = {'0': 0, '1': 1}
assert s.to_python(value) == {'0': 0, '1': 1, '2': 2}
assert s.to_json(value) == b'{"0":0,"1":1,"2":2}'

def ser_foo(_v: dict):
return 'bar'

schema = core_schema.typed_dict_schema(
{},
computed_fields=[
core_schema.computed_field(
'foo', core_schema.str_schema(serialization=core_schema.plain_serializer_function_ser_schema(ser_foo))
)
],
)
s = SchemaSerializer(schema)
assert s.to_python({}) == {'foo': 'bar'}
assert s.to_json({}) == b'{"foo":"bar"}'


def test_computed_fields_with_warpped_serializer_function():
def ser_to_upper(string_arr: List[str]) -> List[str]:
return [s.upper() for s in string_arr]

def ser_columns(v: dict, serializer: core_schema.SerializerFunctionWrapHandler, _) -> str:
column_keys = serializer([key for key in v.keys()])
return column_keys

schema = core_schema.typed_dict_schema(
{
'one': core_schema.typed_dict_field(core_schema.int_schema()),
'two': core_schema.typed_dict_field(core_schema.int_schema()),
'three': core_schema.typed_dict_field(core_schema.int_schema()),
},
computed_fields=[
core_schema.computed_field(
'columns',
core_schema.int_schema(
serialization=core_schema.wrap_serializer_function_ser_schema(
ser_columns,
info_arg=True,
schema=core_schema.list_schema(
serialization=core_schema.plain_serializer_function_ser_schema(ser_to_upper)
),
)
),
)
],
)
s = SchemaSerializer(schema)
value = {'one': 1, 'two': 2, 'three': 3}
assert s.to_python(value) == {'one': 1, 'two': 2, 'three': 3, 'columns': ['ONE', 'TWO', 'THREE']}
assert s.to_json(value) == b'{"one":1,"two":2,"three":3,"columns":["ONE","TWO","THREE"]}'


def test_computed_fields_with_typed_dict_model():
class Model(TypedDict):
x: int

def ser_y(v: Any) -> str:
return f'{v["x"]}.00'

s = SchemaSerializer(
core_schema.typed_dict_schema(
{'x': core_schema.typed_dict_field(core_schema.int_schema())},
computed_fields=[
core_schema.computed_field(
'y', core_schema.str_schema(serialization=core_schema.plain_serializer_function_ser_schema(ser_y))
)
],
)
)
assert s.to_python(Model(x=1000)) == {'x': 1000, 'y': '1000.00'}


def test_computed_fields_without_ser_function():
class Model(TypedDict):
x: int

s = SchemaSerializer(
core_schema.typed_dict_schema(
{'x': core_schema.typed_dict_field(core_schema.int_schema())},
computed_fields=[core_schema.computed_field('y', core_schema.str_schema())],
)
)
with pytest.raises(PydanticSerializationError, match="^No serialization function found for 'y'$"):
s.to_python(Model(x=1000))