diff --git a/trait-variant/Cargo.toml b/trait-variant/Cargo.toml index 3307ece..8d5e68c 100644 --- a/trait-variant/Cargo.toml +++ b/trait-variant/Cargo.toml @@ -25,7 +25,7 @@ proc-macro = true [dependencies] proc-macro2 = "1.0" quote = "1.0" -syn = { version = "2.0", features = ["full"] } +syn = { version = "2.0", features = ["full", "visit-mut"] } [dev-dependencies] tokio = { version = "1", features = ["rt"] } diff --git a/trait-variant/examples/variant.rs b/trait-variant/examples/variant.rs index 88faece..dac3048 100644 --- a/trait-variant/examples/variant.rs +++ b/trait-variant/examples/variant.rs @@ -21,8 +21,8 @@ pub trait LocalIntFactory { fn stream(&self) -> impl Iterator; fn call(&self) -> u32; fn another_async(&self, input: Result<(), &str>) -> Self::MyFut<'_>; - async fn defaulted(&self) -> i32 { - self.make(10, "10").await + async fn defaulted(&self, x: u32) -> i32 { + self.make(x, "10").await } async fn defaulted_mut(&mut self) -> i32 { self.make(10, "10").await diff --git a/trait-variant/src/variant.rs b/trait-variant/src/variant.rs index 04e9058..a13891f 100644 --- a/trait-variant/src/variant.rs +++ b/trait-variant/src/variant.rs @@ -15,9 +15,9 @@ use syn::{ parse_macro_input, parse_quote, punctuated::Punctuated, token::{Comma, Plus}, - Error, FnArg, GenericParam, Generics, Ident, ItemTrait, Lifetime, Pat, PatType, Receiver, - Result, ReturnType, Signature, Token, TraitBound, TraitItem, TraitItemConst, TraitItemFn, - TraitItemType, Type, TypeImplTrait, TypeParamBound, WhereClause, + Error, FnArg, GenericParam, Generics, Ident, ItemTrait, Lifetime, Pat, PatIdent, PatType, + Receiver, Result, ReturnType, Signature, Token, TraitBound, TraitItem, TraitItemConst, + TraitItemFn, TraitItemType, Type, TypeImplTrait, TypeParamBound, WhereClause, }; struct Attrs { @@ -145,10 +145,32 @@ fn transform_item(item: &TraitItem, bounds: &Vec) -> TraitItem { output: ReturnType::Type(syn::parse2(quote! { -> }).unwrap(), Box::new(ty)), ..sig.clone() }, - fn_item - .default - .as_ref() - .map(|b| syn::parse2(quote! { { async move #b } }).unwrap()), + fn_item.default.as_ref().map(|b| { + let items = sig.inputs.iter().map(|i| match i { + FnArg::Receiver(Receiver { self_token, .. }) => { + quote! { let __self = #self_token; } + } + FnArg::Typed(PatType { pat, .. }) => match pat.as_ref() { + Pat::Ident(PatIdent { ident, .. }) => quote! { let #ident = #ident; }, + _ => todo!(), + }, + }); + + struct ReplaceSelfVisitor; + impl syn::visit_mut::VisitMut for ReplaceSelfVisitor { + fn visit_ident_mut(&mut self, ident: &mut syn::Ident) { + if ident == "self" { + *ident = syn::Ident::new("__self", ident.span()); + } + syn::visit_mut::visit_ident_mut(self, ident); + } + } + + let mut block = b.clone(); + syn::visit_mut::visit_block_mut(&mut ReplaceSelfVisitor, &mut block); + + parse_quote! { { async move { #(#items)* #block} } } + }), ) } else { match &sig.output {