Skip to content

Commit

Permalink
feat(connect): add GRPC header carrier to support trace context propa…
Browse files Browse the repository at this point in the history
…gation in GRPC requests

feat(trace): add support for W3C Trace Context format to propagate trace context in HTTP requests and responses
refactor(trace): rename canonicalMapCarrier to CanonicalMapCarrier and grpcHeaderCarrier to GRPCHeaderCarrier for better readability
refactor(trace): remove unused emptySpanContext variable and add constants for supported and max version of W3C Trace Context
refactor(trace): extract common code for encoding and decoding SpanContext to separate functions
refactor(trace): simplify InjectTraceContext and WithTraceContext functions to use TextMapCarrier interface instead of MapCarrier
refactor(trace): simplify SpanContextFromBinary function to use SpanContextFromBytes function
refactor(trace): simplify SpanContextToBinary function to use SpanContextToBytes function
refactor(trace): simplify SpanContextFromBytes function to use SpanContextFromW3CString function
refactor(trace): simplify SpanContextToBytes
  • Loading branch information
shumkovdenis committed May 16, 2023
1 parent 726970c commit 665df36
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 34 deletions.
6 changes: 5 additions & 1 deletion connect/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,18 @@ func NewInsecureClient() *http.Client {

func WithHandlerOptions(interceptors ...connect.Interceptor) connect.HandlerOption {
return connect.WithHandlerOptions(
connect.WithInterceptors(InjectTraceContext()),
connect.WithInterceptors(
InjectTraceContext(),
InjectTraceContextLogger(),
),
connect.WithInterceptors(interceptors...),
)
}

func WithClientOptions(interceptors ...connect.Interceptor) connect.ClientOption {
return connect.WithClientOptions(
connect.WithGRPC(),
connect.WithInterceptors(AddTraceContextHeader()),
connect.WithInterceptors(interceptors...),
)
}
50 changes: 43 additions & 7 deletions connect/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,44 @@ package connect

import (
"context"
"log"

"github.com/bufbuild/connect-go"
"github.com/shumkovdenis/bl/logger"
"github.com/shumkovdenis/bl/trace"
"go.opentelemetry.io/otel/propagation"
)

type ConnectHeaderCarrier propagation.HeaderCarrier

func (c ConnectHeaderCarrier) Get(key string) string {
value := propagation.HeaderCarrier(c).Get(key)
if key == trace.GrpcTraceBinHeader {
b, _ := connect.DecodeBinaryHeader(value)
return string(b)
}
return value
}

func (c ConnectHeaderCarrier) Set(key, value string) {
if key == trace.GrpcTraceBinHeader {
value = connect.EncodeBinaryHeader([]byte(value))
}
propagation.HeaderCarrier(c).Set(key, value)
}

func (c ConnectHeaderCarrier) Keys() []string {
return propagation.HeaderCarrier(c).Keys()
}

func InjectTraceContext() connect.UnaryInterceptorFunc {
interceptor := func(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(
ctx context.Context,
req connect.AnyRequest,
) (connect.AnyResponse, error) {
if !req.Spec().IsClient {
log.Println(req.Header())
log.Println(req.Header().Get(trace.GrpcTraceBinHeader))
t := req.Header().Get(trace.GrpcTraceBinHeader)
b, _ := connect.DecodeBinaryHeader(t)
sc, ok := trace.SpanContextFromBinary(b)
log.Println(sc.TraceID(), sc.SpanID(), ok)
carrier := trace.GRPCHeaderCarrier(ConnectHeaderCarrier(req.Header()))
ctx = trace.WithTraceContext(ctx, carrier)
}
return next(ctx, req)
})
Expand Down Expand Up @@ -62,3 +80,21 @@ func AddHeader(key, value string) connect.UnaryInterceptorFunc {
func AddDaprAppIDHeader(appID string) connect.UnaryInterceptorFunc {
return AddHeader("dapr-app-id", appID)
}

func AddTraceContextHeader() connect.UnaryInterceptorFunc {
interceptor := func(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(
ctx context.Context,
req connect.AnyRequest,
) (connect.AnyResponse, error) {
if req.Spec().IsClient {
header := req.Header()
carrier := trace.GRPCHeaderCarrier(
ConnectHeaderCarrier(header))
trace.InjectTraceContext(ctx, carrier)
}
return next(ctx, req)
})
}
return connect.UnaryInterceptorFunc(interceptor)
}
6 changes: 0 additions & 6 deletions connect_callee.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,6 @@ func NewConnectCallee(cfg Config) *connectCallee {
connectUtils.WithClientOptions(
connectUtils.AddDaprAppIDHeader(cfg.Callee.ServiceName),
),
// connect.WithGRPC(),
// connect.WithInterceptors(
// helpers.NewAppInterceptor("remote"),
// helpers.NewTraceInterceptor(cfg.GRPCTrace),
// helpers.NewLoggerInterceptor(),
// ),
)
return &connectCallee{client: client}
}
Expand Down
3 changes: 2 additions & 1 deletion http/client/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"net/http"

"github.com/shumkovdenis/bl/trace"
"go.opentelemetry.io/otel/propagation"
)

