Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions pending_timeout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,23 @@ import (
)

func TestPendingTimeout_SessionTimeout(t *testing.T) {
session := &session{
log: nullLog{},
}

tests := []pendingTimeout{
pendingTimeout{inSession{}},
pendingTimeout{resendState{}},
}

for _, state := range tests {
session := &session{
log: nullLog{},
sessionState: state,
}

nextState := state.Timeout(session, peerTimeout)
assert.IsType(t, latentState{}, nextState)
}
}

func TestPendingTimeout_TimeoutUnchangedState(t *testing.T) {
session := &session{
log: nullLog{},
}

tests := []pendingTimeout{
pendingTimeout{inSession{}},
pendingTimeout{resendState{}},
Expand All @@ -35,6 +32,11 @@ func TestPendingTimeout_TimeoutUnchangedState(t *testing.T) {
testEvents := []event{needHeartbeat, logonTimeout, logoutTimeout}

for _, state := range tests {
session := &session{
log: nullLog{},
sessionState: state,
}

for _, event := range testEvents {
nextState := state.Timeout(session, event)
assert.Equal(t, state, nextState)
Expand Down
6 changes: 1 addition & 5 deletions registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,9 @@ func SendToTarget(m Messagable, sessionID SessionID) error {
session, err := lookupSession(sessionID)
if err != nil {
return err
} else if session.toSend == nil {
return fmt.Errorf("Not logged on")
}

request := sendRequest{msg, make(chan error)}
session.toSend <- request
return <-request.err
return session.queueForSend(msg)
}

type sessionActivate struct {
Expand Down
23 changes: 13 additions & 10 deletions resend_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ func TestResendState_TimeoutPeerTimeout(t *testing.T) {
<-otherEnd
}()

state := resendState{}
session := &session{
store: new(memoryStore),
application: new(TestClient),
messageOut: otherEnd,
log: nullLog{},
store: new(memoryStore),
application: new(TestClient),
messageOut: otherEnd,
log: nullLog{},
sessionState: state,
}
state := resendState{}

nextState := state.Timeout(session, peerTimeout)
assert.Equal(t, pendingTimeout{state}, nextState)
}
Expand All @@ -33,13 +35,14 @@ func TestResendState_TimeoutUnchanged(t *testing.T) {
<-otherEnd
}()

state := resendState{}
session := &session{
store: new(memoryStore),
application: new(TestClient),
messageOut: otherEnd,
log: nullLog{},
store: new(memoryStore),
application: new(TestClient),
messageOut: otherEnd,
log: nullLog{},
sessionState: state,
}
state := resendState{}

tests := []event{needHeartbeat, logonTimeout, logoutTimeout}

Expand Down
125 changes: 98 additions & 27 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package quickfix

import (
"fmt"
"sync"
"time"

"github.com/quickfixgo/quickfix/config"
Expand All @@ -18,9 +19,14 @@ type session struct {

messageOut chan []byte
messageIn chan fixIn
toSend chan sendRequest
resendIn chan Message

//application messages are queued up for send here
toSend []Message

//mutex for access to toSend
sendMutex sync.Mutex

sessionEvent chan event
messageEvent chan bool
application Application
Expand Down Expand Up @@ -212,31 +218,109 @@ func (s *session) resend(msg Message) error {
return nil
}

//send should NOT be called outside of the run loop
//queueForSend will validate, persist, and queue the message for send
func (s *session) queueForSend(msg Message) (err error) {
s.sendMutex.Lock()
defer s.sendMutex.Unlock()

if err = s.prepMessageForSend(&msg); err != nil {
return
}

s.toSend = append(s.toSend, msg)

select {
case s.messageEvent <- true:
default:
}

return
}

//send will validate, persist, queue the message and send all messages in the queue
func (s *session) send(msg Message) (err error) {
s.fillDefaultHeader(msg)
s.sendMutex.Lock()
defer s.sendMutex.Unlock()

if err = s.prepMessageForSend(&msg); err != nil {
return
}

s.toSend = append(s.toSend, msg)
s.sendQueued()

return
}

//dropAndSend will validate and persist the message, then drops the send queue and sends the message
func (s *session) dropAndSend(msg Message) (err error) {

s.sendMutex.Lock()
defer s.sendMutex.Unlock()
if err = s.prepMessageForSend(&msg); err != nil {
return
}

s.dropQueued()
s.toSend = append(s.toSend, msg)
s.sendQueued()

return
}

func (s *session) prepMessageForSend(msg *Message) (err error) {
s.fillDefaultHeader(*msg)
seqNum := s.store.NextSenderMsgSeqNum()
msg.Header.SetField(tagMsgSeqNum, FIXInt(seqNum))

var msgType FIXString
if msg.Header.GetField(tagMsgType, &msgType); isAdminMessageType(string(msgType)) {
s.application.ToAdmin(msg, s.sessionID)
if err = msg.Header.GetField(tagMsgType, &msgType); err != nil {
return err
}

if isAdminMessageType(string(msgType)) {
s.application.ToAdmin(*msg, s.sessionID)
} else {
s.application.ToApp(msg, s.sessionID)
s.application.ToApp(*msg, s.sessionID)
}

var msgBytes []byte
if msgBytes, err = msg.Build(); err != nil {
return
}

if err = s.store.SaveMessage(seqNum, msgBytes); err == nil {
s.sendBytes(msgBytes)
err = s.store.IncrNextSenderMsgSeqNum()
return s.persist(seqNum, msgBytes)
}

func (s *session) persist(seqNum int, msgBytes []byte) error {
if err := s.store.SaveMessage(seqNum, msgBytes); err != nil {
return err
}

return
return s.store.IncrNextSenderMsgSeqNum()
}

func (s *session) sendQueued() {
for _, msg := range s.toSend {
s.sendBytes(msg.rawMessage)
}

s.dropQueued()
}

func (s *session) dropQueued() {
s.toSend = s.toSend[:0]
}

func (s *session) sendOrDropAppMessages() {
s.sendMutex.Lock()
defer s.sendMutex.Unlock()

if s.IsLoggedOn() {
s.sendQueued()
} else {
s.dropQueued()
}
}

func (s *session) sendBytes(msg []byte) {
Expand Down Expand Up @@ -325,7 +409,7 @@ func (s *session) handleLogon(msg Message) error {
}

s.log.OnEvent("Responding to logon request")
if err := s.send(reply); err != nil {
if err := s.dropAndSend(reply); err != nil {
return err
}
} else {
Expand Down Expand Up @@ -538,15 +622,9 @@ type fixIn struct {
receiveTime time.Time
}

type sendRequest struct {
msg Message
err chan error
}

func (s *session) run(msgIn chan fixIn, msgOut chan []byte, quit chan bool) {
s.messageIn = msgIn
s.messageOut = msgOut
s.toSend = make(chan sendRequest)
s.resendIn = make(chan Message, 1)

type fromCallback struct {
Expand All @@ -557,8 +635,6 @@ func (s *session) run(msgIn chan fixIn, msgOut chan []byte, quit chan bool) {

defer func() {
close(s.messageOut)
close(s.toSend)
s.toSend = nil
s.stateTimer.Stop()
s.peerTimer.Stop()
s.onDisconnect()
Expand Down Expand Up @@ -587,7 +663,7 @@ func (s *session) run(msgIn chan fixIn, msgOut chan []byte, quit chan bool) {
}

s.log.OnEvent("Sending logon request")
if err := s.send(logon); err != nil {
if err := s.dropAndSend(logon); err != nil {
s.logError(err)
return
}
Expand All @@ -607,19 +683,12 @@ func (s *session) run(msgIn chan fixIn, msgOut chan []byte, quit chan bool) {
}

for {

switch s.sessionState.(type) {
case latentState:
return
}

select {
case request := <-s.toSend:
if s.IsLoggedOn() {
request.err <- s.send(request.msg)
} else {
request.err <- fmt.Errorf("Not logged on")
}
case fixIn, ok := <-msgIn:
if ok {
s.log.OnIncoming(string(fixIn.bytes))
Expand Down Expand Up @@ -658,6 +727,8 @@ func (s *session) run(msgIn chan fixIn, msgOut chan []byte, quit chan bool) {
}
case evt := <-s.sessionEvent:
s.sessionState = s.Timeout(s, evt)
case <-s.messageEvent:
s.sendOrDropAppMessages()
}
}
}
Loading