diff --git a/Makefile b/Makefile index d339b664..b68e6eb8 100644 --- a/Makefile +++ b/Makefile @@ -89,7 +89,7 @@ e2e-test: build-all-linux ## Run end-to-end tests (hermetic, auto-builds binarie fmt: ## Format Go and YAML code @echo "Formatting Go code..." - @$(BAZEL) run @rules_go//go -- run golang.org/x/tools/cmd/goimports@$(GOIMPORTS_VERSION) -w . + @find . -name '*.go' -not -path './pkg/*' -not -path './bazel-*' | xargs $(BAZEL) run @rules_go//go -- run golang.org/x/tools/cmd/goimports@$(GOIMPORTS_VERSION) -w @echo "Formatting YAML files..." @$(BAZEL) run @rules_go//go -- run github.com/google/yamlfmt/cmd/yamlfmt@$(YAMLFMT_VERSION) @echo "Formatting complete!" diff --git a/extension/queue/mysql/BUILD.bazel b/extension/queue/mysql/BUILD.bazel index c2f332f1..9cd728f6 100644 --- a/extension/queue/mysql/BUILD.bazel +++ b/extension/queue/mysql/BUILD.bazel @@ -8,6 +8,7 @@ gomock( out = "mock_stores.go", mockgen_tool = _MOCKGEN, package = "mysql", + self_package = "github.com/uber/submitqueue/extension/queue/mysql", source = "stores.go", source_importpath = "github.com/uber/submitqueue/extension/queue/mysql", ) @@ -17,6 +18,7 @@ go_library( name = "mysql", srcs = [ "constants.go", + "delivery_state_store.go", "errors.go", "message_store.go", ":mock_stores_src", @@ -26,10 +28,12 @@ go_library( "sql.go", "stores.go", "subscriber.go", + "subscriber_heartbeat_store.go", ], importpath = "github.com/uber/submitqueue/extension/queue/mysql", visibility = ["//visibility:public"], deps = [ + "//core/metrics", "//entity/queue", "//extension/queue", "@com_github_uber_go_tally_v4//:tally", @@ -41,11 +45,13 @@ go_library( go_test( name = "mysql_test", srcs = [ + "delivery_state_store_test.go", "message_store_test.go", "offset_store_test.go", "partition_lease_store_test.go", "publisher_test.go", "sql_test.go", + "subscriber_heartbeat_store_test.go", "subscriber_test.go", ], embed = [":mysql"], diff --git a/extension/queue/mysql/constants.go b/extension/queue/mysql/constants.go index 55f4dcb4..1d5c7fc9 100644 --- a/extension/queue/mysql/constants.go +++ b/extension/queue/mysql/constants.go @@ -17,16 +17,9 @@ package mysql // Common constants for frequently repeated strings across stores const ( - // Tag key (used in every Tagged() call) - tagErrorType = "error_type" - // Common log field names (used extensively across all stores) logTopic = "topic" logPartitionKey = "partition_key" logMessageID = "message_id" logError = "error" - - // Error types used across multiple methods/stores - errorBeginTx = "begin_transaction" - errorCommit = "commit" ) diff --git a/extension/queue/mysql/ctl/commands.go b/extension/queue/mysql/ctl/commands.go index be3a3e4c..31e2ecb0 100644 --- a/extension/queue/mysql/ctl/commands.go +++ b/extension/queue/mysql/ctl/commands.go @@ -151,8 +151,6 @@ func newTopicStatsCmd(store **lib.AdminStore, jsonOut *bool) *cobra.Command { rows := [][]string{ {"Topic", stats.Topic}, {"Total Messages", strconv.FormatInt(stats.TotalMessages, 10)}, - {"Visible Messages", strconv.FormatInt(stats.VisibleMessages, 10)}, - {"Invisible Messages", strconv.FormatInt(stats.InvisibleMessages, 10)}, {"DLQ Count", strconv.FormatInt(stats.DLQCount, 10)}, {"Partitions", strconv.FormatInt(stats.PartitionCount, 10)}, {"Consumer Groups", strconv.FormatInt(stats.ConsumerGroupCount, 10)}, @@ -181,16 +179,15 @@ func newListMessagesCmd(store **lib.AdminStore, jsonOut *bool) *cobra.Command { if *jsonOut { return lib.FormatJSON(os.Stdout, messages) } - headers := []string{"OFFSET", "ID", "PARTITION", "RETRIES", "INVISIBLE_UNTIL", "CREATED_AT"} + headers := []string{"OFFSET", "ID", "PARTITION", "CREATED_AT", "PUBLISHED_AT"} var rows [][]string for _, m := range messages { rows = append(rows, []string{ strconv.FormatInt(m.Offset, 10), m.ID, m.PartitionKey, - strconv.Itoa(m.RetryCount), - lib.FormatMillis(m.InvisibleUntil), lib.FormatMillis(m.CreatedAt), + lib.FormatMillis(m.PublishedAt), }) } lib.FormatTable(os.Stdout, headers, rows) @@ -226,8 +223,6 @@ func newInspectMessageCmd(store **lib.AdminStore, jsonOut *bool) *cobra.Command {"ID", detail.ID}, {"Topic", detail.Topic}, {"Partition", detail.PartitionKey}, - {"Retry Count", strconv.Itoa(detail.RetryCount)}, - {"Invisible Until", lib.FormatMillis(detail.InvisibleUntil)}, {"Created At", lib.FormatMillis(detail.CreatedAt)}, {"Published At", lib.FormatMillis(detail.PublishedAt)}, {"Payload", string(detail.Payload)}, @@ -321,14 +316,13 @@ func newListDLQCmd(store **lib.AdminStore, jsonOut *bool) *cobra.Command { if *jsonOut { return lib.FormatJSON(os.Stdout, messages) } - headers := []string{"OFFSET", "ID", "PARTITION", "RETRIES", "CREATED_AT"} + headers := []string{"OFFSET", "ID", "PARTITION", "CREATED_AT"} var rows [][]string for _, m := range messages { rows = append(rows, []string{ strconv.FormatInt(m.Offset, 10), m.ID, m.PartitionKey, - strconv.Itoa(m.RetryCount), lib.FormatMillis(m.CreatedAt), }) } diff --git a/extension/queue/mysql/ctl/lib/admin.go b/extension/queue/mysql/ctl/lib/admin.go index dea3b758..a3d6edee 100644 --- a/extension/queue/mysql/ctl/lib/admin.go +++ b/extension/queue/mysql/ctl/lib/admin.go @@ -45,10 +45,6 @@ type MessageSummary struct { Topic string // PartitionKey determines message distribution PartitionKey string - // RetryCount tracks retries on the current topic - RetryCount int - // InvisibleUntil is the epoch milliseconds until which the message is hidden - InvisibleUntil int64 // CreatedAt is the epoch milliseconds when the message was created CreatedAt int64 // PublishedAt is the epoch milliseconds when the message was published @@ -116,10 +112,6 @@ type TopicStats struct { Topic string // TotalMessages is the total number of messages TotalMessages int64 - // VisibleMessages is the count of messages currently visible for consumption - VisibleMessages int64 - // InvisibleMessages is the count of messages hidden by visibility timeout - InvisibleMessages int64 // DLQCount is the number of messages in the DLQ for this topic DLQCount int64 // PartitionCount is the number of distinct partitions @@ -154,7 +146,6 @@ func (s *AdminStore) ListTopics(ctx context.Context) ([]TopicInfo, error) { // GetTopicStats returns detailed statistics for a topic. func (s *AdminStore) GetTopicStats(ctx context.Context, topic string, dlqSuffix string) (TopicStats, error) { stats := TopicStats{Topic: topic} - nowMs := time.Now().UnixMilli() // Total messages err := s.db.QueryRowContext(ctx, @@ -165,18 +156,6 @@ func (s *AdminStore) GetTopicStats(ctx context.Context, topic string, dlqSuffix return stats, fmt.Errorf("count total: %w", err) } - // Visible messages (invisible_until <= now) - err = s.db.QueryRowContext(ctx, - fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE topic = ? AND invisible_until <= ?", mysql.MessagesTableName), - topic, nowMs, - ).Scan(&stats.VisibleMessages) - if err != nil { - return stats, fmt.Errorf("count visible: %w", err) - } - - // Invisible messages - stats.InvisibleMessages = stats.TotalMessages - stats.VisibleMessages - // DLQ count dlqTopic := topic + dlqSuffix err = s.db.QueryRowContext(ctx, @@ -215,12 +194,12 @@ func (s *AdminStore) ListMessages(ctx context.Context, topic string, partition s if partition != "" { rows, err = s.db.QueryContext(ctx, - fmt.Sprintf("SELECT `offset`, id, topic, partition_key, retry_count, invisible_until, created_at, published_at FROM %s WHERE topic = ? AND partition_key = ? ORDER BY `offset` LIMIT ?", mysql.MessagesTableName), + fmt.Sprintf("SELECT `offset`, id, topic, partition_key, created_at, published_at FROM %s WHERE topic = ? AND partition_key = ? ORDER BY `offset` LIMIT ?", mysql.MessagesTableName), topic, partition, limit, ) } else { rows, err = s.db.QueryContext(ctx, - fmt.Sprintf("SELECT `offset`, id, topic, partition_key, retry_count, invisible_until, created_at, published_at FROM %s WHERE topic = ? ORDER BY `offset` LIMIT ?", mysql.MessagesTableName), + fmt.Sprintf("SELECT `offset`, id, topic, partition_key, created_at, published_at FROM %s WHERE topic = ? ORDER BY `offset` LIMIT ?", mysql.MessagesTableName), topic, limit, ) } @@ -232,7 +211,7 @@ func (s *AdminStore) ListMessages(ctx context.Context, topic string, partition s var messages []MessageSummary for rows.Next() { var m MessageSummary - if err := rows.Scan(&m.Offset, &m.ID, &m.Topic, &m.PartitionKey, &m.RetryCount, &m.InvisibleUntil, &m.CreatedAt, &m.PublishedAt); err != nil { + if err := rows.Scan(&m.Offset, &m.ID, &m.Topic, &m.PartitionKey, &m.CreatedAt, &m.PublishedAt); err != nil { return nil, fmt.Errorf("scan message row: %w", err) } messages = append(messages, m) @@ -246,9 +225,9 @@ func (s *AdminStore) InspectMessage(ctx context.Context, topic string, messageID var metadataJSON []byte err := s.db.QueryRowContext(ctx, - fmt.Sprintf("SELECT `offset`, id, topic, partition_key, retry_count, invisible_until, created_at, published_at, payload, metadata, failed_at, failure_count, last_error, original_topic FROM %s WHERE topic = ? AND id = ?", mysql.MessagesTableName), + fmt.Sprintf("SELECT `offset`, id, topic, partition_key, created_at, published_at, payload, metadata, failed_at, failure_count, last_error, original_topic FROM %s WHERE topic = ? AND id = ?", mysql.MessagesTableName), topic, messageID, - ).Scan(&d.Offset, &d.ID, &d.Topic, &d.PartitionKey, &d.RetryCount, &d.InvisibleUntil, &d.CreatedAt, &d.PublishedAt, &d.Payload, &metadataJSON, &d.FailedAt, &d.FailureCount, &d.LastError, &d.OriginalTopic) + ).Scan(&d.Offset, &d.ID, &d.Topic, &d.PartitionKey, &d.CreatedAt, &d.PublishedAt, &d.Payload, &metadataJSON, &d.FailedAt, &d.FailureCount, &d.LastError, &d.OriginalTopic) if err == sql.ErrNoRows { return d, false, nil } @@ -323,7 +302,7 @@ func (s *AdminStore) RequeueDLQ(ctx context.Context, topic string, messageID str // Insert into original topic with reset fields nowMs := time.Now().UnixMilli() _, err = tx.ExecContext(ctx, - fmt.Sprintf("INSERT INTO %s (topic, partition_key, id, payload, metadata, retry_count, invisible_until, created_at, published_at, failed_at, failure_count, last_error, original_topic) VALUES (?, ?, ?, ?, ?, 0, 0, ?, ?, 0, 0, '', '')", mysql.MessagesTableName), + fmt.Sprintf("INSERT INTO %s (topic, partition_key, id, payload, metadata, created_at, published_at, failed_at, failure_count, last_error, original_topic) VALUES (?, ?, ?, ?, ?, ?, ?, 0, 0, '', '')", mysql.MessagesTableName), topic, partitionKey, messageID, payload, metadataJSON, createdAt, nowMs, ) if err != nil { diff --git a/extension/queue/mysql/ctl/lib/admin_test.go b/extension/queue/mysql/ctl/lib/admin_test.go index 75244e0c..46e07c2b 100644 --- a/extension/queue/mysql/ctl/lib/admin_test.go +++ b/extension/queue/mysql/ctl/lib/admin_test.go @@ -75,11 +75,6 @@ func TestGetTopicStats(t *testing.T) { WithArgs("orders"). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(100)) - // Visible messages - mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM queue_messages WHERE topic = \\? AND invisible_until <= \\?"). - WithArgs("orders", sqlmock.AnyArg()). - WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(80)) - // DLQ count mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM queue_messages WHERE topic = \\?"). WithArgs("orders_dlq"). @@ -99,8 +94,6 @@ func TestGetTopicStats(t *testing.T) { require.NoError(t, err) assert.Equal(t, "orders", stats.Topic) assert.Equal(t, int64(100), stats.TotalMessages) - assert.Equal(t, int64(80), stats.VisibleMessages) - assert.Equal(t, int64(20), stats.InvisibleMessages) assert.Equal(t, int64(3), stats.DLQCount) assert.Equal(t, int64(4), stats.PartitionCount) assert.Equal(t, int64(2), stats.ConsumerGroupCount) @@ -114,9 +107,9 @@ func TestListMessages(t *testing.T) { store := NewAdminStore(db) - rows := sqlmock.NewRows([]string{"offset", "id", "topic", "partition_key", "retry_count", "invisible_until", "created_at", "published_at"}). - AddRow(1, "msg-1", "orders", "repo-1", 0, 0, 1000, 1000). - AddRow(2, "msg-2", "orders", "repo-1", 1, 5000, 2000, 2000) + rows := sqlmock.NewRows([]string{"offset", "id", "topic", "partition_key", "created_at", "published_at"}). + AddRow(1, "msg-1", "orders", "repo-1", 1000, 1000). + AddRow(2, "msg-2", "orders", "repo-1", 2000, 2000) mock.ExpectQuery("SELECT .+ FROM queue_messages WHERE topic = \\? ORDER BY `offset` LIMIT \\?"). WithArgs("orders", 50). WillReturnRows(rows) @@ -127,7 +120,6 @@ func TestListMessages(t *testing.T) { assert.Equal(t, "msg-1", messages[0].ID) assert.Equal(t, int64(1), messages[0].Offset) assert.Equal(t, "msg-2", messages[1].ID) - assert.Equal(t, 1, messages[1].RetryCount) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -138,8 +130,8 @@ func TestListMessagesWithPartition(t *testing.T) { store := NewAdminStore(db) - rows := sqlmock.NewRows([]string{"offset", "id", "topic", "partition_key", "retry_count", "invisible_until", "created_at", "published_at"}). - AddRow(1, "msg-1", "orders", "repo-1", 0, 0, 1000, 1000) + rows := sqlmock.NewRows([]string{"offset", "id", "topic", "partition_key", "created_at", "published_at"}). + AddRow(1, "msg-1", "orders", "repo-1", 1000, 1000) mock.ExpectQuery("SELECT .+ FROM queue_messages WHERE topic = \\? AND partition_key = \\? ORDER BY `offset` LIMIT \\?"). WithArgs("orders", "repo-1", 10). WillReturnRows(rows) @@ -158,8 +150,8 @@ func TestInspectMessage(t *testing.T) { store := NewAdminStore(db) - rows := sqlmock.NewRows([]string{"offset", "id", "topic", "partition_key", "retry_count", "invisible_until", "created_at", "published_at", "payload", "metadata", "failed_at", "failure_count", "last_error", "original_topic"}). - AddRow(1, "msg-1", "orders", "repo-1", 0, 0, 1000, 1000, []byte("hello"), []byte(`{"key":"val"}`), 0, 0, "", "") + rows := sqlmock.NewRows([]string{"offset", "id", "topic", "partition_key", "created_at", "published_at", "payload", "metadata", "failed_at", "failure_count", "last_error", "original_topic"}). + AddRow(1, "msg-1", "orders", "repo-1", 1000, 1000, []byte("hello"), []byte(`{"key":"val"}`), 0, 0, "", "") mock.ExpectQuery("SELECT .+ FROM queue_messages WHERE topic = \\? AND id = \\?"). WithArgs("orders", "msg-1"). WillReturnRows(rows) @@ -181,7 +173,7 @@ func TestInspectMessageNotFound(t *testing.T) { store := NewAdminStore(db) - rows := sqlmock.NewRows([]string{"offset", "id", "topic", "partition_key", "retry_count", "invisible_until", "created_at", "published_at", "payload", "metadata", "failed_at", "failure_count", "last_error", "original_topic"}) + rows := sqlmock.NewRows([]string{"offset", "id", "topic", "partition_key", "created_at", "published_at", "payload", "metadata", "failed_at", "failure_count", "last_error", "original_topic"}) mock.ExpectQuery("SELECT .+ FROM queue_messages WHERE topic = \\? AND id = \\?"). WithArgs("orders", "missing"). WillReturnRows(rows) diff --git a/extension/queue/mysql/delivery_state_store.go b/extension/queue/mysql/delivery_state_store.go new file mode 100644 index 00000000..e16ee95e --- /dev/null +++ b/extension/queue/mysql/delivery_state_store.go @@ -0,0 +1,284 @@ +// Copyright (c) 2025 Uber Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysql + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/uber-go/tally/v4" + "github.com/uber/submitqueue/core/metrics" + "go.uber.org/zap" +) + +// sqldeliveryStateStore is the SQL implementation of deliveryStateStore +type sqldeliveryStateStore struct { + db *sql.DB + logger *zap.SugaredLogger + scope tally.Scope +} + +// newDeliveryStateStore creates a new SQL delivery state store +func newDeliveryStateStore(db *sql.DB, logger *zap.SugaredLogger, scope tally.Scope) deliveryStateStore { + return &sqldeliveryStateStore{ + db: db, + logger: logger.Named("delivery_state_store"), + scope: scope.SubScope("delivery_state_store"), + } +} + +// MarkDelivered inserts a row marking message as in-flight for this consumer group. +// Returns the resulting retry_count after the operation. +// +// The INSERT and subsequent SELECT are not in a transaction. This is safe because +// partition leasing guarantees a single writer per (consumer_group, topic, partition_key) +// — only the lease holder calls MarkDelivered for a given partition, so no concurrent +// mutation can occur between the two statements. +func (s *sqldeliveryStateStore) MarkDelivered(ctx context.Context, consumerGroup, topic, partitionKey string, offset int64, visibilityTimeoutMs int64) (_ int, retErr error) { + op := metrics.Begin(s.scope, "mark_delivered", + metrics.NewTag("topic", topic), + metrics.NewTag("consumer_group", consumerGroup), + metrics.NewTag("partition_key", partitionKey)) + defer func() { op.Complete(retErr) }() + + now := time.Now().UnixMilli() + invisibleUntil := now + visibilityTimeoutMs + + _, err := s.db.ExecContext(ctx, fmt.Sprintf(` + INSERT INTO %s (consumer_group, topic, partition_key, message_offset, acked, invisible_until, retry_count) + VALUES (?, ?, ?, ?, FALSE, ?, 0) + ON DUPLICATE KEY UPDATE + invisible_until = IF(acked = FALSE, VALUES(invisible_until), invisible_until), + retry_count = IF(acked = FALSE, retry_count + 1, retry_count) + `, DeliveryStateTableName), + consumerGroup, topic, partitionKey, offset, invisibleUntil) + + if err != nil { + return 0, fmt.Errorf("mark delivered topic=%s partition=%s offset=%d: %w", topic, partitionKey, offset, err) + } + + // Read retry_count after INSERT/UPDATE to get the current value. + // For new inserts, retry_count = 0. For updates (redelivery), retry_count was incremented. + // If this SELECT fails, the caller (consumer lifecycle) nacks and redelivers the message, + // which re-executes MarkDelivered idempotently — so the failure is safe to propagate. + var retryCount int + err = s.db.QueryRowContext(ctx, fmt.Sprintf(` + SELECT retry_count FROM %s + WHERE consumer_group = ? AND topic = ? AND partition_key = ? AND message_offset = ? + `, DeliveryStateTableName), consumerGroup, topic, partitionKey, offset).Scan(&retryCount) + if err != nil { + return 0, fmt.Errorf("get retry count after mark delivered topic=%s partition=%s offset=%d: %w", topic, partitionKey, offset, err) + } + + return retryCount, nil +} + +// ExtendVisibility extends the visibility timeout for an in-flight message +// without incrementing retry_count. Used by ExtendVisibilityTimeout. +func (s *sqldeliveryStateStore) ExtendVisibility(ctx context.Context, consumerGroup, topic, partitionKey string, offset int64, visibilityTimeoutMs int64) (retErr error) { + op := metrics.Begin(s.scope, "extend_visibility", + metrics.NewTag("topic", topic), + metrics.NewTag("consumer_group", consumerGroup), + metrics.NewTag("partition_key", partitionKey)) + defer func() { op.Complete(retErr) }() + + now := time.Now().UnixMilli() + invisibleUntil := now + visibilityTimeoutMs + + result, err := s.db.ExecContext(ctx, fmt.Sprintf(` + UPDATE %s + SET invisible_until = ? + WHERE consumer_group = ? AND topic = ? AND partition_key = ? AND message_offset = ? AND acked = FALSE + `, DeliveryStateTableName), + invisibleUntil, consumerGroup, topic, partitionKey, offset) + + if err != nil { + return fmt.Errorf("extend visibility topic=%s partition=%s offset=%d: %w", topic, partitionKey, offset, err) + } + + rowsAffected, raErr := result.RowsAffected() + if raErr == nil && rowsAffected == 0 { + s.logger.Warnw("extend visibility matched no rows, lease may have expired or message already acked", + logTopic, topic, + logPartitionKey, partitionKey, + "offset", offset, + ) + } + + return nil +} + +// MarkAcked sets acked = TRUE to indicate this group has processed the message. +func (s *sqldeliveryStateStore) MarkAcked(ctx context.Context, consumerGroup, topic, partitionKey string, offset int64) (retErr error) { + op := metrics.Begin(s.scope, "mark_acked", + metrics.NewTag("topic", topic), + metrics.NewTag("consumer_group", consumerGroup), + metrics.NewTag("partition_key", partitionKey)) + defer func() { op.Complete(retErr) }() + + _, err := s.db.ExecContext(ctx, fmt.Sprintf(` + INSERT INTO %s (consumer_group, topic, partition_key, message_offset, acked, invisible_until, retry_count) + VALUES (?, ?, ?, ?, TRUE, 0, 0) + ON DUPLICATE KEY UPDATE acked = TRUE + `, DeliveryStateTableName), + consumerGroup, topic, partitionKey, offset) + + if err != nil { + return fmt.Errorf("mark acked topic=%s partition=%s offset=%d: %w", topic, partitionKey, offset, err) + } + + return nil +} + +// MarkNacked sets invisible_until = now + delay to schedule redelivery. +// retry_count is NOT incremented here — it is incremented by MarkDelivered on redelivery. +func (s *sqldeliveryStateStore) MarkNacked(ctx context.Context, consumerGroup, topic, partitionKey string, offset int64, delayMs int64) (retErr error) { + op := metrics.Begin(s.scope, "mark_nacked", + metrics.NewTag("topic", topic), + metrics.NewTag("consumer_group", consumerGroup), + metrics.NewTag("partition_key", partitionKey)) + defer func() { op.Complete(retErr) }() + + now := time.Now().UnixMilli() + invisibleUntil := now + delayMs + + _, err := s.db.ExecContext(ctx, fmt.Sprintf(` + INSERT INTO %s (consumer_group, topic, partition_key, message_offset, acked, invisible_until, retry_count) + VALUES (?, ?, ?, ?, FALSE, ?, 0) + ON DUPLICATE KEY UPDATE + invisible_until = IF(acked = FALSE, VALUES(invisible_until), invisible_until) + `, DeliveryStateTableName), + consumerGroup, topic, partitionKey, offset, invisibleUntil) + + if err != nil { + return fmt.Errorf("mark nacked topic=%s partition=%s offset=%d: %w", topic, partitionKey, offset, err) + } + + return nil +} + +// GetDeliveryState returns the full delivery state for a message offset. +// Returns (state, found, error). found=false means no row (never delivered). +func (s *sqldeliveryStateStore) GetDeliveryState(ctx context.Context, consumerGroup, topic, partitionKey string, offset int64) (_ DeliveryState, _ bool, retErr error) { + op := metrics.Begin(s.scope, "get_delivery_state", + metrics.NewTag("topic", topic), + metrics.NewTag("consumer_group", consumerGroup), + metrics.NewTag("partition_key", partitionKey)) + defer func() { op.Complete(retErr) }() + + var state DeliveryState + err := s.db.QueryRowContext(ctx, fmt.Sprintf(` + SELECT acked, invisible_until, retry_count FROM %s + WHERE consumer_group = ? AND topic = ? AND partition_key = ? AND message_offset = ? + `, DeliveryStateTableName), consumerGroup, topic, partitionKey, offset).Scan(&state.Acked, &state.InvisibleUntil, &state.RetryCount) + + if err == sql.ErrNoRows { + return DeliveryState{}, false, nil + } + if err != nil { + return DeliveryState{}, false, fmt.Errorf("get delivery state topic=%s partition=%s offset=%d: %w", topic, partitionKey, offset, err) + } + + return state, true, nil +} + +// AdvanceWatermark computes the new contiguous acked watermark and cleans up +// delivery state rows that are behind it. +// offsets are the actual message offsets above the current watermark (from messageStore). +// Returns the new watermark (highest contiguous acked offset from currentWatermark). +func (s *sqldeliveryStateStore) AdvanceWatermark(ctx context.Context, consumerGroup, topic, partitionKey string, currentWatermark int64, offsets []int64) (_ int64, retErr error) { + op := metrics.Begin(s.scope, "advance_watermark", + metrics.NewTag("topic", topic), + metrics.NewTag("consumer_group", consumerGroup), + metrics.NewTag("partition_key", partitionKey)) + defer func() { op.Complete(retErr) }() + + if len(offsets) == 0 { + return currentWatermark, nil + } + + // Batch-fetch delivery state for the provided offsets. + placeholders := make([]byte, 0, len(offsets)*2-1) + args := make([]interface{}, 0, 3+len(offsets)) + args = append(args, consumerGroup, topic, partitionKey) + for i, offset := range offsets { + if i > 0 { + placeholders = append(placeholders, ',') + } + placeholders = append(placeholders, '?') + args = append(args, offset) + } + + rows, err := s.db.QueryContext(ctx, fmt.Sprintf(` + SELECT message_offset, acked FROM %s + WHERE consumer_group = ? AND topic = ? AND partition_key = ? + AND message_offset IN (%s) + `, DeliveryStateTableName, string(placeholders)), args...) + if err != nil { + return currentWatermark, fmt.Errorf("query delivery state for watermark topic=%s partition=%s: %w", topic, partitionKey, err) + } + defer rows.Close() + + // Build lookup map: offset -> acked + ackedMap := make(map[int64]bool, len(offsets)) + for rows.Next() { + var offset int64 + var acked bool + if err := rows.Scan(&offset, &acked); err != nil { + return currentWatermark, fmt.Errorf("scan delivery state topic=%s partition=%s: %w", topic, partitionKey, err) + } + ackedMap[offset] = acked + } + if err := rows.Err(); err != nil { + return currentWatermark, fmt.Errorf("delivery state iteration topic=%s partition=%s: %w", topic, partitionKey, err) + } + + // Walk message offsets in order. Advance while contiguous acked. + // Stop at first offset that is not acked (in-flight, nacked, or undelivered). + newWatermark := currentWatermark + for _, offset := range offsets { + acked, exists := ackedMap[offset] + if !exists || !acked { + // No delivery state (undelivered) or not acked — stop + break + } + newWatermark = offset + } + + // Cleanup error is swallowed because the watermark was already computed and + // will be returned to the caller. The stale delivery state rows behind the + // watermark are harmless — they are never read again (all queries use + // offset > watermark). Cleanup is retried on the next AdvanceWatermark call. + if newWatermark > currentWatermark { + _, err := s.db.ExecContext(ctx, fmt.Sprintf(` + DELETE FROM %s + WHERE consumer_group = ? AND topic = ? AND partition_key = ? AND message_offset <= ? + `, DeliveryStateTableName), consumerGroup, topic, partitionKey, newWatermark) + if err != nil { + metrics.NamedCounter(s.scope, "advance_watermark", "cleanup_errors", 1, + metrics.NewTag("topic", topic)) + s.logger.Warnw("failed to clean up delivery state behind watermark, will retry on next advance", + logTopic, topic, + logPartitionKey, partitionKey, + "watermark", newWatermark, + logError, err, + ) + } + } + + return newWatermark, nil +} diff --git a/extension/queue/mysql/delivery_state_store_test.go b/extension/queue/mysql/delivery_state_store_test.go new file mode 100644 index 00000000..bd5a0903 --- /dev/null +++ b/extension/queue/mysql/delivery_state_store_test.go @@ -0,0 +1,453 @@ +// Copyright (c) 2025 Uber Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysql + +import ( + "context" + "database/sql" + "database/sql/driver" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/uber-go/tally/v4" + "go.uber.org/zap/zaptest" +) + +func newTestDeliveryStateStoreWithMock(t *testing.T) (deliveryStateStore, *sql.DB, sqlmock.Sqlmock) { + t.Helper() + db, mock, err := sqlmock.New() + require.NoError(t, err) + store := newDeliveryStateStore(db, zaptest.NewLogger(t).Sugar(), tally.NoopScope) + return store, db, mock +} + +func TestDeliveryStateStore_MarkDelivered(t *testing.T) { + tests := []struct { + name string + execErr bool + queryErr bool + wantRetryCount int + }{ + { + name: "success new insert returns 0", + wantRetryCount: 0, + }, + { + name: "success redelivery returns incremented count", + wantRetryCount: 3, + }, + { + name: "exec error", + execErr: true, + }, + { + name: "query error on retry_count SELECT", + queryErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store, db, mock := newTestDeliveryStateStoreWithMock(t) + defer db.Close() + + if tt.execErr { + mock.ExpectExec("INSERT INTO queue_delivery_state"). + WithArgs("group-1", "orders", "part-1", int64(5), sqlmock.AnyArg()). + WillReturnError(assert.AnError) + } else { + mock.ExpectExec("INSERT INTO queue_delivery_state"). + WithArgs("group-1", "orders", "part-1", int64(5), sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + if tt.queryErr { + mock.ExpectQuery("SELECT retry_count FROM queue_delivery_state"). + WithArgs("group-1", "orders", "part-1", int64(5)). + WillReturnError(assert.AnError) + } else { + mock.ExpectQuery("SELECT retry_count FROM queue_delivery_state"). + WithArgs("group-1", "orders", "part-1", int64(5)). + WillReturnRows(sqlmock.NewRows([]string{"retry_count"}).AddRow(tt.wantRetryCount)) + } + } + + retryCount, err := store.MarkDelivered(context.Background(), "group-1", "orders", "part-1", 5, 30000) + + if tt.execErr || tt.queryErr { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.wantRetryCount, retryCount) + } + assert.NoError(t, mock.ExpectationsWereMet()) + }) + } +} + +func TestDeliveryStateStore_ExtendVisibility(t *testing.T) { + tests := []struct { + name string + wantErr bool + }{ + { + name: "success", + wantErr: false, + }, + { + name: "db error", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store, db, mock := newTestDeliveryStateStoreWithMock(t) + defer db.Close() + + if tt.wantErr { + mock.ExpectExec("UPDATE queue_delivery_state"). + WithArgs(sqlmock.AnyArg(), "group-1", "orders", "part-1", int64(5)). + WillReturnError(assert.AnError) + } else { + mock.ExpectExec("UPDATE queue_delivery_state"). + WithArgs(sqlmock.AnyArg(), "group-1", "orders", "part-1", int64(5)). + WillReturnResult(sqlmock.NewResult(0, 1)) + } + + err := store.ExtendVisibility(context.Background(), "group-1", "orders", "part-1", 5, 60000) + + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + assert.NoError(t, mock.ExpectationsWereMet()) + }) + } +} + +func TestDeliveryStateStore_MarkAcked(t *testing.T) { + tests := []struct { + name string + wantErr bool + }{ + { + name: "success", + wantErr: false, + }, + { + name: "db error", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store, db, mock := newTestDeliveryStateStoreWithMock(t) + defer db.Close() + + if tt.wantErr { + mock.ExpectExec("INSERT INTO queue_delivery_state"). + WithArgs("group-1", "orders", "part-1", int64(5)). + WillReturnError(assert.AnError) + } else { + mock.ExpectExec("INSERT INTO queue_delivery_state"). + WithArgs("group-1", "orders", "part-1", int64(5)). + WillReturnResult(sqlmock.NewResult(1, 1)) + } + + err := store.MarkAcked(context.Background(), "group-1", "orders", "part-1", 5) + + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + assert.NoError(t, mock.ExpectationsWereMet()) + }) + } +} + +func TestDeliveryStateStore_MarkNacked(t *testing.T) { + tests := []struct { + name string + wantErr bool + }{ + { + name: "success", + wantErr: false, + }, + { + name: "db error", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store, db, mock := newTestDeliveryStateStoreWithMock(t) + defer db.Close() + + if tt.wantErr { + mock.ExpectExec("INSERT INTO queue_delivery_state"). + WithArgs("group-1", "orders", "part-1", int64(5), sqlmock.AnyArg()). + WillReturnError(assert.AnError) + } else { + mock.ExpectExec("INSERT INTO queue_delivery_state"). + WithArgs("group-1", "orders", "part-1", int64(5), sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + } + + err := store.MarkNacked(context.Background(), "group-1", "orders", "part-1", 5, 5000) + + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + assert.NoError(t, mock.ExpectationsWereMet()) + }) + } +} + +func TestDeliveryStateStore_GetDeliveryState(t *testing.T) { + tests := []struct { + name string + acked bool + invisibleUntil int64 + retryCount int + noRows bool + wantErr bool + wantFound bool + }{ + { + name: "no delivery state row returns not found", + noRows: true, + wantFound: false, + }, + { + name: "acked message", + acked: true, + invisibleUntil: 0, + retryCount: 2, + wantFound: true, + }, + { + name: "in-flight message", + acked: false, + invisibleUntil: 9999999999999, + retryCount: 1, + wantFound: true, + }, + { + name: "expired visibility", + acked: false, + invisibleUntil: 1000, + retryCount: 3, + wantFound: true, + }, + { + name: "db error", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store, db, mock := newTestDeliveryStateStoreWithMock(t) + defer db.Close() + + if tt.wantErr { + mock.ExpectQuery("SELECT acked, invisible_until, retry_count FROM queue_delivery_state"). + WithArgs("group-1", "orders", "part-1", int64(5)). + WillReturnError(assert.AnError) + } else if tt.noRows { + mock.ExpectQuery("SELECT acked, invisible_until, retry_count FROM queue_delivery_state"). + WithArgs("group-1", "orders", "part-1", int64(5)). + WillReturnRows(sqlmock.NewRows([]string{"acked", "invisible_until", "retry_count"})) + } else { + mock.ExpectQuery("SELECT acked, invisible_until, retry_count FROM queue_delivery_state"). + WithArgs("group-1", "orders", "part-1", int64(5)). + WillReturnRows(sqlmock.NewRows([]string{"acked", "invisible_until", "retry_count"}). + AddRow(tt.acked, tt.invisibleUntil, tt.retryCount)) + } + + state, found, err := store.GetDeliveryState(context.Background(), "group-1", "orders", "part-1", 5) + + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.wantFound, found) + if found { + assert.Equal(t, tt.acked, state.Acked) + assert.Equal(t, tt.invisibleUntil, state.InvisibleUntil) + assert.Equal(t, tt.retryCount, state.RetryCount) + } + } + assert.NoError(t, mock.ExpectationsWereMet()) + }) + } +} + +func TestDeliveryStateStore_AdvanceWatermark(t *testing.T) { + tests := []struct { + name string + currentWatermark int64 + offsets []int64 // message offsets passed in (from messageStore) + dsRows []struct { + offset int64 + acked bool + } + dsQueryErr bool + expectWatermark int64 + expectCleanup bool + }{ + { + name: "no offsets", + currentWatermark: 5, + offsets: nil, + expectWatermark: 5, + expectCleanup: false, + }, + { + name: "acked offsets advance watermark", + currentWatermark: 5, + offsets: []int64{6, 7, 8}, + dsRows: []struct { + offset int64 + acked bool + }{ + {offset: 6, acked: true}, + {offset: 7, acked: true}, + {offset: 8, acked: true}, + }, + expectWatermark: 8, + expectCleanup: true, + }, + { + name: "gap in offsets does not block advancement", + currentWatermark: 5, + offsets: []int64{6, 8}, // gap: offset 7 does not exist (AUTO_INCREMENT gap) + dsRows: []struct { + offset int64 + acked bool + }{ + {offset: 6, acked: true}, + {offset: 8, acked: true}, + }, + expectWatermark: 8, + expectCleanup: true, + }, + { + name: "non-acked offset stops advancement", + currentWatermark: 5, + offsets: []int64{6, 7, 8}, + dsRows: []struct { + offset int64 + acked bool + }{ + {offset: 6, acked: true}, + {offset: 7, acked: false}, // in-flight, not acked + {offset: 8, acked: true}, + }, + expectWatermark: 6, + expectCleanup: true, + }, + { + name: "undelivered message stops advancement", + currentWatermark: 5, + offsets: []int64{6, 7}, + dsRows: []struct { + offset int64 + acked bool + }{ + {offset: 6, acked: true}, + // offset 7 has no delivery state row (undelivered) + }, + expectWatermark: 6, + expectCleanup: true, + }, + { + name: "first offset not acked means no advancement", + currentWatermark: 5, + offsets: []int64{6}, + dsRows: []struct { + offset int64 + acked bool + }{ + {offset: 6, acked: false}, // in-flight + }, + expectWatermark: 5, + expectCleanup: false, + }, + { + name: "delivery state query error returns current watermark", + currentWatermark: 5, + offsets: []int64{6, 7}, + dsQueryErr: true, + expectWatermark: 5, + expectCleanup: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store, db, mock := newTestDeliveryStateStoreWithMock(t) + defer db.Close() + + // Delivery state query is only issued if there are offsets + if len(tt.offsets) > 0 { + dsArgs := make([]driver.Value, 0, 3+len(tt.offsets)) + dsArgs = append(dsArgs, "group-1", "orders", "part-1") + for _, offset := range tt.offsets { + dsArgs = append(dsArgs, offset) + } + + if tt.dsQueryErr { + mock.ExpectQuery("SELECT message_offset, acked FROM queue_delivery_state"). + WithArgs(dsArgs...). + WillReturnError(assert.AnError) + } else { + dsResultRows := sqlmock.NewRows([]string{"message_offset", "acked"}) + for _, r := range tt.dsRows { + dsResultRows.AddRow(r.offset, r.acked) + } + mock.ExpectQuery("SELECT message_offset, acked FROM queue_delivery_state"). + WithArgs(dsArgs...). + WillReturnRows(dsResultRows) + } + } + + if tt.expectCleanup { + mock.ExpectExec("DELETE FROM queue_delivery_state"). + WithArgs("group-1", "orders", "part-1", tt.expectWatermark). + WillReturnResult(sqlmock.NewResult(0, tt.expectWatermark-tt.currentWatermark)) + } + + watermark, err := store.AdvanceWatermark(context.Background(), "group-1", "orders", "part-1", tt.currentWatermark, tt.offsets) + + if tt.dsQueryErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + assert.Equal(t, tt.expectWatermark, watermark) + assert.NoError(t, mock.ExpectationsWereMet()) + }) + } +} diff --git a/extension/queue/mysql/errors.go b/extension/queue/mysql/errors.go index 5a79a4fc..4dbfba96 100644 --- a/extension/queue/mysql/errors.go +++ b/extension/queue/mysql/errors.go @@ -14,7 +14,19 @@ package mysql -import "fmt" +import ( + "errors" + "fmt" +) + +// ErrPublisherClosed is returned when attempting to publish after the publisher has been closed. +// This is a graceful error, not a programming bug — concurrent goroutines may still hold +// references to the publisher when Close() is called, and their subsequent Publish calls +// return this error to signal they should stop. +var ErrPublisherClosed = errors.New("publisher is closed") + +// ErrSubscriberClosed is returned when attempting to subscribe after the subscriber has been closed. +var ErrSubscriberClosed = errors.New("subscriber is closed") // ErrAlreadyAcknowledged is returned when attempting to ack/nack a delivery that was already processed type ErrAlreadyAcknowledged struct { @@ -24,3 +36,16 @@ type ErrAlreadyAcknowledged struct { func (e *ErrAlreadyAcknowledged) Error() string { return fmt.Sprintf("delivery %s already acknowledged or nacked", e.DeliveryID) } + +// ErrLeaseExpired is returned when a lease renewal fails because the lease +// is no longer owned by this worker (rows affected == 0). +type ErrLeaseExpired struct { + // Topic is the topic the lease was for. + Topic string + // PartitionKey is the partition the lease was for. + PartitionKey string +} + +func (e *ErrLeaseExpired) Error() string { + return fmt.Sprintf("lease expired for topic=%s partition=%s", e.Topic, e.PartitionKey) +} diff --git a/extension/queue/mysql/message_store.go b/extension/queue/mysql/message_store.go index 8e5374fb..e0815451 100644 --- a/extension/queue/mysql/message_store.go +++ b/extension/queue/mysql/message_store.go @@ -24,43 +24,30 @@ import ( "github.com/uber-go/tally/v4" "go.uber.org/zap" + "github.com/uber/submitqueue/core/metrics" "github.com/uber/submitqueue/entity/queue" ) // sqlmessageStore is the SQL implementation of messageStore type sqlmessageStore struct { - db *sql.DB - logger *zap.SugaredLogger - metrics tally.Scope + db *sql.DB + logger *zap.SugaredLogger + scope tally.Scope } -// Metric names for message store -const ( - metricInsertErrors = "insert.errors" - metricFetchErrors = "fetch.errors" - metricMoveToDLQErrors = "move_to_dlq.errors" -) - // newMessageStore creates a new SQL message store -func newMessageStore(db *sql.DB, logger *zap.Logger, metrics tally.Scope) messageStore { +func newMessageStore(db *sql.DB, logger *zap.SugaredLogger, scope tally.Scope) messageStore { return &sqlmessageStore{ - db: db, - logger: logger.Sugar().Named("message_store"), - metrics: metrics.SubScope("message_store"), + db: db, + logger: logger.Named("message_store"), + scope: scope.SubScope("message_store"), } } // Insert inserts messages into the messages table -func (s *sqlmessageStore) Insert(ctx context.Context, topic string, messages []queue.Message) error { - start := time.Now() - success := false - defer func() { - result := "error" - if success { - result = "success" - } - s.metrics.Tagged(map[string]string{"result": result}).Timer("insert.latency").Record(time.Since(start)) - }() +func (s *sqlmessageStore) Insert(ctx context.Context, topic string, messages []queue.Message) (retErr error) { + op := metrics.Begin(s.scope, "insert", metrics.NewTag("topic", topic)) + defer func() { op.Complete(retErr) }() if len(messages) == 0 { return nil @@ -73,42 +60,26 @@ func (s *sqlmessageStore) Insert(ctx context.Context, topic string, messages []q tx, err := s.db.BeginTx(ctx, nil) if err != nil { - s.logger.Errorw("failed to begin transaction", - logTopic, topic, - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "begin_transaction"}).Counter(metricInsertErrors).Inc(1) - return fmt.Errorf("failed to begin transaction: %w", err) + return fmt.Errorf("begin transaction topic=%s: %w", topic, err) } defer tx.Rollback() stmt, err := tx.PrepareContext(ctx, fmt.Sprintf(` - INSERT INTO %s (topic, id, payload, metadata, partition_key, created_at, published_at, retry_count, invisible_until, failed_at, failure_count, last_error, original_topic) - VALUES (?, ?, ?, ?, ?, ?, ?, 0, 0, 0, 0, '', '') + INSERT INTO %s (topic, id, payload, metadata, partition_key, created_at, published_at, failed_at, failure_count, last_error, original_topic) + VALUES (?, ?, ?, ?, ?, ?, ?, 0, 0, '', '') `, MessagesTableName)) if err != nil { - s.logger.Errorw("failed to prepare statement", - logTopic, topic, - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "prepare_statement"}).Counter(metricInsertErrors).Inc(1) - return fmt.Errorf("failed to prepare statement: %w", err) + return fmt.Errorf("prepare statement topic=%s: %w", topic, err) } defer stmt.Close() - now := start.UnixMilli() + now := time.Now().UnixMilli() for _, msg := range messages { var metadataJSON []byte if len(msg.Metadata) > 0 { metadataJSON, err = json.Marshal(msg.Metadata) if err != nil { - s.logger.Errorw("failed to marshal metadata", - logTopic, topic, - logMessageID, msg.ID, - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "marshal_metadata"}).Counter(metricInsertErrors).Inc(1) - return fmt.Errorf("failed to marshal metadata: %w", err) + return fmt.Errorf("marshal metadata topic=%s message=%s: %w", topic, msg.ID, err) } } @@ -122,124 +93,57 @@ func (s *sqlmessageStore) Insert(ctx context.Context, topic string, messages []q msg.PublishedAt, ) if err != nil { - s.logger.Errorw("failed to insert message", - logTopic, topic, - logMessageID, msg.ID, - logPartitionKey, msg.PartitionKey, - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "exec_statement"}).Counter(metricInsertErrors).Inc(1) - return fmt.Errorf("failed to insert message: %w", err) + return fmt.Errorf("insert message topic=%s message=%s partition=%s: %w", topic, msg.ID, msg.PartitionKey, err) } } if err := tx.Commit(); err != nil { - s.logger.Errorw("failed to commit transaction", - logTopic, topic, - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "commit"}).Counter(metricInsertErrors).Inc(1) - return fmt.Errorf("failed to commit transaction: %w", err) + return fmt.Errorf("commit transaction topic=%s: %w", topic, err) } - s.metrics.Counter("insert.success").Inc(1) - s.metrics.Counter("messages.inserted").Inc(int64(len(messages))) s.logger.Debugw("inserted messages", logTopic, topic, "count", len(messages), - "duration_ms", time.Since(start).Milliseconds(), ) - success = true return nil } -// Delete deletes a message by topic and ID -func (s *sqlmessageStore) Delete(ctx context.Context, topic string, messageID string) error { - start := time.Now() - success := false - defer func() { - result := "error" - if success { - result = "success" - } - s.metrics.Tagged(map[string]string{"result": result}).Timer("delete.latency").Record(time.Since(start)) - }() +// Delete deletes a message by topic, partition key, and ID +func (s *sqlmessageStore) Delete(ctx context.Context, topic string, partitionKey string, messageID string) (retErr error) { + op := metrics.Begin(s.scope, "delete", metrics.NewTag("topic", topic)) + defer func() { op.Complete(retErr) }() - result, err := s.db.ExecContext(ctx, fmt.Sprintf(` - DELETE FROM %s WHERE topic = ? AND id = ? - `, MessagesTableName), topic, messageID) + _, err := s.db.ExecContext(ctx, fmt.Sprintf(` + DELETE FROM %s WHERE topic = ? AND partition_key = ? AND id = ? + `, MessagesTableName), topic, partitionKey, messageID) if err != nil { - s.logger.Errorw("failed to delete message", - logTopic, topic, - logMessageID, messageID, - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "exec_delete"}).Counter("delete.errors").Inc(1) - return err - } - - rows, _ := result.RowsAffected() - s.metrics.Counter("delete.success").Inc(1) - if rows > 0 { - s.metrics.Counter("messages.deleted").Inc(rows) + return fmt.Errorf("delete message topic=%s partition=%s message=%s: %w", topic, partitionKey, messageID, err) } - success = true return nil } -// FetchByOffset fetches visible messages with offset > currentOffset for a specific partition -// Atomically sets invisible_until and increments retry_count for fetched messages -func (s *sqlmessageStore) FetchByOffset(ctx context.Context, topic string, partitionKey string, currentOffset int64, limit int, visibilityTimeoutMs int64) ([]messageRow, error) { - start := time.Now() - success := false - defer func() { - result := "error" - if success { - result = "success" - } - s.metrics.Tagged(map[string]string{"result": result}).Timer("fetch.latency").Record(time.Since(start)) - }() - - now := start.UnixMilli() - invisibleUntil := now + visibilityTimeoutMs - - // Start transaction to atomically fetch and update messages - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - s.logger.Errorw("failed to begin transaction for fetch", - logTopic, topic, - logPartitionKey, partitionKey, - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "begin_transaction"}).Counter(metricFetchErrors).Inc(1) - return nil, fmt.Errorf("failed to begin transaction: %w", err) - } - defer tx.Rollback() +// FetchByOffset fetches messages with offset > currentOffset for a specific partition. +// Messages are fetched from the immutable log; no per-message mutation occurs. +func (s *sqlmessageStore) FetchByOffset(ctx context.Context, topic string, partitionKey string, currentOffset int64, limit int) (_ []messageRow, retErr error) { + op := metrics.Begin(s.scope, "fetch", metrics.NewTag("topic", topic)) + defer func() { op.Complete(retErr) }() - // Fetch visible messages (invisible_until <= now) - rows, err := tx.QueryContext(ctx, fmt.Sprintf(` - SELECT offset, id, payload, metadata, partition_key, retry_count, published_at, failed_at, failure_count, last_error, original_topic + rows, err := s.db.QueryContext(ctx, fmt.Sprintf(` + SELECT offset, id, payload, metadata, partition_key, published_at, failed_at, failure_count, last_error, original_topic FROM %s - WHERE topic = ? AND partition_key = ? AND offset > ? AND invisible_until <= ? + WHERE topic = ? AND partition_key = ? AND offset > ? ORDER BY offset LIMIT ? - `, MessagesTableName), topic, partitionKey, currentOffset, now, limit) + `, MessagesTableName), topic, partitionKey, currentOffset, limit) if err != nil { - s.logger.Errorw("failed to query messages", - logTopic, topic, - logPartitionKey, partitionKey, - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "query"}).Counter(metricFetchErrors).Inc(1) - return nil, fmt.Errorf("failed to query messages: %w", err) + return nil, fmt.Errorf("query messages topic=%s partition=%s: %w", topic, partitionKey, err) } defer rows.Close() var results []messageRow - var messageIDs []string for rows.Next() { var ( @@ -248,7 +152,6 @@ func (s *sqlmessageStore) FetchByOffset(ctx context.Context, topic string, parti payload []byte metadataJSON []byte partKey string - retryCount int publishedAtMilli int64 failedAt int64 failureCount int @@ -256,27 +159,14 @@ func (s *sqlmessageStore) FetchByOffset(ctx context.Context, topic string, parti originalTopic string ) - if err := rows.Scan(&offset, &id, &payload, &metadataJSON, &partKey, &retryCount, &publishedAtMilli, &failedAt, &failureCount, &lastError, &originalTopic); err != nil { - s.logger.Errorw("failed to scan message row", - logTopic, topic, - logPartitionKey, partitionKey, - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "scan_row"}).Counter(metricFetchErrors).Inc(1) - return nil, fmt.Errorf("failed to scan row: %w", err) + if err := rows.Scan(&offset, &id, &payload, &metadataJSON, &partKey, &publishedAtMilli, &failedAt, &failureCount, &lastError, &originalTopic); err != nil { + return nil, fmt.Errorf("scan row topic=%s partition=%s: %w", topic, partitionKey, err) } var metadata map[string]string if len(metadataJSON) > 0 { if err := json.Unmarshal(metadataJSON, &metadata); err != nil { - s.logger.Errorw("failed to unmarshal metadata", - logTopic, topic, - logPartitionKey, partitionKey, - logMessageID, id, - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "unmarshal_metadata"}).Counter(metricFetchErrors).Inc(1) - return nil, fmt.Errorf("failed to unmarshal metadata: %w", err) + return nil, fmt.Errorf("unmarshal metadata topic=%s partition=%s message=%s: %w", topic, partitionKey, id, err) } } if metadata == nil { @@ -289,111 +179,40 @@ func (s *sqlmessageStore) FetchByOffset(ctx context.Context, topic string, parti Payload: payload, Metadata: metadata, PartitionKey: partKey, - RetryCount: retryCount, PublishedAt: publishedAtMilli, FailedAt: failedAt, FailureCount: failureCount, LastError: lastError, OriginalTopic: originalTopic, }) - - messageIDs = append(messageIDs, id) } if err := rows.Err(); err != nil { - s.logger.Errorw("row iteration error", - logTopic, topic, - logPartitionKey, partitionKey, - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "row_iteration"}).Counter(metricFetchErrors).Inc(1) - return nil, fmt.Errorf("row iteration error: %w", err) + return nil, fmt.Errorf("row iteration topic=%s partition=%s: %w", topic, partitionKey, err) } - // Update invisible_until and increment retry_count for fetched messages - if len(messageIDs) > 0 { - // Build IN clause for message IDs - placeholders := "" - for i := range messageIDs { - if i > 0 { - placeholders += "," - } - placeholders += "?" - } - - query := fmt.Sprintf(` - UPDATE %s - SET invisible_until = ?, retry_count = retry_count + 1 - WHERE topic = ? AND partition_key = ? AND id IN (%s) - `, MessagesTableName, placeholders) - - args := []interface{}{invisibleUntil, topic, partitionKey} - for _, id := range messageIDs { - args = append(args, id) - } - - _, err = tx.ExecContext(ctx, query, args...) - if err != nil { - s.logger.Errorw("failed to update message visibility", - logTopic, topic, - logPartitionKey, partitionKey, - "message_count", len(messageIDs), - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "update_visibility"}).Counter(metricFetchErrors).Inc(1) - return nil, fmt.Errorf("failed to update messages: %w", err) - } - } - - if err := tx.Commit(); err != nil { - s.logger.Errorw("failed to commit fetch transaction", - logTopic, topic, - logPartitionKey, partitionKey, - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "commit"}).Counter(metricFetchErrors).Inc(1) - return nil, fmt.Errorf("failed to commit transaction: %w", err) - } - - s.metrics.Counter("fetch.success").Inc(1) - s.metrics.Counter("messages.fetched").Inc(int64(len(results))) s.logger.Debugw("fetched messages", logTopic, topic, logPartitionKey, partitionKey, "count", len(results), - "duration_ms", time.Since(start).Milliseconds(), ) - success = true return results, nil } // MoveToDLQ atomically moves a message to the DLQ by reinserting it with the DLQ topic name // The message is inserted back into queue_messages table with the DLQ topic (original + suffix) // This allows DLQ messages to be consumed using the normal subscriber -func (s *sqlmessageStore) MoveToDLQ(ctx context.Context, topic string, messageID string, failureCount int, lastError string, dlqTopicSuffix string) error { - start := time.Now() - success := false - defer func() { - result := "error" - if success { - result = "success" - } - s.metrics.Tagged(map[string]string{"result": result}).Timer("move_to_dlq.latency").Record(time.Since(start)) - }() +func (s *sqlmessageStore) MoveToDLQ(ctx context.Context, topic string, partitionKey string, messageID string, failureCount int, lastError string, dlqTopicSuffix string) (retErr error) { + op := metrics.Begin(s.scope, "move_to_dlq", metrics.NewTag("topic", topic)) + defer func() { op.Complete(retErr) }() // Construct DLQ topic name dlqTopic := topic + dlqTopicSuffix tx, err := s.db.BeginTx(ctx, nil) if err != nil { - s.logger.Errorw("failed to begin transaction for DLQ move", - logTopic, topic, - logMessageID, messageID, - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "begin_transaction"}).Counter(metricMoveToDLQErrors).Inc(1) - return fmt.Errorf("failed to begin transaction: %w", err) + return fmt.Errorf("begin transaction topic=%s message=%s: %w", topic, messageID, err) } defer tx.Rollback() @@ -401,17 +220,16 @@ func (s *sqlmessageStore) MoveToDLQ(ctx context.Context, topic string, messageID var ( payload []byte metadataJSON []byte - partitionKey string + fetchPartKey string createdAtMilli int64 publishedAtMilli int64 - retryCount int ) err = tx.QueryRowContext(ctx, fmt.Sprintf(` - SELECT payload, metadata, partition_key, created_at, published_at, retry_count + SELECT payload, metadata, partition_key, created_at, published_at FROM %s - WHERE topic = ? AND id = ? - `, MessagesTableName), topic, messageID).Scan(&payload, &metadataJSON, &partitionKey, &createdAtMilli, &publishedAtMilli, &retryCount) + WHERE topic = ? AND partition_key = ? AND id = ? + `, MessagesTableName), topic, partitionKey, messageID).Scan(&payload, &metadataJSON, &fetchPartKey, &createdAtMilli, &publishedAtMilli) if err != nil { if err == sql.ErrNoRows { @@ -422,141 +240,108 @@ func (s *sqlmessageStore) MoveToDLQ(ctx context.Context, topic string, messageID ) return nil } - s.logger.Errorw("failed to fetch message for DLQ", - logTopic, topic, - logMessageID, messageID, - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "fetch_message"}).Counter(metricMoveToDLQErrors).Inc(1) - return fmt.Errorf("failed to fetch message: %w", err) + return fmt.Errorf("fetch message for DLQ topic=%s partition=%s message=%s: %w", topic, partitionKey, messageID, err) } // Insert into queue_messages table with DLQ topic name and DLQ-specific fields - // Reset retry_count to 0 since this is a new topic (DLQ processing starts fresh) - // Store the original failure count for tracking purposes - now := start.UnixMilli() + now := time.Now().UnixMilli() _, err = tx.ExecContext(ctx, fmt.Sprintf(` - INSERT INTO %s (topic, id, payload, metadata, partition_key, created_at, published_at, invisible_until, retry_count, failed_at, failure_count, last_error, original_topic) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - `, MessagesTableName), dlqTopic, messageID, payload, metadataJSON, partitionKey, createdAtMilli, publishedAtMilli, int64(0), 0, now, failureCount, lastError, topic) + INSERT INTO %s (topic, id, payload, metadata, partition_key, created_at, published_at, failed_at, failure_count, last_error, original_topic) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + `, MessagesTableName), dlqTopic, messageID, payload, metadataJSON, fetchPartKey, createdAtMilli, publishedAtMilli, now, failureCount, lastError, topic) if err != nil { - s.logger.Errorw("failed to insert into DLQ topic", - logTopic, topic, - "dlq_topic", dlqTopic, - logMessageID, messageID, - logPartitionKey, partitionKey, - "failure_count", failureCount, - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "insert_dlq"}).Counter(metricMoveToDLQErrors).Inc(1) - return fmt.Errorf("failed to insert into DLQ: %w", err) + return fmt.Errorf("insert into DLQ topic=%s dlq=%s partition=%s message=%s: %w", topic, dlqTopic, partitionKey, messageID, err) } // Delete from original topic _, err = tx.ExecContext(ctx, fmt.Sprintf(` - DELETE FROM %s WHERE topic = ? AND id = ? - `, MessagesTableName), topic, messageID) + DELETE FROM %s WHERE topic = ? AND partition_key = ? AND id = ? + `, MessagesTableName), topic, partitionKey, messageID) if err != nil { - s.logger.Errorw("failed to delete from main table after DLQ insert", - logTopic, topic, - logMessageID, messageID, - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "delete_from_main"}).Counter(metricMoveToDLQErrors).Inc(1) - return fmt.Errorf("failed to delete from main table: %w", err) + return fmt.Errorf("delete from main table topic=%s partition=%s message=%s: %w", topic, partitionKey, messageID, err) } if err := tx.Commit(); err != nil { - s.logger.Errorw("failed to commit DLQ transaction", - logTopic, topic, - logMessageID, messageID, - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "commit"}).Counter(metricMoveToDLQErrors).Inc(1) - return fmt.Errorf("failed to commit transaction: %w", err) + return fmt.Errorf("commit DLQ transaction topic=%s message=%s: %w", topic, messageID, err) } - s.metrics.Counter("move_to_dlq.success").Inc(1) - s.metrics.Counter("messages.moved_to_dlq").Inc(1) - s.logger.Infow("moved message to DLQ", - logTopic, topic, - "dlq_topic", dlqTopic, - logMessageID, messageID, - logPartitionKey, partitionKey, - "failure_count", failureCount, - "last_error", lastError, - "duration_ms", time.Since(start).Milliseconds(), - ) - - success = true return nil } -// SetVisibilityTimeout sets the invisible_until timestamp for a message -// visibilityTimeoutMillis: milliseconds from now to hide the message -// If visibilityTimeoutMillis is 0, makes the message visible immediately -// If visibilityTimeoutMillis > 0, makes the message invisible until now + visibilityTimeoutMillis -func (s *sqlmessageStore) SetVisibilityTimeout(ctx context.Context, topic string, messageID string, visibilityTimeoutMillis int64) error { - start := time.Now() - success := false - defer func() { - result := "error" - if success { - result = "success" - } - s.metrics.Tagged(map[string]string{"result": result}).Timer("set_visibility.latency").Record(time.Since(start)) - }() - - var invisibleUntil int64 - if visibilityTimeoutMillis > 0 { - invisibleUntil = start.UnixMilli() + visibilityTimeoutMillis - } else { - invisibleUntil = 0 +// GarbageCollect deletes messages with offset <= minAckedOffset. +// The caller provides minAckedOffset (from offsetStore), keeping messageStore +// free of cross-table queries. +// Returns the number of rows deleted. +func (s *sqlmessageStore) GarbageCollect(ctx context.Context, topic string, partitionKey string, minAckedOffset int64) (_ int64, retErr error) { + op := metrics.Begin(s.scope, "gc", metrics.NewTag("topic", topic)) + defer func() { op.Complete(retErr) }() + + if minAckedOffset == 0 { + return 0, nil } + // Delete messages up to the minimum acked offset result, err := s.db.ExecContext(ctx, fmt.Sprintf(` - UPDATE %s - SET invisible_until = ? - WHERE topic = ? AND id = ? - `, MessagesTableName), invisibleUntil, topic, messageID) + DELETE FROM %s WHERE topic = ? AND partition_key = ? AND offset <= ? + `, MessagesTableName), topic, partitionKey, minAckedOffset) if err != nil { - s.logger.Errorw("failed to set visibility timeout", - logTopic, topic, - logMessageID, messageID, - "timeout_ms", visibilityTimeoutMillis, - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "exec_set"}).Counter("set_visibility.errors").Inc(1) - return fmt.Errorf("failed to set visibility timeout: %w", err) + return 0, fmt.Errorf("garbage collect messages topic=%s partition=%s: %w", topic, partitionKey, err) } - rows, err := result.RowsAffected() + // RowsAffected error is swallowed because the DELETE query itself succeeded. + // This is a driver-level diagnostic failure — the messages are already deleted. + // We log for visibility but the GC operation is complete. + deleted, err := result.RowsAffected() if err != nil { - s.logger.Warnw("failed to check rows affected", + s.logger.Warnw("garbage collect succeeded but row count unavailable (driver diagnostic failure), no impact on correctness", logTopic, topic, - logMessageID, messageID, + logPartitionKey, partitionKey, logError, err, ) } - - if rows == 0 { - s.logger.Debugw("no rows updated when setting visibility", + if deleted > 0 { + s.logger.Debugw("garbage collected messages", logTopic, topic, - logMessageID, messageID, + logPartitionKey, partitionKey, + "deleted", deleted, + "min_offset", minAckedOffset, ) + metrics.NamedCounter(s.scope, "gc", "messages_deleted", deleted, metrics.NewTag("topic", topic)) } - s.metrics.Counter("set_visibility.success").Inc(1) - s.logger.Debugw("set visibility timeout", - logTopic, topic, - logMessageID, messageID, - "timeout_ms", visibilityTimeoutMillis, - "duration_ms", time.Since(start).Milliseconds(), - ) + return deleted, nil +} - success = true - return nil +// GetOffsetsAbove returns message offsets above afterOffset for a partition, ordered ascending. +func (s *sqlmessageStore) GetOffsetsAbove(ctx context.Context, topic string, partitionKey string, afterOffset int64, limit int) (_ []int64, retErr error) { + op := metrics.Begin(s.scope, "get_offsets_above", metrics.NewTag("topic", topic)) + defer func() { op.Complete(retErr) }() + + rows, err := s.db.QueryContext(ctx, fmt.Sprintf(` + SELECT offset FROM %s + WHERE topic = ? AND partition_key = ? AND offset > ? + ORDER BY offset ASC + LIMIT ? + `, MessagesTableName), topic, partitionKey, afterOffset, limit) + if err != nil { + return nil, fmt.Errorf("query offsets topic=%s partition=%s: %w", topic, partitionKey, err) + } + defer rows.Close() + + var offsets []int64 + for rows.Next() { + var offset int64 + if err := rows.Scan(&offset); err != nil { + return nil, fmt.Errorf("scan offset topic=%s partition=%s: %w", topic, partitionKey, err) + } + offsets = append(offsets, offset) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("offset iteration topic=%s partition=%s: %w", topic, partitionKey, err) + } + + return offsets, nil } diff --git a/extension/queue/mysql/message_store_test.go b/extension/queue/mysql/message_store_test.go index 1f487aa9..1b9bdd18 100644 --- a/extension/queue/mysql/message_store_test.go +++ b/extension/queue/mysql/message_store_test.go @@ -17,10 +17,12 @@ package mysql import ( "context" "database/sql" + "fmt" "testing" "time" "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/uber-go/tally/v4" "go.uber.org/zap/zaptest" @@ -39,12 +41,12 @@ func setupmessageStoreTest(t *testing.T) (*sql.DB, sqlmock.Sqlmock, messageStore db, mock, err := sqlmock.New() require.NoError(t, err) - store := newMessageStore(db, zaptest.NewLogger(t), testMetrics()) + store := newMessageStore(db, zaptest.NewLogger(t).Sugar(), testMetrics()) return db, mock, store } -func TestmessageStore_Insert(t *testing.T) { +func TestMessageStore_Insert(t *testing.T) { tests := []struct { name string messages []queue.Message @@ -96,24 +98,25 @@ func TestmessageStore_Insert(t *testing.T) { } } -func TestmessageStore_Delete(t *testing.T) { +func TestMessageStore_Delete(t *testing.T) { db, mock, store := setupmessageStoreTest(t) defer db.Close() ctx := context.Background() topic := "test_topic" + partitionKey := "part1" messageID := "msg1" mock.ExpectExec("DELETE FROM queue_messages"). - WithArgs(topic, messageID). + WithArgs(topic, partitionKey, messageID). WillReturnResult(sqlmock.NewResult(0, 1)) - err := store.Delete(ctx, topic, messageID) + err := store.Delete(ctx, topic, partitionKey, messageID) require.NoError(t, err) require.NoError(t, mock.ExpectationsWereMet()) } -func TestmessageStore_FetchByOffset(t *testing.T) { +func TestMessageStore_FetchByOffset(t *testing.T) { db, mock, store := setupmessageStoreTest(t) defer db.Close() @@ -122,57 +125,29 @@ func TestmessageStore_FetchByOffset(t *testing.T) { partitionKey := "part1" currentOffset := int64(0) limit := 10 - visibilityTimeoutMs := int64(60000) // 60 seconds in milliseconds - - // Expect transaction begin - mock.ExpectBegin() - // Mock query results (including DLQ columns) - rows := sqlmock.NewRows([]string{"offset", "id", "payload", "metadata", "partition_key", "retry_count", "published_at", "failed_at", "failure_count", "last_error", "original_topic"}). - AddRow(int64(1), "msg1", []byte("payload1"), []byte("{}"), "part1", 0, time.Now().UnixMilli(), int64(0), 0, "", "") + // Mock query results (no transaction, simple SELECT) + rows := sqlmock.NewRows([]string{"offset", "id", "payload", "metadata", "partition_key", "published_at", "failed_at", "failure_count", "last_error", "original_topic"}). + AddRow(int64(1), "msg1", []byte("payload1"), []byte("{}"), "part1", time.Now().UnixMilli(), int64(0), 0, "", "") mock.ExpectQuery("SELECT (.+) FROM queue_messages"). - WithArgs(topic, partitionKey, currentOffset, sqlmock.AnyArg(), limit). + WithArgs(topic, partitionKey, currentOffset, limit). WillReturnRows(rows) - // Expect update for visibility timeout - mock.ExpectExec("UPDATE queue_messages"). - WillReturnResult(sqlmock.NewResult(0, 1)) - - // Expect commit - mock.ExpectCommit() - - results, err := store.FetchByOffset(ctx, topic, partitionKey, currentOffset, limit, visibilityTimeoutMs) + results, err := store.FetchByOffset(ctx, topic, partitionKey, currentOffset, limit) require.NoError(t, err) require.Len(t, results, 1) require.Equal(t, "msg1", results[0].ID) require.NoError(t, mock.ExpectationsWereMet()) } -func TestmessageStore_SetVisibilityTimeout(t *testing.T) { - db, mock, store := setupmessageStoreTest(t) - defer db.Close() - - ctx := context.Background() - topic := "test_topic" - messageID := "msg1" - visibilityTimeoutMillis := int64(5000) - - mock.ExpectExec("UPDATE queue_messages"). - WithArgs(sqlmock.AnyArg(), topic, messageID). - WillReturnResult(sqlmock.NewResult(0, 1)) - - err := store.SetVisibilityTimeout(ctx, topic, messageID, visibilityTimeoutMillis) - require.NoError(t, err) - require.NoError(t, mock.ExpectationsWereMet()) -} - -func TestmessageStore_MoveToDLQ(t *testing.T) { +func TestMessageStore_MoveToDLQ(t *testing.T) { db, mock, store := setupmessageStoreTest(t) defer db.Close() ctx := context.Background() topic := "test_topic" + partitionKey := "part1" messageID := "msg1" failureCount := 3 lastError := "test error" @@ -182,30 +157,150 @@ func TestmessageStore_MoveToDLQ(t *testing.T) { // Expect transaction begin mock.ExpectBegin() - // Mock query for fetching message - SELECT payload, metadata, partition_key, created_at, published_at, retry_count - rows := sqlmock.NewRows([]string{"payload", "metadata", "partition_key", "created_at", "published_at", "retry_count"}). - AddRow([]byte("payload1"), []byte(`{"key":"value"}`), "part1", time.Now().UnixMilli(), time.Now().UnixMilli(), failureCount) + // Mock query for fetching message (now includes partition_key in WHERE) + rows := sqlmock.NewRows([]string{"payload", "metadata", "partition_key", "created_at", "published_at"}). + AddRow([]byte("payload1"), []byte(`{"key":"value"}`), "part1", time.Now().UnixMilli(), time.Now().UnixMilli()) mock.ExpectQuery("SELECT (.+) FROM queue_messages"). - WithArgs(topic, messageID). + WithArgs(topic, partitionKey, messageID). WillReturnRows(rows) - // Expect insert into queue_messages with DLQ topic and DLQ-specific columns - // Columns: topic, id, payload, metadata, partition_key, created_at, published_at, invisible_until, retry_count, failed_at, failure_count, last_error, original_topic - // Note: retry_count is reset to 0 for DLQ processing, but failure_count preserves the original attempts + // Expect insert into queue_messages with DLQ topic mock.ExpectExec("INSERT INTO queue_messages"). - WithArgs(dlqTopic, messageID, sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), int64(0), 0, sqlmock.AnyArg(), failureCount, lastError, topic). + WithArgs(dlqTopic, messageID, sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), failureCount, lastError, topic). WillReturnResult(sqlmock.NewResult(1, 1)) - // Expect delete from main table + // Expect delete from main table (now includes partition_key in WHERE) mock.ExpectExec("DELETE FROM queue_messages"). - WithArgs(topic, messageID). + WithArgs(topic, partitionKey, messageID). WillReturnResult(sqlmock.NewResult(0, 1)) // Expect commit mock.ExpectCommit() - err := store.MoveToDLQ(ctx, topic, messageID, failureCount, lastError, dlqTopicSuffix) + err := store.MoveToDLQ(ctx, topic, partitionKey, messageID, failureCount, lastError, dlqTopicSuffix) require.NoError(t, err) require.NoError(t, mock.ExpectationsWereMet()) } + +func TestMessageStore_GetOffsetsAbove(t *testing.T) { + tests := []struct { + name string + afterOffset int64 + limit int + offsets []int64 + wantErr bool + }{ + { + name: "returns offsets in ascending order", + afterOffset: 5, + limit: 1000, + offsets: []int64{6, 7, 8}, + }, + { + name: "returns offsets with AUTO_INCREMENT gaps", + afterOffset: 5, + limit: 1000, + offsets: []int64{6, 9, 15}, + }, + { + name: "no messages above offset", + afterOffset: 100, + limit: 1000, + offsets: nil, + }, + { + name: "db error", + afterOffset: 0, + limit: 1000, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, mock, store := setupmessageStoreTest(t) + defer db.Close() + + if tt.wantErr { + mock.ExpectQuery("SELECT offset FROM queue_messages"). + WithArgs("test_topic", "part-1", tt.afterOffset, tt.limit). + WillReturnError(fmt.Errorf("db error")) + } else { + rows := sqlmock.NewRows([]string{"offset"}) + for _, offset := range tt.offsets { + rows.AddRow(offset) + } + mock.ExpectQuery("SELECT offset FROM queue_messages"). + WithArgs("test_topic", "part-1", tt.afterOffset, tt.limit). + WillReturnRows(rows) + } + + offsets, err := store.GetOffsetsAbove(context.Background(), "test_topic", "part-1", tt.afterOffset, tt.limit) + + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.offsets, offsets) + } + require.NoError(t, mock.ExpectationsWereMet()) + }) + } +} + +func TestMessageStore_GarbageCollect(t *testing.T) { + tests := []struct { + name string + minAckedOffset int64 + deleteErr bool + wantDeleted int64 + wantErr bool + }{ + { + name: "deletes messages up to min offset", + minAckedOffset: 10, + wantDeleted: 5, + }, + { + name: "zero offset returns 0 deleted", + minAckedOffset: 0, + wantDeleted: 0, + }, + { + name: "delete error", + minAckedOffset: 10, + deleteErr: true, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, mock, store := setupmessageStoreTest(t) + defer db.Close() + + if tt.minAckedOffset > 0 { + if tt.deleteErr { + mock.ExpectExec("DELETE FROM queue_messages"). + WithArgs("test_topic", "part-1", tt.minAckedOffset). + WillReturnError(fmt.Errorf("db error")) + } else { + mock.ExpectExec("DELETE FROM queue_messages"). + WithArgs("test_topic", "part-1", tt.minAckedOffset). + WillReturnResult(sqlmock.NewResult(0, tt.wantDeleted)) + } + } + + deleted, err := store.GarbageCollect(context.Background(), "test_topic", "part-1", tt.minAckedOffset) + + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.wantDeleted, deleted) + } + require.NoError(t, mock.ExpectationsWereMet()) + }) + } +} diff --git a/extension/queue/mysql/mock_stores.go b/extension/queue/mysql/mock_stores.go new file mode 100644 index 00000000..cf8477de --- /dev/null +++ b/extension/queue/mysql/mock_stores.go @@ -0,0 +1,489 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: stores.go +// +// Generated by this command: +// +// mockgen -source=stores.go -destination=mock_stores.go -package=mysql +// + +// Package mysql is a generated GoMock package. +package mysql + +import ( + context "context" + reflect "reflect" + + queue "github.com/uber/submitqueue/entity/queue" + gomock "go.uber.org/mock/gomock" +) + +// MockmessageStore is a mock of messageStore interface. +type MockmessageStore struct { + ctrl *gomock.Controller + recorder *MockmessageStoreMockRecorder + isgomock struct{} +} + +// MockmessageStoreMockRecorder is the mock recorder for MockmessageStore. +type MockmessageStoreMockRecorder struct { + mock *MockmessageStore +} + +// NewMockmessageStore creates a new mock instance. +func NewMockmessageStore(ctrl *gomock.Controller) *MockmessageStore { + mock := &MockmessageStore{ctrl: ctrl} + mock.recorder = &MockmessageStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockmessageStore) EXPECT() *MockmessageStoreMockRecorder { + return m.recorder +} + +// Delete mocks base method. +func (m *MockmessageStore) Delete(ctx context.Context, topic, partitionKey, messageID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Delete", ctx, topic, partitionKey, messageID) + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete. +func (mr *MockmessageStoreMockRecorder) Delete(ctx, topic, partitionKey, messageID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockmessageStore)(nil).Delete), ctx, topic, partitionKey, messageID) +} + +// FetchByOffset mocks base method. +func (m *MockmessageStore) FetchByOffset(ctx context.Context, topic, partitionKey string, currentOffset int64, limit int) ([]messageRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FetchByOffset", ctx, topic, partitionKey, currentOffset, limit) + ret0, _ := ret[0].([]messageRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FetchByOffset indicates an expected call of FetchByOffset. +func (mr *MockmessageStoreMockRecorder) FetchByOffset(ctx, topic, partitionKey, currentOffset, limit any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchByOffset", reflect.TypeOf((*MockmessageStore)(nil).FetchByOffset), ctx, topic, partitionKey, currentOffset, limit) +} + +// GarbageCollect mocks base method. +func (m *MockmessageStore) GarbageCollect(ctx context.Context, topic, partitionKey string, minAckedOffset int64) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GarbageCollect", ctx, topic, partitionKey, minAckedOffset) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GarbageCollect indicates an expected call of GarbageCollect. +func (mr *MockmessageStoreMockRecorder) GarbageCollect(ctx, topic, partitionKey, minAckedOffset any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GarbageCollect", reflect.TypeOf((*MockmessageStore)(nil).GarbageCollect), ctx, topic, partitionKey, minAckedOffset) +} + +// GetOffsetsAbove mocks base method. +func (m *MockmessageStore) GetOffsetsAbove(ctx context.Context, topic, partitionKey string, afterOffset int64, limit int) ([]int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetOffsetsAbove", ctx, topic, partitionKey, afterOffset, limit) + ret0, _ := ret[0].([]int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetOffsetsAbove indicates an expected call of GetOffsetsAbove. +func (mr *MockmessageStoreMockRecorder) GetOffsetsAbove(ctx, topic, partitionKey, afterOffset, limit any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOffsetsAbove", reflect.TypeOf((*MockmessageStore)(nil).GetOffsetsAbove), ctx, topic, partitionKey, afterOffset, limit) +} + +// Insert mocks base method. +func (m *MockmessageStore) Insert(ctx context.Context, topic string, messages []queue.Message) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Insert", ctx, topic, messages) + ret0, _ := ret[0].(error) + return ret0 +} + +// Insert indicates an expected call of Insert. +func (mr *MockmessageStoreMockRecorder) Insert(ctx, topic, messages any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockmessageStore)(nil).Insert), ctx, topic, messages) +} + +// MoveToDLQ mocks base method. +func (m *MockmessageStore) MoveToDLQ(ctx context.Context, topic, partitionKey, messageID string, failureCount int, lastError, dlqTopicSuffix string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MoveToDLQ", ctx, topic, partitionKey, messageID, failureCount, lastError, dlqTopicSuffix) + ret0, _ := ret[0].(error) + return ret0 +} + +// MoveToDLQ indicates an expected call of MoveToDLQ. +func (mr *MockmessageStoreMockRecorder) MoveToDLQ(ctx, topic, partitionKey, messageID, failureCount, lastError, dlqTopicSuffix any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MoveToDLQ", reflect.TypeOf((*MockmessageStore)(nil).MoveToDLQ), ctx, topic, partitionKey, messageID, failureCount, lastError, dlqTopicSuffix) +} + +// MockoffsetStore is a mock of offsetStore interface. +type MockoffsetStore struct { + ctrl *gomock.Controller + recorder *MockoffsetStoreMockRecorder + isgomock struct{} +} + +// MockoffsetStoreMockRecorder is the mock recorder for MockoffsetStore. +type MockoffsetStoreMockRecorder struct { + mock *MockoffsetStore +} + +// NewMockoffsetStore creates a new mock instance. +func NewMockoffsetStore(ctrl *gomock.Controller) *MockoffsetStore { + mock := &MockoffsetStore{ctrl: ctrl} + mock.recorder = &MockoffsetStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockoffsetStore) EXPECT() *MockoffsetStoreMockRecorder { + return m.recorder +} + +// GetAckedOffset mocks base method. +func (m *MockoffsetStore) GetAckedOffset(ctx context.Context, topic, partitionKey, consumerGroup string) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAckedOffset", ctx, topic, partitionKey, consumerGroup) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAckedOffset indicates an expected call of GetAckedOffset. +func (mr *MockoffsetStoreMockRecorder) GetAckedOffset(ctx, topic, partitionKey, consumerGroup any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckedOffset", reflect.TypeOf((*MockoffsetStore)(nil).GetAckedOffset), ctx, topic, partitionKey, consumerGroup) +} + +// GetMinAckedOffset mocks base method. +func (m *MockoffsetStore) GetMinAckedOffset(ctx context.Context, topic, partitionKey string) (int64, bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMinAckedOffset", ctx, topic, partitionKey) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(bool) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// GetMinAckedOffset indicates an expected call of GetMinAckedOffset. +func (mr *MockoffsetStoreMockRecorder) GetMinAckedOffset(ctx, topic, partitionKey any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMinAckedOffset", reflect.TypeOf((*MockoffsetStore)(nil).GetMinAckedOffset), ctx, topic, partitionKey) +} + +// Initialize mocks base method. +func (m *MockoffsetStore) Initialize(ctx context.Context, topic, partitionKey, consumerGroup string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Initialize", ctx, topic, partitionKey, consumerGroup) + ret0, _ := ret[0].(error) + return ret0 +} + +// Initialize indicates an expected call of Initialize. +func (mr *MockoffsetStoreMockRecorder) Initialize(ctx, topic, partitionKey, consumerGroup any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Initialize", reflect.TypeOf((*MockoffsetStore)(nil).Initialize), ctx, topic, partitionKey, consumerGroup) +} + +// UpdateAckedOffset mocks base method. +func (m *MockoffsetStore) UpdateAckedOffset(ctx context.Context, topic, partitionKey string, offset int64, consumerGroup string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateAckedOffset", ctx, topic, partitionKey, offset, consumerGroup) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateAckedOffset indicates an expected call of UpdateAckedOffset. +func (mr *MockoffsetStoreMockRecorder) UpdateAckedOffset(ctx, topic, partitionKey, offset, consumerGroup any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAckedOffset", reflect.TypeOf((*MockoffsetStore)(nil).UpdateAckedOffset), ctx, topic, partitionKey, offset, consumerGroup) +} + +// MockpartitionLeaseStore is a mock of partitionLeaseStore interface. +type MockpartitionLeaseStore struct { + ctrl *gomock.Controller + recorder *MockpartitionLeaseStoreMockRecorder + isgomock struct{} +} + +// MockpartitionLeaseStoreMockRecorder is the mock recorder for MockpartitionLeaseStore. +type MockpartitionLeaseStoreMockRecorder struct { + mock *MockpartitionLeaseStore +} + +// NewMockpartitionLeaseStore creates a new mock instance. +func NewMockpartitionLeaseStore(ctrl *gomock.Controller) *MockpartitionLeaseStore { + mock := &MockpartitionLeaseStore{ctrl: ctrl} + mock.recorder = &MockpartitionLeaseStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockpartitionLeaseStore) EXPECT() *MockpartitionLeaseStoreMockRecorder { + return m.recorder +} + +// DiscoverAndAcquirePartitions mocks base method. +func (m *MockpartitionLeaseStore) DiscoverAndAcquirePartitions(ctx context.Context, topic, subscriberName, consumerGroup string, leaseDurationMs int64, maxPartitions int) (int, []string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DiscoverAndAcquirePartitions", ctx, topic, subscriberName, consumerGroup, leaseDurationMs, maxPartitions) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].([]string) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// DiscoverAndAcquirePartitions indicates an expected call of DiscoverAndAcquirePartitions. +func (mr *MockpartitionLeaseStoreMockRecorder) DiscoverAndAcquirePartitions(ctx, topic, subscriberName, consumerGroup, leaseDurationMs, maxPartitions any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DiscoverAndAcquirePartitions", reflect.TypeOf((*MockpartitionLeaseStore)(nil).DiscoverAndAcquirePartitions), ctx, topic, subscriberName, consumerGroup, leaseDurationMs, maxPartitions) +} + +// GetLeasedPartitions mocks base method. +func (m *MockpartitionLeaseStore) GetLeasedPartitions(ctx context.Context, topic, subscriberName, consumerGroup string) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetLeasedPartitions", ctx, topic, subscriberName, consumerGroup) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetLeasedPartitions indicates an expected call of GetLeasedPartitions. +func (mr *MockpartitionLeaseStoreMockRecorder) GetLeasedPartitions(ctx, topic, subscriberName, consumerGroup any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeasedPartitions", reflect.TypeOf((*MockpartitionLeaseStore)(nil).GetLeasedPartitions), ctx, topic, subscriberName, consumerGroup) +} + +// ReleaseLease mocks base method. +func (m *MockpartitionLeaseStore) ReleaseLease(ctx context.Context, topic, partitionKey, subscriberName, consumerGroup string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReleaseLease", ctx, topic, partitionKey, subscriberName, consumerGroup) + ret0, _ := ret[0].(error) + return ret0 +} + +// ReleaseLease indicates an expected call of ReleaseLease. +func (mr *MockpartitionLeaseStoreMockRecorder) ReleaseLease(ctx, topic, partitionKey, subscriberName, consumerGroup any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReleaseLease", reflect.TypeOf((*MockpartitionLeaseStore)(nil).ReleaseLease), ctx, topic, partitionKey, subscriberName, consumerGroup) +} + +// RenewLease mocks base method. +func (m *MockpartitionLeaseStore) RenewLease(ctx context.Context, topic, partitionKey, subscriberName, consumerGroup string, leaseDurationMs int64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RenewLease", ctx, topic, partitionKey, subscriberName, consumerGroup, leaseDurationMs) + ret0, _ := ret[0].(error) + return ret0 +} + +// RenewLease indicates an expected call of RenewLease. +func (mr *MockpartitionLeaseStoreMockRecorder) RenewLease(ctx, topic, partitionKey, subscriberName, consumerGroup, leaseDurationMs any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewLease", reflect.TypeOf((*MockpartitionLeaseStore)(nil).RenewLease), ctx, topic, partitionKey, subscriberName, consumerGroup, leaseDurationMs) +} + +// TryAcquireLease mocks base method. +func (m *MockpartitionLeaseStore) TryAcquireLease(ctx context.Context, topic, partitionKey, subscriberName, consumerGroup string, leaseDurationMs int64) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TryAcquireLease", ctx, topic, partitionKey, subscriberName, consumerGroup, leaseDurationMs) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// TryAcquireLease indicates an expected call of TryAcquireLease. +func (mr *MockpartitionLeaseStoreMockRecorder) TryAcquireLease(ctx, topic, partitionKey, subscriberName, consumerGroup, leaseDurationMs any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryAcquireLease", reflect.TypeOf((*MockpartitionLeaseStore)(nil).TryAcquireLease), ctx, topic, partitionKey, subscriberName, consumerGroup, leaseDurationMs) +} + +// MocksubscriberHeartbeatStore is a mock of subscriberHeartbeatStore interface. +type MocksubscriberHeartbeatStore struct { + ctrl *gomock.Controller + recorder *MocksubscriberHeartbeatStoreMockRecorder + isgomock struct{} +} + +// MocksubscriberHeartbeatStoreMockRecorder is the mock recorder for MocksubscriberHeartbeatStore. +type MocksubscriberHeartbeatStoreMockRecorder struct { + mock *MocksubscriberHeartbeatStore +} + +// NewMocksubscriberHeartbeatStore creates a new mock instance. +func NewMocksubscriberHeartbeatStore(ctrl *gomock.Controller) *MocksubscriberHeartbeatStore { + mock := &MocksubscriberHeartbeatStore{ctrl: ctrl} + mock.recorder = &MocksubscriberHeartbeatStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MocksubscriberHeartbeatStore) EXPECT() *MocksubscriberHeartbeatStoreMockRecorder { + return m.recorder +} + +// ActiveSubscribers mocks base method. +func (m *MocksubscriberHeartbeatStore) ActiveSubscribers(ctx context.Context, topic, consumerGroup string, staleDurationMs int64) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ActiveSubscribers", ctx, topic, consumerGroup, staleDurationMs) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ActiveSubscribers indicates an expected call of ActiveSubscribers. +func (mr *MocksubscriberHeartbeatStoreMockRecorder) ActiveSubscribers(ctx, topic, consumerGroup, staleDurationMs any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActiveSubscribers", reflect.TypeOf((*MocksubscriberHeartbeatStore)(nil).ActiveSubscribers), ctx, topic, consumerGroup, staleDurationMs) +} + +// Deregister mocks base method. +func (m *MocksubscriberHeartbeatStore) Deregister(ctx context.Context, topic, subscriberName, consumerGroup string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Deregister", ctx, topic, subscriberName, consumerGroup) + ret0, _ := ret[0].(error) + return ret0 +} + +// Deregister indicates an expected call of Deregister. +func (mr *MocksubscriberHeartbeatStoreMockRecorder) Deregister(ctx, topic, subscriberName, consumerGroup any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Deregister", reflect.TypeOf((*MocksubscriberHeartbeatStore)(nil).Deregister), ctx, topic, subscriberName, consumerGroup) +} + +// Heartbeat mocks base method. +func (m *MocksubscriberHeartbeatStore) Heartbeat(ctx context.Context, topic, subscriberName, consumerGroup string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Heartbeat", ctx, topic, subscriberName, consumerGroup) + ret0, _ := ret[0].(error) + return ret0 +} + +// Heartbeat indicates an expected call of Heartbeat. +func (mr *MocksubscriberHeartbeatStoreMockRecorder) Heartbeat(ctx, topic, subscriberName, consumerGroup any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MocksubscriberHeartbeatStore)(nil).Heartbeat), ctx, topic, subscriberName, consumerGroup) +} + +// MockdeliveryStateStore is a mock of deliveryStateStore interface. +type MockdeliveryStateStore struct { + ctrl *gomock.Controller + recorder *MockdeliveryStateStoreMockRecorder + isgomock struct{} +} + +// MockdeliveryStateStoreMockRecorder is the mock recorder for MockdeliveryStateStore. +type MockdeliveryStateStoreMockRecorder struct { + mock *MockdeliveryStateStore +} + +// NewMockdeliveryStateStore creates a new mock instance. +func NewMockdeliveryStateStore(ctrl *gomock.Controller) *MockdeliveryStateStore { + mock := &MockdeliveryStateStore{ctrl: ctrl} + mock.recorder = &MockdeliveryStateStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockdeliveryStateStore) EXPECT() *MockdeliveryStateStoreMockRecorder { + return m.recorder +} + +// AdvanceWatermark mocks base method. +func (m *MockdeliveryStateStore) AdvanceWatermark(ctx context.Context, consumerGroup, topic, partitionKey string, currentWatermark int64, offsets []int64) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AdvanceWatermark", ctx, consumerGroup, topic, partitionKey, currentWatermark, offsets) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AdvanceWatermark indicates an expected call of AdvanceWatermark. +func (mr *MockdeliveryStateStoreMockRecorder) AdvanceWatermark(ctx, consumerGroup, topic, partitionKey, currentWatermark, offsets any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AdvanceWatermark", reflect.TypeOf((*MockdeliveryStateStore)(nil).AdvanceWatermark), ctx, consumerGroup, topic, partitionKey, currentWatermark, offsets) +} + +// ExtendVisibility mocks base method. +func (m *MockdeliveryStateStore) ExtendVisibility(ctx context.Context, consumerGroup, topic, partitionKey string, offset, visibilityTimeoutMs int64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ExtendVisibility", ctx, consumerGroup, topic, partitionKey, offset, visibilityTimeoutMs) + ret0, _ := ret[0].(error) + return ret0 +} + +// ExtendVisibility indicates an expected call of ExtendVisibility. +func (mr *MockdeliveryStateStoreMockRecorder) ExtendVisibility(ctx, consumerGroup, topic, partitionKey, offset, visibilityTimeoutMs any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExtendVisibility", reflect.TypeOf((*MockdeliveryStateStore)(nil).ExtendVisibility), ctx, consumerGroup, topic, partitionKey, offset, visibilityTimeoutMs) +} + +// GetDeliveryState mocks base method. +func (m *MockdeliveryStateStore) GetDeliveryState(ctx context.Context, consumerGroup, topic, partitionKey string, offset int64) (DeliveryState, bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDeliveryState", ctx, consumerGroup, topic, partitionKey, offset) + ret0, _ := ret[0].(DeliveryState) + ret1, _ := ret[1].(bool) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// GetDeliveryState indicates an expected call of GetDeliveryState. +func (mr *MockdeliveryStateStoreMockRecorder) GetDeliveryState(ctx, consumerGroup, topic, partitionKey, offset any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDeliveryState", reflect.TypeOf((*MockdeliveryStateStore)(nil).GetDeliveryState), ctx, consumerGroup, topic, partitionKey, offset) +} + +// MarkAcked mocks base method. +func (m *MockdeliveryStateStore) MarkAcked(ctx context.Context, consumerGroup, topic, partitionKey string, offset int64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MarkAcked", ctx, consumerGroup, topic, partitionKey, offset) + ret0, _ := ret[0].(error) + return ret0 +} + +// MarkAcked indicates an expected call of MarkAcked. +func (mr *MockdeliveryStateStoreMockRecorder) MarkAcked(ctx, consumerGroup, topic, partitionKey, offset any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkAcked", reflect.TypeOf((*MockdeliveryStateStore)(nil).MarkAcked), ctx, consumerGroup, topic, partitionKey, offset) +} + +// MarkDelivered mocks base method. +func (m *MockdeliveryStateStore) MarkDelivered(ctx context.Context, consumerGroup, topic, partitionKey string, offset, visibilityTimeoutMs int64) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MarkDelivered", ctx, consumerGroup, topic, partitionKey, offset, visibilityTimeoutMs) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MarkDelivered indicates an expected call of MarkDelivered. +func (mr *MockdeliveryStateStoreMockRecorder) MarkDelivered(ctx, consumerGroup, topic, partitionKey, offset, visibilityTimeoutMs any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkDelivered", reflect.TypeOf((*MockdeliveryStateStore)(nil).MarkDelivered), ctx, consumerGroup, topic, partitionKey, offset, visibilityTimeoutMs) +} + +// MarkNacked mocks base method. +func (m *MockdeliveryStateStore) MarkNacked(ctx context.Context, consumerGroup, topic, partitionKey string, offset, delayMs int64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MarkNacked", ctx, consumerGroup, topic, partitionKey, offset, delayMs) + ret0, _ := ret[0].(error) + return ret0 +} + +// MarkNacked indicates an expected call of MarkNacked. +func (mr *MockdeliveryStateStoreMockRecorder) MarkNacked(ctx, consumerGroup, topic, partitionKey, offset, delayMs any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkNacked", reflect.TypeOf((*MockdeliveryStateStore)(nil).MarkNacked), ctx, consumerGroup, topic, partitionKey, offset, delayMs) +} diff --git a/extension/queue/mysql/offset_store.go b/extension/queue/mysql/offset_store.go index 07adc788..61318e35 100644 --- a/extension/queue/mysql/offset_store.go +++ b/extension/queue/mysql/offset_store.go @@ -21,32 +21,31 @@ import ( "time" "github.com/uber-go/tally/v4" - "go.uber.org/zap" + "github.com/uber/submitqueue/core/metrics" ) // sqloffsetStore is the SQL implementation of offsetStore type sqloffsetStore struct { - db *sql.DB - logger *zap.SugaredLogger - metrics tally.Scope + db *sql.DB + scope tally.Scope } -// Metric names for offset store -const ( - metricAckMessageErrors = "ack_message.errors" -) - // newOffsetStore creates a new SQL offset store -func newOffsetStore(db *sql.DB, logger *zap.Logger, metrics tally.Scope) offsetStore { +func newOffsetStore(db *sql.DB, scope tally.Scope) offsetStore { return &sqloffsetStore{ - db: db, - logger: logger.Sugar().Named("offset_store"), - metrics: metrics.SubScope("offset_store"), + db: db, + scope: scope.SubScope("offset_store"), } } // Initialize creates an offset entry for a topic+partition if it doesn't exist -func (s *sqloffsetStore) Initialize(ctx context.Context, topic string, partitionKey string, consumerGroup string) error { +func (s *sqloffsetStore) Initialize(ctx context.Context, topic string, partitionKey string, consumerGroup string) (retErr error) { + op := metrics.Begin(s.scope, "initialize", + metrics.NewTag("topic", topic), + metrics.NewTag("partition_key", partitionKey), + metrics.NewTag("consumer_group", consumerGroup)) + defer func() { op.Complete(retErr) }() + now := time.Now().UnixMilli() // Try to insert, ignore if already exists @@ -55,11 +54,21 @@ func (s *sqloffsetStore) Initialize(ctx context.Context, topic string, partition VALUES (?, ?, ?, 0, ?) `, OffsetsTableName), consumerGroup, topic, partitionKey, now) - return err + if err != nil { + return fmt.Errorf("initialize offset topic=%s partition=%s: %w", topic, partitionKey, err) + } + + return nil } // GetAckedOffset returns the current acked offset for a topic+partition -func (s *sqloffsetStore) GetAckedOffset(ctx context.Context, topic string, partitionKey string, consumerGroup string) (int64, error) { +func (s *sqloffsetStore) GetAckedOffset(ctx context.Context, topic string, partitionKey string, consumerGroup string) (_ int64, retErr error) { + op := metrics.Begin(s.scope, "get_acked_offset", + metrics.NewTag("topic", topic), + metrics.NewTag("partition_key", partitionKey), + metrics.NewTag("consumer_group", consumerGroup)) + defer func() { op.Complete(retErr) }() + var offset int64 err := s.db.QueryRowContext(ctx, fmt.Sprintf(` SELECT offset_acked FROM %s WHERE consumer_group = ? AND topic = ? AND partition_key = ? @@ -71,14 +80,20 @@ func (s *sqloffsetStore) GetAckedOffset(ctx context.Context, topic string, parti } if err != nil { - return 0, fmt.Errorf("failed to get acked offset: %w", err) + return 0, fmt.Errorf("get acked offset topic=%s partition=%s: %w", topic, partitionKey, err) } return offset, nil } // UpdateAckedOffset updates the offset_acked for a topic+partition (only if new offset is greater) -func (s *sqloffsetStore) UpdateAckedOffset(ctx context.Context, topic string, partitionKey string, offset int64, consumerGroup string) error { +func (s *sqloffsetStore) UpdateAckedOffset(ctx context.Context, topic string, partitionKey string, offset int64, consumerGroup string) (retErr error) { + op := metrics.Begin(s.scope, "update_acked_offset", + metrics.NewTag("topic", topic), + metrics.NewTag("partition_key", partitionKey), + metrics.NewTag("consumer_group", consumerGroup)) + defer func() { op.Complete(retErr) }() + now := time.Now().UnixMilli() _, err := s.db.ExecContext(ctx, fmt.Sprintf(` @@ -87,67 +102,33 @@ func (s *sqloffsetStore) UpdateAckedOffset(ctx context.Context, topic string, pa WHERE consumer_group = ? AND topic = ? AND partition_key = ? AND offset_acked < ? `, OffsetsTableName), offset, now, consumerGroup, topic, partitionKey, offset) - return err -} - -// AckMessage atomically deletes a message and updates the acked offset -func (s *sqloffsetStore) AckMessage(ctx context.Context, topic string, partitionKey string, messageID string, offset int64, consumerGroup string, messageStore messageStore) error { - start := time.Now() - success := false - defer func() { - result := "error" - if success { - result = "success" - } - s.metrics.Tagged(map[string]string{"result": result}).Timer("ack_message.latency").Record(time.Since(start)) - }() - - tx, err := s.db.BeginTx(ctx, nil) if err != nil { - s.metrics.Tagged(map[string]string{tagErrorType: "begin_transaction"}).Counter(metricAckMessageErrors).Inc(1) - return fmt.Errorf("failed to begin transaction: %w", err) + return fmt.Errorf("update acked offset topic=%s partition=%s: %w", topic, partitionKey, err) } - defer tx.Rollback() - // Delete message - _, err = tx.ExecContext(ctx, fmt.Sprintf(` - DELETE FROM %s WHERE topic = ? AND partition_key = ? AND id = ? - `, MessagesTableName), topic, partitionKey, messageID) - if err != nil { - s.metrics.Tagged(map[string]string{tagErrorType: "delete_message"}).Counter(metricAckMessageErrors).Inc(1) - return fmt.Errorf("failed to delete message: %w", err) - } + return nil +} + +// GetMinAckedOffset returns the minimum offset_acked across all consumer groups +// for a topic+partition. Returns (0, false, nil) if no offset rows exist. +func (s *sqloffsetStore) GetMinAckedOffset(ctx context.Context, topic string, partitionKey string) (_ int64, _ bool, retErr error) { + op := metrics.Begin(s.scope, "get_min_acked_offset", + metrics.NewTag("topic", topic), + metrics.NewTag("partition_key", partitionKey)) + defer func() { op.Complete(retErr) }() - now := start.UnixMilli() + var minOffset int64 + err := s.db.QueryRowContext(ctx, fmt.Sprintf(` + SELECT COALESCE(MIN(offset_acked), 0) FROM %s WHERE topic = ? AND partition_key = ? + `, OffsetsTableName), topic, partitionKey).Scan(&minOffset) - // Update offset_acked (insert if not exists) - _, err = tx.ExecContext(ctx, fmt.Sprintf(` - INSERT INTO %s (consumer_group, topic, partition_key, offset_acked, updated_at) - VALUES (?, ?, ?, ?, ?) - ON DUPLICATE KEY UPDATE - offset_acked = IF(VALUES(offset_acked) > offset_acked, VALUES(offset_acked), offset_acked), - updated_at = VALUES(updated_at) - `, OffsetsTableName), consumerGroup, topic, partitionKey, offset, now) if err != nil { - s.metrics.Tagged(map[string]string{tagErrorType: "update_offset"}).Counter(metricAckMessageErrors).Inc(1) - return fmt.Errorf("failed to update offset: %w", err) + return 0, false, fmt.Errorf("query min acked offset topic=%s partition=%s: %w", topic, partitionKey, err) } - if err := tx.Commit(); err != nil { - s.metrics.Tagged(map[string]string{tagErrorType: "commit"}).Counter(metricAckMessageErrors).Inc(1) - return fmt.Errorf("failed to commit transaction: %w", err) + if minOffset == 0 { + return 0, false, nil } - // Log and emit metrics after transaction completes - s.metrics.Counter("ack_message.success").Inc(1) - s.logger.Debugw("acked message", - logTopic, topic, - logPartitionKey, partitionKey, - logMessageID, messageID, - "offset", offset, - "duration_ms", time.Since(start).Milliseconds(), - ) - - success = true - return nil + return minOffset, true, nil } diff --git a/extension/queue/mysql/offset_store_test.go b/extension/queue/mysql/offset_store_test.go index 4ad10bc4..c145c7c5 100644 --- a/extension/queue/mysql/offset_store_test.go +++ b/extension/queue/mysql/offset_store_test.go @@ -17,11 +17,11 @@ package mysql import ( "context" "database/sql" + "fmt" "testing" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/require" - "go.uber.org/zap/zaptest" ) const ( @@ -35,12 +35,12 @@ func setupoffsetStoreTest(t *testing.T) (*sql.DB, sqlmock.Sqlmock, offsetStore) db, mock, err := sqlmock.New() require.NoError(t, err) - store := newOffsetStore(db, zaptest.NewLogger(t), testMetrics()) + store := newOffsetStore(db, testMetrics()) return db, mock, store } -func TestoffsetStore_Initialize(t *testing.T) { +func TestOffsetStore_Initialize(t *testing.T) { db, mock, store := setupoffsetStoreTest(t) defer db.Close() @@ -57,7 +57,7 @@ func TestoffsetStore_Initialize(t *testing.T) { require.NoError(t, mock.ExpectationsWereMet()) } -func TestoffsetStore_GetAckedOffset(t *testing.T) { +func TestOffsetStore_GetAckedOffset(t *testing.T) { tests := []struct { name string setup func(mock sqlmock.Sqlmock) @@ -110,7 +110,7 @@ func TestoffsetStore_GetAckedOffset(t *testing.T) { } } -func TestoffsetStore_UpdateAckedOffset(t *testing.T) { +func TestOffsetStore_UpdateAckedOffset(t *testing.T) { db, mock, store := setupoffsetStoreTest(t) defer db.Close() @@ -128,36 +128,35 @@ func TestoffsetStore_UpdateAckedOffset(t *testing.T) { require.NoError(t, mock.ExpectationsWereMet()) } -func TestoffsetStore_AckMessage(t *testing.T) { +func TestOffsetStore_GetMinAckedOffset(t *testing.T) { tests := []struct { - name string - setup func(mock sqlmock.Sqlmock) - wantErr bool + name string + minOffset int64 + queryErr bool + wantOffset int64 + wantFound bool + wantErr bool }{ { - name: "successful ack", - setup: func(mock sqlmock.Sqlmock) { - mock.ExpectBegin() - mock.ExpectExec("DELETE FROM queue_messages"). - WithArgs("test_topic", "part1", "msg1"). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectExec("INSERT INTO queue_offsets"). - WithArgs(testConsumerGroup, "test_topic", "part1", int64(100), sqlmock.AnyArg()). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - }, - wantErr: false, + name: "returns min offset across consumer groups", + minOffset: 10, + wantOffset: 10, + wantFound: true, }, { - name: "transaction error", - setup: func(mock sqlmock.Sqlmock) { - mock.ExpectBegin() - mock.ExpectExec("DELETE FROM queue_messages"). - WithArgs("test_topic", "part1", "msg1"). - WillReturnError(sql.ErrConnDone) - mock.ExpectRollback() - }, - wantErr: true, + name: "no offset rows returns not found", + minOffset: 0, + wantFound: false, + }, + { + name: "zero offset returns not found", + minOffset: 0, + wantFound: false, + }, + { + name: "query error", + queryErr: true, + wantErr: true, }, } @@ -166,19 +165,24 @@ func TestoffsetStore_AckMessage(t *testing.T) { db, mock, store := setupoffsetStoreTest(t) defer db.Close() - ctx := context.Background() - topic := "test_topic" - partitionKey := "part1" - messageID := "msg1" - offset := int64(100) + if tt.queryErr { + mock.ExpectQuery("SELECT COALESCE\\(MIN\\(offset_acked\\), 0\\) FROM queue_offsets"). + WithArgs("test_topic", "part-1"). + WillReturnError(fmt.Errorf("db error")) + } else { + mock.ExpectQuery("SELECT COALESCE\\(MIN\\(offset_acked\\), 0\\) FROM queue_offsets"). + WithArgs("test_topic", "part-1"). + WillReturnRows(sqlmock.NewRows([]string{"min"}).AddRow(tt.minOffset)) + } - tt.setup(mock) + offset, found, err := store.GetMinAckedOffset(context.Background(), "test_topic", "part-1") - err := store.AckMessage(ctx, topic, partitionKey, messageID, offset, testConsumerGroup, nil) if tt.wantErr { require.Error(t, err) } else { require.NoError(t, err) + require.Equal(t, tt.wantFound, found) + require.Equal(t, tt.wantOffset, offset) } require.NoError(t, mock.ExpectationsWereMet()) }) diff --git a/extension/queue/mysql/partition_lease_store.go b/extension/queue/mysql/partition_lease_store.go index 3144586f..0fc9982e 100644 --- a/extension/queue/mysql/partition_lease_store.go +++ b/extension/queue/mysql/partition_lease_store.go @@ -18,49 +18,36 @@ import ( "context" "database/sql" "fmt" + "sort" "time" "github.com/uber-go/tally/v4" + "github.com/uber/submitqueue/core/metrics" "go.uber.org/zap" ) // sqlpartitionLeaseStore is the SQL implementation of partitionLeaseStore type sqlpartitionLeaseStore struct { - db *sql.DB - logger *zap.SugaredLogger - metrics tally.Scope + db *sql.DB + logger *zap.SugaredLogger + scope tally.Scope } -// Metric names for partition lease store -const ( - metricTryAcquireLeaseErrors = "try_acquire_lease.errors" - metricRenewLeaseErrors = "renew_lease.errors" - metricGetLeasedPartitionsErrors = "get_leased_partitions.errors" - metricDiscoverAndAcquireErrors = "discover_and_acquire.errors" -) - // newPartitionLeaseStore creates a new SQL partition lease store -func newPartitionLeaseStore(db *sql.DB, logger *zap.Logger, metrics tally.Scope) partitionLeaseStore { +func newPartitionLeaseStore(db *sql.DB, logger *zap.SugaredLogger, scope tally.Scope) partitionLeaseStore { return &sqlpartitionLeaseStore{ - db: db, - logger: logger.Sugar().Named("partition_lease_store"), - metrics: metrics.SubScope("partition_lease_store"), + db: db, + logger: logger.Named("partition_lease_store"), + scope: scope.SubScope("partition_lease_store"), } } // TryAcquireLease attempts to acquire or renew a lease for a partition -func (s *sqlpartitionLeaseStore) TryAcquireLease(ctx context.Context, topic string, partitionKey string, subscriberName string, consumerGroup string, leaseDurationMs int64) (bool, error) { - start := time.Now() - success := false - defer func() { - result := "error" - if success { - result = "success" - } - s.metrics.Tagged(map[string]string{"result": result}).Timer("try_acquire_lease.latency").Record(time.Since(start)) - }() +func (s *sqlpartitionLeaseStore) TryAcquireLease(ctx context.Context, topic string, partitionKey string, subscriberName string, consumerGroup string, leaseDurationMs int64) (_ bool, retErr error) { + op := metrics.Begin(s.scope, "try_acquire_lease", metrics.NewTag("topic", topic)) + defer func() { op.Complete(retErr) }() - now := start.UnixMilli() + now := currentTimeMillis() staleThreshold := now - leaseDurationMs // Try to insert or update stale lease @@ -76,13 +63,7 @@ func (s *sqlpartitionLeaseStore) TryAcquireLease(ctx context.Context, topic stri staleThreshold, staleThreshold, staleThreshold) if err != nil { - s.logger.Errorw("failed to acquire lease", - logTopic, topic, - logPartitionKey, partitionKey, - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "exec_acquire"}).Counter(metricTryAcquireLeaseErrors).Inc(1) - return false, fmt.Errorf("failed to acquire lease: %w", err) + return false, fmt.Errorf("acquire lease topic=%s partition=%s: %w", topic, partitionKey, err) } // Check if we own the lease @@ -93,43 +74,29 @@ func (s *sqlpartitionLeaseStore) TryAcquireLease(ctx context.Context, topic stri `, PartitionLeasesTableName), consumerGroup, topic, partitionKey).Scan(&owner) if err != nil { - s.logger.Errorw("failed to check lease ownership", - logTopic, topic, - logPartitionKey, partitionKey, - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "check_ownership"}).Counter(metricTryAcquireLeaseErrors).Inc(1) - return false, fmt.Errorf("failed to check lease ownership: %w", err) + return false, fmt.Errorf("check lease ownership topic=%s partition=%s: %w", topic, partitionKey, err) } acquired := owner == subscriberName if acquired { - s.metrics.Counter("try_acquire_lease.acquired").Inc(1) + metrics.NamedCounter(s.scope, "try_acquire_lease", "acquired", 1, metrics.NewTag("topic", topic)) s.logger.Debugw("acquired lease", logTopic, topic, logPartitionKey, partitionKey, ) } else { - s.metrics.Counter("try_acquire_lease.not_acquired").Inc(1) + metrics.NamedCounter(s.scope, "try_acquire_lease", "not_acquired", 1, metrics.NewTag("topic", topic)) } - success = true return acquired, nil } // RenewLease renews the lease for a partition owned by this worker -func (s *sqlpartitionLeaseStore) RenewLease(ctx context.Context, topic string, partitionKey string, subscriberName string, consumerGroup string, leaseDurationMs int64) error { - start := time.Now() - success := false - defer func() { - result := "error" - if success { - result = "success" - } - s.metrics.Tagged(map[string]string{"result": result}).Timer("renew_lease.latency").Record(time.Since(start)) - }() +func (s *sqlpartitionLeaseStore) RenewLease(ctx context.Context, topic string, partitionKey string, subscriberName string, consumerGroup string, leaseDurationMs int64) (retErr error) { + op := metrics.Begin(s.scope, "renew_lease", metrics.NewTag("topic", topic)) + defer func() { op.Complete(retErr) }() - now := start.UnixMilli() + now := currentTimeMillis() result, err := s.db.ExecContext(ctx, fmt.Sprintf(` UPDATE %s @@ -138,57 +105,30 @@ func (s *sqlpartitionLeaseStore) RenewLease(ctx context.Context, topic string, p `, PartitionLeasesTableName), now, consumerGroup, topic, partitionKey, subscriberName) if err != nil { - s.logger.Errorw("failed to renew lease", - logTopic, topic, - logPartitionKey, partitionKey, - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "exec_renew"}).Counter(metricRenewLeaseErrors).Inc(1) - return fmt.Errorf("failed to renew lease: %w", err) + return fmt.Errorf("renew lease topic=%s partition=%s: %w", topic, partitionKey, err) } rows, err := result.RowsAffected() if err != nil { - s.logger.Errorw("failed to check renewal result", - logTopic, topic, - logPartitionKey, partitionKey, - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "check_rows_affected"}).Counter(metricRenewLeaseErrors).Inc(1) - return fmt.Errorf("failed to check renewal result: %w", err) + return fmt.Errorf("check renewal result topic=%s partition=%s: %w", topic, partitionKey, err) } if rows == 0 { - s.logger.Warnw("lease not owned by this worker or already expired", - logTopic, topic, - logPartitionKey, partitionKey, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "not_owned"}).Counter(metricRenewLeaseErrors).Inc(1) - return fmt.Errorf("lease not owned by this worker or already expired") + return &ErrLeaseExpired{Topic: topic, PartitionKey: partitionKey} } - s.metrics.Counter("renew_lease.success").Inc(1) s.logger.Debugw("renewed lease", logTopic, topic, logPartitionKey, partitionKey, - "duration_ms", time.Since(start).Milliseconds(), ) - success = true return nil } // ReleaseLease releases the lease for a partition owned by this worker -func (s *sqlpartitionLeaseStore) ReleaseLease(ctx context.Context, topic string, partitionKey string, subscriberName string, consumerGroup string) error { - start := time.Now() - success := false - defer func() { - result := "error" - if success { - result = "success" - } - s.metrics.Tagged(map[string]string{"result": result}).Timer("release_lease.latency").Record(time.Since(start)) - }() +func (s *sqlpartitionLeaseStore) ReleaseLease(ctx context.Context, topic string, partitionKey string, subscriberName string, consumerGroup string) (retErr error) { + op := metrics.Begin(s.scope, "release_lease", metrics.NewTag("topic", topic)) + defer func() { op.Complete(retErr) }() result, err := s.db.ExecContext(ctx, fmt.Sprintf(` DELETE FROM %s @@ -196,41 +136,34 @@ func (s *sqlpartitionLeaseStore) ReleaseLease(ctx context.Context, topic string, `, PartitionLeasesTableName), consumerGroup, topic, partitionKey, subscriberName) if err != nil { - s.logger.Errorw("failed to release lease", + return fmt.Errorf("release lease topic=%s partition=%s: %w", topic, partitionKey, err) + } + + // RowsAffected error is swallowed because the DELETE query itself succeeded. + // This is a driver-level diagnostic failure — the lease is already released. + // We log for visibility but the release operation is complete. + rows, err := result.RowsAffected() + if err != nil { + s.logger.Warnw("failed to get rows affected after release lease", logTopic, topic, logPartitionKey, partitionKey, logError, err, ) - s.metrics.Tagged(map[string]string{tagErrorType: "exec_release"}).Counter("release_lease.errors").Inc(1) - return fmt.Errorf("failed to release lease: %w", err) } - - // Only increment success counter if we actually deleted a row (idempotent) - rows, _ := result.RowsAffected() if rows > 0 { - s.metrics.Counter("release_lease.success").Inc(1) s.logger.Debugw("released lease", logTopic, topic, logPartitionKey, partitionKey, - "duration_ms", time.Since(start).Milliseconds(), ) } - success = true return nil } // GetLeasedPartitions returns all partitions currently leased by this worker -func (s *sqlpartitionLeaseStore) GetLeasedPartitions(ctx context.Context, topic string, subscriberName string, consumerGroup string) ([]string, error) { - start := time.Now() - success := false - defer func() { - result := "error" - if success { - result = "success" - } - s.metrics.Tagged(map[string]string{"result": result}).Timer("get_leased_partitions.latency").Record(time.Since(start)) - }() +func (s *sqlpartitionLeaseStore) GetLeasedPartitions(ctx context.Context, topic string, subscriberName string, consumerGroup string) (_ []string, retErr error) { + op := metrics.Begin(s.scope, "get_leased_partitions", metrics.NewTag("topic", topic)) + defer func() { op.Complete(retErr) }() rows, err := s.db.QueryContext(ctx, fmt.Sprintf(` SELECT partition_key FROM %s @@ -238,12 +171,7 @@ func (s *sqlpartitionLeaseStore) GetLeasedPartitions(ctx context.Context, topic `, PartitionLeasesTableName), consumerGroup, topic, subscriberName) if err != nil { - s.logger.Errorw("failed to get leased partitions", - logTopic, topic, - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "query"}).Counter(metricGetLeasedPartitionsErrors).Inc(1) - return nil, fmt.Errorf("failed to get leased partitions: %w", err) + return nil, fmt.Errorf("get leased partitions topic=%s: %w", topic, err) } defer rows.Close() @@ -251,54 +179,40 @@ func (s *sqlpartitionLeaseStore) GetLeasedPartitions(ctx context.Context, topic for rows.Next() { var partition string if err := rows.Scan(&partition); err != nil { - s.logger.Errorw("failed to scan partition", - logTopic, topic, - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "scan_partition"}).Counter(metricGetLeasedPartitionsErrors).Inc(1) - return nil, fmt.Errorf("failed to scan partition: %w", err) + return nil, fmt.Errorf("scan partition topic=%s: %w", topic, err) } partitions = append(partitions, partition) } - s.metrics.Counter("get_leased_partitions.success").Inc(1) - s.metrics.Counter("partitions.leased").Inc(int64(len(partitions))) + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("row iteration topic=%s: %w", topic, err) + } + s.logger.Debugw("retrieved leased partitions", logTopic, topic, "count", len(partitions), - "duration_ms", time.Since(start).Milliseconds(), ) - success = true return partitions, nil } -// DiscoverAndAcquirePartitions discovers partitions from messages table and tries to acquire leases -// Returns the number of new leases acquired -func (s *sqlpartitionLeaseStore) DiscoverAndAcquirePartitions(ctx context.Context, topic string, subscriberName string, consumerGroup string, leaseDurationMs int64) (int, error) { - start := time.Now() - success := false - defer func() { - result := "error" - if success { - result = "success" - } - s.metrics.Tagged(map[string]string{"result": result}).Timer("discover_and_acquire.latency").Record(time.Since(start)) - }() - - // Query distinct partition_keys from messages table - // LIMIT 100: Cap discovery to prevent overwhelming the system when there are many partitions. - // Workers will naturally discover and acquire partitions over multiple discovery cycles. +// DiscoverAndAcquirePartitions discovers partitions from messages table and tries to acquire leases. +// Returns the number of new leases acquired and the full list of discovered partitions. +// maxPartitions limits how many total partitions this subscriber can own (0 = unlimited) +func (s *sqlpartitionLeaseStore) DiscoverAndAcquirePartitions(ctx context.Context, topic string, subscriberName string, consumerGroup string, leaseDurationMs int64, maxPartitions int) (_ int, _ []string, retErr error) { + op := metrics.Begin(s.scope, "discover_and_acquire", metrics.NewTag("topic", topic)) + defer func() { op.Complete(retErr) }() + + // Query distinct partition_keys from messages table. + // No LIMIT is applied because all partitions must be discoverable for fair + // share computation to be accurate — a LIMIT would silently exclude partitions, + // making them permanently unprocessable. The maxPartitions cap only limits how + // many leases this subscriber acquires, not how many partitions are visible. rows, err := s.db.QueryContext(ctx, fmt.Sprintf(` - SELECT DISTINCT partition_key FROM %s WHERE topic = ? LIMIT 100 + SELECT DISTINCT partition_key FROM %s WHERE topic = ? ORDER BY partition_key `, MessagesTableName), topic) if err != nil { - s.logger.Errorw("failed to discover partitions", - logTopic, topic, - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "query_partitions"}).Counter(metricDiscoverAndAcquireErrors).Inc(1) - return 0, fmt.Errorf("failed to discover partitions: %w", err) + return 0, nil, fmt.Errorf("discover partitions topic=%s: %w", topic, err) } defer rows.Close() @@ -306,28 +220,58 @@ func (s *sqlpartitionLeaseStore) DiscoverAndAcquirePartitions(ctx context.Contex for rows.Next() { var partitionKey string if err := rows.Scan(&partitionKey); err != nil { - s.logger.Errorw("failed to scan partition key", - logTopic, topic, - logError, err, - ) - s.metrics.Tagged(map[string]string{tagErrorType: "scan_partition"}).Counter(metricDiscoverAndAcquireErrors).Inc(1) - return 0, fmt.Errorf("failed to scan partition key: %w", err) + return 0, nil, fmt.Errorf("scan partition key topic=%s: %w", topic, err) } partitions = append(partitions, partitionKey) } + if err := rows.Err(); err != nil { + return 0, nil, fmt.Errorf("row iteration topic=%s: %w", topic, err) + } + s.logger.Debugw("discovered partitions", logTopic, topic, "count", len(partitions), ) + // Query owned partitions once before the loop to avoid N+1 queries. + // Build a set of already-owned partition keys so we can distinguish + // re-acquiring an already-owned partition from acquiring a new one. + ownedCount := 0 + ownedSet := make(map[string]struct{}) + if maxPartitions > 0 { + owned, err := s.GetLeasedPartitions(ctx, topic, subscriberName, consumerGroup) + if err != nil { + return 0, nil, fmt.Errorf("get owned partitions for cap check topic=%s: %w", topic, err) + } + ownedCount = len(owned) + for _, pk := range owned { + ownedSet[pk] = struct{}{} + } + } + + // Sort partitions deterministically + sort.Strings(partitions) + // Try to acquire leases for discovered partitions acquiredCount := 0 for _, partitionKey := range partitions { + // Enforce maxPartitions cap using local count + if maxPartitions > 0 && ownedCount >= maxPartitions { + s.logger.Infow("reached max partitions cap, stopping acquisition", + logTopic, topic, + "max_partitions", maxPartitions, + "owned_count", ownedCount, + ) + break + } + acquired, err := s.TryAcquireLease(ctx, topic, partitionKey, subscriberName, consumerGroup, leaseDurationMs) if err != nil { - // Log but continue trying other partitions - s.logger.Warnw("failed to acquire lease for partition", + // Per-partition error is swallowed because one partition's DB failure + // should not prevent acquiring leases for other partitions. The failed + // partition is retried on the next discovery cycle. + s.logger.Errorw("failed to acquire lease for partition", logTopic, topic, logPartitionKey, partitionKey, logError, err, @@ -335,20 +279,28 @@ func (s *sqlpartitionLeaseStore) DiscoverAndAcquirePartitions(ctx context.Contex continue } if acquired { - acquiredCount++ + // Only count as newly acquired if not already owned. + // TryAcquireLease returns true for already-owned partitions (renew), + // so we must not double-count them against the maxPartitions cap. + if _, alreadyOwned := ownedSet[partitionKey]; !alreadyOwned { + acquiredCount++ + ownedCount++ + } } } - s.metrics.Counter("discover_and_acquire.success").Inc(1) - s.metrics.Counter("partitions.discovered").Inc(int64(len(partitions))) - s.metrics.Counter("partitions.acquired").Inc(int64(acquiredCount)) + metrics.NamedCounter(s.scope, "discover_and_acquire", "partitions_discovered", int64(len(partitions)), metrics.NewTag("topic", topic)) + metrics.NamedCounter(s.scope, "discover_and_acquire", "partitions_acquired", int64(acquiredCount), metrics.NewTag("topic", topic)) s.logger.Infow("completed partition discovery and acquisition", logTopic, topic, "discovered_count", len(partitions), "acquired_count", acquiredCount, - "duration_ms", time.Since(start).Milliseconds(), ) - success = true - return acquiredCount, nil + return acquiredCount, partitions, nil +} + +// currentTimeMillis returns the current time in milliseconds since epoch. +func currentTimeMillis() int64 { + return time.Now().UnixMilli() } diff --git a/extension/queue/mysql/partition_lease_store_test.go b/extension/queue/mysql/partition_lease_store_test.go index c8565a90..d18291c8 100644 --- a/extension/queue/mysql/partition_lease_store_test.go +++ b/extension/queue/mysql/partition_lease_store_test.go @@ -33,12 +33,12 @@ func setuppartitionLeaseStoreTest(t *testing.T) (*sql.DB, sqlmock.Sqlmock, parti db, mock, err := sqlmock.New() require.NoError(t, err) - store := newPartitionLeaseStore(db, zaptest.NewLogger(t), tally.NoopScope) + store := newPartitionLeaseStore(db, zaptest.NewLogger(t).Sugar(), tally.NoopScope) return db, mock, store } -func TestpartitionLeaseStore_TryAcquireLease(t *testing.T) { +func TestPartitionLeaseStore_TryAcquireLease(t *testing.T) { tests := []struct { name string setup func(mock sqlmock.Sqlmock) @@ -97,7 +97,7 @@ func TestpartitionLeaseStore_TryAcquireLease(t *testing.T) { } } -func TestpartitionLeaseStore_RenewLease(t *testing.T) { +func TestPartitionLeaseStore_RenewLease(t *testing.T) { tests := []struct { name string setup func(mock sqlmock.Sqlmock) @@ -145,7 +145,7 @@ func TestpartitionLeaseStore_RenewLease(t *testing.T) { } } -func TestpartitionLeaseStore_ReleaseLease(t *testing.T) { +func TestPartitionLeaseStore_ReleaseLease(t *testing.T) { tests := []struct { name string setup func(mock sqlmock.Sqlmock) @@ -193,7 +193,7 @@ func TestpartitionLeaseStore_ReleaseLease(t *testing.T) { } } -func TestpartitionLeaseStore_GetLeasedPartitions(t *testing.T) { +func TestPartitionLeaseStore_GetLeasedPartitions(t *testing.T) { db, mock, store := setuppartitionLeaseStoreTest(t) defer db.Close() @@ -216,40 +216,148 @@ func TestpartitionLeaseStore_GetLeasedPartitions(t *testing.T) { require.NoError(t, mock.ExpectationsWereMet()) } -func TestpartitionLeaseStore_DiscoverAndAcquirePartitions(t *testing.T) { - db, mock, store := setuppartitionLeaseStoreTest(t) - defer db.Close() +func TestPartitionLeaseStore_DiscoverAndAcquirePartitions(t *testing.T) { + tests := []struct { + name string + maxPartitions int + setup func(mock sqlmock.Sqlmock) + wantAcquired int + wantErr bool + }{ + { + name: "unlimited - acquires all available", + maxPartitions: 0, + setup: func(mock sqlmock.Sqlmock) { + // Discover partitions + rows := sqlmock.NewRows([]string{"partition_key"}). + AddRow("part1"). + AddRow("part2") + mock.ExpectQuery("SELECT DISTINCT partition_key FROM queue_messages"). + WithArgs("test_topic"). + WillReturnRows(rows) - ctx := context.Background() - topic := "test_topic" + // Acquire part1 - success + mock.ExpectExec("INSERT INTO queue_partition_leases"). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectQuery("SELECT leased_by FROM queue_partition_leases"). + WillReturnRows(sqlmock.NewRows([]string{"leased_by"}).AddRow(testSubscriberName)) - // Expect query for distinct partition keys - rows := sqlmock.NewRows([]string{"partition_key"}). - AddRow("part1"). - AddRow("part2") + // Acquire part2 - taken by other worker + mock.ExpectExec("INSERT INTO queue_partition_leases"). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectQuery("SELECT leased_by FROM queue_partition_leases"). + WillReturnRows(sqlmock.NewRows([]string{"leased_by"}).AddRow("other-worker")) + }, + wantAcquired: 1, + }, + { + name: "stops acquiring when cap reached", + maxPartitions: 2, + setup: func(mock sqlmock.Sqlmock) { + // Discover 3 partitions + rows := sqlmock.NewRows([]string{"partition_key"}). + AddRow("part1"). + AddRow("part2"). + AddRow("part3") + mock.ExpectQuery("SELECT DISTINCT partition_key FROM queue_messages"). + WithArgs("test_topic"). + WillReturnRows(rows) - mock.ExpectQuery("SELECT DISTINCT partition_key FROM queue_messages"). - WithArgs(topic). - WillReturnRows(rows) + // Pre-loop GetLeasedPartitions: owns 0 partitions + mock.ExpectQuery("SELECT partition_key FROM queue_partition_leases"). + WithArgs(testConsumerGroup, "test_topic", testSubscriberName). + WillReturnRows(sqlmock.NewRows([]string{"partition_key"})) + + // Acquire part1 - success + mock.ExpectExec("INSERT INTO queue_partition_leases"). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectQuery("SELECT leased_by FROM queue_partition_leases"). + WillReturnRows(sqlmock.NewRows([]string{"leased_by"}).AddRow(testSubscriberName)) - // For each partition, expect acquire attempt - for i := 0; i < 2; i++ { - // Expect insert/update - mock.ExpectExec("INSERT INTO queue_partition_leases"). - WillReturnResult(sqlmock.NewResult(1, 1)) - - // Expect ownership check - first one acquired, second not - owner := testSubscriberName - if i == 1 { - owner = "other-worker" - } - ownerRows := sqlmock.NewRows([]string{"leased_by"}).AddRow(owner) - mock.ExpectQuery("SELECT leased_by FROM queue_partition_leases"). - WillReturnRows(ownerRows) + // Acquire part2 - success (now at cap of 2, stops) + mock.ExpectExec("INSERT INTO queue_partition_leases"). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectQuery("SELECT leased_by FROM queue_partition_leases"). + WillReturnRows(sqlmock.NewRows([]string{"leased_by"}).AddRow(testSubscriberName)) + + // part3 is never attempted because ownedCount (2) >= maxPartitions (2) + }, + wantAcquired: 2, + }, + { + name: "pre-owned partitions count toward cap", + maxPartitions: 3, + setup: func(mock sqlmock.Sqlmock) { + // Discover 3 partitions + rows := sqlmock.NewRows([]string{"partition_key"}). + AddRow("part1"). + AddRow("part2"). + AddRow("part3") + mock.ExpectQuery("SELECT DISTINCT partition_key FROM queue_messages"). + WithArgs("test_topic"). + WillReturnRows(rows) + + // Pre-loop GetLeasedPartitions: already owns 2 partitions + mock.ExpectQuery("SELECT partition_key FROM queue_partition_leases"). + WithArgs(testConsumerGroup, "test_topic", testSubscriberName). + WillReturnRows(sqlmock.NewRows([]string{"partition_key"}). + AddRow("existing1"). + AddRow("existing2")) + + // Acquire part1 - success (now at 3, cap reached) + mock.ExpectExec("INSERT INTO queue_partition_leases"). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectQuery("SELECT leased_by FROM queue_partition_leases"). + WillReturnRows(sqlmock.NewRows([]string{"leased_by"}).AddRow(testSubscriberName)) + + // part2, part3 never attempted because ownedCount (3) >= maxPartitions (3) + }, + wantAcquired: 1, + }, + { + name: "already at cap - acquires nothing", + maxPartitions: 2, + setup: func(mock sqlmock.Sqlmock) { + // Discover 2 partitions + rows := sqlmock.NewRows([]string{"partition_key"}). + AddRow("part1"). + AddRow("part2") + mock.ExpectQuery("SELECT DISTINCT partition_key FROM queue_messages"). + WithArgs("test_topic"). + WillReturnRows(rows) + + // Pre-loop GetLeasedPartitions: already owns 2 partitions (at cap) + mock.ExpectQuery("SELECT partition_key FROM queue_partition_leases"). + WithArgs(testConsumerGroup, "test_topic", testSubscriberName). + WillReturnRows(sqlmock.NewRows([]string{"partition_key"}). + AddRow("existing1"). + AddRow("existing2")) + + // No acquire attempts - immediately breaks + }, + wantAcquired: 0, + }, } - acquired, err := store.DiscoverAndAcquirePartitions(ctx, topic, testSubscriberName, testConsumerGroup, testLeaseDurationMs) - require.NoError(t, err) - require.Equal(t, 1, acquired) // Only 1 out of 2 was acquired - require.NoError(t, mock.ExpectationsWereMet()) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, mock, store := setuppartitionLeaseStoreTest(t) + defer db.Close() + + ctx := context.Background() + topic := "test_topic" + + tt.setup(mock) + + acquired, discoveredPartitions, err := store.DiscoverAndAcquirePartitions(ctx, topic, testSubscriberName, testConsumerGroup, testLeaseDurationMs, tt.maxPartitions) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.wantAcquired, acquired) + require.NotNil(t, discoveredPartitions) + } + require.NoError(t, mock.ExpectationsWereMet()) + }) + } } diff --git a/extension/queue/mysql/publisher.go b/extension/queue/mysql/publisher.go index 0b86ad0e..e0fe05a5 100644 --- a/extension/queue/mysql/publisher.go +++ b/extension/queue/mysql/publisher.go @@ -22,46 +22,46 @@ import ( "github.com/uber-go/tally/v4" "go.uber.org/zap" + "github.com/uber/submitqueue/core/metrics" "github.com/uber/submitqueue/entity/queue" ) type publisher struct { logger *zap.SugaredLogger - metrics tally.Scope + scope tally.Scope messageStore messageStore mu sync.RWMutex closed bool } // NewPublisher creates a publisher with the given dependencies -func NewPublisher(logger *zap.SugaredLogger, metrics tally.Scope, messageStore messageStore) *publisher { +func NewPublisher(logger *zap.SugaredLogger, scope tally.Scope, messageStore messageStore) *publisher { return &publisher{ - logger: logger, - metrics: metrics, + logger: logger.Named("publisher"), + scope: scope, messageStore: messageStore, } } // Publish sends a message to the specified topic -func (p *publisher) Publish(ctx context.Context, topic string, message queue.Message) error { +func (p *publisher) Publish(ctx context.Context, topic string, message queue.Message) (retErr error) { + op := metrics.Begin(p.scope, "publish", metrics.NewTag("topic", topic)) + defer func() { op.Complete(retErr) }() + // Check if closed (under lock) p.mu.RLock() closed := p.closed p.mu.RUnlock() if closed { - p.logger.Errorw("publish failure: publisher is closed", "topic", topic) - return fmt.Errorf("publisher is closed") + return ErrPublisherClosed } if err := p.messageStore.Insert(ctx, topic, []queue.Message{message}); err != nil { - p.metrics.Tagged(map[string]string{"topic": topic}).Counter("publish_errors").Inc(1) - p.logger.Errorw("publish failure: message store insert error", "topic", topic, "error", err) return fmt.Errorf("publish message store insert error: %w", err) } - p.metrics.Tagged(map[string]string{"topic": topic}).Counter("messages_published").Inc(1) - p.logger.Debugw("published message", "topic", topic, "message_id", message.ID) + p.logger.Debugw("published message", logTopic, topic, logMessageID, message.ID) return nil } @@ -72,6 +72,6 @@ func (p *publisher) Close() error { p.closed = true p.mu.Unlock() - p.logger.Info("publisher closed") + p.logger.Infow("publisher closed") return nil } diff --git a/extension/queue/mysql/publisher_test.go b/extension/queue/mysql/publisher_test.go index 5e38b296..6138aabd 100644 --- a/extension/queue/mysql/publisher_test.go +++ b/extension/queue/mysql/publisher_test.go @@ -16,6 +16,7 @@ package mysql import ( "context" + "errors" "fmt" "testing" @@ -156,6 +157,7 @@ func TestPublisher_PublishAfterClose(t *testing.T) { msg := queue.NewMessage("msg1", []byte("payload"), "part1", nil) err = pub.Publish(ctx, "test_topic", msg) require.Error(t, err) + require.True(t, errors.Is(err, ErrPublisherClosed)) } func TestPublisher_Close(t *testing.T) { diff --git a/extension/queue/mysql/schema/BUILD.bazel b/extension/queue/mysql/schema/BUILD.bazel index 9f8f6151..89c62abb 100644 --- a/extension/queue/mysql/schema/BUILD.bazel +++ b/extension/queue/mysql/schema/BUILD.bazel @@ -1,9 +1,11 @@ filegroup( name = "schema", srcs = [ + "queue_delivery_state.sql", "queue_messages.sql", "queue_offsets.sql", "queue_partition_leases.sql", + "queue_subscriber_heartbeats.sql", ], visibility = ["//visibility:public"], ) diff --git a/extension/queue/mysql/schema/queue_delivery_state.sql b/extension/queue/mysql/schema/queue_delivery_state.sql new file mode 100644 index 00000000..1b9ddc8f --- /dev/null +++ b/extension/queue/mysql/schema/queue_delivery_state.sql @@ -0,0 +1,35 @@ +-- DELIVERY STATE TABLE +-- Per-consumer-group delivery tracking for messages in the immutable log. +-- Tracks visibility, ack state, and retry count independently per consumer group. +-- +-- State encoding: +-- acked = TRUE → processed, never redeliver +-- acked = FALSE, invisible_until > now → in-flight or nack delay +-- acked = FALSE, invisible_until <= now → ready for (re-)delivery + +CREATE TABLE IF NOT EXISTS queue_delivery_state ( + -- Consumer group this delivery state belongs to + consumer_group VARCHAR(255) NOT NULL, + + -- Topic of the message + topic VARCHAR(255) NOT NULL, + + -- Partition key of the message + partition_key VARCHAR(255) NOT NULL, + + -- Offset of the message in the immutable log + message_offset BIGINT UNSIGNED NOT NULL, + + -- Whether this consumer group has successfully processed this message + acked BOOLEAN NOT NULL DEFAULT FALSE, + + -- Visibility timeout (epoch milliseconds) + -- Only meaningful when acked = FALSE. + -- Future timestamp = in-flight or nack delay, 0/past = ready for delivery. + invisible_until BIGINT UNSIGNED NOT NULL DEFAULT 0, + + -- Number of times this message has been redelivered to this consumer group + retry_count INT UNSIGNED NOT NULL DEFAULT 0, + + PRIMARY KEY (consumer_group, topic, partition_key, message_offset) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin; diff --git a/extension/queue/mysql/schema/queue_messages.sql b/extension/queue/mysql/schema/queue_messages.sql index a60851c5..50e887b2 100644 --- a/extension/queue/mysql/schema/queue_messages.sql +++ b/extension/queue/mysql/schema/queue_messages.sql @@ -1,5 +1,6 @@ --- MESSAGES TABLE +-- MESSAGES TABLE (Immutable Log) -- Single table for all topics. Partition key determines distribution across workers. +-- Messages are append-only; per-consumer-group delivery tracking is in queue_delivery_state. -- Example: topic="merge_queue", partition_key="uber/cadence" CREATE TABLE IF NOT EXISTS queue_messages ( @@ -20,13 +21,6 @@ CREATE TABLE IF NOT EXISTS queue_messages ( payload BLOB NOT NULL, metadata JSON, - -- Retry tracking (persistent across workers) - retry_count INT UNSIGNED NOT NULL, - - -- Visibility timeout (epoch milliseconds) - -- Messages invisible until this timestamp expires - invisible_until BIGINT UNSIGNED NOT NULL, - -- Timestamps (epoch milliseconds) created_at BIGINT UNSIGNED NOT NULL, published_at BIGINT UNSIGNED NOT NULL, @@ -34,15 +28,13 @@ CREATE TABLE IF NOT EXISTS queue_messages ( -- DLQ-specific fields (0/"" for normal messages, populated for DLQ messages) failed_at BIGINT UNSIGNED NOT NULL, -- failure_count stores how many times the message failed on the ORIGINAL topic before moving to DLQ - -- This is different from retry_count, which tracks retries on the CURRENT topic and gets reset to 0 on DLQ move - -- We need both because: retry_count must reset for DLQ processing, but we still need to know original failure count failure_count INT UNSIGNED NOT NULL, last_error TEXT NOT NULL, original_topic VARCHAR(255) NOT NULL, - -- Supports: SELECT ... WHERE topic=? AND partition_key=? AND invisible_until<=? ORDER BY offset - -- Used by subscribers to poll for ready-to-process messages within their assigned partition - INDEX idx_topic_partition_visible_offset (topic, partition_key, invisible_until, offset), + -- Supports: SELECT ... WHERE topic=? AND partition_key=? AND offset > ? ORDER BY offset + -- Used by subscribers to poll for messages within their assigned partition + INDEX idx_topic_partition_offset (topic, partition_key, offset), -- Supports: INSERT ... ON DUPLICATE KEY to enforce idempotent publishes -- Also enables efficient lookups for message updates/deletes by ID diff --git a/extension/queue/mysql/schema/queue_offsets.sql b/extension/queue/mysql/schema/queue_offsets.sql index 888aa666..0a171769 100644 --- a/extension/queue/mysql/schema/queue_offsets.sql +++ b/extension/queue/mysql/schema/queue_offsets.sql @@ -1,6 +1,10 @@ -- CONSUMER OFFSETS TABLE -- Tracks consumption progress per consumer group + topic + partition. -- Each partition has independent offset tracking for crash recovery. +-- +-- The primary key (consumer_group, topic, partition_key) serves as the main +-- lookup index for all queries in offsetStore. No additional indexes are needed +-- because all queries filter by the full primary key or a left prefix of it. CREATE TABLE IF NOT EXISTS queue_offsets ( -- Consumer group consuming the topic @@ -18,15 +22,12 @@ CREATE TABLE IF NOT EXISTS queue_offsets ( -- Last update timestamp (epoch milliseconds) updated_at BIGINT UNSIGNED NOT NULL, - -- Primary key ensures each consumer group has one offset per topic/partition - -- Supports: INSERT ... ON DUPLICATE KEY UPDATE for idempotent offset updates + -- Primary key ensures each consumer group has one offset per topic/partition. + -- Supports: INSERT ... ON DUPLICATE KEY UPDATE for idempotent offset updates. -- Also enables efficient lookups: SELECT ... WHERE consumer_group=? AND topic=? AND partition_key=? + -- Left-prefix covers: SELECT ... WHERE consumer_group=? (all offsets for a group) PRIMARY KEY (consumer_group, topic, partition_key), - -- Supports: SELECT ... WHERE consumer_group=? - -- Used for querying all offsets for a specific consumer group (e.g., for monitoring or rebalancing) - INDEX idx_consumer_group (consumer_group), - -- Supports: SELECT ... WHERE topic=? -- Used for querying all consumer groups consuming a specific topic INDEX idx_topic (topic) diff --git a/extension/queue/mysql/schema/queue_partition_leases.sql b/extension/queue/mysql/schema/queue_partition_leases.sql index 715b6d03..e1015764 100644 --- a/extension/queue/mysql/schema/queue_partition_leases.sql +++ b/extension/queue/mysql/schema/queue_partition_leases.sql @@ -1,6 +1,10 @@ -- PARTITION LEASES TABLE -- Tracks which worker has leased which partition for exclusive processing. -- Workers must renew leases to maintain ownership; stale leases can be stolen. +-- +-- The primary key (consumer_group, topic, partition_key) serves as the main +-- lookup index. Queries by leased_by always include consumer_group and topic, +-- so the primary key's left-prefix is sufficient. CREATE TABLE IF NOT EXISTS queue_partition_leases ( -- Consumer group (e.g., "orchestrator") @@ -22,15 +26,12 @@ CREATE TABLE IF NOT EXISTS queue_partition_leases ( -- Used to detect stale leases lease_renewed_at BIGINT UNSIGNED NOT NULL, - -- Primary key ensures each partition can only be leased by one worker per consumer group - -- Supports: INSERT ... ON DUPLICATE KEY UPDATE for lease acquisition and renewal + -- Primary key ensures each partition can only be leased by one worker per consumer group. + -- Supports: INSERT ... ON DUPLICATE KEY UPDATE for lease acquisition and renewal. -- Also enables efficient lookups: SELECT ... WHERE consumer_group=? AND topic=? AND partition_key=? + -- Left-prefix covers: SELECT ... WHERE consumer_group=? AND topic=? AND leased_by=? PRIMARY KEY (consumer_group, topic, partition_key), - -- Supports: SELECT ... WHERE leased_by=? - -- Used for querying all partitions owned by a specific worker (e.g., for graceful shutdown or rebalancing) - INDEX idx_leased_by (leased_by), - -- Supports: SELECT ... WHERE lease_renewed_at0 means deregistered at that time. + deregistered_at BIGINT UNSIGNED NOT NULL, + + PRIMARY KEY (consumer_group, topic, subscriber_name) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin; diff --git a/extension/queue/mysql/sql.go b/extension/queue/mysql/sql.go index 9a176b05..aa7b9273 100644 --- a/extension/queue/mysql/sql.go +++ b/extension/queue/mysql/sql.go @@ -18,6 +18,7 @@ import ( "database/sql" "errors" "fmt" + "time" "github.com/uber-go/tally/v4" "go.uber.org/zap" @@ -50,29 +51,33 @@ func NewQueue(params Params) (queue.Queue, error) { return nil, fmt.Errorf("failed to ping database: %w", err) } - logger := params.Logger.Sugar().Named("queue.sql") - logger.Info("created SQL queue") + logger := params.Logger.Sugar().Named("queue_mysql") + logger.Infow("created SQL queue") // Create stores - messageStore := newMessageStore(params.DB, params.Logger, params.MetricsScope) - offsetStore := newOffsetStore(params.DB, params.Logger, params.MetricsScope) - leaseStore := newPartitionLeaseStore(params.DB, params.Logger, params.MetricsScope) + messageStore := newMessageStore(params.DB, logger, params.MetricsScope) + offsetStore := newOffsetStore(params.DB, params.MetricsScope) + leaseStore := newPartitionLeaseStore(params.DB, logger, params.MetricsScope) + heartbeatStore := newSubscriberHeartbeatStore(params.DB, logger, params.MetricsScope, time.Now) + deliveryStateStore := newDeliveryStateStore(params.DB, logger, params.MetricsScope) queueMetrics := params.MetricsScope.SubScope("queue") // Create publisher and subscriber publisher := NewPublisher( - logger.Named("publisher"), + logger, queueMetrics.SubScope("publisher"), messageStore, ) subscriber := NewSubscriber( - logger.Named("subscriber"), + logger, queueMetrics.SubScope("subscriber"), messageStore, offsetStore, leaseStore, + heartbeatStore, + deliveryStateStore, ) return &queueImpl{ diff --git a/extension/queue/mysql/stores.go b/extension/queue/mysql/stores.go index 1dce2c9b..10fa7e59 100644 --- a/extension/queue/mysql/stores.go +++ b/extension/queue/mysql/stores.go @@ -24,10 +24,11 @@ import ( const ( // Fixed table names for single-table design - MessagesTableName = "queue_messages" - PartitionLeasesTableName = "queue_partition_leases" - OffsetsTableName = "queue_offsets" - DLQTableName = "queue_dlq" + MessagesTableName = "queue_messages" + PartitionLeasesTableName = "queue_partition_leases" + OffsetsTableName = "queue_offsets" + SubscriberHeartbeatsTableName = "queue_subscriber_heartbeats" + DeliveryStateTableName = "queue_delivery_state" ) // messageRow represents a row from the messages table (internal use only) @@ -42,8 +43,6 @@ type messageRow struct { Metadata map[string]string // PartitionKey determines which partition this message belongs to for ordering guarantees PartitionKey string - // RetryCount tracks how many times this message has been retried on the current topic - RetryCount int // PublishedAt is the Unix timestamp in milliseconds when message was published PublishedAt int64 // FailedAt is the Unix timestamp in milliseconds when the message failed (0 for normal messages, >0 for DLQ) @@ -61,24 +60,30 @@ type messageStore interface { // Insert inserts messages into the topic table Insert(ctx context.Context, topic string, messages []queue.Message) error - // Delete deletes a message by ID - Delete(ctx context.Context, topic string, messageID string) error + // Delete deletes a message by topic, partition key, and ID + Delete(ctx context.Context, topic string, partitionKey string, messageID string) error - // FetchByOffset fetches messages with offset > currentOffset for a specific partition - // Only fetches visible messages (invisible_until <= now) - // Atomically sets invisible_until and increments retry_count - // visibilityTimeoutMs specifies how long messages should be invisible after fetching (in milliseconds) - FetchByOffset(ctx context.Context, topic string, partitionKey string, currentOffset int64, limit int, visibilityTimeoutMs int64) ([]messageRow, error) + // FetchByOffset fetches messages with offset > currentOffset for a specific partition. + // Messages are returned from the immutable log; per-consumer-group visibility + // is handled by the deliveryStateStore. + FetchByOffset(ctx context.Context, topic string, partitionKey string, currentOffset int64, limit int) ([]messageRow, error) // MoveToDLQ moves a message to the dead letter queue // dlqTopicSuffix is appended to the original topic to form the DLQ topic name - MoveToDLQ(ctx context.Context, topic string, messageID string, failureCount int, lastError string, dlqTopicSuffix string) error - - // SetVisibilityTimeout sets the invisible_until timestamp for a message - // visibilityTimeoutMillis: milliseconds from now to hide the message - // If visibilityTimeoutMillis is 0, makes the message visible immediately - // If visibilityTimeoutMillis > 0, makes the message invisible until now + visibilityTimeoutMillis - SetVisibilityTimeout(ctx context.Context, topic string, messageID string, visibilityTimeoutMillis int64) error + MoveToDLQ(ctx context.Context, topic string, partitionKey string, messageID string, failureCount int, lastError string, dlqTopicSuffix string) error + + // GarbageCollect deletes messages with offset <= minAckedOffset. + // The caller (subscriber) is responsible for computing minAckedOffset from the + // offsetStore, keeping messageStore free of cross-table queries. + // Returns the number of rows deleted. + GarbageCollect(ctx context.Context, topic string, partitionKey string, minAckedOffset int64) (int64, error) + + // GetOffsetsAbove returns message offsets above afterOffset for a partition, + // ordered ascending, up to limit rows. Used by the subscriber to drive + // watermark advancement without requiring a cross-table JOIN in the delivery + // state store. Watermark advancement is incremental and idempotent, so + // limiting the result set is safe — it converges over multiple calls. + GetOffsetsAbove(ctx context.Context, topic string, partitionKey string, afterOffset int64, limit int) ([]int64, error) } // offsetStore handles offset table operations for per-partition offset tracking (internal use only) @@ -92,8 +97,11 @@ type offsetStore interface { // UpdateAckedOffset updates the offset_acked for a topic+partition (only if new offset is greater) UpdateAckedOffset(ctx context.Context, topic string, partitionKey string, offset int64, consumerGroup string) error - // AckMessage atomically deletes a message and updates the acked offset - AckMessage(ctx context.Context, topic string, partitionKey string, messageID string, offset int64, consumerGroup string, messageStore messageStore) error + // GetMinAckedOffset returns the minimum offset_acked across all consumer groups + // for a topic+partition. Returns (0, false, nil) if no offset rows exist. + // Used by the subscriber to compute the GC threshold without messageStore + // needing to query the offsets table. + GetMinAckedOffset(ctx context.Context, topic string, partitionKey string) (offset int64, found bool, err error) } // partitionLeaseStore handles partition lease operations (internal use only) @@ -113,8 +121,61 @@ type partitionLeaseStore interface { // GetLeasedPartitions returns all partitions currently leased by this worker GetLeasedPartitions(ctx context.Context, topic string, subscriberName string, consumerGroup string) ([]string, error) - // DiscoverAndAcquirePartitions discovers partitions from messages table and tries to acquire leases - // Returns the number of new leases acquired + // DiscoverAndAcquirePartitions discovers partitions from messages table and tries to acquire leases. + // Returns the number of new leases acquired and the full list of discovered partitions. // leaseDurationMs is how long the lease is valid (in milliseconds) - DiscoverAndAcquirePartitions(ctx context.Context, topic string, subscriberName string, consumerGroup string, leaseDurationMs int64) (int, error) + // maxPartitions limits how many total partitions this subscriber can own (0 = unlimited) + DiscoverAndAcquirePartitions(ctx context.Context, topic string, subscriberName string, consumerGroup string, leaseDurationMs int64, maxPartitions int) (acquiredCount int, discoveredPartitions []string, err error) +} + +// subscriberHeartbeatStore handles subscriber heartbeat operations for fair partition leasing (internal use only) +type subscriberHeartbeatStore interface { + // Heartbeat registers or renews a subscriber's heartbeat + Heartbeat(ctx context.Context, topic string, subscriberName string, consumerGroup string) error + + // ActiveSubscribers returns the names of subscribers with a recent heartbeat. + // staleDurationMs defines the staleness threshold: subscribers without a heartbeat + // within this duration are considered dead. + ActiveSubscribers(ctx context.Context, topic string, consumerGroup string, staleDurationMs int64) ([]string, error) + + // Deregister removes a subscriber's heartbeat entry + Deregister(ctx context.Context, topic string, subscriberName string, consumerGroup string) error +} + +// DeliveryState represents the full per-message delivery tracking state. +type DeliveryState struct { + // Acked indicates whether this consumer group has processed the message + Acked bool + // InvisibleUntil is the epoch milliseconds until which the message is hidden + InvisibleUntil int64 + // RetryCount tracks how many times the message has been delivered + RetryCount int +} + +// deliveryStateStore handles per-consumer-group delivery tracking (internal use only) +type deliveryStateStore interface { + // MarkDelivered inserts a row marking message as in-flight for this consumer group. + // Increments retry_count on redelivery (ON DUPLICATE KEY UPDATE). + // Returns the resulting retry_count after the operation. + MarkDelivered(ctx context.Context, consumerGroup, topic, partitionKey string, offset int64, visibilityTimeoutMs int64) (retryCount int, err error) + + // ExtendVisibility extends the visibility timeout for an in-flight message + // without incrementing retry_count. + ExtendVisibility(ctx context.Context, consumerGroup, topic, partitionKey string, offset int64, visibilityTimeoutMs int64) error + + // MarkAcked sets acked = TRUE to indicate this group has processed the message. + MarkAcked(ctx context.Context, consumerGroup, topic, partitionKey string, offset int64) error + + // MarkNacked sets invisible_until = now + delay to schedule redelivery. + MarkNacked(ctx context.Context, consumerGroup, topic, partitionKey string, offset int64, delayMs int64) error + + // GetDeliveryState returns the full delivery state for a message offset. + // Returns (state, found, error). found=false means no row (never delivered). + GetDeliveryState(ctx context.Context, consumerGroup, topic, partitionKey string, offset int64) (DeliveryState, bool, error) + + // AdvanceWatermark computes the new contiguous acked watermark and cleans up + // delivery state rows behind it. + // offsets are the actual message offsets above the current watermark (from messageStore). + // Returns the new watermark (highest contiguous acked offset from currentWatermark). + AdvanceWatermark(ctx context.Context, consumerGroup, topic, partitionKey string, currentWatermark int64, offsets []int64) (int64, error) } diff --git a/extension/queue/mysql/subscriber.go b/extension/queue/mysql/subscriber.go index 4a8a07f2..fbd95e50 100644 --- a/extension/queue/mysql/subscriber.go +++ b/extension/queue/mysql/subscriber.go @@ -17,6 +17,7 @@ package mysql import ( "context" "fmt" + "sort" "strconv" "sync" "time" @@ -24,6 +25,7 @@ import ( "github.com/uber-go/tally/v4" "go.uber.org/zap" + "github.com/uber/submitqueue/core/metrics" "github.com/uber/submitqueue/entity/queue" extqueue "github.com/uber/submitqueue/extension/queue" ) @@ -40,16 +42,29 @@ const ( // subscriptionShutdownTimeout is the maximum time to wait for the // managePartitions goroutine to finish during Close(). subscriptionShutdownTimeout = 30 * time.Second + + // watermarkAdvancementLimit is the max number of message offsets fetched per + // advanceWatermark call. Watermark advancement is incremental and idempotent, + // so it converges over multiple calls even with large backlogs. + watermarkAdvancementLimit = 1000 + + // gcIdleTickInterval controls how often GC runs during idle poll ticks. + // GC runs every Nth idle tick instead of every tick to avoid excessive + // queries when many partitions are idle (e.g., 50 idle partitions at 100ms + // poll interval = 500 GC queries/sec without throttling). + gcIdleTickInterval = 100 ) type subscriber struct { - logger *zap.SugaredLogger - metrics tally.Scope - messageStore messageStore - offsetStore offsetStore - leaseStore partitionLeaseStore - mu sync.RWMutex - closed bool + logger *zap.SugaredLogger + scope tally.Scope + messageStore messageStore + offsetStore offsetStore + leaseStore partitionLeaseStore + heartbeatStore subscriberHeartbeatStore + deliveryStateStore deliveryStateStore + mu sync.RWMutex + closed bool // Active subscriptions subscriptions map[string]*subscription @@ -77,6 +92,11 @@ type subscription struct { // be called during shutdown. workers map[string]*partitionWorker workersMu sync.Mutex + + // lastDiscoveredPartitions is cached from the most recent + // DiscoverAndAcquirePartitions call. Used by fairShareCap during + // rebalance to avoid a redundant discovery query. + lastDiscoveredPartitions []string } // partitionWorker handles polling and delivering messages for a single partition. @@ -94,6 +114,9 @@ type partitionWorker struct { // partition. Set once on the first successful poll, avoiding repeated // initialization calls on every tick. offsetInitialized bool + // gcCounter counts idle poll ticks. GC only runs every gcIdleTickInterval + // ticks to avoid excessive queries when many partitions are idle. + gcCounter int } // sqlDelivery implements extqueue.Delivery for SQL queue @@ -184,17 +207,13 @@ func (d *sqlDelivery) Ack(ctx context.Context) error { return &ErrAlreadyAcknowledged{DeliveryID: d.deliveryID} } - // Perform acknowledgment - if err := d.subscriber.offsetStore.AckMessage(ctx, d.topic, d.partitionKey, d.messageID, d.offset, d.consumerGroup, d.subscriber.messageStore); err != nil { + // Mark as acked in delivery state (per consumer group). + // Watermark advancement is deferred to the poll loop to reduce per-ack + // latency from 4-5 DB round trips to 1. + if err := d.subscriber.deliveryStateStore.MarkAcked(ctx, d.consumerGroup, d.topic, d.partitionKey, d.offset); err != nil { return err } - // Record metrics - d.subscriber.metrics.Tagged(map[string]string{ - "topic": d.topic, - "partition_key": d.partitionKey, - }).Counter("messages_acked").Inc(1) - d.acknowledged = true return nil } @@ -208,24 +227,12 @@ func (d *sqlDelivery) Nack(ctx context.Context, requeueAfterMillis int64) error return &ErrAlreadyAcknowledged{DeliveryID: d.deliveryID} } - // Set visibility timeout to make message visible after requeueAfter duration - if err := d.subscriber.messageStore.SetVisibilityTimeout(ctx, d.topic, d.messageID, requeueAfterMillis); err != nil { - d.subscriber.logger.Errorw("failed to set visibility timeout for nack", - "topic", d.topic, - "partition_key", d.partitionKey, - "message_id", d.messageID, - "error", err, - ) + // Mark as nacked in delivery state (per consumer group, with delay and retry_count) + if err := d.subscriber.deliveryStateStore.MarkNacked(ctx, d.consumerGroup, d.topic, d.partitionKey, d.offset, requeueAfterMillis); err != nil { return err } - // Record metrics - d.subscriber.metrics.Tagged(map[string]string{ - "topic": d.topic, - "partition_key": d.partitionKey, - }).Counter("messages_nacked").Inc(1) - - d.subscriber.logger.Infow("message nacked", + d.subscriber.logger.Debugw("message nacked", "topic", d.topic, "partition_key", d.partitionKey, "message_id", d.messageID, @@ -248,39 +255,24 @@ func (d *sqlDelivery) Reject(ctx context.Context, reason string) error { if d.dlqConfig.Enabled { // Move to DLQ if err := d.subscriber.messageStore.MoveToDLQ( - ctx, d.topic, d.messageID, d.attempt, reason, d.dlqConfig.TopicSuffix, + ctx, d.topic, d.partitionKey, d.messageID, d.attempt, reason, d.dlqConfig.TopicSuffix, ); err != nil { return fmt.Errorf("failed to move message to DLQ: %w", err) } - // Update offset tracking - if err := d.subscriber.offsetStore.UpdateAckedOffset( - ctx, d.topic, d.partitionKey, d.offset, d.consumerGroup, - ); err != nil { - // Log but don't fail — message is already in DLQ - d.subscriber.logger.Errorw("failed to update offset after DLQ move", - "topic", d.topic, - "message_id", d.messageID, - "error", err, - ) + // Mark as acked in delivery state. Watermark advancement is deferred + // to the poll loop, same as Ack. + if err := d.subscriber.deliveryStateStore.MarkAcked(ctx, d.consumerGroup, d.topic, d.partitionKey, d.offset); err != nil { + return fmt.Errorf("mark acked after DLQ move: %w", err) } - d.subscriber.metrics.Tagged(map[string]string{ - "topic": d.topic, - "partition_key": d.partitionKey, - }).Counter("messages_rejected_to_dlq").Inc(1) } else { - // DLQ disabled — fall back to ack (remove from queue) - if err := d.subscriber.offsetStore.AckMessage( - ctx, d.topic, d.partitionKey, d.messageID, d.offset, d.consumerGroup, d.subscriber.messageStore, - ); err != nil { + // DLQ disabled — mark as acked (remove from processing). + // Watermark advancement is deferred to the poll loop, same as Ack. + if err := d.subscriber.deliveryStateStore.MarkAcked(ctx, d.consumerGroup, d.topic, d.partitionKey, d.offset); err != nil { return err } - d.subscriber.metrics.Tagged(map[string]string{ - "topic": d.topic, - "partition_key": d.partitionKey, - }).Counter("messages_rejected_no_dlq").Inc(1) } d.acknowledged = true @@ -296,41 +288,65 @@ func (d *sqlDelivery) ExtendVisibilityTimeout(ctx context.Context, durationMilli return fmt.Errorf("delivery %s already acknowledged, cannot extend visibility timeout", d.deliveryID) } - if err := d.subscriber.messageStore.SetVisibilityTimeout(ctx, d.topic, d.messageID, durationMillis); err != nil { + // Extend visibility without incrementing retry_count + if err := d.subscriber.deliveryStateStore.ExtendVisibility(ctx, d.consumerGroup, d.topic, d.partitionKey, d.offset, durationMillis); err != nil { return err } - // Record metrics - d.subscriber.metrics.Tagged(map[string]string{ - "topic": d.topic, - "partition_key": d.partitionKey, - }).Counter("visibility_extended").Inc(1) - return nil } -func NewSubscriber(logger *zap.SugaredLogger, metrics tally.Scope, messageStore messageStore, offsetStore offsetStore, leaseStore partitionLeaseStore) *subscriber { - logger.Info("created subscriber") - +func NewSubscriber(logger *zap.SugaredLogger, scope tally.Scope, messageStore messageStore, offsetStore offsetStore, leaseStore partitionLeaseStore, heartbeatStore subscriberHeartbeatStore, deliveryStateStore deliveryStateStore) *subscriber { return &subscriber{ - logger: logger, - metrics: metrics, - messageStore: messageStore, - offsetStore: offsetStore, - leaseStore: leaseStore, - subscriptions: make(map[string]*subscription), + logger: logger.Named("subscriber"), + scope: scope.SubScope("subscriber"), + messageStore: messageStore, + offsetStore: offsetStore, + leaseStore: leaseStore, + heartbeatStore: heartbeatStore, + deliveryStateStore: deliveryStateStore, + subscriptions: make(map[string]*subscription), } } +// advanceWatermark advances offset_acked to the highest contiguous acked offset. +// All operations are idempotent — safe to call from multiple paths (Reject, retry-limit, +// poll loop) and safe to retry on failure. +func (s *subscriber) advanceWatermark(ctx context.Context, consumerGroup, topic, partitionKey string) error { + currentOffset, err := s.offsetStore.GetAckedOffset(ctx, topic, partitionKey, consumerGroup) + if err != nil { + return fmt.Errorf("get acked offset for watermark advance: %w", err) + } + + offsets, err := s.messageStore.GetOffsetsAbove(ctx, topic, partitionKey, currentOffset, watermarkAdvancementLimit) + if err != nil { + return fmt.Errorf("get message offsets for watermark advance: %w", err) + } + + newWatermark, err := s.deliveryStateStore.AdvanceWatermark(ctx, consumerGroup, topic, partitionKey, currentOffset, offsets) + if err != nil { + return fmt.Errorf("advance watermark: %w", err) + } + + if newWatermark > currentOffset { + if err := s.offsetStore.UpdateAckedOffset(ctx, topic, partitionKey, newWatermark, consumerGroup); err != nil { + return fmt.Errorf("update acked offset after watermark advance: %w", err) + } + } + return nil +} + // Subscribe starts consuming messages from the specified topic -func (s *subscriber) Subscribe(ctx context.Context, topic string, config extqueue.SubscriptionConfig) (<-chan extqueue.Delivery, error) { +func (s *subscriber) Subscribe(ctx context.Context, topic string, config extqueue.SubscriptionConfig) (_ <-chan extqueue.Delivery, retErr error) { + op := metrics.Begin(s.scope, "subscribe", metrics.NewTag("topic", topic)) + defer func() { op.Complete(retErr) }() + s.mu.RLock() closed := s.closed s.mu.RUnlock() if closed { - s.logger.Errorw("subscribe failed: subscriber is closed", "topic", topic) - return nil, fmt.Errorf("subscriber is closed") + return nil, ErrSubscriberClosed } // Create subscription key (topic + consumer group must be unique) @@ -368,7 +384,7 @@ func (s *subscriber) Subscribe(ctx context.Context, topic string, config extqueu s.subscriptions[subKey] = sub // Track active subscription - s.metrics.Tagged(map[string]string{"topic": topic}).Gauge("active_subscriptions").Update(1) + metrics.NamedGauge(s.scope, "subscribe", "active_subscriptions", 1, metrics.NewTag("topic", topic)) // Start the supervisor goroutine. It will discover partitions, acquire // leases, and spawn per-partition worker goroutines. The supervisor runs @@ -389,27 +405,44 @@ func (s *subscriber) Subscribe(ctx context.Context, topic string, config extqueu // // Goroutine hierarchy: // -// managePartitions (this goroutine) ← supervisor, tracked by sub.wg -// ├── partitionWorker("part-1") ← tracked by sub.workerWg -// ├── partitionWorker("part-2") -// └── partitionWorker("part-N") +// managePartitions (this goroutine) <- supervisor, tracked by sub.wg +// +-- partitionWorker("part-1") <- tracked by sub.workerWg +// +-- partitionWorker("part-2") +// +-- partitionWorker("part-N") // // Shutdown sequence (triggered by ctx cancellation): // 1. stopAllWorkers: cancels each worker's context and removes from map // 2. releaseAllLeases: releases DB partition leases (fresh context, not cancelled) -// 3. workerWg.Wait(): blocks until all workers have fully exited — this ensures +// 3. workerWg.Wait(): blocks until all workers have fully exited -- this ensures // no worker can send on deliveryCh after step 4 // 4. close(deliveryCh): safe because step 3 guarantees no senders remain -// 5. managePartitions returns → wg.Done() fires → Close() unblocks +// 5. managePartitions returns -> wg.Done() fires -> Close() unblocks func (s *subscriber) managePartitions(ctx context.Context, sub *subscription) { defer sub.wg.Done() - discoveryTicker := time.NewTicker(time.Duration(sub.config.PollIntervalMs) * time.Millisecond) + cfg := sub.config + // Common log fields for all operations in this subscription's lifecycle. + logFields := []interface{}{ + "topic", sub.topic, + "consumer_group", cfg.ConsumerGroup, + "subscriber_name", cfg.SubscriberName, + } + + discoveryTicker := time.NewTicker(time.Duration(cfg.PollIntervalMs) * time.Millisecond) defer discoveryTicker.Stop() - leaseTicker := time.NewTicker(time.Duration(sub.config.LeaseRenewalIntervalMs) * time.Millisecond) + leaseTicker := time.NewTicker(time.Duration(cfg.LeaseRenewalIntervalMs) * time.Millisecond) defer leaseTicker.Stop() + // Send initial heartbeat so this subscriber is immediately visible to + // ActiveSubscribers. Without this, other subscribers compute incorrect + // fair shares until the first leaseTicker fires. + // Initial heartbeat failure is non-fatal — the next leaseTicker fires within + // LeaseRenewalIntervalMs and retries. + if err := s.sendHeartbeat(ctx, sub); err != nil { + s.logger.Errorw("initial heartbeat failed", append(logFields, "error", err)...) + } + for { select { case <-ctx.Done(): @@ -417,39 +450,98 @@ func (s *subscriber) managePartitions(ctx context.Context, sub *subscription) { // Release all leases on shutdown with a fresh context cleanupCtx, cancel := context.WithTimeout(context.Background(), leaseReleaseTimeout) defer cancel() - s.releaseAllLeases(cleanupCtx, sub) + + // Best-effort shutdown cleanup — log errors but don't block shutdown. + // Leases expire naturally after LeaseDurationMs if release fails. + // Heartbeat becomes stale after the same duration. + if err := s.releaseAllLeases(cleanupCtx, sub); err != nil { + s.logger.Errorw("failed to release leases during shutdown", append(logFields, "error", err)...) + } + if err := s.deregisterHeartbeat(cleanupCtx, sub); err != nil { + s.logger.Errorw("failed to deregister heartbeat during shutdown", append(logFields, "error", err)...) + } + // Wait for all workers to fully exit, then close channel sub.workerWg.Wait() close(sub.deliveryCh) return case <-leaseTicker.C: - s.renewLeases(ctx, sub) + // Fetch leased partitions once for this tick — shared by rebalance + // and renewLeases to avoid redundant queries. + leasedPartitions, err := s.leaseStore.GetLeasedPartitions(ctx, sub.topic, cfg.SubscriberName, cfg.ConsumerGroup) + if err != nil { + s.logger.Errorw("get leased partitions failed", append(logFields, "error", err)...) + // Skip rebalance+renew on this tick; retry next tick. + if err := s.sendHeartbeat(ctx, sub); err != nil { + s.logger.Errorw("heartbeat failed during lease error recovery", append(logFields, "error", err)...) + } + continue + } + + // Rebalance, renew, and heartbeat are independent operations. + // Each can fail without affecting the others — the next tick retries. + if err := s.rebalance(ctx, sub, leasedPartitions); err != nil { + s.logger.Errorw("rebalance failed", append(logFields, "error", err)...) + } + if err := s.renewLeases(ctx, sub, leasedPartitions); err != nil { + s.logger.Errorw("lease renewal failed", append(logFields, "error", err)...) + } + if err := s.sendHeartbeat(ctx, sub); err != nil { + s.logger.Errorw("periodic heartbeat failed", append(logFields, "error", err)...) + } case <-discoveryTicker.C: - s.discoverAndReconcileWorkers(ctx, sub) + if err := s.discoverAndReconcileWorkers(ctx, sub); err != nil { + s.logger.Errorw("partition discovery failed, will retry on next tick", append(logFields, "error", err)...) + } } } } // discoverAndReconcileWorkers discovers new partitions and reconciles workers. -func (s *subscriber) discoverAndReconcileWorkers(ctx context.Context, sub *subscription) { +// Uses load-based fair share to limit how many partitions this subscriber acquires. +func (s *subscriber) discoverAndReconcileWorkers(ctx context.Context, sub *subscription) error { cfg := sub.config - // Discover and try to acquire leases for new partitions - acquiredCount, err := s.leaseStore.DiscoverAndAcquirePartitions(ctx, sub.topic, cfg.SubscriberName, cfg.ConsumerGroup, cfg.LeaseDurationMs) - if err == nil && acquiredCount > 0 { - s.metrics.Tagged(map[string]string{"topic": sub.topic}).Counter("leases_acquired").Inc(int64(acquiredCount)) + // Get current leased partitions for fair share computation. + leasedPartitions, err := s.leaseStore.GetLeasedPartitions(ctx, sub.topic, cfg.SubscriberName, cfg.ConsumerGroup) + if err != nil { + return fmt.Errorf("get leased partitions: %w", err) + } + + // Use cached discovered partitions from last tick for fair share cap. + // On the first tick, lastDiscoveredPartitions is nil → fairShareCap uses + // only owned partitions, which gives unlimited cap for new subscribers. + sub.workersMu.Lock() + cachedDiscovered := sub.lastDiscoveredPartitions + sub.workersMu.Unlock() + + maxPartitions, err := s.fairShareCap(ctx, sub, leasedPartitions, cachedDiscovered) + if err != nil { + return fmt.Errorf("compute fair share cap: %w", err) } - // Get currently leased partitions - leasedPartitions, err := s.leaseStore.GetLeasedPartitions(ctx, sub.topic, cfg.SubscriberName, cfg.ConsumerGroup) + // Discover and try to acquire leases for new partitions. + // Returns discovered partitions to cache for the next tick. + _, discoveredPartitions, err := s.leaseStore.DiscoverAndAcquirePartitions(ctx, sub.topic, cfg.SubscriberName, cfg.ConsumerGroup, cfg.LeaseDurationMs, maxPartitions) if err != nil { - s.logger.Errorw("failed to get leased partitions", "topic", sub.topic, "error", err) - return + return fmt.Errorf("discover and acquire partitions: %w", err) + } + + // Cache discovered partitions for fairShareCap reuse by rebalance and next tick. + sub.workersMu.Lock() + sub.lastDiscoveredPartitions = discoveredPartitions + sub.workersMu.Unlock() + + // Refresh leased partitions after acquisition (new leases may have been acquired) + leasedPartitions, err = s.leaseStore.GetLeasedPartitions(ctx, sub.topic, cfg.SubscriberName, cfg.ConsumerGroup) + if err != nil { + return fmt.Errorf("get leased partitions after acquire: %w", err) } s.reconcilePartitionWorkers(ctx, sub, leasedPartitions) + return nil } // reconcilePartitionWorkers diffs the current set of workers against the current @@ -529,7 +621,7 @@ func (s *subscriber) startPartitionWorker(ctx context.Context, sub *subscription // reconciliation can start a replacement if the lease is re-acquired. The old // worker's context is cancelled, so its DB calls will fail and it will exit // imminently. workerWg still tracks the old goroutine, so Close() blocks until -// it fully exits — preventing sends on a closed deliveryCh. +// it fully exits -- preventing sends on a closed deliveryCh. // // The select with workerStopTimeout is purely for observability: if the worker // takes longer than expected to exit, a warning is logged but no action is needed @@ -547,7 +639,7 @@ func (s *subscriber) stopPartitionWorker(sub *subscription, partitionKey string) // Always remove from map so reconcile can start a replacement if needed. // The old worker's context is cancelled so it will exit imminently. - // workerWg still tracks it for shutdown — Close() won't return until it exits. + // workerWg still tracks it for shutdown -- Close() won't return until it exits. sub.workersMu.Lock() delete(sub.workers, partitionKey) sub.workersMu.Unlock() @@ -582,7 +674,7 @@ func (s *subscriber) stopAllWorkers(sub *subscription) { // run is the per-partition goroutine loop. It polls the DB on a ticker and // sends fetched messages to the shared deliveryCh. Each partition worker runs -// independently — a slow or blocked partition does not affect other partitions. +// independently -- a slow or blocked partition does not affect other partitions. // // Lifecycle: // - Started by startPartitionWorker, tracked by sub.workerWg @@ -600,13 +692,33 @@ func (w *partitionWorker) run(ctx context.Context) { case <-ctx.Done(): return case <-pollTicker.C: - w.pollAndDeliver(ctx) + // Errors are logged here rather than propagated because run() is a + // long-lived goroutine on a ticker. There is no upstream caller to + // return to — the only recovery is to retry on the next tick, which + // happens automatically. All pollAndDeliver operations are idempotent. + if err := w.pollAndDeliver(ctx); err != nil { + w.subscriber.logger.Errorw("poll failed", + "topic", w.sub.topic, + "partition_key", w.partitionKey, + "consumer_group", w.sub.config.ConsumerGroup, + "subscriber_name", w.sub.config.SubscriberName, + "error", err, + ) + } } } } // pollAndDeliver fetches messages from this worker's partition and delivers them. -func (w *partitionWorker) pollAndDeliver(ctx context.Context) { +// Returns an error if any DB operation fails — the caller logs once and the ticker +// retries on the next tick. All operations are idempotent, so retries are safe. +// +// Design note: GetDeliveryState and MarkDelivered are called per-message rather than +// batched. This keeps the store interfaces simple and the delivery logic straightforward. +// Partition leasing guarantees a single writer, so the TOCTOU gap between +// GetDeliveryState and MarkDelivered cannot cause incorrect behavior — no other +// worker can mutate the same (consumer_group, topic, partition_key, offset). +func (w *partitionWorker) pollAndDeliver(ctx context.Context) error { start := time.Now() s := w.subscriber sub := w.sub @@ -616,8 +728,7 @@ func (w *partitionWorker) pollAndDeliver(ctx context.Context) { // Initialize offset for this partition once per worker lifetime if !w.offsetInitialized { if err := s.offsetStore.Initialize(ctx, sub.topic, partitionKey, cfg.ConsumerGroup); err != nil { - s.logger.Errorw("offset initialization failure", "topic", sub.topic, "partition_key", partitionKey, "error", err) - return + return fmt.Errorf("initialize offset: %w", err) } w.offsetInitialized = true } @@ -625,61 +736,63 @@ func (w *partitionWorker) pollAndDeliver(ctx context.Context) { // Get current offset for this partition currentOffset, err := s.offsetStore.GetAckedOffset(ctx, sub.topic, partitionKey, cfg.ConsumerGroup) if err != nil { - s.logger.Errorw("get current offset failure", "topic", sub.topic, "partition_key", partitionKey, "error", err) - return + return fmt.Errorf("get acked offset: %w", err) } - // Fetch messages for this partition - rows, err := s.messageStore.FetchByOffset(ctx, sub.topic, partitionKey, currentOffset, cfg.BatchSize, cfg.VisibilityTimeoutMs) + // Fetch messages from immutable log + rows, err := s.messageStore.FetchByOffset(ctx, sub.topic, partitionKey, currentOffset, cfg.BatchSize) if err != nil { - s.logger.Errorw("fetch messages failure", "topic", sub.topic, "partition_key", partitionKey, "error", err) - return + return fmt.Errorf("fetch messages: %w", err) } messageCount := 0 for _, row := range rows { - // Check if message has exceeded retry limit (persistent retry_count from DB) - if row.RetryCount >= cfg.Retry.MaxAttempts { + // Check per-consumer-group deliverability via delivery state. + // Single query replaces separate IsDeliverable + GetRetryCount calls. + state, found, err := s.deliveryStateStore.GetDeliveryState(ctx, cfg.ConsumerGroup, sub.topic, partitionKey, row.Offset) + if err != nil { + return fmt.Errorf("get delivery state offset=%d: %w", row.Offset, err) + } + + // Determine deliverability in-memory: + // !found → new message, deliverable + // state.Acked → already processed, skip + // state.InvisibleUntil > now → in-flight or nack delay, skip + now := time.Now().UnixMilli() + if found && (state.Acked || state.InvisibleUntil > now) { + continue + } + + // Mark as delivered (in-flight) in delivery state. + // Returns the resulting retry_count, avoiding a separate GetRetryCount call. + retryCount, err := s.deliveryStateStore.MarkDelivered(ctx, cfg.ConsumerGroup, sub.topic, partitionKey, row.Offset, cfg.VisibilityTimeoutMs) + if err != nil { + return fmt.Errorf("mark delivered offset=%d: %w", row.Offset, err) + } + + // Check if message has exceeded retry limit + if retryCount >= cfg.Retry.MaxAttempts { s.logger.Warnw("message exceeded retry limit", "topic", sub.topic, + "consumer_group", cfg.ConsumerGroup, "partition_key", partitionKey, "message_id", row.ID, - "retry_count", row.RetryCount, + "retry_count", retryCount, ) - // Move to DLQ if enabled + // Move to DLQ if enabled — must succeed before marking acked, + // otherwise the message is lost from both main queue and DLQ. if cfg.DLQ.Enabled { - dlqTopic := sub.topic + cfg.DLQ.TopicSuffix - if err := s.messageStore.MoveToDLQ(ctx, sub.topic, row.ID, row.RetryCount, "exceeded retry limit", cfg.DLQ.TopicSuffix); err != nil { - s.logger.Errorw("failed to move message to DLQ", - "topic", sub.topic, - "dlq_topic", dlqTopic, - "message_id", row.ID, - "error", err, - ) - } else { - s.logger.Infow("moved message to DLQ", - "topic", sub.topic, - "dlq_topic", dlqTopic, - "message_id", row.ID, - "retry_count", row.RetryCount, - ) - s.metrics.Tagged(map[string]string{ - "topic": sub.topic, - "partition_key": partitionKey, - }).Counter("messages_moved_to_dlq").Inc(1) - - // Update offset since message is now processed (moved to DLQ) - if err := s.offsetStore.UpdateAckedOffset(ctx, sub.topic, partitionKey, row.Offset, cfg.ConsumerGroup); err != nil { - s.logger.Errorw("failed to update offset after DLQ move", - "topic", sub.topic, - "partition_key", partitionKey, - "offset", row.Offset, - "error", err, - ) - } + if err := s.messageStore.MoveToDLQ(ctx, sub.topic, partitionKey, row.ID, retryCount, "exceeded retry limit", cfg.DLQ.TopicSuffix); err != nil { + return fmt.Errorf("move to DLQ message=%s: %w", row.ID, err) } } + + // Mark as acked so watermark can advance past it. + // Watermark advancement is deferred to the poll loop. + if err := s.deliveryStateStore.MarkAcked(ctx, cfg.ConsumerGroup, sub.topic, partitionKey, row.Offset); err != nil { + return fmt.Errorf("mark acked after retry limit message=%s: %w", row.ID, err) + } continue } @@ -689,10 +802,10 @@ func (w *partitionWorker) pollAndDeliver(ctx context.Context) { // Calculate message age for metrics messageAge := time.Duration(time.Now().UnixMilli()-row.PublishedAt) * time.Millisecond - s.metrics.Tagged(map[string]string{ - "topic": sub.topic, - "partition_key": partitionKey, - }).Timer("message_age").Record(messageAge) + metrics.NamedTimer(s.scope, "poll", "message_age", messageAge, + metrics.NewTag("topic", sub.topic), + metrics.NewTag("partition_key", partitionKey), + ) // Create delivery ID from offset deliveryID := strconv.FormatInt(row.Offset, 10) @@ -722,7 +835,7 @@ func (w *partitionWorker) pollAndDeliver(ctx context.Context) { delivery := newSQLDelivery( msg, deliveryID, - row.RetryCount+1, // RetryCount is 0-based, Attempt is 1-based + retryCount+1, // RetryCount is 0-based, Attempt is 1-based deliveryMetadata, s, sub.topic, @@ -738,78 +851,199 @@ func (w *partitionWorker) pollAndDeliver(ctx context.Context) { case sub.deliveryCh <- delivery: messageCount++ case <-ctx.Done(): - return + return nil } } - // Record metrics - if messageCount > 0 { - elapsed := time.Since(start) - partitionTags := map[string]string{ - "topic": sub.topic, - "partition_key": partitionKey, - } - s.metrics.Tagged(partitionTags).Counter("messages_received").Inc(int64(messageCount)) - s.metrics.Tagged(partitionTags).Timer("poll_latency").Record(elapsed) - - s.logger.Debugw("delivered messages", + // Advance watermark periodically (on every poll tick). + // This is deferred from Ack() to reduce per-ack latency to 1 DB call. + // advanceWatermark is idempotent and incremental — safe to call every tick. + if err := s.advanceWatermark(ctx, cfg.ConsumerGroup, sub.topic, partitionKey); err != nil { + s.logger.Warnw("watermark advancement failed", "topic", sub.topic, "partition_key", partitionKey, - "count", messageCount, - "duration_ms", elapsed.Milliseconds(), + "consumer_group", cfg.ConsumerGroup, + "error", err, ) } + + // Run GC periodically (throttled to every Nth idle tick) + if messageCount == 0 { + w.gcCounter++ + if w.gcCounter >= gcIdleTickInterval { + w.gcCounter = 0 + if err := w.garbageCollect(ctx); err != nil { + return fmt.Errorf("garbage collect: %w", err) + } + } + } else { + w.gcCounter = 0 + } + + // Record poll metrics + if messageCount > 0 { + elapsed := time.Since(start) + metrics.NamedCounter(s.scope, "poll", "messages_delivered", int64(messageCount), + metrics.NewTag("topic", sub.topic), + metrics.NewTag("partition_key", partitionKey), + ) + metrics.NamedTimer(s.scope, "poll", "latency", elapsed, + metrics.NewTag("topic", sub.topic), + metrics.NewTag("partition_key", partitionKey), + ) + } + + return nil } -// renewLeases renews leases for all partitions owned by this worker -func (s *subscriber) renewLeases(ctx context.Context, sub *subscription) { - cfg := sub.config - leasedPartitions, err := s.leaseStore.GetLeasedPartitions(ctx, sub.topic, cfg.SubscriberName, cfg.ConsumerGroup) +// garbageCollect orchestrates GC by querying the offsetStore for the minimum +// acked offset across all consumer groups, then telling the messageStore to +// delete messages up to that offset. This keeps each store querying only its +// own table. +func (w *partitionWorker) garbageCollect(ctx context.Context) error { + s := w.subscriber + + minOffset, found, err := s.offsetStore.GetMinAckedOffset(ctx, w.sub.topic, w.partitionKey) if err != nil { - s.logger.Errorw("failed to get leased partitions for renewal", - "topic", sub.topic, - "error", err, - ) - s.metrics.Tagged(map[string]string{"topic": sub.topic}).Counter("lease_renewal.get_partitions_errors").Inc(1) - return + return fmt.Errorf("get min acked offset: %w", err) + } + if !found { + return nil + } + + if _, err := s.messageStore.GarbageCollect(ctx, w.sub.topic, w.partitionKey, minOffset); err != nil { + return fmt.Errorf("delete messages: %w", err) } + return nil +} + +// renewLeases renews leases for all partitions owned by this worker. +func (s *subscriber) renewLeases(ctx context.Context, sub *subscription, leasedPartitions []string) error { + cfg := sub.config + for _, partitionKey := range leasedPartitions { if err := s.leaseStore.RenewLease(ctx, sub.topic, partitionKey, cfg.SubscriberName, cfg.ConsumerGroup, cfg.LeaseDurationMs); err != nil { - s.logger.Warnw("failed to renew lease", - "topic", sub.topic, - "partition_key", partitionKey, - "error", err, - ) - s.metrics.Tagged(map[string]string{ - "topic": sub.topic, - "partition_key": partitionKey, - }).Counter("lease_renewal.renew_errors").Inc(1) + return fmt.Errorf("renew lease partition=%s: %w", partitionKey, err) } } + return nil } -// releaseAllLeases releases all leases for a topic -func (s *subscriber) releaseAllLeases(ctx context.Context, sub *subscription) { +// releaseAllLeases releases all leases for a topic. +func (s *subscriber) releaseAllLeases(ctx context.Context, sub *subscription) error { cfg := sub.config leasedPartitions, err := s.leaseStore.GetLeasedPartitions(ctx, sub.topic, cfg.SubscriberName, cfg.ConsumerGroup) if err != nil { - s.logger.Errorw("failed to get leased partitions for release", - "topic", sub.topic, - "error", err, - ) - return + return fmt.Errorf("get leased partitions for release: %w", err) } for _, partitionKey := range leasedPartitions { if err := s.leaseStore.ReleaseLease(ctx, sub.topic, partitionKey, cfg.SubscriberName, cfg.ConsumerGroup); err != nil { - s.logger.Warnw("failed to release lease", - "topic", sub.topic, - "partition_key", partitionKey, - "error", err, - ) + return fmt.Errorf("release lease partition=%s: %w", partitionKey, err) } } + return nil +} + +// sendHeartbeat sends a heartbeat for this subscriber. +func (s *subscriber) sendHeartbeat(ctx context.Context, sub *subscription) error { + cfg := sub.config + if err := s.heartbeatStore.Heartbeat(ctx, sub.topic, cfg.SubscriberName, cfg.ConsumerGroup); err != nil { + return fmt.Errorf("heartbeat: %w", err) + } + return nil +} + +// deregisterHeartbeat removes this subscriber's heartbeat entry during shutdown. +func (s *subscriber) deregisterHeartbeat(ctx context.Context, sub *subscription) error { + cfg := sub.config + if err := s.heartbeatStore.Deregister(ctx, sub.topic, cfg.SubscriberName, cfg.ConsumerGroup); err != nil { + return fmt.Errorf("deregister heartbeat: %w", err) + } + return nil +} + +// rebalance checks if this subscriber holds more partitions than its fair share +// and releases extras so other subscribers can pick them up. +func (s *subscriber) rebalance(ctx context.Context, sub *subscription, owned []string) error { + cfg := sub.config + + // Use cached discovered partitions from the most recent discovery tick. + sub.workersMu.Lock() + discoveredPartitions := sub.lastDiscoveredPartitions + sub.workersMu.Unlock() + + maxPart, err := s.fairShareCap(ctx, sub, owned, discoveredPartitions) + if err != nil { + return fmt.Errorf("compute fair share cap: %w", err) + } + if maxPart == 0 || len(owned) <= maxPart { + return nil + } + + // Sort deterministically so the same partitions are released across runs. + sort.Strings(owned) + + // Release excess partitions + for _, pk := range owned[maxPart:] { + if err := s.leaseStore.ReleaseLease(ctx, sub.topic, pk, cfg.SubscriberName, cfg.ConsumerGroup); err != nil { + return fmt.Errorf("release partition %s during rebalance: %w", pk, err) + } + + // Stop the worker immediately to prevent duplicate processing. + s.stopPartitionWorker(sub, pk) + + s.logger.Infow("released partition for rebalance", + "topic", sub.topic, + "consumer_group", cfg.ConsumerGroup, + "partition_key", pk, + "owned", len(owned), + "max_partitions", maxPart, + ) + } + return nil +} + +// fairShareCap computes the max partitions this subscriber should own. +// Returns (maxPart, error). maxPart=0 means unlimited. +// owned is the caller-provided list of leased partitions. +// discoveredPartitions is an optional pre-fetched list of all known partitions; +// if nil, only owned partitions are used for fair share computation. +func (s *subscriber) fairShareCap(ctx context.Context, sub *subscription, owned []string, discoveredPartitions []string) (int, error) { + cfg := sub.config + + active, err := s.heartbeatStore.ActiveSubscribers(ctx, sub.topic, cfg.ConsumerGroup, cfg.LeaseDurationMs) + if err != nil { + return 0, err + } + if len(active) <= 1 { + return 0, nil + } + + activeSubscribers := len(active) + + // Count all known partitions as the union of owned + discovered. + // Using max(owned, discovered) would undercount when some partitions + // have leases but no messages, or vice versa. + partitionSet := make(map[string]struct{}, len(owned)) + for _, pk := range owned { + partitionSet[pk] = struct{}{} + } + if discoveredPartitions != nil { + for _, pk := range discoveredPartitions { + partitionSet[pk] = struct{}{} + } + } + totalPartitions := len(partitionSet) + + // ceil(totalPartitions / activeSubscribers) + maxPart := (totalPartitions + activeSubscribers - 1) / activeSubscribers + if maxPart < 1 { + maxPart = 1 + } + + return maxPart, nil } // Close gracefully shuts down the subscriber and all its subscriptions. @@ -820,7 +1054,10 @@ func (s *subscriber) releaseAllLeases(ctx context.Context, sub *subscription) { // Close() does not block indefinitely if a subscription hangs // 3. managePartitions internally handles stopping workers and closing deliveryCh // (see managePartitions shutdown sequence) -func (s *subscriber) Close() error { +func (s *subscriber) Close() (retErr error) { + op := metrics.Begin(s.scope, "close") + defer func() { op.Complete(retErr) }() + s.mu.Lock() defer s.mu.Unlock() @@ -828,18 +1065,21 @@ func (s *subscriber) Close() error { return nil } - s.logger.Info("closing subscriber") + s.logger.Infow("closing subscriber") s.subMu.Lock() defer s.subMu.Unlock() // Cancel all subscriptions - for topic, sub := range s.subscriptions { - s.logger.Debugw("closing subscription", "topic", topic) + for _, sub := range s.subscriptions { + s.logger.Debugw("closing subscription", + "topic", sub.topic, + "consumer_group", sub.config.ConsumerGroup, + ) sub.cancelFunc() // Wait for the managePartitions goroutine to finish. We wrap the - // blocking Wait in a goroutine so we can enforce a timeout — if + // blocking Wait in a goroutine so we can enforce a timeout -- if // managePartitions is stuck, we log a warning and move on rather // than blocking Close() indefinitely. done := make(chan struct{}) @@ -852,17 +1092,20 @@ func (s *subscriber) Close() error { case <-done: // Graceful shutdown completed case <-time.After(subscriptionShutdownTimeout): - s.logger.Warnw("subscription shutdown timeout", "topic", topic) + s.logger.Warnw("subscription shutdown timeout", + "topic", sub.topic, + "consumer_group", sub.config.ConsumerGroup, + ) } // Update metrics - s.metrics.Tagged(map[string]string{"topic": topic}).Gauge("active_subscriptions").Update(0) + metrics.NamedGauge(s.scope, "subscribe", "active_subscriptions", 0, metrics.NewTag("topic", sub.topic)) } s.subscriptions = make(map[string]*subscription) s.closed = true - s.logger.Info("subscriber closed") + s.logger.Infow("subscriber closed") return nil } diff --git a/extension/queue/mysql/subscriber_heartbeat_store.go b/extension/queue/mysql/subscriber_heartbeat_store.go new file mode 100644 index 00000000..19d9567a --- /dev/null +++ b/extension/queue/mysql/subscriber_heartbeat_store.go @@ -0,0 +1,127 @@ +// Copyright (c) 2025 Uber Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysql + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/uber-go/tally/v4" + "github.com/uber/submitqueue/core/metrics" + "go.uber.org/zap" +) + +// sqlSubscriberHeartbeatStore is the SQL implementation of subscriberHeartbeatStore +type sqlSubscriberHeartbeatStore struct { + db *sql.DB + logger *zap.SugaredLogger + scope tally.Scope + nowFunc func() time.Time +} + +// newSubscriberHeartbeatStore creates a new SQL subscriber heartbeat store +func newSubscriberHeartbeatStore(db *sql.DB, logger *zap.SugaredLogger, scope tally.Scope, nowFunc func() time.Time) subscriberHeartbeatStore { + return &sqlSubscriberHeartbeatStore{ + db: db, + logger: logger.Named("subscriber_heartbeat_store"), + scope: scope.SubScope("subscriber_heartbeat_store"), + nowFunc: nowFunc, + } +} + +// Heartbeat registers or renews a subscriber's heartbeat. +func (s *sqlSubscriberHeartbeatStore) Heartbeat(ctx context.Context, topic string, subscriberName string, consumerGroup string) (retErr error) { + op := metrics.Begin(s.scope, "heartbeat") + defer func() { op.Complete(retErr) }() + + now := s.nowFunc().UnixMilli() + + _, err := s.db.ExecContext(ctx, fmt.Sprintf(` + INSERT INTO %s (consumer_group, topic, subscriber_name, heartbeat_at, deregistered_at) + VALUES (?, ?, ?, ?, 0) + ON DUPLICATE KEY UPDATE heartbeat_at = VALUES(heartbeat_at), deregistered_at = 0 + `, SubscriberHeartbeatsTableName), consumerGroup, topic, subscriberName, now) + + if err != nil { + return fmt.Errorf("failed to send heartbeat: %w", err) + } + + return nil +} + +// ActiveSubscribers returns the names of subscribers with a heartbeat newer than the stale threshold. +func (s *sqlSubscriberHeartbeatStore) ActiveSubscribers(ctx context.Context, topic string, consumerGroup string, staleDurationMs int64) (_ []string, retErr error) { + op := metrics.Begin(s.scope, "active_subscribers") + defer func() { op.Complete(retErr) }() + + staleThreshold := s.nowFunc().UnixMilli() - staleDurationMs + + rows, err := s.db.QueryContext(ctx, fmt.Sprintf(` + SELECT subscriber_name FROM %s + WHERE consumer_group = ? AND topic = ? AND heartbeat_at >= ? AND deregistered_at = 0 + `, SubscriberHeartbeatsTableName), consumerGroup, topic, staleThreshold) + if err != nil { + return nil, fmt.Errorf("failed to query active subscribers: %w", err) + } + defer rows.Close() + + var names []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, fmt.Errorf("failed to scan subscriber name: %w", err) + } + names = append(names, name) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("row iteration error: %w", err) + } + + s.logger.Debugw("found active subscribers", + logTopic, topic, + "count", len(names), + "subscribers", names, + ) + + return names, nil +} + +// Deregister soft-deletes a subscriber's heartbeat entry by setting deregistered_at. +// Idempotent: no-op if already deregistered. +func (s *sqlSubscriberHeartbeatStore) Deregister(ctx context.Context, topic string, subscriberName string, consumerGroup string) (retErr error) { + op := metrics.Begin(s.scope, "deregister") + defer func() { op.Complete(retErr) }() + + now := s.nowFunc().UnixMilli() + + _, err := s.db.ExecContext(ctx, fmt.Sprintf(` + UPDATE %s SET deregistered_at = ? + WHERE consumer_group = ? AND topic = ? AND subscriber_name = ? AND deregistered_at = 0 + `, SubscriberHeartbeatsTableName), now, consumerGroup, topic, subscriberName) + + if err != nil { + return fmt.Errorf("failed to deregister subscriber: %w", err) + } + + s.logger.Debugw("deregistered subscriber", + logTopic, topic, + "subscriber_name", subscriberName, + ) + + return nil +} diff --git a/extension/queue/mysql/subscriber_heartbeat_store_test.go b/extension/queue/mysql/subscriber_heartbeat_store_test.go new file mode 100644 index 00000000..50d1e2b7 --- /dev/null +++ b/extension/queue/mysql/subscriber_heartbeat_store_test.go @@ -0,0 +1,276 @@ +// Copyright (c) 2025 Uber Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysql + +import ( + "context" + "database/sql" + "fmt" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/require" + "github.com/uber-go/tally/v4" + "go.uber.org/zap/zaptest" +) + +func setupSubscriberHeartbeatStoreTest(t *testing.T) (*sql.DB, sqlmock.Sqlmock, subscriberHeartbeatStore) { + t.Helper() + + db, mock, err := sqlmock.New() + require.NoError(t, err) + + store := newSubscriberHeartbeatStore(db, zaptest.NewLogger(t).Sugar(), tally.NoopScope, time.Now) + + return db, mock, store +} + +func TestSubscriberHeartbeatStore_Heartbeat(t *testing.T) { + tests := []struct { + name string + setup func(mock sqlmock.Sqlmock) + wantErr bool + }{ + { + name: "successfully send heartbeat", + setup: func(mock sqlmock.Sqlmock) { + mock.ExpectExec("INSERT INTO queue_subscriber_heartbeats"). + WithArgs(testConsumerGroup, "test_topic", testSubscriberName, sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + }, + wantErr: false, + }, + { + name: "update existing heartbeat", + setup: func(mock sqlmock.Sqlmock) { + mock.ExpectExec("INSERT INTO queue_subscriber_heartbeats"). + WithArgs(testConsumerGroup, "test_topic", testSubscriberName, sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(0, 2)) // ON DUPLICATE KEY UPDATE returns 2 for update + }, + wantErr: false, + }, + { + name: "database error", + setup: func(mock sqlmock.Sqlmock) { + mock.ExpectExec("INSERT INTO queue_subscriber_heartbeats"). + WithArgs(testConsumerGroup, "test_topic", testSubscriberName, sqlmock.AnyArg()). + WillReturnError(fmt.Errorf("db error")) + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, mock, store := setupSubscriberHeartbeatStoreTest(t) + defer db.Close() + + ctx := context.Background() + tt.setup(mock) + + err := store.Heartbeat(ctx, "test_topic", testSubscriberName, testConsumerGroup) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.NoError(t, mock.ExpectationsWereMet()) + }) + } +} + +func TestSubscriberHeartbeatStore_ActiveSubscribers(t *testing.T) { + tests := []struct { + name string + setup func(mock sqlmock.Sqlmock) + wantNames []string + wantErr bool + }{ + { + name: "multiple active subscribers", + setup: func(mock sqlmock.Sqlmock) { + rows := sqlmock.NewRows([]string{"subscriber_name"}). + AddRow("sub-1").AddRow("sub-2").AddRow("sub-3") + mock.ExpectQuery("SELECT subscriber_name"). + WithArgs(testConsumerGroup, "test_topic", sqlmock.AnyArg()). + WillReturnRows(rows) + }, + wantNames: []string{"sub-1", "sub-2", "sub-3"}, + wantErr: false, + }, + { + name: "no active subscribers", + setup: func(mock sqlmock.Sqlmock) { + rows := sqlmock.NewRows([]string{"subscriber_name"}) + mock.ExpectQuery("SELECT subscriber_name"). + WithArgs(testConsumerGroup, "test_topic", sqlmock.AnyArg()). + WillReturnRows(rows) + }, + wantNames: nil, + wantErr: false, + }, + { + name: "database error", + setup: func(mock sqlmock.Sqlmock) { + mock.ExpectQuery("SELECT subscriber_name"). + WithArgs(testConsumerGroup, "test_topic", sqlmock.AnyArg()). + WillReturnError(fmt.Errorf("db error")) + }, + wantNames: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, mock, store := setupSubscriberHeartbeatStoreTest(t) + defer db.Close() + + ctx := context.Background() + tt.setup(mock) + + names, err := store.ActiveSubscribers(ctx, "test_topic", testConsumerGroup, testLeaseDurationMs) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.wantNames, names) + } + require.NoError(t, mock.ExpectationsWereMet()) + }) + } +} + +func TestSubscriberHeartbeatStore_ActiveSubscribers_ExcludesDeregistered(t *testing.T) { + db, mock, store := setupSubscriberHeartbeatStoreTest(t) + defer db.Close() + + ctx := context.Background() + + // Verify the query filters by deregistered_at = 0 + rows := sqlmock.NewRows([]string{"subscriber_name"}).AddRow("sub-1").AddRow("sub-2") + mock.ExpectQuery(`SELECT subscriber_name FROM queue_subscriber_heartbeats.*deregistered_at = 0`). + WithArgs(testConsumerGroup, "test_topic", sqlmock.AnyArg()). + WillReturnRows(rows) + + names, err := store.ActiveSubscribers(ctx, "test_topic", testConsumerGroup, testLeaseDurationMs) + require.NoError(t, err) + require.Equal(t, []string{"sub-1", "sub-2"}, names) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestSubscriberHeartbeatStore_Deregister_SoftDelete(t *testing.T) { + db, mock, store := setupSubscriberHeartbeatStoreTest(t) + defer db.Close() + + ctx := context.Background() + + // Verify deregister uses UPDATE (not DELETE) and targets only active rows (deregistered_at = 0) + mock.ExpectExec(`UPDATE queue_subscriber_heartbeats SET deregistered_at.*AND deregistered_at = 0`). + WithArgs(sqlmock.AnyArg(), testConsumerGroup, "test_topic", testSubscriberName). + WillReturnResult(sqlmock.NewResult(0, 1)) + + err := store.Deregister(ctx, "test_topic", testSubscriberName, testConsumerGroup) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestSubscriberHeartbeatStore_ReRegistration(t *testing.T) { + db, mock, store := setupSubscriberHeartbeatStoreTest(t) + defer db.Close() + + ctx := context.Background() + + // Step 1: Initial heartbeat registers the subscriber + mock.ExpectExec("INSERT INTO queue_subscriber_heartbeats"). + WithArgs(testConsumerGroup, "test_topic", testSubscriberName, sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + // Step 2: Deregister soft-deletes the subscriber + mock.ExpectExec("UPDATE queue_subscriber_heartbeats"). + WithArgs(sqlmock.AnyArg(), testConsumerGroup, "test_topic", testSubscriberName). + WillReturnResult(sqlmock.NewResult(0, 1)) + + // Step 3: Heartbeat again re-registers (ON DUPLICATE KEY UPDATE resets deregistered_at = 0) + mock.ExpectExec("INSERT INTO queue_subscriber_heartbeats"). + WithArgs(testConsumerGroup, "test_topic", testSubscriberName, sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(0, 2)) // 2 = ON DUPLICATE KEY UPDATE + + err := store.Heartbeat(ctx, "test_topic", testSubscriberName, testConsumerGroup) + require.NoError(t, err) + + err = store.Deregister(ctx, "test_topic", testSubscriberName, testConsumerGroup) + require.NoError(t, err) + + err = store.Heartbeat(ctx, "test_topic", testSubscriberName, testConsumerGroup) + require.NoError(t, err) + + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestSubscriberHeartbeatStore_Deregister(t *testing.T) { + tests := []struct { + name string + setup func(mock sqlmock.Sqlmock) + wantErr bool + }{ + { + name: "successfully deregister", + setup: func(mock sqlmock.Sqlmock) { + mock.ExpectExec("UPDATE queue_subscriber_heartbeats"). + WithArgs(sqlmock.AnyArg(), testConsumerGroup, "test_topic", testSubscriberName). + WillReturnResult(sqlmock.NewResult(0, 1)) + }, + wantErr: false, + }, + { + name: "idempotent - already deregistered", + setup: func(mock sqlmock.Sqlmock) { + mock.ExpectExec("UPDATE queue_subscriber_heartbeats"). + WithArgs(sqlmock.AnyArg(), testConsumerGroup, "test_topic", testSubscriberName). + WillReturnResult(sqlmock.NewResult(0, 0)) + }, + wantErr: false, + }, + { + name: "database error", + setup: func(mock sqlmock.Sqlmock) { + mock.ExpectExec("UPDATE queue_subscriber_heartbeats"). + WithArgs(sqlmock.AnyArg(), testConsumerGroup, "test_topic", testSubscriberName). + WillReturnError(fmt.Errorf("db error")) + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, mock, store := setupSubscriberHeartbeatStoreTest(t) + defer db.Close() + + ctx := context.Background() + tt.setup(mock) + + err := store.Deregister(ctx, "test_topic", testSubscriberName, testConsumerGroup) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.NoError(t, mock.ExpectationsWereMet()) + }) + } +} diff --git a/extension/queue/mysql/subscriber_test.go b/extension/queue/mysql/subscriber_test.go index cd909424..31e4d4b3 100644 --- a/extension/queue/mysql/subscriber_test.go +++ b/extension/queue/mysql/subscriber_test.go @@ -16,6 +16,7 @@ package mysql import ( "context" + "errors" "fmt" "testing" "time" @@ -34,9 +35,36 @@ func testSubscriptionConfig() extqueue.SubscriptionConfig { return extqueue.DefaultSubscriptionConfig("test-subscriber", "test-consumer") } +// newTestHeartbeatStore creates a mock heartbeat store that allows all calls +func newTestHeartbeatStore(ctrl *gomock.Controller) *MocksubscriberHeartbeatStore { + mockHB := NewMocksubscriberHeartbeatStore(ctrl) + mockHB.EXPECT().Heartbeat(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockHB.EXPECT().ActiveSubscribers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]string{"self"}, nil).AnyTimes() + mockHB.EXPECT().Deregister(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + return mockHB +} + +// newTestDeliveryStateStore creates a mock delivery state store that allows all calls +func newTestDeliveryStateStore(ctrl *gomock.Controller) *MockdeliveryStateStore { + mockDS := NewMockdeliveryStateStore(ctrl) + mockDS.EXPECT().MarkDelivered(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(0, nil).AnyTimes() + mockDS.EXPECT().MarkAcked(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockDS.EXPECT().MarkNacked(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockDS.EXPECT().GetDeliveryState(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(DeliveryState{}, false, nil).AnyTimes() + mockDS.EXPECT().AdvanceWatermark(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(0), nil).AnyTimes() + mockDS.EXPECT().ExtendVisibility(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + return mockDS +} + func setupSubscriberTest(t *testing.T, mockMessageStore *MockmessageStore, mockOffsetStore *MockoffsetStore, mockLeaseStore *MockpartitionLeaseStore) extqueue.Subscriber { t.Helper() - return NewSubscriber(zaptest.NewLogger(t).Sugar().Named("subscriber"), tally.NoopScope.SubScope("subscriber"), mockMessageStore, mockOffsetStore, mockLeaseStore) + ctrl := gomock.NewController(t) + mockHeartbeatStore := newTestHeartbeatStore(ctrl) + mockDeliveryStateStore := newTestDeliveryStateStore(ctrl) + // Allow watermark advancement calls from poll loop + mockOffsetStore.EXPECT().GetAckedOffset(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(0), nil).AnyTimes() + mockMessageStore.EXPECT().GetOffsetsAbove(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + return NewSubscriber(zaptest.NewLogger(t).Sugar().Named("subscriber"), tally.NoopScope.SubScope("subscriber"), mockMessageStore, mockOffsetStore, mockLeaseStore, mockHeartbeatStore, mockDeliveryStateStore) } func TestSubscriber_Subscribe(t *testing.T) { @@ -92,13 +120,84 @@ func TestSubscriber_Subscribe(t *testing.T) { } } +func TestSQLDelivery_Ack(t *testing.T) { + tests := []struct { + name string + alreadyAcked bool + markAckedErr error + expectErr bool + }{ + { + name: "successful ack", + }, + { + name: "already acknowledged returns error", + alreadyAcked: true, + expectErr: true, + }, + { + name: "MarkAcked failure returns error", + markAckedErr: fmt.Errorf("db error"), + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockMsgStore := NewMockmessageStore(ctrl) + mockOffStore := NewMockoffsetStore(ctrl) + mockLeaseStore := NewMockpartitionLeaseStore(ctrl) + mockDeliveryState := NewMockdeliveryStateStore(ctrl) + + sub := NewSubscriber( + zaptest.NewLogger(t).Sugar(), + tally.NoopScope, + mockMsgStore, + mockOffStore, + mockLeaseStore, + newTestHeartbeatStore(ctrl), + mockDeliveryState, + ) + + msg := queue.NewMessage("msg-1", []byte("payload"), "part-1", nil) + d := newSQLDelivery( + msg, "1", 1, nil, + sub, "test_topic", "part-1", 100, "msg-1", "test-group", + extqueue.DLQConfig{}, + ) + + if tt.alreadyAcked { + d.acknowledged = true + } + + if !tt.alreadyAcked { + // Ack only calls MarkAcked — watermark is deferred to poll loop + mockDeliveryState.EXPECT().MarkAcked( + gomock.Any(), "test-group", "test_topic", "part-1", int64(100), + ).Return(tt.markAckedErr) + } + + err := d.Ack(context.Background()) + + if tt.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.True(t, d.acknowledged) + } + }) + } +} + func TestSQLDelivery_Reject(t *testing.T) { tests := []struct { name string dlqEnabled bool alreadyAcked bool moveToDLQErr error - ackMessageErr error expectErr bool expectMoveDLQ bool expectAck bool @@ -109,7 +208,7 @@ func TestSQLDelivery_Reject(t *testing.T) { expectMoveDLQ: true, }, { - name: "DLQ disabled falls back to ack", + name: "DLQ disabled marks as acked", expectAck: true, }, { @@ -125,12 +224,6 @@ func TestSQLDelivery_Reject(t *testing.T) { expectErr: true, expectMoveDLQ: true, }, - { - name: "DLQ disabled but AckMessage fails", - ackMessageErr: fmt.Errorf("db error"), - expectErr: true, - expectAck: true, - }, } for _, tt := range tests { @@ -141,6 +234,7 @@ func TestSQLDelivery_Reject(t *testing.T) { mockMsgStore := NewMockmessageStore(ctrl) mockOffStore := NewMockoffsetStore(ctrl) mockLeaseStore := NewMockpartitionLeaseStore(ctrl) + mockDeliveryState := NewMockdeliveryStateStore(ctrl) sub := NewSubscriber( zaptest.NewLogger(t).Sugar(), @@ -148,6 +242,8 @@ func TestSQLDelivery_Reject(t *testing.T) { mockMsgStore, mockOffStore, mockLeaseStore, + newTestHeartbeatStore(ctrl), + mockDeliveryState, ) msg := queue.NewMessage("msg-1", []byte("payload"), "part-1", nil) @@ -168,20 +264,20 @@ func TestSQLDelivery_Reject(t *testing.T) { if tt.expectMoveDLQ { mockMsgStore.EXPECT().MoveToDLQ( - gomock.Any(), "test_topic", "msg-1", 1, "bad payload", "_dlq", + gomock.Any(), "test_topic", "part-1", "msg-1", 1, "bad payload", "_dlq", ).Return(tt.moveToDLQErr) if tt.moveToDLQErr == nil { - mockOffStore.EXPECT().UpdateAckedOffset( - gomock.Any(), "test_topic", "part-1", int64(100), "test-group", + mockDeliveryState.EXPECT().MarkAcked( + gomock.Any(), "test-group", "test_topic", "part-1", int64(100), ).Return(nil) } } if tt.expectAck { - mockOffStore.EXPECT().AckMessage( - gomock.Any(), "test_topic", "part-1", "msg-1", int64(100), "test-group", mockMsgStore, - ).Return(tt.ackMessageErr) + mockDeliveryState.EXPECT().MarkAcked( + gomock.Any(), "test-group", "test_topic", "part-1", int64(100), + ).Return(nil) } err := d.Reject(context.Background(), "bad payload") @@ -260,6 +356,7 @@ func TestSubscriber_Close(t *testing.T) { ch, err := sub.Subscribe(ctx, "test_topic", testSubscriptionConfig()) if tt.expectSubError { require.Error(t, err) + require.True(t, errors.Is(err, ErrSubscriberClosed)) assert.Nil(t, ch) } else { require.NoError(t, err) @@ -309,12 +406,17 @@ func TestSubscriber_ReconcilePartitionWorkers(t *testing.T) { mockMessageStore, mockOffsetStore, mockLeaseStore, + newTestHeartbeatStore(ctrl), + newTestDeliveryStateStore(ctrl), ) - // Allow offset initialization and fetch calls from workers + // Allow offset initialization, fetch, and watermark calls from workers mockOffsetStore.EXPECT().Initialize(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockOffsetStore.EXPECT().GetAckedOffset(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(0), nil).AnyTimes() - mockMessageStore.EXPECT().FetchByOffset(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + mockMessageStore.EXPECT().FetchByOffset(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + mockMessageStore.EXPECT().GetOffsetsAbove(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + mockMessageStore.EXPECT().GarbageCollect(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(0), nil).AnyTimes() + mockOffsetStore.EXPECT().GetMinAckedOffset(gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(0), false, nil).AnyTimes() ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -357,6 +459,7 @@ func TestSubscriber_PartitionWorkerPollAndDeliver(t *testing.T) { mockMessageStore := NewMockmessageStore(ctrl) mockOffsetStore := NewMockoffsetStore(ctrl) mockLeaseStore := NewMockpartitionLeaseStore(ctrl) + mockDeliveryState := NewMockdeliveryStateStore(ctrl) s := NewSubscriber( zaptest.NewLogger(t).Sugar(), @@ -364,6 +467,8 @@ func TestSubscriber_PartitionWorkerPollAndDeliver(t *testing.T) { mockMessageStore, mockOffsetStore, mockLeaseStore, + newTestHeartbeatStore(ctrl), + mockDeliveryState, ) cfg := testSubscriptionConfig() @@ -378,7 +483,8 @@ func TestSubscriber_PartitionWorkerPollAndDeliver(t *testing.T) { ctx := context.Background() mockOffsetStore.EXPECT().Initialize(gomock.Any(), "test_topic", "part-1", cfg.ConsumerGroup).Return(nil) - mockOffsetStore.EXPECT().GetAckedOffset(gomock.Any(), "test_topic", "part-1", cfg.ConsumerGroup).Return(int64(0), nil) + // GetAckedOffset is called twice: once by pollAndDeliver, once by advanceWatermark + mockOffsetStore.EXPECT().GetAckedOffset(gomock.Any(), "test_topic", "part-1", cfg.ConsumerGroup).Return(int64(0), nil).Times(2) row := messageRow{ ID: "msg-1", @@ -386,11 +492,19 @@ func TestSubscriber_PartitionWorkerPollAndDeliver(t *testing.T) { PartitionKey: "part-1", Payload: []byte("payload"), PublishedAt: time.Now().UnixMilli(), - RetryCount: 0, } - mockMessageStore.EXPECT().FetchByOffset(gomock.Any(), "test_topic", "part-1", int64(0), cfg.BatchSize, cfg.VisibilityTimeoutMs). + mockMessageStore.EXPECT().FetchByOffset(gomock.Any(), "test_topic", "part-1", int64(0), cfg.BatchSize). Return([]messageRow{row}, nil) + // Delivery state checks — GetDeliveryState returns not-found (new message) + mockDeliveryState.EXPECT().GetDeliveryState(gomock.Any(), cfg.ConsumerGroup, "test_topic", "part-1", int64(1)).Return(DeliveryState{}, false, nil) + // MarkDelivered returns retry count 0 (first delivery) + mockDeliveryState.EXPECT().MarkDelivered(gomock.Any(), cfg.ConsumerGroup, "test_topic", "part-1", int64(1), cfg.VisibilityTimeoutMs).Return(0, nil) + + // advanceWatermark called at end of pollAndDeliver + mockMessageStore.EXPECT().GetOffsetsAbove(gomock.Any(), "test_topic", "part-1", int64(0), watermarkAdvancementLimit).Return([]int64{1}, nil) + mockDeliveryState.EXPECT().AdvanceWatermark(gomock.Any(), cfg.ConsumerGroup, "test_topic", "part-1", int64(0), []int64{1}).Return(int64(0), nil) + w := &partitionWorker{ partitionKey: "part-1", sub: sub, @@ -426,12 +540,17 @@ func TestSubscriber_StopAllWorkers(t *testing.T) { mockMessageStore, mockOffsetStore, mockLeaseStore, + newTestHeartbeatStore(ctrl), + newTestDeliveryStateStore(ctrl), ) - // Allow worker polling + // Allow worker polling and watermark advancement mockOffsetStore.EXPECT().Initialize(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockOffsetStore.EXPECT().GetAckedOffset(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(0), nil).AnyTimes() - mockMessageStore.EXPECT().FetchByOffset(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + mockMessageStore.EXPECT().FetchByOffset(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + mockMessageStore.EXPECT().GetOffsetsAbove(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + mockMessageStore.EXPECT().GarbageCollect(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(0), nil).AnyTimes() + mockOffsetStore.EXPECT().GetMinAckedOffset(gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(0), false, nil).AnyTimes() ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/test/integration/extension/queue/mysql/queue_test.go b/test/integration/extension/queue/mysql/queue_test.go index ae15ea8a..da6232ee 100644 --- a/test/integration/extension/queue/mysql/queue_test.go +++ b/test/integration/extension/queue/mysql/queue_test.go @@ -40,6 +40,17 @@ import ( // testTimeout is the safety-net duration for channel waits in integration tests. const testTimeout = 10 * time.Second +// Timing constants for rebalance tests. Keep poll/lease intervals short to make +// tests converge fast, but lease duration long enough that active subscribers +// don't expire each other. +const ( + rebalancePollIntervalMs = 100 + rebalanceLeaseRenewalIntervalMs = 200 + rebalanceLeaseDurationMs = 1000 + rebalanceConvergeTimeout = 10 * time.Second + rebalanceConvergePollInterval = 200 * time.Millisecond +) + type SQLQueueIntegrationSuite struct { suite.Suite ctx context.Context @@ -1379,12 +1390,10 @@ func (s *SQLQueueIntegrationSuite) TestAdmin_TopicStatsAfterPublish() { require.NoError(t, err) assert.Equal(t, int64(3), stats.TotalMessages) - assert.Equal(t, int64(3), stats.VisibleMessages) - assert.Equal(t, int64(0), stats.InvisibleMessages) assert.Equal(t, int64(2), stats.PartitionCount) // p1, p2 assert.Equal(t, int64(0), stats.DLQCount) - t.Logf("Topic stats verified: total=%d visible=%d partitions=%d", stats.TotalMessages, stats.VisibleMessages, stats.PartitionCount) + t.Logf("Topic stats verified: total=%d partitions=%d", stats.TotalMessages, stats.PartitionCount) } func (s *SQLQueueIntegrationSuite) TestAdmin_InspectMessage() { @@ -1529,6 +1538,10 @@ func (s *SQLQueueIntegrationSuite) TestAdmin_LeasesAndOffsets() { delivery := receiveWithTimeout(t, deliveryChan, 5*time.Second) require.NoError(t, delivery.Ack(s.ctx)) + // Wait for the next poll tick to advance the watermark (deferred from Ack). + // PollIntervalMs=100, so 2 ticks (200ms) guarantees at least one full cycle. + time.Sleep(2 * time.Duration(subConfig.PollIntervalMs) * time.Millisecond) + admin := queueAdmin.NewAdminStore(s.db) // Verify leases are visible @@ -1620,3 +1633,797 @@ func (s *SQLQueueIntegrationSuite) TestAdmin_ResetOffsetAndReleaseLease() { t.Logf("Reset offset and release lease verified") } + +// --- Rebalance integration tests --- + +// getPartitionLeases queries the partition lease table and returns a map from +// subscriber name to the set of partition keys it owns for the given topic and +// consumer group. +func getPartitionLeases(db *sql.DB, topic, consumerGroup string) (map[string][]string, error) { + rows, err := db.Query( + "SELECT leased_by, partition_key FROM queue_partition_leases WHERE topic = ? AND consumer_group = ? ORDER BY leased_by, partition_key", + topic, consumerGroup, + ) + if err != nil { + return nil, err + } + defer rows.Close() + + result := make(map[string][]string) + for rows.Next() { + var owner, pk string + if err := rows.Scan(&owner, &pk); err != nil { + return nil, err + } + result[owner] = append(result[owner], pk) + } + return result, nil +} + +// rebalanceSubConfig returns a SubscriptionConfig tuned for fast rebalance tests. +func rebalanceSubConfig(subscriberName, consumerGroup string) extqueue.SubscriptionConfig { + cfg := extqueue.DefaultSubscriptionConfig(subscriberName, consumerGroup) + cfg.PollIntervalMs = rebalancePollIntervalMs + cfg.LeaseRenewalIntervalMs = rebalanceLeaseRenewalIntervalMs + cfg.LeaseDurationMs = rebalanceLeaseDurationMs + return cfg +} + +func (s *SQLQueueIntegrationSuite) TestRebalance_EvenDistribution() { + t := s.T() + + topic := "rebalance_even_topic" + consumerGroup := "rebalance-even-cg" + partitions := []string{"pk-1", "pk-2", "pk-3", "pk-4"} + + // Publish one message per partition so they are discoverable. + pubQ, err := queueMySQL.NewQueue(queueMySQL.Params{ + DB: s.db, Logger: zaptest.NewLogger(t), MetricsScope: tally.NoopScope, + }) + require.NoError(t, err) + defer pubQ.Close() + + for i, pk := range partitions { + msg := queue.NewMessage(fmt.Sprintf("rb-even-%d", i), []byte("x"), pk, nil) + require.NoError(t, pubQ.Publisher().Publish(s.ctx, topic, msg)) + } + + // S1: subscribe, should acquire all 4 partitions (only subscriber). + q1, err := queueMySQL.NewQueue(queueMySQL.Params{ + DB: s.db, Logger: zaptest.NewLogger(t), MetricsScope: tally.NoopScope, + }) + require.NoError(t, err) + defer q1.Close() + + _, err = q1.Subscriber().Subscribe(s.ctx, topic, rebalanceSubConfig("s1", consumerGroup)) + require.NoError(t, err) + + require.Eventually(t, func() bool { + leases, _ := getPartitionLeases(s.db, topic, consumerGroup) + return len(leases["s1"]) == 4 + }, rebalanceConvergeTimeout, rebalanceConvergePollInterval, "S1 should acquire all 4 partitions") + + // S2: subscribe. After rebalancing, each should own 2. + q2, err := queueMySQL.NewQueue(queueMySQL.Params{ + DB: s.db, Logger: zaptest.NewLogger(t), MetricsScope: tally.NoopScope, + }) + require.NoError(t, err) + defer q2.Close() + + _, err = q2.Subscriber().Subscribe(s.ctx, topic, rebalanceSubConfig("s2", consumerGroup)) + require.NoError(t, err) + + require.Eventually(t, func() bool { + leases, _ := getPartitionLeases(s.db, topic, consumerGroup) + return len(leases["s1"]) == 2 && len(leases["s2"]) == 2 + }, rebalanceConvergeTimeout, rebalanceConvergePollInterval, "each subscriber should own exactly 2 partitions") + + t.Logf("Even distribution verified: 4 partitions split evenly across 2 subscribers") +} + +func (s *SQLQueueIntegrationSuite) TestRebalance_SubscriberLeaves() { + t := s.T() + + topic := "rebalance_leave_topic" + consumerGroup := "rebalance-leave-cg" + partitions := []string{"pk-1", "pk-2", "pk-3", "pk-4"} + + // Publish messages. + pubQ, err := queueMySQL.NewQueue(queueMySQL.Params{ + DB: s.db, Logger: zaptest.NewLogger(t), MetricsScope: tally.NoopScope, + }) + require.NoError(t, err) + defer pubQ.Close() + + for i, pk := range partitions { + msg := queue.NewMessage(fmt.Sprintf("rb-leave-%d", i), []byte("x"), pk, nil) + require.NoError(t, pubQ.Publisher().Publish(s.ctx, topic, msg)) + } + + // S1 + S2 start, wait for 2+2 split. + q1, err := queueMySQL.NewQueue(queueMySQL.Params{ + DB: s.db, Logger: zaptest.NewLogger(t), MetricsScope: tally.NoopScope, + }) + require.NoError(t, err) + defer q1.Close() + + q2, err := queueMySQL.NewQueue(queueMySQL.Params{ + DB: s.db, Logger: zaptest.NewLogger(t), MetricsScope: tally.NoopScope, + }) + require.NoError(t, err) + // no defer close — we close explicitly below + + _, err = q1.Subscriber().Subscribe(s.ctx, topic, rebalanceSubConfig("s1", consumerGroup)) + require.NoError(t, err) + _, err = q2.Subscriber().Subscribe(s.ctx, topic, rebalanceSubConfig("s2", consumerGroup)) + require.NoError(t, err) + + require.Eventually(t, func() bool { + leases, _ := getPartitionLeases(s.db, topic, consumerGroup) + return len(leases["s1"])+len(leases["s2"]) == 4 && len(leases["s1"]) == 2 && len(leases["s2"]) == 2 + }, rebalanceConvergeTimeout, rebalanceConvergePollInterval, "2+2 split should converge") + + // S2 leaves: close releases leases and deregisters heartbeat. + require.NoError(t, q2.Close()) + + // S1's discovery loop will detect orphaned (expired) partitions and acquire them. + require.Eventually(t, func() bool { + leases, _ := getPartitionLeases(s.db, topic, consumerGroup) + return len(leases["s1"]) == 4 + }, rebalanceConvergeTimeout, rebalanceConvergePollInterval, "S1 should reacquire all 4 partitions after S2 leaves") + + t.Logf("Subscriber leave verified: S1 owns all 4 partitions after S2 departed") +} + +func (s *SQLQueueIntegrationSuite) TestRebalance_OddPartitions() { + t := s.T() + + topic := "rebalance_odd_topic" + consumerGroup := "rebalance-odd-cg" + partitions := []string{"pk-1", "pk-2", "pk-3", "pk-4", "pk-5"} + + pubQ, err := queueMySQL.NewQueue(queueMySQL.Params{ + DB: s.db, Logger: zaptest.NewLogger(t), MetricsScope: tally.NoopScope, + }) + require.NoError(t, err) + defer pubQ.Close() + + for i, pk := range partitions { + msg := queue.NewMessage(fmt.Sprintf("rb-odd-%d", i), []byte("x"), pk, nil) + require.NoError(t, pubQ.Publisher().Publish(s.ctx, topic, msg)) + } + + q1, err := queueMySQL.NewQueue(queueMySQL.Params{ + DB: s.db, Logger: zaptest.NewLogger(t), MetricsScope: tally.NoopScope, + }) + require.NoError(t, err) + defer q1.Close() + + q2, err := queueMySQL.NewQueue(queueMySQL.Params{ + DB: s.db, Logger: zaptest.NewLogger(t), MetricsScope: tally.NoopScope, + }) + require.NoError(t, err) + defer q2.Close() + + _, err = q1.Subscriber().Subscribe(s.ctx, topic, rebalanceSubConfig("s1", consumerGroup)) + require.NoError(t, err) + _, err = q2.Subscriber().Subscribe(s.ctx, topic, rebalanceSubConfig("s2", consumerGroup)) + require.NoError(t, err) + + // maxPart = ceil(5/2) = 3. One gets 3, the other gets 2. + require.Eventually(t, func() bool { + leases, _ := getPartitionLeases(s.db, topic, consumerGroup) + total := len(leases["s1"]) + len(leases["s2"]) + max := len(leases["s1"]) + if len(leases["s2"]) > max { + max = len(leases["s2"]) + } + min := len(leases["s1"]) + if len(leases["s2"]) < min { + min = len(leases["s2"]) + } + return total == 5 && max == 3 && min == 2 + }, rebalanceConvergeTimeout, rebalanceConvergePollInterval, "5 partitions should split 3+2 across 2 subscribers") + + t.Logf("Odd partition distribution verified: 5 partitions split 3+2") +} + +func (s *SQLQueueIntegrationSuite) TestRebalance_NoOrphans() { + t := s.T() + + topic := "rebalance_orphan_topic" + consumerGroup := "rebalance-orphan-cg" + partitions := []string{"pk-1", "pk-2", "pk-3", "pk-4", "pk-5", "pk-6"} + + pubQ, err := queueMySQL.NewQueue(queueMySQL.Params{ + DB: s.db, Logger: zaptest.NewLogger(t), MetricsScope: tally.NoopScope, + }) + require.NoError(t, err) + defer pubQ.Close() + + for i, pk := range partitions { + msg := queue.NewMessage(fmt.Sprintf("rb-orphan-%d", i), []byte("x"), pk, nil) + require.NoError(t, pubQ.Publisher().Publish(s.ctx, topic, msg)) + } + + // 3 subscribers → 2 each. + queues := make([]extqueue.Queue, 3) + subNames := []string{"s1", "s2", "s3"} + for i, name := range subNames { + q, err := queueMySQL.NewQueue(queueMySQL.Params{ + DB: s.db, Logger: zaptest.NewLogger(t), MetricsScope: tally.NoopScope, + }) + require.NoError(t, err) + queues[i] = q + _, err = q.Subscriber().Subscribe(s.ctx, topic, rebalanceSubConfig(name, consumerGroup)) + require.NoError(t, err) + } + defer queues[0].Close() + defer queues[1].Close() + // queues[2] will be closed explicitly + + require.Eventually(t, func() bool { + leases, _ := getPartitionLeases(s.db, topic, consumerGroup) + total := 0 + for _, pks := range leases { + total += len(pks) + } + return total == 6 + }, rebalanceConvergeTimeout, rebalanceConvergePollInterval, "all 6 partitions should be assigned across 3 subscribers") + + // Remove S3 → remaining 2 should pick up orphans. maxPart = ceil(6/2) = 3. + require.NoError(t, queues[2].Close()) + + // S1/S2 discovery loops will detect orphaned (expired) partitions and acquire them. + require.Eventually(t, func() bool { + leases, _ := getPartitionLeases(s.db, topic, consumerGroup) + total := len(leases["s1"]) + len(leases["s2"]) + // s3 leases should be gone (released on close or expired) + return total == 6 && len(leases["s3"]) == 0 + }, rebalanceConvergeTimeout, rebalanceConvergePollInterval, "remaining 2 subscribers should own all 6 partitions") + + t.Logf("No orphan partitions: all 6 reassigned after subscriber left") +} + +func (s *SQLQueueIntegrationSuite) TestRebalance_MoreSubscribersThanPartitions() { + t := s.T() + + topic := "rebalance_excess_topic" + consumerGroup := "rebalance-excess-cg" + partitions := []string{"pk-1", "pk-2"} + + pubQ, err := queueMySQL.NewQueue(queueMySQL.Params{ + DB: s.db, Logger: zaptest.NewLogger(t), MetricsScope: tally.NoopScope, + }) + require.NoError(t, err) + defer pubQ.Close() + + for i, pk := range partitions { + msg := queue.NewMessage(fmt.Sprintf("rb-excess-%d", i), []byte("x"), pk, nil) + require.NoError(t, pubQ.Publisher().Publish(s.ctx, topic, msg)) + } + + // 4 subscribers competing for 2 partitions. maxPart = ceil(2/4) = 1. + subNames := []string{"s1", "s2", "s3", "s4"} + var queues []extqueue.Queue + for _, name := range subNames { + q, err := queueMySQL.NewQueue(queueMySQL.Params{ + DB: s.db, Logger: zaptest.NewLogger(t), MetricsScope: tally.NoopScope, + }) + require.NoError(t, err) + queues = append(queues, q) + _, err = q.Subscriber().Subscribe(s.ctx, topic, rebalanceSubConfig(name, consumerGroup)) + require.NoError(t, err) + } + defer func() { + for _, q := range queues { + q.Close() + } + }() + + require.Eventually(t, func() bool { + leases, _ := getPartitionLeases(s.db, topic, consumerGroup) + total := 0 + maxOwned := 0 + for _, pks := range leases { + total += len(pks) + if len(pks) > maxOwned { + maxOwned = len(pks) + } + } + return total == 2 && maxOwned <= 1 + }, rebalanceConvergeTimeout, rebalanceConvergePollInterval, + "2 partitions across 4 subscribers: total=2, max per subscriber=1") + + t.Logf("More subscribers than partitions verified: 2 partitions, 4 subscribers, max 1 each") +} + +// TestNackDoesNotBlockOtherMessages verifies that nacking a message does not +// block delivery of subsequent messages in the same partition. The nacked +// message should be skipped (invisible) while later messages are delivered. +func (s *SQLQueueIntegrationSuite) TestNackDoesNotBlockOtherMessages() { + t := s.T() + + q, err := queueMySQL.NewQueue(queueMySQL.Params{ + DB: s.db, Logger: zaptest.NewLogger(t), MetricsScope: tally.NoopScope, + }) + require.NoError(t, err) + defer q.Close() + + topic := "nack_nonblocking_topic" + partition := "nack-nb-part" + + // Subscribe with batch=10 to fetch multiple messages per poll + subConfig := extqueue.DefaultSubscriptionConfig("worker-1", "nack-nb-cg") + subConfig.PollIntervalMs = 50 + subConfig.BatchSize = 10 + deliveryCh, err := q.Subscriber().Subscribe(s.ctx, topic, subConfig) + require.NoError(t, err) + + // Publish 3 messages in order + for i := 1; i <= 3; i++ { + msg := queue.NewMessage(fmt.Sprintf("msg-%d", i), []byte(fmt.Sprintf("payload-%d", i)), partition, nil) + require.NoError(t, q.Publisher().Publish(s.ctx, topic, msg)) + } + + // Receive first message and nack it with a long delay + d1 := receiveWithTimeout(t, deliveryCh, testTimeout) + assert.Equal(t, "msg-1", d1.Message().ID) + require.NoError(t, d1.Nack(s.ctx, 30000)) // 30s delay — won't come back during test + t.Logf("Nacked msg-1 with 30s delay") + + // Messages 2 and 3 should still be deliverable despite msg-1 being nacked + d2 := receiveWithTimeout(t, deliveryCh, testTimeout) + assert.Equal(t, "msg-2", d2.Message().ID) + require.NoError(t, d2.Ack(s.ctx)) + t.Logf("Received and acked msg-2") + + d3 := receiveWithTimeout(t, deliveryCh, testTimeout) + assert.Equal(t, "msg-3", d3.Message().ID) + require.NoError(t, d3.Ack(s.ctx)) + t.Logf("Received and acked msg-3") + + t.Logf("Verified: nacked message did not block subsequent messages") +} + +// TestBatchSizeOneStrictSerialization verifies that with batchSize=1, messages +// within a partition are processed strictly in order — only one message is +// in-flight at a time. +func (s *SQLQueueIntegrationSuite) TestBatchSizeOneStrictSerialization() { + t := s.T() + + q, err := queueMySQL.NewQueue(queueMySQL.Params{ + DB: s.db, Logger: zaptest.NewLogger(t), MetricsScope: tally.NoopScope, + }) + require.NoError(t, err) + defer q.Close() + + topic := "serial_topic" + partition := "serial-part" + + // Subscribe with batchSize=1 for strict serialization + subConfig := extqueue.DefaultSubscriptionConfig("worker-1", "serial-cg") + subConfig.PollIntervalMs = 50 + subConfig.BatchSize = 1 + deliveryCh, err := q.Subscriber().Subscribe(s.ctx, topic, subConfig) + require.NoError(t, err) + + // Publish 5 messages + for i := 1; i <= 5; i++ { + msg := queue.NewMessage(fmt.Sprintf("serial-%d", i), []byte(strconv.Itoa(i)), partition, nil) + require.NoError(t, q.Publisher().Publish(s.ctx, topic, msg)) + } + + // Receive each message strictly in order, acking before receiving next + for i := 1; i <= 5; i++ { + delivery := receiveWithTimeout(t, deliveryCh, testTimeout) + assert.Equal(t, fmt.Sprintf("serial-%d", i), delivery.Message().ID, + "expected message %d but got %s", i, delivery.Message().ID) + require.NoError(t, delivery.Ack(s.ctx)) + t.Logf("Strictly ordered delivery: serial-%d", i) + } + + // Verify no more messages + select { + case d := <-deliveryCh: + t.Fatalf("unexpected extra delivery: %s", d.Message().ID) + case <-time.After(500 * time.Millisecond): + // Expected — no more messages + } + + t.Logf("Verified: batchSize=1 enforces strict serialization") +} + +// TestMultipleConsumerGroupsIndependentState verifies that two consumer groups +// can independently process, nack, retry, and ack the same messages without +// interfering with each other. +func (s *SQLQueueIntegrationSuite) TestMultipleConsumerGroupsIndependentState() { + t := s.T() + + q, err := queueMySQL.NewQueue(queueMySQL.Params{ + DB: s.db, Logger: zaptest.NewLogger(t), MetricsScope: tally.NoopScope, + }) + require.NoError(t, err) + defer q.Close() + + topic := "multi_cg_state_topic" + partition := "multi-cg-part" + + // Two consumer groups subscribing to the same topic + cfg1 := extqueue.DefaultSubscriptionConfig("worker-1", "cg-alpha") + cfg1.PollIntervalMs = 50 + cfg2 := extqueue.DefaultSubscriptionConfig("worker-2", "cg-beta") + cfg2.PollIntervalMs = 50 + + ch1, err := q.Subscriber().Subscribe(s.ctx, topic, cfg1) + require.NoError(t, err) + ch2, err := q.Subscriber().Subscribe(s.ctx, topic, cfg2) + require.NoError(t, err) + + // Publish 2 messages + for i := 1; i <= 2; i++ { + msg := queue.NewMessage(fmt.Sprintf("shared-%d", i), []byte(strconv.Itoa(i)), partition, nil) + require.NoError(t, q.Publisher().Publish(s.ctx, topic, msg)) + } + + // CG-alpha: nack msg-1, ack msg-2 + d1a := receiveWithTimeout(t, ch1, testTimeout) + assert.Equal(t, "shared-1", d1a.Message().ID) + require.NoError(t, d1a.Nack(s.ctx, 200)) // short nack delay + t.Logf("cg-alpha nacked shared-1") + + d2a := receiveWithTimeout(t, ch1, testTimeout) + assert.Equal(t, "shared-2", d2a.Message().ID) + require.NoError(t, d2a.Ack(s.ctx)) + t.Logf("cg-alpha acked shared-2") + + // CG-beta: ack both messages immediately (independent state) + d1b := receiveWithTimeout(t, ch2, testTimeout) + assert.Equal(t, "shared-1", d1b.Message().ID) + require.NoError(t, d1b.Ack(s.ctx)) + t.Logf("cg-beta acked shared-1") + + d2b := receiveWithTimeout(t, ch2, testTimeout) + assert.Equal(t, "shared-2", d2b.Message().ID) + require.NoError(t, d2b.Ack(s.ctx)) + t.Logf("cg-beta acked shared-2") + + // CG-alpha should get shared-1 redelivered after nack delay + d1aRetry := receiveWithTimeout(t, ch1, testTimeout) + assert.Equal(t, "shared-1", d1aRetry.Message().ID) + require.Greater(t, d1aRetry.Attempt(), 1, "should be a retry attempt") + require.NoError(t, d1aRetry.Ack(s.ctx)) + t.Logf("cg-alpha received retry of shared-1 (attempt=%d)", d1aRetry.Attempt()) + + // CG-beta should NOT get shared-1 again (already acked independently) + select { + case d := <-ch2: + t.Fatalf("cg-beta should not receive more messages, got: %s", d.Message().ID) + case <-time.After(500 * time.Millisecond): + // Expected — cg-beta is done + } + + t.Logf("Verified: consumer groups have fully independent delivery state") +} + +// TestCrashAfterRejectDoesNotLoseMessages verifies that rejecting a later message +// (which sends it to DLQ) does not cause earlier in-flight messages to be lost +// after a process crash. This is a regression test for P0-1 where Reject() called +// UpdateAckedOffset directly, bypassing watermark contiguity. +func (s *SQLQueueIntegrationSuite) TestCrashAfterRejectDoesNotLoseMessages() { + t := s.T() + + topic := "crash_reject_topic" + + q1, err := queueMySQL.NewQueue(queueMySQL.Params{ + DB: s.db, + Logger: zaptest.NewLogger(t), + MetricsScope: tally.NoopScope, + }) + require.NoError(t, err) + + publisher := q1.Publisher() + + // Publish 3 messages to the same partition + require.NoError(t, publisher.Publish(s.ctx, topic, queue.NewMessage("msg-A", []byte("A"), "same-part", nil))) + require.NoError(t, publisher.Publish(s.ctx, topic, queue.NewMessage("msg-B", []byte("B"), "same-part", nil))) + require.NoError(t, publisher.Publish(s.ctx, topic, queue.NewMessage("msg-C", []byte("C"), "same-part", nil))) + + // Subscribe with short timeouts for fast test + subConfig := extqueue.DefaultSubscriptionConfig("worker-1", "crash-reject-cg") + subConfig.PollIntervalMs = 100 + subConfig.VisibilityTimeoutMs = 2000 + subConfig.LeaseDurationMs = 3000 + subConfig.LeaseRenewalIntervalMs = 1000 + subConfig.BatchSize = 10 + subConfig.Retry.MaxAttempts = 3 + subConfig.DLQ.Enabled = true + + deliveryChan1, err := q1.Subscriber().Subscribe(s.ctx, topic, subConfig) + require.NoError(t, err) + + // Receive all 3 messages + deliveries := make(map[string]extqueue.Delivery) + receiveNWithTimeout(t, deliveryChan1, 3, testTimeout, func(d extqueue.Delivery, _ int) { + deliveries[d.Message().ID] = d + t.Logf("Received %s", d.Message().ID) + }) + + // Ack A, Reject B (→ DLQ), leave C in-flight + require.NoError(t, deliveries["msg-A"].Ack(s.ctx)) + t.Logf("Acked msg-A") + + require.NoError(t, deliveries["msg-B"].Reject(s.ctx, "bad payload")) + t.Logf("Rejected msg-B → DLQ") + + // Do NOT ack msg-C — simulating in-flight at crash time + + // Simulate crash + q1.Close() + t.Logf("Worker-1 crashed (queue closed)") + + // Wait for lease + visibility to expire + waitTime := time.Duration(subConfig.LeaseDurationMs+subConfig.VisibilityTimeoutMs)*time.Millisecond + time.Second + t.Logf("Waiting %v for lease and visibility to expire", waitTime) + time.Sleep(waitTime) + + // Start worker-2 with same consumer group + q2, err := queueMySQL.NewQueue(queueMySQL.Params{ + DB: s.db, + Logger: zaptest.NewLogger(t), + MetricsScope: tally.NoopScope, + }) + require.NoError(t, err) + defer q2.Close() + + subConfig2 := extqueue.DefaultSubscriptionConfig("worker-2", "crash-reject-cg") + subConfig2.PollIntervalMs = 100 + subConfig2.VisibilityTimeoutMs = 2000 + subConfig2.LeaseDurationMs = 3000 + subConfig2.LeaseRenewalIntervalMs = 1000 + subConfig2.BatchSize = 10 + subConfig2.Retry.MaxAttempts = 3 + subConfig2.DLQ.Enabled = true + + deliveryChan2, err := q2.Subscriber().Subscribe(s.ctx, topic, subConfig2) + require.NoError(t, err) + + // Worker-2 MUST receive msg-C (it must NOT be lost) + delivery := receiveWithTimeout(t, deliveryChan2, testTimeout) + assert.Equal(t, "msg-C", delivery.Message().ID, "msg-C must be recovered after crash") + require.NoError(t, delivery.Ack(s.ctx)) + t.Logf("Worker-2 recovered msg-C (attempt=%d)", delivery.Attempt()) + + // Verify DLQ contains msg-B + dlqTopic := topic + subConfig.DLQ.TopicSuffix + dlqConfig := extqueue.DefaultSubscriptionConfig("worker-2", "crash-reject-cg") + dlqConfig.PollIntervalMs = 100 + dlqChan, err := q2.Subscriber().Subscribe(s.ctx, dlqTopic, dlqConfig) + require.NoError(t, err) + + dlqDelivery := receiveWithTimeout(t, dlqChan, testTimeout) + assert.Equal(t, "msg-B", dlqDelivery.Message().ID, "msg-B should be in DLQ") + require.NoError(t, dlqDelivery.Ack(s.ctx)) + + // Verify consumer lag is 0 + admin := queueAdmin.NewAdminStore(s.db) + lags, err := admin.ConsumerLag(s.ctx, topic) + require.NoError(t, err) + for _, lag := range lags { + if lag.ConsumerGroup == "crash-reject-cg" { + assert.Equal(t, int64(0), lag.Lag, "consumer lag should be 0 after recovery") + } + } + + t.Logf("Verified: crash after reject does not lose messages") +} + +// TestCrashAfterRetryLimitDoesNotLoseMessages verifies that the retry-limit +// auto-DLQ path does not cause earlier in-flight messages to be lost after crash. +// This is a regression test for P0-1 where the retry-limit path in pollAndDeliver +// called UpdateAckedOffset directly, bypassing watermark contiguity. +func (s *SQLQueueIntegrationSuite) TestCrashAfterRetryLimitDoesNotLoseMessages() { + t := s.T() + + topic := "crash_retry_limit_topic" + + q1, err := queueMySQL.NewQueue(queueMySQL.Params{ + DB: s.db, + Logger: zaptest.NewLogger(t), + MetricsScope: tally.NoopScope, + }) + require.NoError(t, err) + + publisher := q1.Publisher() + + // Publish 3 messages to the same partition + require.NoError(t, publisher.Publish(s.ctx, topic, queue.NewMessage("msg-A", []byte("A"), "same-part", nil))) + require.NoError(t, publisher.Publish(s.ctx, topic, queue.NewMessage("msg-B", []byte("B"), "same-part", nil))) + require.NoError(t, publisher.Publish(s.ctx, topic, queue.NewMessage("msg-C", []byte("C"), "same-part", nil))) + + // MaxAttempts=2: msg-B needs nack → redeliver → retry_count=2 → auto-DLQ. + // Use long visibility so msg-C stays in-flight and isn't auto-DLQ'd. + subConfig := extqueue.DefaultSubscriptionConfig("worker-1", "crash-retry-cg") + subConfig.PollIntervalMs = 100 + subConfig.VisibilityTimeoutMs = 30000 // long visibility so msg-C stays in-flight + subConfig.LeaseDurationMs = 5000 + subConfig.LeaseRenewalIntervalMs = 2000 + subConfig.BatchSize = 10 + subConfig.Retry.MaxAttempts = 2 + subConfig.DLQ.Enabled = true + + deliveryChan1, err := q1.Subscriber().Subscribe(s.ctx, topic, subConfig) + require.NoError(t, err) + + // Receive all 3 messages + deliveries := make(map[string]extqueue.Delivery) + receiveNWithTimeout(t, deliveryChan1, 3, testTimeout, func(d extqueue.Delivery, _ int) { + deliveries[d.Message().ID] = d + t.Logf("Received %s (attempt=%d)", d.Message().ID, d.Attempt()) + }) + + // Ack A + require.NoError(t, deliveries["msg-A"].Ack(s.ctx)) + t.Logf("Acked msg-A") + + // Nack B with short delay so it becomes visible quickly for redelivery + require.NoError(t, deliveries["msg-B"].Nack(s.ctx, 100)) + t.Logf("Nacked msg-B, waiting for retry-limit to trigger auto-DLQ") + + // Wait for nack delay to expire so msg-B gets redelivered. + // The poll loop will see retry_count=2 >= MaxAttempts=2 → auto-DLQ. + time.Sleep(1 * time.Second) + + // Do NOT ack msg-C — simulating in-flight at crash time. + // msg-C has long visibility (30s) so it stays in-flight during the crash. + + // Simulate crash + q1.Close() + t.Logf("Worker-1 crashed (queue closed)") + + // Wait for lease to expire (visibility for C is 30s, but we use a fresh + // subscriber with its own visibility, so C becomes deliverable after + // the original visibility expires) + leaseWait := time.Duration(subConfig.LeaseDurationMs)*time.Millisecond + time.Second + t.Logf("Waiting %v for lease to expire", leaseWait) + time.Sleep(leaseWait) + + // Wait for msg-C's visibility to expire + visibilityWait := time.Duration(subConfig.VisibilityTimeoutMs) * time.Millisecond + t.Logf("Waiting %v for msg-C visibility to expire", visibilityWait) + time.Sleep(visibilityWait) + + // Start worker-2 with same consumer group + q2, err := queueMySQL.NewQueue(queueMySQL.Params{ + DB: s.db, + Logger: zaptest.NewLogger(t), + MetricsScope: tally.NoopScope, + }) + require.NoError(t, err) + defer q2.Close() + + subConfig2 := extqueue.DefaultSubscriptionConfig("worker-2", "crash-retry-cg") + subConfig2.PollIntervalMs = 100 + subConfig2.VisibilityTimeoutMs = 5000 + subConfig2.LeaseDurationMs = 5000 + subConfig2.LeaseRenewalIntervalMs = 2000 + subConfig2.BatchSize = 10 + subConfig2.Retry.MaxAttempts = 10 // high limit so recovered messages aren't DLQ'd + subConfig2.DLQ.Enabled = true + + deliveryChan2, err := q2.Subscriber().Subscribe(s.ctx, topic, subConfig2) + require.NoError(t, err) + + // Worker-2 should receive both msg-B and msg-C (neither should be lost). + // msg-B was nacked but may or may not have hit retry-limit depending on timing. + // The key invariant: all unacked messages are recoverable after crash. + recovered := make(map[string]bool) + for i := 0; i < 2; i++ { + delivery := receiveWithTimeout(t, deliveryChan2, testTimeout) + recovered[delivery.Message().ID] = true + require.NoError(t, delivery.Ack(s.ctx)) + t.Logf("Worker-2 recovered %s (attempt=%d)", delivery.Message().ID, delivery.Attempt()) + } + assert.True(t, recovered["msg-B"] || recovered["msg-C"], "at least msg-B or msg-C must be recovered") + assert.True(t, recovered["msg-C"], "msg-C must be recovered after crash") + + t.Logf("Verified: crash after retry-limit does not lose messages") +} + +// TestWatermarkAdvancesContiguously verifies that the acked offset watermark +// advances correctly with out-of-order acks. The watermark should only advance +// when all preceding messages have been acked (contiguous). +func (s *SQLQueueIntegrationSuite) TestWatermarkAdvancesContiguously() { + t := s.T() + + topic := "watermark_contiguous_topic" + + q, err := queueMySQL.NewQueue(queueMySQL.Params{ + DB: s.db, + Logger: zaptest.NewLogger(t), + MetricsScope: tally.NoopScope, + }) + require.NoError(t, err) + defer q.Close() + + publisher := q.Publisher() + + // Publish 5 messages to the same partition + for i := 1; i <= 5; i++ { + msg := queue.NewMessage( + fmt.Sprintf("wm-msg-%d", i), + []byte(fmt.Sprintf("payload-%d", i)), + "wm-part", + nil, + ) + require.NoError(t, publisher.Publish(s.ctx, topic, msg)) + } + + subConfig := extqueue.DefaultSubscriptionConfig("worker-1", "watermark-cg") + subConfig.PollIntervalMs = 100 + subConfig.VisibilityTimeoutMs = 30000 // long visibility so nothing re-delivers + subConfig.BatchSize = 10 + + deliveryChan, err := q.Subscriber().Subscribe(s.ctx, topic, subConfig) + require.NoError(t, err) + + // Receive all 5 + deliveries := make(map[string]extqueue.Delivery) + receiveNWithTimeout(t, deliveryChan, 5, testTimeout, func(d extqueue.Delivery, _ int) { + deliveries[d.Message().ID] = d + t.Logf("Received %s", d.Message().ID) + }) + + admin := queueAdmin.NewAdminStore(s.db) + + // Helper to get consumer lag + getLag := func() int64 { + lags, err := admin.ConsumerLag(s.ctx, topic) + require.NoError(t, err) + for _, lag := range lags { + if lag.ConsumerGroup == "watermark-cg" && lag.PartitionKey == "wm-part" { + return lag.Lag + } + } + return -1 + } + + // Ack message 3 first (out of order) + require.NoError(t, deliveries["wm-msg-3"].Ack(s.ctx)) + t.Logf("Acked msg-3") + + // Ack messages 1 and 2 — now 1,2,3 are contiguous + require.NoError(t, deliveries["wm-msg-1"].Ack(s.ctx)) + require.NoError(t, deliveries["wm-msg-2"].Ack(s.ctx)) + t.Logf("Acked msg-1 and msg-2") + + // Wait for poll loop to advance watermark + time.Sleep(500 * time.Millisecond) + + // After acking 1,2,3: watermark should advance to 3, lag should be 2 (msg-4, msg-5) + lag := getLag() + assert.Equal(t, int64(2), lag, "lag should be 2 after acking 1,2,3 (4 and 5 remain)") + t.Logf("After acking 1,2,3: lag=%d", lag) + + // Ack message 5 (skip 4) — watermark should NOT advance past 3 + require.NoError(t, deliveries["wm-msg-5"].Ack(s.ctx)) + t.Logf("Acked msg-5 (skipping msg-4)") + + time.Sleep(500 * time.Millisecond) + + lag = getLag() + assert.Equal(t, int64(2), lag, "lag should still be 2 after acking 5 but not 4") + t.Logf("After acking 5 (not 4): lag=%d", lag) + + // Ack message 4 — now all 5 are contiguous, watermark should advance to 5 + require.NoError(t, deliveries["wm-msg-4"].Ack(s.ctx)) + t.Logf("Acked msg-4") + + time.Sleep(500 * time.Millisecond) + + lag = getLag() + assert.Equal(t, int64(0), lag, "lag should be 0 after acking all 5 messages") + t.Logf("After acking all 5: lag=%d", lag) + + t.Logf("Verified: watermark advances contiguously with out-of-order acks") +}