Skip to content

Commit

Permalink
relay: Fix race between connection close and new call
Browse files Browse the repository at this point in the history
Currently, relay items are added without checking the connection state.
This can lead to new items being added while a connection is closing.

This change makes checking the connection state atomic with incrementing
the pending count. This ensures that only active connections can accept
new incoming calls.
  • Loading branch information
prashantv committed May 25, 2016
1 parent 090c354 commit 00ffb0d
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 8 deletions.
7 changes: 7 additions & 0 deletions connection.go
Expand Up @@ -882,6 +882,13 @@ func (c *Connection) checkExchanges() {
}

if c.readState() == connectionInboundClosed {
// Safety check -- this should never happen since we already did the check
// when transitioning to connectionInboundClosed.
if !c.relay.canClose() {
c.relay.logger.Error("Relay can't close even though state is InboundClosed.")
return
}

if c.outbound.count() == 0 && moveState(connectionInboundClosed, connectionClosed) {
updated = connectionClosed
}
Expand Down
46 changes: 40 additions & 6 deletions relay.go
Expand Up @@ -214,19 +214,31 @@ func (r *Relayer) Receive(f *Frame, fType frameType) {
}
}

func (r *Relayer) handleCallReq(f lazyCallReq) error {
func (r *Relayer) canHandleNewCall() bool {
var canHandle bool
r.conn.withStateRLock(func() error {
canHandle = r.conn.state == connectionActive
if canHandle {
r.pending.Inc()
}
return nil
})
return canHandle
}

func (r *Relayer) getDestination(f lazyCallReq) (*Connection, bool, error) {
if _, ok := r.outbound.Get(f.Header.ID); ok {
r.logger.WithFields(LogField{"id", f.Header.ID}).Warn("received duplicate callReq")
// TODO: this is a protocol error, kill the connection.
return errors.New("callReq with already active ID")
return nil, false, errors.New("callReq with already active ID")
}

// Get the destination
hostPort := r.hosts.Get(f)
if hostPort == "" {
// TODO: What is the span in the error frame actually used for, and do we need it?
r.conn.SendSystemError(f.Header.ID, nil, errUnknownGroup(f.Service()))
return nil
return nil, false, nil
}
peer := r.peers.GetOrAdd(hostPort)

Expand All @@ -239,10 +251,32 @@ func (r *Relayer) handleCallReq(f lazyCallReq) error {
).Warn("Failed to connect to relay host.")
// TODO: Same as above, do we need span here?
r.conn.SendSystemError(f.Header.ID, nil, NewWrappedSystemError(ErrCodeNetwork, err))
return nil
return nil, false, nil
}

return remoteConn, true, nil
}

func (r *Relayer) handleCallReq(f lazyCallReq) error {
if !r.canHandleNewCall() {
return ErrChannelClosed
}

// Get a remote connection and check whether it can handle this call.
remoteConn, ok, err := r.getDestination(f)
if err == nil && ok {
if !remoteConn.relay.canHandleNewCall() {
err = NewSystemError(ErrCodeNetwork, "selected closed connection, retry")
}
}
if err != nil || !ok {
// Failed to get a remote connection, or the connection is not in the right
// state to handle this call. Since we already incremented pending on
// the current relay, we need to decrement it.
r.pending.Dec()
return err
}

// TODO: Is there a race for adding the same ID twice?
destinationID := remoteConn.NextMessageID()
ttl := f.TTL()
remoteConn.relay.addRelayItem(false /* isOriginator */, destinationID, f.Header.ID, r, ttl)
Expand Down Expand Up @@ -290,7 +324,6 @@ func (r *Relayer) addRelayItem(isOriginator bool, id, remapID uint32, destinatio
remapID: remapID,
destination: destination,
}
r.pending.Inc()

items := r.inbound
if isOriginator {
Expand All @@ -310,6 +343,7 @@ func (r *Relayer) timeoutRelayItem(items *relayItems, id uint32, isOriginator bo
r.conn.SendSystemError(id, nil, ErrTimeout)
}
r.pending.Dec()

r.conn.checkExchanges()
}

Expand Down
56 changes: 54 additions & 2 deletions relay_test.go
Expand Up @@ -6,11 +6,14 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
. "github.com/uber/tchannel-go"

"github.com/uber/tchannel-go/atomic"
"github.com/uber/tchannel-go/raw"
"github.com/uber/tchannel-go/testutils"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type relayTest struct {
Expand Down Expand Up @@ -160,3 +163,52 @@ func TestErrorFrameEndsRelay(t *testing.T) {
assert.Equal(t, ErrCodeBadRequest, se.Code(), "Expected BadRequest error")
})
}

// Trigger a race between receiving a new call and a connection closing
// by closing the relay while a lot of background calls are being made.
func TestRaceCloseWithNewCall(t *testing.T) {
opts := serviceNameOpts("s1").SetRelayOnly().DisableLogVerification()
testutils.WithTestServer(t, opts, func(ts *testutils.TestServer) {
s1 := ts.Server()
s2 := ts.NewServer(serviceNameOpts("s2").DisableLogVerification())
testutils.RegisterEcho(s1, nil)

// signal to start closing the relay.
var (
closeRelay sync.WaitGroup
stopCalling atomic.Int32
callers sync.WaitGroup
)

for i := 0; i < 20; i++ {
callers.Add(1)
closeRelay.Add(1)

go func() {
defer callers.Done()

calls := 0
for stopCalling.Load() == 0 {
testutils.CallEcho(s2, ts.HostPort(), "s1", nil)
calls++
if calls == 10 {
closeRelay.Done()
}
}
}()
}

closeRelay.Wait()

// Close the relay, wait for it to close.
ts.Relay().Close()
closed := testutils.WaitFor(time.Second, func() bool {
return ts.Relay().State() == ChannelClosed
})
assert.True(t, closed, "Relay did not close within timeout")

// Now stop all calls, and wait for the calling goroutine to end.
stopCalling.Inc()
callers.Wait()
})
}

0 comments on commit 00ffb0d

Please sign in to comment.