Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: trait_variant::make supports rewriting of the original trait. #27

Merged
merged 2 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions trait-variant/examples/variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,18 @@ where
fn build<T: Display>(&self, items: impl Iterator<Item = T>) -> Self::B<T>;
}

#[trait_variant::make(Send + Sync)]
pub trait GenericTraitWithBounds<'x, S: Sync, Y, const X: usize>
where
Y: Sync,
{
const CONST: usize = 3;
type F;
type A<const ANOTHER_CONST: u8>;
type B<T: Display>: FromIterator<T>;

async fn take(&self, s: S);
fn build<T: Display>(&self, items: impl Iterator<Item = T>) -> Self::B<T>;
}

fn main() {}
19 changes: 16 additions & 3 deletions trait-variant/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ mod variant;
/// fn` and/or `-> impl Trait` return types.
///
/// ```
/// #[trait_variant::make(IntFactory: Send)]
/// trait LocalIntFactory {
/// #[trait_variant::make(Send)]
/// trait IntFactory {
/// async fn make(&self) -> i32;
/// fn stream(&self) -> impl Iterator<Item = i32>;
/// fn call(&self) -> u32;
/// }
/// ```
///
/// The above example causes a second trait called `IntFactory` to be created:
/// The above example causes the trait to be rewritten as:
///
/// ```
/// # use core::future::Future;
Expand All @@ -35,6 +35,19 @@ mod variant;
///
/// Note that ordinary methods such as `call` are not affected.
///
/// If you want to preserve an original trait untouched, `make` can be used to create a new trait with bounds on `async
/// fn` and/or `-> impl Trait` return types.
///
/// ```
/// #[trait_variant::make(IntFactory: Send)]
/// trait LocalIntFactory {
/// async fn make(&self) -> i32;
/// fn stream(&self) -> impl Iterator<Item = i32>;
/// fn call(&self) -> u32;
/// }
/// ```
///
/// The example causes a second trait called `IntFactory` to be created.
/// Implementers of the trait can choose to implement the variant instead of the
/// original trait. The macro creates a blanket impl which ensures that any type
/// which implements the variant also implements the original trait.
Expand Down
111 changes: 59 additions & 52 deletions trait-variant/src/variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,33 @@ impl Parse for Attrs {
}
}

struct MakeVariant {
name: Ident,
#[allow(unused)]
colon: Token![:],
bounds: Punctuated<TraitBound, Plus>,
enum MakeVariant {
// Creates a variant of a trait under a new name with additional bounds while preserving the original trait.
Create {
name: Ident,
_colon: Token![:],
bounds: Punctuated<TraitBound, Plus>,
},
// Rewrites the original trait into a new trait with additional bounds.
Rewrite {
bounds: Punctuated<TraitBound, Plus>,
},
}

impl Parse for MakeVariant {
fn parse(input: ParseStream) -> Result<Self> {
Ok(Self {
name: input.parse()?,
colon: input.parse()?,
bounds: input.parse_terminated(TraitBound::parse, Token![+])?,
})
let variant = if input.peek(Ident) && input.peek2(Token![:]) {
MakeVariant::Create {
name: input.parse()?,
_colon: input.parse()?,
bounds: input.parse_terminated(TraitBound::parse, Token![+])?,
}
} else {
MakeVariant::Rewrite {
bounds: input.parse_terminated(TraitBound::parse, Token![+])?,
}
};
Ok(variant)
}
}

Expand All @@ -56,43 +69,51 @@ pub fn make(
let attrs = parse_macro_input!(attr as Attrs);
let item = parse_macro_input!(item as ItemTrait);

let maybe_allow_async_lint = if attrs
.variant
.bounds
.iter()
.any(|b| b.path.segments.last().unwrap().ident == "Send")
{
quote! { #[allow(async_fn_in_trait)] }
} else {
quote! {}
};
match attrs.variant {
MakeVariant::Create { name, bounds, .. } => {
let maybe_allow_async_lint = if bounds
.iter()
.any(|b| b.path.segments.last().unwrap().ident == "Send")
{
quote! { #[allow(async_fn_in_trait)] }
} else {
quote! {}
};

let variant = mk_variant(&attrs, &item);
let blanket_impl = mk_blanket_impl(&attrs, &item);
let variant = mk_variant(&name, bounds, &item);
let blanket_impl = mk_blanket_impl(&name, &item);

quote! {
#maybe_allow_async_lint
#item
quote! {
#maybe_allow_async_lint
#item

#variant
#variant

#blanket_impl
#blanket_impl
}
.into()
}
MakeVariant::Rewrite { bounds, .. } => {
let variant = mk_variant(&item.ident, bounds, &item);
quote! {
#variant
}
.into()
}
}
.into()
}

fn mk_variant(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
let MakeVariant {
ref name,
colon: _,
ref bounds,
} = attrs.variant;
let bounds: Vec<_> = bounds
fn mk_variant(
variant: &Ident,
with_bounds: Punctuated<TraitBound, Plus>,
tr: &ItemTrait,
) -> TokenStream {
let bounds: Vec<_> = with_bounds
.into_iter()
.map(|b| TypeParamBound::Trait(b.clone()))
.collect();
let variant = ItemTrait {
ident: name.clone(),
ident: variant.clone(),
supertraits: tr.supertraits.iter().chain(&bounds).cloned().collect(),
items: tr
.items
Expand All @@ -104,21 +125,8 @@ fn mk_variant(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
quote! { #variant }
}

// Transforms a one item declaration within the definition if it has `async fn` and/or `-> impl Trait` return types by adding new bounds.
tmandry marked this conversation as resolved.
Show resolved Hide resolved
fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
// #[make_variant(SendIntFactory: Send)]
// trait IntFactory {
// async fn make(&self, x: u32, y: &str) -> i32;
// fn stream(&self) -> impl Iterator<Item = i32>;
// fn call(&self) -> u32;
// }
//
// becomes:
//
// trait SendIntFactory: Send {
// fn make(&self, x: u32, y: &str) -> impl ::core::future::Future<Output = i32> + Send;
// fn stream(&self) -> impl Iterator<Item = i32> + Send;
// fn call(&self) -> u32;
// }
let TraitItem::Fn(fn_item @ TraitItemFn { sig, .. }) = item else {
return item.clone();
};
Expand Down Expand Up @@ -160,9 +168,8 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
})
}

fn mk_blanket_impl(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
fn mk_blanket_impl(variant: &Ident, tr: &ItemTrait) -> TokenStream {
let orig = &tr.ident;
let variant = &attrs.variant.name;
let (_impl, orig_ty_generics, _where) = &tr.generics.split_for_impl();
let items = tr
.items
Expand Down
Loading