From 3e3940c3944aa8ffe55906ef65e4729e9337fb94 Mon Sep 17 00:00:00 2001 From: Nick Fitzgerald Date: Tue, 14 Jun 2022 11:29:53 -0700 Subject: [PATCH] derive: Protect against unbounded recursion when `u.is_empty()` Fixes https://github.com/rust-fuzz/arbitrary/issues/107 --- derive/src/lib.rs | 71 +++++++++++++++++++++++++++++++++++++++++------ tests/derive.rs | 42 ++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 8 deletions(-) diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 983bd68..2c2ce36 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -12,7 +12,13 @@ pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStr let (lifetime_without_bounds, lifetime_with_bounds) = build_arbitrary_lifetime(input.generics.clone()); - let arbitrary_method = gen_arbitrary_method(&input, lifetime_without_bounds.clone()); + let recursive_count = syn::Ident::new( + &format!("RECURSIVE_COUNT_{}", input.ident), + Span::call_site(), + ); + + let arbitrary_method = + gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count); let size_hint_method = gen_size_hint_method(&input); let name = input.ident; // Add a bound `T: Arbitrary` to every type parameter T. @@ -29,6 +35,10 @@ pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStr let (_, ty_generics, where_clause) = generics.split_for_impl(); (quote! { + thread_local! { + static #recursive_count: Cell = Cell::new(0); + } + impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause { #arbitrary_method #size_hint_method @@ -67,21 +77,51 @@ fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeDef) -> Generics { generics } -fn gen_arbitrary_method(input: &DeriveInput, lifetime: LifetimeDef) -> TokenStream { +fn with_recursive_count_guard( + recursive_count: &syn::Ident, + expr: impl quote::ToTokens, +) -> impl quote::ToTokens { + quote! { + #recursive_count.with(|count| { + if count.get() > 0 && u.is_empty() { + return Err(arbitrary::Error::NotEnoughData); + } + + count.set(count.get() + 1); + let result = { #expr }; + count.set(count.get() - 1); + + result + }) + } +} + +fn gen_arbitrary_method( + input: &DeriveInput, + lifetime: LifetimeDef, + recursive_count: &syn::Ident, +) -> TokenStream { let ident = &input.ident; + let arbitrary_structlike = |fields| { let arbitrary = construct(fields, |_, _| quote!(arbitrary::Arbitrary::arbitrary(u)?)); + let body = with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary) }); + let arbitrary_take_rest = construct_take_rest(fields); + let take_rest_body = + with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary_take_rest) }); + quote! { fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result { - Ok(#ident #arbitrary) + #body } fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result { - Ok(#ident #arbitrary_take_rest) + #take_rest_body } } }; + match &input.data { Data::Struct(data) => arbitrary_structlike(&data.fields), Data::Union(data) => arbitrary_structlike(&Fields::Named(data.fields.clone())), @@ -101,8 +141,10 @@ fn gen_arbitrary_method(input: &DeriveInput, lifetime: LifetimeDef) -> TokenStre quote! { #idx => #ident::#variant_name #ctor } }); let count = data.variants.len() as u64; - quote! { - fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result { + + let arbitrary = with_recursive_count_guard( + recursive_count, + quote! { // Use a multiply + shift to generate a ranged random number // with slight bias. For details, see: // https://lemire.me/blog/2016/06/30/fast-random-shuffling @@ -110,9 +152,12 @@ fn gen_arbitrary_method(input: &DeriveInput, lifetime: LifetimeDef) -> TokenStre #(#variants,)* _ => unreachable!() }) - } + }, + ); - fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result { + let arbitrary_take_rest = with_recursive_count_guard( + recursive_count, + quote! { // Use a multiply + shift to generate a ranged random number // with slight bias. For details, see: // https://lemire.me/blog/2016/06/30/fast-random-shuffling @@ -120,6 +165,16 @@ fn gen_arbitrary_method(input: &DeriveInput, lifetime: LifetimeDef) -> TokenStre #(#variants_take_rest,)* _ => unreachable!() }) + }, + ); + + quote! { + fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result { + #arbitrary + } + + fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result { + #arbitrary_take_rest } } } diff --git a/tests/derive.rs b/tests/derive.rs index f107688..af44e7c 100755 --- a/tests/derive.rs +++ b/tests/derive.rs @@ -186,3 +186,45 @@ fn two_lifetimes() { assert_eq!(lower, 0); assert_eq!(upper, None); } + +#[test] +fn recursive_and_empty_input() { + // None of the following derives should result in a stack overflow. See + // https://github.com/rust-fuzz/arbitrary/issues/107 for details. + + #[derive(Debug, Arbitrary)] + enum Nat { + Succ(Box), + Zero, + } + + let _ = Nat::arbitrary(&mut Unstructured::new(&[])); + + #[derive(Debug, Arbitrary)] + enum Nat2 { + Zero, + Succ(Box), + } + + let _ = Nat2::arbitrary(&mut Unstructured::new(&[])); + + #[derive(Debug, Arbitrary)] + struct Nat3 { + f: Option>, + } + + let _ = Nat3::arbitrary(&mut Unstructured::new(&[])); + + #[derive(Debug, Arbitrary)] + struct Nat4(Option>); + + let _ = Nat4::arbitrary(&mut Unstructured::new(&[])); + + #[derive(Debug, Arbitrary)] + enum Nat5 { + Zero, + Succ { f: Box }, + } + + let _ = Nat5::arbitrary(&mut Unstructured::new(&[])); +}