diff --git a/connection.go b/connection.go index d3a5e137..4d732eb3 100644 --- a/connection.go +++ b/connection.go @@ -882,6 +882,13 @@ func (c *Connection) checkExchanges() { } if c.readState() == connectionInboundClosed { + // Safety check -- this should never happen since we already did the check + // when transitioning to connectionInboundClosed. + if !c.relay.canClose() { + c.relay.logger.Error("Relay can't close even though state is InboundClosed.") + return + } + if c.outbound.count() == 0 && moveState(connectionInboundClosed, connectionClosed) { updated = connectionClosed } diff --git a/relay.go b/relay.go index edf5cd6a..b788308f 100644 --- a/relay.go +++ b/relay.go @@ -214,11 +214,23 @@ func (r *Relayer) Receive(f *Frame, fType frameType) { } } -func (r *Relayer) handleCallReq(f lazyCallReq) error { +func (r *Relayer) canHandleNewCall() bool { + var canHandle bool + r.conn.withStateRLock(func() error { + canHandle = r.conn.state == connectionActive + if canHandle { + r.pending.Inc() + } + return nil + }) + return canHandle +} + +func (r *Relayer) getDestination(f lazyCallReq) (*Connection, bool, error) { if _, ok := r.outbound.Get(f.Header.ID); ok { r.logger.WithFields(LogField{"id", f.Header.ID}).Warn("received duplicate callReq") // TODO: this is a protocol error, kill the connection. - return errors.New("callReq with already active ID") + return nil, false, errors.New("callReq with already active ID") } // Get the destination @@ -226,7 +238,7 @@ func (r *Relayer) handleCallReq(f lazyCallReq) error { if hostPort == "" { // TODO: What is the span in the error frame actually used for, and do we need it? r.conn.SendSystemError(f.Header.ID, nil, errUnknownGroup(f.Service())) - return nil + return nil, false, nil } peer := r.peers.GetOrAdd(hostPort) @@ -239,10 +251,32 @@ func (r *Relayer) handleCallReq(f lazyCallReq) error { ).Warn("Failed to connect to relay host.") // TODO: Same as above, do we need span here? r.conn.SendSystemError(f.Header.ID, nil, NewWrappedSystemError(ErrCodeNetwork, err)) - return nil + return nil, false, nil + } + + return remoteConn, true, nil +} + +func (r *Relayer) handleCallReq(f lazyCallReq) error { + if !r.canHandleNewCall() { + return ErrChannelClosed + } + + // Get a remote connection and check whether it can handle this call. + remoteConn, ok, err := r.getDestination(f) + if err == nil && ok { + if !remoteConn.relay.canHandleNewCall() { + err = NewSystemError(ErrCodeNetwork, "selected closed connection, retry") + } + } + if err != nil || !ok { + // Failed to get a remote connection, or the connection is not in the right + // state to handle this call. Since we already incremented pending on + // the current relay, we need to decrement it. + r.pending.Dec() + return err } - // TODO: Is there a race for adding the same ID twice? destinationID := remoteConn.NextMessageID() ttl := f.TTL() remoteConn.relay.addRelayItem(false /* isOriginator */, destinationID, f.Header.ID, r, ttl) @@ -290,7 +324,6 @@ func (r *Relayer) addRelayItem(isOriginator bool, id, remapID uint32, destinatio remapID: remapID, destination: destination, } - r.pending.Inc() items := r.inbound if isOriginator { @@ -310,6 +343,7 @@ func (r *Relayer) timeoutRelayItem(items *relayItems, id uint32, isOriginator bo r.conn.SendSystemError(id, nil, ErrTimeout) } r.pending.Dec() + r.conn.checkExchanges() } diff --git a/relay_test.go b/relay_test.go index 3db76e0f..5fba6465 100644 --- a/relay_test.go +++ b/relay_test.go @@ -6,11 +6,14 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" . "github.com/uber/tchannel-go" + + "github.com/uber/tchannel-go/atomic" "github.com/uber/tchannel-go/raw" "github.com/uber/tchannel-go/testutils" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type relayTest struct { @@ -160,3 +163,52 @@ func TestErrorFrameEndsRelay(t *testing.T) { assert.Equal(t, ErrCodeBadRequest, se.Code(), "Expected BadRequest error") }) } + +// Trigger a race between receiving a new call and a connection closing +// by closing the relay while a lot of background calls are being made. +func TestRaceCloseWithNewCall(t *testing.T) { + opts := serviceNameOpts("s1").SetRelayOnly().DisableLogVerification() + testutils.WithTestServer(t, opts, func(ts *testutils.TestServer) { + s1 := ts.Server() + s2 := ts.NewServer(serviceNameOpts("s2").DisableLogVerification()) + testutils.RegisterEcho(s1, nil) + + // signal to start closing the relay. + var ( + closeRelay sync.WaitGroup + stopCalling atomic.Int32 + callers sync.WaitGroup + ) + + for i := 0; i < 20; i++ { + callers.Add(1) + closeRelay.Add(1) + + go func() { + defer callers.Done() + + calls := 0 + for stopCalling.Load() == 0 { + testutils.CallEcho(s2, ts.HostPort(), "s1", nil) + calls++ + if calls == 10 { + closeRelay.Done() + } + } + }() + } + + closeRelay.Wait() + + // Close the relay, wait for it to close. + ts.Relay().Close() + closed := testutils.WaitFor(time.Second, func() bool { + return ts.Relay().State() == ChannelClosed + }) + assert.True(t, closed, "Relay did not close within timeout") + + // Now stop all calls, and wait for the calling goroutine to end. + stopCalling.Inc() + callers.Wait() + }) +}