Skip to content
This repository has been archived by the owner on Dec 14, 2020. It is now read-only.

Commit

Permalink
Merge 79c3e52 into a3e580d
Browse files Browse the repository at this point in the history
  • Loading branch information
vcabbage committed Feb 14, 2018
2 parents a3e580d + 79c3e52 commit a77efe0
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 31 deletions.
23 changes: 23 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"crypto/tls"
"encoding/binary"
"errors"
"fmt"
"math"
"math/rand"
Expand All @@ -14,6 +15,16 @@ import (
"time"
)

var (
// ErrSessionClosed is propagated to Sender/Receivers
// when Session.Close() is called.
ErrSessionClosed = errors.New("amqp: session closed")

// ErrLinkClosed returned by send and receive operations when
// Sender.Close() or Receiver.Close() are called.
ErrLinkClosed = errors.New("amqp: link closed")
)

// maxSliceLen is equal to math.MaxInt32 or math.MaxInt64, depending on platform
const maxSliceLen = uint64(^uint(0) >> 1)

Expand Down Expand Up @@ -183,6 +194,9 @@ func newSession(c *conn, channel uint16) *Session {
func (s *Session) Close() error {
s.closeOnce.Do(func() { close(s.close) })
<-s.done
if s.err == ErrSessionClosed {
return nil
}
return s.err
}

Expand Down Expand Up @@ -426,6 +440,7 @@ func (s *Session) mux(remoteBegin *performBegin) {
// release session
select {
case s.conn.delSession <- s:
s.err = ErrSessionClosed
case <-s.conn.done:
s.err = s.conn.getErr()
}
Expand Down Expand Up @@ -928,6 +943,7 @@ func (l *link) mux() {
return
}
case <-l.close:
l.err = ErrLinkClosed
return
case <-l.session.done:
l.err = l.session.err
Expand All @@ -939,6 +955,8 @@ func (l *link) mux() {
l.linkCredit = l.receiver.maxCredit
}

// TODO: Look into avoiding the select statement duplication.

select {
// send data
case tr := <-outgoingTransfers:
Expand All @@ -954,6 +972,7 @@ func (l *link) mux() {
return
}
case <-l.close:
l.err = ErrLinkClosed
return
case <-l.session.done:
l.err = l.session.err
Expand All @@ -969,6 +988,7 @@ func (l *link) mux() {
return
}
case <-l.close:
l.err = ErrLinkClosed
return
case <-l.session.done:
l.err = l.session.err
Expand All @@ -983,6 +1003,9 @@ func (l *link) mux() {
func (l *link) Close() error {
l.closeOnce.Do(func() { close(l.close) })
<-l.done
if l.err == ErrLinkClosed {
return nil
}
return l.err
}

Expand Down
62 changes: 36 additions & 26 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package amqp
import (
"bytes"
"crypto/tls"
"errors"
"io"
"math"
"net"
Expand All @@ -19,7 +20,11 @@ const (

// Errors
var (
ErrTimeout = errorNew("timeout waiting for response")
ErrTimeout = errors.New("amqp: timeout waiting for response")

// ErrConnClosed is propagated to Session and Senders/Receivers
// when Client.Close() is called.
ErrConnClosed = errors.New("amqp: connection closed")
)

// ConnOption is an function for configuring an AMQP connection.
Expand Down Expand Up @@ -153,16 +158,16 @@ type conn struct {
peerMaxFrameSize uint32 // maximum frame size peer will accept

// conn state
errMu sync.Mutex // mux holds errMu from start until shutdown completes; operations are sequential before mux is started
err error // error to be returned to client
doneOnce sync.Once // only close done once
done chan struct{} // indicates the connection is done
closeOnce sync.Once
errMu sync.Mutex // mux holds errMu from start until shutdown completes; operations are sequential before mux is started
err error // error to be returned to client
done chan struct{} // indicates the connection is done

// mux
newSession chan newSessionResp // new Sessions are requested from mux by reading off this channel
delSession chan *Session // session completion is indicated to mux by sending the Session on this channel
connErr chan error // connReader/Writer notifications of an error
newSession chan newSessionResp // new Sessions are requested from mux by reading off this channel
delSession chan *Session // session completion is indicated to mux by sending the Session on this channel
connErr chan error // connReader/Writer notifications of an error
closeMux chan struct{} // indicates that the mux should stop
closeMuxOnce sync.Once

// connReader
rxProto chan protoHeader // protoHeaders received by connReader
Expand Down Expand Up @@ -190,6 +195,7 @@ func newConn(netConn net.Conn, opts ...ConnOption) (*conn, error) {
idleTimeout: DefaultIdleTimeout,
done: make(chan struct{}),
connErr: make(chan error, 2), // buffered to ensure connReader/Writer won't leak
closeMux: make(chan struct{}),
rxProto: make(chan protoHeader),
rxFrame: make(chan frame),
rxDone: make(chan struct{}),
Expand Down Expand Up @@ -245,36 +251,40 @@ func (c *conn) start() error {
}

func (c *conn) Close() error {
c.closeOnce.Do(func() { c.close() })
return c.err
c.closeMuxOnce.Do(func() { close(c.closeMux) })
err := c.getErr()
if err == ErrConnClosed {
return nil
}
return err
}

// close should only be called by conn.mux.
func (c *conn) close() {
c.closeDone() // notify goroutines and blocked functions to exit

// Client.mux holds err lock until shutdown, block until
// shutdown completes, then return the error (if any)
c.errMu.Lock()
defer c.errMu.Unlock()
close(c.done) // notify goroutines and blocked functions to exit

// wait for writing to stop, allows it to send the final close frame
<-c.txDone

err := c.net.Close()
if c.err == nil {
switch {
// conn.err already set
case c.err != nil:

// conn.err not set and c.net.Close() returned a non-nil error
case err != nil:
c.err = err

// no errors
default:
c.err = ErrConnClosed
}

// check rxDone after closing net, otherwise may block
// for up to c.idleTimeout
<-c.rxDone
}

// closeDone closes Client.done if it has not already been closed
func (c *conn) closeDone() {
c.doneOnce.Do(func() { close(c.done) })
}

// getErr returns conn.err.
//
// Must only be called after conn.done is closed.
Expand All @@ -299,11 +309,11 @@ func (c *conn) mux() {
// hold the errMu lock until error or done
c.errMu.Lock()
defer c.errMu.Unlock()
defer c.close() // defer order is important. c.errMu unlock indicates that connection is finally complete

for {
// check if last loop returned an error
if c.err != nil {
c.closeDone()
return
}

Expand Down Expand Up @@ -340,7 +350,7 @@ func (c *conn) mux() {

select {
case session.rx <- fr:
case <-c.done:
case <-c.closeMux:
return
}

Expand Down Expand Up @@ -375,7 +385,7 @@ func (c *conn) mux() {
delete(sessionsByRemoteChannel, s.remoteChannel)

// connection is complete
case <-c.done:
case <-c.closeMux:
return
}
}
Expand Down
133 changes: 128 additions & 5 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
"testing"
"time"

"github.com/Azure/azure-sdk-for-go/arm/servicebus"
"github.com/Azure/azure-sdk-for-go/services/servicebus/mgmt/2017-04-01/servicebus"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/azure"
Expand Down Expand Up @@ -182,7 +182,7 @@ func TestIntegrationRoundTrip(t *testing.T) {
// Wait for Azure to update stats
time.Sleep(1 * time.Second)

q, err := queuesClient.Get(resourceGroup, namespace, queueName)
q, err := queuesClient.Get(context.Background(), resourceGroup, namespace, queueName)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -257,7 +257,7 @@ func TestIntegrationSend(t *testing.T) {
// Wait for Azure to update stats
time.Sleep(1 * time.Second)

q, err := queuesClient.Get(resourceGroup, namespace, queueName)
q, err := queuesClient.Get(context.Background(), resourceGroup, namespace, queueName)
if err != nil {
t.Fatal(err)
}
Expand All @@ -273,6 +273,129 @@ func TestIntegrationSend(t *testing.T) {
}
}

func TestIntegrationClose(t *testing.T) {
queueName, _, cleanup := newTestQueue(t, "close")
defer cleanup()

label := "link"
t.Run(label, func(t *testing.T) {
checkLeaks := leaktest.CheckTimeout(t, 60*time.Second)

// Create client
client := newClient(t, label)
defer client.Close()

// Open a session
session, err := client.NewSession()
if err != nil {
t.Fatal(err)
}

// Create a sender
receiver, err := session.NewReceiver(
amqp.LinkTargetAddress(queueName),
)
if err != nil {
t.Fatal(err)
}

go func() {
err := receiver.Close()
if err != nil {
t.Fatalf("Expected nil error from receiver.Close(), got: %+v", err)
}
}()

_, err = receiver.Receive(context.Background())
if err != amqp.ErrLinkClosed {
t.Fatalf("Expected ErrLinkClosed from receiver.Receiver, got: %+v", err)
return
}

client.Close() // close before leak check

checkLeaks()
})

label = "session"
t.Run(label, func(t *testing.T) {
checkLeaks := leaktest.CheckTimeout(t, 60*time.Second)

// Create client
client := newClient(t, label)
defer client.Close()

// Open a session
session, err := client.NewSession()
if err != nil {
t.Fatal(err)
}

// Create a sender
receiver, err := session.NewReceiver(
amqp.LinkTargetAddress(queueName),
)
if err != nil {
t.Fatal(err)
}

go func() {
err := session.Close()
if err != nil {
t.Fatalf("Expected nil error from session.Close(), got: %+v", err)
}
}()

_, err = receiver.Receive(context.Background())
if err != amqp.ErrSessionClosed {
t.Fatalf("Expected ErrSessionClosed from receiver.Receiver, got: %+v", err)
return
}

client.Close() // close before leak check

checkLeaks()
})

label = "conn"
t.Run(label, func(t *testing.T) {
checkLeaks := leaktest.CheckTimeout(t, 60*time.Second)

// Create client
client := newClient(t, label)
defer client.Close()

// Open a session
session, err := client.NewSession()
if err != nil {
t.Fatal(err)
}

// Create a sender
receiver, err := session.NewReceiver(
amqp.LinkTargetAddress(queueName),
)
if err != nil {
t.Fatal(err)
}

go func() {
err := client.Close()
if err != nil {
t.Fatalf("Expected nil error from client.Close(), got: %+v", err)
}
}()

_, err = receiver.Receive(context.Background())
if err != amqp.ErrConnClosed {
t.Fatalf("Expected ErrConnClosed from receiver.Receiver, got: %+v", err)
return
}

checkLeaks()
})
}

func dump(i interface{}) {
enc := json.NewEncoder(os.Stdout)
enc.SetIndent("", "\t")
Expand Down Expand Up @@ -349,13 +472,13 @@ func newTestQueue(tb testing.TB, suffix string) (string, servicebus.QueuesClient
queuesClient.Authorizer = autorest.NewBearerAuthorizer(token)

params := servicebus.SBQueue{}
_, err = queuesClient.CreateOrUpdate(resourceGroup, namespace, queueName, params)
_, err = queuesClient.CreateOrUpdate(context.Background(), resourceGroup, namespace, queueName, params)
if err != nil {
tb.Fatal(err)
}

cleanup := func() {
_, err = queuesClient.Delete(resourceGroup, namespace, queueName)
_, err = queuesClient.Delete(context.Background(), resourceGroup, namespace, queueName)
if err != nil {
tb.Logf("Unable to remove queue: %s - %v", queueName, err)
}
Expand Down

0 comments on commit a77efe0

Please sign in to comment.