Skip to content

Commit

Permalink
Merge pull request #20 from sfackler/inline
Browse files Browse the repository at this point in the history
Allow inlined types
  • Loading branch information
sfackler committed Jan 31, 2024
2 parents fef49c5 + a94ae6b commit 466ed56
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 34 deletions.
125 changes: 91 additions & 34 deletions staged-builder-internals/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ use syn::{
/// crate. Defaults to `::staged_builder`.
/// * `mod` - The name of the submodule that will contain the generated builder types. Defaults to the struct's name
/// converted to `snake_case`.
/// * `inline` - Causes the generated builder types to be defined in the same module as the struct, rather than a
/// submodule.
///
/// # Field options
///
Expand Down Expand Up @@ -235,37 +237,55 @@ fn expand(input: DeriveInput) -> Result<TokenStream, Error> {
let overrides = StructOverrides::new(&input.attrs)?;
let fields = resolve_fields(&overrides, fields)?;

let vis = &input.vis;
let module_name = module_name(&overrides, &input);

let builder_impl = builder_impl(&input, &overrides, &fields);
let module = module(&input, &overrides, &fields);

let module_docs = format!("Builder types for [`{}`].", &input.ident);
let tokens = quote! {
#builder_impl
#module
};

let builder = builder(&input);
let default = default_impl(&overrides, &fields);
Ok(tokens)
}

fn module(
input: &DeriveInput,
overrides: &StructOverrides,
fields: &[ResolvedField<'_>],
) -> TokenStream {
let builder = builder(input, overrides);
let default = default_impl(overrides, fields);
let stages = fields
.iter()
.enumerate()
.filter(|(_, f)| f.default.is_none())
.map(|(i, _)| stage(&input, i, &fields));
let final_stage = final_stage(&input, &overrides, &fields);
.map(|(i, _)| stage(input, overrides, i, fields));
let final_stage = final_stage(input, overrides, fields);

let parts = quote! {
#builder
#default
#(#stages)*
#final_stage
};

let tokens = quote! {
#builder_impl
if overrides.inline {
return parts;
}

let vis = &input.vis;
let module_name = module_name(overrides, input);

let module_docs = format!("Builder types for [`{}`].", &input.ident);

quote! {
#[doc = #module_docs]
#vis mod #module_name {
use super::*;

#builder
#default
#(#stages)*
#final_stage
#parts
}
};

Ok(tokens)
}
}

fn module_name(overrides: &StructOverrides, input: &DeriveInput) -> Ident {
Expand All @@ -283,15 +303,20 @@ fn builder_impl(
let name = &input.ident;
let vis = &input.vis;

let module_name = module_name(overrides, input);
let builder_name = initial_stage(fields).unwrap_or_else(final_name);
let module_path = if overrides.inline {
quote!()
} else {
let module_name = module_name(overrides, input);
quote!(#module_name::)
};
let stage_name = initial_stage(fields).unwrap_or_else(final_name);
let private = overrides.private();

quote! {
impl #name {
/// Returns a new builder.
#[inline]
#vis fn builder() -> #module_name::Builder<#module_name::#builder_name> {
#vis fn builder() -> #module_path Builder<#module_path #stage_name> {
#private::Default::default()
}
}
Expand All @@ -305,12 +330,20 @@ fn initial_stage(fields: &[ResolvedField<'_>]) -> Option<Ident> {
.map(|f| stage_name(f))
}

fn builder(input: &DeriveInput) -> TokenStream {
let docs = format!("A builder for [{0}](super::{0}).", input.ident);
fn builder(input: &DeriveInput, overrides: &StructOverrides) -> TokenStream {
let link = if overrides.inline {
format!("[{}]", input.ident)
} else {
format!("[{0}](super::{0})", input.ident)
};

let docs = format!("A builder for {link}");

let vis = stage_vis(&input.vis, overrides);

quote! {
#[doc = #docs]
pub struct Builder<T>(T);
#vis struct Builder<T>(T);
}
}

Expand Down Expand Up @@ -345,8 +378,13 @@ fn default_field_initializers(fields: &[ResolvedField<'_>]) -> TokenStream {
quote!(#(#fields,)*)
}

fn stage(input: &DeriveInput, idx: usize, fields: &[ResolvedField<'_>]) -> TokenStream {
let vis = stage_vis(&input.vis);
fn stage(
input: &DeriveInput,
overrides: &StructOverrides,
idx: usize,
fields: &[ResolvedField<'_>],
) -> TokenStream {
let vis = stage_vis(&input.vis, overrides);
let field = &fields[idx];
let name = field.field.ident.as_ref().unwrap();

Expand Down Expand Up @@ -397,7 +435,11 @@ fn stage(input: &DeriveInput, idx: usize, fields: &[ResolvedField<'_>]) -> Token
}
}

fn stage_vis(vis: &Visibility) -> TokenStream {
fn stage_vis(vis: &Visibility, overrides: &StructOverrides) -> TokenStream {
if overrides.inline {
return quote!(#vis);
}

match vis {
Visibility::Public(_) => quote!(#vis),
Visibility::Restricted(restricted) => {
Expand Down Expand Up @@ -440,7 +482,7 @@ fn final_stage(
overrides: &StructOverrides,
fields: &[ResolvedField<'_>],
) -> TokenStream {
let vis = stage_vis(&input.vis);
let vis = stage_vis(&input.vis, overrides);
let builder_name = final_name();
let struct_name = &input.ident;
let names = fields.iter().map(|f| f.field.ident.as_ref().unwrap());
Expand All @@ -459,7 +501,7 @@ fn final_stage(
let build = if overrides.validate {
validated_build(input, overrides, fields)
} else {
unvalidated_build(input, fields)
unvalidated_build(input, overrides, fields)
};

quote! {
Expand Down Expand Up @@ -610,6 +652,11 @@ fn validated_build(
fields: &[ResolvedField<'_>],
) -> TokenStream {
let struct_name = &input.ident;
let struct_path = if overrides.inline {
quote!(#struct_name)
} else {
quote!(super::#struct_name)
};
let names = fields
.iter()
.map(|f| f.field.ident.as_ref().unwrap())
Expand All @@ -623,10 +670,10 @@ fn validated_build(
pub fn build(
self,
) -> #private::Result<
super::#struct_name,
<super::#struct_name as #crate_::Validate>::Error,
#struct_path,
<#struct_path as #crate_::Validate>::Error,
> {
let value = super::#struct_name {
let value = #struct_path {
#(#names: self.0.#names,)*
};
#crate_::Validate::validate(&value)?;
Expand All @@ -635,17 +682,26 @@ fn validated_build(
}
}

fn unvalidated_build(input: &DeriveInput, fields: &[ResolvedField<'_>]) -> TokenStream {
fn unvalidated_build(
input: &DeriveInput,
overrides: &StructOverrides,
fields: &[ResolvedField<'_>],
) -> TokenStream {
let struct_name = &input.ident;
let struct_path = if overrides.inline {
quote!(#struct_name)
} else {
quote!(super::#struct_name)
};
let names = fields
.iter()
.map(|f| f.field.ident.as_ref().unwrap())
.collect::<Vec<_>>();

quote! {
#[inline]
pub fn build(self) -> super::#struct_name {
super::#struct_name {
pub fn build(self) -> #struct_path {
#struct_path {
#(#names: self.0.#names,)*
}
}
Expand Down Expand Up @@ -682,6 +738,7 @@ struct StructOverrides {
crate_: Option<Path>,
#[struct_meta(name = "mod")]
mod_: Option<Ident>,
inline: bool,
}

impl StructOverrides {
Expand Down
20 changes: 20 additions & 0 deletions staged-builder/tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,23 @@ fn closure_convert() {
};
assert_eq!(actual, expected);
}

mod inline {
use staged_builder::staged_builder;

#[derive(PartialEq, Debug)]
#[staged_builder]
#[builder(inline)]
struct Inline {
a: i32,
}

#[test]
fn inline() {
let builder: Builder<AStage> = Inline::builder();
let stage: Builder<Complete> = builder.a(1);
let actual: Inline = stage.build();
let expected = Inline { a: 1 };
assert_eq!(actual, expected);
}
}

0 comments on commit 466ed56

Please sign in to comment.