Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support customization of fields on derive #129

Merged
merged 13 commits into from Oct 20, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 25 additions & 0 deletions README.md
Expand Up @@ -61,6 +61,31 @@ pub struct Rgb {
}
```

### Customizing single fields
fitzgen marked this conversation as resolved.
Show resolved Hide resolved

This can be particular handy if your structure uses a type that does not implement `Arbitrary` or you want to have more customization for particular fields.

```rust
#[derive(Arbitrary)]
pub struct Rgb {
// set `r` to Default::default()
#[arbitrary(default)]
pub r: u8,

// set `g` to 255
#[arbitrary(value = "255")]
fitzgen marked this conversation as resolved.
Show resolved Hide resolved
pub g: u8,

// generate `b` with a custom function
fitzgen marked this conversation as resolved.
Show resolved Hide resolved
#[arbitrary(with = "arbitrary_b")]
fitzgen marked this conversation as resolved.
Show resolved Hide resolved
pub b: u8,
}

fn arbitrary_b(u: &mut Unstructured) -> arbitrary::Result<u8> {
u.int_in_range(64..=128)
}
```

### Implementing `Arbitrary` By Hand

Alternatively, you can write an `Arbitrary` implementation by hand:
Expand Down
2 changes: 1 addition & 1 deletion derive/Cargo.toml
Expand Up @@ -9,7 +9,7 @@ authors = [
"Corey Farwell <coreyf@rwell.org>",
]
categories = ["development-tools::testing"]
edition = "2018"
edition = "2021"
keywords = ["arbitrary", "testing", "derive", "macro"]
readme = "README.md"
description = "Derives arbitrary traits"
Expand Down
110 changes: 110 additions & 0 deletions derive/src/field_attributes.rs
@@ -0,0 +1,110 @@
use proc_macro2::{Literal, TokenStream, TokenTree};
use quote::quote;
use syn::*;

/// Used to filter out necessary field attribute and in panics.
static ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary";

/// Determines how a value for a field should be constructed.
#[cfg_attr(test, derive(Debug))]
pub enum FieldConstructor {
/// Assume that Arbitrary is defined for the type of this field and use it (default)
Arbitrary,

/// Places `Default::default()` as a field value.
Default,

/// Use custom function to generate a value for a field.
WithFunction(TokenStream),

/// Set a field always to the given value.
Value(TokenStream),
}

pub fn determine_field_constructor(field: &Field) -> FieldConstructor {
let opt_attr = fetch_attr_from_field(field);
match opt_attr {
Some(attr) => parse_attribute(attr),
None => FieldConstructor::Arbitrary,
}
}

fn fetch_attr_from_field(field: &Field) -> Option<&Attribute> {
field.attrs.iter().find(|a| {
let path = &a.path;
let name = quote!(#path).to_string();
name == ARBITRARY_ATTRIBUTE_NAME
})
}

fn parse_attribute(attr: &Attribute) -> FieldConstructor {
let group = {
let mut tokens_iter = attr.clone().tokens.into_iter();
let token = tokens_iter
.next()
.unwrap_or_else(|| panic!("{ARBITRARY_ATTRIBUTE_NAME} attribute cannot be empty."));
match token {
TokenTree::Group(g) => g,
t => panic!("{ARBITRARY_ATTRIBUTE_NAME} must contain a group, got: {t})"),
}
};
parse_attribute_internals(group.stream())
}

fn parse_attribute_internals(stream: TokenStream) -> FieldConstructor {
let mut tokens_iter = stream.into_iter();
let token = tokens_iter
.next()
.unwrap_or_else(|| panic!("{ARBITRARY_ATTRIBUTE_NAME} attribute cannot be empty."));
match token.to_string().as_ref() {
"default" => FieldConstructor::Default,
"with" => {
let func_path = parse_assigned_value(tokens_iter);
FieldConstructor::WithFunction(func_path)
}
"value" => {
let value = parse_assigned_value(tokens_iter);
FieldConstructor::Value(value)
}
_ => panic!("Unknown options for {ARBITRARY_ATTRIBUTE_NAME}: {token}"),
fitzgen marked this conversation as resolved.
Show resolved Hide resolved
}
}

// Input:
// = "2 + 2"
// Output:
// 2 + 2
fn parse_assigned_value(mut tokens_iter: impl Iterator<Item = TokenTree>) -> TokenStream {
let eq_sign = tokens_iter
.next()
.unwrap_or_else(|| panic!("Invalid syntax for {ARBITRARY_ATTRIBUTE_NAME}() attribute"));
if eq_sign.to_string() != "=" {
panic!("Invalid syntax for {ARBITRARY_ATTRIBUTE_NAME}() attribute");
}
let lit_token = tokens_iter
.next()
.unwrap_or_else(|| panic!("Invalid syntax for {ARBITRARY_ATTRIBUTE_NAME}() attribute"));
let value = unwrap_token_as_string_literal(lit_token);
value.parse().unwrap()
}

fn unwrap_token_as_string_literal(token: TokenTree) -> String {
let lit = unwrap_token_as_literal(token);
literal_to_string(lit)
}

fn literal_to_string(lit: Literal) -> String {
let value = lit.to_string();
if value.starts_with('"') && value.ends_with('"') {
// Trim the first and the last chars (double quotes)
return value[1..(value.len() - 1)].to_string();
}
panic!("{ARBITRARY_ATTRIBUTE_NAME}() expected an attribute to be a string, but got: {value}",);
}

fn unwrap_token_as_literal(token: TokenTree) -> Literal {
match token {
TokenTree::Literal(lit) => lit,
something => panic!("{ARBITRARY_ATTRIBUTE_NAME}() expected a literal, got: {something}"),
}
}
58 changes: 46 additions & 12 deletions derive/src/lib.rs
Expand Up @@ -4,9 +4,12 @@ use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::*;

mod field_attributes;
use field_attributes::{determine_field_constructor, FieldConstructor};

static ARBITRARY_LIFETIME_NAME: &str = "'arbitrary";

#[proc_macro_derive(Arbitrary)]
#[proc_macro_derive(Arbitrary, attributes(arbitrary))]
pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = syn::parse_macro_input!(tokens as syn::DeriveInput);
let (lifetime_without_bounds, lifetime_with_bounds) =
Expand Down Expand Up @@ -116,7 +119,7 @@ fn gen_arbitrary_method(
let ident = &input.ident;

let arbitrary_structlike = |fields| {
let arbitrary = construct(fields, |_, _| quote!(arbitrary::Arbitrary::arbitrary(u)?));
let arbitrary = construct(fields, |_idx, field| gen_constructor_for_field(field));
let body = with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary) });

let arbitrary_take_rest = construct_take_rest(fields);
Expand All @@ -140,9 +143,7 @@ fn gen_arbitrary_method(
Data::Enum(data) => {
let variants = data.variants.iter().enumerate().map(|(i, variant)| {
let idx = i as u64;
let ctor = construct(&variant.fields, |_, _| {
quote!(arbitrary::Arbitrary::arbitrary(u)?)
});
let ctor = construct(&variant.fields, |_, field| gen_constructor_for_field(field));
let variant_name = &variant.ident;
quote! { #idx => #ident::#variant_name #ctor }
});
Expand Down Expand Up @@ -215,21 +216,45 @@ fn construct(fields: &Fields, ctor: impl Fn(usize, &Field) -> TokenStream) -> To
}

fn construct_take_rest(fields: &Fields) -> TokenStream {
construct(fields, |idx, _| {
if idx + 1 == fields.len() {
quote! { arbitrary::Arbitrary::arbitrary_take_rest(u)? }
} else {
quote! { arbitrary::Arbitrary::arbitrary(&mut u)? }
construct(fields, |idx, field| {
match determine_field_constructor(field) {
FieldConstructor::Default => quote!(Default::default()),
FieldConstructor::Arbitrary => {
if idx + 1 == fields.len() {
quote! { arbitrary::Arbitrary::arbitrary_take_rest(u)? }
} else {
quote! { arbitrary::Arbitrary::arbitrary(&mut u)? }
}
}
FieldConstructor::WithFunction(func_path) => quote!(#func_path(&mut u)?),
FieldConstructor::Value(value) => quote!(#value),
}
})
}

fn gen_size_hint_method(input: &DeriveInput) -> TokenStream {
let size_hint_fields = |fields: &Fields| {
let tys = fields.iter().map(|f| &f.ty);
let hints = fields.iter().map(|f| {
let ty = &f.ty;
match determine_field_constructor(f) {
FieldConstructor::Default | FieldConstructor::Value(_) => {
quote!((0, Some(0)))
}
FieldConstructor::Arbitrary => {
quote! { <#ty as arbitrary::Arbitrary>::size_hint(depth) }
}

// Note that in this case it's hard to determine what size_hint must be, so size_of::<T>() is
// just an educated guess, although it's gonna be inaccurate for dynamically
// allocated types (Vec, HashMap, etc.).
FieldConstructor::WithFunction(_) => {
quote! { (::core::mem::size_of::<#ty>(), None) }
}
}
});
quote! {
arbitrary::size_hint::and_all(&[
#( <#tys as arbitrary::Arbitrary>::size_hint(depth) ),*
#( #hints ),*
])
}
};
Expand Down Expand Up @@ -261,3 +286,12 @@ fn gen_size_hint_method(input: &DeriveInput) -> TokenStream {
}
}
}

fn gen_constructor_for_field(field: &Field) -> TokenStream {
match determine_field_constructor(field) {
FieldConstructor::Default => quote!(Default::default()),
FieldConstructor::Arbitrary => quote!(arbitrary::Arbitrary::arbitrary(u)?),
FieldConstructor::WithFunction(func_path) => quote!(#func_path(u)?),
FieldConstructor::Value(value) => quote!(#value),
}
}
39 changes: 39 additions & 0 deletions tests/derive.rs
Expand Up @@ -231,3 +231,42 @@ fn recursive_and_empty_input() {

let _ = Nat5::arbitrary(&mut Unstructured::new(&[]));
}

#[test]
fn test_field_attributes() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add some doc tests that use compile_fail to check that we return compiler errors when

  • #[arbitrary] is used multiple times on a single field
  • #[arbitrary(unknown_attr)] and #[arbitrary(unknown_attr = unknown_val)] are used
  • #[arbitrary(value)] and #[arbitrary(with)] are used without RHS values

Should be able to just do something like the following in this file:

/// Can only use `#[arbitrary]` once per field:
///
/// ```compile_fail
/// use arbitrary::*;
/// #[derive(Arbitrary)]
/// struct Foo {
///     #[arbitrary(with = foo)]
///     #[arbitrary(with = bar)]
///     field: u32,
/// }
/// fn foo(_: &mut Unstructured) -> u32 { todo!() }
/// fn bar(_: &mut Unstructured) -> u32 { todo!() }
/// ```
///
/// Etc...
pub struct CompileFailTests;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 8913ec0

