diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 235b3da74..9c2c0ab91 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -180,6 +180,7 @@ impl<'a> CodeGenerator<'a> { self.append_doc(&fq_message_name, None); self.append_type_attributes(&fq_message_name); + self.append_message_attributes(&fq_message_name); self.push_indent(); self.buf .push_str("#[allow(clippy::derive_partial_eq_without_eq)]\n"); @@ -270,6 +271,24 @@ impl<'a> CodeGenerator<'a> { } } + fn append_message_attributes(&mut self, fq_message_name: &str) { + assert_eq!(b'.', fq_message_name.as_bytes()[0]); + for attribute in self.config.message_attributes.get(fq_message_name) { + push_indent(self.buf, self.depth); + self.buf.push_str(attribute); + self.buf.push('\n'); + } + } + + fn append_enum_attributes(&mut self, fq_message_name: &str) { + assert_eq!(b'.', fq_message_name.as_bytes()[0]); + for attribute in self.config.enum_attributes.get(fq_message_name) { + push_indent(self.buf, self.depth); + self.buf.push_str(attribute); + self.buf.push('\n'); + } + } + fn append_field_attributes(&mut self, fq_message_name: &str, field_name: &str) { assert_eq!(b'.', fq_message_name.as_bytes()[0]); for attribute in self @@ -504,6 +523,7 @@ impl<'a> CodeGenerator<'a> { let oneof_name = format!("{}.{}", fq_message_name, oneof.name()); self.append_type_attributes(&oneof_name); + self.append_enum_attributes(&oneof_name); self.push_indent(); self.buf .push_str("#[allow(clippy::derive_partial_eq_without_eq)]\n"); @@ -615,6 +635,7 @@ impl<'a> CodeGenerator<'a> { self.append_doc(&fq_proto_enum_name, None); self.append_type_attributes(&fq_proto_enum_name); + self.append_enum_attributes(&fq_proto_enum_name); self.push_indent(); self.buf.push_str( &format!("#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, {}::Enumeration)]\n",self.config.prost_path.as_deref().unwrap_or("::prost")), diff --git a/prost-build/src/fixtures/helloworld/_expected_helloworld.rs b/prost-build/src/fixtures/helloworld/_expected_helloworld.rs new file mode 100644 index 000000000..a64c4da3c --- /dev/null +++ b/prost-build/src/fixtures/helloworld/_expected_helloworld.rs @@ -0,0 +1,44 @@ +#[derive(derive_builder::Builder)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Message { + #[prost(string, tag = "1")] + pub say: ::prost::alloc::string::String, +} +#[derive(derive_builder::Builder)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Response { + #[prost(string, tag = "1")] + pub say: ::prost::alloc::string::String, +} +#[some_enum_attr(u8)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum ServingStatus { + Unknown = 0, + Serving = 1, + NotServing = 2, +} +impl ServingStatus { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + ServingStatus::Unknown => "UNKNOWN", + ServingStatus::Serving => "SERVING", + ServingStatus::NotServing => "NOT_SERVING", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "UNKNOWN" => Some(Self::Unknown), + "SERVING" => Some(Self::Serving), + "NOT_SERVING" => Some(Self::NotServing), + _ => None, + } + } +} diff --git a/prost-build/src/fixtures/helloworld/types.proto b/prost-build/src/fixtures/helloworld/types.proto index 4d9d5e0e2..5bf84aa67 100644 --- a/prost-build/src/fixtures/helloworld/types.proto +++ b/prost-build/src/fixtures/helloworld/types.proto @@ -9,3 +9,9 @@ message Message { message Response { string say = 1; } + +enum ServingStatus { + UNKNOWN = 0; + SERVING = 1; + NOT_SERVING = 2; +} diff --git a/prost-build/src/lib.rs b/prost-build/src/lib.rs index 189a5dc93..5502d4461 100644 --- a/prost-build/src/lib.rs +++ b/prost-build/src/lib.rs @@ -244,6 +244,8 @@ pub struct Config { map_type: PathMap, bytes_type: PathMap, type_attributes: PathMap, + message_attributes: PathMap, + enum_attributes: PathMap, field_attributes: PathMap, prost_types: bool, strip_enum_prefix: bool, @@ -468,6 +470,94 @@ impl Config { self } + /// Add additional attribute to matched messages. + /// + /// # Arguments + /// + /// **`paths`** - a path matching any number of types. It works the same way as in + /// [`btree_map`](#method.btree_map), just with the field name omitted. + /// + /// **`attribute`** - an arbitrary string to be placed before each matched type. The + /// expected usage are additional attributes, but anything is allowed. + /// + /// The calls to this method are cumulative. They don't overwrite previous calls and if a + /// type is matched by multiple calls of the method, all relevant attributes are added to + /// it. + /// + /// For things like serde it might be needed to combine with [field + /// attributes](#method.field_attribute). + /// + /// # Examples + /// + /// ```rust + /// # let mut config = prost_build::Config::new(); + /// // Nothing around uses floats, so we can derive real `Eq` in addition to `PartialEq`. + /// config.message_attribute(".", "#[derive(Eq)]"); + /// // Some messages want to be serializable with serde as well. + /// config.message_attribute("my_messages.MyMessageType", + /// "#[derive(Serialize)] #[serde(rename_all = \"snake_case\")]"); + /// config.message_attribute("my_messages.MyMessageType.MyNestedMessageType", + /// "#[derive(Serialize)] #[serde(rename_all = \"snake_case\")]"); + /// ``` + pub fn message_attribute(&mut self, path: P, attribute: A) -> &mut Self + where + P: AsRef, + A: AsRef, + { + self.message_attributes + .insert(path.as_ref().to_string(), attribute.as_ref().to_string()); + self + } + + /// Add additional attribute to matched enums and one-ofs. + /// + /// # Arguments + /// + /// **`paths`** - a path matching any number of types. It works the same way as in + /// [`btree_map`](#method.btree_map), just with the field name omitted. + /// + /// **`attribute`** - an arbitrary string to be placed before each matched type. The + /// expected usage are additional attributes, but anything is allowed. + /// + /// The calls to this method are cumulative. They don't overwrite previous calls and if a + /// type is matched by multiple calls of the method, all relevant attributes are added to + /// it. + /// + /// For things like serde it might be needed to combine with [field + /// attributes](#method.field_attribute). + /// + /// # Examples + /// + /// ```rust + /// # let mut config = prost_build::Config::new(); + /// // Nothing around uses floats, so we can derive real `Eq` in addition to `PartialEq`. + /// config.enum_attribute(".", "#[derive(Eq)]"); + /// // Some messages want to be serializable with serde as well. + /// config.enum_attribute("my_messages.MyEnumType", + /// "#[derive(Serialize)] #[serde(rename_all = \"snake_case\")]"); + /// config.enum_attribute("my_messages.MyMessageType.MyNestedEnumType", + /// "#[derive(Serialize)] #[serde(rename_all = \"snake_case\")]"); + /// ``` + /// + /// # Oneof fields + /// + /// The `oneof` fields don't have a type name of their own inside Protobuf. Therefore, the + /// field name can be used both with `enum_attribute` and `field_attribute` ‒ the first is + /// placed before the `enum` type definition, the other before the field inside corresponding + /// message `struct`. + /// + /// In other words, to place an attribute on the `enum` implementing the `oneof`, the match + /// would look like `my_messages.MyNestedMessageType.oneofname`. + pub fn enum_attribute(&mut self, path: P, attribute: A) -> &mut Self + where + P: AsRef, + A: AsRef, + { + self.enum_attributes + .insert(path.as_ref().to_string(), attribute.as_ref().to_string()); + self + } + /// Configures the code generator to use the provided service generator. pub fn service_generator(&mut self, service_generator: Box) -> &mut Self { self.service_generator = Some(service_generator); @@ -1099,6 +1189,8 @@ impl default::Default for Config { map_type: PathMap::default(), bytes_type: PathMap::default(), type_attributes: PathMap::default(), + message_attributes: PathMap::default(), + enum_attributes: PathMap::default(), field_attributes: PathMap::default(), prost_types: true, strip_enum_prefix: true, @@ -1425,6 +1517,37 @@ mod tests { assert_eq!(state.finalized, 3); } + #[test] + fn test_generate_message_attributes() { + let _ = env_logger::try_init(); + + let out_dir = std::env::temp_dir(); + + Config::new() + .out_dir(out_dir.clone()) + .message_attribute(".", "#[derive(derive_builder::Builder)]") + .enum_attribute(".", "#[some_enum_attr(u8)]") + .compile_protos( + &["src/fixtures/helloworld/hello.proto"], + &["src/fixtures/helloworld"], + ) + .unwrap(); + + let out_file = out_dir + .join("helloworld.rs") + .as_path() + .display() + .to_string(); + let expected_content = read_all_content("src/fixtures/helloworld/_expected_helloworld.rs") + .replace("\r\n", "\n"); + let content = read_all_content(&out_file).replace("\r\n", "\n"); + assert_eq!( + expected_content, content, + "Unexpected content: \n{}", + content + ); + } + #[test] fn test_generate_no_empty_outputs() { let _ = env_logger::try_init();