Skip to content

Commit

Permalink
Support omitting the variant name
Browse files Browse the repository at this point in the history
Let's say that we want to create a `Send` variant of a trait but we
don't need a non-`Send` variant of that trait at all.  We could of
course just write that trait, but adding the `Send` bounds in all the
right places could be annoying.

In this commit, we allow simply omitting the new variant name from the
call to `make`.  When called that way, we use the name from the item
to emit only one variant with the bounds applied.  We don't emit the
original item.

For completeness and explicit disambiguation, we support prefixing the
bounds with a colon (but giving no variant name).  Similarly, for
completeness, we support giving the same name as the trait item.  In
both of these cases, we just emit the one variant.  Since these are
for completeness, we don't advertise these syntaxes in the
documentation.

That is, we now support:

- `make(NAME: BOUNDS)`
- `make(NAME:)`
- `make(:BOUNDS)`
- `make(BOUNDS)`

This resolves #18.
  • Loading branch information
traviscross committed Feb 10, 2024
1 parent f81bb09 commit c8a15fe
Show file tree
Hide file tree
Showing 8 changed files with 336 additions and 35 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ trait IntFactory: Send {

Implementers can choose to implement either `LocalIntFactory` or `IntFactory` as appropriate.

If a non-`Send` variant of the trait is not needed, the name of the new variant can simply be omitted. E.g., this generates a *single* (rather than an additional) trait whose definition matches that in the expansion above:

```rust
#[trait_variant::make(Send)]
trait IntFactory {
async fn make(&self) -> i32;
fn stream(&self) -> impl Iterator<Item = i32>;
fn call(&self) -> u32;
}
```

For more details, see the docs for [`trait_variant::make`].

[`trait_variant::make`]: https://docs.rs/trait-variant/latest/trait_variant/attr.make.html
Expand Down
5 changes: 5 additions & 0 deletions trait-variant/examples/variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ fn spawn_task(factory: impl IntFactory + 'static) {
});
}

#[trait_variant::make(Send)]
pub trait TupleFactory {
async fn new() -> Self;
}

