diff --git a/filestore.go b/filestore.go index 976273754..46a83e6b4 100644 --- a/filestore.go +++ b/filestore.go @@ -301,6 +301,14 @@ func (store *fileStore) SaveMessage(seqNum int, msg []byte) error { return nil } +func (store *fileStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg []byte) error { + err := store.SaveMessage(seqNum, msg) + if err != nil { + return err + } + return store.IncrNextSenderMsgSeqNum() +} + func (store *fileStore) getMessage(seqNum int) (msg []byte, found bool, err error) { msgInfo, found := store.offsets[seqNum] if !found { diff --git a/mongostore.go b/mongostore.go index 15c419347..e07cc5c7b 100644 --- a/mongostore.go +++ b/mongostore.go @@ -248,6 +248,15 @@ func (store *mongoStore) SaveMessage(seqNum int, msg []byte) (err error) { return } +func (store *mongoStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg []byte) error { + // TODO add transaction + err := store.SaveMessage(seqNum, msg) + if err != nil { + return err + } + return store.IncrNextSenderMsgSeqNum() +} + func (store *mongoStore) GetMessages(beginSeqNum, endSeqNum int) (msgs [][]byte, err error) { msgFilter := generateMessageFilter(&store.sessionID) //Marshal into database form diff --git a/session.go b/session.go index 6ef143f2a..73557cf45 100644 --- a/session.go +++ b/session.go @@ -321,9 +321,7 @@ func (s *session) prepMessageForSend(msg *Message, inReplyTo *Message) (msgBytes func (s *session) persist(seqNum int, msgBytes []byte) error { if !s.DisableMessagePersist { - if err := s.store.SaveMessage(seqNum, msgBytes); err != nil { - return err - } + return s.store.SaveMessageAndIncrNextSenderMsgSeqNum(seqNum, msgBytes) } return s.store.IncrNextSenderMsgSeqNum() diff --git a/sqlstore.go b/sqlstore.go index a57281dd8..a176b7509 100644 --- a/sqlstore.go +++ b/sqlstore.go @@ -274,6 +274,54 @@ func (store *sqlStore) SaveMessage(seqNum int, msg []byte) error { return err } +func (store *sqlStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg []byte) error { + s := store.sessionID + + tx, err := store.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.Exec(sqlString(`INSERT INTO messages ( + msgseqnum, message, + beginstring, session_qualifier, + sendercompid, sendersubid, senderlocid, + targetcompid, targetsubid, targetlocid) + VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, store.placeholder), + seqNum, string(msg), + s.BeginString, s.Qualifier, + s.SenderCompID, s.SenderSubID, s.SenderLocationID, + s.TargetCompID, s.TargetSubID, s.TargetLocationID) + if err != nil { + return err + } + + next := store.cache.NextSenderMsgSeqNum() + 1 + _, err = tx.Exec(sqlString(`UPDATE sessions SET outgoing_seqnum = ? + WHERE beginstring=? AND session_qualifier=? + AND sendercompid=? AND sendersubid=? AND senderlocid=? + AND targetcompid=? AND targetsubid=? AND targetlocid=?`, store.placeholder), + next, s.BeginString, s.Qualifier, + s.SenderCompID, s.SenderSubID, s.SenderLocationID, + s.TargetCompID, s.TargetSubID, s.TargetLocationID) + if err != nil { + return err + } + + err = tx.Commit() + if err != nil { + return err + } + + err = store.cache.SetNextSenderMsgSeqNum(next) + if err != nil { + return err + } + + return nil +} + func (store *sqlStore) GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) { s := store.sessionID var msgs [][]byte diff --git a/store.go b/store.go index 41b6bc0c0..b743a46cc 100644 --- a/store.go +++ b/store.go @@ -20,6 +20,7 @@ type MessageStore interface { CreationTime() time.Time SaveMessage(seqNum int, msg []byte) error + SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg []byte) error GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) Refresh() error @@ -97,6 +98,14 @@ func (store *memoryStore) SaveMessage(seqNum int, msg []byte) error { return nil } +func (store *memoryStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg []byte) error { + err := store.SaveMessage(seqNum, msg) + if err != nil { + return err + } + return store.IncrNextSenderMsgSeqNum() +} + func (store *memoryStore) GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) { var msgs [][]byte for seqNum := beginSeqNum; seqNum <= endSeqNum; seqNum++ {