Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for extensions in CRuby, JRuby, and FFI Ruby (#14703) #14756

Merged
merged 1 commit into from Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 8 additions & 2 deletions ruby/ext/google/protobuf_c/defs.c
Expand Up @@ -144,20 +144,26 @@ VALUE DescriptorPool_add_serialized_file(VALUE _self,
* call-seq:
* DescriptorPool.lookup(name) => descriptor
*
* Finds a Descriptor or EnumDescriptor by name and returns it, or nil if none
* exists with the given name.
* Finds a Descriptor, EnumDescriptor or FieldDescriptor by name and returns it,
* or nil if none exists with the given name.
*/
static VALUE DescriptorPool_lookup(VALUE _self, VALUE name) {
DescriptorPool* self = ruby_to_DescriptorPool(_self);
const char* name_str = get_str(name);
const upb_MessageDef* msgdef;
const upb_EnumDef* enumdef;
const upb_FieldDef* fielddef;

msgdef = upb_DefPool_FindMessageByName(self->symtab, name_str);
if (msgdef) {
return get_msgdef_obj(_self, msgdef);
}

fielddef = upb_DefPool_FindExtensionByName(self->symtab, name_str);
if (fielddef) {
return get_fielddef_obj(_self, fielddef);
}

enumdef = upb_DefPool_FindEnumByName(self->symtab, name_str);
if (enumdef) {
return get_enumdef_obj(_self, enumdef);
Expand Down
10 changes: 8 additions & 2 deletions ruby/ext/google/protobuf_c/message.c
Expand Up @@ -977,9 +977,12 @@ VALUE Message_decode_bytes(int size, const char* bytes, int options,
VALUE msg_rb = initialize_rb_class_with_no_args(klass);
Message* msg = ruby_to_Message(msg_rb);

const upb_FileDef* file = upb_MessageDef_File(msg->msgdef);
const upb_ExtensionRegistry* extreg =
upb_DefPool_ExtensionRegistry(upb_FileDef_Pool(file));
upb_DecodeStatus status = upb_Decode(bytes, size, (upb_Message*)msg->msg,
upb_MessageDef_MiniTable(msg->msgdef),
NULL, options, Arena_get(msg->arena));
extreg, options, Arena_get(msg->arena));
if (status != kUpb_DecodeStatus_Ok) {
rb_raise(cParseError, "Error occurred during parsing");
}
Expand Down Expand Up @@ -1303,9 +1306,12 @@ upb_Message* Message_deep_copy(const upb_Message* msg, const upb_MessageDef* m,
upb_Message* new_msg = upb_Message_New(layout, arena);
char* data;

const upb_FileDef* file = upb_MessageDef_File(m);
const upb_ExtensionRegistry* extreg =
upb_DefPool_ExtensionRegistry(upb_FileDef_Pool(file));
if (upb_Encode(msg, layout, 0, tmp_arena, &data, &size) !=
kUpb_EncodeStatus_Ok ||
upb_Decode(data, size, new_msg, layout, NULL, 0, arena) !=
upb_Decode(data, size, new_msg, layout, extreg, 0, arena) !=
kUpb_DecodeStatus_Ok) {
upb_Arena_Free(tmp_arena);
rb_raise(cParseError, "Error occurred copying proto");
Expand Down
20 changes: 12 additions & 8 deletions ruby/lib/google/protobuf/ffi/descriptor_pool.rb
Expand Up @@ -9,13 +9,16 @@ module Google
module Protobuf
class FFI
# DefPool
attach_function :add_serialized_file, :upb_DefPool_AddFile, [:DefPool, :FileDescriptorProto, Status.by_ref], :FileDef
attach_function :free_descriptor_pool, :upb_DefPool_Free, [:DefPool], :void
attach_function :create_descriptor_pool,:upb_DefPool_New, [], :DefPool
attach_function :lookup_enum, :upb_DefPool_FindEnumByName, [:DefPool, :string], EnumDescriptor
attach_function :lookup_msg, :upb_DefPool_FindMessageByName, [:DefPool, :string], Descriptor
# FileDescriptorProto
attach_function :parse, :FileDescriptorProto_parse, [:binary_string, :size_t, Internal::Arena], :FileDescriptorProto
attach_function :add_serialized_file, :upb_DefPool_AddFile, [:DefPool, :FileDescriptorProto, Status.by_ref], :FileDef
attach_function :free_descriptor_pool, :upb_DefPool_Free, [:DefPool], :void
attach_function :create_descriptor_pool,:upb_DefPool_New, [], :DefPool
attach_function :get_extension_registry,:upb_DefPool_ExtensionRegistry, [:DefPool], :ExtensionRegistry
attach_function :lookup_enum, :upb_DefPool_FindEnumByName, [:DefPool, :string], EnumDescriptor
attach_function :lookup_extension, :upb_DefPool_FindExtensionByName,[:DefPool, :string], FieldDescriptor
attach_function :lookup_msg, :upb_DefPool_FindMessageByName, [:DefPool, :string], Descriptor

# FileDescriptorProto
attach_function :parse, :FileDescriptorProto_parse, [:binary_string, :size_t, Internal::Arena], :FileDescriptorProto
end
class DescriptorPool
attr :descriptor_pool
Expand Down Expand Up @@ -50,7 +53,8 @@ def add_serialized_file(file_contents)

def lookup name
Google::Protobuf::FFI.lookup_msg(@descriptor_pool, name) ||
Google::Protobuf::FFI.lookup_enum(@descriptor_pool, name)
Google::Protobuf::FFI.lookup_enum(@descriptor_pool, name) ||
Google::Protobuf::FFI.lookup_extension(@descriptor_pool, name)
end

def self.generated_pool
Expand Down
10 changes: 9 additions & 1 deletion ruby/lib/google/protobuf/ffi/message.rb
Expand Up @@ -170,7 +170,15 @@ def self.decode(data, options = {})

message = new
mini_table_ptr = Google::Protobuf::FFI.get_mini_table(message.class.descriptor)
status = Google::Protobuf::FFI.decode_message(data, data.bytesize, message.instance_variable_get(:@msg), mini_table_ptr, nil, decoding_options, message.instance_variable_get(:@arena))
status = Google::Protobuf::FFI.decode_message(
data,
data.bytesize,
message.instance_variable_get(:@msg),
mini_table_ptr,
Google::Protobuf::FFI.get_extension_registry(message.class.descriptor.send(:pool).descriptor_pool),
decoding_options,
message.instance_variable_get(:@arena)
)
raise ParseError.new "Error occurred during parsing" unless status == :Ok
message
end
Expand Down
Expand Up @@ -36,7 +36,9 @@
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.DescriptorValidationException;
import com.google.protobuf.Descriptors.EnumDescriptor;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.InvalidProtocolBufferException;
import java.util.ArrayList;
import java.util.HashMap;
Expand Down Expand Up @@ -70,6 +72,7 @@ public IRubyObject allocate(Ruby runtime, RubyClass klazz) {
cDescriptorPool.newInstance(runtime.getCurrentContext(), Block.NULL_BLOCK);
cDescriptor = (RubyClass) runtime.getClassFromPath("Google::Protobuf::Descriptor");
cEnumDescriptor = (RubyClass) runtime.getClassFromPath("Google::Protobuf::EnumDescriptor");
cFieldDescriptor = (RubyClass) runtime.getClassFromPath("Google::Protobuf::FieldDescriptor");
}

public RubyDescriptorPool(Ruby runtime, RubyClass klazz) {
Expand All @@ -92,7 +95,7 @@ public IRubyObject build(ThreadContext context, Block block) {
* call-seq:
* DescriptorPool.lookup(name) => descriptor
*
* Finds a Descriptor or EnumDescriptor by name and returns it, or nil if none
* Finds a Descriptor, EnumDescriptor or FieldDescriptor by name and returns it, or nil if none
* exists with the given name.
*
* This currently lazy loads the ruby descriptor objects as they are requested.
Expand Down Expand Up @@ -121,7 +124,8 @@ public static IRubyObject generatedPool(ThreadContext context, IRubyObject recv)
public IRubyObject add_serialized_file(ThreadContext context, IRubyObject data) {
byte[] bin = data.convertToString().getBytes();
try {
FileDescriptorProto.Builder builder = FileDescriptorProto.newBuilder().mergeFrom(bin);
FileDescriptorProto.Builder builder =
FileDescriptorProto.newBuilder().mergeFrom(bin, registry);
registerFileDescriptor(context, builder);
} catch (InvalidProtocolBufferException e) {
throw RaiseException.from(
Expand Down Expand Up @@ -150,6 +154,8 @@ protected void registerFileDescriptor(
for (EnumDescriptor ed : fd.getEnumTypes()) registerEnumDescriptor(context, ed, packageName);
for (Descriptor message : fd.getMessageTypes())
registerDescriptor(context, message, packageName);
for (FieldDescriptor fieldDescriptor : fd.getExtensions())
registerExtension(context, fieldDescriptor, packageName);

// Mark this as a loaded file
fileDescriptors.add(fd);
Expand All @@ -170,6 +176,24 @@ private void registerDescriptor(ThreadContext context, Descriptor descriptor, St
registerEnumDescriptor(context, ed, fullPath);
for (Descriptor message : descriptor.getNestedTypes())
registerDescriptor(context, message, fullPath);
for (FieldDescriptor fieldDescriptor : descriptor.getExtensions())
registerExtension(context, fieldDescriptor, fullPath);
}

private void registerExtension(
ThreadContext context, FieldDescriptor descriptor, String parentPath) {
if (descriptor.getJavaType() == FieldDescriptor.JavaType.MESSAGE) {
registry.add(descriptor, descriptor.toProto());
} else {
registry.add(descriptor);
}
RubyString name = context.runtime.newString(parentPath + descriptor.getName());
RubyFieldDescriptor des =
(RubyFieldDescriptor) cFieldDescriptor.newInstance(context, Block.NULL_BLOCK);
des.setName(name);
des.setDescriptor(context, descriptor, this);
// For MessageSet extensions, there is the possibility of a name conflict. Prefer the Message.
symtab.putIfAbsent(name, des);
}

private void registerEnumDescriptor(
Expand All @@ -188,8 +212,10 @@ private FileDescriptor[] existingFileDescriptors() {

private static RubyClass cDescriptor;
private static RubyClass cEnumDescriptor;
private static RubyClass cFieldDescriptor;
private static RubyDescriptorPool descriptorPool;

private List<FileDescriptor> fileDescriptors;
private Map<IRubyObject, IRubyObject> symtab;
protected static final ExtensionRegistry registry = ExtensionRegistry.newInstance();
}
Expand Up @@ -103,6 +103,10 @@ public IRubyObject getName(ThreadContext context) {
return this.name;
}

protected void setName(IRubyObject name) {
this.name = name;
}

/*
* call-seq:
* FieldDescriptor.subtype => message_or_enum_descriptor
Expand Down Expand Up @@ -229,7 +233,7 @@ public IRubyObject has(ThreadContext context, IRubyObject message) {
*/
@JRubyMethod(name = "set")
public IRubyObject setValue(ThreadContext context, IRubyObject message, IRubyObject value) {
((RubyMessage) message).setField(context, descriptor, value);
((RubyMessage) message).setField(context, this, value);
return context.nil;
}

Expand Down Expand Up @@ -263,6 +267,10 @@ protected void setDescriptor(
this.pool = pool;
}

protected FieldDescriptor getDescriptor() {
return descriptor;
}

private void calculateLabel(ThreadContext context) {
if (descriptor.isRepeated()) {
this.label = context.runtime.newSymbol("repeated");
Expand Down
24 changes: 21 additions & 3 deletions ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java
Expand Up @@ -634,7 +634,7 @@ public static IRubyObject decode(ThreadContext context, IRubyObject recv, IRubyO
public static IRubyObject decodeBytes(
ThreadContext context, RubyMessage ret, CodedInputStream input, boolean freeze) {
try {
ret.builder.mergeFrom(input);
ret.builder.mergeFrom(input, RubyDescriptorPool.registry);
} catch (Exception e) {
throw RaiseException.from(
context.runtime,
Expand Down Expand Up @@ -965,6 +965,12 @@ protected IRubyObject setField(
return setFieldInternal(context, fieldDescriptor, value);
}

protected IRubyObject setField(
ThreadContext context, RubyFieldDescriptor fieldDescriptor, IRubyObject value) {
validateMessageType(context, fieldDescriptor.getDescriptor(), "set");
return setFieldInternal(context, fieldDescriptor.getDescriptor(), fieldDescriptor, value);
}

private RubyRepeatedField getRepeatedField(
ThreadContext context, FieldDescriptor fieldDescriptor) {
if (fields.containsKey(fieldDescriptor)) {
Expand Down Expand Up @@ -1275,6 +1281,14 @@ private IRubyObject getFieldInternal(

private IRubyObject setFieldInternal(
ThreadContext context, FieldDescriptor fieldDescriptor, IRubyObject value) {
return setFieldInternal(context, fieldDescriptor, null, value);
}

private IRubyObject setFieldInternal(
ThreadContext context,
FieldDescriptor fieldDescriptor,
RubyFieldDescriptor rubyFieldDescriptor,
IRubyObject value) {
testFrozen("can't modify frozen " + getMetaClass());

if (fieldDescriptor.isMapField()) {
Expand All @@ -1299,8 +1313,12 @@ private IRubyObject setFieldInternal(
// Determine the typeclass, if any
IRubyObject typeClass = context.runtime.getObject();
if (fieldType == FieldDescriptor.Type.MESSAGE) {
typeClass =
((RubyDescriptor) getDescriptorForField(context, fieldDescriptor)).msgclass(context);
if (rubyFieldDescriptor != null) {
typeClass = ((RubyDescriptor) rubyFieldDescriptor.getSubtype(context)).msgclass(context);
} else {
typeClass =
((RubyDescriptor) getDescriptorForField(context, fieldDescriptor)).msgclass(context);
}
if (value.isNil()) {
addValue = false;
}
Expand Down
32 changes: 32 additions & 0 deletions ruby/tests/basic.rb
Expand Up @@ -729,6 +729,19 @@ def test_oneof_descriptor_options
oneof_descriptor = descriptor.lookup_oneof("test_deprecated_message_oneof")

assert_instance_of Google::Protobuf::OneofOptions, oneof_descriptor.options
test_top_level_option = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test.test_top_level_option'
assert_instance_of Google::Protobuf::FieldDescriptor, test_top_level_option
assert_equal "Custom option value", test_top_level_option.get(oneof_descriptor.options)
end

def test_nested_extension
descriptor = TestDeprecatedMessage.descriptor
oneof_descriptor = descriptor.lookup_oneof("test_deprecated_message_oneof")

assert_instance_of Google::Protobuf::OneofOptions, oneof_descriptor.options
test_nested_option = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test.TestDeprecatedMessage.test_nested_option'
assert_instance_of Google::Protobuf::FieldDescriptor, test_nested_option
assert_equal "Another custom option value", test_nested_option.get(oneof_descriptor.options)
end

def test_options_deep_freeze
Expand All @@ -739,6 +752,25 @@ def test_options_deep_freeze
Google::Protobuf::UninterpretedOption.new
end
end

def test_message_deep_freeze
message = TestDeprecatedMessage.new
omit(":internal_deep_freeze only exists under FFI") unless message.respond_to? :internal_deep_freeze, true
nested_message_2 = TestMessage2.new

message.map_string_msg["message"] = TestMessage2.new
message.repeated_msg.push(TestMessage2.new)

message.send(:internal_deep_freeze)

assert_raise FrozenError do
message.map_string_msg["message"].foo = "bar"
end

assert_raise FrozenError do
message.repeated_msg[0].foo = "bar"
end
end
end

def test_oneof_fields_respond_to? # regression test for issue 9202
Expand Down
52 changes: 52 additions & 0 deletions ruby/tests/basic_proto2.rb
Expand Up @@ -269,5 +269,57 @@ def test_oneof_fields_respond_to? # regression test for issue 9202
assert msg.respond_to? :has_d?
refute msg.has_d?
end

def test_extension
message = TestExtensions.new
extension = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test_proto2.optional_int32_extension'
assert_instance_of Google::Protobuf::FieldDescriptor, extension
assert_equal 0, extension.get(message)
extension.set message, 42
assert_equal 42, extension.get(message)
end

def test_nested_extension
message = TestExtensions.new
extension = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test_proto2.TestNestedExtension.test'
assert_instance_of Google::Protobuf::FieldDescriptor, extension
assert_equal 'test', extension.get(message)
extension.set message, 'another test'
assert_equal 'another test', extension.get(message)
end

def test_message_set_extension_json_roundtrip
omit "Java Protobuf JsonFormat does not handle Proto2 extensions" if defined? JRUBY_VERSION and :NATIVE == Google::Protobuf::IMPLEMENTATION
message = TestMessageSet.new
ext1 = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test_proto2.TestMessageSetExtension1.message_set_extension'
assert_instance_of Google::Protobuf::FieldDescriptor, ext1
ext2 = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test_proto2.TestMessageSetExtension2.message_set_extension'
assert_instance_of Google::Protobuf::FieldDescriptor, ext2
ext3 = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test_proto2.message_set_extension3'
assert_instance_of Google::Protobuf::FieldDescriptor, ext3
ext1.set(message, ext1.subtype.msgclass.new(i: 42))
ext2.set(message, ext2.subtype.msgclass.new(str: 'foo'))
ext3.set(message, ext3.subtype.msgclass.new(text: 'bar'))
message_text = message.to_json
parsed_message = TestMessageSet.decode_json message_text
assert_equal message, parsed_message
end


def test_message_set_extension_roundtrip
message = TestMessageSet.new
ext1 = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test_proto2.TestMessageSetExtension1.message_set_extension'
assert_instance_of Google::Protobuf::FieldDescriptor, ext1
ext2 = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test_proto2.TestMessageSetExtension2.message_set_extension'
assert_instance_of Google::Protobuf::FieldDescriptor, ext2
ext3 = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test_proto2.message_set_extension3'
assert_instance_of Google::Protobuf::FieldDescriptor, ext3
ext1.set(message, ext1.subtype.msgclass.new(i: 42))
ext2.set(message, ext2.subtype.msgclass.new(str: 'foo'))
ext3.set(message, ext3.subtype.msgclass.new(text: 'bar'))
encoded_message = TestMessageSet.encode message
decoded_message = TestMessageSet.decode encoded_message
assert_equal message, decoded_message
end
end
end