diff --git a/extra/redisotel/config.go b/extra/redisotel/config.go index 62b3c9bc28..b9311beafa 100644 --- a/extra/redisotel/config.go +++ b/extra/redisotel/config.go @@ -22,9 +22,11 @@ type config struct { tp trace.TracerProvider tracer trace.Tracer - dbStmtEnabled bool - callerEnabled bool - filter func(cmd redis.Cmder) bool + dbStmtEnabled bool + callerEnabled bool + filterDial bool + filterProcessPipeline func(cmds []redis.Cmder) bool + filterProcess func(cmd redis.Cmder) bool // Metrics options. @@ -65,6 +67,15 @@ func newConfig(opts ...baseOption) *config { mp: otel.GetMeterProvider(), dbStmtEnabled: true, callerEnabled: true, + filterProcess: DefaultCommandFilter, + filterProcessPipeline: func(cmds []redis.Cmder) bool { + for _, cmd := range cmds { + if DefaultCommandFilter(cmd) { + return true + } + } + return false + }, } for _, opt := range opts { @@ -132,11 +143,28 @@ func WithCallerEnabled(on bool) TracingOption { // passwords. func WithCommandFilter(filter func(cmd redis.Cmder) bool) TracingOption { return tracingOption(func(conf *config) { - conf.filter = filter + conf.filterProcess = filter }) } -func BasicCommandFilter(cmd redis.Cmder) bool { +// WithCommandsFilter allows filtering of pipeline commands +// when tracing to omit commands that may have sensitive details like +// passwords in a pipeline. +func WithCommandsFilter(filter func(cmds []redis.Cmder) bool) TracingOption { + return tracingOption(func(conf *config) { + conf.filterProcessPipeline = filter + }) +} + +// WithDialFilter enables or disables filtering of dial commands. +func WithDialFilter(on bool) TracingOption { + return tracingOption(func(conf *config) { + conf.filterDial = on + }) +} + +// DefaultCommandFilter filters out AUTH commands from tracing. +func DefaultCommandFilter(cmd redis.Cmder) bool { if strings.ToLower(cmd.Name()) == "auth" { return true } @@ -159,6 +187,12 @@ func BasicCommandFilter(cmd redis.Cmder) bool { return false } +// BasicCommandFilter filters out AUTH commands from tracing. +// Deprecated: use DefaultCommandFilter instead. +func BasicCommandFilter(cmd redis.Cmder) bool { + return DefaultCommandFilter(cmd) +} + //------------------------------------------------------------------------------ type MetricsOption interface { diff --git a/extra/redisotel/tracing.go b/extra/redisotel/tracing.go index 5c91710c60..a6f361b06e 100644 --- a/extra/redisotel/tracing.go +++ b/extra/redisotel/tracing.go @@ -87,6 +87,11 @@ func newTracingHook(connString string, opts ...TracingOption) *tracingHook { func (th *tracingHook) DialHook(hook redis.DialHook) redis.DialHook { return func(ctx context.Context, network, addr string) (net.Conn, error) { + + if th.conf.filterDial { + return hook(ctx, network, addr) + } + ctx, span := th.conf.tracer.Start(ctx, "redis.dial", th.spanOpts...) defer span.End() @@ -103,7 +108,7 @@ func (th *tracingHook) ProcessHook(hook redis.ProcessHook) redis.ProcessHook { return func(ctx context.Context, cmd redis.Cmder) error { // Check if the command should be filtered out - if th.conf.filter != nil && th.conf.filter(cmd) { + if th.conf.filterProcess != nil && th.conf.filterProcess(cmd) { // If so, just call the next hook return hook(ctx, cmd) } @@ -141,6 +146,11 @@ func (th *tracingHook) ProcessPipelineHook( hook redis.ProcessPipelineHook, ) redis.ProcessPipelineHook { return func(ctx context.Context, cmds []redis.Cmder) error { + + if th.conf.filterProcessPipeline != nil && th.conf.filterProcessPipeline(cmds) { + return hook(ctx, cmds) + } + attrs := make([]attribute.KeyValue, 0, 8) attrs = append(attrs, attribute.Int("db.redis.num_cmd", len(cmds)), diff --git a/extra/redisotel/tracing_test.go b/extra/redisotel/tracing_test.go index 0ae70c2d8b..96c2aff835 100644 --- a/extra/redisotel/tracing_test.go +++ b/extra/redisotel/tracing_test.go @@ -156,7 +156,7 @@ func TestWithCommandFilter(t *testing.T) { hook := newTracingHook( "", WithTracerProvider(provider), - WithCommandFilter(BasicCommandFilter), + WithCommandFilter(DefaultCommandFilter), ) ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test") cmd := redis.NewCmd(ctx, "auth", "test-password") @@ -181,7 +181,7 @@ func TestWithCommandFilter(t *testing.T) { hook := newTracingHook( "", WithTracerProvider(provider), - WithCommandFilter(BasicCommandFilter), + WithCommandFilter(DefaultCommandFilter), ) ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test") cmd := redis.NewCmd(ctx, "hello", 3, "AUTH", "test-user", "test-password") @@ -206,7 +206,7 @@ func TestWithCommandFilter(t *testing.T) { hook := newTracingHook( "", WithTracerProvider(provider), - WithCommandFilter(BasicCommandFilter), + WithCommandFilter(DefaultCommandFilter), ) ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test") cmd := redis.NewCmd(ctx, "hello", 3) @@ -227,6 +227,120 @@ func TestWithCommandFilter(t *testing.T) { }) } +func TestWithCommandsFilter(t *testing.T) { + t.Run("filter out ping and info commands", func(t *testing.T) { + provider := sdktrace.NewTracerProvider() + hook := newTracingHook( + "", + WithTracerProvider(provider), + WithCommandsFilter(func(cmds []redis.Cmder) bool { + for _, cmd := range cmds { + if cmd.Name() == "ping" || cmd.Name() == "info" { + return true + } + } + return false + }), + ) + + ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test") + cmds := []redis.Cmder{ + redis.NewCmd(ctx, "ping"), + redis.NewCmd(ctx, "info"), + } + defer span.End() + + processPipelineHook := hook.ProcessPipelineHook(func(ctx context.Context, cmds []redis.Cmder) error { + innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan) + if innerSpan.Name() != "redis-test" || innerSpan.Name() == "redis.pipeline ping\ninfo" { + t.Fatalf("ping and info commands should not be traced") + } + return nil + }) + err := processPipelineHook(ctx, cmds) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("do not filter ping and info commands", func(t *testing.T) { + provider := sdktrace.NewTracerProvider() + hook := newTracingHook( + "", + WithTracerProvider(provider), + WithCommandsFilter(func(cmds []redis.Cmder) bool { + return false // never filter + }), + ) + ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test") + cmds := []redis.Cmder{ + redis.NewCmd(ctx, "ping"), + redis.NewCmd(ctx, "info"), + } + defer span.End() + processPipelineHook := hook.ProcessPipelineHook(func(ctx context.Context, cmds []redis.Cmder) error { + innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan) + if innerSpan.Name() != "redis.pipeline ping info" { + t.Fatalf("ping and info commands should be traced") + } + + return nil + }) + + err := processPipelineHook(ctx, cmds) + if err != nil { + t.Fatal(err) + } + }) +} + +func TestWithDialFilter(t *testing.T) { + t.Run("filter out dial", func(t *testing.T) { + provider := sdktrace.NewTracerProvider() + hook := newTracingHook( + "", + WithTracerProvider(provider), + WithDialFilter(true), + ) + ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test") + defer span.End() + dialHook := hook.DialHook(func(ctx context.Context, network, addr string) (conn net.Conn, err error) { + innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan) + if innerSpan.Name() == "redis.dial" { + t.Fatalf("dial should not be traced") + } + return nil, nil + }) + + _, err := dialHook(ctx, "tcp", "localhost:6379") + if err != nil { + t.Fatal(err) + } + }) + + t.Run("do not filter dial", func(t *testing.T) { + provider := sdktrace.NewTracerProvider() + hook := newTracingHook( + "", + WithTracerProvider(provider), + WithDialFilter(false), + ) + ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test") + defer span.End() + dialHook := hook.DialHook(func(ctx context.Context, network, addr string) (conn net.Conn, err error) { + innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan) + if innerSpan.Name() != "redis.dial" { + t.Fatalf("dial should be traced") + } + return nil, nil + }) + _, err := dialHook(ctx, "tcp", "localhost:6379") + if err != nil { + t.Fatal(err) + } + }) +} + func TestTracingHook_DialHook(t *testing.T) { imsb := tracetest.NewInMemoryExporter() provider := sdktrace.NewTracerProvider(sdktrace.WithSyncer(imsb))