// https://jonfriesen.ca/articles/go-http-client-middleware
Expand Down Expand Up @@ -64,7 +65,7 @@ func AddTraceContextHeader() Middleware {
header = make(http.Header)
}

trace.InjectTraceContext(ctx, header)
trace.InjectTraceContext(ctx, propagation.HeaderCarrier(header))

return rt.RoundTrip(req)
})
Expand Down
3 changes: 2 additions & 1 deletion http/server/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import (
func InjectTraceContext() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
ctx := c.UserContext()
ctx = trace.WithTraceContextFromMap(ctx, c.GetReqHeaders())
headers := c.GetReqHeaders()
ctx = trace.WithTraceContext(ctx, trace.CanonicalMapCarrier(headers))
c.SetUserContext(ctx)
return c.Next()
}
Expand Down
45 changes: 28 additions & 17 deletions trace/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,49 +11,60 @@ import (
const (
TraceparentHeader = "traceparent"
TracestateHeader = "tracestate"
GrpcTraceBinHeader = "grpc-trace-bin"
GrpcTraceBinHeader = "Grpc-Trace-Bin"
)

var (
traceContext propagation.TraceContext
)

type canonicalMapCarrier propagation.MapCarrier
type CanonicalMapCarrier propagation.MapCarrier

func (c canonicalMapCarrier) Get(key string) string {
func (c CanonicalMapCarrier) Get(key string) string {
return propagation.MapCarrier(c).Get(http.CanonicalHeaderKey(key))
}

func (c canonicalMapCarrier) Set(key, value string) {
func (c CanonicalMapCarrier) Set(key, value string) {
propagation.MapCarrier(c).Set(http.CanonicalHeaderKey(key), value)
}

func (c canonicalMapCarrier) Keys() []string {
func (c CanonicalMapCarrier) Keys() []string {
return propagation.MapCarrier(c).Keys()
}

type grpcHeaderCarrier propagation.HeaderCarrier
type GRPCHeaderCarrier propagation.HeaderCarrier

func (c grpcHeaderCarrier) Get(key string) string {
return propagation.HeaderCarrier(c).Get(http.CanonicalHeaderKey(key))
func (c GRPCHeaderCarrier) Get(key string) string {
if key == TraceparentHeader {
grpcTraceBin := propagation.HeaderCarrier(c).Get(GrpcTraceBinHeader)
sc, _ := SpanContextFromBinary([]byte(grpcTraceBin))
return SpanContextToW3CString(sc)
}
return propagation.HeaderCarrier(c).Get(key)
}

func (c grpcHeaderCarrier) Set(key, value string) {
propagation.HeaderCarrier(c).Set(http.CanonicalHeaderKey(key), value)
func (c GRPCHeaderCarrier) Set(key, value string) {
if key == TraceparentHeader {
sc, _ := SpanContextFromW3CString(value)
grpcTraceBin := BinaryFromSpanContext(sc)
propagation.HeaderCarrier(c).Set(GrpcTraceBinHeader, string(grpcTraceBin))
} else {
propagation.HeaderCarrier(c).Set(key, value)
}
}

func (c grpcHeaderCarrier) Keys() []string {
func (c GRPCHeaderCarrier) Keys() []string {
return propagation.HeaderCarrier(c).Keys()
}

func WithTraceContextFromMap(ctx context.Context, headers map[string]string) context.Context {
return traceContext.Extract(ctx, canonicalMapCarrier(headers))
func InjectTraceContext(ctx context.Context, carrier propagation.TextMapCarrier) {
traceContext.Inject(ctx, carrier)
}

func TraceContextFromContext(ctx context.Context) trace.SpanContext {
return trace.SpanContextFromContext(ctx)
func WithTraceContext(ctx context.Context, carrier propagation.TextMapCarrier) context.Context {
return traceContext.Extract(ctx, carrier)
}

func InjectTraceContext(ctx context.Context, header http.Header) {
traceContext.Inject(ctx, propagation.HeaderCarrier(header))
func TraceContextFromContext(ctx context.Context) trace.SpanContext {
return trace.SpanContextFromContext(ctx)
}
83 changes: 82 additions & 1 deletion trace/utils.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
package trace

import "go.opentelemetry.io/otel/trace"
import (
"encoding/hex"
"fmt"
"strings"

"go.opentelemetry.io/otel/trace"
)

const (
supportedVersion = 0
maxVersion = 254
)

var emptySpanContext trace.SpanContext

Expand Down Expand Up @@ -48,3 +59,73 @@ func SpanContextFromBinary(b []byte) (sc trace.SpanContext, ok bool) {
sc = trace.NewSpanContext(scConfig)
return sc, true
}

// SpanContextToW3CString returns the SpanContext string representation.
func SpanContextToW3CString(sc trace.SpanContext) string {
traceID := sc.TraceID()
spanID := sc.SpanID()
traceFlags := sc.TraceFlags()
return fmt.Sprintf("%x-%x-%x-%x",
[]byte{supportedVersion},
traceID[:],
spanID[:],
[]byte{byte(traceFlags)})
}

// SpanContextFromW3CString extracts a span context from given string which got earlier from SpanContextToW3CString format.
func SpanContextFromW3CString(h string) (sc trace.SpanContext, ok bool) {
if h == "" {
return trace.SpanContext{}, false
}
sections := strings.Split(h, "-")
if len(sections) < 4 {
return trace.SpanContext{}, false
}

if len(sections[0]) != 2 {
return trace.SpanContext{}, false
}
ver, err := hex.DecodeString(sections[0])
if err != nil {
return trace.SpanContext{}, false
}
version := int(ver[0])
if version > maxVersion {
return trace.SpanContext{}, false
}

if version == 0 && len(sections) != 4 {
return trace.SpanContext{}, false
}

if len(sections[1]) != 32 {
return trace.SpanContext{}, false
}
tid, err := trace.TraceIDFromHex(sections[1])
if err != nil {
return trace.SpanContext{}, false
}
sc = sc.WithTraceID(tid)

if len(sections[2]) != 16 {
return trace.SpanContext{}, false
}
sid, err := trace.SpanIDFromHex(sections[2])
if err != nil {
return trace.SpanContext{}, false
}
sc = sc.WithSpanID(sid)

opts, err := hex.DecodeString(sections[3])
if err != nil || len(opts) < 1 {
return trace.SpanContext{}, false
}
sc = sc.WithTraceFlags(trace.TraceFlags(opts[0]))

// Don't allow all zero trace or span ID.
if sc.TraceID() == [16]byte{} || sc.SpanID() == [8]byte{} {
return trace.SpanContext{}, false
}

return sc, true
}

0 comments on commit 665df36

Please sign in to comment.