Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for SendBatchToTarget #599

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
42 changes: 42 additions & 0 deletions internal/testsuite/store_suite.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,48 @@ func (s *StoreTestSuite) TestMessageStoreSaveMessageAndIncrementGetMessage() {
s.Equal(expectedMsgsBySeqNum[3], string(actualMsgs[2]))
}

func (s *StoreTestSuite) TestMessageStoreSaveBatchAndIncrementGetMessage() {
s.Require().Nil(s.MsgStore.SetNextSenderMsgSeqNum(420))

// Given the following saved messages
expectedMsgsBySeqNum := map[int]string{
1: "In the frozen land of Nador",
2: "they were forced to eat Robin's minstrels",
3: "and there was much rejoicing",
}
var msgs [][]byte
for _, msg := range expectedMsgsBySeqNum {
msgs = append(msgs, []byte(msg))
}
s.Require().Nil(s.MsgStore.SaveBatchAndIncrNextSenderMsgSeqNum(1, msgs))
s.Equal(423, s.MsgStore.NextSenderMsgSeqNum())

// When the messages are retrieved from the MessageStore
actualMsgs, err := s.MsgStore.GetMessages(1, 3)
s.Require().Nil(err)

// Then the messages should be
s.Require().Len(actualMsgs, 3)
s.Equal(expectedMsgsBySeqNum[1], string(actualMsgs[0]))
s.Equal(expectedMsgsBySeqNum[2], string(actualMsgs[1]))
s.Equal(expectedMsgsBySeqNum[3], string(actualMsgs[2]))

// When the store is refreshed from its backing store
s.Require().Nil(s.MsgStore.Refresh())

// And the messages are retrieved from the MessageStore
actualMsgs, err = s.MsgStore.GetMessages(1, 3)
s.Require().Nil(err)

s.Equal(423, s.MsgStore.NextSenderMsgSeqNum())

// Then the messages should still be
s.Require().Len(actualMsgs, 3)
s.Equal(expectedMsgsBySeqNum[1], string(actualMsgs[0]))
s.Equal(expectedMsgsBySeqNum[2], string(actualMsgs[1]))
s.Equal(expectedMsgsBySeqNum[3], string(actualMsgs[2]))
}