I was puzzled where to put such doc tests (I don't think users really want to see them in the docs).
Eventually I look into serde again, they're using trybuild to test compile failures.
So I've decided to introduce trybuild to this project too, I think it's a nice tool for this purpose.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was doing doc tests on a pub struct CompileFailTests in tests/derive.rs not getting picked up by rustdoc?

If not, then we could maybe avoid a new dependency by doing

// src/lib.rs

#[cfg(all(test, feature = "derive"))]
/// ...
pub struct CompileFailTests {}

So that it never shows up in the docs and only is built when we are running tests (with the derive enabled).

Copy link
Contributor Author

@greyblake greyblake Oct 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fitzgen I just tried doc tests within derive/src/lib.rs, they are not picked up.

If not, then we could maybe avoid a new dependency by doing

It's dev-dependency, shouldn't it be fine? I definitely like trybuild, it generates really helpful error messages and allows to structure tests in more easy to understand / maintainable way.

For example, this is what a reported failure looks like:
image

This way we actually test, the compiler fails in the way we expect it to fail. Doc tests may give false positives if something in the code gets mistyped.

I would see the cost of using trybuild is rather potential failures if rustc changes the way it prints errors.
But still I think pros of trybuild outweights its cons.

Please let me know, if you still insist on using doc tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if the test is in src/lib.rs rather than derive/src/lib.rs it should work? I'd really want to avoid failing tests on nightly but not stable unless as an absolute last resort.

// A type that DOES NOT implement Arbitrary
#[derive(Debug)]
struct Weight(u8);

#[derive(Debug, Arbitrary)]
struct Parcel {
#[arbitrary(with = "arbitrary_weight")]
fitzgen marked this conversation as resolved.
Show resolved Hide resolved
weight: Weight,

#[arbitrary(default)]
width: u8,

#[arbitrary(value = "2 + 2")]
length: u8,

height: u8,
}

fn arbitrary_weight(u: &mut Unstructured) -> arbitrary::Result<Weight> {
u.int_in_range(45..=56).map(Weight)
}

let parcel: Parcel = arbitrary_from(&[6, 199]);

// 45 + 6 = 51
assert_eq!(parcel.weight.0, 51);

// u8::default()
assert_eq!(parcel.width, 0);

// 2 + 2 = 4
assert_eq!(parcel.length, 4);

// 199 is the second byte, used by arbitrary
assert_eq!(parcel.height, 199);
}