#[trait_variant::make(GenericTrait: Send)]
pub trait LocalGenericTrait<'x, S: Sync, Y, const X: usize>
where
Expand Down
13 changes: 13 additions & 0 deletions trait-variant/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ mod variant;
/// Implementers of the trait can choose to implement the variant instead of the
/// original trait. The macro creates a blanket impl which ensures that any type
/// which implements the variant also implements the original trait.
///
/// If a non-`Send` variant of the trait is not needed, the name of
/// new variant can simply be omitted. E.g., this generates a
/// *single* (rather than an additional) trait whose definition
/// matches that in the expansion above:
///
/// #[trait_variant::make(Send)]
/// trait IntFactory {
/// async fn make(&self) -> i32;
/// fn stream(&self) -> impl Iterator<Item = i32>;
/// fn call(&self) -> u32;
/// }
/// ```
#[proc_macro_attribute]
pub fn make(
attr: proc_macro::TokenStream,
Expand Down
81 changes: 46 additions & 35 deletions trait-variant/src/variant.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// Copyright (c) 2023 Google LLC
// Copyright (c) 2023 Various contributors (see git history)
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
Expand All @@ -11,7 +12,7 @@ use std::iter;
use proc_macro2::TokenStream;
use quote::quote;
use syn::{
parse::{Parse, ParseStream},
parse::{discouraged::Speculative as _, Parse, ParseStream},
parse_macro_input, parse_quote,
punctuated::Punctuated,
token::Plus,
Expand All @@ -20,44 +21,57 @@ use syn::{
TypeImplTrait, TypeParam, TypeParamBound,
};

struct Attrs {
variant: MakeVariant,
#[derive(Clone)]
struct Variant {
name: Option<Ident>,
_colon: Option<Token![:]>,
bounds: Punctuated<TraitBound, Plus>,
}

impl Parse for Attrs {
fn parse(input: ParseStream) -> Result<Self> {
Ok(Self {
variant: MakeVariant::parse(input)?,
})
}
fn parse_bounds_only(input: ParseStream) -> Result<Option<Variant>> {
let fork = input.fork();
let colon: Option<Token![:]> = fork.parse()?;
let bounds = match fork.parse_terminated(TraitBound::parse, Token![+]) {
Ok(x) => Ok(x),
Err(e) if colon.is_some() => Err(e),
Err(_) => return Ok(None),
};
input.advance_to(&fork);
Ok(Some(Variant {
name: None,
_colon: colon,
bounds: bounds?,
}))
}

struct MakeVariant {
name: Ident,
#[allow(unused)]
colon: Token![:],
bounds: Punctuated<TraitBound, Plus>,
fn parse_fallback(input: ParseStream) -> Result<Variant> {
let name: Ident = input.parse()?;
let colon: Token![:] = input.parse()?;
let bounds = input.parse_terminated(TraitBound::parse, Token![+])?;
Ok(Variant {
name: Some(name),
_colon: Some(colon),
bounds,
})
}

impl Parse for MakeVariant {
impl Parse for Variant {
fn parse(input: ParseStream) -> Result<Self> {
Ok(Self {
name: input.parse()?,
colon: input.parse()?,
bounds: input.parse_terminated(TraitBound::parse, Token![+])?,
})
match parse_bounds_only(input)? {
Some(x) => Ok(x),
None => parse_fallback(input),
}
}
}

pub fn make(
attr: proc_macro::TokenStream,
item: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let attrs = parse_macro_input!(attr as Attrs);
let variant = parse_macro_input!(attr as Variant);
let item = parse_macro_input!(item as ItemTrait);

let maybe_allow_async_lint = if attrs
.variant
let maybe_allow_async_lint = if variant
.bounds
.iter()
.any(|b| b.path.segments.last().unwrap().ident == "Send")
Expand All @@ -67,26 +81,24 @@ pub fn make(
quote! {}
};

let variant = mk_variant(&attrs, &item);
let blanket_impl = mk_blanket_impl(&attrs, &item);

let variant_name = variant.clone().name.unwrap_or(item.clone().ident);
let variant_def = mk_variant(&variant_name, &variant.bounds, &item);
if variant_name == item.ident {
return variant_def.into();
}
let blanket_impl = Some(mk_blanket_impl(&variant_name, &item));
quote! {
#maybe_allow_async_lint
#item

#variant
#variant_def

#blanket_impl
}
.into()
}

fn mk_variant(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
let MakeVariant {
ref name,
colon: _,
ref bounds,
} = attrs.variant;
fn mk_variant(name: &Ident, bounds: &Punctuated<TraitBound, Plus>, tr: &ItemTrait) -> TokenStream {
let bounds: Vec<_> = bounds
.into_iter()
.map(|b| TypeParamBound::Trait(b.clone()))
Expand Down Expand Up @@ -160,9 +172,8 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
})
}

fn mk_blanket_impl(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
fn mk_blanket_impl(variant: &Ident, tr: &ItemTrait) -> TokenStream {
let orig = &tr.ident;
let variant = &attrs.variant.name;
let (_impl, orig_ty_generics, _where) = &tr.generics.split_for_impl();
let items = tr
.items
Expand Down
65 changes: 65 additions & 0 deletions trait-variant/tests/bounds.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (c) 2023 Google LLC
// Copyright (c) 2023 Various contributors (see git history)
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

#[trait_variant::make(Send + Sync)]
pub trait Trait {
const CONST: &'static ();
type Gat<'a>
where
Self: 'a;
async fn assoc_async_fn_no_ret(a: (), b: ());
async fn assoc_async_method_no_ret(&self, a: (), b: ());
async fn assoc_async_fn(a: (), b: ()) -> ();
async fn assoc_async_method(&self, a: (), b: ()) -> ();
fn assoc_sync_fn_no_ret(a: (), b: ());
fn assoc_sync_method_no_ret(&self, a: (), b: ());
fn assoc_sync_fn(a: (), b: ()) -> ();
fn assoc_sync_method(&self, a: (), b: ()) -> ();
// FIXME: See #17.
//async fn dft_assoc_async_fn_no_ret(_a: (), _b: ()) {}
//async fn dft_assoc_async_method_no_ret(&self, _a: (), _b: ()) {}
//async fn dft_assoc_async_fn(_a: (), _b: ()) -> () {}
//async fn dft_assoc_async_method(&self, _a: (), _b: ()) -> () {}
fn dft_assoc_sync_fn_no_ret(_a: (), _b: ()) {}
fn dft_assoc_sync_method_no_ret(&self, _a: (), _b: ()) {}
fn dft_assoc_sync_fn(_a: (), _b: ()) -> () {}
fn dft_assoc_sync_method(&self, _a: (), _b: ()) -> () {}
}

impl Trait for () {
const CONST: &'static () = &();
type Gat<'a> = ();
async fn assoc_async_fn_no_ret(_a: (), _b: ()) {}
async fn assoc_async_method_no_ret(&self, _a: (), _b: ()) {}
async fn assoc_async_fn(_a: (), _b: ()) -> () {}
async fn assoc_async_method(&self, _a: (), _b: ()) -> () {}
fn assoc_sync_fn_no_ret(_a: (), _b: ()) {}
fn assoc_sync_method_no_ret(&self, _a: (), _b: ()) {}
fn assoc_sync_fn(_a: (), _b: ()) -> () {}
fn assoc_sync_method(&self, _a: (), _b: ()) -> () {}
}

fn is_bounded<T: Send + Sync>(_: T) {}

#[test]
fn test() {
fn inner<T: Trait>(x: T) {
let (a, b) = ((), ());
is_bounded(<T as Trait>::assoc_async_fn_no_ret(a, b));
is_bounded(<T as Trait>::assoc_async_method_no_ret(&x, a, b));
is_bounded(<T as Trait>::assoc_async_fn(a, b));
is_bounded(<T as Trait>::assoc_async_method(&x, a, b));
// FIXME: See #17.
//is_bounded(<T as Trait>::dft_assoc_async_fn_no_ret(a, b));
//is_bounded(<T as Trait>::dft_assoc_async_method_no_ret(&x, a, b));
//is_bounded(<T as Trait>::dft_assoc_async_fn(a, b));
//is_bounded(<T as Trait>::dft_assoc_async_method(&x, a, b));
}
inner(());
}
65 changes: 65 additions & 0 deletions trait-variant/tests/colon-bounds.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (c) 2023 Google LLC
// Copyright (c) 2023 Various contributors (see git history)
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

#[trait_variant::make(: Send + Sync)]
pub trait Trait {
const CONST: &'static ();
type Gat<'a>
where
Self: 'a;
async fn assoc_async_fn_no_ret(a: (), b: ());
async fn assoc_async_method_no_ret(&self, a: (), b: ());
async fn assoc_async_fn(a: (), b: ()) -> ();
async fn assoc_async_method(&self, a: (), b: ()) -> ();
fn assoc_sync_fn_no_ret(a: (), b: ());
fn assoc_sync_method_no_ret(&self, a: (), b: ());
fn assoc_sync_fn(a: (), b: ()) -> ();
fn assoc_sync_method(&self, a: (), b: ()) -> ();
// FIXME: See #17.
//async fn dft_assoc_async_fn_no_ret(_a: (), _b: ()) {}
//async fn dft_assoc_async_method_no_ret(&self, _a: (), _b: ()) {}
//async fn dft_assoc_async_fn(_a: (), _b: ()) -> () {}
//async fn dft_assoc_async_method(&self, _a: (), _b: ()) -> () {}
fn dft_assoc_sync_fn_no_ret(_a: (), _b: ()) {}
fn dft_assoc_sync_method_no_ret(&self, _a: (), _b: ()) {}
fn dft_assoc_sync_fn(_a: (), _b: ()) -> () {}
fn dft_assoc_sync_method(&self, _a: (), _b: ()) -> () {}
}

impl Trait for () {
const CONST: &'static () = &();
type Gat<'a> = ();
async fn assoc_async_fn_no_ret(_a: (), _b: ()) {}
async fn assoc_async_method_no_ret(&self, _a: (), _b: ()) {}
async fn assoc_async_fn(_a: (), _b: ()) -> () {}
async fn assoc_async_method(&self, _a: (), _b: ()) -> () {}
fn assoc_sync_fn_no_ret(_a: (), _b: ()) {}
fn assoc_sync_method_no_ret(&self, _a: (), _b: ()) {}
fn assoc_sync_fn(_a: (), _b: ()) -> () {}
fn assoc_sync_method(&self, _a: (), _b: ()) -> () {}
}

fn is_bounded<T: Send + Sync>(_: T) {}

#[test]
fn test() {
fn inner<T: Trait>(x: T) {
let (a, b) = ((), ());
is_bounded(<T as Trait>::assoc_async_fn_no_ret(a, b));
is_bounded(<T as Trait>::assoc_async_method_no_ret(&x, a, b));
is_bounded(<T as Trait>::assoc_async_fn(a, b));
is_bounded(<T as Trait>::assoc_async_method(&x, a, b));
// FIXME: See #17.
//is_bounded(<T as Trait>::dft_assoc_async_fn_no_ret(a, b));
//is_bounded(<T as Trait>::dft_assoc_async_method_no_ret(&x, a, b));
//is_bounded(<T as Trait>::dft_assoc_async_fn(a, b));
//is_bounded(<T as Trait>::dft_assoc_async_method(&x, a, b));
}
inner(());
}
66 changes: 66 additions & 0 deletions trait-variant/tests/name-colon.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright (c) 2023 Google LLC
// Copyright (c) 2023 Various contributors (see git history)
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

#[trait_variant::make(Trait:)]
#[allow(async_fn_in_trait)]
pub trait LocalTrait {
const CONST: &'static ();
type Gat<'a>
where
Self: 'a;
async fn assoc_async_fn_no_ret(a: (), b: ());
async fn assoc_async_method_no_ret(&self, a: (), b: ());
async fn assoc_async_fn(a: (), b: ()) -> ();
async fn assoc_async_method(&self, a: (), b: ()) -> ();
fn assoc_sync_fn_no_ret(a: (), b: ());
fn assoc_sync_method_no_ret(&self, a: (), b: ());
fn assoc_sync_fn(a: (), b: ()) -> ();
fn assoc_sync_method(&self, a: (), b: ()) -> ();
// FIXME: See #17.
//async fn dft_assoc_async_fn_no_ret(_a: (), _b: ()) {}
//async fn dft_assoc_async_method_no_ret(&self, _a: (), _b: ()) {}
//async fn dft_assoc_async_fn(_a: (), _b: ()) -> () {}
//async fn dft_assoc_async_method(&self, _a: (), _b: ()) -> () {}
fn dft_assoc_sync_fn_no_ret(_a: (), _b: ()) {}
fn dft_assoc_sync_method_no_ret(&self, _a: (), _b: ()) {}
fn dft_assoc_sync_fn(_a: (), _b: ()) -> () {}
fn dft_assoc_sync_method(&self, _a: (), _b: ()) -> () {}
}

impl Trait for () {
const CONST: &'static () = &();
type Gat<'a> = ();
async fn assoc_async_fn_no_ret(_a: (), _b: ()) {}
async fn assoc_async_method_no_ret(&self, _a: (), _b: ()) {}
async fn assoc_async_fn(_a: (), _b: ()) -> () {}
async fn assoc_async_method(&self, _a: (), _b: ()) -> () {}
fn assoc_sync_fn_no_ret(_a: (), _b: ()) {}
fn assoc_sync_method_no_ret(&self, _a: (), _b: ()) {}
fn assoc_sync_fn(_a: (), _b: ()) -> () {}
fn assoc_sync_method(&self, _a: (), _b: ()) -> () {}
}

fn is_bounded<T>(_: T) {}

#[test]
fn test() {
fn inner<T: Trait>(x: T) {
let (a, b) = ((), ());
is_bounded(<T as Trait>::assoc_async_fn_no_ret(a, b));
is_bounded(<T as Trait>::assoc_async_method_no_ret(&x, a, b));
is_bounded(<T as Trait>::assoc_async_fn(a, b));
is_bounded(<T as Trait>::assoc_async_method(&x, a, b));
// FIXME: See #17.
//is_bounded(<T as Trait>::dft_assoc_async_fn_no_ret(a, b));
//is_bounded(<T as Trait>::dft_assoc_async_method_no_ret(&x, a, b));
//is_bounded(<T as Trait>::dft_assoc_async_fn(a, b));
//is_bounded(<T as Trait>::dft_assoc_async_method(&x, a, b));
}
inner(());
}
Loading

0 comments on commit c8a15fe

Please sign in to comment.