diff --git a/in_session.go b/in_session.go index 3cdac27a8..20ab6bde7 100644 --- a/in_session.go +++ b/in_session.go @@ -43,7 +43,7 @@ func (state inSession) FixMsgIn(session *session, msg Message) (nextState sessio case enum.MsgType_LOGOUT: session.log.OnEvent("Received logout request") session.log.OnEvent("Sending logout response") - state.generateLogout(session) + session.sendLogout("") return latentState{} case enum.MsgType_TEST_REQUEST: return state.handleTestRequest(session, msg) @@ -63,7 +63,8 @@ func (state inSession) FixMsgInRej(session *session, msg Message, rej MessageRej if err := msg.Header.GetField(tagMsgType, &msgType); err == nil { switch string(msgType) { case enum.MsgType_LOGON: - return state.initiateLogout(session, "") + session.initiateLogout("") + return logoutState{} case enum.MsgType_LOGOUT: return latentState{} } @@ -199,13 +200,15 @@ func (state inSession) processReject(session *session, msg Message, rej MessageR case targetTooLow: return state.doTargetTooLow(session, msg, TypedError) case incorrectBeginString: - return state.initiateLogout(session, rej.Error()) + session.initiateLogout(rej.Error()) + return logoutState{} } switch rej.RejectReason() { case rejectReasonCompIDProblem, rejectReasonSendingTimeAccuracyProblem: session.doReject(msg, rej) - return state.initiateLogout(session, "") + session.initiateLogout("") + return logoutState{} default: session.doReject(msg, rej) session.store.IncrNextTargetMsgSeqNum() @@ -228,28 +231,23 @@ func (state inSession) doTargetTooLow(session *session, msg Message, rej targetT if sendingTime.Before(origSendingTime.Time) { session.doReject(msg, sendingTimeAccuracyProblem()) - return state.initiateLogout(session, "") + session.initiateLogout("") + return logoutState{} } if appReject := session.fromCallback(msg); appReject != nil { session.doReject(msg, appReject) - return state.initiateLogout(session, "") + session.initiateLogout("") + return logoutState{} } } else { - return state.initiateLogout(session, rej.Error()) + session.initiateLogout(rej.Error()) + return logoutState{} } return state } -func (state *inSession) initiateLogout(session *session, reason string) (nextState logoutState) { - session.log.OnEvent("Inititated logout request") - state.generateLogoutWithReason(session, reason) - time.AfterFunc(time.Duration(2)*time.Second, func() { session.sessionEvent <- logoutTimeout }) - - return -} - func (state *inSession) generateSequenceReset(session *session, beginSeqNo int, endSeqNo int) { sequenceReset := NewMessage() session.fillDefaultHeader(sequenceReset) @@ -269,21 +267,3 @@ func (state *inSession) generateSequenceReset(session *session, beginSeqNo int, msgBytes, _ := sequenceReset.Build() session.sendBytes(msgBytes) } - -func (state *inSession) generateLogout(session *session) { - state.generateLogoutWithReason(session, "") -} - -func (state *inSession) generateLogoutWithReason(session *session, reason string) { - logout := NewMessage() - logout.Header.SetField(tagMsgType, FIXString("5")) - logout.Header.SetField(tagBeginString, FIXString(session.sessionID.BeginString)) - logout.Header.SetField(tagTargetCompID, FIXString(session.sessionID.TargetCompID)) - logout.Header.SetField(tagSenderCompID, FIXString(session.sessionID.SenderCompID)) - - if reason != "" { - logout.Body.SetField(tagText, FIXString(reason)) - } - - session.send(logout) -} diff --git a/session.go b/session.go index 09642e7f3..65843a22b 100644 --- a/session.go +++ b/session.go @@ -172,6 +172,18 @@ func (s *session) fillDefaultHeader(msg Message) { s.insertSendingTime(msg.Header) } +func (s *session) sendLogout(reason string) { + logout := NewMessage() + logout.Header.SetField(tagMsgType, FIXString("5")) + logout.Header.SetField(tagBeginString, FIXString(s.sessionID.BeginString)) + logout.Header.SetField(tagTargetCompID, FIXString(s.sessionID.TargetCompID)) + logout.Header.SetField(tagSenderCompID, FIXString(s.sessionID.SenderCompID)) + if reason != "" { + logout.Body.SetField(tagText, FIXString(reason)) + } + s.send(logout) +} + func (s *session) resend(msg Message) { msg.Header.SetField(tagPossDupFlag, FIXBoolean(true)) @@ -310,6 +322,12 @@ func (s *session) handleLogon(msg Message) error { return nil } +func (s *session) initiateLogout(reason string) { + s.log.OnEvent("Inititated logout request") + s.sendLogout(reason) + time.AfterFunc(time.Duration(2)*time.Second, func() { s.sessionEvent <- logoutTimeout }) +} + func (s *session) verify(msg Message) MessageRejectError { return s.verifySelect(msg, true, true) } @@ -595,8 +613,9 @@ func (s *session) run(msgIn chan fixIn, msgOut chan []byte, quit chan bool) { } case <-quit: quit = nil // prevent infinitly receiving on a closed channel - if state, ok := s.sessionState.(inSession); ok { - s.sessionState = state.initiateLogout(s, "") + if s.IsLoggedOn() { + s.initiateLogout("") + s.sessionState = logoutState{} } else { return }