Skip to content

Commit

Permalink
Check relay max timeout in NewChannel
Browse files Browse the repository at this point in the history
Simplify the lazy call req and relay code by validating the configured
max timeout at channel creation time.
  • Loading branch information
Akshay Shah committed Nov 3, 2016
1 parent f8dd94e commit 7651f4d
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 49 deletions.
28 changes: 19 additions & 9 deletions channel.go
Expand Up @@ -23,6 +23,7 @@ package tchannel
import (
"errors"
"fmt"
"math"
"net"
"os"
"path/filepath"
Expand Down Expand Up @@ -172,16 +173,20 @@ func NewChannel(serviceName string, opts *ChannelOptions) (*Channel, error) {
opts = &ChannelOptions{}
}

logger := opts.Logger
if logger == nil {
logger = NullLogger
}

processName := opts.ProcessName
if processName == "" {
processName = fmt.Sprintf("%s[%d]", filepath.Base(os.Args[0]), os.Getpid())
}

logger := opts.Logger
if logger == nil {
logger = NullLogger
}
logger = logger.WithFields(
LogField{"service", serviceName},
LogField{"process", processName},
)

statsReporter := opts.StatsReporter
if statsReporter == nil {
statsReporter = NullStatsReporter
Expand All @@ -197,15 +202,20 @@ func NewChannel(serviceName string, opts *ChannelOptions) (*Channel, error) {
relayStats = opts.RelayStats
}

if opts.RelayMaxTimeout <= 0 {
maxMillis := opts.RelayMaxTimeout / time.Millisecond
if opts.RelayMaxTimeout == 0 {
opts.RelayMaxTimeout = defaultRelayMaxTimeout
} else if opts.RelayMaxTimeout < 0 || maxMillis > math.MaxUint32 {
logger.WithFields(
LogField{"configuredMaxTimeout", opts.RelayMaxTimeout},
LogField{"defaultMaxTimeout", defaultRelayMaxTimeout},
).Warn("Configured RelayMaxTimeout is invalid, using default instead.")
opts.RelayMaxTimeout = defaultRelayMaxTimeout
}

ch := &Channel{
channelConnectionCommon: channelConnectionCommon{
log: logger.WithFields(
LogField{"service", serviceName},
LogField{"process", processName}),
log: logger,
relayStats: relayStats,
relayLocal: toStringSet(opts.RelayLocalHandlers),
statsReporter: statsReporter,
Expand Down
23 changes: 23 additions & 0 deletions channel_test.go
Expand Up @@ -22,8 +22,10 @@ package tchannel

import (
"io/ioutil"
"math"
"os"
"testing"
"time"

"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/mocktracer"
Expand Down Expand Up @@ -82,6 +84,27 @@ func TestStats(t *testing.T) {
assert.Equal(t, "subch", subTags["subchannel"], "subchannel tag missing")
}

func TestRelayMaxTTL(t *testing.T) {
tests := []struct {
max time.Duration
expected time.Duration
}{
{time.Second, time.Second},
{-time.Second, defaultRelayMaxTimeout},
{0, defaultRelayMaxTimeout},
{math.MaxUint32 * time.Millisecond, math.MaxUint32 * time.Millisecond},
{(math.MaxUint32 + 1) * time.Millisecond, defaultRelayMaxTimeout},
}

for _, tt := range tests {
ch, err := NewChannel("svc", &ChannelOptions{
RelayMaxTimeout: tt.max,
})
assert.NoError(t, err, "Unexpected error when creating channel.")
assert.Equal(t, ch.relayMaxTimeout, tt.expected, "Unexpected max timeout on channel.")
}
}

func TestIsolatedSubChannelsDontSharePeers(t *testing.T) {
ch, err := NewChannel("svc", &ChannelOptions{
Logger: NewLogger(ioutil.Discard),
Expand Down
11 changes: 1 addition & 10 deletions relay.go
Expand Up @@ -384,16 +384,7 @@ func (r *Relayer) handleCallReq(f lazyCallReq) error {
ttl := f.TTL()
if ttl > r.maxTimeout {
ttl = r.maxTimeout
if err := f.SetTTL(r.maxTimeout); err != nil {
originalTTL := f.TTL()
r.logger.WithFields(
ErrField(err),
LogField{"maxTTL", r.maxTimeout},
LogField{"originalTTL", originalTTL},
).Warn("Failed to clamp callreq TTL to max.")
// The max TTL is misconfigured, don't use it.
ttl = originalTTL
}
f.SetTTL(r.maxTimeout)
}
span := f.Span()
// The remote side of the relay doesn't need to track stats.
Expand Down
18 changes: 3 additions & 15 deletions relay_messages.go
Expand Up @@ -23,19 +23,14 @@ package tchannel
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"math"
"time"
)

var (
_callerNameKeyBytes = []byte(CallerName)
_routingDelegateKeyBytes = []byte(RoutingDelegate)
_routingKeyKeyBytes = []byte(RoutingKey)

errTTLNegative = errors.New("can't set a negative TTL")
errTTLOverflow = errors.New("TTL overflows uint32")
)

const (
Expand Down Expand Up @@ -172,17 +167,10 @@ func (f lazyCallReq) TTL() time.Duration {
return time.Duration(ttl) * time.Millisecond
}

func (f lazyCallReq) SetTTL(d time.Duration) error {
if d < 0 {
return errTTLNegative
}
millis := d / time.Millisecond
if millis > math.MaxUint32 {
return errTTLOverflow
}
ttl := uint32(millis)
// SetTTL overwrites the frame's TTL.
func (f lazyCallReq) SetTTL(d time.Duration) {
ttl := uint32(d / time.Millisecond)
binary.BigEndian.PutUint32(f.Payload[_ttlIndex:_ttlIndex+_ttlLen], ttl)
return nil
}

// Span returns the Span
Expand Down
14 changes: 1 addition & 13 deletions relay_messages_test.go
Expand Up @@ -21,14 +21,12 @@
package tchannel

import (
"math"
"testing"
"time"

"github.com/uber/tchannel-go/typed"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type testCallReq int
Expand Down Expand Up @@ -269,21 +267,11 @@ func TestLazyCallReqTTL(t *testing.T) {
func TestLazyCallReqSetTTL(t *testing.T) {
withLazyCallReqCombinations(func(crt testCallReq) {
cr := crt.req()
require.NoError(t, cr.SetTTL(time.Second), "Unexpected error setting TTL.")
cr.SetTTL(time.Second)
assert.Equal(t, time.Second, cr.TTL(), "Failed to write TTL to frame.")
})
}

func TestLazyCallReqSetInvalidTTL(t *testing.T) {
tooBig := time.Duration(math.MaxUint32+1) * time.Millisecond
withLazyCallReqCombinations(func(crt testCallReq) {
cr := crt.req()
require.Error(t, cr.SetTTL(-1*time.Second), "Expected setting a negative TTL to be an error.")
require.Error(t, cr.SetTTL(tooBig), "Expected error when setting a TTL that overflows uint32.")
assert.Equal(t, 42*time.Millisecond, cr.TTL(), "Expected erroneous SetTTL calls to leave TTL unchanged.")
})
}

func TestLazyCallResRejectsOtherFrames(t *testing.T) {
assertWrappingPanics(
t,
Expand Down
6 changes: 4 additions & 2 deletions relay_test.go
Expand Up @@ -365,19 +365,21 @@ func TestLargeTimeoutsAreClamped(t *testing.T) {
return &raw.Res{Arg2: args.Arg2, Arg3: args.Arg3}, nil
})

ctx, cancel := NewContext(longTTL)
done := make(chan struct{})
go func() {
ctx, cancel := NewContext(longTTL)
defer cancel()
_, _, _, err := raw.Call(ctx, client, ts.HostPort(), "echo-service", "echo", nil, nil)
require.Error(t, err)
code := GetSystemErrorCode(err)
assert.Equal(t, ErrCodeTimeout, code)
close(done)
}()

select {
case <-time.After(testutils.Timeout(10 * clampTTL)):
t.Fatal("Failed to clamp timeout.")
case <-ctx.Done():
case <-done:
}
})
}
Expand Down

0 comments on commit 7651f4d

Please sign in to comment.