Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly process flatten fields in enum variants #2567

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 62 additions & 31 deletions serde_derive/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,11 +283,21 @@ fn deserialize_body(cont: &Container, params: &Parameters) -> Fragment {
} else if let attr::Identifier::No = cont.attrs.identifier() {
match &cont.data {
Data::Enum(variants) => deserialize_enum(params, variants, &cont.attrs),
Data::Struct(Style::Struct, fields) => {
deserialize_struct(params, fields, &cont.attrs, StructForm::Struct)
}
Data::Struct(Style::Struct, fields) => deserialize_struct(
params,
fields,
&cont.attrs,
cont.attrs.has_flatten(),
StructForm::Struct,
),
Data::Struct(Style::Tuple, fields) | Data::Struct(Style::Newtype, fields) => {
deserialize_tuple(params, fields, &cont.attrs, TupleForm::Tuple)
deserialize_tuple(
params,
fields,
&cont.attrs,
cont.attrs.has_flatten(),
TupleForm::Tuple,
)
}
Data::Struct(Style::Unit, _) => deserialize_unit_struct(params, &cont.attrs),
}
Expand Down Expand Up @@ -466,9 +476,13 @@ fn deserialize_tuple(
params: &Parameters,
fields: &[Field],
cattrs: &attr::Container,
has_flatten: bool,
form: TupleForm,
) -> Fragment {
assert!(!cattrs.has_flatten());
assert!(
!has_flatten,
"tuples and tuple variants cannot have flatten fields"
);

let field_count = fields
.iter()
Expand Down Expand Up @@ -586,7 +600,10 @@ fn deserialize_tuple_in_place(
fields: &[Field],
cattrs: &attr::Container,
) -> Fragment {
assert!(!cattrs.has_flatten());
assert!(
!cattrs.has_flatten(),
"tuples and tuple variants cannot have flatten fields"
);

let field_count = fields
.iter()
Expand Down Expand Up @@ -917,6 +934,7 @@ fn deserialize_struct(
params: &Parameters,
fields: &[Field],
cattrs: &attr::Container,
has_flatten: bool,
form: StructForm,
) -> Fragment {
let this_type = &params.this_type;
Expand Down Expand Up @@ -965,13 +983,13 @@ fn deserialize_struct(
)
})
.collect();
let field_visitor = deserialize_field_identifier(&field_names_idents, cattrs);
let field_visitor = deserialize_field_identifier(&field_names_idents, cattrs, has_flatten);

// untagged struct variants do not get a visit_seq method. The same applies to
// structs that only have a map representation.
let visit_seq = match form {
StructForm::Untagged(..) => None,
_ if cattrs.has_flatten() => None,
_ if has_flatten => None,
_ => {
let mut_seq = if field_names_idents.is_empty() {
quote!(_)
Expand All @@ -994,10 +1012,16 @@ fn deserialize_struct(
})
}
};
let visit_map = Stmts(deserialize_map(&type_path, params, fields, cattrs));
let visit_map = Stmts(deserialize_map(
&type_path,
params,
fields,
cattrs,
has_flatten,
));

let visitor_seed = match form {
StructForm::ExternallyTagged(..) if cattrs.has_flatten() => Some(quote! {
StructForm::ExternallyTagged(..) if has_flatten => Some(quote! {
impl #de_impl_generics _serde::de::DeserializeSeed<#delife> for __Visitor #de_ty_generics #where_clause {
type Value = #this_type #ty_generics;

Expand All @@ -1012,7 +1036,7 @@ fn deserialize_struct(
_ => None,
};

let fields_stmt = if cattrs.has_flatten() {
let fields_stmt = if has_flatten {
None
} else {
let field_names = field_names_idents
Expand All @@ -1032,7 +1056,7 @@ fn deserialize_struct(
}
};
let dispatch = match form {
StructForm::Struct if cattrs.has_flatten() => quote! {
StructForm::Struct if has_flatten => quote! {
_serde::Deserializer::deserialize_map(__deserializer, #visitor_expr)
},
StructForm::Struct => {
Expand All @@ -1041,7 +1065,7 @@ fn deserialize_struct(
_serde::Deserializer::deserialize_struct(__deserializer, #type_name, FIELDS, #visitor_expr)
}
}
StructForm::ExternallyTagged(_) if cattrs.has_flatten() => quote! {
StructForm::ExternallyTagged(_) if has_flatten => quote! {
_serde::de::VariantAccess::newtype_variant_seed(__variant, #visitor_expr)
},
StructForm::ExternallyTagged(_) => quote! {
Expand Down Expand Up @@ -1123,7 +1147,7 @@ fn deserialize_struct_in_place(
})
.collect();

let field_visitor = deserialize_field_identifier(&field_names_idents, cattrs);
let field_visitor = deserialize_field_identifier(&field_names_idents, cattrs, false);

let mut_seq = if field_names_idents.is_empty() {
quote!(_)
Expand Down Expand Up @@ -1217,10 +1241,7 @@ fn deserialize_homogeneous_enum(
}
}

fn prepare_enum_variant_enum(
variants: &[Variant],
cattrs: &attr::Container,
) -> (TokenStream, Stmts) {
fn prepare_enum_variant_enum(variants: &[Variant]) -> (TokenStream, Stmts) {
let mut deserialized_variants = variants
.iter()
.enumerate()
Expand Down Expand Up @@ -1254,7 +1275,7 @@ fn prepare_enum_variant_enum(

let variant_visitor = Stmts(deserialize_generated_identifier(
&variant_names_idents,
cattrs,
false, // variant identifiers does not depend on the presence of flatten fields
true,
None,
fallthrough,
Expand All @@ -1277,7 +1298,7 @@ fn deserialize_externally_tagged_enum(
let expecting = format!("enum {}", params.type_name());
let expecting = cattrs.expecting().unwrap_or(&expecting);

let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants, cattrs);
let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants);

// Match arms to extract a variant from a string
let variant_arms = variants
Expand Down Expand Up @@ -1362,7 +1383,7 @@ fn deserialize_internally_tagged_enum(
cattrs: &attr::Container,
tag: &str,
) -> Fragment {
let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants, cattrs);
let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants);

// Match arms to extract a variant from a string
let variant_arms = variants
Expand Down Expand Up @@ -1416,7 +1437,7 @@ fn deserialize_adjacently_tagged_enum(
split_with_de_lifetime(params);
let delife = params.borrowed.de_lifetime();

let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants, cattrs);
let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants);

let variant_arms: &Vec<_> = &variants
.iter()
Expand Down Expand Up @@ -1805,12 +1826,14 @@ fn deserialize_externally_tagged_variant(
params,
&variant.fields,
cattrs,
variant.attrs.has_flatten(),
TupleForm::ExternallyTagged(variant_ident),
),
Style::Struct => deserialize_struct(
params,
&variant.fields,
cattrs,
variant.attrs.has_flatten(),
StructForm::ExternallyTagged(variant_ident),
),
}
Expand Down Expand Up @@ -1854,6 +1877,7 @@ fn deserialize_internally_tagged_variant(
params,
&variant.fields,
cattrs,
variant.attrs.has_flatten(),
StructForm::InternallyTagged(variant_ident, deserializer),
),
Style::Tuple => unreachable!("checked in serde_derive_internals"),
Expand Down Expand Up @@ -1904,12 +1928,14 @@ fn deserialize_untagged_variant(
params,
&variant.fields,
cattrs,
variant.attrs.has_flatten(),
TupleForm::Untagged(variant_ident, deserializer),
),
Style::Struct => deserialize_struct(
params,
&variant.fields,
cattrs,
variant.attrs.has_flatten(),
StructForm::Untagged(variant_ident, deserializer),
),
}
Expand Down Expand Up @@ -1980,7 +2006,7 @@ fn deserialize_untagged_newtype_variant(

fn deserialize_generated_identifier(
fields: &[(&str, Ident, &BTreeSet<String>)],
cattrs: &attr::Container,
has_flatten: bool,
is_variant: bool,
ignore_variant: Option<TokenStream>,
fallthrough: Option<TokenStream>,
Expand All @@ -1994,11 +2020,11 @@ fn deserialize_generated_identifier(
is_variant,
fallthrough,
None,
!is_variant && cattrs.has_flatten(),
!is_variant && has_flatten,
None,
));

let lifetime = if !is_variant && cattrs.has_flatten() {
let lifetime = if !is_variant && has_flatten {
Some(quote!(<'de>))
} else {
None
Expand Down Expand Up @@ -2038,8 +2064,9 @@ fn deserialize_generated_identifier(
fn deserialize_field_identifier(
fields: &[(&str, Ident, &BTreeSet<String>)],
cattrs: &attr::Container,
has_flatten: bool,
) -> Stmts {
let (ignore_variant, fallthrough) = if cattrs.has_flatten() {
let (ignore_variant, fallthrough) = if has_flatten {
let ignore_variant = quote!(__other(_serde::__private::de::Content<'de>),);
let fallthrough = quote!(_serde::__private::Ok(__Field::__other(__value)));
(Some(ignore_variant), Some(fallthrough))
Expand All @@ -2053,7 +2080,7 @@ fn deserialize_field_identifier(

Stmts(deserialize_generated_identifier(
fields,
cattrs,
has_flatten,
false,
ignore_variant,
fallthrough,
Expand Down Expand Up @@ -2455,6 +2482,7 @@ fn deserialize_map(
params: &Parameters,
fields: &[Field],
cattrs: &attr::Container,
has_flatten: bool,
) -> Fragment {
// Create the field names for the fields.
let fields_names: Vec<_> = fields
Expand All @@ -2475,7 +2503,7 @@ fn deserialize_map(
});

// Collect contents for flatten fields into a buffer
let let_collect = if cattrs.has_flatten() {
let let_collect = if has_flatten {
Some(quote! {
let mut __collect = _serde::__private::Vec::<_serde::__private::Option<(
_serde::__private::de::Content,
Expand Down Expand Up @@ -2527,7 +2555,7 @@ fn deserialize_map(
});

// Visit ignored values to consume them
let ignored_arm = if cattrs.has_flatten() {
let ignored_arm = if has_flatten {
Some(quote! {
__Field::__other(__name) => {
__collect.push(_serde::__private::Some((
Expand Down Expand Up @@ -2597,7 +2625,7 @@ fn deserialize_map(
}
});

let collected_deny_unknown_fields = if cattrs.has_flatten() && cattrs.deny_unknown_fields() {
let collected_deny_unknown_fields = if has_flatten && cattrs.deny_unknown_fields() {
Some(quote! {
if let _serde::__private::Some(_serde::__private::Some((__key, _))) =
__collect.into_iter().filter(_serde::__private::Option::is_some).next()
Expand Down Expand Up @@ -2673,7 +2701,10 @@ fn deserialize_map_in_place(
fields: &[Field],
cattrs: &attr::Container,
) -> Fragment {
assert!(!cattrs.has_flatten());
assert!(
!cattrs.has_flatten(),
"inplace deserialization of maps doesn't support flatten fields"
);

// Create the field names for the fields.
let fields_names: Vec<_> = fields
Expand Down
1 change: 1 addition & 0 deletions serde_derive/src/internals/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ impl<'a> Container<'a> {
for field in &mut variant.fields {
if field.attrs.flatten() {
has_flatten = true;
variant.attrs.mark_has_flatten();
}
field.attrs.rename_by_rules(
variant
Expand Down
37 changes: 37 additions & 0 deletions serde_derive/src/internals/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,22 @@ pub struct Container {
type_into: Option<syn::Type>,
remote: Option<syn::Path>,
identifier: Identifier,
/// `true` if container is a `struct` and it has a field with `#[serde(flatten)]`
/// attribute or it is an `enum` with a struct variant which has a field with
/// `#[serde(flatten)]` attribute. Examples:
///
/// ```ignore
/// struct Container {
/// #[serde(flatten)]
/// some_field: (),
/// }
/// enum Container {
/// Variant {
/// #[serde(flatten)]
/// some_field: (),
/// },
/// }
/// ```
has_flatten: bool,
serde_path: Option<syn::Path>,
is_packed: bool,
Expand Down Expand Up @@ -785,6 +801,18 @@ pub struct Variant {
rename_all_rules: RenameAllRules,
ser_bound: Option<Vec<syn::WherePredicate>>,
de_bound: Option<Vec<syn::WherePredicate>>,
/// `true` if variant is a struct variant which contains a field with `#[serde(flatten)]`
/// attribute. Examples:
///
/// ```ignore
/// enum Enum {
/// Variant {
/// #[serde(flatten)]
/// some_field: (),
/// },
/// }
/// ```
has_flatten: bool,
skip_deserializing: bool,
skip_serializing: bool,
other: bool,
Expand Down Expand Up @@ -954,6 +982,7 @@ impl Variant {
},
ser_bound: ser_bound.get(),
de_bound: de_bound.get(),
has_flatten: false,
skip_deserializing: skip_deserializing.get(),
skip_serializing: skip_serializing.get(),
other: other.get(),
Expand Down Expand Up @@ -996,6 +1025,14 @@ impl Variant {
self.de_bound.as_ref().map(|vec| &vec[..])
}

pub fn has_flatten(&self) -> bool {
self.has_flatten
}

pub fn mark_has_flatten(&mut self) {
self.has_flatten = true;
}

pub fn skip_deserializing(&self) -> bool {
self.skip_deserializing
}
Expand Down
Loading