Skip to content

Commit

Permalink
derive: Protect against unbounded recursion when u.is_empty()
Browse files Browse the repository at this point in the history
Fixes #107
  • Loading branch information
fitzgen committed Jun 14, 2022
1 parent 5934aa1 commit 3e3940c
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 8 deletions.
71 changes: 63 additions & 8 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@ pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStr
let (lifetime_without_bounds, lifetime_with_bounds) =
build_arbitrary_lifetime(input.generics.clone());

let arbitrary_method = gen_arbitrary_method(&input, lifetime_without_bounds.clone());
let recursive_count = syn::Ident::new(
&format!("RECURSIVE_COUNT_{}", input.ident),
Span::call_site(),
);

let arbitrary_method =
gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count);
let size_hint_method = gen_size_hint_method(&input);
let name = input.ident;
// Add a bound `T: Arbitrary` to every type parameter T.
Expand All @@ -29,6 +35,10 @@ pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStr
let (_, ty_generics, where_clause) = generics.split_for_impl();

(quote! {
thread_local! {
static #recursive_count: Cell<u32> = Cell::new(0);
}

impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause {
#arbitrary_method
#size_hint_method
Expand Down Expand Up @@ -67,21 +77,51 @@ fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeDef) -> Generics {
generics
}

fn gen_arbitrary_method(input: &DeriveInput, lifetime: LifetimeDef) -> TokenStream {
fn with_recursive_count_guard(
recursive_count: &syn::Ident,
expr: impl quote::ToTokens,
) -> impl quote::ToTokens {
quote! {
#recursive_count.with(|count| {
if count.get() > 0 && u.is_empty() {
return Err(arbitrary::Error::NotEnoughData);
}

count.set(count.get() + 1);
let result = { #expr };
count.set(count.get() - 1);

result
})
}
}

fn gen_arbitrary_method(
input: &DeriveInput,
lifetime: LifetimeDef,
recursive_count: &syn::Ident,
) -> TokenStream {
let ident = &input.ident;

let arbitrary_structlike = |fields| {
let arbitrary = construct(fields, |_, _| quote!(arbitrary::Arbitrary::arbitrary(u)?));
let body = with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary) });

let arbitrary_take_rest = construct_take_rest(fields);
let take_rest_body =
with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary_take_rest) });

quote! {
fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
Ok(#ident #arbitrary)
#body
}

fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
Ok(#ident #arbitrary_take_rest)
#take_rest_body
}
}
};

match &input.data {
Data::Struct(data) => arbitrary_structlike(&data.fields),
Data::Union(data) => arbitrary_structlike(&Fields::Named(data.fields.clone())),
Expand All @@ -101,25 +141,40 @@ fn gen_arbitrary_method(input: &DeriveInput, lifetime: LifetimeDef) -> TokenStre
quote! { #idx => #ident::#variant_name #ctor }
});
let count = data.variants.len() as u64;
quote! {
fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {

let arbitrary = with_recursive_count_guard(
recursive_count,
quote! {
// Use a multiply + shift to generate a ranged random number
// with slight bias. For details, see:
// https://lemire.me/blog/2016/06/30/fast-random-shuffling
Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(u)?) * #count) >> 32 {
#(#variants,)*
_ => unreachable!()
})
}
},
);

fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
let arbitrary_take_rest = with_recursive_count_guard(
recursive_count,
quote! {
// Use a multiply + shift to generate a ranged random number
// with slight bias. For details, see:
// https://lemire.me/blog/2016/06/30/fast-random-shuffling
Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(&mut u)?) * #count) >> 32 {
#(#variants_take_rest,)*
_ => unreachable!()
})
},
);

quote! {
fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
#arbitrary
}

fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
#arbitrary_take_rest
}
}
}
Expand Down
42 changes: 42 additions & 0 deletions tests/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,45 @@ fn two_lifetimes() {
assert_eq!(lower, 0);
assert_eq!(upper, None);
}

#[test]
fn recursive_and_empty_input() {
// None of the following derives should result in a stack overflow. See
// https://github.com/rust-fuzz/arbitrary/issues/107 for details.

#[derive(Debug, Arbitrary)]
enum Nat {
Succ(Box<Nat>),
Zero,
}

let _ = Nat::arbitrary(&mut Unstructured::new(&[]));

#[derive(Debug, Arbitrary)]
enum Nat2 {
Zero,
Succ(Box<Nat2>),
}

let _ = Nat2::arbitrary(&mut Unstructured::new(&[]));

#[derive(Debug, Arbitrary)]
struct Nat3 {
f: Option<Box<Nat3>>,
}

let _ = Nat3::arbitrary(&mut Unstructured::new(&[]));

#[derive(Debug, Arbitrary)]
struct Nat4(Option<Box<Nat4>>);

let _ = Nat4::arbitrary(&mut Unstructured::new(&[]));

#[derive(Debug, Arbitrary)]
enum Nat5 {
Zero,
Succ { f: Box<Nat5> },
}

let _ = Nat5::arbitrary(&mut Unstructured::new(&[]));
}

0 comments on commit 3e3940c

Please sign in to comment.