Skip to content

Commit

Permalink
Get receiverless_trait_function test to pass
Browse files Browse the repository at this point in the history
  • Loading branch information
smoelius committed Jan 18, 2024
1 parent 72f9579 commit 02249fe
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 26 deletions.
79 changes: 55 additions & 24 deletions macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ fn map_method_or_fn(
let impl_generic_args = opts_impl_generic_args.as_ref().map(args_as_turbofish);
let generic_args = opts_generic_args.as_ref().map(args_as_turbofish);
let combined_generic_args_base = combine_options(
opts_impl_generic_args,
opts_impl_generic_args.clone(),
opts_generic_args,
|mut left, right| {
left.extend(right);
Expand Down Expand Up @@ -371,7 +371,7 @@ fn map_method_or_fn(

let self_ty_base = self_ty.and_then(type_utils::type_base);

let (receiver, mut arg_tys, fmt_args, mut ser_args, de_args) = {
let (mut arg_tys, fmt_args, mut ser_args, de_args) = {
let mut candidates = BTreeSet::new();
let result = map_args(
&mut conversions,
Expand Down Expand Up @@ -530,19 +530,23 @@ fn map_method_or_fn(
let args_ret_ty: Type = parse_quote! {
<Args #combined_generic_args_with_dummy_lifetimes as HasRetTy>::RetTy
};
let call: Expr = if receiver {
let mut de_args = de_args.iter();
let self_arg = de_args
.next()
.expect("Should have at least one deserialized argument");
parse_quote! {
( #self_arg ). #target_ident #generic_args (
#(#de_args),*
)
}
} else if let Some(self_ty_base) = self_ty_base {
let call: Expr = if let Some(self_ty) = self_ty {
let opts_impl_generic_args = opts_impl_generic_args.unwrap_or_default();
let map = generic_params_map(generics, &opts_impl_generic_args);
let self_ty_with_generic_args =
type_utils::type_as_turbofish(&type_utils::map_type_generic_params(&map, self_ty));
let qualified_self = if let Some(trait_path) = trait_path {
let trait_path_with_generic_args = type_utils::path_as_turbofish(
&type_utils::map_path_generic_params(&map, trait_path),
);
quote! {
< #self_ty_with_generic_args as #trait_path_with_generic_args >
}
} else {
self_ty_with_generic_args
};
parse_quote! {
#self_ty_base #impl_generic_args :: #target_ident #generic_args (
#qualified_self :: #target_ident #generic_args (
#(#de_args),*
)
}
Expand Down Expand Up @@ -748,35 +752,62 @@ fn map_method_or_fn(
)
}

fn generic_params_map<'a, 'b>(
generics: &'a Generics,
impl_generic_args: &'b Punctuated<GenericArgument, token::Comma>,
) -> BTreeMap<&'a Ident, &'b GenericArgument> {
let n = generics
.params
.len()
.checked_sub(impl_generic_args.len())
.unwrap_or_else(|| {
panic!(
"{:?} is shorter than {:?}",
generics.params, impl_generic_args
);
});
generics
.params
.iter()
.skip(n)
.zip(impl_generic_args)
.filter_map(|(key, value)| {
if let GenericParam::Type(TypeParam { ident, .. }) = key {
Some((ident, value))
} else {
None
}
})
.collect()
}

fn map_args<'a, I>(
conversions: &mut Conversions,
candidates: &mut BTreeSet<OrdType>,
trait_path: &Option<Path>,
self_ty: Option<&Type>,
inputs: I,
) -> (bool, Vec<Type>, Vec<Stmt>, Vec<Expr>, Vec<Expr>)
) -> (Vec<Type>, Vec<Stmt>, Vec<Expr>, Vec<Expr>)
where
I: Iterator<Item = &'a FnArg>,
{
let (receiver, ty, fmt, ser, de): (Vec<_>, Vec<_>, Vec<_>, Vec<_>, Vec<_>) = inputs
let (ty, fmt, ser, de): (Vec<_>, Vec<_>, Vec<_>, Vec<_>) = inputs
.enumerate()
.map(map_arg(conversions, candidates, trait_path, self_ty))
.multiunzip();

let receiver = receiver.first().map_or(false, |&x| x);

(receiver, ty, fmt, ser, de)
(ty, fmt, ser, de)
}

fn map_arg<'a>(
conversions: &'a mut Conversions,
candidates: &'a mut BTreeSet<OrdType>,
trait_path: &'a Option<Path>,
self_ty: Option<&'a Type>,
) -> impl FnMut((usize, &FnArg)) -> (bool, Type, Stmt, Expr, Expr) + 'a {
) -> impl FnMut((usize, &FnArg)) -> (Type, Stmt, Expr, Expr) + 'a {
move |(i, arg)| {
let i = Literal::usize_unsuffixed(i);
let (receiver, expr, ty, fmt) = match arg {
let (expr, ty, fmt) = match arg {
FnArg::Receiver(Receiver {
reference,
mutability,
Expand All @@ -792,7 +823,7 @@ fn map_arg<'a>(
debug_struct.field("self", value);
});
};
(true, expr, ty, fmt)
(expr, ty, fmt)
}
FnArg::Typed(PatType { pat, ty, .. }) => {
let ident = match *pat_utils::pat_idents(pat).as_slice() {
Expand All @@ -810,11 +841,11 @@ fn map_arg<'a>(
debug_struct.field(#name, value);
});
};
(false, expr, ty, fmt)
(expr, ty, fmt)
}
};
let (ty, ser, de) = map_typed_arg(conversions, candidates, &i, &expr, &ty);
(receiver, ty, fmt, ser, de)
(ty, fmt, ser, de)
}
}

Expand Down
98 changes: 96 additions & 2 deletions macro/src/type_utils.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,104 @@
use proc_macro2::Span;
use proc_macro2::{Punct, Spacing, Span, TokenStream, TokenTree};
use quote::ToTokens;
use std::collections::BTreeMap;
use syn::{
parse_quote,
visit::{visit_path_arguments, Visit},
visit_mut::{visit_type_mut, VisitMut},
Ident, Path, PathArguments, PathSegment, Type, TypePath,
GenericArgument, Ident, Path, PathArguments, PathSegment, Type, TypePath,
};

pub fn map_path_generic_params(map: &BTreeMap<&Ident, &GenericArgument>, path: &Path) -> Path {
let mut path = path.clone();
let mut visitor = GenericParamVisitor { map };
visitor.visit_path_mut(&mut path);
path
}

pub fn map_type_generic_params(map: &BTreeMap<&Ident, &GenericArgument>, ty: &Type) -> Type {
let mut ty = ty.clone();
let mut visitor = GenericParamVisitor { map };
visitor.visit_type_mut(&mut ty);
ty
}

struct GenericParamVisitor<'a> {
map: &'a BTreeMap<&'a Ident, &'a GenericArgument>,
}

