Skip to content

Commit

Permalink
Refactor derive functions return syn::Result<_>
Browse files Browse the repository at this point in the history
  • Loading branch information
greyblake committed Oct 20, 2022
1 parent 727ab88 commit f3225dd
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 105 deletions.
33 changes: 17 additions & 16 deletions derive/src/field_attributes.rs
Expand Up @@ -21,15 +21,16 @@ pub enum FieldConstructor {
Value(TokenStream),
}

pub fn determine_field_constructor(field: &Field) -> FieldConstructor {
let opt_attr = fetch_attr_from_field(field);
match opt_attr {
Some(attr) => parse_attribute(attr),
pub fn determine_field_constructor(field: &Field) -> Result<FieldConstructor> {
let opt_attr = fetch_attr_from_field(field)?;
let ctor = match opt_attr {
Some(attr) => parse_attribute(attr)?,
None => FieldConstructor::Arbitrary,
}
};
Ok(ctor)
}

fn fetch_attr_from_field(field: &Field) -> Option<&Attribute> {
fn fetch_attr_from_field(field: &Field) -> Result<Option<&Attribute>> {
let found_attributes: Vec<_> = field
.attrs
.iter()
Expand All @@ -45,10 +46,10 @@ fn fetch_attr_from_field(field: &Field) -> Option<&Attribute> {
"Multiple conflicting #[{ARBITRARY_ATTRIBUTE_NAME}] attributes found on field `{name}`"
);
}
found_attributes.into_iter().next()
Ok(found_attributes.into_iter().next())
}

fn parse_attribute(attr: &Attribute) -> FieldConstructor {
fn parse_attribute(attr: &Attribute) -> Result<FieldConstructor> {
let group = {
let mut tokens_iter = attr.clone().tokens.into_iter();
let token = tokens_iter
Expand All @@ -62,20 +63,20 @@ fn parse_attribute(attr: &Attribute) -> FieldConstructor {
parse_attribute_internals(group.stream())
}

fn parse_attribute_internals(stream: TokenStream) -> FieldConstructor {
fn parse_attribute_internals(stream: TokenStream) -> Result<FieldConstructor> {
let mut tokens_iter = stream.into_iter();
let token = tokens_iter
.next()
.unwrap_or_else(|| panic!("#[{ARBITRARY_ATTRIBUTE_NAME}] cannot be empty."));
match token.to_string().as_ref() {
"default" => FieldConstructor::Default,
"default" => Ok(FieldConstructor::Default),
"with" => {
let func_path = parse_assigned_value("with", tokens_iter);
FieldConstructor::With(func_path)
let func_path = parse_assigned_value("with", tokens_iter)?;
Ok(FieldConstructor::With(func_path))
}
"value" => {
let value = parse_assigned_value("value", tokens_iter);
FieldConstructor::Value(value)
let value = parse_assigned_value("value", tokens_iter)?;
Ok(FieldConstructor::Value(value))
}
_ => panic!("Unknown option for #[{ARBITRARY_ATTRIBUTE_NAME}]: `{token}`"),
}
Expand All @@ -88,12 +89,12 @@ fn parse_attribute_internals(stream: TokenStream) -> FieldConstructor {
fn parse_assigned_value(
opt_name: &str,
mut tokens_iter: impl Iterator<Item = TokenTree>,
) -> TokenStream {
) -> Result<TokenStream> {
let eq_sign = tokens_iter.next().unwrap_or_else(|| {
panic!("Invalid syntax for #[{ARBITRARY_ATTRIBUTE_NAME}], `{opt_name}` is missing RHS.")
});
if eq_sign.to_string() != "=" {
panic!("Invalid syntax for #[{ARBITRARY_ATTRIBUTE_NAME}], expected `=` after `{opt_name}`, got: `{eq_sign}`");
}
tokens_iter.collect()
Ok(tokens_iter.collect())
}
228 changes: 139 additions & 89 deletions derive/src/lib.rs
Expand Up @@ -12,6 +12,12 @@ static ARBITRARY_LIFETIME_NAME: &str = "'arbitrary";
#[proc_macro_derive(Arbitrary, attributes(arbitrary))]
pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = syn::parse_macro_input!(tokens as syn::DeriveInput);
expand_derive_arbitrary(input)
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}

fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
let (lifetime_without_bounds, lifetime_with_bounds) =
build_arbitrary_lifetime(input.generics.clone());

Expand All @@ -21,8 +27,8 @@ pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStr
);

let arbitrary_method =
gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count);
let size_hint_method = gen_size_hint_method(&input);
gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count)?;
let size_hint_method = gen_size_hint_method(&input)?;
let name = input.ident;
// Add a bound `T: Arbitrary` to every type parameter T.
let generics = add_trait_bounds(input.generics, lifetime_without_bounds.clone());
Expand All @@ -37,7 +43,7 @@ pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStr
// Build TypeGenerics and WhereClause without a lifetime
let (_, ty_generics, where_clause) = generics.split_for_impl();

(quote! {
Ok(quote! {
const _: () = {
thread_local! {
#[allow(non_upper_case_globals)]
Expand All @@ -50,7 +56,6 @@ pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStr
}
};
})
.into()
}

// Returns: (lifetime without bounds, lifetime with bounds)
Expand Down Expand Up @@ -115,44 +120,65 @@ fn gen_arbitrary_method(
input: &DeriveInput,
lifetime: LifetimeDef,
recursive_count: &syn::Ident,
) -> TokenStream {
let ident = &input.ident;

let arbitrary_structlike = |fields| {
let arbitrary = construct(fields, |_idx, field| gen_constructor_for_field(field));
) -> Result<TokenStream> {
fn arbitrary_structlike(
fields: &Fields,
ident: &syn::Ident,
lifetime: LifetimeDef,
recursive_count: &syn::Ident,
) -> Result<TokenStream> {
let arbitrary = construct(fields, |_idx, field| gen_constructor_for_field(field))?;
let body = with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary) });

let arbitrary_take_rest = construct_take_rest(fields);
let arbitrary_take_rest = construct_take_rest(fields)?;
let take_rest_body =
with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary_take_rest) });

quote! {
Ok(quote! {
fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
#body
}

fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
#take_rest_body
}
}
};
})
}

