Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 85 additions & 27 deletions in_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,17 @@ func (state inSession) FixMsgIn(session *session, msg Message) (nextState sessio
if err := msg.Header.GetField(tagMsgType, &msgType); err == nil {
switch string(msgType) {
case enum.MsgType_LOGON:
session.handleLogon(msg)
if err := session.handleLogon(msg); err != nil {
return session.handleError(err)
}

return
case enum.MsgType_LOGOUT:
session.log.OnEvent("Received logout request")
session.log.OnEvent("Sending logout response")
session.sendLogout("")
if err := session.sendLogout(""); err != nil {
return session.handleError(err)
}
nextState = latentState{}
case enum.MsgType_TEST_REQUEST:
return state.handleTestRequest(session, msg)
Expand All @@ -56,7 +61,9 @@ func (state inSession) FixMsgIn(session *session, msg Message) (nextState sessio
}
}

session.store.IncrNextTargetMsgSeqNum()
if err := session.store.IncrNextTargetMsgSeqNum(); err != nil {
return session.handleError(err)
}
return
}

Expand All @@ -65,7 +72,9 @@ func (state inSession) FixMsgInRej(session *session, msg Message, rej MessageRej
if err := msg.Header.GetField(tagMsgType, &msgType); err == nil {
switch string(msgType) {
case enum.MsgType_LOGON:
session.initiateLogout("")
if err := session.initiateLogout(""); err != nil {
return session.handleError(err)
}
return logoutState{}
case enum.MsgType_LOGOUT:
return latentState{}
Expand All @@ -79,12 +88,16 @@ func (state inSession) Timeout(session *session, event event) (nextState session
case needHeartbeat:
heartBt := NewMessage()
heartBt.Header.SetField(tagMsgType, FIXString("0"))
session.send(heartBt)
if err := session.send(heartBt); err != nil {
return session.handleError(err)
}
case peerTimeout:
testReq := NewMessage()
testReq.Header.SetField(tagMsgType, FIXString("1"))
testReq.Body.SetField(tagTestReqID, FIXString("TEST"))
session.send(testReq)
if err := session.send(testReq); err != nil {
return session.handleError(err)
}
session.log.OnEvent("Sent test request TEST")
session.peerTimer.Reset(time.Duration(int64(1.2 * float64(session.heartBeatTimeout))))
return pendingTimeout{}
Expand All @@ -100,10 +113,14 @@ func (state inSession) handleTestRequest(session *session, msg Message) (nextSta
heartBt := NewMessage()
heartBt.Header.SetField(tagMsgType, FIXString("0"))
heartBt.Body.SetField(tagTestReqID, testReq)
session.send(heartBt)
if err := session.send(heartBt); err != nil {
return session.handleError(err)
}
}

session.store.IncrNextTargetMsgSeqNum()
if err := session.store.IncrNextTargetMsgSeqNum(); err != nil {
return session.handleError(err)
}
return state
}

Expand All @@ -115,10 +132,14 @@ func (state inSession) handleSequenceReset(session *session, msg Message) (nextS

switch {
case newSeqNo > expectedSeqNum:
session.store.SetNextTargetMsgSeqNum(int(newSeqNo))
if err := session.store.SetNextTargetMsgSeqNum(int(newSeqNo)); err != nil {
return session.handleError(err)
}
case newSeqNo < expectedSeqNum:
//FIXME: to be compliant with legacy tests, do not include tag in reftagid? (11c_NewSeqNoLess)
session.doReject(msg, valueIsIncorrectNoTag())
if err := session.doReject(msg, valueIsIncorrectNoTag()); err != nil {
return session.handleError(err)
}
}
}
return state
Expand Down Expand Up @@ -149,16 +170,21 @@ func (state inSession) handleResendRequest(session *session, msg Message) (nextS
endSeqNo = expectedSeqNum - 1
}

state.resendMessages(session, int(beginSeqNo), endSeqNo)
session.store.IncrNextTargetMsgSeqNum()
if err := state.resendMessages(session, int(beginSeqNo), endSeqNo); err != nil {
return session.handleError(err)
}

if err := session.store.IncrNextTargetMsgSeqNum(); err != nil {
return session.handleError(err)
}
return state
}

func (state inSession) resendMessages(session *session, beginSeqNo, endSeqNo int) {
func (state inSession) resendMessages(session *session, beginSeqNo, endSeqNo int) (err error) {
msgs, err := session.store.GetMessages(beginSeqNo, endSeqNo)
if err != nil {
session.log.OnEventf("error retrieving messages from store: %s", err.Error())
panic(err)
return
}

seqNum := beginSeqNo
Expand All @@ -177,8 +203,10 @@ func (state inSession) resendMessages(session *session, beginSeqNo, endSeqNo int
state.generateSequenceReset(session, seqNum, sentMessageSeqNum)
}

session.resend(msg)
session.log.OnEventf("Resending Message: %v", sentMessageSeqNum)
if err = session.resend(msg); err != nil {
return
}

seqNum = sentMessageSeqNum + 1
nextSeqNum = seqNum
Expand All @@ -187,6 +215,8 @@ func (state inSession) resendMessages(session *session, beginSeqNo, endSeqNo int
if seqNum != nextSeqNum { // gapfill for catch-up
state.generateSequenceReset(session, seqNum, nextSeqNum)
}

return
}

func (state inSession) processReject(session *session, msg Message, rej MessageRejectError) (nextState sessionState) {
Expand All @@ -195,7 +225,9 @@ func (state inSession) processReject(session *session, msg Message, rej MessageR

switch session.sessionState.(type) {
default:
session.doTargetTooHigh(TypedError)
if err := session.doTargetTooHigh(TypedError); err != nil {
return session.handleError(err)
}
case resendState:
//assumes target too high reject already sent
}
Expand All @@ -206,18 +238,30 @@ func (state inSession) processReject(session *session, msg Message, rej MessageR
case targetTooLow:
return state.doTargetTooLow(session, msg, TypedError)
case incorrectBeginString:
session.initiateLogout(rej.Error())
if err := session.initiateLogout(rej.Error()); err != nil {
session.handleError(err)
}
return logoutState{}
}

switch rej.RejectReason() {
case rejectReasonCompIDProblem, rejectReasonSendingTimeAccuracyProblem:
session.doReject(msg, rej)
session.initiateLogout("")
if err := session.doReject(msg, rej); err != nil {
return session.handleError(err)
}

if err := session.initiateLogout(""); err != nil {
return session.handleError(err)
}
return logoutState{}
default:
session.doReject(msg, rej)
session.store.IncrNextTargetMsgSeqNum()
if err := session.doReject(msg, rej); err != nil {
return session.handleError(err)
}

if err := session.store.IncrNextTargetMsgSeqNum(); err != nil {
return session.handleError(err)
}
return state
}
}
Expand All @@ -228,26 +272,40 @@ func (state inSession) doTargetTooLow(session *session, msg Message, rej targetT

origSendingTime := new(FIXUTCTimestamp)
if err = msg.Header.GetField(tagOrigSendingTime, origSendingTime); err != nil {
session.doReject(msg, RequiredTagMissing(tagOrigSendingTime))
if rejErr := session.doReject(msg, RequiredTagMissing(tagOrigSendingTime)); rejErr != nil {
return session.handleError(rejErr)
}
return state
}

sendingTime := new(FIXUTCTimestamp)
msg.Header.GetField(tagSendingTime, sendingTime)

if sendingTime.Before(origSendingTime.Time) {
session.doReject(msg, sendingTimeAccuracyProblem())
session.initiateLogout("")
if err := session.doReject(msg, sendingTimeAccuracyProblem()); err != nil {
return session.handleError(err)
}

if err := session.initiateLogout(""); err != nil {
return session.handleError(err)
}
return logoutState{}
}

if appReject := session.fromCallback(msg); appReject != nil {
session.doReject(msg, appReject)
session.initiateLogout("")
if err := session.doReject(msg, appReject); err != nil {
return session.handleError(err)
}

if err := session.initiateLogout(""); err != nil {
return session.handleError(err)
}
return logoutState{}
}
} else {
session.initiateLogout(rej.Error())
if err := session.initiateLogout(rej.Error()); err != nil {
return session.handleError(err)
}
return logoutState{}
}

Expand Down
2 changes: 1 addition & 1 deletion initiator.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (i *Initiator) Start() error {
return fmt.Errorf("error on SocketConnectPort: %v", err)
}

var reconnectInterval int = 30 // Default configuration (in seconds)
reconnectInterval := 30 // Default configuration (in seconds)
if s.HasSetting(config.ReconnectInterval) {
if reconnectInterval, err = s.IntSetting(config.ReconnectInterval); err != nil {
return fmt.Errorf("error on ReconnectInterval: %v", err)
Expand Down
Loading