Skip to content

Commit

Permalink
Make ContextWithHeaders use net.Context field storage
Browse files Browse the repository at this point in the history
Support parent context in ContextBuilder
  • Loading branch information
Yuri Shkuro committed Mar 22, 2016
1 parent 226988c commit 8a0c71f
Show file tree
Hide file tree
Showing 9 changed files with 209 additions and 17 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ lint.log
*.swp
.DS_Store
.idea
tchannel-go.iml
5 changes: 4 additions & 1 deletion context.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ const defaultTimeout = time.Second

type contextKey int

const contextKeyTChannel = 1
const (
contextKeyTChannel contextKey = iota
contextKeyHeaders
)

type tchannelCtxParams struct {
span *Span
Expand Down
77 changes: 70 additions & 7 deletions context_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,27 @@ type ContextBuilder struct {
// RetryOptions are the retry options for this call.
RetryOptions *RetryOptions

// ParentContext to build the new context from. If empty, context.Background() is used.
// The new (child) context inherits a number of properties from the parent context:
// - the tracing Span, unless replaced via SetExternalSpan()
// - context fields, accessible via `ctx.Value(key)`
// - headers if parent is a ContextWithHeaders, unless replaced via SetHeaders()
ParentContext context.Context

// Hidden fields: we do not want users outside of tchannel to set these.
incomingCall IncomingCall
span *Span

// replaceParentHeaders is set to true when SetHeaders() method is called.
// It forces headers from ParentContext to be ignored. When false, parent
// headers will be merged with headers accumulated by the builder.
replaceParentHeaders bool
}

// NewContextBuilder returns a builder that can be used to create a Context.
func NewContextBuilder(timeout time.Duration) *ContextBuilder {
return &ContextBuilder{
Timeout: timeout,
span: NewRootSpan(),
}
}

Expand All @@ -74,8 +85,10 @@ func (cb *ContextBuilder) AddHeader(key, value string) *ContextBuilder {
}

// SetHeaders sets the application headers for this Context.
// If there is a ParentContext, its headers will be ignored after the call to this method.
func (cb *ContextBuilder) SetHeaders(headers map[string]string) *ContextBuilder {
cb.Headers = headers
cb.replaceParentHeaders = true
return cb
}

Expand Down Expand Up @@ -139,6 +152,20 @@ func (cb *ContextBuilder) SetTimeoutPerAttempt(timeoutPerAttempt time.Duration)
return cb
}

// SetParentContext sets the parent for the Context.
func (cb *ContextBuilder) SetParentContext(ctx context.Context) *ContextBuilder {
cb.ParentContext = ctx
return cb
}

// SetExternalSpan creates a new TChannel tracing Span from externally provided IDs
// and sets it as the current span for the context.
// Intended for integration with other Zipkin-like tracers.
func (cb *ContextBuilder) SetExternalSpan(traceID, spanID, parentID uint64, traced bool) *ContextBuilder {
span := newSpan(traceID, spanID, parentID, traced)
return cb.setSpan(span)
}

