From dadf555aea98a195949123a6276226b477c10e59 Mon Sep 17 00:00:00 2001 From: Alexander Valinurov Date: Tue, 23 Oct 2018 23:59:41 +0300 Subject: [PATCH] Proper close connections with context and wait group pattern Proper stop queue - should leave all storage-working gorutines Proper stop consumer - using locks for watch consumer status Fix issue with queue's call consumers Add more tests for server, queue and qos system --- consumer/consumer.go | 32 ++++-- interfaces/interfaces.go | 2 +- qos/qos.go | 1 + qos/qos_test.go | 16 +++ queue/queue.go | 17 +++- queue/queue_consumer_test.go | 9 +- queue/queue_storage_test.go | 2 +- queue/queue_test.go | 188 ++++++++++++++++++++++++++++++++++- server/channel.go | 67 +++++++------ server/connection.go | 146 +++++++++++++++------------ server/connectionMethods.go | 8 +- server/server.go | 11 +- server/server_test.go | 25 +++++ 13 files changed, 404 insertions(+), 120 deletions(-) diff --git a/consumer/consumer.go b/consumer/consumer.go index bed2e5b..77db9ac 100644 --- a/consumer/consumer.go +++ b/consumer/consumer.go @@ -28,10 +28,10 @@ type Consumer struct { noAck bool channel interfaces.Channel queue *queue.Queue + statusLock sync.RWMutex status int qos []*qos.AmqpQos consume chan bool - stopLock sync.RWMutex } // NewConsumer returns new instance of Consumer @@ -73,8 +73,8 @@ func (consumer *Consumer) startConsume() { func (consumer *Consumer) retrieveAndSendMessage() { var message *amqp.Message - consumer.stopLock.RLock() - defer consumer.stopLock.RUnlock() + consumer.statusLock.RLock() + defer consumer.statusLock.RUnlock() if consumer.status == stopped { return } @@ -117,39 +117,55 @@ func (consumer *Consumer) retrieveAndSendMessage() { consumer.queue.GetMetrics().Deliver.Counter.Inc(1) consumer.queue.GetMetrics().ServerDeliver.Counter.Inc(1) - consumer.Consume() + consumer.consumeMsg() + + return } // Pause pause consumer, used by channel.flow change func (consumer *Consumer) Pause() { + consumer.statusLock.Lock() + defer consumer.statusLock.Unlock() consumer.status = paused } // UnPause unpause consumer, used by channel.flow change func (consumer *Consumer) UnPause() { + consumer.statusLock.Lock() + defer consumer.statusLock.Unlock() consumer.status = started } // Consume send signal into consumer channel, than consumer can try to pop message from queue -func (consumer *Consumer) Consume() { +func (consumer *Consumer) Consume() bool { + consumer.statusLock.RLock() + defer consumer.statusLock.RUnlock() + + return consumer.consumeMsg() +} + +func (consumer *Consumer) consumeMsg() bool { if consumer.status == stopped || consumer.status == paused { - return + return false } select { case consumer.consume <- true: + return true default: + return false } } // Stop stops consumer and remove it from queue consumers list func (consumer *Consumer) Stop() { - consumer.stopLock.Lock() - defer consumer.stopLock.Unlock() + consumer.statusLock.Lock() if consumer.status == stopped { + consumer.statusLock.Unlock() return } consumer.status = stopped + consumer.statusLock.Unlock() consumer.queue.RemoveConsumer(consumer.ConsumerTag) close(consumer.consume) } diff --git a/interfaces/interfaces.go b/interfaces/interfaces.go index 0ee4396..262f68a 100644 --- a/interfaces/interfaces.go +++ b/interfaces/interfaces.go @@ -14,7 +14,7 @@ type Channel interface { // Consumer represents base consumer public interface type Consumer interface { - Consume() + Consume() bool Tag() string Cancel() } diff --git a/qos/qos.go b/qos/qos.go index afd2cc6..9b87b95 100644 --- a/qos/qos.go +++ b/qos/qos.go @@ -88,6 +88,7 @@ func (qos *AmqpQos) Release() { qos.currentSize = 0 } +// Copy safe copy current qos instance to new one func (qos *AmqpQos) Copy() *AmqpQos { qos.Lock() defer qos.Unlock() diff --git a/qos/qos_test.go b/qos/qos_test.go index 411984d..3ec263d 100644 --- a/qos/qos_test.go +++ b/qos/qos_test.go @@ -95,3 +95,19 @@ func TestAmqpQos_Release(t *testing.T) { t.Fatalf("Release: Expected currentSize %d, actual %d", 0, q.currentCount) } } + +func TestAmqpQos_Copy(t *testing.T) { + q := NewAmqpQos(5, 10) + q.Inc(1, 6) + + q2 := q.Copy() + q.Release() + + if q2.currentCount != 1 { + t.Fatalf("Expected currentCount %d, actual %d", 0, q.currentCount) + } + + if q2.currentSize != 6 { + t.Fatalf("Expected currentSize %d, actual %d", 0, q.currentCount) + } +} diff --git a/queue/queue.go b/queue/queue.go index 3a6b77d..398044e 100644 --- a/queue/queue.go +++ b/queue/queue.go @@ -63,6 +63,7 @@ type Queue struct { lastMemMsgId uint64 swappedToDisk bool maybeLoadFromStorageCh chan bool + wg *sync.WaitGroup } // NewQueue returns new instance of Queue @@ -85,6 +86,7 @@ func NewQueue(name string, connID uint64, exclusive bool, autoDelete bool, durab currentConsumer: 0, autoDeleteQueue: autoDeleteQueue, swappedToDisk: false, + wg: &sync.WaitGroup{}, metrics: &MetricsState{ Ready: metrics.NewTrackCounter(0, true), Unacked: metrics.NewTrackCounter(0, true), @@ -110,7 +112,9 @@ func (queue *Queue) Start() { defer queue.actLock.Unlock() queue.active = true + queue.wg.Add(1) go func() { + defer queue.wg.Done() for range queue.call { func() { queue.cmrLock.RLock() @@ -122,13 +126,17 @@ func (queue *Queue) Start() { } queue.currentConsumer = (queue.currentConsumer + 1) % cmrCount cmr := queue.consumers[queue.currentConsumer] - cmr.Consume() + if cmr.Consume() { + return + } } }() } }() + queue.wg.Add(1) go func() { + defer queue.wg.Done() for range queue.maybeLoadFromStorageCh { queue.mayBeLoadFromStorage() } @@ -142,6 +150,9 @@ func (queue *Queue) Stop() error { defer queue.actLock.Unlock() queue.active = false + close(queue.maybeLoadFromStorageCh) + close(queue.call) + queue.wg.Wait() return nil } @@ -468,10 +479,6 @@ func (queue *Queue) Delete(ifUnused bool, ifEmpty bool) (uint64, error) { queue.metrics.ServerTotal.Counter.Dec(int64(length)) queue.metrics.ServerReady.Counter.Dec(int64(length)) - // TODO Proper close channels - //close(queue.maybeLoadFromStorageCh) - //close(queue.call) - return length, nil } diff --git a/queue/queue_consumer_test.go b/queue/queue_consumer_test.go index d2fd79d..7e396d4 100644 --- a/queue/queue_consumer_test.go +++ b/queue/queue_consumer_test.go @@ -6,12 +6,13 @@ import ( // ConsumerMock implements AMQP consumer mock type ConsumerMock struct { - tag string + tag string + cancel bool } // Consume send signal into consumer channel, than consumer can try to pop message from queue -func (consumer *ConsumerMock) Consume() { - +func (consumer *ConsumerMock) Consume() bool { + return true } // Stop stops consumer and remove it from queue consumers list @@ -21,7 +22,7 @@ func (consumer *ConsumerMock) Stop() { // Cancel stops consumer and send basic.cancel method to the client func (consumer *ConsumerMock) Cancel() { - + consumer.cancel = true } // Tag returns consumer tag diff --git a/queue/queue_storage_test.go b/queue/queue_storage_test.go index 3d52581..011825b 100644 --- a/queue/queue_storage_test.go +++ b/queue/queue_storage_test.go @@ -54,7 +54,7 @@ func (storage *MsgStorageMock) PurgeQueue(queue string) { } func (storage *MsgStorageMock) GetQueueLength(queue string) uint64 { - return 0 + return uint64(len(storage.messages)) } func (storage *MsgStorageMock) IterateByQueueFromMsgID(queue string, msgId uint64, limit uint64, fn func(message *amqp.Message)) uint64 { diff --git a/queue/queue_test.go b/queue/queue_test.go index 006508e..345f2de 100644 --- a/queue/queue_test.go +++ b/queue/queue_test.go @@ -46,7 +46,7 @@ func TestQueue_Property(t *testing.T) { } } -func TestQueue_PushPop(t *testing.T) { +func TestQueue_PushPop_Inactive(t *testing.T) { queue := NewQueue("test", 0, false, false, false, baseConfig, nil, nil, nil) queue.Start() queueLength := SIZE * 8 @@ -54,12 +54,21 @@ func TestQueue_PushPop(t *testing.T) { message := &amqp.Message{ID: uint64(item + 1)} queue.Push(message) } - queue.Stop() + if queue.Pop() != nil { t.Fatal("Expected nil from non-active queue") } +} + +func TestQueue_PushPop(t *testing.T) { + queue := NewQueue("test", 0, false, false, false, baseConfig, nil, nil, nil) queue.Start() + queueLength := SIZE * 8 + for item := 0; item < queueLength; item++ { + message := &amqp.Message{ID: uint64(item + 1)} + queue.Push(message) + } if queue.Length() != uint64(queueLength) { t.Fatalf("expected %d elements, have %d", queueLength, queue.Length()) @@ -136,7 +145,7 @@ func TestQueue_PopQos_Empty(t *testing.T) { } } -func TestQueue_PopQos_Single(t *testing.T) { +func TestQueue_PopQos_Single_Inactive(t *testing.T) { prefetchCount := 10 qosRule := qos.NewAmqpQos(uint16(prefetchCount), 0) @@ -153,8 +162,21 @@ func TestQueue_PopQos_Single(t *testing.T) { if queue.PopQos([]*qos.AmqpQos{qosRule}) != nil { t.Fatal("Expected nil from non-active queue") } +} + +func TestQueue_PopQos_Single(t *testing.T) { + prefetchCount := 10 + qosRule := qos.NewAmqpQos(uint16(prefetchCount), 0) + + queue := NewQueue("test", 0, false, false, false, baseConfig, nil, nil, nil) queue.Start() + queueLength := SIZE * 8 + for item := 0; item < queueLength; item++ { + message := &amqp.Message{ID: uint64(item)} + queue.Push(message) + } + rcvCount := 0 for item := 0; item < queueLength; item++ { message := queue.PopQos([]*qos.AmqpQos{qosRule}) @@ -168,7 +190,7 @@ func TestQueue_PopQos_Single(t *testing.T) { } } -func TestQueue_PopQos_Multiple(t *testing.T) { +func TestQueue_PopQos_Multiple_Inactive(t *testing.T) { prefetchCount := 28 qosRules := []*qos.AmqpQos{ qos.NewAmqpQos(0, 0), @@ -188,7 +210,23 @@ func TestQueue_PopQos_Multiple(t *testing.T) { if queue.PopQos(qosRules) != nil { t.Fatal("Expected nil from non-active queue") } +} + +func TestQueue_PopQos_Multiple(t *testing.T) { + prefetchCount := 28 + qosRules := []*qos.AmqpQos{ + qos.NewAmqpQos(0, 0), + qos.NewAmqpQos(uint16(prefetchCount), 0), + qos.NewAmqpQos(uint16(prefetchCount*2), 0), + } + + queue := NewQueue("test", 0, false, false, false, baseConfig, nil, nil, nil) queue.Start() + queueLength := SIZE * 8 + for item := 0; item < queueLength; item++ { + message := &amqp.Message{ID: uint64(item)} + queue.Push(message) + } rcvCount := 0 for item := 0; item < queueLength; item++ { @@ -245,6 +283,27 @@ func TestQueue_AddConsumer(t *testing.T) { } } +func TestQueue_AddConsumer_Exclusive(t *testing.T) { + queue := NewQueue("test", 0, false, false, false, baseConfig, nil, nil, nil) + queue.Start() + + if err := queue.AddConsumer(&ConsumerMock{}, true); err != nil { + t.Fatal(err) + } + + if queue.wasConsumed == false { + t.Fatalf("Expected wasConsumed true") + } + + if queue.ConsumersCount() != 1 { + t.Fatalf("Expected %d consumers, actual %d", 1, queue.ConsumersCount()) + } + + if queue.AddConsumer(&ConsumerMock{}, false) == nil { + t.Fatalf("Expected error, queue is busy") + } +} + func TestQueue_RemoveConsumer(t *testing.T) { queue := NewQueue("test", 0, false, false, false, baseConfig, nil, nil, nil) queue.Start() @@ -568,3 +627,124 @@ func TestQueue_LoadFromStorage_Swap(t *testing.T) { t.Fatalf("Expected %d messages from queue, actual %d", count, popCount) } } + +func TestQueue_LoadFromMsgStorage_LessMaxMessages(t *testing.T) { + var baseConfig = config.Queue{ShardSize: SIZE, MaxMessagesInRam: 1000} + count := baseConfig.MaxMessagesInRam / 5 + + storagePersisted := NewStorageMock(int(count)) + storageTransient := NewStorageMock(int(count)) + queue := NewQueue("test", 0, false, false, true, baseConfig, storagePersisted, storageTransient, nil) + + var dMode byte = 2 + + var idx uint64 + for i := 0; i < int(count); i++ { + idx++ + // persisted + messageP := &amqp.Message{ + ID: idx, + Header: &amqp.ContentHeader{ + PropertyList: &amqp.BasicPropertyList{ + DeliveryMode: &dMode, + }, + }, + } + idx++ + // transient + messageT := &amqp.Message{ + ID: idx, + Header: &amqp.ContentHeader{ + PropertyList: &amqp.BasicPropertyList{}, + }, + } + storagePersisted.Add(messageP, "test") + storageTransient.Add(messageT, "test") + } + queue.LoadFromMsgStorage() + + if queue.Length() != count { + t.Fatalf("Expected %d messages into the queue, actual %d", count, queue.Length()) + } +} + +func TestQueue_LoadFromMsgStorage_OverMaxMessages(t *testing.T) { + var baseConfig = config.Queue{ShardSize: SIZE, MaxMessagesInRam: 10} + count := baseConfig.MaxMessagesInRam * 5 + + storagePersisted := NewStorageMock(int(count)) + storageTransient := NewStorageMock(int(count)) + queue := NewQueue("test", 0, false, false, true, baseConfig, storagePersisted, storageTransient, nil) + + var dMode byte = 2 + + var idx uint64 + for i := 0; i < int(count); i++ { + idx++ + // persisted + messageP := &amqp.Message{ + ID: idx, + Header: &amqp.ContentHeader{ + PropertyList: &amqp.BasicPropertyList{ + DeliveryMode: &dMode, + }, + }, + } + idx++ + // transient + messageT := &amqp.Message{ + ID: idx, + Header: &amqp.ContentHeader{ + PropertyList: &amqp.BasicPropertyList{}, + }, + } + storagePersisted.Add(messageP, "test") + storageTransient.Add(messageT, "test") + } + queue.LoadFromMsgStorage() + + if queue.Length() != count { + t.Fatalf("Expected %d messages into the queue, actual %d", count, queue.Length()) + } +} + +func TestQueue_AutoDelete(t *testing.T) { + autoDeleteCh := make(chan string, 1) + + queue := NewQueue("test", 0, false, true, false, baseConfig, nil, nil, autoDeleteCh) + queue.Start() + + cmr := &ConsumerMock{} + if err := queue.AddConsumer(cmr, false); err != nil { + t.Fatal(err) + } + + queue.RemoveConsumer(cmr.Tag()) + + tick := time.After(100 * time.Millisecond) + + select { + case <-tick: + t.Fatalf("Expected message to remove queue") + case q := <-autoDeleteCh: + if q != queue.GetName() { + t.Fatalf("Expected %s, actual %s", queue.GetName(), q) + } + } +} + +func TestQueue_CancelConsumers(t *testing.T) { + queue := NewQueue("test", 0, false, false, false, baseConfig, nil, nil, nil) + queue.Start() + + cmr := &ConsumerMock{} + if err := queue.AddConsumer(cmr, false); err != nil { + t.Fatal(err) + } + + queue.Delete(false, false) + + if !cmr.cancel { + t.Fatalf("Expected call consumer.Cancel()") + } +} diff --git a/server/channel.go b/server/channel.go index e866ed4..c585b87 100644 --- a/server/channel.go +++ b/server/channel.go @@ -122,27 +122,38 @@ func (channel *Channel) start() { func (channel *Channel) handleIncoming() { buffer := bytes.NewReader([]byte{}) - for frame := range channel.incoming { - switch frame.Type { - case amqp.FrameMethod: - buffer.Reset(frame.Payload) - method, err := amqp.ReadMethod(buffer, channel.protoVersion) - channel.logger.Debug("Incoming method <- " + method.Name()) - if err != nil { - channel.logger.WithError(err).Error("Error on handling frame") - channel.sendError(amqp.NewConnectionError(amqp.FrameError, err.Error(), 0, 0)) - } - if err := channel.handleMethod(method); err != nil { - channel.sendError(err) - } - case amqp.FrameHeader: - if err := channel.handleContentHeader(frame); err != nil { - channel.sendError(err) + for { + select { + case <-channel.conn.ctx.Done(): + return + case frame := <-channel.incoming: + if frame == nil { + // channel.incoming closed by connection + return } - case amqp.FrameBody: - if err := channel.handleContentBody(frame); err != nil { - channel.sendError(err) + + switch frame.Type { + case amqp.FrameMethod: + buffer.Reset(frame.Payload) + method, err := amqp.ReadMethod(buffer, channel.protoVersion) + channel.logger.Debug("Incoming method <- " + method.Name()) + if err != nil { + channel.logger.WithError(err).Error("Error on handling frame") + channel.sendError(amqp.NewConnectionError(amqp.FrameError, err.Error(), 0, 0)) + } + + if err := channel.handleMethod(method); err != nil { + channel.sendError(err) + } + case amqp.FrameHeader: + if err := channel.handleContentHeader(frame); err != nil { + channel.sendError(err) + } + case amqp.FrameBody: + if err := channel.handleContentBody(frame); err != nil { + channel.sendError(err) + } } } } @@ -160,8 +171,9 @@ func (channel *Channel) sendError(err *amqp.Error) { MethodId: err.MethodID, }) case amqp.ErrorOnConnection: - if channel, ok := channel.conn.channels[0]; ok { - channel.SendMethod(&amqp.ConnectionClose{ + ch := channel.conn.getChannel(0) + if ch != nil { + ch.SendMethod(&amqp.ConnectionClose{ ReplyCode: err.ReplyCode, ReplyText: err.ReplyText, ClassId: err.ClassID, @@ -301,17 +313,14 @@ func (channel *Channel) SendMethod(method amqp.Method) { } func (channel *Channel) sendOutgoing(frame *amqp.Frame) { - defer func() { - if recover() != nil { - // it is possible to send close channel here, cause outgoing channel can be closed by conn.close - // looks like as bad design of frames flow, but at the moment is better to fix goroutine leaks + select { + case <-channel.conn.ctx.Done(): + if channel.id == 0 { + close(channel.outgoing) } - }() - - if channel.status == channelDelete { return + case channel.outgoing <- frame: } - channel.outgoing <- frame } // SendContent send message to consumers or returns to publishers diff --git a/server/connection.go b/server/connection.go index 3ee8ab2..90e1114 100644 --- a/server/connection.go +++ b/server/connection.go @@ -3,9 +3,11 @@ package server import ( "bufio" "bytes" + "context" "fmt" "net" "sort" + "strings" "sync" "sync/atomic" "time" @@ -27,7 +29,6 @@ const ( ConnOpen ConnOpenOK ConnCloseOK - ConnClosing ConnClosed ) @@ -52,6 +53,7 @@ type Connection struct { server *Server netConn *net.TCPConn logger *log.Entry + channelsLock sync.RWMutex channels map[uint16]*Channel outgoing chan *amqp.Frame clientProperties *amqp.Table @@ -66,6 +68,10 @@ type Connection struct { srvMetrics *SrvMetricsState metrics *ConnMetricsState userName string + + wg *sync.WaitGroup + ctx context.Context + cancelCtx context.CancelFunc } // NewConnection returns new instance of amqp Connection @@ -79,8 +85,9 @@ func NewConnection(server *Server, netConn *net.TCPConn) (connection *Connection maxChannels: server.config.Connection.ChannelsMax, maxFrameSize: server.config.Connection.FrameMaxSize, qos: qos.NewAmqpQos(0, 0), - closeCh: make(chan bool, 1), + closeCh: make(chan bool, 2), srvMetrics: server.metrics, + wg: &sync.WaitGroup{}, } connection.logger = log.WithFields(log.Fields{ @@ -101,16 +108,22 @@ func (conn *Connection) initMetrics() { func (conn *Connection) close() { conn.statusLock.Lock() - defer conn.statusLock.Unlock() if conn.status == ConnClosed { + conn.statusLock.Unlock() return } conn.status = ConnClosed + conn.statusLock.Unlock() + + conn.netConn.SetLinger(0) conn.netConn.Close() - close(conn.outgoing) + + conn.cancelCtx() + conn.wg.Wait() // channel0 should we be closed at the end channelIds := make([]int, 0) + conn.channelsLock.Lock() for chID := range conn.channels { channelIds = append(channelIds, int(chID)) } @@ -118,9 +131,10 @@ func (conn *Connection) close() { for _, chID := range channelIds { channel := conn.channels[uint16(chID)] channel.delete() + delete(conn.channels, uint16(chID)) } + conn.channelsLock.Unlock() conn.clearQueues() - //conn.netConn.Close() conn.logger.WithFields(log.Fields{ "vhost": conn.vhostName, @@ -129,13 +143,6 @@ func (conn *Connection) close() { conn.server.removeConnection(conn.id) conn.closeCh <- true - - // now we close incoming channel - for _, chID := range channelIds { - channel := conn.channels[uint16(chID)] - delete(conn.channels, uint16(chID)) - close(channel.incoming) - } } func (conn *Connection) getChannel(id uint16) *Channel { @@ -181,22 +188,6 @@ func (conn *Connection) clearQueues() { } } -func (conn *Connection) setStatus(status int) { - conn.statusLock.Lock() - defer conn.statusLock.Unlock() - - if conn.status == ConnClosed { - return - } - conn.status = status -} - -func (conn *Connection) getStatus() int { - conn.statusLock.RLock() - defer conn.statusLock.RUnlock() - return conn.status -} - func (conn *Connection) handleConnection() { buf := make([]byte, 8) _, err := conn.netConn.Read(buf) @@ -219,39 +210,52 @@ func (conn *Connection) handleConnection() { return } - conn.channels[0] = NewChannel(0, conn) - conn.channels[0].start() + conn.ctx, conn.cancelCtx = context.WithCancel(context.Background()) + + channel := NewChannel(0, conn) + conn.channelsLock.Lock() + conn.channels[channel.id] = channel + conn.channelsLock.Unlock() + + channel.start() + conn.wg.Add(1) go conn.handleOutgoing() + conn.wg.Add(1) go conn.handleIncoming() } func (conn *Connection) handleOutgoing() { - buffer := bufio.NewWriter(conn.netConn) - for frame := range conn.outgoing { - if conn.getStatus() >= ConnClosing { - continue - } + defer func() { + conn.wg.Done() + conn.close() + }() - if err := amqp.WriteFrame(buffer, frame); err != nil { - conn.logger.WithError(err).Warn("writing frame") - conn.setStatus(ConnClosing) - conn.close() - continue - } + buffer := bufio.NewWriter(conn.netConn) + for { + select { + case <-conn.ctx.Done(): + return + case frame := <-conn.outgoing: + if frame == nil { + return + } + if err := amqp.WriteFrame(buffer, frame); err != nil && !conn.isClosedError(err) { + conn.logger.WithError(err).Warn("writing frame") + return + } - if frame.CloseAfter { - conn.setStatus(ConnClosing) - buffer.Flush() - conn.close() - continue - } + if frame.CloseAfter { + buffer.Flush() + return + } - if frame.Sync { - conn.srvMetrics.TrafficOut.Counter.Inc(int64(buffer.Buffered())) - conn.metrics.TrafficOut.Counter.Inc(int64(buffer.Buffered())) - buffer.Flush() - } else { - conn.mayBeFlushBuffer(buffer) + if frame.Sync { + conn.srvMetrics.TrafficOut.Counter.Inc(int64(buffer.Buffered())) + conn.metrics.TrafficOut.Counter.Inc(int64(buffer.Buffered())) + buffer.Flush() + } else { + conn.mayBeFlushBuffer(buffer) + } } } } @@ -273,41 +277,57 @@ func (conn *Connection) mayBeFlushBuffer(buffer *bufio.Writer) { } func (conn *Connection) handleIncoming() { + defer func() { + conn.wg.Done() + conn.close() + }() + buffer := bufio.NewReader(conn.netConn) for { - if conn.getStatus() >= ConnClosing { - return - } - frame, err := amqp.ReadFrame(buffer) if err != nil { - if err.Error() != "EOF" && conn.getStatus() < ConnClosing { + if err.Error() != "EOF" && !conn.isClosedError(err) { conn.logger.WithError(err).Warn("reading frame") } - conn.close() return } - if frame.ChannelID != 0 && conn.getStatus() < ConnOpen { + if conn.status < ConnOpen && frame.ChannelID != 0 { conn.logger.WithError(err).Error("Frame not allowed for unopened connection") - conn.close() return } conn.srvMetrics.TrafficIn.Counter.Inc(int64(len(frame.Payload))) conn.metrics.TrafficIn.Counter.Inc(int64(len(frame.Payload))) + conn.channelsLock.RLock() channel, ok := conn.channels[frame.ChannelID] + conn.channelsLock.RUnlock() + if !ok { channel = NewChannel(frame.ChannelID, conn) + + conn.channelsLock.Lock() conn.channels[frame.ChannelID] = channel - conn.channels[frame.ChannelID].start() + conn.channelsLock.Unlock() + + channel.start() } - channel.incoming <- frame + select { + case <-conn.ctx.Done(): + close(channel.incoming) + return + case channel.incoming <- frame: + } } } +func (conn *Connection) isClosedError(err error) bool { + // See: https://github.com/golang/go/issues/4373 + return err != nil && strings.Contains(err.Error(), "use of closed network connection") +} + func (conn *Connection) GetVirtualHost() *VirtualHost { return conn.virtualHost } diff --git a/server/connectionMethods.go b/server/connectionMethods.go index 879436d..fa1f5cc 100644 --- a/server/connectionMethods.go +++ b/server/connectionMethods.go @@ -92,9 +92,9 @@ func (channel *Channel) connectionTuneOk(method *amqp.ConnectionTuneOk) *amqp.Er channel.conn.maxChannels = method.ChannelMax channel.conn.maxFrameSize = method.FrameMax - if method.Heartbeat > 0 { - channel.conn.close() - } + //if method.Heartbeat > 0 { + // channel.conn.close() + //} return nil } @@ -121,6 +121,6 @@ func (channel *Channel) connectionClose(method *amqp.ConnectionClose) *amqp.Erro } func (channel *Channel) connectionCloseOk(method *amqp.ConnectionCloseOk) *amqp.Error { - channel.conn.close() + go channel.conn.close() return nil } diff --git a/server/server.go b/server/server.go index d76e1e8..25f191a 100644 --- a/server/server.go +++ b/server/server.go @@ -73,7 +73,7 @@ func NewServer(host string, port string, protoVersion string, config *config.Con config: config, users: make(map[string]string), vhosts: make(map[string]*VirtualHost), - connSeq: 1, + connSeq: 0, } server.initMetrics() @@ -135,6 +135,7 @@ func (srv *Server) Stop() { go conn.safeClose(&wg) } wg.Wait() + log.Info("All connections safe closed") // stop exchanges and queues for _, virtualHost := range srv.vhosts { @@ -321,6 +322,14 @@ func (srv *Server) onSignal(sig os.Signal) { } } +// Special method for calling in tests without os.Exit(0) +func (srv *Server) testOnSignal(sig os.Signal) { + switch sig { + case syscall.SIGTERM, syscall.SIGINT: + srv.Stop() + } +} + func (srv *Server) hookSignals() { c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) diff --git a/server/server_test.go b/server/server_test.go index 6814be7..55d0d70 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -191,3 +191,28 @@ func TestServer_GetVhosts(t *testing.T) { } } } + +func TestServer_RealStart(t *testing.T) { + defer (&ServerClient{}).clean() + cfg := getDefaultTestConfig() + metrics.NewTrackRegistry(15, time.Second, true) + server := NewServer("localhost", "55672", proto, &cfg.srvConfig) + go server.Start() + time.Sleep(2 * time.Second) + defer server.Stop() + + conn, err := amqpclient.Dial("amqp://guest:guest@localhost:55672/") + if err != nil { + t.Fatal("Could not connect to real server", err) + return + } + + if len(server.connections) == 0 { + t.Fatal("Expected connected client") + } + + if len(server.connections[1].channels) == 0 { + t.Fatal("Expected channels on connections") + } + defer conn.Close() +}