Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix getter/setter generated for optional enum fields should use option #1060

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 57 additions & 52 deletions prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ impl Field {
}

fn rust_name(&self) -> String {
to_snake(self.descriptor.name())
to_snake(self.descriptor.name_fallback())
}
}

Expand All @@ -83,7 +83,7 @@ impl OneofField {
}

fn rust_name(&self) -> String {
to_snake(self.descriptor.name())
to_snake(self.descriptor.name_fallback())
}
}

Expand Down Expand Up @@ -156,9 +156,9 @@ impl<'a> CodeGenerator<'a> {
}

fn append_message(&mut self, message: DescriptorProto) {
debug!(" message: {:?}", message.name());
debug!(" message: {:?}", message.name_fallback());

let message_name = message.name().to_string();
let message_name = message.name_fallback().to_string();
let fq_message_name = self.fq_name(&message_name);

// Skip external types.
Expand All @@ -184,10 +184,10 @@ impl<'a> CodeGenerator<'a> {
{
let key = nested_type.field[0].clone();
let value = nested_type.field[1].clone();
assert_eq!("key", key.name());
assert_eq!("value", value.name());
assert_eq!("key", key.name_fallback());
assert_eq!("value", value.name_fallback());

let name = format!("{}.{}", &fq_message_name, nested_type.name());
let name = format!("{}.{}", &fq_message_name, nested_type.name_fallback());
Either::Right((name, (key, value)))
} else {
Either::Left((nested_type, idx))
Expand Down Expand Up @@ -398,7 +398,7 @@ impl<'a> CodeGenerator<'a> {
}

fn append_field(&mut self, fq_message_name: &str, field: &Field) {
let type_ = field.descriptor.r#type();
let type_ = field.descriptor.r#type_fallback();
let repeated = field.descriptor.label == Some(Label::Repeated as i32);
let deprecated = self.deprecated(&field.descriptor);
let optional = self.optional(&field.descriptor);
Expand All @@ -412,7 +412,7 @@ impl<'a> CodeGenerator<'a> {
boxed
);

self.append_doc(fq_message_name, Some(field.descriptor.name()));
self.append_doc(fq_message_name, Some(field.descriptor.name_fallback()));

if deprecated {
self.push_indent();
Expand All @@ -428,14 +428,14 @@ impl<'a> CodeGenerator<'a> {
let bytes_type = self
.config
.bytes_type
.get_first_field(fq_message_name, field.descriptor.name())
.get_first_field(fq_message_name, field.descriptor.name_fallback())
.copied()
.unwrap_or_default();
self.buf
.push_str(&format!("={:?}", bytes_type.annotation()));
}

match field.descriptor.label() {
match field.descriptor.label_fallback() {
Label::Optional => {
if optional {
self.buf.push_str(", optional");
Expand All @@ -449,7 +449,9 @@ impl<'a> CodeGenerator<'a> {
.descriptor
.options
.as_ref()
.map_or(self.syntax == Syntax::Proto3, |options| options.packed())
.map_or(self.syntax == Syntax::Proto3, |options| {
options.packed_fallback()
})
{
self.buf.push_str(", packed=\"false\"");
}
Expand All @@ -460,7 +462,8 @@ impl<'a> CodeGenerator<'a> {
self.buf.push_str(", boxed");
}
self.buf.push_str(", tag=\"");
self.buf.push_str(&field.descriptor.number().to_string());
self.buf
.push_str(&field.descriptor.number_fallback().to_string());

if let Some(ref default) = field.descriptor.default_value {
self.buf.push_str("\", default=\"");
Expand Down Expand Up @@ -494,7 +497,7 @@ impl<'a> CodeGenerator<'a> {
}

self.buf.push_str("\")]\n");
self.append_field_attributes(fq_message_name, field.descriptor.name());
self.append_field_attributes(fq_message_name, field.descriptor.name_fallback());
self.push_indent();
self.buf.push_str("pub ");
self.buf.push_str(&field.rust_name());
Expand Down Expand Up @@ -539,13 +542,13 @@ impl<'a> CodeGenerator<'a> {
value_ty
);

self.append_doc(fq_message_name, Some(field.descriptor.name()));
self.append_doc(fq_message_name, Some(field.descriptor.name_fallback()));
self.push_indent();

let map_type = self
.config
.map_type
.get_first_field(fq_message_name, field.descriptor.name())
.get_first_field(fq_message_name, field.descriptor.name_fallback())
.copied()
.unwrap_or_default();
let key_tag = self.field_type_tag(key);
Expand All @@ -556,9 +559,9 @@ impl<'a> CodeGenerator<'a> {
map_type.annotation(),
key_tag,
value_tag,
field.descriptor.number()
field.descriptor.number_fallback()
));
self.append_field_attributes(fq_message_name, field.descriptor.name());
self.append_field_attributes(fq_message_name, field.descriptor.name_fallback());
self.push_indent();
self.buf.push_str(&format!(
"pub {}: {}<{}, {}>,\n",
Expand All @@ -578,7 +581,7 @@ impl<'a> CodeGenerator<'a> {
let type_name = format!(
"{}::{}",
to_snake(message_name),
to_upper_camel(oneof.descriptor.name())
to_upper_camel(oneof.descriptor.name_fallback())
);
self.append_doc(fq_message_name, None);
self.push_indent();
Expand All @@ -588,10 +591,10 @@ impl<'a> CodeGenerator<'a> {
oneof
.fields
.iter()
.map(|field| field.descriptor.number())
.map(|field| field.descriptor.number_fallback())
.join(", "),
));
self.append_field_attributes(fq_message_name, oneof.descriptor.name());
self.append_field_attributes(fq_message_name, oneof.descriptor.name_fallback());
self.push_indent();
self.buf.push_str(&format!(
"pub {}: ::core::option::Option<{}>,\n",
Expand All @@ -607,7 +610,7 @@ impl<'a> CodeGenerator<'a> {
self.path.pop();
self.path.pop();

let oneof_name = format!("{}.{}", fq_message_name, oneof.descriptor.name());
let oneof_name = format!("{}.{}", fq_message_name, oneof.descriptor.name_fallback());
self.append_type_attributes(&oneof_name);
self.append_enum_attributes(&oneof_name);
self.push_indent();
Expand All @@ -620,32 +623,33 @@ impl<'a> CodeGenerator<'a> {
self.append_skip_debug(fq_message_name);
self.push_indent();
self.buf.push_str("pub enum ");
self.buf.push_str(&to_upper_camel(oneof.descriptor.name()));
self.buf
.push_str(&to_upper_camel(oneof.descriptor.name_fallback()));
self.buf.push_str(" {\n");

self.path.push(2);
self.depth += 1;
for field in &oneof.fields {
self.path.push(field.path_index);
self.append_doc(fq_message_name, Some(field.descriptor.name()));
self.append_doc(fq_message_name, Some(field.descriptor.name_fallback()));
self.path.pop();

self.push_indent();
let ty_tag = self.field_type_tag(&field.descriptor);
self.buf.push_str(&format!(
"#[prost({}, tag=\"{}\")]\n",
ty_tag,
field.descriptor.number()
field.descriptor.number_fallback()
));
self.append_field_attributes(&oneof_name, field.descriptor.name());
self.append_field_attributes(&oneof_name, field.descriptor.name_fallback());

self.push_indent();
let ty = self.resolve_type(&field.descriptor, fq_message_name);

let boxed = self.boxed(
&field.descriptor,
fq_message_name,
Some(oneof.descriptor.name()),
Some(oneof.descriptor.name_fallback()),
);

debug!(
Expand All @@ -658,13 +662,13 @@ impl<'a> CodeGenerator<'a> {
if boxed {
self.buf.push_str(&format!(
"{}(::prost::alloc::boxed::Box<{}>),\n",
to_upper_camel(field.descriptor.name()),
to_upper_camel(field.descriptor.name_fallback()),
ty
));
} else {
self.buf.push_str(&format!(
"{}({}),\n",
to_upper_camel(field.descriptor.name()),
to_upper_camel(field.descriptor.name_fallback()),
ty
));
}
Expand Down Expand Up @@ -704,7 +708,7 @@ impl<'a> CodeGenerator<'a> {
fn append_enum(&mut self, desc: EnumDescriptorProto) {
debug!(" enum: {:?}", desc.name());

let proto_enum_name = desc.name();
let proto_enum_name = desc.name_fallback();
let enum_name = to_upper_camel(proto_enum_name);

let enum_values = &desc.value;
Expand Down Expand Up @@ -851,7 +855,7 @@ impl<'a> CodeGenerator<'a> {
}

fn push_service(&mut self, service: ServiceDescriptorProto) {
let name = service.name().to_owned();
let name = service.name_fallback().to_owned();
debug!(" service: {:?}", name);

let comments = self
Expand Down Expand Up @@ -879,8 +883,8 @@ impl<'a> CodeGenerator<'a> {
let output_proto_type = method.output_type.take().unwrap();
let input_type = self.resolve_ident(&input_proto_type);
let output_type = self.resolve_ident(&output_proto_type);
let client_streaming = method.client_streaming();
let server_streaming = method.server_streaming();
let client_streaming = method.client_streaming_fallback();
let server_streaming = method.server_streaming_fallback();

Method {
name: to_snake(&name),
Expand Down Expand Up @@ -942,7 +946,7 @@ impl<'a> CodeGenerator<'a> {
}

fn resolve_type(&self, field: &FieldDescriptorProto, fq_message_name: &str) -> String {
match field.r#type() {
match field.r#type_fallback() {
Type::Float => String::from("f32"),
Type::Double => String::from("f64"),
Type::Uint32 | Type::Fixed32 => String::from("u32"),
Expand All @@ -954,12 +958,12 @@ impl<'a> CodeGenerator<'a> {
Type::Bytes => self
.config
.bytes_type
.get_first_field(fq_message_name, field.name())
.get_first_field(fq_message_name, field.name_fallback())
.copied()
.unwrap_or_default()
.rust_type()
.to_owned(),
Type::Group | Type::Message => self.resolve_ident(field.type_name()),
Type::Group | Type::Message => self.resolve_ident(field.type_name_fallback()),
}
}

Expand Down Expand Up @@ -1002,7 +1006,7 @@ impl<'a> CodeGenerator<'a> {
}

fn field_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> {
match field.r#type() {
match field.r#type_fallback() {
Type::Float => Cow::Borrowed("float"),
Type::Double => Cow::Borrowed("double"),
Type::Int32 => Cow::Borrowed("int32"),
Expand All @@ -1022,16 +1026,16 @@ impl<'a> CodeGenerator<'a> {
Type::Message => Cow::Borrowed("message"),
Type::Enum => Cow::Owned(format!(
"enumeration={:?}",
self.resolve_ident(field.type_name())
self.resolve_ident(field.type_name_fallback())
)),
}
}

fn map_value_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> {
match field.r#type() {
match field.r#type_fallback() {
Type::Enum => Cow::Owned(format!(
"enumeration({})",
self.resolve_ident(field.type_name())
self.resolve_ident(field.type_name_fallback())
)),
_ => self.field_type_tag(field),
}
Expand All @@ -1042,11 +1046,11 @@ impl<'a> CodeGenerator<'a> {
return true;
}

if field.label() != Label::Optional {
if field.label_fallback() != Label::Optional {
return false;
}

match field.r#type() {
match field.r#type_fallback() {
Type::Message => true,
_ => self.syntax == Syntax::Proto2,
}
Expand All @@ -1064,12 +1068,12 @@ impl<'a> CodeGenerator<'a> {
oneof: Option<&str>,
) -> bool {
let repeated = field.label == Some(Label::Repeated as i32);
let fd_type = field.r#type();
let fd_type = field.r#type_fallback();
if !repeated
&& (fd_type == Type::Message || fd_type == Type::Group)
&& self
.message_graph
.is_nested(field.type_name(), fq_message_name)
.is_nested(field.type_name_fallback(), fq_message_name)
{
return true;
}
Expand All @@ -1080,7 +1084,7 @@ impl<'a> CodeGenerator<'a> {
if self
.config
.boxed
.get_first_field(&config_path, field.name())
.get_first_field(&config_path, field.name_fallback())
.is_some()
{
if repeated {
Expand All @@ -1100,7 +1104,7 @@ impl<'a> CodeGenerator<'a> {
field
.options
.as_ref()
.map_or(false, FieldOptions::deprecated)
.map_or(false, FieldOptions::deprecated_fallback)
}

/// Returns the fully-qualified name, starting with a dot
Expand All @@ -1119,7 +1123,7 @@ impl<'a> CodeGenerator<'a> {
/// Returns `true` if the repeated field type can be packed.
fn can_pack(field: &FieldDescriptorProto) -> bool {
matches!(
field.r#type(),
field.r#type_fallback(),
Type::Float
| Type::Double
| Type::Int32
Expand Down Expand Up @@ -1160,22 +1164,23 @@ fn build_enum_value_mappings<'a>(
continue;
}

let mut generated_variant_name = to_upper_camel(value.name());
let mut generated_variant_name = to_upper_camel(value.name_fallback());
if do_strip_enum_prefix {
generated_variant_name =
strip_enum_prefix(generated_enum_name, &generated_variant_name);
}

if let Some(old_v) = generated_names.insert(generated_variant_name.to_owned(), value.name())
if let Some(old_v) =
generated_names.insert(generated_variant_name.to_owned(), value.name_fallback())
{
panic!("Generated enum variant names overlap: `{}` variant name to be used both by `{}` and `{}` ProtoBuf enum values",
generated_variant_name, old_v, value.name());
generated_variant_name, old_v, value.name_fallback());
}

mappings.push(EnumVariantMapping {
path_idx: idx,
proto_name: value.name(),
proto_number: value.number(),
proto_name: value.name_fallback(),
proto_number: value.number_fallback(),
generated_variant_name,
})
}
Expand Down
Loading
Loading