From c572de671750a4135892916965c2b90c37bb37de Mon Sep 17 00:00:00 2001 From: Kush Date: Thu, 9 Dec 2021 11:47:30 +0530 Subject: [PATCH] fix: grpc payload deserialization should only happen for requested fields Signed-off-by: Kush --- middleware/grpc_payload.go | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/middleware/grpc_payload.go b/middleware/grpc_payload.go index 0c1bb0d64..f41ba46ac 100644 --- a/middleware/grpc_payload.go +++ b/middleware/grpc_payload.go @@ -3,12 +3,13 @@ package middleware import ( "bytes" "encoding/binary" - "errors" "fmt" "io" "io/ioutil" "net/http" + "github.com/pkg/errors" + "github.com/golang/protobuf/protoc-gen-go/descriptor" "github.com/jhump/protoreflect/desc" "github.com/jhump/protoreflect/desc/builder" @@ -106,9 +107,9 @@ func (p *grpcRequestParser) Parse() (pf GRPCPayloadCompressionFormat, msg []byte } func fieldFromProtoMessage(msg []byte, tagIndex int) (string, error) { - desc, err := buildPayloadGenericProto() + desc, err := buildPayloadGenericProto(tagIndex) if err != nil { - return "", err + return "", errors.Wrap(err, "failed to build generic proto") } // populate message @@ -123,23 +124,25 @@ func fieldFromProtoMessage(msg []byte, tagIndex int) (string, error) { return val.(string), nil } -// should only be built once -var genericProtoCache *desc.MessageDescriptor +// should only be built once for one index +var genericProtoCache = make(map[int]*desc.MessageDescriptor) -func buildPayloadGenericProto() (*desc.MessageDescriptor, error) { - if genericProtoCache != nil { - return genericProtoCache, nil +// currently we build a generic message with just one field of string type +// should be able to support more primitive types if needed +func buildPayloadGenericProto(idx int) (*desc.MessageDescriptor, error) { + if val, ok := genericProtoCache[idx]; ok { + return val, nil } - builderMsg := builder.NewMessage("message") for i := 1; i < 100; i++ { + builderMsg := builder.NewMessage("message") builderMsg.AddField(builder.NewField(fmt.Sprintf("field_%d", i), builder.FieldTypeScalar(descriptor.FieldDescriptorProto_TYPE_STRING)).SetNumber(int32(i))) + desc, err := builderMsg.Build() + if err != nil { + return nil, err + } + genericProtoCache[i] = desc } - desc, err := builderMsg.Build() - if err != nil { - return nil, err - } - genericProtoCache = desc - return genericProtoCache, nil + return genericProtoCache[idx], nil }