diff --git a/components/salsa-macros/src/interned.rs b/components/salsa-macros/src/interned.rs index 8b72b3510..6a0748dc2 100644 --- a/components/salsa-macros/src/interned.rs +++ b/components/salsa-macros/src/interned.rs @@ -1,4 +1,8 @@ use proc_macro2::TokenStream; +use quote::ToTokens; +use syn::spanned::Spanned; +use syn::visit::Visit; +use syn::visit_mut::VisitMut; use crate::hygiene::Hygiene; use crate::options::{AllowedOptions, AllowedPersistOptions, Options}; @@ -16,13 +20,31 @@ pub(crate) fn interned( ) -> proc_macro::TokenStream { let args = syn::parse_macro_input!(args as InternedArgs); let hygiene = Hygiene::from1(&input); - let struct_item = parse_macro_input!(input as syn::ItemStruct); - let m = Macro { - hygiene, - args, - struct_item, + let item = match syn::parse::(input.clone()) { + Ok(item) => item, + Err(err) => return token_stream_with_error(input, err), }; - match m.try_macro() { + + let lowered = match item { + syn::Item::Struct(struct_item) => Ok(InternedInput::from_struct(struct_item)), + syn::Item::Enum(enum_item) => InternedInput::from_enum(enum_item, &args), + other => Err(syn::Error::new( + other.span(), + "interned can only be applied to structs and enums", + )), + }; + + match lowered.and_then(|input| { + Macro { + hygiene, + args, + struct_item: input.struct_item, + struct_data_ident: input.struct_data_ident, + skip_conflict_rename: input.skip_conflict_rename, + additional_items: input.additional_items, + } + .try_macro() + }) { Ok(v) => v.into(), Err(e) => token_stream_with_error(input, e), } @@ -90,6 +112,9 @@ struct Macro { hygiene: Hygiene, args: InternedArgs, struct_item: syn::ItemStruct, + struct_data_ident: syn::Ident, + skip_conflict_rename: bool, + additional_items: Vec, } impl Macro { @@ -100,7 +125,6 @@ impl Macro { let attrs = &self.struct_item.attrs; let vis = &self.struct_item.vis; let struct_ident = &self.struct_item.ident; - let struct_data_ident = format_ident!("{}Data", struct_ident); let db_lt = db_lifetime::db_lifetime(&self.struct_item.generics); let new_fn = salsa_struct.constructor_name(); let field_ids = salsa_struct.field_ids(); @@ -116,6 +140,13 @@ impl Macro { let has_lifetime = salsa_struct.generate_lifetime(); let id = salsa_struct.id(); let revisions = salsa_struct.revisions(); + let mut struct_data_ident = self.struct_data_ident.clone(); + if !self.skip_conflict_rename + && self.args.data.is_none() + && struct_data_ident_conflicts(&struct_data_ident, &field_tys) + { + struct_data_ident = self.hygiene.scoped_ident(struct_ident, "Fields"); + } let (db_lt_arg, cfg, interior_lt) = if has_lifetime { ( @@ -144,10 +175,12 @@ impl Macro { let Configuration = self.hygiene.ident("Configuration"); let CACHE = self.hygiene.ident("CACHE"); let Db = self.hygiene.ident("Db"); + let additional_items = &self.additional_items; Ok(crate::debug::dump_tokens( struct_ident, quote! { + #(#additional_items)* salsa::plumbing::setup_interned_struct!( attrs: [#(#attrs),*], vis: #vis, @@ -185,3 +218,127 @@ impl Macro { )) } } + +struct InternedInput { + struct_item: syn::ItemStruct, + struct_data_ident: syn::Ident, + skip_conflict_rename: bool, + additional_items: Vec, +} + +impl InternedInput { + fn from_struct(struct_item: syn::ItemStruct) -> Self { + let struct_data_ident = format_ident!("{}Data", struct_item.ident); + Self { + struct_item, + struct_data_ident, + skip_conflict_rename: false, + additional_items: Vec::new(), + } + } + + fn from_enum(enum_item: syn::ItemEnum, args: &InternedArgs) -> syn::Result { + let struct_ident = enum_item.ident.clone(); + if let Some(data_ident) = args.data.clone() { + if data_ident == struct_ident { + return Err(syn::Error::new( + data_ident.span(), + "data name conflicts with a generated identifier; please choose a different `data` name", + )); + } + } + + let data_ident = args + .data + .clone() + .unwrap_or_else(|| format_ident!("{}Data", struct_ident)); + + let mut data_enum = enum_item; + rename_type_idents(&mut data_enum, &struct_ident, &data_ident); + data_enum.ident = data_ident.clone(); + + let generics = data_enum.generics.clone(); + let (_, ty_generics, _) = generics.split_for_impl(); + let field_ty: syn::Type = parse_quote!(#data_ident #ty_generics); + let struct_attrs = data_enum + .attrs + .iter() + .filter(|attr| !attr.path().is_ident("derive")) + .cloned() + .collect::>(); + let struct_vis = data_enum.vis.clone(); + let struct_item: syn::ItemStruct = parse_quote! { + #(#struct_attrs)* + #struct_vis struct #struct_ident #generics { + value: #field_ty, + } + }; + + Ok(Self { + struct_item, + // Use a distinct alias for the macro-internal tuple to avoid cycling with the data enum. + struct_data_ident: format_ident!("{}Fields", struct_ident), + // Allow conflict renaming to kick in if this identifier is already in use. + skip_conflict_rename: false, + additional_items: vec![data_enum.into_token_stream()], + }) + } +} + +fn rename_type_idents(enum_item: &mut syn::ItemEnum, from: &syn::Ident, to: &syn::Ident) { + struct Renamer<'a> { + from: &'a syn::Ident, + to: &'a syn::Ident, + } + + impl syn::visit_mut::VisitMut for Renamer<'_> { + fn visit_type_path_mut(&mut self, node: &mut syn::TypePath) { + if node.qself.is_none() + && node.path.leading_colon.is_none() + && node.path.segments.len() == 1 + && node.path.segments.first().map(|s| &s.ident) == Some(self.from) + { + node.path.segments[0].ident = self.to.clone(); + } + syn::visit_mut::visit_type_path_mut(self, node); + } + } + + let mut renamer = Renamer { from, to }; + renamer.visit_item_enum_mut(enum_item); +} + +fn struct_data_ident_conflicts(ident: &syn::Ident, field_tys: &[&syn::Type]) -> bool { + field_tys + .iter() + .copied() + .any(|ty| type_contains_ident(ty, ident)) +} + +fn type_contains_ident(ty: &syn::Type, ident: &syn::Ident) -> bool { + struct Finder<'a> { + ident: &'a syn::Ident, + found: bool, + } + + impl<'ast> syn::visit::Visit<'ast> for Finder<'_> { + fn visit_type_path(&mut self, node: &'ast syn::TypePath) { + if node.qself.is_none() + && node.path.leading_colon.is_none() + && node.path.segments.len() == 1 + && node.path.segments.first().map(|s| &s.ident) == Some(self.ident) + { + self.found = true; + } + syn::visit::visit_type_path(self, node); + } + } + + let mut finder = Finder { + ident, + found: false, + }; + + finder.visit_type(ty); + finder.found +} diff --git a/tests/compile-fail/interned-enum_data_name_conflict.rs b/tests/compile-fail/interned-enum_data_name_conflict.rs new file mode 100644 index 000000000..5ccc3eb88 --- /dev/null +++ b/tests/compile-fail/interned-enum_data_name_conflict.rs @@ -0,0 +1,9 @@ +//@compile-fail +#![deny(warnings)] + +#[salsa::interned(data = ConflictingData)] +enum ConflictingData<'db> { + Variant(&'db ()), +} + +fn main() {} diff --git a/tests/compile-fail/interned-enum_data_name_conflict.stderr b/tests/compile-fail/interned-enum_data_name_conflict.stderr new file mode 100644 index 000000000..515de54bd --- /dev/null +++ b/tests/compile-fail/interned-enum_data_name_conflict.stderr @@ -0,0 +1,5 @@ +error: data name conflicts with a generated identifier; please choose a different `data` name + --> tests/compile-fail/interned-enum_data_name_conflict.rs:4:26 + | +4 | #[salsa::interned(data = ConflictingData)] + | ^^^^^^^^^^^^^^^ diff --git a/tests/interned-enum.rs b/tests/interned-enum.rs new file mode 100644 index 000000000..d02f3d912 --- /dev/null +++ b/tests/interned-enum.rs @@ -0,0 +1,82 @@ +#![cfg(feature = "inventory")] + +#[salsa::interned(debug)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[allow(dead_code)] +enum InternedEnum<'db> { + Unit, + Tuple(u8, u8), + Wrap(Box), + Ref(&'db ()), +} + +#[salsa::interned(debug, data = CustomPayload)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +enum CustomDataEnum { + One(u32), + Two(String), +} + +#[salsa::interned(no_lifetime, debug)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +enum NoLifetimeInterned { + Item(&'static str), +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +struct CollisionData<'db>(std::marker::PhantomData<&'db ()>); + +// Field type intentionally matches the auto data name (`CollisionData`) to +// exercise the hygiene fallback for generated data identifiers. +#[salsa::interned(debug)] +struct Collision<'db> { + payload: CollisionData<'db>, +} + +#[test] +fn supports_enums() { + let db = salsa::DatabaseImpl::new(); + + let unit1 = InternedEnum::new(&db, InternedEnumData::Unit); + let unit2 = InternedEnum::new(&db, InternedEnumData::Unit); + assert_eq!(unit1, unit2); + + let wrapped = InternedEnum::new( + &db, + InternedEnumData::Wrap(Box::new(InternedEnumData::Tuple(1, 2))), + ); + assert_eq!( + wrapped.value(&db), + InternedEnumData::Wrap(Box::new(InternedEnumData::Tuple(1, 2))) + ); +} + +#[test] +fn respects_custom_data_name() { + let db = salsa::DatabaseImpl::new(); + + let v = CustomDataEnum::new(&db, CustomPayload::Two("hi".into())); + assert_eq!(v.value(&db), CustomPayload::Two("hi".into())); + + let v2 = CustomDataEnum::new(&db, CustomPayload::One(1)); + assert_eq!(v2.value(&db), CustomPayload::One(1)); +} + +#[test] +fn supports_no_lifetime_enum() { + let db = salsa::DatabaseImpl::new(); + + let v = NoLifetimeInterned::new(&db, NoLifetimeInternedData::Item("static")); + assert_eq!(v.value(&db), NoLifetimeInternedData::Item("static")); +} + +#[test] +fn auto_data_name_conflict_is_renamed() { + let db = salsa::DatabaseImpl::new(); + + let collision = Collision::new(&db, CollisionData(std::marker::PhantomData)); + assert_eq!( + collision.payload(&db), + CollisionData(std::marker::PhantomData) + ); +}