diff --git a/internal/testsuite/store_suite.go b/internal/testsuite/store_suite.go index 0563a5861..c623b6861 100644 --- a/internal/testsuite/store_suite.go +++ b/internal/testsuite/store_suite.go @@ -152,6 +152,48 @@ func (s *StoreTestSuite) TestMessageStoreSaveMessageAndIncrementGetMessage() { s.Equal(expectedMsgsBySeqNum[3], string(actualMsgs[2])) } +func (s *StoreTestSuite) TestMessageStoreSaveBatchAndIncrementGetMessage() { + s.Require().Nil(s.MsgStore.SetNextSenderMsgSeqNum(420)) + + // Given the following saved messages + expectedMsgsBySeqNum := map[int]string{ + 1: "In the frozen land of Nador", + 2: "they were forced to eat Robin's minstrels", + 3: "and there was much rejoicing", + } + var msgs [][]byte + for _, msg := range expectedMsgsBySeqNum { + msgs = append(msgs, []byte(msg)) + } + s.Require().Nil(s.MsgStore.SaveBatchAndIncrNextSenderMsgSeqNum(1, msgs)) + s.Equal(423, s.MsgStore.NextSenderMsgSeqNum()) + + // When the messages are retrieved from the MessageStore + actualMsgs, err := s.MsgStore.GetMessages(1, 3) + s.Require().Nil(err) + + // Then the messages should be + s.Require().Len(actualMsgs, 3) + s.Equal(expectedMsgsBySeqNum[1], string(actualMsgs[0])) + s.Equal(expectedMsgsBySeqNum[2], string(actualMsgs[1])) + s.Equal(expectedMsgsBySeqNum[3], string(actualMsgs[2])) + + // When the store is refreshed from its backing store + s.Require().Nil(s.MsgStore.Refresh()) + + // And the messages are retrieved from the MessageStore + actualMsgs, err = s.MsgStore.GetMessages(1, 3) + s.Require().Nil(err) + + s.Equal(423, s.MsgStore.NextSenderMsgSeqNum()) + + // Then the messages should still be + s.Require().Len(actualMsgs, 3) + s.Equal(expectedMsgsBySeqNum[1], string(actualMsgs[0])) + s.Equal(expectedMsgsBySeqNum[2], string(actualMsgs[1])) + s.Equal(expectedMsgsBySeqNum[3], string(actualMsgs[2])) +} + func (s *StoreTestSuite) TestMessageStoreGetMessagesEmptyStore() { // When messages are retrieved from an empty store messages, err := s.MsgStore.GetMessages(1, 2) diff --git a/memorystore.go b/memorystore.go index 5773f09b5..a3be11a75 100644 --- a/memorystore.go +++ b/memorystore.go @@ -97,6 +97,15 @@ func (store *memoryStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg return store.IncrNextSenderMsgSeqNum() } +func (store *memoryStore) SaveBatchAndIncrNextSenderMsgSeqNum(seqNum int, msg [][]byte) error { + for offset, m := range msg { + if err := store.SaveMessageAndIncrNextSenderMsgSeqNum(seqNum+offset, m); err != nil { + return err + } + } + return nil +} + func (store *memoryStore) GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) { var msgs [][]byte for seqNum := beginSeqNum; seqNum <= endSeqNum; seqNum++ { diff --git a/registry.go b/registry.go index 236c795b6..91179bb62 100644 --- a/registry.go +++ b/registry.go @@ -64,6 +64,25 @@ func SendToTarget(m Messagable, sessionID SessionID) error { return session.queueForSend(msg) } +// SendBatchToTarget is similar to SendToTarget, but it sends application messages in batch to the sessionID. +// The entire batch would fail if: +// - any message in the batch fails ToApp() validation +// - any message in the batch is an admin message +// This is more efficient compare to SendToTarget in the case of sending a burst of application messages, +// especially when using a persistent store like SQLStore, because it allows batching at the storage layer. +func SendBatchToTarget(m []Messagable, sessionID SessionID) error { + session, ok := lookupSession(sessionID) + if !ok { + return errUnknownSession + } + msg := make([]*Message, len(m)) + for i, v := range m { + msg[i] = v.ToMessage() + } + + return session.queueBatchAppsForSend(msg) +} + // ResetSession resets session's sequence numbers. func ResetSession(sessionID SessionID) error { session, ok := lookupSession(sessionID) diff --git a/session.go b/session.go index b359245e2..794315f88 100644 --- a/session.go +++ b/session.go @@ -346,6 +346,59 @@ func (s *session) persist(seqNum int, msgBytes []byte) error { return s.store.IncrNextSenderMsgSeqNum() } +// queueBatchAppsForSend will validate, persist, and queue the messages for send. +func (s *session) queueBatchAppsForSend(msg []*Message) error { + s.sendMutex.Lock() + defer s.sendMutex.Unlock() + + msgBytes, err := s.prepBatchAppMessagesForSend(msg) + if err != nil { + return err + } + + for _, mb := range msgBytes { + s.toSend = append(s.toSend, mb) + select { + case s.messageEvent <- true: + default: + } + } + + return nil +} + +func (s *session) prepBatchAppMessagesForSend(msg []*Message) (msgBytes [][]byte, err error) { + seqNum := s.store.NextSenderMsgSeqNum() + for i, m := range msg { + s.fillDefaultHeader(m, nil) + m.Header.SetField(tagMsgSeqNum, FIXInt(seqNum+i)) + msgType, err := m.Header.GetBytes(tagMsgType) + if err != nil { + return nil, err + } + if isAdminMessageType(msgType) { + return nil, fmt.Errorf("cannot send admin messages in batch") + } + if errToApp := s.application.ToApp(m, s.sessionID); errToApp != nil { + return nil, errToApp + } + msgBytes = append(msgBytes, m.build()) + } + err = s.persistBatch(seqNum, msgBytes) + if err != nil { + return nil, err + } + return msgBytes, nil +} + +func (s *session) persistBatch(seqNum int, msgBytes [][]byte) error { + if !s.DisableMessagePersist { + return s.store.SaveBatchAndIncrNextSenderMsgSeqNum(seqNum, msgBytes) + } + + return s.store.SetNextSenderMsgSeqNum(seqNum + len(msgBytes)) +} + func (s *session) sendQueued() { for _, msgBytes := range s.toSend { s.sendBytes(msgBytes) diff --git a/store.go b/store.go index 34e2570e4..59f7b3689 100644 --- a/store.go +++ b/store.go @@ -35,6 +35,7 @@ type MessageStore interface { SaveMessage(seqNum int, msg []byte) error SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg []byte) error + SaveBatchAndIncrNextSenderMsgSeqNum(seqNum int, msg [][]byte) error GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) Refresh() error diff --git a/store/file/filestore.go b/store/file/filestore.go index 4a3dadb5d..b0bd6506e 100644 --- a/store/file/filestore.go +++ b/store/file/filestore.go @@ -323,7 +323,7 @@ func (store *fileStore) CreationTime() time.Time { func (store *fileStore) SetCreationTime(_ time.Time) { } -func (store *fileStore) SaveMessage(seqNum int, msg []byte) error { +func (store *fileStore) saveMessages(seqNum int, messages [][]byte) error { offset, err := store.bodyFile.Seek(0, io.SeekEnd) if err != nil { return fmt.Errorf("unable to seek to end of file: %s: %s", store.bodyFname, err.Error()) @@ -331,12 +331,16 @@ func (store *fileStore) SaveMessage(seqNum int, msg []byte) error { if _, err := store.headerFile.Seek(0, io.SeekEnd); err != nil { return fmt.Errorf("unable to seek to end of file: %s: %s", store.headerFname, err.Error()) } - if _, err := fmt.Fprintf(store.headerFile, "%d,%d,%d\n", seqNum, offset, len(msg)); err != nil { - return fmt.Errorf("unable to write to file: %s: %s", store.headerFname, err.Error()) - } + msgOffset := offset + for seqOffset, msg := range messages { + if _, err := fmt.Fprintf(store.headerFile, "%d,%d,%d\n", seqNum+seqOffset, msgOffset, len(msg)); err != nil { + return fmt.Errorf("unable to write to file: %s: %s", store.headerFname, err.Error()) + } - if _, err := store.bodyFile.Write(msg); err != nil { - return fmt.Errorf("unable to write to file: %s: %s", store.bodyFname, err.Error()) + if _, err := store.bodyFile.Write(msg); err != nil { + return fmt.Errorf("unable to write to file: %s: %s", store.bodyFname, err.Error()) + } + msgOffset = msgOffset + int64(len(msg)) } if store.fileSync { if err := store.bodyFile.Sync(); err != nil { @@ -347,10 +351,18 @@ func (store *fileStore) SaveMessage(seqNum int, msg []byte) error { } } - store.offsets[seqNum] = msgDef{offset: offset, size: len(msg)} + msgOffset = offset + for seqOffset, msg := range messages { + store.offsets[seqNum+seqOffset] = msgDef{offset: msgOffset, size: len(msg)} + msgOffset = msgOffset + int64(len(msg)) + } return nil } +func (store *fileStore) SaveMessage(seqNum int, msg []byte) error { + return store.saveMessages(seqNum, [][]byte{msg}) +} + func (store *fileStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg []byte) error { err := store.SaveMessage(seqNum, msg) if err != nil { @@ -359,6 +371,14 @@ func (store *fileStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg [] return store.IncrNextSenderMsgSeqNum() } +func (store *fileStore) SaveBatchAndIncrNextSenderMsgSeqNum(seqNum int, msg [][]byte) error { + err := store.saveMessages(seqNum, msg) + if err != nil { + return err + } + return store.SetNextSenderMsgSeqNum(store.cache.NextSenderMsgSeqNum() + len(msg)) +} + func (store *fileStore) getMessage(seqNum int) (msg []byte, found bool, err error) { msgInfo, found := store.offsets[seqNum] if !found { diff --git a/store/mongo/mongostore.go b/store/mongo/mongostore.go index 5af278e10..93e1853d5 100644 --- a/store/mongo/mongostore.go +++ b/store/mongo/mongostore.go @@ -338,6 +338,56 @@ func (store *mongoStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg [ return store.cache.SetNextSenderMsgSeqNum(next) } +func (store *mongoStore) SaveBatchAndIncrNextSenderMsgSeqNum(seqNum int, messages [][]byte) error { + + if !store.allowTransactions { + for _, msg := range messages { + if err := store.SaveMessageAndIncrNextSenderMsgSeqNum(seqNum, msg); err != nil { + return err + } + } + return nil + } + + // If the mongodb supports replicasets, perform this operation as a transaction instead- + var next int + err := store.db.UseSession(context.Background(), func(sessionCtx mongo.SessionContext) error { + if err := sessionCtx.StartTransaction(); err != nil { + return err + } + + entries := make([]interface{}, 0, len(messages)) + for _, msg := range messages { + msgFilter := generateMessageFilter(&store.sessionID) + msgFilter.Msgseq = seqNum + msgFilter.Message = msg + } + _, err := store.db.Database(store.mongoDatabase).Collection(store.messagesCollection).InsertMany(sessionCtx, entries) + if err != nil { + return err + } + + next = store.cache.NextSenderMsgSeqNum() + len(messages) + + msgFilter := generateMessageFilter(&store.sessionID) + sessionUpdate := generateMessageFilter(&store.sessionID) + sessionUpdate.IncomingSeqNum = store.cache.NextTargetMsgSeqNum() + sessionUpdate.OutgoingSeqNum = next + sessionUpdate.CreationTime = store.cache.CreationTime() + _, err = store.db.Database(store.mongoDatabase).Collection(store.sessionsCollection).UpdateOne(sessionCtx, msgFilter, bson.M{"$set": sessionUpdate}) + if err != nil { + return err + } + + return sessionCtx.CommitTransaction(context.Background()) + }) + if err != nil { + return err + } + + return store.cache.SetNextSenderMsgSeqNum(next) +} + func (store *mongoStore) GetMessages(beginSeqNum, endSeqNum int) (msgs [][]byte, err error) { msgFilter := generateMessageFilter(&store.sessionID) // Marshal into database form. diff --git a/store/sql/sqlstore.go b/store/sql/sqlstore.go index c54ac517f..1f2dc5612 100644 --- a/store/sql/sqlstore.go +++ b/store/sql/sqlstore.go @@ -19,6 +19,7 @@ import ( "database/sql" "fmt" "regexp" + "strings" "time" "github.com/pkg/errors" @@ -352,6 +353,56 @@ func (store *sqlStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg []b return store.cache.SetNextSenderMsgSeqNum(next) } +func (store *sqlStore) SaveBatchAndIncrNextSenderMsgSeqNum(seqNum int, msg [][]byte) error { + s := store.sessionID + + tx, err := store.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + const values = "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" + placeholders := make([]string, 0, len(msg)) + params := make([]interface{}, 0, len(msg)*10) + for offset, m := range msg { + placeholders = append(placeholders, values) + params = append(params, seqNum+offset, string(m), + s.BeginString, s.Qualifier, + s.SenderCompID, s.SenderSubID, s.SenderLocationID, + s.TargetCompID, s.TargetSubID, s.TargetLocationID) + } + _, err = tx.Exec(sqlString(`INSERT INTO messages ( + msgseqnum, message, + beginstring, session_qualifier, + sendercompid, sendersubid, senderlocid, + targetcompid, targetsubid, targetlocid) + VALUES`+strings.Join(placeholders, ","), store.placeholder), + params...) + if err != nil { + return err + } + + next := store.cache.NextSenderMsgSeqNum() + len(msg) + _, err = tx.Exec(sqlString(`UPDATE sessions SET outgoing_seqnum = ? + WHERE beginstring=? AND session_qualifier=? + AND sendercompid=? AND sendersubid=? AND senderlocid=? + AND targetcompid=? AND targetsubid=? AND targetlocid=?`, store.placeholder), + next, s.BeginString, s.Qualifier, + s.SenderCompID, s.SenderSubID, s.SenderLocationID, + s.TargetCompID, s.TargetSubID, s.TargetLocationID) + if err != nil { + return err + } + + err = tx.Commit() + if err != nil { + return err + } + + return store.cache.SetNextSenderMsgSeqNum(next) +} + func (store *sqlStore) GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) { s := store.sessionID var msgs [][]byte