Skip to content
Merged
44 changes: 39 additions & 5 deletions extra/redisotel/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ type config struct {
tp trace.TracerProvider
tracer trace.Tracer

dbStmtEnabled bool
callerEnabled bool
filter func(cmd redis.Cmder) bool
dbStmtEnabled bool
callerEnabled bool
filterDial bool
filterProcessPipeline func(cmds []redis.Cmder) bool
filterProcess func(cmd redis.Cmder) bool

// Metrics options.

Expand Down Expand Up @@ -65,6 +67,15 @@ func newConfig(opts ...baseOption) *config {
mp: otel.GetMeterProvider(),
dbStmtEnabled: true,
callerEnabled: true,
filterProcess: DefaultCommandFilter,
filterProcessPipeline: func(cmds []redis.Cmder) bool {
for _, cmd := range cmds {
if DefaultCommandFilter(cmd) {
return true
}
}
return false
},
}

for _, opt := range opts {
Expand Down Expand Up @@ -132,11 +143,28 @@ func WithCallerEnabled(on bool) TracingOption {
// passwords.
func WithCommandFilter(filter func(cmd redis.Cmder) bool) TracingOption {
return tracingOption(func(conf *config) {
conf.filter = filter
conf.filterProcess = filter
})
}

func BasicCommandFilter(cmd redis.Cmder) bool {
// WithCommandsFilter allows filtering of pipeline commands
// when tracing to omit commands that may have sensitive details like
// passwords in a pipeline.
func WithCommandsFilter(filter func(cmds []redis.Cmder) bool) TracingOption {
return tracingOption(func(conf *config) {
conf.filterProcessPipeline = filter
})
}

// WithDialFilter enables or disables filtering of dial commands.
func WithDialFilter(on bool) TracingOption {
return tracingOption(func(conf *config) {
conf.filterDial = on
})
}

// DefaultCommandFilter filters out AUTH commands from tracing.
func DefaultCommandFilter(cmd redis.Cmder) bool {
if strings.ToLower(cmd.Name()) == "auth" {
return true
}
Expand All @@ -159,6 +187,12 @@ func BasicCommandFilter(cmd redis.Cmder) bool {
return false
}

// BasicCommandFilter filters out AUTH commands from tracing.
// Deprecated: use DefaultCommandFilter instead.
func BasicCommandFilter(cmd redis.Cmder) bool {
return DefaultCommandFilter(cmd)
}

//------------------------------------------------------------------------------

type MetricsOption interface {
Expand Down
12 changes: 11 additions & 1 deletion extra/redisotel/tracing.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ func newTracingHook(connString string, opts ...TracingOption) *tracingHook {

func (th *tracingHook) DialHook(hook redis.DialHook) redis.DialHook {
return func(ctx context.Context, network, addr string) (net.Conn, error) {

if th.conf.filterDial {
return hook(ctx, network, addr)
}

ctx, span := th.conf.tracer.Start(ctx, "redis.dial", th.spanOpts...)
defer span.End()

Expand All @@ -103,7 +108,7 @@ func (th *tracingHook) ProcessHook(hook redis.ProcessHook) redis.ProcessHook {
return func(ctx context.Context, cmd redis.Cmder) error {

// Check if the command should be filtered out
if th.conf.filter != nil && th.conf.filter(cmd) {
if th.conf.filterProcess != nil && th.conf.filterProcess(cmd) {
// If so, just call the next hook
return hook(ctx, cmd)
}
Expand Down Expand Up @@ -141,6 +146,11 @@ func (th *tracingHook) ProcessPipelineHook(
hook redis.ProcessPipelineHook,
) redis.ProcessPipelineHook {
return func(ctx context.Context, cmds []redis.Cmder) error {

if th.conf.filterProcessPipeline != nil && th.conf.filterProcessPipeline(cmds) {
return hook(ctx, cmds)
}

attrs := make([]attribute.KeyValue, 0, 8)
attrs = append(attrs,
attribute.Int("db.redis.num_cmd", len(cmds)),
Expand Down
120 changes: 117 additions & 3 deletions extra/redisotel/tracing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func TestWithCommandFilter(t *testing.T) {
hook := newTracingHook(
"",
WithTracerProvider(provider),
WithCommandFilter(BasicCommandFilter),
WithCommandFilter(DefaultCommandFilter),
)
ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test")
cmd := redis.NewCmd(ctx, "auth", "test-password")
Expand All @@ -181,7 +181,7 @@ func TestWithCommandFilter(t *testing.T) {
hook := newTracingHook(
"",
WithTracerProvider(provider),
WithCommandFilter(BasicCommandFilter),
WithCommandFilter(DefaultCommandFilter),
)
ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test")
cmd := redis.NewCmd(ctx, "hello", 3, "AUTH", "test-user", "test-password")
Expand All @@ -206,7 +206,7 @@ func TestWithCommandFilter(t *testing.T) {
hook := newTracingHook(
"",
WithTracerProvider(provider),
WithCommandFilter(BasicCommandFilter),
WithCommandFilter(DefaultCommandFilter),
)
ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test")
cmd := redis.NewCmd(ctx, "hello", 3)
Expand All @@ -227,6 +227,120 @@ func TestWithCommandFilter(t *testing.T) {
})
}

func TestWithCommandsFilter(t *testing.T) {
t.Run("filter out ping and info commands", func(t *testing.T) {
provider := sdktrace.NewTracerProvider()
hook := newTracingHook(
"",
WithTracerProvider(provider),
WithCommandsFilter(func(cmds []redis.Cmder) bool {
for _, cmd := range cmds {
if cmd.Name() == "ping" || cmd.Name() == "info" {
return true
}
}
return false
}),
)

ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test")
cmds := []redis.Cmder{
redis.NewCmd(ctx, "ping"),
redis.NewCmd(ctx, "info"),
}
defer span.End()

processPipelineHook := hook.ProcessPipelineHook(func(ctx context.Context, cmds []redis.Cmder) error {
innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan)
if innerSpan.Name() != "redis-test" || innerSpan.Name() == "redis.pipeline ping\ninfo" {
t.Fatalf("ping and info commands should not be traced")
}
return nil
})
err := processPipelineHook(ctx, cmds)
if err != nil {
t.Fatal(err)
}
})

t.Run("do not filter ping and info commands", func(t *testing.T) {
provider := sdktrace.NewTracerProvider()
hook := newTracingHook(
"",
WithTracerProvider(provider),
WithCommandsFilter(func(cmds []redis.Cmder) bool {
return false // never filter
}),
)
ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test")
cmds := []redis.Cmder{
redis.NewCmd(ctx, "ping"),
redis.NewCmd(ctx, "info"),
}
defer span.End()
processPipelineHook := hook.ProcessPipelineHook(func(ctx context.Context, cmds []redis.Cmder) error {
innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan)
if innerSpan.Name() != "redis.pipeline ping info" {
t.Fatalf("ping and info commands should be traced")
}

return nil
})

err := processPipelineHook(ctx, cmds)
if err != nil {
t.Fatal(err)
}
})
}

func TestWithDialFilter(t *testing.T) {
t.Run("filter out dial", func(t *testing.T) {
provider := sdktrace.NewTracerProvider()
hook := newTracingHook(
"",
WithTracerProvider(provider),
WithDialFilter(true),
)
ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test")
defer span.End()
dialHook := hook.DialHook(func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan)
if innerSpan.Name() == "redis.dial" {
t.Fatalf("dial should not be traced")
}
return nil, nil
})

_, err := dialHook(ctx, "tcp", "localhost:6379")
if err != nil {
t.Fatal(err)
}
})

t.Run("do not filter dial", func(t *testing.T) {
provider := sdktrace.NewTracerProvider()
hook := newTracingHook(
"",
WithTracerProvider(provider),
WithDialFilter(false),
)
ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test")
defer span.End()
dialHook := hook.DialHook(func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan)
if innerSpan.Name() != "redis.dial" {
t.Fatalf("dial should be traced")
}
return nil, nil
})
_, err := dialHook(ctx, "tcp", "localhost:6379")
if err != nil {
t.Fatal(err)
}
})
}

func TestTracingHook_DialHook(t *testing.T) {
imsb := tracetest.NewInMemoryExporter()
provider := sdktrace.NewTracerProvider(sdktrace.WithSyncer(imsb))
Expand Down
Loading