match &input.data {
Data::Struct(data) => arbitrary_structlike(&data.fields),
Data::Union(data) => arbitrary_structlike(&Fields::Named(data.fields.clone())),
let ident = &input.ident;
let output = match &input.data {
Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count)?,
Data::Union(data) => arbitrary_structlike(
&Fields::Named(data.fields.clone()),
ident,
lifetime,
recursive_count,
)?,
Data::Enum(data) => {
let variants = data.variants.iter().enumerate().map(|(i, variant)| {
let idx = i as u64;
let ctor = construct(&variant.fields, |_, field| gen_constructor_for_field(field));
let variant_name = &variant.ident;
quote! { #idx => #ident::#variant_name #ctor }
});
let variants_take_rest = data.variants.iter().enumerate().map(|(i, variant)| {
let idx = i as u64;
let ctor = construct_take_rest(&variant.fields);
let variant_name = &variant.ident;
quote! { #idx => #ident::#variant_name #ctor }
});
let variants: Vec<TokenStream> = data
.variants
.iter()
.enumerate()
.map(|(i, variant)| {
let idx = i as u64;
let variant_name = &variant.ident;
construct(&variant.fields, |_, field| gen_constructor_for_field(field))
.map(|ctor| quote! { #idx => #ident::#variant_name #ctor })
})
.collect::<Result<_>>()?;

let variants_take_rest: Vec<TokenStream> = data
.variants
.iter()
.enumerate()
.map(|(i, variant)| {
let idx = i as u64;
let variant_name = &variant.ident;
construct_take_rest(&variant.fields)
.map(|ctor| quote! { #idx => #ident::#variant_name #ctor })
})
.collect::<Result<_>>()?;

let count = data.variants.len() as u64;

let arbitrary = with_recursive_count_guard(
Expand Down Expand Up @@ -191,33 +217,44 @@ fn gen_arbitrary_method(
}
}
}
}
};
Ok(output)
}

fn construct(fields: &Fields, ctor: impl Fn(usize, &Field) -> TokenStream) -> TokenStream {
match fields {
fn construct(
fields: &Fields,
ctor: impl Fn(usize, &Field) -> Result<TokenStream>,
) -> Result<TokenStream> {
let output = match fields {
Fields::Named(names) => {
let names = names.named.iter().enumerate().map(|(i, f)| {
let name = f.ident.as_ref().unwrap();
let ctor = ctor(i, f);
quote! { #name: #ctor }
});
let names: Vec<TokenStream> = names
.named
.iter()
.enumerate()
.map(|(i, f)| {
let name = f.ident.as_ref().unwrap();
ctor(i, f).map(|ctor| quote! { #name: #ctor })
})
.collect::<Result<_>>()?;
quote! { { #(#names,)* } }
}
Fields::Unnamed(names) => {
let names = names.unnamed.iter().enumerate().map(|(i, f)| {
let ctor = ctor(i, f);
quote! { #ctor }
});
let names: Vec<TokenStream> = names
.unnamed
.iter()
.enumerate()
.map(|(i, f)| ctor(i, f).map(|ctor| quote! { #ctor }))
.collect::<Result<_>>()?;
quote! { ( #(#names),* ) }
}
Fields::Unit => quote!(),
}
};
Ok(output)
}

fn construct_take_rest(fields: &Fields) -> TokenStream {
fn construct_take_rest(fields: &Fields) -> Result<TokenStream> {
construct(fields, |idx, field| {
match determine_field_constructor(field) {
determine_field_constructor(field).map(|field_constructor| match field_constructor {
FieldConstructor::Default => quote!(Default::default()),
FieldConstructor::Arbitrary => {
if idx + 1 == fields.len() {
Expand All @@ -228,70 +265,83 @@ fn construct_take_rest(fields: &Fields) -> TokenStream {
}
FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(&mut u)?),
FieldConstructor::Value(value) => quote!(#value),
}
})
})
}

fn gen_size_hint_method(input: &DeriveInput) -> TokenStream {
fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
let size_hint_fields = |fields: &Fields| {
let hints = fields.iter().map(|f| {
let ty = &f.ty;
match determine_field_constructor(f) {
FieldConstructor::Default | FieldConstructor::Value(_) => {
quote!((0, Some(0)))
}
FieldConstructor::Arbitrary => {
quote! { <#ty as arbitrary::Arbitrary>::size_hint(depth) }
}

// Note that in this case it's hard to determine what size_hint must be, so size_of::<T>() is
// just an educated guess, although it's gonna be inaccurate for dynamically
// allocated types (Vec, HashMap, etc.).
FieldConstructor::With(_) => {
quote! { (::core::mem::size_of::<#ty>(), None) }
fields
.iter()
.map(|f| {
let ty = &f.ty;
determine_field_constructor(f).map(|field_constructor| {
match field_constructor {
FieldConstructor::Default | FieldConstructor::Value(_) => {
quote!((0, Some(0)))
}
FieldConstructor::Arbitrary => {
quote! { <#ty as arbitrary::Arbitrary>::size_hint(depth) }
}

// Note that in this case it's hard to determine what size_hint must be, so size_of::<T>() is
// just an educated guess, although it's gonna be inaccurate for dynamically
// allocated types (Vec, HashMap, etc.).
FieldConstructor::With(_) => {
quote! { (::core::mem::size_of::<#ty>(), None) }
}
}
})
})
.collect::<Result<Vec<TokenStream>>>()
.map(|hints| {
quote! {
arbitrary::size_hint::and_all(&[
#( #hints ),*
])
}
}
});
quote! {
arbitrary::size_hint::and_all(&[
#( #hints ),*
])
}
})
};
let size_hint_structlike = |fields: &Fields| {
let hint = size_hint_fields(fields);
quote! {
#[inline]
fn size_hint(depth: usize) -> (usize, Option<usize>) {
arbitrary::size_hint::recursion_guard(depth, |depth| #hint)
size_hint_fields(fields).map(|hint| {
quote! {
#[inline]
fn size_hint(depth: usize) -> (usize, Option<usize>) {
arbitrary::size_hint::recursion_guard(depth, |depth| #hint)
}
}
}
})
};
match &input.data {
Data::Struct(data) => size_hint_structlike(&data.fields),
Data::Union(data) => size_hint_structlike(&Fields::Named(data.fields.clone())),
Data::Enum(data) => {
let variants = data.variants.iter().map(|v| size_hint_fields(&v.fields));
quote! {
#[inline]
fn size_hint(depth: usize) -> (usize, Option<usize>) {
arbitrary::size_hint::and(
<u32 as arbitrary::Arbitrary>::size_hint(depth),
arbitrary::size_hint::recursion_guard(depth, |depth| {
arbitrary::size_hint::or_all(&[ #( #variants ),* ])
}),
)
Data::Enum(data) => data
.variants
.iter()
.map(|v| size_hint_fields(&v.fields))
.collect::<Result<Vec<TokenStream>>>()
.map(|variants| {
quote! {
#[inline]
fn size_hint(depth: usize) -> (usize, Option<usize>) {
arbitrary::size_hint::and(
<u32 as arbitrary::Arbitrary>::size_hint(depth),
arbitrary::size_hint::recursion_guard(depth, |depth| {
arbitrary::size_hint::or_all(&[ #( #variants ),* ])
}),
)
}
}
}
}
}),
}
}

fn gen_constructor_for_field(field: &Field) -> TokenStream {
match determine_field_constructor(field) {
fn gen_constructor_for_field(field: &Field) -> Result<TokenStream> {
let ctor = match determine_field_constructor(field)? {
FieldConstructor::Default => quote!(Default::default()),
FieldConstructor::Arbitrary => quote!(arbitrary::Arbitrary::arbitrary(u)?),
FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(u)?),
FieldConstructor::Value(value) => quote!(#value),
}
};
Ok(ctor)
}

0 comments on commit f3225dd

Please sign in to comment.