From 0cdcd9ab81609c42c8dc58dc6b4335058c5759f6 Mon Sep 17 00:00:00 2001 From: Vincent Esche Date: Thu, 22 May 2025 17:01:02 +0200 Subject: [PATCH] Make `VariantDiscriminant` derive macro support variants with non-`'static` lifetimes --- CHANGELOG.md | 2 +- macros/Cargo.toml | 2 +- macros/src/enum_deriver.rs | 26 ++- macros/src/lib.rs | 1 + macros/src/type_visitor.rs | 45 ++++- macros/src/type_visitor_mut.rs | 30 +++ .../pass/enum/lifetimes.out.rs | 187 ++++++++++++++++++ .../pass/enum/lifetimes.rs | 27 +++ 8 files changed, 305 insertions(+), 15 deletions(-) create mode 100644 macros/src/type_visitor_mut.rs create mode 100644 tests/derive-tests/variant_discriminant/pass/enum/lifetimes.out.rs create mode 100644 tests/derive-tests/variant_discriminant/pass/enum/lifetimes.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ac2a7b..149f394 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,7 +20,7 @@ Please make sure to add your changes to the appropriate categories: ### Added -- n/a +- Made `VariantDiscriminant` derive macro support variants with non-`'static` lifetimes. ### Changed diff --git a/macros/Cargo.toml b/macros/Cargo.toml index 43b7a94..75b9be9 100644 --- a/macros/Cargo.toml +++ b/macros/Cargo.toml @@ -20,4 +20,4 @@ proc-macro = true [dependencies] proc-macro2 = { version = "1.0.81", features = ["span-locations"] } quote = "1.0.36" -syn = { version = "2.0.60", features = ["full", "visit"] } +syn = { version = "2.0.60", features = ["full", "visit", "visit-mut"] } diff --git a/macros/src/enum_deriver.rs b/macros/src/enum_deriver.rs index 66914fa..9f6430e 100644 --- a/macros/src/enum_deriver.rs +++ b/macros/src/enum_deriver.rs @@ -1,10 +1,11 @@ use proc_macro2::TokenStream as TokenStream2; use quote::quote; use syn::{ - parse_quote, parse_quote_spanned, spanned::Spanned, visit::Visit as _, Fields, Type, Variant, + parse_quote, parse_quote_spanned, spanned::Spanned, visit::Visit as _, visit_mut::VisitMut, + Fields, Type, Variant, }; -use crate::*; +use crate::{type_visitor_mut::TypeVisitorMut, *}; pub(crate) struct EnumDeriver { item: syn::ItemEnum, @@ -585,15 +586,23 @@ impl EnumDeriver { match nested { NestedDiscriminantType::Default => { let (field, _) = field_selection.expect("no selected field found"); - let field_type = &field.ty; - if self.uses_generic_const_or_type(field_type) { + let mut visitor = TypeVisitor::new(&self.item.generics); + visitor.visit_type(&field.ty); + + if visitor.type_uses_const_or_type_param() { return Err(syn::Error::new( field.span(), "generic fields require an explicit nested discriminant type", )); } + let field_type = if visitor.type_uses_lifetime_param() { + self.type_replacing_lifetimes_with_static(&field.ty) + } else { + field.ty.clone() + }; + let nested_type = parse_quote! { <#field_type as ::enumcapsulate::VariantDiscriminant>::Discriminant }; @@ -741,11 +750,18 @@ impl EnumDeriver { } } + fn type_replacing_lifetimes_with_static(&self, ty: &syn::Type) -> syn::Type { + let mut ty = ty.clone(); + let mut visitor = TypeVisitorMut::default().replace_lifetimes_with_static(); + visitor.visit_type_mut(&mut ty); + ty + } + fn uses_generic_const_or_type(&self, ty: &syn::Type) -> bool { let mut visitor = TypeVisitor::new(&self.item.generics); visitor.visit_type(ty); - visitor.uses_const_or_type_param() + visitor.type_uses_const_or_type_param() } } diff --git a/macros/src/lib.rs b/macros/src/lib.rs index c8bb028..c40d4e3 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -6,6 +6,7 @@ use crate::utils::tokenstream; mod config; mod enum_deriver; mod type_visitor; +mod type_visitor_mut; mod utils; use self::{config::*, enum_deriver::*, type_visitor::*, utils::*}; diff --git a/macros/src/type_visitor.rs b/macros/src/type_visitor.rs index c7e77a3..7476187 100644 --- a/macros/src/type_visitor.rs +++ b/macros/src/type_visitor.rs @@ -3,25 +3,48 @@ use std::collections::HashSet; use syn::visit::Visit; pub struct TypeVisitor<'ast> { + lifetime_param_idents: HashSet<&'ast syn::Ident>, const_param_idents: HashSet<&'ast syn::Ident>, type_param_idents: HashSet<&'ast syn::Ident>, - uses_const_param: bool, - uses_type_param: bool, + type_uses_lifetime_param: bool, + type_uses_const_param: bool, + type_uses_type_param: bool, } impl<'ast> TypeVisitor<'ast> { pub fn new(generics: &'ast syn::Generics) -> Self { Self { + lifetime_param_idents: generics + .lifetimes() + .map(|param| ¶m.lifetime.ident) + .collect(), const_param_idents: generics.const_params().map(|param| ¶m.ident).collect(), type_param_idents: generics.type_params().map(|param| ¶m.ident).collect(), - uses_const_param: false, - uses_type_param: false, + type_uses_lifetime_param: false, + type_uses_const_param: false, + type_uses_type_param: false, } } - pub fn uses_const_or_type_param(self) -> bool { - self.uses_const_param || self.uses_type_param + #[allow(dead_code)] + pub fn type_uses_lifetime_param(&self) -> bool { + self.type_uses_lifetime_param + } + + #[allow(dead_code)] + pub fn type_uses_const_param(&self) -> bool { + self.type_uses_const_param + } + + #[allow(dead_code)] + pub fn type_uses_type_param(&self) -> bool { + self.type_uses_type_param + } + + #[allow(dead_code)] + pub fn type_uses_const_or_type_param(&self) -> bool { + self.type_uses_const_param || self.type_uses_type_param } } @@ -39,12 +62,18 @@ impl<'ast> Visit<'ast> for TypeVisitor<'ast> { let ident = &path_segment.ident; if self.type_param_idents.contains(ident) { - self.uses_type_param = true; + self.type_uses_type_param = true; } else if self.const_param_idents.contains(ident) { - self.uses_const_param = true; + self.type_uses_const_param = true; } } } syn::visit::visit_type_path(self, node); } + + fn visit_lifetime(&mut self, lifetime: &'ast syn::Lifetime) { + if self.lifetime_param_idents.contains(&lifetime.ident) { + self.type_uses_lifetime_param = true; + } + } } diff --git a/macros/src/type_visitor_mut.rs b/macros/src/type_visitor_mut.rs new file mode 100644 index 0000000..6175eec --- /dev/null +++ b/macros/src/type_visitor_mut.rs @@ -0,0 +1,30 @@ +use syn::visit_mut::VisitMut; + +pub struct TypeVisitorMut { + replace_lifetimes_with_static: bool, +} + +impl TypeVisitorMut { + pub fn new() -> Self { + Self { + replace_lifetimes_with_static: false, + } + } + + pub fn replace_lifetimes_with_static(mut self) -> Self { + self.replace_lifetimes_with_static = true; + self + } +} + +impl Default for TypeVisitorMut { + fn default() -> Self { + Self::new() + } +} + +impl VisitMut for TypeVisitorMut { + fn visit_lifetime_mut(&mut self, lifetime: &mut syn::Lifetime) { + lifetime.ident = quote::format_ident!("static"); + } +} diff --git a/tests/derive-tests/variant_discriminant/pass/enum/lifetimes.out.rs b/tests/derive-tests/variant_discriminant/pass/enum/lifetimes.out.rs new file mode 100644 index 0000000..77e4675 --- /dev/null +++ b/tests/derive-tests/variant_discriminant/pass/enum/lifetimes.out.rs @@ -0,0 +1,187 @@ +use enumcapsulate::VariantDiscriminant; +pub enum VariantWithLifetime<'a> { + Variant(&'a ()), +} +pub enum VariantWithLifetimeDiscriminant { + Variant, +} +#[automatically_derived] +impl ::core::marker::Copy for VariantWithLifetimeDiscriminant {} +#[automatically_derived] +impl ::core::clone::Clone for VariantWithLifetimeDiscriminant { + #[inline] + fn clone(&self) -> VariantWithLifetimeDiscriminant { + *self + } +} +#[automatically_derived] +impl ::core::cmp::Ord for VariantWithLifetimeDiscriminant { + #[inline] + fn cmp(&self, other: &VariantWithLifetimeDiscriminant) -> ::core::cmp::Ordering { + ::core::cmp::Ordering::Equal + } +} +#[automatically_derived] +impl ::core::cmp::PartialOrd for VariantWithLifetimeDiscriminant { + #[inline] + fn partial_cmp( + &self, + other: &VariantWithLifetimeDiscriminant, + ) -> ::core::option::Option<::core::cmp::Ordering> { + ::core::option::Option::Some(::core::cmp::Ordering::Equal) + } +} +#[automatically_derived] +impl ::core::cmp::Eq for VariantWithLifetimeDiscriminant { + #[inline] + #[doc(hidden)] + #[coverage(off)] + fn assert_receiver_is_total_eq(&self) -> () {} +} +#[automatically_derived] +impl ::core::marker::StructuralPartialEq for VariantWithLifetimeDiscriminant {} +#[automatically_derived] +impl ::core::cmp::PartialEq for VariantWithLifetimeDiscriminant { + #[inline] + fn eq(&self, other: &VariantWithLifetimeDiscriminant) -> bool { + true + } +} +#[automatically_derived] +impl ::core::hash::Hash for VariantWithLifetimeDiscriminant { + #[inline] + fn hash<__H: ::core::hash::Hasher>(&self, state: &mut __H) -> () {} +} +#[automatically_derived] +impl ::core::fmt::Debug for VariantWithLifetimeDiscriminant { + #[inline] + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + ::core::fmt::Formatter::write_str(f, "Variant") + } +} +impl<'a> ::enumcapsulate::VariantDiscriminant for VariantWithLifetime<'a> { + type Discriminant = VariantWithLifetimeDiscriminant; + fn variant_discriminant(&self) -> Self::Discriminant { + match self { + VariantWithLifetime::Variant(..) => VariantWithLifetimeDiscriminant::Variant, + _ => ::core::panicking::panic("internal error: entered unreachable code"), + } + } +} +pub enum EnumWithLifetime<'a> { + #[enumcapsulate(discriminant(nested))] + VariantA(VariantWithLifetime<'a>), +} +pub enum EnumWithLifetimeDiscriminant { + VariantA( + as ::enumcapsulate::VariantDiscriminant>::Discriminant, + ), +} +#[automatically_derived] +impl ::core::marker::Copy for EnumWithLifetimeDiscriminant {} +#[automatically_derived] +impl ::core::clone::Clone for EnumWithLifetimeDiscriminant { + #[inline] + fn clone(&self) -> EnumWithLifetimeDiscriminant { + let _: ::core::clone::AssertParamIsClone< + as ::enumcapsulate::VariantDiscriminant>::Discriminant, + >; + *self + } +} +#[automatically_derived] +impl ::core::cmp::Ord for EnumWithLifetimeDiscriminant { + #[inline] + fn cmp(&self, other: &EnumWithLifetimeDiscriminant) -> ::core::cmp::Ordering { + match (self, other) { + ( + EnumWithLifetimeDiscriminant::VariantA(__self_0), + EnumWithLifetimeDiscriminant::VariantA(__arg1_0), + ) => ::core::cmp::Ord::cmp(__self_0, __arg1_0), + } + } +} +#[automatically_derived] +impl ::core::cmp::PartialOrd for EnumWithLifetimeDiscriminant { + #[inline] + fn partial_cmp( + &self, + other: &EnumWithLifetimeDiscriminant, + ) -> ::core::option::Option<::core::cmp::Ordering> { + match (self, other) { + ( + EnumWithLifetimeDiscriminant::VariantA(__self_0), + EnumWithLifetimeDiscriminant::VariantA(__arg1_0), + ) => ::core::cmp::PartialOrd::partial_cmp(__self_0, __arg1_0), + } + } +} +#[automatically_derived] +impl ::core::cmp::Eq for EnumWithLifetimeDiscriminant { + #[inline] + #[doc(hidden)] + #[coverage(off)] + fn assert_receiver_is_total_eq(&self) -> () { + let _: ::core::cmp::AssertParamIsEq< + as ::enumcapsulate::VariantDiscriminant>::Discriminant, + >; + } +} +#[automatically_derived] +impl ::core::marker::StructuralPartialEq for EnumWithLifetimeDiscriminant {} +#[automatically_derived] +impl ::core::cmp::PartialEq for EnumWithLifetimeDiscriminant { + #[inline] + fn eq(&self, other: &EnumWithLifetimeDiscriminant) -> bool { + match (self, other) { + ( + EnumWithLifetimeDiscriminant::VariantA(__self_0), + EnumWithLifetimeDiscriminant::VariantA(__arg1_0), + ) => *__self_0 == *__arg1_0, + } + } +} +#[automatically_derived] +impl ::core::hash::Hash for EnumWithLifetimeDiscriminant { + #[inline] + fn hash<__H: ::core::hash::Hasher>(&self, state: &mut __H) -> () { + match self { + EnumWithLifetimeDiscriminant::VariantA(__self_0) => { + ::core::hash::Hash::hash(__self_0, state) + } + } + } +} +#[automatically_derived] +impl ::core::fmt::Debug for EnumWithLifetimeDiscriminant { + #[inline] + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + match self { + EnumWithLifetimeDiscriminant::VariantA(__self_0) => { + ::core::fmt::Formatter::debug_tuple_field1_finish( + f, + "VariantA", + &__self_0, + ) + } + } + } +} +impl<'a> ::enumcapsulate::VariantDiscriminant for EnumWithLifetime<'a> { + type Discriminant = EnumWithLifetimeDiscriminant; + fn variant_discriminant(&self) -> Self::Discriminant { + match self { + EnumWithLifetime::VariantA(inner, ..) => { + EnumWithLifetimeDiscriminant::VariantA(inner.variant_discriminant()) + } + _ => ::core::panicking::panic("internal error: entered unreachable code"), + } + } +} +fn main() {} diff --git a/tests/derive-tests/variant_discriminant/pass/enum/lifetimes.rs b/tests/derive-tests/variant_discriminant/pass/enum/lifetimes.rs new file mode 100644 index 0000000..eeb6232 --- /dev/null +++ b/tests/derive-tests/variant_discriminant/pass/enum/lifetimes.rs @@ -0,0 +1,27 @@ +use enumcapsulate::VariantDiscriminant; + +#[derive(VariantDiscriminant)] +pub enum VariantWithLifetime<'a> { + Variant(&'a ()), +} + +// #[derive(VariantDiscriminant)] +// pub enum GenericVariantWithLifetime<'a, T> { +// Variant(&'a T), +// } + +#[derive(VariantDiscriminant)] +pub enum EnumWithLifetime<'a> { + #[enumcapsulate(discriminant(nested))] + VariantA(VariantWithLifetime<'a>), + // #[enumcapsulate(discriminant(nested))] + // VariantB { b: VariantWithLifetime<'a> }, + // #[enumcapsulate(field = 0, discriminant(nested = GenericVariantWithLifetimeDiscriminant))] + // VariantC(GenericVariantWithLifetime<'a, T>), + // #[enumcapsulate(field = "d", discriminant(nested = GenericVariantWithLifetimeDiscriminant))] + // VariantD { + // d: GenericVariantWithLifetime<'a, T>, + // }, +} + +fn main() {}