Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Please make sure to add your changes to the appropriate categories:

### Added

- n/a
- Made `VariantDiscriminant` derive macro support variants with non-`'static` lifetimes.

### Changed

Expand Down
2 changes: 1 addition & 1 deletion macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ proc-macro = true
[dependencies]
proc-macro2 = { version = "1.0.81", features = ["span-locations"] }
quote = "1.0.36"
syn = { version = "2.0.60", features = ["full", "visit"] }
syn = { version = "2.0.60", features = ["full", "visit", "visit-mut"] }
26 changes: 21 additions & 5 deletions macros/src/enum_deriver.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{
parse_quote, parse_quote_spanned, spanned::Spanned, visit::Visit as _, Fields, Type, Variant,
parse_quote, parse_quote_spanned, spanned::Spanned, visit::Visit as _, visit_mut::VisitMut,
Fields, Type, Variant,
};

use crate::*;
use crate::{type_visitor_mut::TypeVisitorMut, *};

pub(crate) struct EnumDeriver {
item: syn::ItemEnum,
Expand Down Expand Up @@ -585,15 +586,23 @@ impl EnumDeriver {
match nested {
NestedDiscriminantType::Default => {
let (field, _) = field_selection.expect("no selected field found");
let field_type = &field.ty;

if self.uses_generic_const_or_type(field_type) {
let mut visitor = TypeVisitor::new(&self.item.generics);
visitor.visit_type(&field.ty);

if visitor.type_uses_const_or_type_param() {
return Err(syn::Error::new(
field.span(),
"generic fields require an explicit nested discriminant type",
));
}

let field_type = if visitor.type_uses_lifetime_param() {
self.type_replacing_lifetimes_with_static(&field.ty)
} else {
field.ty.clone()
};

let nested_type = parse_quote! {
<#field_type as ::enumcapsulate::VariantDiscriminant>::Discriminant
};
Expand Down Expand Up @@ -741,11 +750,18 @@ impl EnumDeriver {
}
}

fn type_replacing_lifetimes_with_static(&self, ty: &syn::Type) -> syn::Type {
let mut ty = ty.clone();
let mut visitor = TypeVisitorMut::default().replace_lifetimes_with_static();
visitor.visit_type_mut(&mut ty);
ty
}

fn uses_generic_const_or_type(&self, ty: &syn::Type) -> bool {
let mut visitor = TypeVisitor::new(&self.item.generics);

visitor.visit_type(ty);

visitor.uses_const_or_type_param()
visitor.type_uses_const_or_type_param()
}
}
1 change: 1 addition & 0 deletions macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::utils::tokenstream;
mod config;
mod enum_deriver;
mod type_visitor;
mod type_visitor_mut;
mod utils;

use self::{config::*, enum_deriver::*, type_visitor::*, utils::*};
Expand Down
45 changes: 37 additions & 8 deletions macros/src/type_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,48 @@ use std::collections::HashSet;
use syn::visit::Visit;

pub struct TypeVisitor<'ast> {
lifetime_param_idents: HashSet<&'ast syn::Ident>,
const_param_idents: HashSet<&'ast syn::Ident>,
type_param_idents: HashSet<&'ast syn::Ident>,

uses_const_param: bool,
uses_type_param: bool,
type_uses_lifetime_param: bool,
type_uses_const_param: bool,
type_uses_type_param: bool,
}

impl<'ast> TypeVisitor<'ast> {
pub fn new(generics: &'ast syn::Generics) -> Self {
Self {
lifetime_param_idents: generics
.lifetimes()
.map(|param| &param.lifetime.ident)
.collect(),
const_param_idents: generics.const_params().map(|param| &param.ident).collect(),
type_param_idents: generics.type_params().map(|param| &param.ident).collect(),
uses_const_param: false,
uses_type_param: false,
type_uses_lifetime_param: false,
type_uses_const_param: false,
type_uses_type_param: false,
}
}

pub fn uses_const_or_type_param(self) -> bool {
self.uses_const_param || self.uses_type_param
#[allow(dead_code)]
pub fn type_uses_lifetime_param(&self) -> bool {
self.type_uses_lifetime_param
}

#[allow(dead_code)]
pub fn type_uses_const_param(&self) -> bool {
self.type_uses_const_param
}

#[allow(dead_code)]
pub fn type_uses_type_param(&self) -> bool {
self.type_uses_type_param
}

#[allow(dead_code)]
pub fn type_uses_const_or_type_param(&self) -> bool {
self.type_uses_const_param || self.type_uses_type_param
}
}

Expand All @@ -39,12 +62,18 @@ impl<'ast> Visit<'ast> for TypeVisitor<'ast> {
let ident = &path_segment.ident;

if self.type_param_idents.contains(ident) {
self.uses_type_param = true;
self.type_uses_type_param = true;
} else if self.const_param_idents.contains(ident) {
self.uses_const_param = true;
self.type_uses_const_param = true;
}
}
}
syn::visit::visit_type_path(self, node);
}

