Skip to content

Commit

Permalink
fix(otelgorm): restore original context after "after" callback
Browse files Browse the repository at this point in the history
  • Loading branch information
markhildreth-gravity committed Nov 21, 2022
1 parent 2cf35d7 commit 7cd8508
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
16 changes: 16 additions & 0 deletions otelgorm/internal/e2etest/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,22 @@ func TestEndToEnd(t *testing.T) {
require.Equal(t, 1, len(spans))
},
},
{
do: func(ctx context.Context, db *gorm.DB) {
var count int64
query := db.WithContext(ctx).Table("generate_series(1, 10)")
_, _ = query.Select("*").Rows()
_ = query.Count(&count)
},
require: func(t *testing.T, spans []sdktrace.ReadOnlySpan) {
require.Equal(t, 2, len(spans))
require.Equal(
t,
spans[0].Parent().SpanID().String(),
spans[1].Parent().SpanID().String(),
)
},
},
}

for i, test := range tests {
Expand Down
14 changes: 13 additions & 1 deletion otelgorm/otelgorm.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package otelgorm

import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
Expand Down Expand Up @@ -95,12 +96,18 @@ func (p otelPlugin) Initialize(db *gorm.DB) (err error) {
return firstErr
}

type originalCtxType int

const originalCtxValue originalCtxType = iota

func (p *otelPlugin) before(spanName string) gormHookFunc {
return func(tx *gorm.DB) {
if tx.DryRun && !p.includeDryRunSpans {
return
}
tx.Statement.Context, _ = p.tracer.Start(tx.Statement.Context, spanName, trace.WithSpanKind(trace.SpanKindClient))
originalCtx := tx.Statement.Context
newCtx := context.WithValue(originalCtx, originalCtxValue, originalCtx)
tx.Statement.Context, _ = p.tracer.Start(newCtx, spanName, trace.WithSpanKind(trace.SpanKindClient))
}
}

Expand Down Expand Up @@ -154,6 +161,11 @@ func (p *otelPlugin) after() gormHookFunc {
span.RecordError(tx.Error)
span.SetStatus(codes.Error, tx.Error.Error())
}

switch originalCtx := tx.Statement.Context.Value(originalCtxValue).(type) {
case context.Context:
tx.Statement.Context = originalCtx
}
}
}

Expand Down

0 comments on commit 7cd8508

Please sign in to comment.