diff --git a/conformance/internal/conformance/clonevt_test.go b/conformance/internal/conformance/clonevt_test.go index 6f186e5..64eeb71 100644 --- a/conformance/internal/conformance/clonevt_test.go +++ b/conformance/internal/conformance/clonevt_test.go @@ -4,7 +4,9 @@ import ( "fmt" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/wrapperspb" @@ -97,3 +99,29 @@ func TestCloneVT3(t *testing.T) { require.False(t, clone.EqualVT(msg), "cloned message unchanged after mutation") require.True(t, orig.EqualVT(msg), "mutating cloned %T mutated original:\nmsg = %+v\nafter clone = %+v\n", msg, msg, orig) } + +func TestCloneVT_UnknownFields(t *testing.T) { + msg := &TestAllTypesProto3{ + OptionalInt32: 42, + } + + const unknownFieldNumber = 1337 + require.Nil(t, msg.ProtoReflect().Descriptor().Fields().ByNumber(unknownFieldNumber), + "if this assertion fails, please change the above constant to a field number not used in the proto") + data, err := msg.MarshalVT() + require.NoError(t, err) + + data = protowire.AppendTag(data, unknownFieldNumber, protowire.BytesType) + data = protowire.AppendString(data, "foo bar baz") + + unmarshaled := new(TestAllTypesProto3) + require.NoError(t, unmarshaled.UnmarshalVT(data), "unmarshaling should succeed") + + cloned := unmarshaled.CloneVT() + assert.Truef(t, proto.Equal(unmarshaled, cloned), "expected %T to be equal:\nunmarshaled = %+v\ncloned = %+v\n", unmarshaled, unmarshaled, cloned) + + protoCloned := proto.Clone(unmarshaled).(*TestAllTypesProto3) + require.True(t, proto.Equal(unmarshaled, protoCloned), "proto.Clone is misbehaving") + + assert.Truef(t, proto.Equal(cloned, protoCloned), "expected %T to be equal:\ncloned = %+v\nprotoCloned = %+v\n", cloned, cloned, protoCloned) +} diff --git a/conformance/internal/conformance/conformance_vtproto.pb.go b/conformance/internal/conformance/conformance_vtproto.pb.go index 0bd1ca2..f14fdc0 100644 --- a/conformance/internal/conformance/conformance_vtproto.pb.go +++ b/conformance/internal/conformance/conformance_vtproto.pb.go @@ -28,6 +28,10 @@ func (m *FailureSet) CloneVT() *FailureSet { copy(tmpContainer, rhs) r.Failure = tmpContainer } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -51,6 +55,10 @@ func (m *ConformanceRequest) CloneVT() *ConformanceRequest { CloneVT() isConformanceRequest_Payload }).CloneVT() } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -111,6 +119,10 @@ func (m *ConformanceResponse) CloneVT() *ConformanceResponse { CloneVT() isConformanceResponse_Result }).CloneVT() } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -208,6 +220,10 @@ func (m *JspbEncodingConfig) CloneVT() *JspbEncodingConfig { r := &JspbEncodingConfig{ UseJspbArrayAnyFormat: m.UseJspbArrayAnyFormat, } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } diff --git a/conformance/internal/conformance/test_messages_proto2_vtproto.pb.go b/conformance/internal/conformance/test_messages_proto2_vtproto.pb.go index a81caae..df9024d 100644 --- a/conformance/internal/conformance/test_messages_proto2_vtproto.pb.go +++ b/conformance/internal/conformance/test_messages_proto2_vtproto.pb.go @@ -32,6 +32,10 @@ func (m *TestAllTypesProto2_NestedMessage) CloneVT() *TestAllTypesProto2_NestedM tmpVal := *rhs r.A = &tmpVal } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -52,6 +56,10 @@ func (m *TestAllTypesProto2_Data) CloneVT() *TestAllTypesProto2_Data { tmpVal := *rhs r.GroupUint32 = &tmpVal } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -64,6 +72,10 @@ func (m *TestAllTypesProto2_MessageSetCorrect) CloneVT() *TestAllTypesProto2_Mes return (*TestAllTypesProto2_MessageSetCorrect)(nil) } r := &TestAllTypesProto2_MessageSetCorrect{} + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -80,6 +92,10 @@ func (m *TestAllTypesProto2_MessageSetCorrectExtension1) CloneVT() *TestAllTypes tmpVal := *rhs r.Str = &tmpVal } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -96,6 +112,10 @@ func (m *TestAllTypesProto2_MessageSetCorrectExtension2) CloneVT() *TestAllTypes tmpVal := *rhs r.I = &tmpVal } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -716,6 +736,10 @@ func (m *TestAllTypesProto2) CloneVT() *TestAllTypesProto2 { tmpVal := *rhs r.FieldName18__ = &tmpVal } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -825,6 +849,10 @@ func (m *ForeignMessageProto2) CloneVT() *ForeignMessageProto2 { tmpVal := *rhs r.C = &tmpVal } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -841,6 +869,10 @@ func (m *UnknownToTestAllTypes_OptionalGroup) CloneVT() *UnknownToTestAllTypes_O tmpVal := *rhs r.A = &tmpVal } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -873,6 +905,10 @@ func (m *UnknownToTestAllTypes) CloneVT() *UnknownToTestAllTypes { copy(tmpContainer, rhs) r.RepeatedInt32 = tmpContainer } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -885,6 +921,10 @@ func (m *NullHypothesisProto2) CloneVT() *NullHypothesisProto2 { return (*NullHypothesisProto2)(nil) } r := &NullHypothesisProto2{} + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -897,6 +937,10 @@ func (m *EnumOnlyProto2) CloneVT() *EnumOnlyProto2 { return (*EnumOnlyProto2)(nil) } r := &EnumOnlyProto2{} + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -913,6 +957,10 @@ func (m *OneStringProto2) CloneVT() *OneStringProto2 { tmpVal := *rhs r.Data = &tmpVal } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } diff --git a/conformance/internal/conformance/test_messages_proto3_vtproto.pb.go b/conformance/internal/conformance/test_messages_proto3_vtproto.pb.go index bd6d150..b29b049 100644 --- a/conformance/internal/conformance/test_messages_proto3_vtproto.pb.go +++ b/conformance/internal/conformance/test_messages_proto3_vtproto.pb.go @@ -34,6 +34,10 @@ func (m *TestAllTypesProto3_NestedMessage) CloneVT() *TestAllTypesProto3_NestedM A: m.A, Corecursive: m.Corecursive.CloneVT(), } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -783,6 +787,10 @@ func (m *TestAllTypesProto3) CloneVT() *TestAllTypesProto3 { } r.RepeatedListValue = tmpContainer } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -900,6 +908,10 @@ func (m *ForeignMessage) CloneVT() *ForeignMessage { r := &ForeignMessage{ C: m.C, } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -912,6 +924,10 @@ func (m *NullHypothesisProto3) CloneVT() *NullHypothesisProto3 { return (*NullHypothesisProto3)(nil) } r := &NullHypothesisProto3{} + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -924,6 +940,10 @@ func (m *EnumOnlyProto3) CloneVT() *EnumOnlyProto3 { return (*EnumOnlyProto3)(nil) } r := &EnumOnlyProto3{} + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } diff --git a/features/clone/clone.go b/features/clone/clone.go index 693a491..6782a26 100644 --- a/features/clone/clone.go +++ b/features/clone/clone.go @@ -144,7 +144,7 @@ func (p *clone) cloneField(lhsBase, rhsBase string, allFieldsNullable bool, fiel func (p *clone) generateCloneMethodsForMessage(proto3 bool, message *protogen.Message) { ccTypeName := message.GoIdent.GoName p.P(`func (m *`, ccTypeName, `) `, cloneName, `() *`, ccTypeName, ` {`) - p.body(!proto3, ccTypeName, message.Fields) + p.body(!proto3, ccTypeName, message.Fields, true) p.P(`}`) p.P() p.P(`func (m *`, ccTypeName, `) `, cloneGenericName, `() `, protoPkg.Ident("Message"), ` {`) @@ -156,7 +156,7 @@ func (p *clone) generateCloneMethodsForMessage(proto3 bool, message *protogen.Me // body generates the code for the actual cloning logic of a structure containing the given fields. // In practice, those can be the fields of a message, or of a oneof struct. // The object to be cloned is assumed to be called "m". -func (p *clone) body(allFieldsNullable bool, ccTypeName string, fields []*protogen.Field) { +func (p *clone) body(allFieldsNullable bool, ccTypeName string, fields []*protogen.Field, cloneUnknownFields bool) { // The method body for a message or a oneof wrapper always starts with a nil check. p.P(`if m == nil {`) // We use an explicitly typed nil to avoid returning the nil interface in the oneof wrapper @@ -200,6 +200,14 @@ func (p *clone) body(allFieldsNullable bool, ccTypeName string, fields []*protog p.cloneField("r", "m", allFieldsNullable, field) } + if cloneUnknownFields { + // Clone unknown fields, if any + p.P(`if len(m.unknownFields) > 0 {`) + p.P(`r.unknownFields = make([]byte, len(m.unknownFields))`) + p.P(`copy(r.unknownFields, m.unknownFields)`) + p.P(`}`) + } + p.P(`return r`) } @@ -214,7 +222,7 @@ func (p *clone) generateCloneMethodsForOneof(field *protogen.Field) { fieldInOneof := *field fieldInOneof.Oneof = nil // If we have a scalar field in a oneof, that field is never nullable, even when using proto2 - p.body(false, ccTypeName, []*protogen.Field{&fieldInOneof}) + p.body(false, ccTypeName, []*protogen.Field{&fieldInOneof}, false) p.P(`}`) p.P() } diff --git a/testproto/pool/pool_vtproto.pb.go b/testproto/pool/pool_vtproto.pb.go index 3d17eea..3224fd9 100644 --- a/testproto/pool/pool_vtproto.pb.go +++ b/testproto/pool/pool_vtproto.pb.go @@ -28,6 +28,10 @@ func (m *MemoryPoolExtension) CloneVT() *MemoryPoolExtension { Foo1: m.Foo1, Foo2: m.Foo2, } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } diff --git a/testproto/pool/pool_with_slice_reuse_vtproto.pb.go b/testproto/pool/pool_with_slice_reuse_vtproto.pb.go index da11a95..931aedb 100644 --- a/testproto/pool/pool_with_slice_reuse_vtproto.pb.go +++ b/testproto/pool/pool_with_slice_reuse_vtproto.pb.go @@ -29,6 +29,10 @@ func (m *Test1) CloneVT() *Test1 { copy(tmpContainer, rhs) r.Sl = tmpContainer } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -48,6 +52,10 @@ func (m *Test2) CloneVT() *Test2 { } r.Sl = tmpContainer } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -80,6 +88,10 @@ func (m *Slice2) CloneVT() *Slice2 { copy(tmpContainer, rhs) r.C = tmpContainer } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -94,6 +106,10 @@ func (m *Element2) CloneVT() *Element2 { r := &Element2{ A: m.A, } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } diff --git a/testproto/proto2/scalars_vtproto.pb.go b/testproto/proto2/scalars_vtproto.pb.go index 1c97a1a..021313c 100644 --- a/testproto/proto2/scalars_vtproto.pb.go +++ b/testproto/proto2/scalars_vtproto.pb.go @@ -44,6 +44,10 @@ func (m *DoubleMessage) CloneVT() *DoubleMessage { copy(tmpContainer, rhs) r.PackedField = tmpContainer } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -74,6 +78,10 @@ func (m *FloatMessage) CloneVT() *FloatMessage { copy(tmpContainer, rhs) r.PackedField = tmpContainer } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -104,6 +112,10 @@ func (m *Int32Message) CloneVT() *Int32Message { copy(tmpContainer, rhs) r.PackedField = tmpContainer } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -134,6 +146,10 @@ func (m *Int64Message) CloneVT() *Int64Message { copy(tmpContainer, rhs) r.PackedField = tmpContainer } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -164,6 +180,10 @@ func (m *Uint32Message) CloneVT() *Uint32Message { copy(tmpContainer, rhs) r.PackedField = tmpContainer } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -194,6 +214,10 @@ func (m *Uint64Message) CloneVT() *Uint64Message { copy(tmpContainer, rhs) r.PackedField = tmpContainer } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -224,6 +248,10 @@ func (m *Sint32Message) CloneVT() *Sint32Message { copy(tmpContainer, rhs) r.PackedField = tmpContainer } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -254,6 +282,10 @@ func (m *Sint64Message) CloneVT() *Sint64Message { copy(tmpContainer, rhs) r.PackedField = tmpContainer } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -284,6 +316,10 @@ func (m *Fixed32Message) CloneVT() *Fixed32Message { copy(tmpContainer, rhs) r.PackedField = tmpContainer } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -314,6 +350,10 @@ func (m *Fixed64Message) CloneVT() *Fixed64Message { copy(tmpContainer, rhs) r.PackedField = tmpContainer } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -344,6 +384,10 @@ func (m *Sfixed32Message) CloneVT() *Sfixed32Message { copy(tmpContainer, rhs) r.PackedField = tmpContainer } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -374,6 +418,10 @@ func (m *Sfixed64Message) CloneVT() *Sfixed64Message { copy(tmpContainer, rhs) r.PackedField = tmpContainer } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -404,6 +452,10 @@ func (m *BoolMessage) CloneVT() *BoolMessage { copy(tmpContainer, rhs) r.PackedField = tmpContainer } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -429,6 +481,10 @@ func (m *StringMessage) CloneVT() *StringMessage { copy(tmpContainer, rhs) r.RepeatedField = tmpContainer } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -460,6 +516,10 @@ func (m *BytesMessage) CloneVT() *BytesMessage { } r.RepeatedField = tmpContainer } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } @@ -490,6 +550,10 @@ func (m *EnumMessage) CloneVT() *EnumMessage { copy(tmpContainer, rhs) r.PackedField = tmpContainer } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r } diff --git a/testproto/proto3opt/opt_vtproto.pb.go b/testproto/proto3opt/opt_vtproto.pb.go index a30abc2..f04dae0 100644 --- a/testproto/proto3opt/opt_vtproto.pb.go +++ b/testproto/proto3opt/opt_vtproto.pb.go @@ -91,6 +91,10 @@ func (m *OptionalFieldInProto3) CloneVT() *OptionalFieldInProto3 { tmpVal := *rhs r.OptionalEnum = &tmpVal } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } return r }