fn visit_lifetime(&mut self, lifetime: &'ast syn::Lifetime) {
if self.lifetime_param_idents.contains(&lifetime.ident) {
self.type_uses_lifetime_param = true;
}
}
}
30 changes: 30 additions & 0 deletions macros/src/type_visitor_mut.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use syn::visit_mut::VisitMut;

pub struct TypeVisitorMut {
replace_lifetimes_with_static: bool,
}

impl TypeVisitorMut {
pub fn new() -> Self {
Self {
replace_lifetimes_with_static: false,
}
}

pub fn replace_lifetimes_with_static(mut self) -> Self {
self.replace_lifetimes_with_static = true;
self
}
}

impl Default for TypeVisitorMut {
fn default() -> Self {
Self::new()
}
}

impl VisitMut for TypeVisitorMut {
fn visit_lifetime_mut(&mut self, lifetime: &mut syn::Lifetime) {
lifetime.ident = quote::format_ident!("static");
}
}
187 changes: 187 additions & 0 deletions tests/derive-tests/variant_discriminant/pass/enum/lifetimes.out.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
use enumcapsulate::VariantDiscriminant;
pub enum VariantWithLifetime<'a> {
Variant(&'a ()),
}
pub enum VariantWithLifetimeDiscriminant {
Variant,
}
#[automatically_derived]
impl ::core::marker::Copy for VariantWithLifetimeDiscriminant {}
#[automatically_derived]
impl ::core::clone::Clone for VariantWithLifetimeDiscriminant {
#[inline]
fn clone(&self) -> VariantWithLifetimeDiscriminant {
*self
}
}
#[automatically_derived]
impl ::core::cmp::Ord for VariantWithLifetimeDiscriminant {
#[inline]
fn cmp(&self, other: &VariantWithLifetimeDiscriminant) -> ::core::cmp::Ordering {
::core::cmp::Ordering::Equal
}
}
#[automatically_derived]
impl ::core::cmp::PartialOrd for VariantWithLifetimeDiscriminant {
#[inline]
fn partial_cmp(
&self,
other: &VariantWithLifetimeDiscriminant,
) -> ::core::option::Option<::core::cmp::Ordering> {
::core::option::Option::Some(::core::cmp::Ordering::Equal)
}
}
#[automatically_derived]
impl ::core::cmp::Eq for VariantWithLifetimeDiscriminant {
#[inline]
#[doc(hidden)]
#[coverage(off)]
fn assert_receiver_is_total_eq(&self) -> () {}
}
#[automatically_derived]
impl ::core::marker::StructuralPartialEq for VariantWithLifetimeDiscriminant {}
#[automatically_derived]
impl ::core::cmp::PartialEq for VariantWithLifetimeDiscriminant {
#[inline]
fn eq(&self, other: &VariantWithLifetimeDiscriminant) -> bool {
true
}
}
#[automatically_derived]
impl ::core::hash::Hash for VariantWithLifetimeDiscriminant {
#[inline]
fn hash<__H: ::core::hash::Hasher>(&self, state: &mut __H) -> () {}
}
#[automatically_derived]
impl ::core::fmt::Debug for VariantWithLifetimeDiscriminant {
#[inline]
fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
::core::fmt::Formatter::write_str(f, "Variant")
}
}
impl<'a> ::enumcapsulate::VariantDiscriminant for VariantWithLifetime<'a> {
type Discriminant = VariantWithLifetimeDiscriminant;
fn variant_discriminant(&self) -> Self::Discriminant {
match self {
VariantWithLifetime::Variant(..) => VariantWithLifetimeDiscriminant::Variant,
_ => ::core::panicking::panic("internal error: entered unreachable code"),
}
}
}
pub enum EnumWithLifetime<'a> {
#[enumcapsulate(discriminant(nested))]
VariantA(VariantWithLifetime<'a>),
}
pub enum EnumWithLifetimeDiscriminant {
VariantA(
<VariantWithLifetime<
'static,
> as ::enumcapsulate::VariantDiscriminant>::Discriminant,
),
}
#[automatically_derived]
impl ::core::marker::Copy for EnumWithLifetimeDiscriminant {}
#[automatically_derived]
impl ::core::clone::Clone for EnumWithLifetimeDiscriminant {
#[inline]
fn clone(&self) -> EnumWithLifetimeDiscriminant {
let _: ::core::clone::AssertParamIsClone<
<VariantWithLifetime<
'static,
> as ::enumcapsulate::VariantDiscriminant>::Discriminant,
>;
*self
}
}
#[automatically_derived]
impl ::core::cmp::Ord for EnumWithLifetimeDiscriminant {
#[inline]
fn cmp(&self, other: &EnumWithLifetimeDiscriminant) -> ::core::cmp::Ordering {
match (self, other) {
(
EnumWithLifetimeDiscriminant::VariantA(__self_0),
EnumWithLifetimeDiscriminant::VariantA(__arg1_0),
) => ::core::cmp::Ord::cmp(__self_0, __arg1_0),
}
}
}
#[automatically_derived]
impl ::core::cmp::PartialOrd for EnumWithLifetimeDiscriminant {
#[inline]
fn partial_cmp(
&self,
other: &EnumWithLifetimeDiscriminant,
) -> ::core::option::Option<::core::cmp::Ordering> {
match (self, other) {
(
EnumWithLifetimeDiscriminant::VariantA(__self_0),
EnumWithLifetimeDiscriminant::VariantA(__arg1_0),
) => ::core::cmp::PartialOrd::partial_cmp(__self_0, __arg1_0),
}
}
}
#[automatically_derived]
impl ::core::cmp::Eq for EnumWithLifetimeDiscriminant {
#[inline]
#[doc(hidden)]
#[coverage(off)]
fn assert_receiver_is_total_eq(&self) -> () {
let _: ::core::cmp::AssertParamIsEq<
<VariantWithLifetime<
'static,
> as ::enumcapsulate::VariantDiscriminant>::Discriminant,
>;
}
}
#[automatically_derived]
impl ::core::marker::StructuralPartialEq for EnumWithLifetimeDiscriminant {}
#[automatically_derived]
impl ::core::cmp::PartialEq for EnumWithLifetimeDiscriminant {
#[inline]
fn eq(&self, other: &EnumWithLifetimeDiscriminant) -> bool {
match (self, other) {
(
EnumWithLifetimeDiscriminant::VariantA(__self_0),
EnumWithLifetimeDiscriminant::VariantA(__arg1_0),
) => *__self_0 == *__arg1_0,
}
}
}
#[automatically_derived]
impl ::core::hash::Hash for EnumWithLifetimeDiscriminant {
#[inline]
fn hash<__H: ::core::hash::Hasher>(&self, state: &mut __H) -> () {
match self {
EnumWithLifetimeDiscriminant::VariantA(__self_0) => {
::core::hash::Hash::hash(__self_0, state)
}
}
}
}
#[automatically_derived]
impl ::core::fmt::Debug for EnumWithLifetimeDiscriminant {
#[inline]
fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
match self {
EnumWithLifetimeDiscriminant::VariantA(__self_0) => {
::core::fmt::Formatter::debug_tuple_field1_finish(
f,
"VariantA",
&__self_0,
)
}
}
}
}
impl<'a> ::enumcapsulate::VariantDiscriminant for EnumWithLifetime<'a> {
type Discriminant = EnumWithLifetimeDiscriminant;
fn variant_discriminant(&self) -> Self::Discriminant {
match self {
EnumWithLifetime::VariantA(inner, ..) => {
EnumWithLifetimeDiscriminant::VariantA(inner.variant_discriminant())
}
_ => ::core::panicking::panic("internal error: entered unreachable code"),
}
}
}
fn main() {}
Loading
Loading