diff --git a/logon_state_test.go b/logon_state_test.go index 96afc2e90..ee47a2be6 100644 --- a/logon_state_test.go +++ b/logon_state_test.go @@ -333,7 +333,7 @@ func (s *LogonStateTestSuite) TestFixMsgInLogonSeqNumTooHigh() { s.Require().Nil(err) s.MessageType(string(msgTypeLogon), sentMessage) - s.session.sendQueued() + s.session.sendQueued(true) s.MessageType(string(msgTypeResendRequest), s.MockApp.lastToAdmin) s.FieldEquals(tagBeginSeqNo, 1, s.MockApp.lastToAdmin.Body) @@ -373,7 +373,7 @@ func (s *LogonStateTestSuite) TestFixMsgInLogonSeqNumTooLow() { s.Require().Nil(err) s.MessageType(string(msgTypeLogout), sentMessage) - s.session.sendQueued() + s.session.sendQueued(true) s.MessageType(string(msgTypeLogout), s.MockApp.lastToAdmin) s.FieldEquals(tagText, "MsgSeqNum too low, expecting 2 but received 1", s.MockApp.lastToAdmin.Body) } diff --git a/session.go b/session.go index b359245e2..1bf60b122 100644 --- a/session.go +++ b/session.go @@ -235,12 +235,16 @@ func (s *session) queueForSend(msg *Message) error { s.toSend = append(s.toSend, msgBytes) + s.notifyMessageOut() + + return nil +} + +func (s *session) notifyMessageOut() { select { case s.messageEvent <- true: default: } - - return nil } // send will validate, persist, queue the message. If the session is logged on, send all messages in the queue. @@ -261,7 +265,7 @@ func (s *session) sendInReplyTo(msg *Message, inReplyTo *Message) error { } s.toSend = append(s.toSend, msgBytes) - s.sendQueued() + s.sendQueued(true) return nil } @@ -290,7 +294,7 @@ func (s *session) dropAndSendInReplyTo(msg *Message, inReplyTo *Message) error { s.dropQueued() s.toSend = append(s.toSend, msgBytes) - s.sendQueued() + s.sendQueued(true) return nil } @@ -346,9 +350,13 @@ func (s *session) persist(seqNum int, msgBytes []byte) error { return s.store.IncrNextSenderMsgSeqNum() } -func (s *session) sendQueued() { - for _, msgBytes := range s.toSend { - s.sendBytes(msgBytes) +func (s *session) sendQueued(blockUntilSent bool) { + for i, msgBytes := range s.toSend { + if !s.sendBytes(msgBytes, blockUntilSent) { + s.toSend = s.toSend[i:] + s.notifyMessageOut() + return + } } s.dropQueued() @@ -363,18 +371,30 @@ func (s *session) EnqueueBytesAndSend(msg []byte) { defer s.sendMutex.Unlock() s.toSend = append(s.toSend, msg) - s.sendQueued() + s.sendQueued(true) } -func (s *session) sendBytes(msg []byte) { +func (s *session) sendBytes(msg []byte, blockUntilSent bool) bool { if s.messageOut == nil { s.log.OnEventf("Failed to send: disconnected") - return + return false + } + + if blockUntilSent { + s.messageOut <- msg + s.log.OnOutgoing(msg) + s.stateTimer.Reset(s.HeartBtInt) + return true } - s.log.OnOutgoing(msg) - s.messageOut <- msg - s.stateTimer.Reset(s.HeartBtInt) + select { + case s.messageOut <- msg: + s.log.OnOutgoing(msg) + s.stateTimer.Reset(s.HeartBtInt) + return true + default: + return false + } } func (s *session) doTargetTooHigh(reject targetTooHigh) (nextState resendState, err error) { diff --git a/session_state.go b/session_state.go index 230ac8613..527556209 100644 --- a/session_state.go +++ b/session_state.go @@ -105,7 +105,7 @@ func (sm *stateMachine) SendAppMessages(session *session) { defer session.sendMutex.Unlock() if session.IsLoggedOn() { - session.sendQueued() + session.sendQueued(false) } else { session.dropQueued() }