func (s *StoreTestSuite) TestMessageStoreGetMessagesEmptyStore() {
// When messages are retrieved from an empty store
messages, err := s.MsgStore.GetMessages(1, 2)
Expand Down
9 changes: 9 additions & 0 deletions memorystore.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ func (store *memoryStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg
return store.IncrNextSenderMsgSeqNum()
}

func (store *memoryStore) SaveBatchAndIncrNextSenderMsgSeqNum(seqNum int, msg [][]byte) error {
for offset, m := range msg {
if err := store.SaveMessageAndIncrNextSenderMsgSeqNum(seqNum+offset, m); err != nil {
return err
}
}
return nil
}

func (store *memoryStore) GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) {
var msgs [][]byte
for seqNum := beginSeqNum; seqNum <= endSeqNum; seqNum++ {
Expand Down
19 changes: 19 additions & 0 deletions registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,25 @@ func SendToTarget(m Messagable, sessionID SessionID) error {
return session.queueForSend(msg)
}

// SendBatchToTarget is similar to SendToTarget, but it sends application messages in batch to the sessionID.
// The entire batch would fail if:
// - any message in the batch fails ToApp() validation
// - any message in the batch is an admin message
// This is more efficient compare to SendToTarget in the case of sending a burst of application messages,
// especially when using a persistent store like SQLStore, because it allows batching at the storage layer.
func SendBatchToTarget(m []Messagable, sessionID SessionID) error {
session, ok := lookupSession(sessionID)
if !ok {
return errUnknownSession
}
msg := make([]*Message, len(m))
for i, v := range m {
msg[i] = v.ToMessage()
}

return session.queueBatchAppsForSend(msg)
}

// ResetSession resets session's sequence numbers.
func ResetSession(sessionID SessionID) error {
session, ok := lookupSession(sessionID)
Expand Down
53 changes: 53 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,59 @@ func (s *session) persist(seqNum int, msgBytes []byte) error {
return s.store.IncrNextSenderMsgSeqNum()
}

// queueBatchAppsForSend will validate, persist, and queue the messages for send.
func (s *session) queueBatchAppsForSend(msg []*Message) error {
s.sendMutex.Lock()
defer s.sendMutex.Unlock()

msgBytes, err := s.prepBatchAppMessagesForSend(msg)
if err != nil {
return err
}

for _, mb := range msgBytes {
s.toSend = append(s.toSend, mb)
select {
case s.messageEvent <- true:
default:
}
}

return nil
}

func (s *session) prepBatchAppMessagesForSend(msg []*Message) (msgBytes [][]byte, err error) {
seqNum := s.store.NextSenderMsgSeqNum()
for i, m := range msg {
s.fillDefaultHeader(m, nil)
m.Header.SetField(tagMsgSeqNum, FIXInt(seqNum+i))
msgType, err := m.Header.GetBytes(tagMsgType)
if err != nil {
return nil, err
}
if isAdminMessageType(msgType) {
return nil, fmt.Errorf("cannot send admin messages in batch")
}
if errToApp := s.application.ToApp(m, s.sessionID); errToApp != nil {
return nil, errToApp
}
msgBytes = append(msgBytes, m.build())
}
err = s.persistBatch(seqNum, msgBytes)
if err != nil {
return nil, err
}
return msgBytes, nil
}

func (s *session) persistBatch(seqNum int, msgBytes [][]byte) error {
if !s.DisableMessagePersist {
return s.store.SaveBatchAndIncrNextSenderMsgSeqNum(seqNum, msgBytes)
}

return s.store.SetNextSenderMsgSeqNum(seqNum + len(msgBytes))
}

func (s *session) sendQueued() {
for _, msgBytes := range s.toSend {
s.sendBytes(msgBytes)
Expand Down
1 change: 1 addition & 0 deletions store.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type MessageStore interface {

SaveMessage(seqNum int, msg []byte) error
SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg []byte) error
SaveBatchAndIncrNextSenderMsgSeqNum(seqNum int, msg [][]byte) error
GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error)

Refresh() error
Expand Down
34 changes: 27 additions & 7 deletions store/file/filestore.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,20 +323,24 @@ func (store *fileStore) CreationTime() time.Time {
func (store *fileStore) SetCreationTime(_ time.Time) {
}

func (store *fileStore) SaveMessage(seqNum int, msg []byte) error {
func (store *fileStore) saveMessages(seqNum int, messages [][]byte) error {
offset, err := store.bodyFile.Seek(0, io.SeekEnd)
if err != nil {
return fmt.Errorf("unable to seek to end of file: %s: %s", store.bodyFname, err.Error())
}
if _, err := store.headerFile.Seek(0, io.SeekEnd); err != nil {
return fmt.Errorf("unable to seek to end of file: %s: %s", store.headerFname, err.Error())
}
if _, err := fmt.Fprintf(store.headerFile, "%d,%d,%d\n", seqNum, offset, len(msg)); err != nil {
return fmt.Errorf("unable to write to file: %s: %s", store.headerFname, err.Error())
}
msgOffset := offset
for seqOffset, msg := range messages {
if _, err := fmt.Fprintf(store.headerFile, "%d,%d,%d\n", seqNum+seqOffset, msgOffset, len(msg)); err != nil {
return fmt.Errorf("unable to write to file: %s: %s", store.headerFname, err.Error())
}

if _, err := store.bodyFile.Write(msg); err != nil {
return fmt.Errorf("unable to write to file: %s: %s", store.bodyFname, err.Error())
if _, err := store.bodyFile.Write(msg); err != nil {
return fmt.Errorf("unable to write to file: %s: %s", store.bodyFname, err.Error())
}
msgOffset = msgOffset + int64(len(msg))
}
if store.fileSync {
if err := store.bodyFile.Sync(); err != nil {
Expand All @@ -347,10 +351,18 @@ func (store *fileStore) SaveMessage(seqNum int, msg []byte) error {
}
}

store.offsets[seqNum] = msgDef{offset: offset, size: len(msg)}
msgOffset = offset
for seqOffset, msg := range messages {
store.offsets[seqNum+seqOffset] = msgDef{offset: msgOffset, size: len(msg)}
msgOffset = msgOffset + int64(len(msg))
}
return nil
}

func (store *fileStore) SaveMessage(seqNum int, msg []byte) error {
return store.saveMessages(seqNum, [][]byte{msg})
}

func (store *fileStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg []byte) error {
err := store.SaveMessage(seqNum, msg)
if err != nil {
Expand All @@ -359,6 +371,14 @@ func (store *fileStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg []
return store.IncrNextSenderMsgSeqNum()
}

func (store *fileStore) SaveBatchAndIncrNextSenderMsgSeqNum(seqNum int, msg [][]byte) error {
err := store.saveMessages(seqNum, msg)
if err != nil {
return err
}
return store.SetNextSenderMsgSeqNum(store.cache.NextSenderMsgSeqNum() + len(msg))
}

func (store *fileStore) getMessage(seqNum int) (msg []byte, found bool, err error) {
msgInfo, found := store.offsets[seqNum]
if !found {
Expand Down
50 changes: 50 additions & 0 deletions store/mongo/mongostore.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,56 @@ func (store *mongoStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg [
return store.cache.SetNextSenderMsgSeqNum(next)
}

func (store *mongoStore) SaveBatchAndIncrNextSenderMsgSeqNum(seqNum int, messages [][]byte) error {

if !store.allowTransactions {
for _, msg := range messages {
if err := store.SaveMessageAndIncrNextSenderMsgSeqNum(seqNum, msg); err != nil {
return err
}
}
return nil
}

// If the mongodb supports replicasets, perform this operation as a transaction instead-
var next int
err := store.db.UseSession(context.Background(), func(sessionCtx mongo.SessionContext) error {
if err := sessionCtx.StartTransaction(); err != nil {
return err
}

entries := make([]interface{}, 0, len(messages))
for _, msg := range messages {
msgFilter := generateMessageFilter(&store.sessionID)
msgFilter.Msgseq = seqNum
msgFilter.Message = msg
}
_, err := store.db.Database(store.mongoDatabase).Collection(store.messagesCollection).InsertMany(sessionCtx, entries)
if err != nil {
return err
}

next = store.cache.NextSenderMsgSeqNum() + len(messages)

msgFilter := generateMessageFilter(&store.sessionID)
sessionUpdate := generateMessageFilter(&store.sessionID)
sessionUpdate.IncomingSeqNum = store.cache.NextTargetMsgSeqNum()
sessionUpdate.OutgoingSeqNum = next
sessionUpdate.CreationTime = store.cache.CreationTime()
_, err = store.db.Database(store.mongoDatabase).Collection(store.sessionsCollection).UpdateOne(sessionCtx, msgFilter, bson.M{"$set": sessionUpdate})
if err != nil {
return err
}

return sessionCtx.CommitTransaction(context.Background())
})
if err != nil {
return err
}

return store.cache.SetNextSenderMsgSeqNum(next)
}

func (store *mongoStore) GetMessages(beginSeqNum, endSeqNum int) (msgs [][]byte, err error) {
msgFilter := generateMessageFilter(&store.sessionID)
// Marshal into database form.
Expand Down
51 changes: 51 additions & 0 deletions store/sql/sqlstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"database/sql"
"fmt"
"regexp"
"strings"
"time"

"github.com/pkg/errors"
Expand Down Expand Up @@ -352,6 +353,56 @@ func (store *sqlStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg []b
return store.cache.SetNextSenderMsgSeqNum(next)
}

func (store *sqlStore) SaveBatchAndIncrNextSenderMsgSeqNum(seqNum int, msg [][]byte) error {
s := store.sessionID

tx, err := store.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()

const values = "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"
placeholders := make([]string, 0, len(msg))
params := make([]interface{}, 0, len(msg)*10)
for offset, m := range msg {
placeholders = append(placeholders, values)
params = append(params, seqNum+offset, string(m),
s.BeginString, s.Qualifier,
s.SenderCompID, s.SenderSubID, s.SenderLocationID,
s.TargetCompID, s.TargetSubID, s.TargetLocationID)
}
_, err = tx.Exec(sqlString(`INSERT INTO messages (
msgseqnum, message,
beginstring, session_qualifier,
sendercompid, sendersubid, senderlocid,
targetcompid, targetsubid, targetlocid)
VALUES`+strings.Join(placeholders, ","), store.placeholder),
params...)
if err != nil {
return err
}

next := store.cache.NextSenderMsgSeqNum() + len(msg)
_, err = tx.Exec(sqlString(`UPDATE sessions SET outgoing_seqnum = ?
WHERE beginstring=? AND session_qualifier=?
AND sendercompid=? AND sendersubid=? AND senderlocid=?
AND targetcompid=? AND targetsubid=? AND targetlocid=?`, store.placeholder),
next, s.BeginString, s.Qualifier,
s.SenderCompID, s.SenderSubID, s.SenderLocationID,
s.TargetCompID, s.TargetSubID, s.TargetLocationID)
if err != nil {
return err
}

err = tx.Commit()
if err != nil {
return err
}

return store.cache.SetNextSenderMsgSeqNum(next)
}

func (store *sqlStore) GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) {
s := store.sessionID
var msgs [][]byte
Expand Down