diff --git a/staged-builder-internals/src/lib.rs b/staged-builder-internals/src/lib.rs index 1e18c4b..dee8770 100644 --- a/staged-builder-internals/src/lib.rs +++ b/staged-builder-internals/src/lib.rs @@ -28,6 +28,8 @@ use syn::{ /// crate. Defaults to `::staged_builder`. /// * `mod` - The name of the submodule that will contain the generated builder types. Defaults to the struct's name /// converted to `snake_case`. +/// * `inline` - Causes the generated builder types to be defined in the same module as the struct, rather than a +/// submodule. /// /// # Field options /// @@ -235,37 +237,55 @@ fn expand(input: DeriveInput) -> Result { let overrides = StructOverrides::new(&input.attrs)?; let fields = resolve_fields(&overrides, fields)?; - let vis = &input.vis; - let module_name = module_name(&overrides, &input); - let builder_impl = builder_impl(&input, &overrides, &fields); + let module = module(&input, &overrides, &fields); - let module_docs = format!("Builder types for [`{}`].", &input.ident); + let tokens = quote! { + #builder_impl + #module + }; - let builder = builder(&input); - let default = default_impl(&overrides, &fields); + Ok(tokens) +} + +fn module( + input: &DeriveInput, + overrides: &StructOverrides, + fields: &[ResolvedField<'_>], +) -> TokenStream { + let builder = builder(input, overrides); + let default = default_impl(overrides, fields); let stages = fields .iter() .enumerate() .filter(|(_, f)| f.default.is_none()) - .map(|(i, _)| stage(&input, i, &fields)); - let final_stage = final_stage(&input, &overrides, &fields); + .map(|(i, _)| stage(input, overrides, i, fields)); + let final_stage = final_stage(input, overrides, fields); + + let parts = quote! { + #builder + #default + #(#stages)* + #final_stage + }; - let tokens = quote! { - #builder_impl + if overrides.inline { + return parts; + } + + let vis = &input.vis; + let module_name = module_name(overrides, input); + + let module_docs = format!("Builder types for [`{}`].", &input.ident); + quote! { #[doc = #module_docs] #vis mod #module_name { use super::*; - #builder - #default - #(#stages)* - #final_stage + #parts } - }; - - Ok(tokens) + } } fn module_name(overrides: &StructOverrides, input: &DeriveInput) -> Ident { @@ -283,15 +303,20 @@ fn builder_impl( let name = &input.ident; let vis = &input.vis; - let module_name = module_name(overrides, input); - let builder_name = initial_stage(fields).unwrap_or_else(final_name); + let module_path = if overrides.inline { + quote!() + } else { + let module_name = module_name(overrides, input); + quote!(#module_name::) + }; + let stage_name = initial_stage(fields).unwrap_or_else(final_name); let private = overrides.private(); quote! { impl #name { /// Returns a new builder. #[inline] - #vis fn builder() -> #module_name::Builder<#module_name::#builder_name> { + #vis fn builder() -> #module_path Builder<#module_path #stage_name> { #private::Default::default() } } @@ -305,12 +330,20 @@ fn initial_stage(fields: &[ResolvedField<'_>]) -> Option { .map(|f| stage_name(f)) } -fn builder(input: &DeriveInput) -> TokenStream { - let docs = format!("A builder for [{0}](super::{0}).", input.ident); +fn builder(input: &DeriveInput, overrides: &StructOverrides) -> TokenStream { + let link = if overrides.inline { + format!("[{}]", input.ident) + } else { + format!("[{0}](super::{0})", input.ident) + }; + + let docs = format!("A builder for {link}"); + + let vis = stage_vis(&input.vis, overrides); quote! { #[doc = #docs] - pub struct Builder(T); + #vis struct Builder(T); } } @@ -345,8 +378,13 @@ fn default_field_initializers(fields: &[ResolvedField<'_>]) -> TokenStream { quote!(#(#fields,)*) } -fn stage(input: &DeriveInput, idx: usize, fields: &[ResolvedField<'_>]) -> TokenStream { - let vis = stage_vis(&input.vis); +fn stage( + input: &DeriveInput, + overrides: &StructOverrides, + idx: usize, + fields: &[ResolvedField<'_>], +) -> TokenStream { + let vis = stage_vis(&input.vis, overrides); let field = &fields[idx]; let name = field.field.ident.as_ref().unwrap(); @@ -397,7 +435,11 @@ fn stage(input: &DeriveInput, idx: usize, fields: &[ResolvedField<'_>]) -> Token } } -fn stage_vis(vis: &Visibility) -> TokenStream { +fn stage_vis(vis: &Visibility, overrides: &StructOverrides) -> TokenStream { + if overrides.inline { + return quote!(#vis); + } + match vis { Visibility::Public(_) => quote!(#vis), Visibility::Restricted(restricted) => { @@ -440,7 +482,7 @@ fn final_stage( overrides: &StructOverrides, fields: &[ResolvedField<'_>], ) -> TokenStream { - let vis = stage_vis(&input.vis); + let vis = stage_vis(&input.vis, overrides); let builder_name = final_name(); let struct_name = &input.ident; let names = fields.iter().map(|f| f.field.ident.as_ref().unwrap()); @@ -459,7 +501,7 @@ fn final_stage( let build = if overrides.validate { validated_build(input, overrides, fields) } else { - unvalidated_build(input, fields) + unvalidated_build(input, overrides, fields) }; quote! { @@ -610,6 +652,11 @@ fn validated_build( fields: &[ResolvedField<'_>], ) -> TokenStream { let struct_name = &input.ident; + let struct_path = if overrides.inline { + quote!(#struct_name) + } else { + quote!(super::#struct_name) + }; let names = fields .iter() .map(|f| f.field.ident.as_ref().unwrap()) @@ -623,10 +670,10 @@ fn validated_build( pub fn build( self, ) -> #private::Result< - super::#struct_name, - ::Error, + #struct_path, + <#struct_path as #crate_::Validate>::Error, > { - let value = super::#struct_name { + let value = #struct_path { #(#names: self.0.#names,)* }; #crate_::Validate::validate(&value)?; @@ -635,8 +682,17 @@ fn validated_build( } } -fn unvalidated_build(input: &DeriveInput, fields: &[ResolvedField<'_>]) -> TokenStream { +fn unvalidated_build( + input: &DeriveInput, + overrides: &StructOverrides, + fields: &[ResolvedField<'_>], +) -> TokenStream { let struct_name = &input.ident; + let struct_path = if overrides.inline { + quote!(#struct_name) + } else { + quote!(super::#struct_name) + }; let names = fields .iter() .map(|f| f.field.ident.as_ref().unwrap()) @@ -644,8 +700,8 @@ fn unvalidated_build(input: &DeriveInput, fields: &[ResolvedField<'_>]) -> Token quote! { #[inline] - pub fn build(self) -> super::#struct_name { - super::#struct_name { + pub fn build(self) -> #struct_path { + #struct_path { #(#names: self.0.#names,)* } } @@ -682,6 +738,7 @@ struct StructOverrides { crate_: Option, #[struct_meta(name = "mod")] mod_: Option, + inline: bool, } impl StructOverrides { diff --git a/staged-builder/tests/test.rs b/staged-builder/tests/test.rs index 6beaca9..793823c 100644 --- a/staged-builder/tests/test.rs +++ b/staged-builder/tests/test.rs @@ -188,3 +188,23 @@ fn closure_convert() { }; assert_eq!(actual, expected); } + +mod inline { + use staged_builder::staged_builder; + + #[derive(PartialEq, Debug)] + #[staged_builder] + #[builder(inline)] + struct Inline { + a: i32, + } + + #[test] + fn inline() { + let builder: Builder = Inline::builder(); + let stage: Builder = builder.a(1); + let actual: Inline = stage.build(); + let expected = Inline { a: 1 }; + assert_eq!(actual, expected); + } +}