From 4756e28b863bc24ff89ba4cc7256262b1c6a1c1e Mon Sep 17 00:00:00 2001 From: Denis Mulin Date: Mon, 22 Apr 2024 18:53:57 +0300 Subject: [PATCH] grpc reflection metadata --- components/guns/grpc/core.go | 18 ++--- components/guns/grpc/scenario/core.go | 22 +++--- tests/acceptance/config_model.go | 13 ++-- tests/acceptance/grpc_test.go | 107 ++++++++++++++++++++++++++ 4 files changed, 136 insertions(+), 24 deletions(-) diff --git a/components/guns/grpc/core.go b/components/guns/grpc/core.go index 251f8958b..c030a7949 100644 --- a/components/guns/grpc/core.go +++ b/components/guns/grpc/core.go @@ -43,13 +43,14 @@ type GrpcDialOptions struct { } type GunConfig struct { - Target string `validate:"required"` - ReflectPort int64 `config:"reflect_port"` - Timeout time.Duration `config:"timeout"` // grpc request timeout - TLS bool `config:"tls"` - DialOptions GrpcDialOptions `config:"dial_options"` - AnswLog AnswLogConfig `config:"answlog"` - SharedClient struct { + Target string `validate:"required"` + ReflectPort int64 `config:"reflect_port"` + ReflectMetadata metadata.MD `config:"reflect_metadata"` + Timeout time.Duration `config:"timeout"` // grpc request timeout + TLS bool `config:"tls"` + DialOptions GrpcDialOptions `config:"dial_options"` + AnswLog AnswLogConfig `config:"answlog"` + SharedClient struct { ClientNumber int `config:"client-number,omitempty"` Enabled bool `config:"enabled"` } `config:"shared-client,omitempty"` @@ -110,8 +111,7 @@ func (g *Gun) prepareMethodList(opts *warmup.Options) (map[string]desc.MethodDes } defer conn.Close() - meta := make(metadata.MD) - refCtx := metadata.NewOutgoingContext(context.Background(), meta) + refCtx := metadata.NewOutgoingContext(context.Background(), g.Conf.ReflectMetadata) refClient := grpcreflect.NewClientAuto(refCtx, conn) listServices, err := refClient.ListServices() if err != nil { diff --git a/components/guns/grpc/scenario/core.go b/components/guns/grpc/scenario/core.go index 68e078ac0..018c31f89 100644 --- a/components/guns/grpc/scenario/core.go +++ b/components/guns/grpc/scenario/core.go @@ -21,12 +21,13 @@ import ( const defaultTimeout = time.Second * 15 type GunConfig struct { - Target string `validate:"required"` - ReflectPort int64 `config:"reflect_port"` - Timeout time.Duration `config:"timeout"` // grpc request timeout - TLS bool `config:"tls"` - DialOptions GrpcDialOptions `config:"dial_options"` - AnswLog AnswLogConfig `config:"answlog"` + Target string `validate:"required"` + ReflectPort int64 `config:"reflect_port"` + ReflectMetadata metadata.MD `config:"reflect_metadata"` + Timeout time.Duration `config:"timeout"` // grpc request timeout + TLS bool `config:"tls"` + DialOptions GrpcDialOptions `config:"dial_options"` + AnswLog AnswLogConfig `config:"answlog"` } type GrpcDialOptions struct { @@ -57,10 +58,11 @@ func NewGun(conf GunConfig) *Gun { return &Gun{ templ: NewTextTemplater(), gun: &grpcgun.Gun{Conf: grpcgun.GunConfig{ - Target: conf.Target, - ReflectPort: conf.ReflectPort, - Timeout: conf.Timeout, - TLS: conf.TLS, + Target: conf.Target, + ReflectPort: conf.ReflectPort, + ReflectMetadata: conf.ReflectMetadata, + Timeout: conf.Timeout, + TLS: conf.TLS, DialOptions: grpcgun.GrpcDialOptions{ Authority: conf.DialOptions.Authority, Timeout: conf.DialOptions.Timeout, diff --git a/tests/acceptance/config_model.go b/tests/acceptance/config_model.go index a0ccd9e15..b16b3b6ef 100644 --- a/tests/acceptance/config_model.go +++ b/tests/acceptance/config_model.go @@ -1,5 +1,7 @@ package acceptance +import "google.golang.org/grpc/metadata" + type PandoraConfigLog struct { Level string `yaml:"level"` } @@ -11,11 +13,12 @@ type PandoraConfigMonitoring struct { ExpVar PandoraConfigMonitoringExpVar `yaml:"expvar"` } type PandoraConfigGRPCGun struct { - Type string `yaml:"type"` - Target string `yaml:"target"` - TLS bool `yaml:"tls"` - ReflectPort *int64 `yaml:"reflect_port,omitempty"` - SharedClient struct { + Type string `yaml:"type"` + Target string `yaml:"target"` + TLS bool `yaml:"tls"` + ReflectPort *int64 `yaml:"reflect_port,omitempty"` + ReflectMetadata *metadata.MD `yaml:"reflect_metadata,omitempty"` + SharedClient struct { ClientNumber int `yaml:"client-number,omitempty"` Enabled bool `yaml:"enabled"` } `yaml:"shared-client,omitempty"` diff --git a/tests/acceptance/grpc_test.go b/tests/acceptance/grpc_test.go index 491b327c9..6b538fc19 100644 --- a/tests/acceptance/grpc_test.go +++ b/tests/acceptance/grpc_test.go @@ -2,12 +2,14 @@ package acceptance import ( "context" + "fmt" "log/slog" "net" "os" "testing" "time" + "github.com/pkg/errors" "github.com/spf13/afero" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -21,6 +23,7 @@ import ( "github.com/yandex/pandora/lib/testutil" "go.uber.org/zap" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/reflection" "gopkg.in/yaml.v2" ) @@ -116,6 +119,89 @@ func TestCheckGRPCReflectServer(t *testing.T) { require.NoError(t, err) require.Equal(t, int64(8), st.Hello) }) + + t.Run("reflect with custom metadata", func(t *testing.T) { + metadataKey := "testKey" + metadataValue := "testValue" + wrongMDValuesLengthError := errors.New("wrong metadata values length") + wrongMDValueError := errors.New("wrong metadata value") + metadataChecker := func(ctx context.Context) (context.Context, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, wrongMDValuesLengthError + } + vals := md.Get(metadataKey) + if len(vals) != 1 { + return nil, wrongMDValuesLengthError + } + if vals[0] != metadataValue { + return nil, wrongMDValueError + } + return ctx, nil + } + grpcServer := grpc.NewServer( + grpc.UnaryInterceptor(MetadataServerInterceptor(metadataChecker)), + grpc.StreamInterceptor(MetadataServerStreamInterceptor(metadataChecker))) + srv := server.NewServer(logger, time.Now().UnixNano()) + server.RegisterTargetServiceServer(grpcServer, srv) + grpcAddress := "localhost:18888" + reflection.Register(grpcServer) + l, err := net.Listen("tcp", grpcAddress) + require.NoError(t, err) + go func() { + err = grpcServer.Serve(l) + require.NoError(t, err) + }() + + defer func() { + grpcServer.Stop() + }() + + cases := []struct { + name string + conf *cli.CliConfig + err error + }{ + { + name: "success", + conf: parseFileContentToCliConfig(t, baseFile, func(c *PandoraConfigGRPC) { + md := metadata.New(map[string]string{metadataKey: metadataValue}) + c.Pools[0].Gun.ReflectMetadata = &md + }), + }, + { + name: "no metadata", + conf: parseFileContentToCliConfig(t, baseFile, nil), + err: wrongMDValuesLengthError, + }, + { + name: "wrong metadata value", + conf: parseFileContentToCliConfig(t, baseFile, func(c *PandoraConfigGRPC) { + md := metadata.New(map[string]string{metadataKey: "wrong-value"}) + c.Pools[0].Gun.ReflectMetadata = &md + }), + err: wrongMDValueError, + }, + } + + for _, cc := range cases { + t.Run(cc.name, func(t *testing.T) { + require.Equal(t, 1, len(cc.conf.Engine.Pools)) + aggr := &aggregator{} + cc.conf.Engine.Pools[0].Aggregator = aggr + + pandora := engine.New(pandoraLogger, pandoraMetrics, cc.conf.Engine) + err = pandora.Run(context.Background()) + + if cc.err == nil { + require.NoError(t, err) + } else { + require.Error(t, err) + require.Contains(t, err.Error(), cc.err.Error()) + } + }) + } + }) } func TestGrpcGunSuite(t *testing.T) { @@ -211,3 +297,24 @@ func parseFileContentToCliConfig(t *testing.T, baseFile []byte, overwrite func(c return decodeConfig(t, mapCfg) } + +func MetadataServerInterceptor(metadataChecker func(ctx context.Context) (context.Context, error)) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + ctx, err = metadataChecker(ctx) + if err != nil { + return nil, fmt.Errorf("metadata checker: %w", err) + } + return handler(ctx, req) + } +} + +func MetadataServerStreamInterceptor(metadataChecker func(ctx context.Context) (context.Context, error)) grpc.StreamServerInterceptor { + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { + ctx := ss.Context() + ctx, err = metadataChecker(ctx) + if err != nil { + return fmt.Errorf("metadata checker: %w", err) + } + return handler(srv, ss) + } +}