From 15eccf3ec4636cab20a16b7570d3967d79ecfa95 Mon Sep 17 00:00:00 2001 From: Mike Kruskal Date: Tue, 7 Nov 2023 14:37:40 -0800 Subject: [PATCH] Implement Editions in Pure Python. This change only covers pure python, and follow-up changes will handle C++/upb variants and actually enable editions support. The C++ one works (as evident from the conformance tests), but needs some APIs added to allow for testing. PiperOrigin-RevId: 580304039 --- conformance/BUILD.bazel | 3 +- conformance/conformance_python.py | 22 +- conformance/failure_list_python.txt | 3 + conformance/failure_list_python_cpp.txt | 3 + .../text_format_failure_list_python.txt | 6 + .../text_format_failure_list_python_cpp.txt | 4 + python/build_targets.bzl | 70 ++- python/google/protobuf/descriptor.py | 422 ++++++++++++++---- python/google/protobuf/descriptor_pool.py | 116 ++++- .../protobuf/internal/descriptor_pool_test.py | 203 ++++++++- .../protobuf/internal/descriptor_test.py | 369 ++++++++++++++- .../protobuf/internal/legacy_features.proto | 18 + .../python_edition_defaults.py.template | 5 + src/google/protobuf/compiler/BUILD.bazel | 3 +- .../protobuf/compiler/python/generator.cc | 193 ++++---- .../protobuf/compiler/python/generator.h | 59 ++- src/google/protobuf/editions/BUILD | 26 ++ 17 files changed, 1242 insertions(+), 283 deletions(-) create mode 100644 python/google/protobuf/internal/legacy_features.proto create mode 100644 python/google/protobuf/internal/python_edition_defaults.py.template diff --git a/conformance/BUILD.bazel b/conformance/BUILD.bazel index e54eab02d2a2..d0640dc9a9e9 100644 --- a/conformance/BUILD.bazel +++ b/conformance/BUILD.bazel @@ -261,8 +261,7 @@ py_binary( deps = [ ":conformance_py_proto", "//:protobuf_python", - "//python:test_messages_proto2_py_proto", - "//python:test_messages_proto3_py_proto", + "//python:conformance_test_py_proto", ], ) diff --git a/conformance/conformance_python.py b/conformance/conformance_python.py index 5e2a99af820a..1709457a3744 100755 --- a/conformance/conformance_python.py +++ b/conformance/conformance_python.py @@ -19,6 +19,8 @@ from google.protobuf import test_messages_proto2_pb2 from google.protobuf import test_messages_proto3_pb2 from conformance import conformance_pb2 +from google.protobuf.editions.golden import test_messages_proto2_editions_pb2 +from google.protobuf.editions.golden import test_messages_proto3_editions_pb2 test_count = 0 verbose = False @@ -28,6 +30,18 @@ class ProtocolError(Exception): pass +def _create_test_message(type): + if type == "protobuf_test_messages.proto2.TestAllTypesProto2": + return test_messages_proto2_pb2.TestAllTypesProto2() + if type == "protobuf_test_messages.proto3.TestAllTypesProto3": + return test_messages_proto3_pb2.TestAllTypesProto3() + if type == "protobuf_test_messages.editions.proto2.TestAllTypesProto2": + return test_messages_proto2_editions_pb2.TestAllTypesProto2() + if type == "protobuf_test_messages.editions.proto3.TestAllTypesProto3": + return test_messages_proto3_editions_pb2.TestAllTypesProto3() + return None + + def do_test(request): response = conformance_pb2.ConformanceResponse() @@ -85,16 +99,12 @@ def do_test(request): response.protobuf_payload = failure_set.SerializeToString() return response - isProto3 = (request.message_type == "protobuf_test_messages.proto3.TestAllTypesProto3") isJson = (request.WhichOneof('payload') == 'json_payload') - isProto2 = (request.message_type == "protobuf_test_messages.proto2.TestAllTypesProto2") + test_message = _create_test_message(request.message_type) - if (not isProto3) and (not isJson) and (not isProto2): + if (not isJson) and (test_message is None): raise ProtocolError("Protobuf request doesn't have specific payload type") - test_message = test_messages_proto2_pb2.TestAllTypesProto2() if isProto2 else \ - test_messages_proto3_pb2.TestAllTypesProto3() - try: if request.WhichOneof('payload') == 'protobuf_payload': try: diff --git a/conformance/failure_list_python.txt b/conformance/failure_list_python.txt index 8bbf094293d4..b278006bcc55 100644 --- a/conformance/failure_list_python.txt +++ b/conformance/failure_list_python.txt @@ -1,3 +1,6 @@ Recommended.Proto3.JsonInput.IgnoreUnknownEnumStringValueInMapValue.ProtobufOutput Recommended.Proto3.JsonInput.IgnoreUnknownEnumStringValueInOptionalField.ProtobufOutput Recommended.Proto3.JsonInput.IgnoreUnknownEnumStringValueInRepeatedField.ProtobufOutput +Recommended.Editions_Proto3.JsonInput.IgnoreUnknownEnumStringValueInMapValue.ProtobufOutput +Recommended.Editions_Proto3.JsonInput.IgnoreUnknownEnumStringValueInOptionalField.ProtobufOutput +Recommended.Editions_Proto3.JsonInput.IgnoreUnknownEnumStringValueInRepeatedField.ProtobufOutput diff --git a/conformance/failure_list_python_cpp.txt b/conformance/failure_list_python_cpp.txt index a49939327bb2..9b0dea68648e 100644 --- a/conformance/failure_list_python_cpp.txt +++ b/conformance/failure_list_python_cpp.txt @@ -9,3 +9,6 @@ Recommended.Proto3.JsonInput.IgnoreUnknownEnumStringValueInMapValue.ProtobufOutput Recommended.Proto3.JsonInput.IgnoreUnknownEnumStringValueInOptionalField.ProtobufOutput Recommended.Proto3.JsonInput.IgnoreUnknownEnumStringValueInRepeatedField.ProtobufOutput +Recommended.Editions_Proto3.JsonInput.IgnoreUnknownEnumStringValueInMapValue.ProtobufOutput +Recommended.Editions_Proto3.JsonInput.IgnoreUnknownEnumStringValueInOptionalField.ProtobufOutput +Recommended.Editions_Proto3.JsonInput.IgnoreUnknownEnumStringValueInRepeatedField.ProtobufOutput diff --git a/conformance/text_format_failure_list_python.txt b/conformance/text_format_failure_list_python.txt index 2f7f22471cb4..6754aa4c4b96 100644 --- a/conformance/text_format_failure_list_python.txt +++ b/conformance/text_format_failure_list_python.txt @@ -7,3 +7,9 @@ Required.Proto3.TextFormatInput.StringLiteralBasicEscapesBytes.ProtobufOutput Required.Proto3.TextFormatInput.StringLiteralBasicEscapesBytes.TextFormatOutput Required.Proto3.TextFormatInput.StringLiteralBasicEscapesString.ProtobufOutput Required.Proto3.TextFormatInput.StringLiteralBasicEscapesString.TextFormatOutput +Required.Editions_Proto3.TextFormatInput.FloatFieldMaxValue.ProtobufOutput +Required.Editions_Proto3.TextFormatInput.FloatFieldMaxValue.TextFormatOutput +Required.Editions_Proto3.TextFormatInput.StringLiteralBasicEscapesBytes.ProtobufOutput +Required.Editions_Proto3.TextFormatInput.StringLiteralBasicEscapesBytes.TextFormatOutput +Required.Editions_Proto3.TextFormatInput.StringLiteralBasicEscapesString.ProtobufOutput +Required.Editions_Proto3.TextFormatInput.StringLiteralBasicEscapesString.TextFormatOutput diff --git a/conformance/text_format_failure_list_python_cpp.txt b/conformance/text_format_failure_list_python_cpp.txt index b9da32dab814..037ca00e134e 100644 --- a/conformance/text_format_failure_list_python_cpp.txt +++ b/conformance/text_format_failure_list_python_cpp.txt @@ -2,3 +2,7 @@ Required.Proto3.TextFormatInput.StringLiteralBasicEscapesBytes.ProtobufOutput Required.Proto3.TextFormatInput.StringLiteralBasicEscapesBytes.TextFormatOutput Required.Proto3.TextFormatInput.StringLiteralBasicEscapesString.ProtobufOutput Required.Proto3.TextFormatInput.StringLiteralBasicEscapesString.TextFormatOutput +Required.Editions_Proto3.TextFormatInput.StringLiteralBasicEscapesBytes.ProtobufOutput +Required.Editions_Proto3.TextFormatInput.StringLiteralBasicEscapesBytes.TextFormatOutput +Required.Editions_Proto3.TextFormatInput.StringLiteralBasicEscapesString.ProtobufOutput +Required.Editions_Proto3.TextFormatInput.StringLiteralBasicEscapesString.TextFormatOutput diff --git a/python/build_targets.bzl b/python/build_targets.bzl index fff32b4a5b6f..56ed77b4dd99 100644 --- a/python/build_targets.bzl +++ b/python/build_targets.bzl @@ -12,6 +12,7 @@ load("//:protobuf.bzl", "internal_py_proto_library") load("//build_defs:arch_tests.bzl", "aarch64_test", "x86_64_test") load("//build_defs:cpp_opts.bzl", "COPTS") load("//conformance:defs.bzl", "conformance_test") +load("//src/google/protobuf/editions:defaults.bzl", "compile_edition_defaults", "embed_edition_defaults") load(":internal.bzl", "internal_copy_files", "internal_py_test") def build_targets(name): @@ -143,8 +144,23 @@ def build_targets(name): ], ) - py_library( - name = "python_srcs", + compile_edition_defaults( + name = "python_edition_defaults", + srcs = ["//:descriptor_proto"], + maximum_edition = "2023", + minimum_edition = "PROTO2", + ) + + embed_edition_defaults( + name = "embedded_python_edition_defaults_generate", + defaults = "python_edition_defaults", + output = "google/protobuf/internal/python_edition_defaults.py", + placeholder = "DEFAULTS_VALUE", + template = "google/protobuf/internal/python_edition_defaults.py.template", + ) + + native.filegroup( + name = "python_src_files", srcs = native.glob( [ "google/protobuf/**/*.py", @@ -154,7 +170,12 @@ def build_targets(name): "google/protobuf/internal/test_util.py", "google/protobuf/internal/import_test_package/__init__.py", ], - ), + ) + ["google/protobuf/internal/python_edition_defaults.py"], + ) + + py_library( + name = "python_srcs", + srcs = [":python_src_files"], imports = ["python"], srcs_version = "PY2AND3", visibility = [ @@ -196,19 +217,13 @@ def build_targets(name): ) internal_copy_files( - name = "copied_test_messages_proto2_files", + name = "copied_conformance_test_files", testonly = 1, srcs = [ "//src/google/protobuf:test_messages_proto2.proto", - ], - strip_prefix = "src", - ) - - internal_copy_files( - name = "copied_test_messages_proto3_files", - testonly = 1, - srcs = [ "//src/google/protobuf:test_messages_proto3.proto", + "//src/google/protobuf/editions:golden/test_messages_proto2_editions.proto", + "//src/google/protobuf/editions:golden/test_messages_proto3_editions.proto", ], strip_prefix = "src", ) @@ -241,22 +256,9 @@ def build_targets(name): ) internal_py_proto_library( - name = "test_messages_proto2_py_proto", - testonly = 1, - srcs = [":copied_test_messages_proto2_files"], - include = ".", - default_runtime = "//:protobuf_python", - protoc = "//:protoc", - visibility = [ - "//conformance:__pkg__", - "//python:__subpackages__", - ], - ) - - internal_py_proto_library( - name = "test_messages_proto3_py_proto", + name = "conformance_test_py_proto", testonly = 1, - srcs = [":copied_test_messages_proto3_files"], + srcs = [":copied_conformance_test_files"], include = ".", default_runtime = "//:protobuf_python", protoc = "//:protoc", @@ -404,6 +406,7 @@ def build_targets(name): ":use_fast_cpp_protos": ["@platforms//:incompatible"], "//conditions:default": [], }), + maximum_edition = "2023", testee = "//conformance:conformance_python", text_format_failure_list = "//conformance:text_format_failure_list_python.txt", ) @@ -418,6 +421,7 @@ def build_targets(name): ":use_fast_cpp_protos": [], "//conditions:default": ["@platforms//:incompatible"], }), + maximum_edition = "2023", testee = "//conformance:conformance_python", text_format_failure_list = "//conformance:text_format_failure_list_python_cpp.txt", ) @@ -428,16 +432,8 @@ def build_targets(name): pkg_files( name = "python_source_files", - srcs = native.glob( - [ - "google/protobuf/**/*.py", - ], - exclude = [ - "google/protobuf/internal/*_test.py", - "google/protobuf/internal/test_util.py", - "google/protobuf/internal/import_test_package/__init__.py", - ], - ) + [ + srcs = [ + ":python_src_files", "README.md", "google/__init__.py", "setup.cfg", diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py index 5b32e5e1546d..fb16bb4fe979 100755 --- a/python/google/protobuf/descriptor.py +++ b/python/google/protobuf/descriptor.py @@ -11,6 +11,9 @@ __author__ = 'robinson@google.com (Will Robinson)' +import abc +import binascii +import os import threading import warnings @@ -18,9 +21,6 @@ _USE_C_DESCRIPTORS = False if api_implementation.Type() != 'python': - # Used by MakeDescriptor in cpp mode - import binascii - import os # pylint: disable=protected-access _message = api_implementation._c_module # TODO: Remove this import after fix api_implementation @@ -52,7 +52,7 @@ def __instancecheck__(cls, obj): return False else: # The standard metaclass; nothing changes. - DescriptorMetaclass = type + DescriptorMetaclass = abc.ABCMeta class _Lock(object): @@ -83,6 +83,14 @@ def _Deprecated(name): % name, category=DeprecationWarning, stacklevel=3) +# These must match the values in descriptor.proto, but we can't use them +# directly because we sometimes need to reference them in feature helpers +# below *during* the build of descriptor.proto. +_FEATURESET_MESSAGE_ENCODING_DELIMITED = 2 +_FEATURESET_FIELD_PRESENCE_IMPLICIT = 2 +_FEATURESET_FIELD_PRESENCE_LEGACY_REQUIRED = 3 +_FEATURESET_REPEATED_FIELD_ENCODING_PACKED = 1 +_FEATURESET_ENUM_TYPE_CLOSED = 2 # Deprecated warnings will print 100 times at most which should be enough for # users to notice and do not cause timeout. @@ -118,8 +126,10 @@ def __init__(self, file, options, serialized_options, options_class_name): class of the options message. The name of the class is required in case the options message is None and has to be created. """ + self._features = None self.file = file self._options = options + self._loaded_options = None self._options_class_name = options_class_name self._serialized_options = serialized_options @@ -128,44 +138,106 @@ class of the options message. The name of the class is required in case self._serialized_options is not None ) - def _SetOptions(self, options, options_class_name): - """Sets the descriptor's options + @property + @abc.abstractmethod + def _parent(self): + pass + + def _InferLegacyFeatures(self, edition, options, features): + """Infers features from proto2/proto3 syntax so that editions logic can be used everywhere. - This function is used in generated proto2 files to update descriptor - options. It must not be used outside proto2. + Args: + edition: The edition to infer features for. + options: The options for this descriptor that are being processed. + features: The feature set object to modify with inferred features. """ - self._options = options - self._options_class_name = options_class_name + pass - # Does this descriptor have non-default options? - self.has_options = options is not None + def _GetFeatures(self): + if not self._features: + self._LazyLoadOptions() + return self._features - def GetOptions(self): - """Retrieves descriptor options. + def _ResolveFeatures(self, edition, raw_options): + """Resolves features from the raw options of this descriptor. - This method returns the options set or creates the default options for the - descriptor. + Args: + edition: The edition to use for feature defaults. + raw_options: The options for this descriptor that are being processed. + + Returns: + A fully resolved feature set for making runtime decisions. """ - if self._options: - return self._options + # pylint: disable=g-import-not-at-top + from google.protobuf import descriptor_pb2 + + if self._parent: + features = descriptor_pb2.FeatureSet() + features.CopyFrom(self._parent._GetFeatures()) + else: + features = self.file.pool._CreateDefaultFeatures(edition) + unresolved = descriptor_pb2.FeatureSet() + unresolved.CopyFrom(raw_options.features) + self._InferLegacyFeatures(edition, raw_options, unresolved) + features.MergeFrom(unresolved) + + # Use the feature cache to reduce memory bloat. + return self.file.pool._InternFeatures(features) + + def _LazyLoadOptions(self): + """Lazily initializes descriptor options towards the end of the build.""" + if self._loaded_options: + return + # pylint: disable=g-import-not-at-top from google.protobuf import descriptor_pb2 - try: - options_class = getattr(descriptor_pb2, - self._options_class_name) - except AttributeError: - raise RuntimeError('Unknown options class name %s!' % - (self._options_class_name)) - if self._serialized_options is None: + if not hasattr(descriptor_pb2, self._options_class_name): + raise RuntimeError( + 'Unknown options class name %s!' % self._options_class_name + ) + options_class = getattr(descriptor_pb2, self._options_class_name) + features = None + edition = self.file._edition + + if not self.has_options: + if not self._features: + features = self._ResolveFeatures( + descriptor_pb2.Edition.Value(edition), options_class() + ) with _lock: - self._options = options_class() + self._loaded_options = options_class() + if not self._features: + self._features = features else: - options = _ParseOptions(options_class(), self._serialized_options) + if not self._serialized_options: + options = self._options + else: + options = _ParseOptions(options_class(), self._serialized_options) + + if not self._features: + features = self._ResolveFeatures( + descriptor_pb2.Edition.Value(edition), options + ) with _lock: - self._options = options + self._loaded_options = options + if not self._features: + self._features = features + if options.HasField('features'): + options.ClearField('features') + if not options.SerializeToString(): + self._loaded_options = options_class() + self.has_options = False + + def GetOptions(self): + """Retrieves descriptor options. - return self._options + Returns: + The options set on this descriptor. + """ + if not self._loaded_options: + self._LazyLoadOptions() + return self._loaded_options class _NestedDescriptorBase(DescriptorBase): @@ -327,6 +399,7 @@ def __init__(self, name, full_name, filename, containing_type, fields, self.fields = fields for field in self.fields: field.containing_type = self + field.file = file self.fields_by_number = dict((f.number, f) for f in fields) self.fields_by_name = dict((f.name, f) for f in fields) self._fields_by_camelcase_name = None @@ -353,17 +426,12 @@ def __init__(self, name, full_name, filename, containing_type, fields, self.oneofs_by_name = dict((o.name, o) for o in self.oneofs) for oneof in self.oneofs: oneof.containing_type = self - self._deprecated_syntax = syntax or "proto2" + oneof.file = file self._is_map_entry = is_map_entry @property - def syntax(self): - warnings.warn( - 'descriptor.syntax is deprecated. It will be removed' - ' soon. Most usages are checking field descriptors. Consider to use' - ' has_presence, is_packed on field descriptors.' - ) - return self._deprecated_syntax + def _parent(self): + return self.containing_type or self.file @property def fields_by_camelcase_name(self): @@ -584,9 +652,9 @@ def __init__(self, name, full_name, index, number, type, cpp_type, label, self.json_name = json_name self.index = index self.number = number - self.type = type + self._type = type self.cpp_type = cpp_type - self.label = label + self._label = label self.has_default_value = has_default_value self.default_value = default_value self.containing_type = containing_type @@ -603,6 +671,60 @@ def __init__(self, name, full_name, index, number, type, cpp_type, label, else: self._cdescriptor = _message.default_pool.FindFieldByName(full_name) + @property + def _parent(self): + if self.containing_oneof: + return self.containing_oneof + if self.is_extension: + return self.extension_scope or self.file + return self.containing_type + + def _InferLegacyFeatures(self, edition, options, features): + # pylint: disable=g-import-not-at-top + from google.protobuf import descriptor_pb2 + + if edition >= descriptor_pb2.Edition.EDITION_2023: + return + + if self._label == FieldDescriptor.LABEL_REQUIRED: + features.field_presence = ( + descriptor_pb2.FeatureSet.FieldPresence.LEGACY_REQUIRED + ) + + if self._type == FieldDescriptor.TYPE_GROUP: + features.message_encoding = ( + descriptor_pb2.FeatureSet.MessageEncoding.DELIMITED + ) + + if options.HasField('packed'): + features.repeated_field_encoding = ( + descriptor_pb2.FeatureSet.RepeatedFieldEncoding.PACKED + if options.packed + else descriptor_pb2.FeatureSet.RepeatedFieldEncoding.EXPANDED + ) + + @property + def type(self): + if ( + self._GetFeatures().message_encoding + == _FEATURESET_MESSAGE_ENCODING_DELIMITED + ): + return FieldDescriptor.TYPE_GROUP + return self._type + + @type.setter + def type(self, val): + self._type = val + + @property + def label(self): + if ( + self._GetFeatures().field_presence + == _FEATURESET_FIELD_PRESENCE_LEGACY_REQUIRED + ): + return FieldDescriptor.LABEL_REQUIRED + return self._label + @property def camelcase_name(self): """Camelcase name of this field. @@ -626,11 +748,11 @@ def has_presence(self): if (self.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE or self.containing_oneof): return True - # self.containing_type is used here instead of self.file for legacy - # compatibility. FieldDescriptor.file was added in cl/153110619 - # Some old/generated code didn't link file to FieldDescriptor. - # TODO: remove syntax usage b/240619313 - return self.containing_type._deprecated_syntax == 'proto2' + + return ( + self._GetFeatures().field_presence + != _FEATURESET_FIELD_PRESENCE_IMPLICIT + ) @property def is_packed(self): @@ -643,12 +765,11 @@ def is_packed(self): field_type == FieldDescriptor.TYPE_MESSAGE or field_type == FieldDescriptor.TYPE_BYTES): return False - if self.containing_type._deprecated_syntax == 'proto2': - return self.has_options and self.GetOptions().packed - else: - return (not self.has_options or - not self.GetOptions().HasField('packed') or - self.GetOptions().packed) + + return ( + self._GetFeatures().repeated_field_encoding + == _FEATURESET_REPEATED_FIELD_ENCODING_PACKED + ) @staticmethod def ProtoTypeToCppProtoType(proto_type): @@ -730,6 +851,10 @@ def __init__(self, name, full_name, filename, values, # Values are reversed to ensure that the first alias is retained. self.values_by_number = dict((v.number, v) for v in reversed(values)) + @property + def _parent(self): + return self.containing_type or self.file + @property def is_closed(self): """Returns true whether this is a "closed" enum. @@ -752,7 +877,7 @@ def is_closed(self): Care should be taken when using this function to respect the target runtime's enum handling quirks. """ - return self.file._deprecated_syntax == 'proto2' + return self._GetFeatures().enum_type == _FEATURESET_ENUM_TYPE_CLOSED def CopyToProto(self, proto): """Copies this to a descriptor_pb2.EnumDescriptorProto. @@ -811,6 +936,10 @@ def __init__(self, name, index, number, self.number = number self.type = type + @property + def _parent(self): + return self.type + class OneofDescriptor(DescriptorBase): """Descriptor for a oneof field. @@ -855,6 +984,10 @@ def __init__( self.containing_type = containing_type self.fields = fields + @property + def _parent(self): + return self.containing_type + class ServiceDescriptor(_NestedDescriptorBase): @@ -911,6 +1044,10 @@ def __init__(self, name, full_name, index, methods, options=None, method.file = self.file method.containing_service = self + @property + def _parent(self): + return self.file + def FindMethodByName(self, name): """Searches for the specified method, and returns its descriptor. @@ -1008,6 +1145,10 @@ def __init__(self, self.client_streaming = client_streaming self.server_streaming = server_streaming + @property + def _parent(self): + return self.containing_service + def CopyToProto(self, proto): """Copies this to a descriptor_pb2.MethodDescriptorProto. @@ -1061,10 +1202,20 @@ class FileDescriptor(DescriptorBase): if _USE_C_DESCRIPTORS: _C_DESCRIPTOR_CLASS = _message.FileDescriptor - def __new__(cls, name, package, options=None, - serialized_options=None, serialized_pb=None, - dependencies=None, public_dependencies=None, - syntax=None, pool=None, create_key=None): + def __new__( + cls, + name, + package, + options=None, + serialized_options=None, + serialized_pb=None, + dependencies=None, + public_dependencies=None, + syntax=None, + edition=None, + pool=None, + create_key=None, + ): # FileDescriptor() is called from various places, not only from generated # files, to register dynamic proto files and messages. # pylint: disable=g-explicit-bool-comparison @@ -1073,18 +1224,35 @@ def __new__(cls, name, package, options=None, else: return super(FileDescriptor, cls).__new__(cls) - def __init__(self, name, package, options=None, - serialized_options=None, serialized_pb=None, - dependencies=None, public_dependencies=None, - syntax=None, pool=None, create_key=None): + def __init__( + self, + name, + package, + options=None, + serialized_options=None, + serialized_pb=None, + dependencies=None, + public_dependencies=None, + syntax=None, + edition=None, + pool=None, + create_key=None, + ): """Constructor.""" if create_key is not _internal_create_key: _Deprecated('FileDescriptor') super(FileDescriptor, self).__init__( - None, options, serialized_options, 'FileOptions' + self, options, serialized_options, 'FileOptions' ) + if edition and edition != 'EDITION_UNKNOWN': + self._edition = edition + elif syntax == 'proto3': + self._edition = 'EDITION_PROTO3' + else: + self._edition = 'EDITION_PROTO2' + if pool is None: from google.protobuf import descriptor_pool pool = descriptor_pool.Default() @@ -1118,6 +1286,10 @@ def CopyToProto(self, proto): """ proto.ParseFromString(self.serialized_pb) + @property + def _parent(self): + return None + def _ParseOptions(message, string): """Parses serialized options. @@ -1175,8 +1347,14 @@ def _ToJsonName(name): return ''.join(result) -def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True, - syntax=None): +def MakeDescriptor( + desc_proto, + package='', + build_file_if_cpp=True, + syntax=None, + edition=None, + file_desc=None, +): """Make a protobuf Descriptor given a DescriptorProto protobuf. Handles nested descriptors. Note that this is limited to the scope of defining @@ -1186,34 +1364,41 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True, Args: desc_proto: The descriptor_pb2.DescriptorProto protobuf message. package: Optional package name for the new message Descriptor (string). - build_file_if_cpp: Update the C++ descriptor pool if api matches. - Set to False on recursion, so no duplicates are created. + build_file_if_cpp: Update the C++ descriptor pool if api matches. Set to + False on recursion, so no duplicates are created. syntax: The syntax/semantics that should be used. Set to "proto3" to get - proto3 field presence semantics. + proto3 field presence semantics. + edition: The edition that should be used if syntax is "edition". + file_desc: A FileDescriptor to place this descriptor into. + Returns: A Descriptor for protobuf messages. """ + # pylint: disable=g-import-not-at-top + from google.protobuf import descriptor_pb2 + + # Generate a random name for this proto file to prevent conflicts with any + # imported ones. We need to specify a file name so the descriptor pool + # accepts our FileDescriptorProto, but it is not important what that file + # name is actually set to. + proto_name = binascii.hexlify(os.urandom(16)).decode('ascii') + + if package: + file_name = os.path.join(package.replace('.', '/'), proto_name + '.proto') + else: + file_name = proto_name + '.proto' + if api_implementation.Type() != 'python' and build_file_if_cpp: # The C++ implementation requires all descriptors to be backed by the same # definition in the C++ descriptor pool. To do this, we build a # FileDescriptorProto with the same definition as this descriptor and build # it into the pool. - from google.protobuf import descriptor_pb2 file_descriptor_proto = descriptor_pb2.FileDescriptorProto() file_descriptor_proto.message_type.add().MergeFrom(desc_proto) - # Generate a random name for this proto file to prevent conflicts with any - # imported ones. We need to specify a file name so the descriptor pool - # accepts our FileDescriptorProto, but it is not important what that file - # name is actually set to. - proto_name = binascii.hexlify(os.urandom(16)).decode('ascii') - if package: - file_descriptor_proto.name = os.path.join(package.replace('.', '/'), - proto_name + '.proto') file_descriptor_proto.package = package - else: - file_descriptor_proto.name = proto_name + '.proto' + file_descriptor_proto.name = file_name _message.default_pool.Add(file_descriptor_proto) result = _message.default_pool.FindFileByName(file_descriptor_proto.name) @@ -1221,6 +1406,19 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True, if _USE_C_DESCRIPTORS: return result.message_types_by_name[desc_proto.name] + if file_desc is None: + file_desc = FileDescriptor( + pool=None, + name=file_name, + package=package, + syntax=syntax, + edition=edition, + options=None, + serialized_pb='', + dependencies=[], + public_dependencies=[], + create_key=_internal_create_key, + ) full_message_name = [desc_proto.name] if package: full_message_name.insert(0, package) @@ -1229,11 +1427,21 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True, for enum_proto in desc_proto.enum_type: full_name = '.'.join(full_message_name + [enum_proto.name]) enum_desc = EnumDescriptor( - enum_proto.name, full_name, None, [ - EnumValueDescriptor(enum_val.name, ii, enum_val.number, - create_key=_internal_create_key) - for ii, enum_val in enumerate(enum_proto.value)], - create_key=_internal_create_key) + enum_proto.name, + full_name, + None, + [ + EnumValueDescriptor( + enum_val.name, + ii, + enum_val.number, + create_key=_internal_create_key, + ) + for ii, enum_val in enumerate(enum_proto.value) + ], + file=file_desc, + create_key=_internal_create_key, + ) enum_types[full_name] = enum_desc # Create Descriptors for nested types @@ -1242,10 +1450,14 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True, full_name = '.'.join(full_message_name + [nested_proto.name]) # Nested types are just those defined inside of the message, not all types # used by fields in the message, so no loops are possible here. - nested_desc = MakeDescriptor(nested_proto, - package='.'.join(full_message_name), - build_file_if_cpp=False, - syntax=syntax) + nested_desc = MakeDescriptor( + nested_proto, + package='.'.join(full_message_name), + build_file_if_cpp=False, + syntax=syntax, + edition=edition, + file_desc=file_desc, + ) nested_types[full_name] = nested_desc fields = [] @@ -1267,16 +1479,38 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True, enum_desc = enum_types[full_type_name] # Else type_name references a non-local type, which isn't implemented field = FieldDescriptor( - field_proto.name, full_name, field_proto.number - 1, - field_proto.number, field_proto.type, + field_proto.name, + full_name, + field_proto.number - 1, + field_proto.number, + field_proto.type, FieldDescriptor.ProtoTypeToCppProtoType(field_proto.type), - field_proto.label, None, nested_desc, enum_desc, None, False, None, - options=_OptionsOrNone(field_proto), has_default_value=False, - json_name=json_name, create_key=_internal_create_key) + field_proto.label, + None, + nested_desc, + enum_desc, + None, + False, + None, + options=_OptionsOrNone(field_proto), + has_default_value=False, + json_name=json_name, + file=file_desc, + create_key=_internal_create_key, + ) fields.append(field) desc_name = '.'.join(full_message_name) - return Descriptor(desc_proto.name, desc_name, None, None, fields, - list(nested_types.values()), list(enum_types.values()), [], - options=_OptionsOrNone(desc_proto), - create_key=_internal_create_key) + return Descriptor( + desc_proto.name, + desc_name, + None, + None, + fields, + list(nested_types.values()), + list(enum_types.values()), + [], + options=_OptionsOrNone(desc_proto), + file=file_desc, + create_key=_internal_create_key, + ) diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py index c2fe59fdcd4d..10ad72d3d2ae 100644 --- a/python/google/protobuf/descriptor_pool.py +++ b/python/google/protobuf/descriptor_pool.py @@ -35,11 +35,13 @@ __author__ = 'matthewtoia@google.com (Matt Toia)' import collections +import threading import warnings from google.protobuf import descriptor from google.protobuf import descriptor_database from google.protobuf import text_encoding +from google.protobuf.internal import python_edition_defaults from google.protobuf.internal import python_message _USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS # pylint: disable=protected-access @@ -91,6 +93,8 @@ def _IsMessageSetExtension(field): field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL) +_edition_defaults_lock = threading.Lock() + class DescriptorPool(object): """A collection of protobufs dynamically constructed by descriptor protos.""" @@ -131,6 +135,11 @@ def __init__( # full name or its tag number. self._extensions_by_name = collections.defaultdict(dict) self._extensions_by_number = collections.defaultdict(dict) + self._serialized_edition_defaults = ( + python_edition_defaults._PROTOBUF_INTERNAL_PYTHON_EDITION_DEFAULTS + ) + self._edition_defaults = None + self._feature_cache = dict() def _CheckConflictRegister(self, desc, desc_name, file_name): """Check if the descriptor name conflicts with another of the same name. @@ -679,6 +688,102 @@ def FindMethodByName(self, full_name): service_descriptor = self.FindServiceByName(service_name) return service_descriptor.methods_by_name[method_name] + def SetFeatureSetDefaults(self, defaults): + """Sets the default feature mappings used during the build. + + Args: + defaults: a FeatureSetDefaults message containing the new mappings. + """ + if self._edition_defaults is not None: + raise ValueError( + "Feature set defaults can't be changed once the pool has started" + ' building!' + ) + + # pylint: disable=g-import-not-at-top + from google.protobuf import descriptor_pb2 + + if defaults.minimum_edition > defaults.maximum_edition: + raise ValueError( + 'Invalid edition range %s to %s' + % ( + descriptor_pb2.Edition.Name(defaults.minimum_edition), + descriptor_pb2.Edition.Name(defaults.maximum_edition), + ) + ) + + prev_edition = descriptor_pb2.Edition.EDITION_UNKNOWN + for d in defaults.defaults: + if d.edition == descriptor_pb2.Edition.EDITION_UNKNOWN: + raise ValueError('Invalid edition EDITION_UNKNOWN specified') + if prev_edition >= d.edition: + raise ValueError('Feature set defaults are not strictly increasing') + prev_edition = d.edition + self._edition_defaults = defaults + + def _CreateDefaultFeatures(self, edition): + """Creates a FeatureSet message with defaults for a specific edition. + + Args: + edition: the edition to generate defaults for. + + Returns: + A FeatureSet message with defaults for a specific edition. + """ + # pylint: disable=g-import-not-at-top + from google.protobuf import descriptor_pb2 + + with _edition_defaults_lock: + if not self._edition_defaults: + self._edition_defaults = descriptor_pb2.FeatureSetDefaults() + self._edition_defaults.ParseFromString( + self._serialized_edition_defaults + ) + + if edition < self._edition_defaults.minimum_edition: + raise TypeError( + 'Edition %s is earlier than the minimum supported edition %s!' + % ( + descriptor_pb2.Edition.Name(edition), + descriptor_pb2.Edition.Name( + self._edition_defaults.minimum_edition + ), + ) + ) + if edition > self._edition_defaults.maximum_edition: + raise TypeError( + 'Edition %s is later than the maximum supported edition %s!' + % ( + descriptor_pb2.Edition.Name(edition), + descriptor_pb2.Edition.Name( + self._edition_defaults.maximum_edition + ), + ) + ) + found = None + for d in self._edition_defaults.defaults: + if d.edition > edition: + break + found = d.features + if found is None: + raise TypeError( + 'No valid default found for edition %s!' + % descriptor_pb2.Edition.Name(edition) + ) + + defaults = descriptor_pb2.FeatureSet() + defaults.CopyFrom(found) + return defaults + + def _InternFeatures(self, features): + serialized = features.SerializeToString() + with _edition_defaults_lock: + cached = self._feature_cache.get(serialized) + if cached is None: + self._feature_cache[serialized] = features + cached = features + return cached + def _FindFileContainingSymbolInDb(self, symbol): """Finds the file in descriptor DB containing the specified symbol. @@ -719,17 +824,22 @@ def _ConvertFileProtoToFileDescriptor(self, file_proto): direct_deps = [self.FindFileByName(n) for n in file_proto.dependency] public_deps = [direct_deps[i] for i in file_proto.public_dependency] + # pylint: disable=g-import-not-at-top + from google.protobuf import descriptor_pb2 + file_descriptor = descriptor.FileDescriptor( pool=self, name=file_proto.name, package=file_proto.package, syntax=file_proto.syntax, + edition=descriptor_pb2.Edition.Name(file_proto.edition), options=_OptionsOrNone(file_proto), serialized_pb=file_proto.SerializeToString(), dependencies=direct_deps, public_dependencies=public_deps, # pylint: disable=protected-access - create_key=descriptor._internal_create_key) + create_key=descriptor._internal_create_key, + ) scope = {} # This loop extracts all the message and enum types from all the @@ -876,10 +986,10 @@ def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None, file=file_desc, serialized_start=None, serialized_end=None, - syntax=syntax, is_map_entry=desc_proto.options.map_entry, # pylint: disable=protected-access - create_key=descriptor._internal_create_key) + create_key=descriptor._internal_create_key, + ) for nested in desc.nested_types: nested.containing_type = desc for enum in desc.enum_types: diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py index dd2dc324239d..69fab8fd99d1 100644 --- a/python/google/protobuf/internal/descriptor_pool_test.py +++ b/python/google/protobuf/internal/descriptor_pool_test.py @@ -13,7 +13,12 @@ import unittest import warnings +from google.protobuf import descriptor +from google.protobuf import descriptor_database from google.protobuf import descriptor_pb2 +from google.protobuf import descriptor_pool +from google.protobuf import message_factory +from google.protobuf import symbol_database from google.protobuf.internal import api_implementation from google.protobuf.internal import descriptor_pool_test1_pb2 from google.protobuf.internal import descriptor_pool_test2_pb2 @@ -23,15 +28,15 @@ from google.protobuf.internal import more_messages_pb2 from google.protobuf.internal import no_package_pb2 from google.protobuf.internal import testing_refleaks -from google.protobuf import descriptor -from google.protobuf import descriptor_database -from google.protobuf import descriptor_pool -from google.protobuf import message_factory -from google.protobuf import symbol_database + +from google.protobuf import unittest_features_pb2 from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_import_public_pb2 from google.protobuf import unittest_pb2 +# pyformat: disable +# pyformat: enable + warnings.simplefilter('error', DeprecationWarning) @@ -1070,6 +1075,194 @@ def testAddTypeError(self): pool._AddFileDescriptor(0) +# TODO Expand these tests to upb and C++ once the DescriptorPool +# API is unified. +@unittest.skipIf( + api_implementation.Type() != 'python', + 'Only pure python allows SetFeatureSetDefaults()', +) +@testing_refleaks.TestCase +class FeatureSetDefaults(unittest.TestCase): + + def testDefault(self): + pool = descriptor_pool.DescriptorPool() + file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto') + file = pool.AddSerializedFile(file_desc.SerializeToString()) + self.assertFalse( + file._GetFeatures().HasExtension(unittest_features_pb2.test) + ) + + def testOverride(self): + pool = descriptor_pool.DescriptorPool() + defaults = descriptor_pb2.FeatureSetDefaults( + defaults=[ + descriptor_pb2.FeatureSetDefaults.FeatureSetEditionDefault( + edition=descriptor_pb2.Edition.EDITION_PROTO2, + features=unittest_features_pb2.DESCRIPTOR._GetFeatures(), + ) + ], + minimum_edition=descriptor_pb2.Edition.EDITION_PROTO2, + maximum_edition=descriptor_pb2.Edition.EDITION_2023, + ) + defaults.defaults[0].features.Extensions[ + unittest_features_pb2.test + ].int_file_feature = 9 + pool.SetFeatureSetDefaults(defaults) + file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto') + file = pool.AddSerializedFile(file_desc.SerializeToString()) + self.assertTrue( + file._GetFeatures().HasExtension(unittest_features_pb2.test) + ) + + def testInvalidEditionRange(self): + pool = descriptor_pool.DescriptorPool() + with self.assertRaisesRegex(ValueError, 'Invalid edition range'): + pool.SetFeatureSetDefaults( + descriptor_pb2.FeatureSetDefaults( + defaults=[ + descriptor_pb2.FeatureSetDefaults.FeatureSetEditionDefault( + edition=descriptor_pb2.Edition.EDITION_PROTO2, + features=unittest_features_pb2.DESCRIPTOR._GetFeatures(), + ) + ], + minimum_edition=descriptor_pb2.Edition.EDITION_2023, + maximum_edition=descriptor_pb2.Edition.EDITION_PROTO2, + ) + ) + file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto') + file = pool.AddSerializedFile(file_desc.SerializeToString()) + + def testNotStrictlyIncreasing(self): + pool = descriptor_pool.DescriptorPool() + with self.assertRaisesRegex(ValueError, 'not strictly increasing'): + pool.SetFeatureSetDefaults( + descriptor_pb2.FeatureSetDefaults( + defaults=[ + descriptor_pb2.FeatureSetDefaults.FeatureSetEditionDefault( + edition=descriptor_pb2.Edition.EDITION_PROTO3, + features=unittest_features_pb2.DESCRIPTOR._GetFeatures(), + ), + descriptor_pb2.FeatureSetDefaults.FeatureSetEditionDefault( + edition=descriptor_pb2.Edition.EDITION_PROTO2, + features=unittest_features_pb2.DESCRIPTOR._GetFeatures(), + ), + ], + minimum_edition=descriptor_pb2.Edition.EDITION_PROTO2, + maximum_edition=descriptor_pb2.Edition.EDITION_2023, + ) + ) + + def testUnknownEdition(self): + pool = descriptor_pool.DescriptorPool() + with self.assertRaisesRegex(ValueError, 'Invalid edition'): + pool.SetFeatureSetDefaults( + descriptor_pb2.FeatureSetDefaults( + defaults=[ + descriptor_pb2.FeatureSetDefaults.FeatureSetEditionDefault( + edition=descriptor_pb2.Edition.EDITION_UNKNOWN, + features=unittest_features_pb2.DESCRIPTOR._GetFeatures(), + ), + descriptor_pb2.FeatureSetDefaults.FeatureSetEditionDefault( + edition=descriptor_pb2.Edition.EDITION_PROTO2, + features=unittest_features_pb2.DESCRIPTOR._GetFeatures(), + ), + ], + minimum_edition=descriptor_pb2.Edition.EDITION_PROTO2, + maximum_edition=descriptor_pb2.Edition.EDITION_2023, + ) + ) + + def testChangeAfterBuild(self): + pool = descriptor_pool.DescriptorPool() + file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto') + file = pool.AddSerializedFile(file_desc.SerializeToString()) + file._GetFeatures() + defaults = descriptor_pb2.FeatureSetDefaults( + defaults=[ + descriptor_pb2.FeatureSetDefaults.FeatureSetEditionDefault( + edition=descriptor_pb2.Edition.EDITION_PROTO2, + features=unittest_features_pb2.DESCRIPTOR._GetFeatures(), + ) + ], + minimum_edition=descriptor_pb2.Edition.EDITION_PROTO2, + maximum_edition=descriptor_pb2.Edition.EDITION_2023, + ) + with self.assertRaisesRegex(ValueError, "defaults can't be changed"): + pool.SetFeatureSetDefaults(defaults) + + def testChangeDefaultPool(self): + defaults = descriptor_pb2.FeatureSetDefaults( + defaults=[ + descriptor_pb2.FeatureSetDefaults.FeatureSetEditionDefault( + edition=descriptor_pb2.Edition.EDITION_PROTO2, + features=unittest_features_pb2.DESCRIPTOR._GetFeatures(), + ) + ], + minimum_edition=descriptor_pb2.Edition.EDITION_PROTO2, + maximum_edition=descriptor_pb2.Edition.EDITION_2023, + ) + with self.assertRaisesRegex(ValueError, "defaults can't be changed"): + descriptor_pool.Default().SetFeatureSetDefaults(defaults) + + def testNoValidFeatures(self): + pool = descriptor_pool.DescriptorPool() + defaults = descriptor_pb2.FeatureSetDefaults( + defaults=[ + descriptor_pb2.FeatureSetDefaults.FeatureSetEditionDefault( + edition=descriptor_pb2.Edition.EDITION_2023, + features=unittest_features_pb2.DESCRIPTOR._GetFeatures(), + ) + ], + minimum_edition=descriptor_pb2.Edition.EDITION_PROTO2, + maximum_edition=descriptor_pb2.Edition.EDITION_2023, + ) + pool.SetFeatureSetDefaults(defaults) + file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto') + with self.assertRaisesRegex(TypeError, 'No valid default found'): + file = pool.AddSerializedFile(file_desc.SerializeToString()) + file._GetFeatures() + + def testBelowMinimum(self): + pool = descriptor_pool.DescriptorPool() + defaults = descriptor_pb2.FeatureSetDefaults( + defaults=[ + descriptor_pb2.FeatureSetDefaults.FeatureSetEditionDefault( + edition=descriptor_pb2.Edition.EDITION_PROTO3, + features=unittest_features_pb2.DESCRIPTOR._GetFeatures(), + ) + ], + minimum_edition=descriptor_pb2.Edition.EDITION_PROTO3, + maximum_edition=descriptor_pb2.Edition.EDITION_2023, + ) + pool.SetFeatureSetDefaults(defaults) + file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto') + with self.assertRaisesRegex(TypeError, 'earlier than the minimum'): + file = pool.AddSerializedFile(file_desc.SerializeToString()) + file._GetFeatures() + + def testAboveMaximum(self): + pool = descriptor_pool.DescriptorPool() + defaults = descriptor_pb2.FeatureSetDefaults( + defaults=[ + descriptor_pb2.FeatureSetDefaults.FeatureSetEditionDefault( + edition=descriptor_pb2.Edition.EDITION_PROTO2, + features=unittest_features_pb2.DESCRIPTOR._GetFeatures(), + ) + ], + minimum_edition=descriptor_pb2.Edition.EDITION_PROTO2, + maximum_edition=descriptor_pb2.Edition.EDITION_PROTO3, + ) + pool.SetFeatureSetDefaults(defaults) + file_desc = descriptor_pb2.FileDescriptorProto( + name='some/file.proto', + syntax='editions', + edition=descriptor_pb2.Edition.EDITION_2023, + ) + with self.assertRaisesRegex(TypeError, 'later than the maximum'): + file = pool.AddSerializedFile(file_desc.SerializeToString()) + file._GetFeatures() + + TEST1_FILE = ProtoFile( 'google/protobuf/internal/descriptor_pool_test1.proto', 'google.protobuf.python.internal', diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py index 56ca079f3401..2e90f0f51dad 100755 --- a/python/google/protobuf/internal/descriptor_test.py +++ b/python/google/protobuf/internal/descriptor_test.py @@ -18,12 +18,15 @@ from google.protobuf import symbol_database from google.protobuf import text_format from google.protobuf.internal import api_implementation +from google.protobuf.internal import legacy_features_pb2 from google.protobuf.internal import test_util from google.protobuf.internal import _parameterized from google.protobuf import unittest_custom_options_pb2 +from google.protobuf import unittest_features_pb2 from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_pb2 +from google.protobuf import unittest_proto3_pb2 TEST_EMPTY_MESSAGE_DESCRIPTOR_ASCII = """ @@ -1215,9 +1218,13 @@ def testJsonName(self): json_names[index]) +# TODO Add _GetFeatures for upb and C++. +@unittest.skipIf( + api_implementation.Type() != 'python', + 'Features field is only available with the pure python implementation', +) class FeaturesTest(_parameterized.TestCase): - # TODO Add _features for upb and C++. @_parameterized.named_parameters([ ('File', lambda: descriptor_pb2.DESCRIPTOR), ('Message', lambda: descriptor_pb2.FeatureSet.DESCRIPTOR), @@ -1232,46 +1239,376 @@ class FeaturesTest(_parameterized.TestCase): ], ), ]) - @unittest.skipIf( - api_implementation.Type() != 'python', - 'Features field is only available with the pure python implementation', - ) def testDescriptorProtoDefaultFeatures(self, desc): self.assertEqual( - desc()._features.field_presence, + desc()._GetFeatures().field_presence, descriptor_pb2.FeatureSet.FieldPresence.EXPLICIT, ) self.assertEqual( - desc()._features.enum_type, + desc()._GetFeatures().enum_type, descriptor_pb2.FeatureSet.EnumType.CLOSED, ) self.assertEqual( - desc()._features.repeated_field_encoding, + desc()._GetFeatures().repeated_field_encoding, descriptor_pb2.FeatureSet.RepeatedFieldEncoding.EXPANDED, ) - # TODO Add _features for upb and C++. - @unittest.skipIf( - api_implementation.Type() != 'python', - 'Features field is only available with the pure python implementation', - ) def testDescriptorProtoOverrideFeatures(self): desc = descriptor_pb2.SourceCodeInfo.Location.DESCRIPTOR.fields_by_name[ 'path' ] self.assertEqual( - desc._features.field_presence, + desc._GetFeatures().field_presence, descriptor_pb2.FeatureSet.FieldPresence.EXPLICIT, ) self.assertEqual( - desc._features.enum_type, + desc._GetFeatures().enum_type, descriptor_pb2.FeatureSet.EnumType.CLOSED, ) self.assertEqual( - desc._features.repeated_field_encoding, + desc._GetFeatures().repeated_field_encoding, + descriptor_pb2.FeatureSet.RepeatedFieldEncoding.PACKED, + ) + + def testFeaturesStripped(self): + desc = legacy_features_pb2.TestEditionsMessage.DESCRIPTOR.fields_by_name[ + 'required_field' + ] + self.assertFalse(desc.GetOptions().HasField('features')) + + def testLegacyRequiredTransform(self): + desc = legacy_features_pb2.TestEditionsMessage.DESCRIPTOR + self.assertEqual( + desc.fields_by_name['required_field'].label, + descriptor.FieldDescriptor.LABEL_REQUIRED, + ) + + def testLegacyGroupTransform(self): + desc = legacy_features_pb2.TestEditionsMessage.DESCRIPTOR + self.assertEqual( + desc.fields_by_name['delimited_field'].type, + descriptor.FieldDescriptor.TYPE_GROUP, + ) + + def testLegacyInferRequired(self): + desc = unittest_pb2.TestRequired.DESCRIPTOR.fields_by_name['a'] + self.assertEqual( + desc._GetFeatures().field_presence, + descriptor_pb2.FeatureSet.FieldPresence.LEGACY_REQUIRED, + ) + + def testLegacyInferGroup(self): + desc = unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name['optionalgroup'] + self.assertEqual( + desc._GetFeatures().message_encoding, + descriptor_pb2.FeatureSet.MessageEncoding.DELIMITED, + ) + + def testLegacyInferProto2Packed(self): + desc = unittest_pb2.TestPackedTypes.DESCRIPTOR.fields_by_name[ + 'packed_int32' + ] + self.assertEqual( + desc._GetFeatures().repeated_field_encoding, descriptor_pb2.FeatureSet.RepeatedFieldEncoding.PACKED, ) + def testLegacyInferProto3Expanded(self): + desc = unittest_proto3_pb2.TestUnpackedTypes.DESCRIPTOR.fields_by_name[ + 'repeated_int32' + ] + self.assertEqual( + desc._GetFeatures().repeated_field_encoding, + descriptor_pb2.FeatureSet.RepeatedFieldEncoding.EXPANDED, + ) + + def testProto2Defaults(self): + features = unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name[ + 'optional_int32' + ]._GetFeatures() + fs = descriptor_pb2.FeatureSet + self.assertEqual(features.field_presence, fs.FieldPresence.EXPLICIT) + self.assertEqual(features.enum_type, fs.EnumType.CLOSED) + self.assertEqual( + features.repeated_field_encoding, fs.RepeatedFieldEncoding.EXPANDED + ) + self.assertEqual(features.utf8_validation, fs.Utf8Validation.NONE) + self.assertEqual( + features.message_encoding, fs.MessageEncoding.LENGTH_PREFIXED + ) + self.assertEqual(features.json_format, fs.JsonFormat.LEGACY_BEST_EFFORT) + + def testProto3Defaults(self): + features = unittest_proto3_pb2.TestAllTypes.DESCRIPTOR.fields_by_name[ + 'optional_int32' + ]._GetFeatures() + fs = descriptor_pb2.FeatureSet + self.assertEqual(features.field_presence, fs.FieldPresence.IMPLICIT) + self.assertEqual(features.enum_type, fs.EnumType.OPEN) + self.assertEqual( + features.repeated_field_encoding, fs.RepeatedFieldEncoding.PACKED + ) + self.assertEqual(features.utf8_validation, fs.Utf8Validation.VERIFY) + self.assertEqual( + features.message_encoding, fs.MessageEncoding.LENGTH_PREFIXED + ) + self.assertEqual(features.json_format, fs.JsonFormat.ALLOW) + + +def GetTestFeature(desc): + return ( + desc._GetFeatures() + .Extensions[unittest_features_pb2.test] + .int_multiple_feature + ) + + +def SetTestFeature(proto, value): + proto.options.features.Extensions[ + unittest_features_pb2.test + ].int_multiple_feature = value + + +# TODO Add _GetFeatures for upb and C++. +@unittest.skipIf( + api_implementation.Type() != 'python', + 'Features field is only available with the pure python implementation', +) +class FeatureInheritanceTest(unittest.TestCase): + + def setUp(self): + super().setUp() + self.file_proto = descriptor_pb2.FileDescriptorProto( + name='some/filename/some.proto', + package='protobuf_unittest', + edition=descriptor_pb2.Edition.EDITION_2023, + syntax='editions', + ) + self.top_extension_proto = self.file_proto.extension.add( + name='top_extension', + number=10, + type=descriptor_pb2.FieldDescriptorProto.TYPE_INT32, + label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL, + extendee='.protobuf_unittest.TopMessage', + ) + self.top_enum_proto = self.file_proto.enum_type.add(name='TopEnum') + self.enum_value_proto = self.top_enum_proto.value.add( + name='TOP_VALUE', number=0 + ) + self.top_message_proto = self.file_proto.message_type.add(name='TopMessage') + self.field_proto = self.top_message_proto.field.add( + name='field', + number=1, + type=descriptor_pb2.FieldDescriptorProto.TYPE_INT32, + label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL, + ) + self.top_message_proto.extension_range.add(start=10, end=20) + self.nested_extension_proto = self.top_message_proto.extension.add( + name='nested_extension', + number=11, + type=descriptor_pb2.FieldDescriptorProto.TYPE_INT32, + label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL, + extendee='.protobuf_unittest.TopMessage', + ) + self.nested_message_proto = self.top_message_proto.nested_type.add( + name='NestedMessage' + ) + self.nested_enum_proto = self.top_message_proto.enum_type.add( + name='NestedEnum' + ) + self.nested_enum_proto.value.add(name='NESTED_VALUE', number=0) + self.oneof_proto = self.top_message_proto.oneof_decl.add(name='Oneof') + self.oneof_field_proto = self.top_message_proto.field.add( + name='oneof_field', + number=2, + type=descriptor_pb2.FieldDescriptorProto.TYPE_INT32, + label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL, + oneof_index=0, + ) + + self.service_proto = self.file_proto.service.add(name='TestService') + self.method_proto = self.service_proto.method.add( + name='CallMethod', + input_type='.protobuf_unittest.TopMessage', + output_type='.protobuf_unittest.TopMessage', + ) + + def BuildPool(self): + pool = descriptor_pool.DescriptorPool() + defaults = descriptor_pb2.FeatureSetDefaults( + defaults=[ + descriptor_pb2.FeatureSetDefaults.FeatureSetEditionDefault( + edition=descriptor_pb2.Edition.EDITION_PROTO2, + features=unittest_pb2.TestAllTypes.DESCRIPTOR._GetFeatures(), + ) + ], + minimum_edition=descriptor_pb2.Edition.EDITION_PROTO2, + maximum_edition=descriptor_pb2.Edition.EDITION_2023, + ) + defaults.defaults[0].features.Extensions[ + unittest_features_pb2.test + ].int_multiple_feature = 1 + pool.SetFeatureSetDefaults(defaults) + + self.file = pool.AddSerializedFile(self.file_proto.SerializeToString()) + self.top_message = pool.FindMessageTypeByName('protobuf_unittest.TopMessage') + self.top_enum = pool.FindEnumTypeByName('protobuf_unittest.TopEnum') + self.top_extension = pool.FindExtensionByName( + 'protobuf_unittest.top_extension' + ) + self.nested_message = self.top_message.nested_types_by_name['NestedMessage'] + self.nested_enum = self.top_message.enum_types_by_name['NestedEnum'] + self.nested_extension = self.top_message.extensions_by_name[ + 'nested_extension' + ] + self.field = self.top_message.fields_by_name['field'] + self.oneof = self.top_message.oneofs_by_name['Oneof'] + self.oneof_field = self.top_message.fields_by_name['oneof_field'] + self.enum_value = self.top_enum.values_by_name['TOP_VALUE'] + self.service = pool.FindServiceByName('protobuf_unittest.TestService') + self.method = self.service.methods_by_name['CallMethod'] + + def testFileDefaults(self): + self.BuildPool() + self.assertEqual(GetTestFeature(self.file), 1) + + def testFileOverride(self): + SetTestFeature(self.file_proto, 3) + self.BuildPool() + self.assertEqual(GetTestFeature(self.file), 3) + + def testFileMessageInherit(self): + SetTestFeature(self.file_proto, 3) + self.BuildPool() + self.assertEqual(GetTestFeature(self.top_message), 3) + + def testFileMessageOverride(self): + SetTestFeature(self.file_proto, 3) + SetTestFeature(self.top_message_proto, 5) + self.BuildPool() + self.assertEqual(GetTestFeature(self.top_message), 5) + + def testFileEnumInherit(self): + SetTestFeature(self.file_proto, 3) + self.BuildPool() + self.assertEqual(GetTestFeature(self.top_enum), 3) + + def testFileEnumOverride(self): + SetTestFeature(self.file_proto, 3) + SetTestFeature(self.top_enum_proto, 5) + self.BuildPool() + self.assertEqual(GetTestFeature(self.top_enum), 5) + + def testFileExtensionInherit(self): + SetTestFeature(self.file_proto, 3) + self.BuildPool() + self.assertEqual(GetTestFeature(self.top_extension), 3) + + def testFileExtensionOverride(self): + SetTestFeature(self.file_proto, 3) + SetTestFeature(self.top_extension_proto, 5) + self.BuildPool() + self.assertEqual(GetTestFeature(self.top_extension), 5) + + def testFileServiceInherit(self): + SetTestFeature(self.file_proto, 3) + self.BuildPool() + self.assertEqual(GetTestFeature(self.service), 3) + + def testFileServiceOverride(self): + SetTestFeature(self.file_proto, 3) + SetTestFeature(self.service_proto, 5) + self.BuildPool() + self.assertEqual(GetTestFeature(self.service), 5) + + def testMessageFieldInherit(self): + SetTestFeature(self.top_message_proto, 3) + self.BuildPool() + self.assertEqual(GetTestFeature(self.field), 3) + + def testMessageFieldOverride(self): + SetTestFeature(self.top_message_proto, 3) + SetTestFeature(self.field_proto, 5) + self.BuildPool() + self.assertEqual(GetTestFeature(self.field), 5) + + def testMessageEnumInherit(self): + SetTestFeature(self.top_message_proto, 3) + self.BuildPool() + self.assertEqual(GetTestFeature(self.nested_enum), 3) + + def testMessageEnumOverride(self): + SetTestFeature(self.top_message_proto, 3) + SetTestFeature(self.nested_enum_proto, 5) + self.BuildPool() + self.assertEqual(GetTestFeature(self.nested_enum), 5) + + def testMessageMessageInherit(self): + SetTestFeature(self.top_message_proto, 3) + self.BuildPool() + self.assertEqual(GetTestFeature(self.nested_message), 3) + + def testMessageMessageOverride(self): + SetTestFeature(self.top_message_proto, 3) + SetTestFeature(self.nested_message_proto, 5) + self.BuildPool() + self.assertEqual(GetTestFeature(self.nested_message), 5) + + def testMessageExtensionInherit(self): + SetTestFeature(self.top_message_proto, 3) + self.BuildPool() + self.assertEqual(GetTestFeature(self.nested_extension), 3) + + def testMessageExtensionOverride(self): + SetTestFeature(self.top_message_proto, 3) + SetTestFeature(self.nested_extension_proto, 5) + self.BuildPool() + self.assertEqual(GetTestFeature(self.nested_extension), 5) + + def testMessageOneofInherit(self): + SetTestFeature(self.top_message_proto, 3) + self.BuildPool() + self.assertEqual(GetTestFeature(self.oneof), 3) + + def testMessageOneofOverride(self): + SetTestFeature(self.top_message_proto, 3) + SetTestFeature(self.oneof_proto, 5) + self.BuildPool() + self.assertEqual(GetTestFeature(self.oneof), 5) + + def testOneofFieldInherit(self): + SetTestFeature(self.oneof_proto, 3) + self.BuildPool() + self.assertEqual(GetTestFeature(self.oneof_field), 3) + + def testOneofFieldOverride(self): + SetTestFeature(self.oneof_proto, 3) + SetTestFeature(self.oneof_field_proto, 5) + self.BuildPool() + self.assertEqual(GetTestFeature(self.oneof_field), 5) + + def testEnumValueInherit(self): + SetTestFeature(self.top_enum_proto, 3) + self.BuildPool() + self.assertEqual(GetTestFeature(self.enum_value), 3) + + def testEnumValueOverride(self): + SetTestFeature(self.top_enum_proto, 3) + SetTestFeature(self.enum_value_proto, 5) + self.BuildPool() + self.assertEqual(GetTestFeature(self.enum_value), 5) + + def testServiceMethodInherit(self): + SetTestFeature(self.service_proto, 3) + self.BuildPool() + self.assertEqual(GetTestFeature(self.method), 3) + + def testServiceMethodOverride(self): + SetTestFeature(self.service_proto, 3) + SetTestFeature(self.method_proto, 5) + self.BuildPool() + self.assertEqual(GetTestFeature(self.method), 5) + if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/legacy_features.proto b/python/google/protobuf/internal/legacy_features.proto new file mode 100644 index 000000000000..ef803ddbade8 --- /dev/null +++ b/python/google/protobuf/internal/legacy_features.proto @@ -0,0 +1,18 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file or at +// https://developers.google.com/open-source/licenses/bsd + +// Test that features with legacy descriptor helpers get properly converted. + +edition = "2023"; + +package google.protobuf.internal; + +message TestEditionsMessage { + int32 required_field = 1 [features.field_presence = LEGACY_REQUIRED]; + TestEditionsMessage delimited_field = 2 + [features.message_encoding = DELIMITED]; +} diff --git a/python/google/protobuf/internal/python_edition_defaults.py.template b/python/google/protobuf/internal/python_edition_defaults.py.template new file mode 100644 index 000000000000..56bdf042e63f --- /dev/null +++ b/python/google/protobuf/internal/python_edition_defaults.py.template @@ -0,0 +1,5 @@ +""" +This file contains the serialized FeatureSetDefaults object corresponding to +the Pure Python runtime. This is used for feature resolution under Editions. +""" +_PROTOBUF_INTERNAL_PYTHON_EDITION_DEFAULTS = b"DEFAULTS_VALUE" diff --git a/src/google/protobuf/compiler/BUILD.bazel b/src/google/protobuf/compiler/BUILD.bazel index 45afef2a933a..a03b8171c4f6 100644 --- a/src/google/protobuf/compiler/BUILD.bazel +++ b/src/google/protobuf/compiler/BUILD.bazel @@ -163,8 +163,7 @@ cc_binary( ], copts = COPTS, visibility = [ - "//src/google/protobuf:__subpackages__", - "//upb:__subpackages__", + "//:__subpackages__", ], deps = [ ":command_line_interface", diff --git a/src/google/protobuf/compiler/python/generator.cc b/src/google/protobuf/compiler/python/generator.cc index 3f3410e3cd0f..e34a7be01569 100644 --- a/src/google/protobuf/compiler/python/generator.cc +++ b/src/google/protobuf/compiler/python/generator.cc @@ -21,9 +21,7 @@ #include "google/protobuf/compiler/python/generator.h" -#include #include -#include #include #include #include @@ -160,16 +158,21 @@ std::string StringifyDefaultValue(const FieldDescriptor& field) { return ""; } +// Returns a CEscaped string of serialized_options. +std::string OptionsValue(absl::string_view serialized_options) { + if (serialized_options.empty()) { + return "None"; + } else { + return absl::StrCat("b'", absl::CEscape(serialized_options), "'"); + } +} + } // namespace Generator::Generator() : file_(nullptr) {} Generator::~Generator() {} -uint64_t Generator::GetSupportedFeatures() const { - return CodeGenerator::Feature::FEATURE_PROTO3_OPTIONAL; -} - GeneratorOptions Generator::ParseParameter(absl::string_view parameter, std::string* error) const { GeneratorOptions options; @@ -231,8 +234,8 @@ bool Generator::Generate(const FileDescriptor* file, std::string filename = GetFileName(file, ".py"); - FileDescriptorProto fdp = StripSourceRetentionOptions(*file_); - fdp.SerializeToString(&file_descriptor_serialized_); + proto_ = StripSourceRetentionOptions(*file_); + proto_.SerializeToString(&file_descriptor_serialized_); if (!opensource_runtime_ && GeneratingDescriptorProto()) { std::string bootstrap_filename = @@ -320,7 +323,7 @@ bool Generator::Generate(const FileDescriptor* file, FixAllDescriptorOptions(); // Set serialized_start and serialized_end. - SetSerializedPbInterval(fdp); + SetSerializedPbInterval(proto_); printer_->Outdent(); if (HasGenericServices(file)) { @@ -533,8 +536,8 @@ void Generator::PrintFileDescriptor() const { m["package"] = file_->package(); m["syntax"] = std::string( FileDescriptorLegacy::SyntaxName(FileDescriptorLegacy(file_).syntax())); - m["options"] = OptionsValue( - StripLocalSourceRetentionOptions(*file_).SerializeAsString()); + m["edition"] = Edition_Name(file_->edition()); + m["options"] = OptionsValue(proto_.options().SerializeAsString()); m["serialized_descriptor"] = absl::CHexEscape(file_descriptor_serialized_); if (GeneratingDescriptorProto()) { printer_->Print("if _descriptor._USE_C_DESCRIPTORS == False:\n"); @@ -547,6 +550,7 @@ void Generator::PrintFileDescriptor() const { " name='$name$',\n" " package='$package$',\n" " syntax='$syntax$',\n" + " edition='$edition$',\n" " serialized_options=$options$,\n" " create_key=_descriptor._internal_create_key,\n"; printer_->Print(m, file_descriptor_template); @@ -596,17 +600,18 @@ void Generator::PrintFileDescriptor() const { // Prints all enums contained in all message types in |file|. void Generator::PrintAllEnumsInFile() const { for (int i = 0; i < file_->enum_type_count(); ++i) { - PrintEnum(*file_->enum_type(i)); + PrintEnum(*file_->enum_type(i), proto_.enum_type(i)); } for (int i = 0; i < file_->message_type_count(); ++i) { - PrintNestedEnums(*file_->message_type(i)); + PrintNestedEnums(*file_->message_type(i), proto_.message_type(i)); } } // Prints a Python statement assigning the appropriate module-level // enum name to a Python EnumDescriptor object equivalent to // enum_descriptor. -void Generator::PrintEnum(const EnumDescriptor& enum_descriptor) const { +void Generator::PrintEnum(const EnumDescriptor& enum_descriptor, + const EnumDescriptorProto& proto) const { absl::flat_hash_map m; std::string module_level_descriptor_name = ModuleLevelDescriptorName(enum_descriptor); @@ -623,14 +628,13 @@ void Generator::PrintEnum(const EnumDescriptor& enum_descriptor) const { " create_key=_descriptor._internal_create_key,\n" " values=[\n"; std::string options_string; - StripLocalSourceRetentionOptions(enum_descriptor) - .SerializeToString(&options_string); + proto.options().SerializeToString(&options_string); printer_->Print(m, enum_descriptor_template); printer_->Indent(); printer_->Indent(); for (int i = 0; i < enum_descriptor.value_count(); ++i) { - PrintEnumValueDescriptor(*enum_descriptor.value(i)); + PrintEnumValueDescriptor(*enum_descriptor.value(i), proto.value(i)); printer_->Print(",\n"); } @@ -649,20 +653,21 @@ void Generator::PrintEnum(const EnumDescriptor& enum_descriptor) const { // Recursively prints enums in nested types within descriptor, then // prints enums contained at the top level in descriptor. -void Generator::PrintNestedEnums(const Descriptor& descriptor) const { +void Generator::PrintNestedEnums(const Descriptor& descriptor, + const DescriptorProto& proto) const { for (int i = 0; i < descriptor.nested_type_count(); ++i) { - PrintNestedEnums(*descriptor.nested_type(i)); + PrintNestedEnums(*descriptor.nested_type(i), proto.nested_type(i)); } for (int i = 0; i < descriptor.enum_type_count(); ++i) { - PrintEnum(*descriptor.enum_type(i)); + PrintEnum(*descriptor.enum_type(i), proto.enum_type(i)); } } // Prints Python equivalents of all Descriptors in |file|. void Generator::PrintMessageDescriptors() const { for (int i = 0; i < file_->message_type_count(); ++i) { - PrintDescriptor(*file_->message_type(i)); + PrintDescriptor(*file_->message_type(i), proto_.message_type(i)); printer_->Print("\n"); } } @@ -732,13 +737,14 @@ void Generator::PrintServiceStub(const ServiceDescriptor& descriptor) const { // to a Python Descriptor object for message_descriptor. // // Mutually recursive with PrintNestedDescriptors(). -void Generator::PrintDescriptor(const Descriptor& message_descriptor) const { +void Generator::PrintDescriptor(const Descriptor& message_descriptor, + const DescriptorProto& proto) const { absl::flat_hash_map m; m["name"] = message_descriptor.name(); m["full_name"] = message_descriptor.full_name(); m["file"] = kDescriptorKey; - PrintNestedDescriptors(message_descriptor); + PrintNestedDescriptors(message_descriptor, proto); printer_->Print("\n"); printer_->Print("$descriptor_name$ = _descriptor.Descriptor(\n", @@ -753,8 +759,8 @@ void Generator::PrintDescriptor(const Descriptor& message_descriptor) const { "containing_type=None,\n" "create_key=_descriptor._internal_create_key,\n"; printer_->Print(m, required_function_arguments); - PrintFieldsInDescriptor(message_descriptor); - PrintExtensionsInDescriptor(message_descriptor); + PrintFieldsInDescriptor(message_descriptor, proto); + PrintExtensionsInDescriptor(message_descriptor, proto); // Nested types printer_->Print("nested_types=["); @@ -777,16 +783,12 @@ void Generator::PrintDescriptor(const Descriptor& message_descriptor) const { printer_->Outdent(); printer_->Print("],\n"); std::string options_string; - StripLocalSourceRetentionOptions(message_descriptor) - .SerializeToString(&options_string); + proto.options().SerializeToString(&options_string); printer_->Print( "serialized_options=$options_value$,\n" - "is_extendable=$extendable$,\n" - "syntax='$syntax$'", + "is_extendable=$extendable$", "options_value", OptionsValue(options_string), "extendable", - message_descriptor.extension_range_count() > 0 ? "True" : "False", - "syntax", - FileDescriptorLegacy::SyntaxName(FileDescriptorLegacy(file_).syntax())); + message_descriptor.extension_range_count() > 0 ? "True" : "False"); printer_->Print(",\n"); // Extension ranges @@ -807,8 +809,8 @@ void Generator::PrintDescriptor(const Descriptor& message_descriptor) const { m["name"] = desc->name(); m["full_name"] = desc->full_name(); m["index"] = absl::StrCat(desc->index()); - options_string = OptionsValue( - StripLocalSourceRetentionOptions(*desc).SerializeAsString()); + options_string = + OptionsValue(proto.oneof_decl(i).options().SerializeAsString()); if (options_string == "None") { m["serialized_options"] = ""; } else { @@ -833,10 +835,11 @@ void Generator::PrintDescriptor(const Descriptor& message_descriptor) const { // message_descriptor. // // Mutually recursive with PrintDescriptor(). -void Generator::PrintNestedDescriptors( - const Descriptor& containing_descriptor) const { +void Generator::PrintNestedDescriptors(const Descriptor& containing_descriptor, + const DescriptorProto& proto) const { for (int i = 0; i < containing_descriptor.nested_type_count(); ++i) { - PrintDescriptor(*containing_descriptor.nested_type(i)); + PrintDescriptor(*containing_descriptor.nested_type(i), + proto.nested_type(i)); } } @@ -1103,12 +1106,12 @@ void Generator::FixForeignFieldsInDescriptors() const { // Returns a Python expression that instantiates a Python EnumValueDescriptor // object for the given C++ descriptor. void Generator::PrintEnumValueDescriptor( - const EnumValueDescriptor& descriptor) const { + const EnumValueDescriptor& descriptor, + const EnumValueDescriptorProto& proto) const { // TODO: Fix up EnumValueDescriptor "type" fields. // More circular references. ::sigh:: std::string options_string; - StripLocalSourceRetentionOptions(descriptor) - .SerializeToString(&options_string); + proto.options().SerializeToString(&options_string); absl::flat_hash_map m; m["name"] = descriptor.name(); m["index"] = absl::StrCat(descriptor.index()); @@ -1122,21 +1125,11 @@ void Generator::PrintEnumValueDescriptor( " create_key=_descriptor._internal_create_key)"); } -// Returns a CEscaped string of serialized_options. -std::string Generator::OptionsValue( - absl::string_view serialized_options) const { - if (serialized_options.length() == 0) { - return "None"; - } else { - return absl::StrCat("b'", absl::CEscape(serialized_options), "'"); - } -} - // Prints an expression for a Python FieldDescriptor for |field|. void Generator::PrintFieldDescriptor(const FieldDescriptor& field, - bool is_extension) const { + const FieldDescriptorProto& proto) const { std::string options_string; - StripLocalSourceRetentionOptions(field).SerializeToString(&options_string); + proto.options().SerializeToString(&options_string); absl::flat_hash_map m; m["name"] = field.name(); m["full_name"] = field.full_name(); @@ -1147,7 +1140,7 @@ void Generator::PrintFieldDescriptor(const FieldDescriptor& field, m["label"] = absl::StrCat(field.label()); m["has_default_value"] = field.has_default_value() ? "True" : "False"; m["default_value"] = StringifyDefaultValue(field); - m["is_extension"] = is_extension ? "True" : "False"; + m["is_extension"] = field.is_extension() ? "True" : "False"; m["serialized_options"] = OptionsValue(options_string); m["json_name"] = field.has_json_name() ? absl::StrCat(", json_name='", field.json_name(), "'") @@ -1170,13 +1163,16 @@ void Generator::PrintFieldDescriptor(const FieldDescriptor& field, // Helper for Print{Fields,Extensions}InDescriptor(). void Generator::PrintFieldDescriptorsInDescriptor( - const Descriptor& message_descriptor, bool is_extension, - absl::string_view list_variable_name, int (Descriptor::*CountFn)() const, - const FieldDescriptor* (Descriptor::*GetterFn)(int) const) const { + const Descriptor& message_descriptor, const DescriptorProto& proto, + bool is_extension, absl::string_view list_variable_name) const { printer_->Print("$list$=[\n", "list", list_variable_name); printer_->Indent(); - for (int i = 0; i < (message_descriptor.*CountFn)(); ++i) { - PrintFieldDescriptor(*(message_descriptor.*GetterFn)(i), is_extension); + int count = is_extension ? message_descriptor.extension_count() + : message_descriptor.field_count(); + for (int i = 0; i < count; ++i) { + PrintFieldDescriptor(is_extension ? *message_descriptor.extension(i) + : *message_descriptor.field(i), + is_extension ? proto.extension(i) : proto.field(i)); printer_->Print(",\n"); } printer_->Outdent(); @@ -1185,22 +1181,20 @@ void Generator::PrintFieldDescriptorsInDescriptor( // Prints a statement assigning "fields" to a list of Python FieldDescriptors, // one for each field present in message_descriptor. -void Generator::PrintFieldsInDescriptor( - const Descriptor& message_descriptor) const { +void Generator::PrintFieldsInDescriptor(const Descriptor& message_descriptor, + const DescriptorProto& proto) const { const bool is_extension = false; - PrintFieldDescriptorsInDescriptor(message_descriptor, is_extension, "fields", - &Descriptor::field_count, - &Descriptor::field); + PrintFieldDescriptorsInDescriptor(message_descriptor, proto, is_extension, + "fields"); } // Prints a statement assigning "extensions" to a list of Python // FieldDescriptors, one for each extension present in message_descriptor. void Generator::PrintExtensionsInDescriptor( - const Descriptor& message_descriptor) const { + const Descriptor& message_descriptor, const DescriptorProto& proto) const { const bool is_extension = true; - PrintFieldDescriptorsInDescriptor(message_descriptor, is_extension, - "extensions", &Descriptor::extension_count, - &Descriptor::extension); + PrintFieldDescriptorsInDescriptor(message_descriptor, proto, is_extension, + "extensions"); } bool Generator::GeneratingDescriptorProto() const { @@ -1287,9 +1281,9 @@ void Generator::PrintSerializedPbInterval( template bool Generator::PrintDescriptorOptionsFixingCode( - const DescriptorT& descriptor, absl::string_view descriptor_str) const { - std::string options = OptionsValue( - StripLocalSourceRetentionOptions(descriptor).SerializeAsString()); + const DescriptorT& descriptor, const typename DescriptorT::Proto& proto, + absl::string_view descriptor_str) const { + std::string options = OptionsValue(proto.options().SerializeAsString()); // Reset the _options to None thus DescriptorBase.GetOptions() can // parse _options again after extensions are registered. @@ -1308,7 +1302,7 @@ bool Generator::PrintDescriptorOptionsFixingCode( } printer_->Print( - "$descriptor_name$._options = None\n" + "$descriptor_name$._loaded_options = None\n" "$descriptor_name$._serialized_options = $serialized_value$\n", "descriptor_name", descriptor_name, "serialized_value", options); return true; @@ -1362,46 +1356,46 @@ void Generator::SetMessagePbInterval(const DescriptorProto& message_proto, // Prints expressions that set the options field of all descriptors. void Generator::FixAllDescriptorOptions() const { // Prints an expression that sets the file descriptor's options. - if (!PrintDescriptorOptionsFixingCode(*file_, kDescriptorKey)) { - printer_->Print("DESCRIPTOR._options = None\n"); + if (!PrintDescriptorOptionsFixingCode(*file_, proto_, kDescriptorKey)) { + printer_->Print("DESCRIPTOR._loaded_options = None\n"); } // Prints expressions that set the options for all top level enums. for (int i = 0; i < file_->enum_type_count(); ++i) { - const EnumDescriptor& enum_descriptor = *file_->enum_type(i); - FixOptionsForEnum(enum_descriptor); + FixOptionsForEnum(*file_->enum_type(i), proto_.enum_type(i)); } // Prints expressions that set the options for all top level extensions. for (int i = 0; i < file_->extension_count(); ++i) { - const FieldDescriptor& field = *file_->extension(i); - FixOptionsForField(field); + FixOptionsForField(*file_->extension(i), proto_.extension(i)); } // Prints expressions that set the options for all messages, nested enums, // nested extensions and message fields. for (int i = 0; i < file_->message_type_count(); ++i) { - FixOptionsForMessage(*file_->message_type(i)); + FixOptionsForMessage(*file_->message_type(i), proto_.message_type(i)); } for (int i = 0; i < file_->service_count(); ++i) { - FixOptionsForService(*file_->service(i)); + FixOptionsForService(*file_->service(i), proto_.service(i)); } } -void Generator::FixOptionsForOneof(const OneofDescriptor& oneof) const { +void Generator::FixOptionsForOneof(const OneofDescriptor& oneof, + const OneofDescriptorProto& proto) const { std::string oneof_name = absl::Substitute( "$0.$1['$2']", ModuleLevelDescriptorName(*oneof.containing_type()), "oneofs_by_name", oneof.name()); - PrintDescriptorOptionsFixingCode(oneof, oneof_name); + PrintDescriptorOptionsFixingCode(oneof, proto, oneof_name); } // Prints expressions that set the options for an enum descriptor and its // value descriptors. -void Generator::FixOptionsForEnum(const EnumDescriptor& enum_descriptor) const { +void Generator::FixOptionsForEnum(const EnumDescriptor& enum_descriptor, + const EnumDescriptorProto& proto) const { std::string descriptor_name = ModuleLevelDescriptorName(enum_descriptor); - PrintDescriptorOptionsFixingCode(enum_descriptor, descriptor_name); + PrintDescriptorOptionsFixingCode(enum_descriptor, proto, descriptor_name); for (int i = 0; i < enum_descriptor.value_count(); ++i) { const EnumValueDescriptor& value_descriptor = *enum_descriptor.value(i); PrintDescriptorOptionsFixingCode( - value_descriptor, + value_descriptor, proto.value(i), absl::StrFormat("%s.values_by_name[\"%s\"]", descriptor_name.c_str(), value_descriptor.name().c_str())); } @@ -1410,22 +1404,24 @@ void Generator::FixOptionsForEnum(const EnumDescriptor& enum_descriptor) const { // Prints expressions that set the options for an service descriptor and its // value descriptors. void Generator::FixOptionsForService( - const ServiceDescriptor& service_descriptor) const { + const ServiceDescriptor& service_descriptor, + const ServiceDescriptorProto& proto) const { std::string descriptor_name = ModuleLevelServiceDescriptorName(service_descriptor); - PrintDescriptorOptionsFixingCode(service_descriptor, descriptor_name); + PrintDescriptorOptionsFixingCode(service_descriptor, proto, descriptor_name); for (int i = 0; i < service_descriptor.method_count(); ++i) { const MethodDescriptor* method = service_descriptor.method(i); - PrintDescriptorOptionsFixingCode( - *method, absl::StrCat(descriptor_name, ".methods_by_name['", - method->name(), "']")); + std::string method_name = absl::StrCat( + descriptor_name, ".methods_by_name['", method->name(), "']"); + PrintDescriptorOptionsFixingCode(*method, proto.method(i), method_name); } } // Prints expressions that set the options for field descriptors (including // extensions). -void Generator::FixOptionsForField(const FieldDescriptor& field) const { +void Generator::FixOptionsForField(const FieldDescriptor& field, + const FieldDescriptorProto& proto) const { std::string field_name; if (field.is_extension()) { if (field.extension_scope() == nullptr) { @@ -1439,36 +1435,37 @@ void Generator::FixOptionsForField(const FieldDescriptor& field) const { field_name = FieldReferencingExpression(field.containing_type(), field, "fields_by_name"); } - PrintDescriptorOptionsFixingCode(field, field_name); + PrintDescriptorOptionsFixingCode(field, proto, field_name); } // Prints expressions that set the options for a message and all its inner // types (nested messages, nested enums, extensions, fields). -void Generator::FixOptionsForMessage(const Descriptor& descriptor) const { +void Generator::FixOptionsForMessage(const Descriptor& descriptor, + const DescriptorProto& proto) const { // Nested messages. for (int i = 0; i < descriptor.nested_type_count(); ++i) { - FixOptionsForMessage(*descriptor.nested_type(i)); + FixOptionsForMessage(*descriptor.nested_type(i), proto.nested_type(i)); } // Oneofs. for (int i = 0; i < descriptor.oneof_decl_count(); ++i) { - FixOptionsForOneof(*descriptor.oneof_decl(i)); + FixOptionsForOneof(*descriptor.oneof_decl(i), proto.oneof_decl(i)); } // Enums. for (int i = 0; i < descriptor.enum_type_count(); ++i) { - FixOptionsForEnum(*descriptor.enum_type(i)); + FixOptionsForEnum(*descriptor.enum_type(i), proto.enum_type(i)); } // Fields. for (int i = 0; i < descriptor.field_count(); ++i) { const FieldDescriptor& field = *descriptor.field(i); - FixOptionsForField(field); + FixOptionsForField(field, proto.field(i)); } // Extensions. for (int i = 0; i < descriptor.extension_count(); ++i) { const FieldDescriptor& field = *descriptor.extension(i); - FixOptionsForField(field); + FixOptionsForField(field, proto.extension(i)); } // Message option for this message. - PrintDescriptorOptionsFixingCode(descriptor, + PrintDescriptorOptionsFixingCode(descriptor, proto, ModuleLevelDescriptorName(descriptor)); } diff --git a/src/google/protobuf/compiler/python/generator.h b/src/google/protobuf/compiler/python/generator.h index c5beeeea35af..531c44a3492b 100644 --- a/src/google/protobuf/compiler/python/generator.h +++ b/src/google/protobuf/compiler/python/generator.h @@ -65,7 +65,14 @@ class PROTOC_EXPORT Generator : public CodeGenerator { GeneratorContext* generator_context, std::string* error) const override; - uint64_t GetSupportedFeatures() const override; + uint64_t GetSupportedFeatures() const override { + return Feature::FEATURE_PROTO3_OPTIONAL; + } + Edition GetMinimumEdition() const override { return Edition::EDITION_PROTO2; } + Edition GetMaximumEdition() const override { return Edition::EDITION_2023; } + std::vector GetFeatureExtensions() const override { + return {}; + } void set_opensource_runtime(bool opensource) { opensource_runtime_ = opensource; @@ -80,20 +87,25 @@ class PROTOC_EXPORT Generator : public CodeGenerator { void PrintResolvedFeatures() const; void PrintFileDescriptor() const; void PrintAllEnumsInFile() const; - void PrintNestedEnums(const Descriptor& descriptor) const; - void PrintEnum(const EnumDescriptor& enum_descriptor) const; + void PrintNestedEnums(const Descriptor& descriptor, + const DescriptorProto& proto) const; + void PrintEnum(const EnumDescriptor& enum_descriptor, + const EnumDescriptorProto& proto) const; void PrintFieldDescriptor(const FieldDescriptor& field, - bool is_extension) const; + const FieldDescriptorProto& proto) const; void PrintFieldDescriptorsInDescriptor( - const Descriptor& message_descriptor, bool is_extension, - absl::string_view list_variable_name, int (Descriptor::*CountFn)() const, - const FieldDescriptor* (Descriptor::*GetterFn)(int) const) const; - void PrintFieldsInDescriptor(const Descriptor& message_descriptor) const; - void PrintExtensionsInDescriptor(const Descriptor& message_descriptor) const; + const Descriptor& message_descriptor, const DescriptorProto& proto, + bool is_extension, absl::string_view list_variable_name) const; + void PrintFieldsInDescriptor(const Descriptor& message_descriptor, + const DescriptorProto& proto) const; + void PrintExtensionsInDescriptor(const Descriptor& message_descriptor, + const DescriptorProto& proto) const; void PrintMessageDescriptors() const; - void PrintDescriptor(const Descriptor& message_descriptor) const; - void PrintNestedDescriptors(const Descriptor& containing_descriptor) const; + void PrintDescriptor(const Descriptor& message_descriptor, + const DescriptorProto& proto) const; + void PrintNestedDescriptors(const Descriptor& containing_descriptor, + const DescriptorProto& proto) const; void PrintMessages() const; void PrintMessage(const Descriptor& message_descriptor, @@ -132,8 +144,8 @@ class PROTOC_EXPORT Generator : public CodeGenerator { void PrintDescriptorKeyAndModuleName( const ServiceDescriptor& descriptor) const; - void PrintEnumValueDescriptor(const EnumValueDescriptor& descriptor) const; - std::string OptionsValue(absl::string_view serialized_options) const; + void PrintEnumValueDescriptor(const EnumValueDescriptor& descriptor, + const EnumValueDescriptorProto& proto) const; bool GeneratingDescriptorProto() const; template @@ -147,15 +159,21 @@ class PROTOC_EXPORT Generator : public CodeGenerator { absl::string_view name) const; template - bool PrintDescriptorOptionsFixingCode(const DescriptorT& descriptor, - absl::string_view descriptor_str) const; + bool PrintDescriptorOptionsFixingCode( + const DescriptorT& descriptor, const typename DescriptorT::Proto& proto, + absl::string_view descriptor_str) const; void FixAllDescriptorOptions() const; - void FixOptionsForField(const FieldDescriptor& field) const; - void FixOptionsForOneof(const OneofDescriptor& oneof) const; - void FixOptionsForEnum(const EnumDescriptor& descriptor) const; - void FixOptionsForService(const ServiceDescriptor& descriptor) const; - void FixOptionsForMessage(const Descriptor& descriptor) const; + void FixOptionsForField(const FieldDescriptor& field, + const FieldDescriptorProto& proto) const; + void FixOptionsForOneof(const OneofDescriptor& oneof, + const OneofDescriptorProto& proto) const; + void FixOptionsForEnum(const EnumDescriptor& descriptor, + const EnumDescriptorProto& proto) const; + void FixOptionsForService(const ServiceDescriptor& descriptor, + const ServiceDescriptorProto& proto) const; + void FixOptionsForMessage(const Descriptor& descriptor, + const DescriptorProto& proto) const; void SetSerializedPbInterval(const FileDescriptorProto& file) const; void SetMessagePbInterval(const DescriptorProto& message_proto, @@ -168,6 +186,7 @@ class PROTOC_EXPORT Generator : public CodeGenerator { // Guards file_, printer_ and file_descriptor_serialized_. mutable absl::Mutex mutex_; mutable const FileDescriptor* file_; // Set in Generate(). Under mutex_. + mutable FileDescriptorProto proto_; // Set in Generate(). Under mutex_. mutable std::string file_descriptor_serialized_; mutable io::Printer* printer_; // Set in Generate(). Under mutex_. diff --git a/src/google/protobuf/editions/BUILD b/src/google/protobuf/editions/BUILD index 4489ffdbf8e9..cba63638a0e9 100644 --- a/src/google/protobuf/editions/BUILD +++ b/src/google/protobuf/editions/BUILD @@ -1,3 +1,4 @@ +load("@rules_python//python:proto.bzl", "py_proto_library") load("@rules_cc//cc:defs.bzl", "cc_proto_library") load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load(":defaults.bzl", "compile_edition_defaults", "embed_edition_defaults") @@ -98,6 +99,13 @@ cc_proto_library( deps = [":test_messages_proto2_editions_proto"], ) +py_proto_library( + name = "test_messages_proto2_editions_py_pb2", + testonly = True, + visibility = ["//conformance:__pkg__"], + deps = [":test_messages_proto2_editions_proto"], +) + proto_library( name = "test_messages_proto3_editions_proto", testonly = True, @@ -120,6 +128,24 @@ cc_proto_library( deps = [":test_messages_proto3_editions_proto"], ) +py_proto_library( + name = "test_messages_proto3_editions_py_pb2", + testonly = True, + visibility = ["//conformance:__pkg__"], + deps = [":test_messages_proto3_editions_proto"], +) + +# Export these for conformance tests until we support py_proto_library. +exports_files( + [ + "golden/test_messages_proto2_editions.proto", + "golden/test_messages_proto3_editions.proto", + ], + visibility = [ + "//python:__pkg__", + ], +) + proto_library( name = "test_editions_default_features_proto", testonly = True,