Skip to content

Commit

Permalink
Make ContextWithHeaders use net.Context field storage
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuri Shkuro committed Mar 22, 2016
1 parent ea8d37b commit d4cffca
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 10 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ lint.log
*.swp
.DS_Store
.idea
tchannel-go.iml
5 changes: 4 additions & 1 deletion context.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ const defaultTimeout = time.Second

type contextKey int

const contextKeyTChannel = 1
const (
contextKeyTChannel contextKey = iota
contextKeyHeaders
)

type tchannelCtxParams struct {
span *Span
Expand Down
20 changes: 17 additions & 3 deletions context_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ type ContextBuilder struct {
// RetryOptions are the retry options for this call.
RetryOptions *RetryOptions

// ParentContext to build the new context from. If empty, context.Background() is used.
// If ParentContext is of type ContextWithHeaders, then parent headers will be carried
// over to the child context, unless overwritten via `AddHeader()`.
ParentContext context.Context

// Hidden fields: we do not want users outside of tchannel to set these.
incomingCall IncomingCall
span *Span
Expand Down Expand Up @@ -139,6 +144,12 @@ func (cb *ContextBuilder) SetTimeoutPerAttempt(timeoutPerAttempt time.Duration)
return cb
}

// SetParentContext sets the parent for the Context.
func (cb *ContextBuilder) SetParentContext(ctx context.Context) *ContextBuilder {
cb.ParentContext = ctx
return cb
}

func (cb *ContextBuilder) setSpan(span *Span) *ContextBuilder {
cb.span = span
return cb
Expand All @@ -151,8 +162,6 @@ func (cb *ContextBuilder) setIncomingCall(call IncomingCall) *ContextBuilder {

// Build returns a ContextWithHeaders that can be used to make calls.
func (cb *ContextBuilder) Build() (ContextWithHeaders, context.CancelFunc) {
timeout := cb.Timeout

if cb.TracingDisabled {
cb.span.EnableTracing(false)
}
Expand All @@ -164,7 +173,12 @@ func (cb *ContextBuilder) Build() (ContextWithHeaders, context.CancelFunc) {
retryOptions: cb.RetryOptions,
}

ctx, cancel := context.WithTimeout(context.Background(), timeout)
parent := cb.ParentContext
if parent == nil {
parent = context.Background()
}
ctx, cancel := context.WithTimeout(parent, cb.Timeout)

ctx = context.WithValue(ctx, contextKeyTChannel, params)
return WrapWithHeaders(ctx, cb.Headers), cancel
}
46 changes: 41 additions & 5 deletions context_header.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,29 +38,65 @@ type ContextWithHeaders interface {

type headerCtx struct {
context.Context
}

// headersContainer stores the headers, and is itself stored in the context under `contextKeyHeaders`
type headersContainer struct {
reqHeaders map[string]string
respHeaders map[string]string
}

func (c *headerCtx) headers() *headersContainer {
if h, ok := c.Value(contextKeyHeaders).(*headersContainer); ok {
return h
}
return nil
}

// Headers gets application headers out of the context.
func (c *headerCtx) Headers() map[string]string {
return c.reqHeaders
if h := c.headers(); h != nil {
return h.reqHeaders
}
return nil
}

// ResponseHeaders returns the response headers.
func (c *headerCtx) ResponseHeaders() map[string]string {
return c.respHeaders
if h := c.headers(); h != nil {
return h.respHeaders
}
return nil
}

// SetResponseHeaders sets the response headers.
func (c *headerCtx) SetResponseHeaders(headers map[string]string) {
c.respHeaders = headers
if h := c.headers(); h != nil {
h.respHeaders = headers
return
}
panic("SetResponseHeaders called on ContextWithHeaders not created via WrapWithHeaders")
}

// WrapWithHeaders returns a Context that can be used to make a call with request headers.
// If the parent `ctx` is already an instance of ContextWithHeaders, its Headers will be
// carried over, but the new `headers` will take precedence. ResponseHeaders are not copied.
func WrapWithHeaders(ctx context.Context, headers map[string]string) ContextWithHeaders {
return &headerCtx{
Context: ctx,
h := &headersContainer{
reqHeaders: headers,
}
if parent_h, ok := ctx.Value(contextKeyHeaders).(*headersContainer); ok {
if pl := len(parent_h.reqHeaders); pl > 0 {
mergedHeaders := make(map[string]string, len(headers) + pl)
for k, v := range parent_h.reqHeaders {
mergedHeaders[k] = v
}
for k, v := range headers {
mergedHeaders[k] = v
}
h.reqHeaders = mergedHeaders
}
}
newCtx := context.WithValue(ctx, contextKeyHeaders, h)
return &headerCtx{Context: newCtx}
}
38 changes: 38 additions & 0 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,41 @@ func TestCurrentCallWithNilResult(t *testing.T) {
call := CurrentCall(ctx)
assert.Nil(t, call, "Should return nil.")
}

func TestContextWithHeaders(t *testing.T) {
ctx := context.WithValue(context.Background(), "some key", "some value")

t_ctx1, _ := NewContextBuilder(time.Second).
AddHeader("header key", "header value").
Build()
assert.EqualValues(t, "header value", t_ctx1.Headers()["header key"])
assert.Nil(t, t_ctx1.Value("some key"), "no inheritance of parent context")

t_ctx2, _ := NewContextBuilder(time.Second).
SetParentContext(ctx).
AddHeader("header key", "header value").
Build()
assert.EqualValues(t, "header value", t_ctx2.Headers()["header key"])
assert.EqualValues(t, "some value", t_ctx2.Value("some key"), "inherited from parent ctx")

t_ctx2.SetResponseHeaders(map[string]string{"resp key": "resp value"})
assert.EqualValues(t, "resp value", t_ctx2.ResponseHeaders()["resp key"])

ctx = t_ctx2 // test as regular context
assert.EqualValues(t, "some value", ctx.Value("some key"))

t_ctx3, _ := NewContextBuilder(time.Second).
SetParentContext(t_ctx2).
AddHeader("header key2", "header value2").
Build()
assert.EqualValues(t, "header value", t_ctx3.Headers()["header key"], "headers merged")
assert.EqualValues(t, "header value2", t_ctx3.Headers()["header key2"], "headers merged")
assert.EqualValues(t, "some value", t_ctx3.Value("some key"), "inherited from parent ctx")

t_ctx4, _ := NewContextBuilder(time.Second).
SetParentContext(t_ctx2).
AddHeader("header key", "header value2").
Build()
assert.EqualValues(t, "header value2", t_ctx4.Headers()["header key"], "header overwriten")
assert.EqualValues(t, "some value", t_ctx4.Value("some key"), "inherited from parent ctx")
}
7 changes: 6 additions & 1 deletion retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,12 @@ var requestStatePool = sync.Pool{
}

func getRetryOptions(ctx context.Context) *RetryOptions {
opts := getTChannelParams(ctx).retryOptions
params := getTChannelParams(ctx)
if params == nil {
return defaultRetryOptions
}

opts := params.retryOptions
if opts == nil {
return defaultRetryOptions
}
Expand Down

0 comments on commit d4cffca

Please sign in to comment.