diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py index a916276319b6..8ff549381e74 100755 --- a/python/google/protobuf/internal/decoder.py +++ b/python/google/protobuf/internal/decoder.py @@ -806,8 +806,7 @@ def DecodeItem(buffer, pos, end, message, field_dict): if value is None: message_type = extension.message_type if not hasattr(message_type, '_concrete_class'): - # pylint: disable=protected-access - message._FACTORY.GetPrototype(message_type) + message_factory.GetMessageClass(message_type) value = field_dict.setdefault( extension, message_type._concrete_class()) if value._InternalParse(buffer, message_start,message_end) != message_end: diff --git a/python/google/protobuf/internal/extension_dict.py b/python/google/protobuf/internal/extension_dict.py index b346cf283e2c..83c4cb5dc656 100644 --- a/python/google/protobuf/internal/extension_dict.py +++ b/python/google/protobuf/internal/extension_dict.py @@ -89,8 +89,9 @@ def __getitem__(self, extension_handle): elif extension_handle.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: message_type = extension_handle.message_type if not hasattr(message_type, '_concrete_class'): - # pylint: disable=protected-access - self._extended_message._FACTORY.GetPrototype(message_type) + # pylint: disable=g-import-not-at-top + from google.protobuf import message_factory + message_factory.GetMessageClass(message_type) assert getattr(extension_handle.message_type, '_concrete_class', None), ( 'Uninitialized concrete class found for field %r (message type %r)' % (extension_handle.full_name, diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py index 3ab3b8ed370c..0d132828bdbf 100644 --- a/python/google/protobuf/internal/message_factory_test.py +++ b/python/google/protobuf/internal/message_factory_test.py @@ -92,36 +92,17 @@ def testGetPrototype(self): pool = descriptor_pool.DescriptorPool(db) db.Add(self.factory_test1_fd) db.Add(self.factory_test2_fd) - factory = message_factory.MessageFactory() - cls = factory.GetPrototype(pool.FindMessageTypeByName( + cls = message_factory.GetMessageClass(pool.FindMessageTypeByName( 'google.protobuf.python.internal.Factory2Message')) self.assertFalse(cls is factory_test2_pb2.Factory2Message) self._ExerciseDynamicClass(cls) - cls2 = factory.GetPrototype(pool.FindMessageTypeByName( + cls2 = message_factory.GetMessageClass(pool.FindMessageTypeByName( 'google.protobuf.python.internal.Factory2Message')) self.assertTrue(cls is cls2) - def testCreatePrototypeOverride(self): - class MyMessageFactory(message_factory.MessageFactory): - - def CreatePrototype(self, descriptor): - cls = super(MyMessageFactory, self).CreatePrototype(descriptor) - cls.additional_field = 'Some value' - return cls - - db = descriptor_database.DescriptorDatabase() - pool = descriptor_pool.DescriptorPool(db) - db.Add(self.factory_test1_fd) - db.Add(self.factory_test2_fd) - factory = MyMessageFactory() - cls = factory.GetPrototype(pool.FindMessageTypeByName( - 'google.protobuf.python.internal.Factory2Message')) - self.assertTrue(hasattr(cls, 'additional_field')) - def testGetExistingPrototype(self): - factory = message_factory.MessageFactory() # Get Existing Prototype should not create a new class. - cls = factory.GetPrototype( + cls = message_factory.GetMessageClass( descriptor=factory_test2_pb2.Factory2Message.DESCRIPTOR) msg = factory_test2_pb2.Factory2Message() self.assertIsInstance(msg, cls) @@ -181,7 +162,6 @@ def testGetMessages(self): def testDuplicateExtensionNumber(self): pool = descriptor_pool.DescriptorPool() - factory = message_factory.MessageFactory(pool=pool) # Add Container message. f = descriptor_pb2.FileDescriptorProto( @@ -189,7 +169,7 @@ def testDuplicateExtensionNumber(self): package='google.protobuf.python.internal') f.message_type.add(name='Container').extension_range.add(start=1, end=10) pool.Add(f) - msgs = factory.GetMessages([f.name]) + msgs = message_factory.GetMessageClassesForFiles([f.name], pool) self.assertIn('google.protobuf.python.internal.Container', msgs) # Extend container. @@ -205,7 +185,7 @@ def testDuplicateExtensionNumber(self): type_name='Extension', extendee='Container') pool.Add(f) - msgs = factory.GetMessages([f.name]) + msgs = message_factory.GetMessageClassesForFiles([f.name], pool) self.assertIn('google.protobuf.python.internal.Extension', msgs) # Add Duplicate extending the same field number. @@ -223,7 +203,7 @@ def testDuplicateExtensionNumber(self): pool.Add(f) with self.assertRaises(Exception) as cm: - factory.GetMessages([f.name]) + message_factory.GetMessageClassesForFiles([f.name], pool) self.assertIn(str(cm.exception), ['Extensions ' @@ -281,8 +261,8 @@ def FindFileByName(self, name): db = SimpleDescriptorDB({f1.name: f1, f2.name: f2, f3.name: f3}) pool = descriptor_pool.DescriptorPool(db) - factory = message_factory.MessageFactory(pool=pool) - msgs = factory.GetMessages([f1.name, f3.name]) # Deliberately not f2. + msgs = message_factory.GetMessageClassesForFiles( + [f1.name, f3.name], pool) # Deliberately not f2. msg = msgs['google.protobuf.python.internal.Container'] desc = msgs['google.protobuf.python.internal.Extension'].DESCRIPTOR ext1 = desc.file.extensions_by_name['top_level_extension_field'] @@ -293,8 +273,8 @@ def FindFileByName(self, name): serialized = m.SerializeToString() pool = descriptor_pool.DescriptorPool(db) - factory = message_factory.MessageFactory(pool=pool) - msgs = factory.GetMessages([f1.name, f3.name]) # Deliberately not f2. + msgs = message_factory.GetMessageClassesForFiles( + [f1.name, f3.name], pool) # Deliberately not f2. msg = msgs['google.protobuf.python.internal.Container'] desc = msgs['google.protobuf.python.internal.Extension'].DESCRIPTOR ext1 = desc.file.extensions_by_name['top_level_extension_field'] diff --git a/python/google/protobuf/json_format.py b/python/google/protobuf/json_format.py index e2d058e1a6da..a04e8aef1331 100644 --- a/python/google/protobuf/json_format.py +++ b/python/google/protobuf/json_format.py @@ -53,6 +53,7 @@ from google.protobuf.internal import type_checkers from google.protobuf import descriptor +from google.protobuf import message_factory from google.protobuf import symbol_database @@ -409,7 +410,7 @@ def _CreateMessageFromTypeUrl(type_url, descriptor_pool): raise TypeError( 'Can not find message descriptor by type_url: {0}'.format(type_url) ) from e - message_class = db.GetPrototype(message_descriptor) + message_class = message_factory.GetMessageClass(message_descriptor) return message_class() diff --git a/python/google/protobuf/message_factory.py b/python/google/protobuf/message_factory.py index ce5b5a7f65d5..fac1165c517b 100644 --- a/python/google/protobuf/message_factory.py +++ b/python/google/protobuf/message_factory.py @@ -39,6 +39,8 @@ __author__ = 'matthewtoia@google.com (Matt Toia)' +import warnings + from google.protobuf.internal import api_implementation from google.protobuf import descriptor_pool from google.protobuf import message @@ -53,6 +55,95 @@ _GENERATED_PROTOCOL_MESSAGE_TYPE = message_impl.GeneratedProtocolMessageType +def GetMessageClass(descriptor): + """Obtains a proto2 message class based on the passed in descriptor. + + Passing a descriptor with a fully qualified name matching a previous + invocation will cause the same class to be returned. + + Args: + descriptor: The descriptor to build from. + + Returns: + A class describing the passed in descriptor. + """ + concrete_class = getattr(descriptor, '_concrete_class', None) + if concrete_class: + return concrete_class + return _InternalCreateMessageClass(descriptor) + + +def GetMessageClassesForFiles(files, pool): + """Gets all the messages from specified files. + + This will find and resolve dependencies, failing if the descriptor + pool cannot satisfy them. + + Args: + files: The file names to extract messages from. + pool: The descriptor pool to find the files including the dependent + files. + + Returns: + A dictionary mapping proto names to the message classes. + """ + result = {} + for file_name in files: + file_desc = pool.FindFileByName(file_name) + for desc in file_desc.message_types_by_name.values(): + result[desc.full_name] = GetMessageClass(desc) + + # While the extension FieldDescriptors are created by the descriptor pool, + # the python classes created in the factory need them to be registered + # explicitly, which is done below. + # + # The call to RegisterExtension will specifically check if the + # extension was already registered on the object and either + # ignore the registration if the original was the same, or raise + # an error if they were different. + + for extension in file_desc.extensions_by_name.values(): + extended_class = GetMessageClass(extension.containing_type) + extended_class.RegisterExtension(extension) + # Recursively load protos for extension field, in order to be able to + # fully represent the extension. This matches the behavior for regular + # fields too. + if extension.message_type: + GetMessageClass(extension.message_type) + return result + + +def _InternalCreateMessageClass(descriptor): + """Builds a proto2 message class based on the passed in descriptor. + + Args: + descriptor: The descriptor to build from. + + Returns: + A class describing the passed in descriptor. + """ + descriptor_name = descriptor.name + result_class = _GENERATED_PROTOCOL_MESSAGE_TYPE( + descriptor_name, + (message.Message,), + { + 'DESCRIPTOR': descriptor, + # If module not set, it wrongly points to message_factory module. + '__module__': None, + }) + for field in descriptor.fields: + if field.message_type: + GetMessageClass(field.message_type) + for extension in result_class.DESCRIPTOR.extensions: + extended_class = GetMessageClass(extension.containing_type) + extended_class.RegisterExtension(extension) + if extension.message_type: + GetMessageClass(extension.message_type) + return result_class + + +# Deprecated. Please use GetMessageClass() or GetMessageClassesForFiles() +# method above instead. class MessageFactory(object): """Factory for creating Proto2 messages from descriptors in a pool.""" @@ -72,18 +163,17 @@ def GetPrototype(self, descriptor): Returns: A class describing the passed in descriptor. """ - concrete_class = getattr(descriptor, '_concrete_class', None) - if concrete_class: - return concrete_class - result_class = self.CreatePrototype(descriptor) - return result_class + # TODO(b/258832141): add this warning + # warnings.warn('MessageFactory class is deprecated. Please use ' + # 'GetMessageClass() instead of MessageFactory.GetPrototype. ' + # 'MessageFactory class will be removed after 2024.') + return GetMessageClass(descriptor) def CreatePrototype(self, descriptor): """Builds a proto2 message class based on the passed in descriptor. Don't call this function directly, it always creates a new class. Call - GetPrototype() instead. This method is meant to be overridden in subblasses - to perform additional operations on the newly constructed class. + GetMessageClass() instead. Args: descriptor: The descriptor to build from. @@ -91,25 +181,11 @@ def CreatePrototype(self, descriptor): Returns: A class describing the passed in descriptor. """ - descriptor_name = descriptor.name - result_class = _GENERATED_PROTOCOL_MESSAGE_TYPE( - descriptor_name, - (message.Message,), - { - 'DESCRIPTOR': descriptor, - # If module not set, it wrongly points to message_factory module. - '__module__': None, - }) - result_class._FACTORY = self # pylint: disable=protected-access - for field in descriptor.fields: - if field.message_type: - self.GetPrototype(field.message_type) - for extension in result_class.DESCRIPTOR.extensions: - extended_class = self.GetPrototype(extension.containing_type) - extended_class.RegisterExtension(extension) - if extension.message_type: - self.GetPrototype(extension.message_type) - return result_class + # TODO(b/258832141): add this warning + # warnings.warn('Directly call CreatePrototype is wrong. Please use ' + # 'GetMessageClass() method instead. Directly use ' + # 'CreatePrototype will raise error after July 2023.') + return _InternalCreateMessageClass(descriptor) def GetMessages(self, files): """Gets all the messages from a specified file. @@ -125,37 +201,20 @@ def GetMessages(self, files): any dependent messages as well as any messages defined in the same file as a specified message. """ - result = {} - for file_name in files: - file_desc = self.pool.FindFileByName(file_name) - for desc in file_desc.message_types_by_name.values(): - result[desc.full_name] = self.GetPrototype(desc) - - # While the extension FieldDescriptors are created by the descriptor pool, - # the python classes created in the factory need them to be registered - # explicitly, which is done below. - # - # The call to RegisterExtension will specifically check if the - # extension was already registered on the object and either - # ignore the registration if the original was the same, or raise - # an error if they were different. - - for extension in file_desc.extensions_by_name.values(): - extended_class = self.GetPrototype(extension.containing_type) - extended_class.RegisterExtension(extension) - if extension.message_type: - self.GetPrototype(extension.message_type) - return result - - -_FACTORY = MessageFactory() + # TODO(b/258832141): add this warning + # warnings.warn('MessageFactory class is deprecated. Please use ' + # 'GetMessageClassesForFiles() instead of ' + # 'MessageFactory.GetMessages(). MessageFactory class ' + # 'will be removed after 2024.') + return GetMessageClassesForFiles(files, self.pool) -def GetMessages(file_protos): +def GetMessages(file_protos, pool=None): """Builds a dictionary of all the messages available in a set of files. Args: file_protos: Iterable of FileDescriptorProto to build messages out of. + pool: The descriptor pool to add the file protos. Returns: A dictionary mapping proto names to the message classes. This will include @@ -164,13 +223,15 @@ def GetMessages(file_protos): """ # The cpp implementation of the protocol buffer library requires to add the # message in topological order of the dependency graph. + des_pool = pool or descriptor_pool.DescriptorPool() file_by_name = {file_proto.name: file_proto for file_proto in file_protos} def _AddFile(file_proto): for dependency in file_proto.dependency: if dependency in file_by_name: # Remove from elements to be visited, in order to cut cycles. _AddFile(file_by_name.pop(dependency)) - _FACTORY.pool.Add(file_proto) + des_pool.Add(file_proto) while file_by_name: _AddFile(file_by_name.popitem()[1]) - return _FACTORY.GetMessages([file_proto.name for file_proto in file_protos]) + return GetMessageClassesForFiles( + [file_proto.name for file_proto in file_protos], des_pool) diff --git a/python/google/protobuf/proto_builder.py b/python/google/protobuf/proto_builder.py index a4667ce63ec3..8dab8b3ee084 100644 --- a/python/google/protobuf/proto_builder.py +++ b/python/google/protobuf/proto_builder.py @@ -36,22 +36,23 @@ from google.protobuf import descriptor_pb2 from google.protobuf import descriptor +from google.protobuf import descriptor_pool from google.protobuf import message_factory -def _GetMessageFromFactory(factory, full_name): +def _GetMessageFromFactory(pool, full_name): """Get a proto class from the MessageFactory by name. Args: - factory: a MessageFactory instance. + pool: a descriptor pool. full_name: str, the fully qualified name of the proto type. Returns: A class, for the type identified by full_name. Raises: KeyError, if the proto is not found in the factory's descriptor pool. """ - proto_descriptor = factory.pool.FindMessageTypeByName(full_name) - proto_cls = factory.GetPrototype(proto_descriptor) + proto_descriptor = pool.FindMessageTypeByName(full_name) + proto_cls = message_factory.GetMessageClass(proto_descriptor) return proto_cls @@ -69,11 +70,10 @@ def MakeSimpleProtoClass(fields, full_name=None, pool=None): Returns: a class, the new protobuf class with a FileDescriptor. """ - factory = message_factory.MessageFactory(pool=pool) - + pool_instance = pool or descriptor_pool.DescriptorPool() if full_name is not None: try: - proto_cls = _GetMessageFromFactory(factory, full_name) + proto_cls = _GetMessageFromFactory(pool_instance, full_name) return proto_cls except KeyError: # The factory's DescriptorPool doesn't know about this class yet. @@ -99,16 +99,16 @@ def MakeSimpleProtoClass(fields, full_name=None, pool=None): full_name = ('net.proto2.python.public.proto_builder.AnonymousProto_' + fields_hash.hexdigest()) try: - proto_cls = _GetMessageFromFactory(factory, full_name) + proto_cls = _GetMessageFromFactory(pool_instance, full_name) return proto_cls except KeyError: # The factory's DescriptorPool doesn't know about this class yet. pass # This is the first time we see this proto: add a new descriptor to the pool. - factory.pool.Add( + pool_instance.Add( _MakeFileDescriptorProto(proto_file_name, full_name, field_items)) - return _GetMessageFromFactory(factory, full_name) + return _GetMessageFromFactory(pool_instance, full_name) def _MakeFileDescriptorProto(proto_file_name, full_name, field_items): diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py index 81e18859a804..1627669b955d 100644 --- a/python/google/protobuf/reflection.py +++ b/python/google/protobuf/reflection.py @@ -92,4 +92,4 @@ def MakeClass(descriptor): # Original implementation leads to duplicate message classes, which won't play # well with extensions. Message factory info is also missing. # Redirect to message_factory. - return symbol_database.Default().GetPrototype(descriptor) + return message_factory.GetMessageClass(descriptor) diff --git a/python/google/protobuf/symbol_database.py b/python/google/protobuf/symbol_database.py index ed5fce39e9f3..390c49810df9 100644 --- a/python/google/protobuf/symbol_database.py +++ b/python/google/protobuf/symbol_database.py @@ -57,18 +57,41 @@ my_message_instance = db.GetSymbol('MyMessage')() """ +import warnings from google.protobuf.internal import api_implementation from google.protobuf import descriptor_pool from google.protobuf import message_factory -class SymbolDatabase(message_factory.MessageFactory): +class SymbolDatabase(): """A database of Python generated symbols.""" # local cache of registered classes. _classes = {} + def __init__(self, pool=None): + """Initializes a new SymbolDatabase.""" + self.pool = pool or descriptor_pool.DescriptorPool() + + def GetPrototype(self, descriptor): + warnings.warn('SymbolDatabase.GetPrototype() is deprecated. Please ' + 'use message_factory.GetMessageClass() instead. ' + 'SymbolDatabase.GetPrototype() will be removed soon.') + return message_factory.GetMessageClass(descriptor) + + def CreatePrototype(self, descriptor): + warnings.warn('Directly call CreatePrototype() is wrong. Please use ' + 'message_factory.GetMessageClass() instead. ' + 'SymbolDatabase.CreatePrototype() will be removed soon.') + return message_factory._InternalCreateMessageClass(descriptor) + + def GetMessages(self, files): + warnings.warn('SymbolDatabase.GetMessages() is deprecated. Please use ' + 'message_factory.GetMessageClassedForFiles() instead. ' + 'SymbolDatabase.GetMessages() will be removed soon.') + return message_factory.GetMessageClassedForFiles(files, self.pool) + def RegisterMessage(self, message): """Registers the given message type in the local database. diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py index e8f9b16e7294..e1a5ad544923 100644 --- a/python/google/protobuf/text_format.py +++ b/python/google/protobuf/text_format.py @@ -330,13 +330,12 @@ def _BuildMessageFromTypeName(type_name, descriptor_pool): if descriptor_pool is None: from google.protobuf import descriptor_pool as pool_mod descriptor_pool = pool_mod.Default() - from google.protobuf import symbol_database - database = symbol_database.Default() + from google.protobuf import message_factory try: message_descriptor = descriptor_pool.FindMessageTypeByName(type_name) except KeyError: return None - message_type = database.GetPrototype(message_descriptor) + message_type = message_factory.GetMessageClass(message_descriptor) return message_type()