func (cb *ContextBuilder) setSpan(span *Span) *ContextBuilder {
cb.span = span
return cb
Expand All @@ -149,22 +176,58 @@ func (cb *ContextBuilder) setIncomingCall(call IncomingCall) *ContextBuilder {
return cb
}

func (cb *ContextBuilder) getSpan() *Span {
if cb.span != nil {
return cb.span
}
if cb.ParentContext != nil {
if span := CurrentSpan(cb.ParentContext); span != nil {
return span
}
}
return NewRootSpan()
}

func (cb *ContextBuilder) getHeaders() map[string]string {
if cb.ParentContext == nil || cb.replaceParentHeaders {
return cb.Headers
}

parent, ok := cb.ParentContext.Value(contextKeyHeaders).(*headersContainer)
if !ok || len(parent.reqHeaders) == 0 {
return cb.Headers
}

mergedHeaders := make(map[string]string, len(cb.Headers)+len(parent.reqHeaders))
for k, v := range parent.reqHeaders {
mergedHeaders[k] = v
}
for k, v := range cb.Headers {
mergedHeaders[k] = v
}
return mergedHeaders
}

// Build returns a ContextWithHeaders that can be used to make calls.
func (cb *ContextBuilder) Build() (ContextWithHeaders, context.CancelFunc) {
timeout := cb.Timeout

span := cb.getSpan()
if cb.TracingDisabled {
cb.span.EnableTracing(false)
span.EnableTracing(false)
}

params := &tchannelCtxParams{
options: cb.CallOptions,
span: cb.span,
span: span,
call: cb.incomingCall,
retryOptions: cb.RetryOptions,
}

ctx, cancel := context.WithTimeout(context.Background(), timeout)
parent := cb.ParentContext
if parent == nil {
parent = context.Background()
}
ctx, cancel := context.WithTimeout(parent, cb.Timeout)

ctx = context.WithValue(ctx, contextKeyTChannel, params)
return WrapWithHeaders(ctx, cb.Headers), cancel
return WrapWithHeaders(ctx, cb.getHeaders()), cancel
}
34 changes: 29 additions & 5 deletions context_header.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,29 +38,53 @@ type ContextWithHeaders interface {

type headerCtx struct {
context.Context
}

// headersContainer stores the headers, and is itself stored in the context under `contextKeyHeaders`
type headersContainer struct {
reqHeaders map[string]string
respHeaders map[string]string
}

func (c *headerCtx) headers() *headersContainer {
if h, ok := c.Value(contextKeyHeaders).(*headersContainer); ok {
return h
}
return nil
}

// Headers gets application headers out of the context.
func (c *headerCtx) Headers() map[string]string {
return c.reqHeaders
if h := c.headers(); h != nil {
return h.reqHeaders
}
return nil
}

// ResponseHeaders returns the response headers.
func (c *headerCtx) ResponseHeaders() map[string]string {
return c.respHeaders
if h := c.headers(); h != nil {
return h.respHeaders
}
return nil
}

// SetResponseHeaders sets the response headers.
func (c *headerCtx) SetResponseHeaders(headers map[string]string) {
c.respHeaders = headers
if h := c.headers(); h != nil {
h.respHeaders = headers
return
}
panic("SetResponseHeaders called on ContextWithHeaders not created via WrapWithHeaders")
}

// WrapWithHeaders returns a Context that can be used to make a call with request headers.
// If the parent `ctx` is already an instance of ContextWithHeaders, its existing headers
// will be ignored. In order to merge new headers with parent headers, use ContextBuilder.
func WrapWithHeaders(ctx context.Context, headers map[string]string) ContextWithHeaders {
return &headerCtx{
Context: ctx,
h := &headersContainer{
reqHeaders: headers,
}
newCtx := context.WithValue(ctx, contextKeyHeaders, h)
return &headerCtx{Context: newCtx}
}
83 changes: 83 additions & 0 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,86 @@ func TestCurrentCallWithNilResult(t *testing.T) {
call := CurrentCall(ctx)
assert.Nil(t, call, "Should return nil.")
}

func getParentContext(t *testing.T) ContextWithHeaders {
ctx := context.WithValue(context.Background(), "some key", "some value")

ctx1, _ := NewContextBuilder(time.Second).
SetParentContext(ctx).
AddHeader("header key", "header value").
Build()
return ctx1
}

func TestContextBuilderParentContextNoHeaders(t *testing.T) {
ctx := getParentContext(t)
assert.EqualValues(t, "header value", ctx.Headers()["header key"])
assert.EqualValues(t, "some value", ctx.Value("some key"), "inherited from parent ctx")
}

func TestContextBuilderParentContextMergeHeaders(t *testing.T) {
ctx := getParentContext(t)
ctx.Headers()["fixed header"] = "fixed value"

// append header to parent
ctx2, _ := NewContextBuilder(time.Second).
SetParentContext(ctx).
AddHeader("header key 2", "header value 2").
Build()
assert.EqualValues(t, "header value", ctx2.Headers()["header key"], "inherited")
assert.EqualValues(t, "fixed value", ctx2.Headers()["fixed header"], "inherited")
assert.EqualValues(t, "header value 2", ctx2.Headers()["header key 2"], "appended")
assert.Equal(t, 3, len(ctx2.Headers()))

// override parent header
ctx3, _ := NewContextBuilder(time.Second).
SetParentContext(ctx).
AddHeader("header key", "header value 2"). // override
Build()

assert.EqualValues(t, "header value 2", ctx3.Headers()["header key"], "overwritten")
assert.EqualValues(t, "fixed value", ctx2.Headers()["fixed header"], "inherited")
assert.Equal(t, 2, len(ctx3.Headers()))
}

func TestContextBuilderParentContextReplaceHeaders(t *testing.T) {
ctx := getParentContext(t)
ctx.Headers()["fixed header"] = "fixed value"
assert.Equal(t, 2, len(ctx.Headers()))

// replace headers with a new map
ctx2, _ := NewContextBuilder(time.Second).
SetParentContext(ctx).
SetHeaders(map[string]string{"header key": "header value 2"}).
Build()
assert.EqualValues(t, "header value 2", ctx2.Headers()["header key"], "replaced")
assert.Equal(t, 1, len(ctx2.Headers()), "size drops to 1")
}

func TestContextWithHeadersAsContext(t *testing.T) {
var ctx context.Context = getParentContext(t)
assert.EqualValues(t, "some value", ctx.Value("some key"), "inherited from parent ctx")
}

func TestContextBuilderParentContextSpan(t *testing.T) {
ctx := getParentContext(t)
span := NewSpan(5, 4, 3)

ctx2, _ := NewContextBuilder(time.Second).
SetParentContext(ctx).
SetSpanForTest(&span).
Build()
assert.Equal(t, &span, CurrentSpan(ctx2), "explicitly provided span used")

ctx3, _ := NewContextBuilder(time.Second).
SetParentContext(ctx2).
Build()
assert.Equal(t, &span, CurrentSpan(ctx3), "span inherited from parent")

ctx4, _ := NewContextBuilder(time.Second).
SetParentContext(ctx2).
SetExternalSpan(3, 2, 1, true).
Build()
span4 := NewSpan(3, 2, 1)
assert.Equal(t, &span4, CurrentSpan(ctx4), "span inherited from parent")
}
7 changes: 6 additions & 1 deletion retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,12 @@ var requestStatePool = sync.Pool{
}

func getRetryOptions(ctx context.Context) *RetryOptions {
opts := getTChannelParams(ctx).retryOptions
params := getTChannelParams(ctx)
if params == nil {
return defaultRetryOptions
}

opts := params.retryOptions
if opts == nil {
return defaultRetryOptions
}
Expand Down
13 changes: 13 additions & 0 deletions tracing.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,16 @@ func (s *Span) sampleRootSpan(sampleRate float64) {
s.EnableTracing(false)
}
}

func newSpan(traceID, spanID, parentID uint64, tracingEnabled bool) *Span {
flags := byte(0)
if tracingEnabled {
flags = tracingFlagEnabled
}
return &Span{
traceID: traceID,
spanID: spanID,
parentID: parentID,
flags: flags,
}
}
2 changes: 1 addition & 1 deletion tracing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ func TestTraceReportingEnabled(t *testing.T) {
expected := TraceData{Annotations: tt.expected, BinaryAnnotations: binaryAnnotations, Source: source, Target: target, Method: "echo"}
assert.Equal(t, expected, state.call, "%v: Report args mismatch", tt.name)
curSpan := CurrentSpan(ctx)
assert.Equal(t, NewSpan(curSpan.TraceID(), 0, curSpan.TraceID()), state.span, "Span mismatch")
assert.Equal(t, NewSpan(curSpan.TraceID(), curSpan.TraceID(), 0), state.span, "Span mismatch")
})
}
}
Expand Down
4 changes: 2 additions & 2 deletions utils_for_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,6 @@ func InboundConnection(call IncomingCall) (*Connection, net.Conn) {
}

// NewSpan returns a Span for testing.
func NewSpan(traceID uint64, parentID uint64, spanID uint64) Span {
return Span{traceID: traceID, parentID: parentID, spanID: spanID, flags: defaultTracingFlags}
func NewSpan(traceID uint64, spanID uint64, parentID uint64) Span {
return Span{traceID: traceID, spanID: spanID, parentID: parentID, flags: defaultTracingFlags}
}

0 comments on commit 8a0c71f

Please sign in to comment.