Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

derive: Protect against unbounded recursion when u.is_empty() #109

Merged
merged 1 commit into from
Jun 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 63 additions & 8 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@ pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStr
let (lifetime_without_bounds, lifetime_with_bounds) =
build_arbitrary_lifetime(input.generics.clone());

let arbitrary_method = gen_arbitrary_method(&input, lifetime_without_bounds.clone());
let recursive_count = syn::Ident::new(
&format!("RECURSIVE_COUNT_{}", input.ident),
Span::call_site(),
);

let arbitrary_method =
gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count);
let size_hint_method = gen_size_hint_method(&input);
let name = input.ident;
// Add a bound `T: Arbitrary` to every type parameter T.
Expand All @@ -29,6 +35,10 @@ pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStr
let (_, ty_generics, where_clause) = generics.split_for_impl();

(quote! {
thread_local! {
static #recursive_count: std::cell::Cell<u32> = std::cell::Cell::new(0);
}

impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause {
#arbitrary_method
#size_hint_method
Expand Down Expand Up @@ -67,21 +77,51 @@ fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeDef) -> Generics {
generics
}

fn gen_arbitrary_method(input: &DeriveInput, lifetime: LifetimeDef) -> TokenStream {
fn with_recursive_count_guard(
recursive_count: &syn::Ident,
expr: impl quote::ToTokens,
) -> impl quote::ToTokens {
quote! {
#recursive_count.with(|count| {
if count.get() > 0 && u.is_empty() {
return Err(arbitrary::Error::NotEnoughData);
}

count.set(count.get() + 1);
let result = { #expr };
count.set(count.get() - 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work as intended, i.e. is it ok to skip decreasing count if #expr aborts through ??

Generated code for my struct, for consideration
fn arbitrary(u: &mut Unstructured) -> Result<Self> {
    RECURSIVE_COUNT_Nah.with(|count| {
        if count.get() > 0 && u.is_empty() {
            return Err(arbitrary::Error::NotEnoughData);
        }
        count.set(count.get() + 1);
        let result = {
            Ok(
                match (u64::from(<u32 as Arbitrary>::arbitrary(u)?) * 2u64) >> 32 {
                    0u64 => Nah::Foo(
                        Arbitrary::arbitrary(u)?,
                        Arbitrary::arbitrary(u)?,
                        Arbitrary::arbitrary(u)?,
                    ),
                    1u64 => Nah::Bar(Arbitrary::arbitrary(u)?, Arbitrary::arbitrary(u)?),
                    _ => panic!("internal error: entered unreachable code"),
                },
            )
        };
        count.set(count.get() - 1);
        result
    })
}

Copy link
Contributor

@jcaesar jcaesar Jun 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about it further, wouldn't it be better to only modify count when the unstructured is already empty?

  • No need to access the thread local under normal conditions (performance micro-optimization, I know.)
  • It's entirely possible that a recursive struct is benign and does finish generating on zeros only (after recursing a few times from the Unstructured input). But this currently aborts before trying that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to review a pull request if you want to make one!


result
})
}
}

fn gen_arbitrary_method(
input: &DeriveInput,
lifetime: LifetimeDef,
recursive_count: &syn::Ident,
) -> TokenStream {
let ident = &input.ident;

let arbitrary_structlike = |fields| {
let arbitrary = construct(fields, |_, _| quote!(arbitrary::Arbitrary::arbitrary(u)?));
let body = with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary) });

let arbitrary_take_rest = construct_take_rest(fields);
let take_rest_body =
with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary_take_rest) });

quote! {
fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
Ok(#ident #arbitrary)
#body
}

fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
Ok(#ident #arbitrary_take_rest)
#take_rest_body
}
}
};

match &input.data {
Data::Struct(data) => arbitrary_structlike(&data.fields),
Data::Union(data) => arbitrary_structlike(&Fields::Named(data.fields.clone())),
Expand All @@ -101,25 +141,40 @@ fn gen_arbitrary_method(input: &DeriveInput, lifetime: LifetimeDef) -> TokenStre
quote! { #idx => #ident::#variant_name #ctor }
});
let count = data.variants.len() as u64;
quote! {
fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {

let arbitrary = with_recursive_count_guard(
recursive_count,
quote! {
// Use a multiply + shift to generate a ranged random number
// with slight bias. For details, see:
// https://lemire.me/blog/2016/06/30/fast-random-shuffling
Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(u)?) * #count) >> 32 {
#(#variants,)*
_ => unreachable!()
})
}
},
);

fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
let arbitrary_take_rest = with_recursive_count_guard(
recursive_count,
quote! {
// Use a multiply + shift to generate a ranged random number
// with slight bias. For details, see:
// https://lemire.me/blog/2016/06/30/fast-random-shuffling
Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(&mut u)?) * #count) >> 32 {
#(#variants_take_rest,)*
_ => unreachable!()
})
},
);

quote! {
fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
#arbitrary
}

fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
#arbitrary_take_rest
}
}
}
Expand Down
42 changes: 42 additions & 0 deletions tests/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,45 @@ fn two_lifetimes() {
assert_eq!(lower, 0);
assert_eq!(upper, None);
}

#[test]
fn recursive_and_empty_input() {
// None of the following derives should result in a stack overflow. See
// https://github.com/rust-fuzz/arbitrary/issues/107 for details.

#[derive(Debug, Arbitrary)]
enum Nat {
Succ(Box<Nat>),
Zero,
}

let _ = Nat::arbitrary(&mut Unstructured::new(&[]));

#[derive(Debug, Arbitrary)]
enum Nat2 {
Zero,
Succ(Box<Nat2>),
}

let _ = Nat2::arbitrary(&mut Unstructured::new(&[]));

#[derive(Debug, Arbitrary)]
struct Nat3 {
f: Option<Box<Nat3>>,
}

let _ = Nat3::arbitrary(&mut Unstructured::new(&[]));

#[derive(Debug, Arbitrary)]
struct Nat4(Option<Box<Nat4>>);

let _ = Nat4::arbitrary(&mut Unstructured::new(&[]));

#[derive(Debug, Arbitrary)]
enum Nat5 {
Zero,
Succ { f: Box<Nat5> },
}

let _ = Nat5::arbitrary(&mut Unstructured::new(&[]));
}