diff --git a/prost-build/src/config.rs b/prost-build/src/config.rs index a696e404b..4d5dcce44 100644 --- a/prost-build/src/config.rs +++ b/prost-build/src/config.rs @@ -835,37 +835,42 @@ impl Config { Ok(()) } - /// Compile `.proto` files into Rust files during a Cargo build with additional code generator - /// configuration options. - /// - /// This method is like the `prost_build::compile_protos` function, with the added ability to - /// specify non-default code generation options. See that function for more information about - /// the arguments and generated outputs. - /// - /// The `protos` and `includes` arguments are ignored if `skip_protoc_run` is specified. + /// Loads `.proto` files as a [`FileDescriptorSet`]. This allows inspection of the descriptors + /// before calling [`Config::compile_fds`]. This could be used to change [`Config`] + /// attributes after introspecting what is actually present in the `.proto` files. /// /// # Example `build.rs` /// /// ```rust,no_run - /// # use std::io::Result; - /// fn main() -> Result<()> { - /// let mut prost_build = prost_build::Config::new(); - /// prost_build.btree_map(&["."]); - /// prost_build.compile_protos(&["src/frontend.proto", "src/backend.proto"], &["src"])?; - /// Ok(()) + /// # use prost_types::FileDescriptorSet; + /// # use prost_build::Config; + /// fn main() -> std::io::Result<()> { + /// let mut config = Config::new(); + /// let file_descriptor_set = config.load_fds(&["src/frontend.proto", "src/backend.proto"], &["src"])?; + /// + /// // Add custom attributes to messages that are service inputs or outputs. + /// for file in &file_descriptor_set.file { + /// for service in &file.service { + /// for method in &service.method { + /// if let Some(input) = &method.input_type { + /// config.message_attribute(input, "#[derive(custom_proto::Input)]"); + /// } + /// if let Some(output) = &method.output_type { + /// config.message_attribute(output, "#[derive(custom_proto::Output)]"); + /// } + /// } + /// } + /// } + /// + /// config.compile_fds(file_descriptor_set) /// } /// ``` - pub fn compile_protos( + + pub fn load_fds( &mut self, protos: &[impl AsRef], includes: &[impl AsRef], - ) -> Result<()> { - // TODO: This should probably emit 'rerun-if-changed=PATH' directives for cargo, however - // according to [1] if any are output then those paths replace the default crate root, - // which is undesirable. Figure out how to do it in an additive way; perhaps gcc-rs has - // this figured out. - // [1]: http://doc.crates.io/build-script.html#outputs-of-the-build-script - + ) -> Result { let tmp; let file_descriptor_set_path = if let Some(path) = &self.file_descriptor_set_path { path.clone() @@ -952,6 +957,42 @@ impl Config { ) })?; + Ok(file_descriptor_set) + } + + /// Compile `.proto` files into Rust files during a Cargo build with additional code generator + /// configuration options. + /// + /// This method is like the `prost_build::compile_protos` function, with the added ability to + /// specify non-default code generation options. See that function for more information about + /// the arguments and generated outputs. + /// + /// The `protos` and `includes` arguments are ignored if `skip_protoc_run` is specified. + /// + /// # Example `build.rs` + /// + /// ```rust,no_run + /// # use std::io::Result; + /// fn main() -> Result<()> { + /// let mut prost_build = prost_build::Config::new(); + /// prost_build.btree_map(&["."]); + /// prost_build.compile_protos(&["src/frontend.proto", "src/backend.proto"], &["src"])?; + /// Ok(()) + /// } + /// ``` + pub fn compile_protos( + &mut self, + protos: &[impl AsRef], + includes: &[impl AsRef], + ) -> Result<()> { + // TODO: This should probably emit 'rerun-if-changed=PATH' directives for cargo, however + // according to [1] if any are output then those paths replace the default crate root, + // which is undesirable. Figure out how to do it in an additive way; perhaps gcc-rs has + // this figured out. + // [1]: http://doc.crates.io/build-script.html#outputs-of-the-build-script + + let file_descriptor_set = self.load_fds(protos, includes)?; + self.compile_fds(file_descriptor_set) } diff --git a/prost-build/src/fixtures/helloworld/_expected_helloworld.rs b/prost-build/src/fixtures/helloworld/_expected_helloworld.rs index 401ee90cd..6936a82d0 100644 --- a/prost-build/src/fixtures/helloworld/_expected_helloworld.rs +++ b/prost-build/src/fixtures/helloworld/_expected_helloworld.rs @@ -1,5 +1,6 @@ // This file is @generated by prost-build. #[derive(derive_builder::Builder)] +#[derive(custom_proto::Input)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Message { @@ -7,6 +8,7 @@ pub struct Message { pub say: ::prost::alloc::string::String, } #[derive(derive_builder::Builder)] +#[derive(custom_proto::Output)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Response { diff --git a/prost-build/src/fixtures/helloworld/_expected_helloworld_formatted.rs b/prost-build/src/fixtures/helloworld/_expected_helloworld_formatted.rs index 3f688c7e0..95a95fe65 100644 --- a/prost-build/src/fixtures/helloworld/_expected_helloworld_formatted.rs +++ b/prost-build/src/fixtures/helloworld/_expected_helloworld_formatted.rs @@ -1,5 +1,6 @@ // This file is @generated by prost-build. #[derive(derive_builder::Builder)] +#[derive(custom_proto::Input)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Message { @@ -7,6 +8,7 @@ pub struct Message { pub say: ::prost::alloc::string::String, } #[derive(derive_builder::Builder)] +#[derive(custom_proto::Output)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Response { diff --git a/prost-build/src/lib.rs b/prost-build/src/lib.rs index e3659a2c5..15b687af2 100644 --- a/prost-build/src/lib.rs +++ b/prost-build/src/lib.rs @@ -390,16 +390,36 @@ mod tests { let _ = env_logger::try_init(); let tempdir = tempfile::tempdir().unwrap(); - Config::new() + let mut config = Config::new(); + config .out_dir(tempdir.path()) + // Add attributes to all messages and enums .message_attribute(".", "#[derive(derive_builder::Builder)]") - .enum_attribute(".", "#[some_enum_attr(u8)]") - .compile_protos( + .enum_attribute(".", "#[some_enum_attr(u8)]"); + + let fds = config + .load_fds( &["src/fixtures/helloworld/hello.proto"], &["src/fixtures/helloworld"], ) .unwrap(); + // Add custom attributes to messages that are service inputs or outputs. + for file in &fds.file { + for service in &file.service { + for method in &service.method { + if let Some(input) = &method.input_type { + config.message_attribute(input, "#[derive(custom_proto::Input)]"); + } + if let Some(output) = &method.output_type { + config.message_attribute(output, "#[derive(custom_proto::Output)]"); + } + } + } + } + + config.compile_fds(fds).unwrap(); + let out_file = tempdir.path().join("helloworld.rs"); #[cfg(feature = "format")] let expected_content =