impl<'a> VisitMut for GenericParamVisitor<'a> {
fn visit_type_mut(&mut self, ty: &mut Type) {
if let Type::Path(TypePath { qself: None, path }) = ty {
if let Some(ident) = path.get_ident() {
if let Some(generic_arg) = self.map.get(ident) {
let GenericArgument::Type(ty_new) = generic_arg else {
panic!(
"Unexpected generic argument: {}",
generic_arg.to_token_stream()
);
};
*ty = ty_new.clone();
return;
}
}
}
visit_type_mut(self, ty);
}
}

pub fn path_as_turbofish(path: &Path) -> TokenStream {
let tokens = path.to_token_stream().into_iter().collect::<Vec<_>>();
let mut visitor = TurbofishVisitor { tokens };
visitor.visit_path(path);
visitor.tokens.into_iter().collect()
}

pub fn type_as_turbofish(ty: &Type) -> TokenStream {
let tokens = ty.to_token_stream().into_iter().collect::<Vec<_>>();
let mut visitor = TurbofishVisitor { tokens };
visitor.visit_type(ty);
visitor.tokens.into_iter().collect()
}

struct TurbofishVisitor {
tokens: Vec<TokenTree>,
}

impl<'a> Visit<'a> for TurbofishVisitor {
fn visit_path_arguments(&mut self, path_args: &PathArguments) {
if !path_args.is_none() {
let mut visitor_token_strings = token_strings(&self.tokens);
let path_args_tokens = path_args.to_token_stream().into_iter().collect::<Vec<_>>();
let path_args_token_strings = token_strings(&path_args_tokens);
let n = path_args_tokens.len();
let mut i: usize = 0;
while i + n <= self.tokens.len() {
if visitor_token_strings[i..i + n] == path_args_token_strings
&& (i < 2 || visitor_token_strings[i - 2..i] != [":", ":"])
{
self.tokens = [
&self.tokens[..i],
&[
TokenTree::Punct(Punct::new(':', Spacing::Joint)),
TokenTree::Punct(Punct::new(':', Spacing::Alone)),
],
&self.tokens[i..],
]
.concat();
visitor_token_strings = token_strings(&self.tokens);
i += 2;
}
i += 1;
}
}
visit_path_arguments(self, path_args);
}
}

fn token_strings(tokens: &[TokenTree]) -> Vec<String> {
tokens.iter().map(ToString::to_string).collect::<Vec<_>>()
}

pub fn expand_self(trait_path: &Option<Path>, self_ty: &Type, ty: &Type) -> Type {
let mut ty = ty.clone();
let mut visitor = ExpandSelfVisitor {
Expand Down

0 comments on commit 02249fe

Please sign in to comment.