Skip to content

Commit

Permalink
Add QueueV2.CreateQueue and implement it for Cassandra (#4904)
Browse files Browse the repository at this point in the history
<!-- Describe what has changed in this PR -->
**What changed?**
I added the `CreateQueue` method to the `QueueV2` interface. I also
implemented it for Cassandra. Finally, I added some validation for the
read and write methods of queue messages to check that their
corresponding queue exists.

<!-- Tell your future self why have you made these changes -->
**Why?**
We need to create queues to track the minimum message ID in order to
prevent us from scanning over deleted tombstone messages in Cassandra
when reading from the front of the queue.

<!-- How have you verified this change? Tested locally? Added a unit
test? Checked in staging env? -->
**How did you test it?**
There is 100% test coverage for everything.

<!-- Assuming the worst case, what can be broken when deploying this
change to production? -->
**Potential risks**


<!-- Is this PR a hotfix candidate or require that a notification be
sent to the broader community? (Yes/No) -->
**Is hotfix candidate?**
  • Loading branch information
MichaelSnowden committed Oct 7, 2023
1 parent d54f4d8 commit b536313
Show file tree
Hide file tree
Showing 28 changed files with 1,351 additions and 151 deletions.
642 changes: 604 additions & 38 deletions api/persistence/v1/queues.pb.go

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions client/history/historytest/clienttest.go
Expand Up @@ -73,6 +73,15 @@ func TestClientGetDLQTasks(t *testing.T, historyTaskQueueManager persistence.His
targetCluster := "target-cluster-" + t.Name()
numTasks := 2

_, err := historyTaskQueueManager.CreateQueue(context.Background(), &persistence.CreateQueueRequest{
QueueKey: persistence.QueueKey{
QueueType: persistence.QueueTypeHistoryDLQ,
Category: tasks.CategoryTransfer,
SourceCluster: sourceCluster,
TargetCluster: targetCluster,
},
})
require.NoError(t, err)
enqueueTasks(t, historyTaskQueueManager, numTasks, sourceCluster, targetCluster)

listener := nettest.NewListener(nettest.NewPipe())
Expand Down
2 changes: 1 addition & 1 deletion common/persistence/cassandra/factory.go
Expand Up @@ -119,7 +119,7 @@ func (f *Factory) NewQueue(queueType p.QueueType) (p.Queue, error) {
// NewQueueV2 returns a new data-access object for queues and messages stored in Cassandra. It will never return an
// error.
func (f *Factory) NewQueueV2() (p.QueueV2, error) {
return NewQueueV2Store(f.session), nil
return NewQueueV2Store(f.session, f.logger), nil
}

// Close closes the factory
Expand Down
135 changes: 121 additions & 14 deletions common/persistence/cassandra/queue_v2_store.go
Expand Up @@ -26,14 +26,17 @@ package cassandra

import (
"context"
"errors"
"fmt"

commonpb "go.temporal.io/api/common/v1"
"go.temporal.io/api/enums/v1"
"go.temporal.io/api/serviceerror"

persistencespb "go.temporal.io/server/api/persistence/v1"
"go.temporal.io/server/common/log"
"go.temporal.io/server/common/persistence"
"go.temporal.io/server/common/persistence/nosql/nosqlplugin/cassandra/gocql"
"go.temporal.io/server/common/persistence/serialization"
)

type (
Expand All @@ -42,13 +45,16 @@ type (
// schema/cassandra/temporal/versioned/v1.9/queues.cql
queueV2Store struct {
session gocql.Session
logger log.Logger
}
)

const (
TemplateEnqueueMessageQuery = `INSERT INTO queue_messages (queue_type, queue_name, queue_partition, message_id, message_payload, message_encoding) VALUES (?, ?, ?, ?, ?, ?) IF NOT EXISTS`
TemplateGetMessagesQuery = `SELECT message_id, message_payload, message_encoding FROM queue_messages WHERE queue_type = ? AND queue_name = ? AND queue_partition = ? AND message_id >= ? ORDER BY message_id ASC LIMIT ?`
TemplateGetMaxMessageIDQuery = `SELECT message_id FROM queue_messages WHERE queue_type = ? AND queue_name = ? AND queue_partition = ? ORDER BY message_id DESC LIMIT 1`
TemplateCreateQueueQuery = `INSERT INTO queues (queue_type, queue_name, metadata_payload, metadata_encoding, version) VALUES (?, ?, ?, ?, ?) IF NOT EXISTS`
TemplateGetQueueQuery = `SELECT metadata_payload, metadata_encoding, version FROM queues WHERE queue_type = ? AND queue_name = ?`

// QueueMessageIDConflict will be part of the error message when a message with the same ID already exists in the
// queue. This is possible when there are concurrent writes to the queue because we enqueue a message using two
Expand Down Expand Up @@ -87,13 +93,10 @@ const (
pageTokenPrefixByte = 0
)

var (
ErrInvalidQueueMessageEncodingType = errors.New("invalid encoding type for queue message")
)

func NewQueueV2Store(session gocql.Session) persistence.QueueV2 {
func NewQueueV2Store(session gocql.Session, logger log.Logger) persistence.QueueV2 {
return &queueV2Store{
session: session,
logger: logger,
}
}

Expand All @@ -102,6 +105,11 @@ func (q *queueV2Store) EnqueueMessage(
request *persistence.InternalEnqueueMessageRequest,
) (*persistence.InternalEnqueueMessageResponse, error) {
// TODO: add concurrency control around this method to avoid things like QueueMessageIDConflict.
// TODO: cache the queue in memory to avoid querying the database every time.
_, err := q.getQueue(ctx, request.QueueType, request.QueueName)
if err != nil {
return nil, err
}
messageID, err := q.getNextMessageID(ctx, request.QueueType, request.QueueName)
if err != nil {
return nil, err
Expand All @@ -121,10 +129,14 @@ func (q *queueV2Store) ReadMessages(
ctx context.Context,
request *persistence.InternalReadMessagesRequest,
) (*persistence.InternalReadMessagesResponse, error) {
queue, err := q.getQueue(ctx, request.QueueType, request.QueueName)
if err != nil {
return nil, err
}
if request.PageSize <= 0 {
return nil, persistence.ErrNonPositiveReadQueueMessagesPageSize
}
minMessageID, err := q.getMinMessageID(request)
minMessageID, err := q.getMinMessageID(request.QueueType, request.QueueName, request.NextPageToken, queue)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -154,7 +166,7 @@ func (q *queueV2Store) ReadMessages(
}
encoding, ok := enums.EncodingType_value[messageEncoding]
if !ok {
return nil, fmt.Errorf("%w: %v", ErrInvalidQueueMessageEncodingType, messageEncoding)
return nil, serialization.NewUnknownEncodingTypeError(messageEncoding)
}

encodingType := enums.EncodingType(encoding)
Expand All @@ -181,21 +193,72 @@ func (q *queueV2Store) ReadMessages(
}, nil
}

func (q *queueV2Store) getMinMessageID(request *persistence.InternalReadMessagesRequest) (int, error) {
// TODO: start from the ack level of the queue partition instead of the first message ID when there is no token.
if len(request.NextPageToken) == 0 {
return persistence.FirstQueueMessageID, nil
func (q *queueV2Store) CreateQueue(
ctx context.Context,
request *persistence.InternalCreateQueueRequest,
) (*persistence.InternalCreateQueueResponse, error) {
queueType := request.QueueType
queueName := request.QueueName
queue := persistencespb.Queue{
Partitions: map[int32]*persistencespb.QueuePartition{
0: {
MinMessageId: persistence.FirstQueueMessageID,
},
},
}
bytes, _ := queue.Marshal()
applied, err := q.session.Query(
TemplateCreateQueueQuery,
queueType,
queueName,
bytes,
enums.ENCODING_TYPE_PROTO3.String(),
0,
).WithContext(ctx).MapScanCAS(make(map[string]interface{}))
if err != nil {
return nil, gocql.ConvertError("QueueV2CreateQueue", err)
}

if !applied {
return nil, fmt.Errorf(
"%w: queue type %v and name %v",
persistence.ErrQueueAlreadyExists,
queueType,
queueName,
)
}
return &persistence.InternalCreateQueueResponse{}, nil
}

func (q *queueV2Store) getMinMessageID(queueType persistence.QueueV2Type, name string, nextPageToken []byte, queue *persistencespb.Queue) (int, error) {
if len(nextPageToken) == 0 {
// Currently, we only have one partition for each queue. However, that might change in the future. If a queue is
// created with more than 1 partition by a server on a future release, and then that server is downgraded, we
// will need to handle this case. Since all DLQ tasks are retried infinitely, we just return an error.
numPartitions := len(queue.Partitions)
if numPartitions != 1 {
return 0, serviceerror.NewInternal(
fmt.Sprintf(
"queue with type %v and name %v has %d partitions, but this implementation only supports"+
" queues with 1 partition. Did you downgrade your Temporal server?",
queueType,
name,
numPartitions,
),
)
}
return int(queue.Partitions[0].MinMessageId), nil
}

var token persistencespb.ReadQueueMessagesNextPageToken

// Skip the first byte. See the comment on pageTokenPrefixByte for more details.
err := token.Unmarshal(request.NextPageToken[1:])
err := token.Unmarshal(nextPageToken[1:])
if err != nil {
return 0, fmt.Errorf(
"%w: %q: %v",
persistence.ErrInvalidReadQueueMessagesNextPageToken,
request.NextPageToken,
nextPageToken,
err,
)
}
Expand Down Expand Up @@ -256,3 +319,47 @@ func (q *queueV2Store) getNextMessageID(ctx context.Context, queueType persisten
// The next message ID is the max message ID + 1.
return maxMessageID + 1, nil
}

func (q *queueV2Store) getQueue(
ctx context.Context,
queueType persistence.QueueV2Type,
name string,
) (*persistencespb.Queue, error) {
var (
queueBytes []byte
queueEncodingStr string
version int64
)

err := q.session.Query(TemplateGetQueueQuery, queueType, name).WithContext(ctx).Scan(
&queueBytes,
&queueEncodingStr,
&version,
)
if err != nil {
if gocql.IsNotFoundError(err) {
return nil, persistence.NewQueueNotFoundError(queueType, name)
}
return nil, gocql.ConvertError("QueueV2GetQueue", err)
}

if queueEncodingStr != enums.ENCODING_TYPE_PROTO3.String() {
return nil, fmt.Errorf(
"queue with type %v and name %v has invalid encoding: %w",
queueType,
name,
serialization.NewUnknownEncodingTypeError(queueEncodingStr, enums.ENCODING_TYPE_PROTO3),
)
}

queue := &persistencespb.Queue{}
err = queue.Unmarshal(queueBytes)
if err != nil {
return nil, serialization.NewDeserializationError(
enums.ENCODING_TYPE_PROTO3,
fmt.Errorf("unmarshal payload for queue with type %v and name %v failed: %w", queueType, name, err),
)
}

return queue, nil
}
31 changes: 28 additions & 3 deletions common/persistence/cassandra/queue_v2_store_test.go
Expand Up @@ -35,6 +35,8 @@ import (
commonpb "go.temporal.io/api/common/v1"
enumspb "go.temporal.io/api/enums/v1"
"go.temporal.io/api/serviceerror"

"go.temporal.io/server/common/log"
"go.temporal.io/server/common/persistence"
"go.temporal.io/server/common/persistence/cassandra"
"go.temporal.io/server/common/persistence/nosql/nosqlplugin/cassandra/gocql"
Expand All @@ -60,10 +62,14 @@ func (f failingQuery) Scan(...interface{}) error {
return assert.AnError
}

func TestGetMaxQueueMessageIDQueryErr(t *testing.T) {
func (f failingQuery) MapScanCAS(map[string]interface{}) (bool, error) {
return false, assert.AnError
}

func TestQueueV2EnqueueMessageQueryErr(t *testing.T) {
t.Parallel()

q := cassandra.NewQueueV2Store(failingSession{})
q := newQueue()
_, err := q.EnqueueMessage(context.Background(), &persistence.InternalEnqueueMessageRequest{
QueueType: persistence.QueueTypeHistoryNormal,
QueueName: "test-queue-" + t.Name(),
Expand All @@ -72,7 +78,26 @@ func TestGetMaxQueueMessageIDQueryErr(t *testing.T) {
Data: []byte("1"),
},
})
assertUnavailable(t, err)
}

func TestQueueV2CreateQueueQueryErr(t *testing.T) {
t.Parallel()

q := newQueue()
_, err := q.CreateQueue(context.Background(), &persistence.InternalCreateQueueRequest{
QueueType: persistence.QueueTypeHistoryNormal,
QueueName: "test-queue-" + t.Name(),
})
assertUnavailable(t, err)
}

func assertUnavailable(t *testing.T, err error) {
t.Helper()
assert.ErrorAs(t, err, new(*serviceerror.Unavailable))
assert.ErrorContains(t, err, assert.AnError.Error())
assert.ErrorContains(t, err, "QueueV2GetMaxMessageID")
}

func newQueue() persistence.QueueV2 {
return cassandra.NewQueueV2Store(failingSession{}, log.NewTestLogger())
}
22 changes: 19 additions & 3 deletions common/persistence/client/fault_injection.go
Expand Up @@ -148,6 +148,7 @@ func (d *FaultInjectionDataStoreFactory) UpdateRate(rate float64) {
d.Queue.UpdateRate(rate)
d.ClusterMDStore.UpdateRate(rate)
}

func (d *FaultInjectionDataStoreFactory) NewTaskStore() (persistence.TaskStore, error) {
if d.TaskStore == nil {
baseFactory, err := d.baseFactory.NewTaskStore()
Expand All @@ -168,7 +169,6 @@ func (d *FaultInjectionDataStoreFactory) NewTaskStore() (persistence.TaskStore,
}
return d.TaskStore, nil
}

func (d *FaultInjectionDataStoreFactory) NewShardStore() (persistence.ShardStore, error) {
if d.ShardStore == nil {
baseFactory, err := d.baseFactory.NewShardStore()
Expand Down Expand Up @@ -456,20 +456,36 @@ func NewFaultInjectionQueueV2(rate float64, baseQueue persistence.QueueV2) *Faul
}
}

func (f *FaultInjectionQueueV2) EnqueueMessage(ctx context.Context, request *persistence.InternalEnqueueMessageRequest) (*persistence.InternalEnqueueMessageResponse, error) {
func (f *FaultInjectionQueueV2) EnqueueMessage(
ctx context.Context,
request *persistence.InternalEnqueueMessageRequest,
) (*persistence.InternalEnqueueMessageResponse, error) {
if err := f.ErrorGenerator.Generate(); err != nil {
return nil, err
}
return f.baseQueue.EnqueueMessage(ctx, request)
}

func (f *FaultInjectionQueueV2) ReadMessages(ctx context.Context, request *persistence.InternalReadMessagesRequest) (*persistence.InternalReadMessagesResponse, error) {
func (f *FaultInjectionQueueV2) ReadMessages(
ctx context.Context,
request *persistence.InternalReadMessagesRequest,
) (*persistence.InternalReadMessagesResponse, error) {
if err := f.ErrorGenerator.Generate(); err != nil {
return nil, err
}
return f.baseQueue.ReadMessages(ctx, request)
}

func (f *FaultInjectionQueueV2) CreateQueue(
ctx context.Context,
request *persistence.InternalCreateQueueRequest,
) (*persistence.InternalCreateQueueResponse, error) {
if err := f.ErrorGenerator.Generate(); err != nil {
return nil, err
}
return f.baseQueue.CreateQueue(ctx, request)
}

func NewFaultInjectionExecutionStore(
rate float64,
executionStore persistence.ExecutionStore,
Expand Down

0 comments on commit b536313

Please sign in to comment.