diff --git a/trait-variant/examples/variant.rs b/trait-variant/examples/variant.rs index 04bcc6a..dcd7189 100644 --- a/trait-variant/examples/variant.rs +++ b/trait-variant/examples/variant.rs @@ -43,4 +43,18 @@ where fn build(&self, items: impl Iterator) -> Self::B; } +#[trait_variant::make(Send + Sync)] +pub trait GenericTraitWithBounds<'x, S: Sync, Y, const X: usize> +where + Y: Sync, +{ + const CONST: usize = 3; + type F; + type A; + type B: FromIterator; + + async fn take(&self, s: S); + fn build(&self, items: impl Iterator) -> Self::B; +} + fn main() {} diff --git a/trait-variant/src/lib.rs b/trait-variant/src/lib.rs index d3286d9..ca03a4e 100644 --- a/trait-variant/src/lib.rs +++ b/trait-variant/src/lib.rs @@ -14,15 +14,15 @@ mod variant; /// fn` and/or `-> impl Trait` return types. /// /// ``` -/// #[trait_variant::make(IntFactory: Send)] -/// trait LocalIntFactory { +/// #[trait_variant::make(Send)] +/// trait IntFactory { /// async fn make(&self) -> i32; /// fn stream(&self) -> impl Iterator; /// fn call(&self) -> u32; /// } /// ``` /// -/// The above example causes a second trait called `IntFactory` to be created: +/// The above example causes the trait to be rewritten as: /// /// ``` /// # use core::future::Future; @@ -35,6 +35,19 @@ mod variant; /// /// Note that ordinary methods such as `call` are not affected. /// +/// If you want to preserve an original trait untouched, `make` can be used to create a new trait with bounds on `async +/// fn` and/or `-> impl Trait` return types. +/// +/// ``` +/// #[trait_variant::make(IntFactory: Send)] +/// trait LocalIntFactory { +/// async fn make(&self) -> i32; +/// fn stream(&self) -> impl Iterator; +/// fn call(&self) -> u32; +/// } +/// ``` +/// +/// The example causes a second trait called `IntFactory` to be created. /// Implementers of the trait can choose to implement the variant instead of the /// original trait. The macro creates a blanket impl which ensures that any type /// which implements the variant also implements the original trait. diff --git a/trait-variant/src/variant.rs b/trait-variant/src/variant.rs index f7f0d27..0a61f33 100644 --- a/trait-variant/src/variant.rs +++ b/trait-variant/src/variant.rs @@ -32,20 +32,33 @@ impl Parse for Attrs { } } -struct MakeVariant { - name: Ident, - #[allow(unused)] - colon: Token![:], - bounds: Punctuated, +enum MakeVariant { + // Creates a variant of a trait under a new name with additional bounds while preserving the original trait. + Create { + name: Ident, + _colon: Token![:], + bounds: Punctuated, + }, + // Rewrites the original trait into a new trait with additional bounds. + Rewrite { + bounds: Punctuated, + }, } impl Parse for MakeVariant { fn parse(input: ParseStream) -> Result { - Ok(Self { - name: input.parse()?, - colon: input.parse()?, - bounds: input.parse_terminated(TraitBound::parse, Token![+])?, - }) + let variant = if input.peek(Ident) && input.peek2(Token![:]) { + MakeVariant::Create { + name: input.parse()?, + _colon: input.parse()?, + bounds: input.parse_terminated(TraitBound::parse, Token![+])?, + } + } else { + MakeVariant::Rewrite { + bounds: input.parse_terminated(TraitBound::parse, Token![+])?, + } + }; + Ok(variant) } } @@ -56,43 +69,51 @@ pub fn make( let attrs = parse_macro_input!(attr as Attrs); let item = parse_macro_input!(item as ItemTrait); - let maybe_allow_async_lint = if attrs - .variant - .bounds - .iter() - .any(|b| b.path.segments.last().unwrap().ident == "Send") - { - quote! { #[allow(async_fn_in_trait)] } - } else { - quote! {} - }; + match attrs.variant { + MakeVariant::Create { name, bounds, .. } => { + let maybe_allow_async_lint = if bounds + .iter() + .any(|b| b.path.segments.last().unwrap().ident == "Send") + { + quote! { #[allow(async_fn_in_trait)] } + } else { + quote! {} + }; - let variant = mk_variant(&attrs, &item); - let blanket_impl = mk_blanket_impl(&attrs, &item); + let variant = mk_variant(&name, bounds, &item); + let blanket_impl = mk_blanket_impl(&name, &item); - quote! { - #maybe_allow_async_lint - #item + quote! { + #maybe_allow_async_lint + #item - #variant + #variant - #blanket_impl + #blanket_impl + } + .into() + } + MakeVariant::Rewrite { bounds, .. } => { + let variant = mk_variant(&item.ident, bounds, &item); + quote! { + #variant + } + .into() + } } - .into() } -fn mk_variant(attrs: &Attrs, tr: &ItemTrait) -> TokenStream { - let MakeVariant { - ref name, - colon: _, - ref bounds, - } = attrs.variant; - let bounds: Vec<_> = bounds +fn mk_variant( + variant: &Ident, + with_bounds: Punctuated, + tr: &ItemTrait, +) -> TokenStream { + let bounds: Vec<_> = with_bounds .into_iter() .map(|b| TypeParamBound::Trait(b.clone())) .collect(); let variant = ItemTrait { - ident: name.clone(), + ident: variant.clone(), supertraits: tr.supertraits.iter().chain(&bounds).cloned().collect(), items: tr .items @@ -104,21 +125,8 @@ fn mk_variant(attrs: &Attrs, tr: &ItemTrait) -> TokenStream { quote! { #variant } } +// Transforms one item declaration within the definition if it has `async fn` and/or `-> impl Trait` return types by adding new bounds. fn transform_item(item: &TraitItem, bounds: &Vec) -> TraitItem { - // #[make_variant(SendIntFactory: Send)] - // trait IntFactory { - // async fn make(&self, x: u32, y: &str) -> i32; - // fn stream(&self) -> impl Iterator; - // fn call(&self) -> u32; - // } - // - // becomes: - // - // trait SendIntFactory: Send { - // fn make(&self, x: u32, y: &str) -> impl ::core::future::Future + Send; - // fn stream(&self) -> impl Iterator + Send; - // fn call(&self) -> u32; - // } let TraitItem::Fn(fn_item @ TraitItemFn { sig, .. }) = item else { return item.clone(); }; @@ -160,9 +168,8 @@ fn transform_item(item: &TraitItem, bounds: &Vec) -> TraitItem { }) } -fn mk_blanket_impl(attrs: &Attrs, tr: &ItemTrait) -> TokenStream { +fn mk_blanket_impl(variant: &Ident, tr: &ItemTrait) -> TokenStream { let orig = &tr.ident; - let variant = &attrs.variant.name; let (_impl, orig_ty_generics, _where) = &tr.generics.split_for_impl(); let items = tr .items