From 7651f4d9f75d483554663a6d990fef1d6031ebc5 Mon Sep 17 00:00:00 2001 From: Akshay Shah Date: Wed, 2 Nov 2016 21:39:48 -0700 Subject: [PATCH] Check relay max timeout in NewChannel Simplify the lazy call req and relay code by validating the configured max timeout at channel creation time. --- channel.go | 28 +++++++++++++++++++--------- channel_test.go | 23 +++++++++++++++++++++++ relay.go | 11 +---------- relay_messages.go | 18 +++--------------- relay_messages_test.go | 14 +------------- relay_test.go | 6 ++++-- 6 files changed, 51 insertions(+), 49 deletions(-) diff --git a/channel.go b/channel.go index 5f4c4f84..d8525478 100644 --- a/channel.go +++ b/channel.go @@ -23,6 +23,7 @@ package tchannel import ( "errors" "fmt" + "math" "net" "os" "path/filepath" @@ -172,16 +173,20 @@ func NewChannel(serviceName string, opts *ChannelOptions) (*Channel, error) { opts = &ChannelOptions{} } - logger := opts.Logger - if logger == nil { - logger = NullLogger - } - processName := opts.ProcessName if processName == "" { processName = fmt.Sprintf("%s[%d]", filepath.Base(os.Args[0]), os.Getpid()) } + logger := opts.Logger + if logger == nil { + logger = NullLogger + } + logger = logger.WithFields( + LogField{"service", serviceName}, + LogField{"process", processName}, + ) + statsReporter := opts.StatsReporter if statsReporter == nil { statsReporter = NullStatsReporter @@ -197,15 +202,20 @@ func NewChannel(serviceName string, opts *ChannelOptions) (*Channel, error) { relayStats = opts.RelayStats } - if opts.RelayMaxTimeout <= 0 { + maxMillis := opts.RelayMaxTimeout / time.Millisecond + if opts.RelayMaxTimeout == 0 { + opts.RelayMaxTimeout = defaultRelayMaxTimeout + } else if opts.RelayMaxTimeout < 0 || maxMillis > math.MaxUint32 { + logger.WithFields( + LogField{"configuredMaxTimeout", opts.RelayMaxTimeout}, + LogField{"defaultMaxTimeout", defaultRelayMaxTimeout}, + ).Warn("Configured RelayMaxTimeout is invalid, using default instead.") opts.RelayMaxTimeout = defaultRelayMaxTimeout } ch := &Channel{ channelConnectionCommon: channelConnectionCommon{ - log: logger.WithFields( - LogField{"service", serviceName}, - LogField{"process", processName}), + log: logger, relayStats: relayStats, relayLocal: toStringSet(opts.RelayLocalHandlers), statsReporter: statsReporter, diff --git a/channel_test.go b/channel_test.go index 4a52a2cb..95b3d9d3 100644 --- a/channel_test.go +++ b/channel_test.go @@ -22,8 +22,10 @@ package tchannel import ( "io/ioutil" + "math" "os" "testing" + "time" "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/mocktracer" @@ -82,6 +84,27 @@ func TestStats(t *testing.T) { assert.Equal(t, "subch", subTags["subchannel"], "subchannel tag missing") } +func TestRelayMaxTTL(t *testing.T) { + tests := []struct { + max time.Duration + expected time.Duration + }{ + {time.Second, time.Second}, + {-time.Second, defaultRelayMaxTimeout}, + {0, defaultRelayMaxTimeout}, + {math.MaxUint32 * time.Millisecond, math.MaxUint32 * time.Millisecond}, + {(math.MaxUint32 + 1) * time.Millisecond, defaultRelayMaxTimeout}, + } + + for _, tt := range tests { + ch, err := NewChannel("svc", &ChannelOptions{ + RelayMaxTimeout: tt.max, + }) + assert.NoError(t, err, "Unexpected error when creating channel.") + assert.Equal(t, ch.relayMaxTimeout, tt.expected, "Unexpected max timeout on channel.") + } +} + func TestIsolatedSubChannelsDontSharePeers(t *testing.T) { ch, err := NewChannel("svc", &ChannelOptions{ Logger: NewLogger(ioutil.Discard), diff --git a/relay.go b/relay.go index 0af3fc43..0a19c27f 100644 --- a/relay.go +++ b/relay.go @@ -384,16 +384,7 @@ func (r *Relayer) handleCallReq(f lazyCallReq) error { ttl := f.TTL() if ttl > r.maxTimeout { ttl = r.maxTimeout - if err := f.SetTTL(r.maxTimeout); err != nil { - originalTTL := f.TTL() - r.logger.WithFields( - ErrField(err), - LogField{"maxTTL", r.maxTimeout}, - LogField{"originalTTL", originalTTL}, - ).Warn("Failed to clamp callreq TTL to max.") - // The max TTL is misconfigured, don't use it. - ttl = originalTTL - } + f.SetTTL(r.maxTimeout) } span := f.Span() // The remote side of the relay doesn't need to track stats. diff --git a/relay_messages.go b/relay_messages.go index a756bff5..e990cb55 100644 --- a/relay_messages.go +++ b/relay_messages.go @@ -23,9 +23,7 @@ package tchannel import ( "bytes" "encoding/binary" - "errors" "fmt" - "math" "time" ) @@ -33,9 +31,6 @@ var ( _callerNameKeyBytes = []byte(CallerName) _routingDelegateKeyBytes = []byte(RoutingDelegate) _routingKeyKeyBytes = []byte(RoutingKey) - - errTTLNegative = errors.New("can't set a negative TTL") - errTTLOverflow = errors.New("TTL overflows uint32") ) const ( @@ -172,17 +167,10 @@ func (f lazyCallReq) TTL() time.Duration { return time.Duration(ttl) * time.Millisecond } -func (f lazyCallReq) SetTTL(d time.Duration) error { - if d < 0 { - return errTTLNegative - } - millis := d / time.Millisecond - if millis > math.MaxUint32 { - return errTTLOverflow - } - ttl := uint32(millis) +// SetTTL overwrites the frame's TTL. +func (f lazyCallReq) SetTTL(d time.Duration) { + ttl := uint32(d / time.Millisecond) binary.BigEndian.PutUint32(f.Payload[_ttlIndex:_ttlIndex+_ttlLen], ttl) - return nil } // Span returns the Span diff --git a/relay_messages_test.go b/relay_messages_test.go index 52e6d2c0..64c52f7c 100644 --- a/relay_messages_test.go +++ b/relay_messages_test.go @@ -21,14 +21,12 @@ package tchannel import ( - "math" "testing" "time" "github.com/uber/tchannel-go/typed" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) type testCallReq int @@ -269,21 +267,11 @@ func TestLazyCallReqTTL(t *testing.T) { func TestLazyCallReqSetTTL(t *testing.T) { withLazyCallReqCombinations(func(crt testCallReq) { cr := crt.req() - require.NoError(t, cr.SetTTL(time.Second), "Unexpected error setting TTL.") + cr.SetTTL(time.Second) assert.Equal(t, time.Second, cr.TTL(), "Failed to write TTL to frame.") }) } -func TestLazyCallReqSetInvalidTTL(t *testing.T) { - tooBig := time.Duration(math.MaxUint32+1) * time.Millisecond - withLazyCallReqCombinations(func(crt testCallReq) { - cr := crt.req() - require.Error(t, cr.SetTTL(-1*time.Second), "Expected setting a negative TTL to be an error.") - require.Error(t, cr.SetTTL(tooBig), "Expected error when setting a TTL that overflows uint32.") - assert.Equal(t, 42*time.Millisecond, cr.TTL(), "Expected erroneous SetTTL calls to leave TTL unchanged.") - }) -} - func TestLazyCallResRejectsOtherFrames(t *testing.T) { assertWrappingPanics( t, diff --git a/relay_test.go b/relay_test.go index ccb519ee..dd81bf18 100644 --- a/relay_test.go +++ b/relay_test.go @@ -365,19 +365,21 @@ func TestLargeTimeoutsAreClamped(t *testing.T) { return &raw.Res{Arg2: args.Arg2, Arg3: args.Arg3}, nil }) - ctx, cancel := NewContext(longTTL) + done := make(chan struct{}) go func() { + ctx, cancel := NewContext(longTTL) defer cancel() _, _, _, err := raw.Call(ctx, client, ts.HostPort(), "echo-service", "echo", nil, nil) require.Error(t, err) code := GetSystemErrorCode(err) assert.Equal(t, ErrCodeTimeout, code) + close(done) }() select { case <-time.After(testutils.Timeout(10 * clampTTL)): t.Fatal("Failed to clamp timeout.") - case <-ctx.Done(): + case <-done: } }) }