Skip to content

Commit

Permalink
Add support for defaulted methods
Browse files Browse the repository at this point in the history
  • Loading branch information
mendess committed Jan 16, 2024
1 parent 90c80bd commit de24caa
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 15 deletions.
16 changes: 16 additions & 0 deletions trait-variant/examples/variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,25 @@ pub trait LocalIntFactory {
Self: 'a;

async fn make(&self, x: u32, y: &str) -> i32;
async fn make_mut(&mut self);
fn stream(&self) -> impl Iterator<Item = i32>;
fn call(&self) -> u32;
fn another_async(&self, input: Result<(), &str>) -> Self::MyFut<'_>;
async fn defaulted(&self) -> i32 {
self.make(10, "10").await
}
async fn defaulted_mut(&mut self) -> i32 {
self.make(10, "10").await
}
async fn defaulted_mut_2(&mut self) {
self.make_mut().await
}
async fn defaulted_move(self) -> i32
where
Self: Sized,
{
self.make(10, "10").await
}
}

#[allow(dead_code)]
Expand Down
112 changes: 97 additions & 15 deletions trait-variant/src/variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@

use std::iter;

use proc_macro2::TokenStream;
use proc_macro2::{Span, TokenStream};
use quote::{quote, ToTokens};
use syn::{
parse::{Parse, ParseStream},
parse_macro_input,
parse_macro_input, parse_quote,
punctuated::Punctuated,
token::{Comma, Plus},
Error, FnArg, GenericParam, Generics, Ident, ItemTrait, Lifetime, Pat, PatType, Result,
ReturnType, Signature, Token, TraitBound, TraitItem, TraitItemConst, TraitItemFn,
TraitItemType, Type, TypeImplTrait, TypeParamBound,
Error, FnArg, GenericParam, Generics, Ident, ItemTrait, Lifetime, Pat, PatType, Receiver,
Result, ReturnType, Signature, Token, TraitBound, TraitItem, TraitItemConst, TraitItemFn,
TraitItemType, Type, TypeImplTrait, TypeParamBound, TypeReference, WhereClause,
};

struct Attrs {
Expand Down Expand Up @@ -119,10 +119,10 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
// fn stream(&self) -> impl Iterator<Item = i32> + Send;
// fn call(&self) -> u32;
// }
let TraitItem::Fn(fn_item @ TraitItemFn { sig, .. }) = item else {
let TraitItem::Fn(fn_item @ TraitItemFn { sig, default, .. }) = item else {
return item.clone();
};
let (arrow, output) = if sig.asyncness.is_some() {
let (sig, default) = if sig.asyncness.is_some() {
let orig = match &sig.output {
ReturnType::Default => quote! { () },
ReturnType::Type(_, ty) => quote! { #ty },
Expand All @@ -134,7 +134,22 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
.chain(bounds.iter().cloned())
.collect(),
});
(syn::parse2(quote! { -> }).unwrap(), ty)
let mut sig = sig.clone();
if default.is_some() {
add_receiver_bounds(&mut sig);
}

(
Signature {
asyncness: None,
output: ReturnType::Type(syn::parse2(quote! { -> }).unwrap(), Box::new(ty)),
..sig.clone()
},
fn_item
.default
.as_ref()
.map(|b| syn::parse2(quote! { { async move #b } }).unwrap()),
)
} else {
match &sig.output {
ReturnType::Type(arrow, ty) => match &**ty {
Expand All @@ -143,19 +158,22 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
impl_token: it.impl_token,
bounds: it.bounds.iter().chain(bounds).cloned().collect(),
});
(*arrow, ty)
(
Signature {
output: ReturnType::Type(*arrow, Box::new(ty)),
..sig.clone()
},
fn_item.default.clone(),
)
}
_ => return item.clone(),
},
ReturnType::Default => return item.clone(),
}
};
TraitItem::Fn(TraitItemFn {
sig: Signature {
asyncness: None,
output: ReturnType::Type(arrow, Box::new(output)),
..sig.clone()
},
sig,
default,
..fn_item.clone()
})
}
Expand Down Expand Up @@ -184,7 +202,26 @@ fn mk_blanket_impl(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
.items
.iter()
.map(|item| blanket_impl_item(item, variant, &generic_names));
let where_clauses = tr.generics.where_clause.as_ref().map(|wh| &wh.predicates);
let mut where_clauses = tr
.generics
.where_clause
.as_ref()
.map(|wh| wh.predicates.clone())
.unwrap_or_default();
let self_is_sync = tr.items.iter().any(|item| {
matches!(
item,
TraitItem::Fn(TraitItemFn {
default: Some(_),
..
})
)
});

if self_is_sync {
where_clauses.push(parse_quote! { for<'s> &'s Self: Send });
}

quote! {
impl<#generics #trailing_comma TraitVariantBlanketType> #orig<#generic_names>
for TraitVariantBlanketType
Expand Down Expand Up @@ -249,6 +286,7 @@ fn blanket_impl_item(
} else {
quote! {}
};

quote! {
#sig {
<Self as #variant<#generic_names>>::#ident(#(#args),*)#maybe_await
Expand All @@ -272,3 +310,47 @@ fn blanket_impl_item(
_ => Error::new_spanned(item, "unsupported item type").into_compile_error(),
}
}

fn add_receiver_bounds(sig: &mut Signature) {
let Some(FnArg::Receiver(Receiver { ty, reference, .. })) = sig.inputs.first_mut() else {
return;
};
let Type::Reference(
recv_ty @ TypeReference {
mutability: None, ..
},
) = &mut **ty
else {
return;
};
let Some((_and, lt)) = reference else {
return;
};

let lifetime = syn::Lifetime {
apostrophe: Span::mixed_site(),
ident: Ident::new("the_self_lt", Span::mixed_site()),
};
sig.generics.params.insert(
0,
syn::GenericParam::Lifetime(syn::LifetimeParam {
lifetime: lifetime.clone(),
colon_token: None,
bounds: Default::default(),
attrs: Default::default(),
}),
);
recv_ty.lifetime = Some(lifetime.clone());
*lt = Some(lifetime);
let predicate = parse_quote! { #recv_ty: Send };

if let Some(wh) = &mut sig.generics.where_clause {
wh.predicates.push(predicate);
} else {
let where_clause = WhereClause {
where_token: Token![where](Span::mixed_site()),
predicates: Punctuated::from_iter([predicate]),
};
sig.generics.where_clause = Some(where_clause);
}
}

0 comments on commit de24caa

Please sign in to comment.