From bc7b5bb7aabaf35109411790ea3b64a80c4381e4 Mon Sep 17 00:00:00 2001 From: Preetam Dwivedi Date: Thu, 19 Feb 2026 20:21:03 -0800 Subject: [PATCH] feat(queue/sql): add subscriber with partition leasing and offset tracking - Implement Subscriber interface with partition-based message polling - Add partition lease management for distributed workers - Add offset tracking per partition for consumption progress - Extend MessageStore interface with FetchByOffset, SetVisibilityTimeout, MoveToDLQ - Add OffsetStore interface for offset management - Add PartitionLeaseStore interface for partition leasing - Generate mocks for all three store interfaces - Add standardized metric tags (topic + partition_key) across all operations - Add comprehensive test coverage for subscription, ack/nack, and error handling --- extensions/queue/sql/BUILD.bazel | 4 + extensions/queue/sql/errors.go | 12 + extensions/queue/sql/mock_stores.go | 265 +++++++++++- extensions/queue/sql/publisher.go | 4 +- extensions/queue/sql/publisher_test.go | 32 +- extensions/queue/sql/stores.go | 74 +++- extensions/queue/sql/subscriber.go | 552 ++++++++++++++++++++++++ extensions/queue/sql/subscriber_test.go | 146 +++++++ 8 files changed, 1054 insertions(+), 35 deletions(-) create mode 100644 extensions/queue/sql/errors.go create mode 100644 extensions/queue/sql/subscriber.go create mode 100644 extensions/queue/sql/subscriber_test.go diff --git a/extensions/queue/sql/BUILD.bazel b/extensions/queue/sql/BUILD.bazel index df128279..5c38cd48 100644 --- a/extensions/queue/sql/BUILD.bazel +++ b/extensions/queue/sql/BUILD.bazel @@ -4,15 +4,18 @@ go_library( name = "sql", srcs = [ "config.go", + "errors.go", "mock_stores.go", "publisher.go", "stores.go", + "subscriber.go", "validation.go", ], importpath = "github.com/uber/submitqueue/extensions/queue/sql", visibility = ["//visibility:public"], deps = [ "//entities/queue", + "//extensions/queue", "@com_github_uber_go_tally_v4//:tally", "@org_uber_go_mock//gomock", "@org_uber_go_zap//:zap", @@ -24,6 +27,7 @@ go_test( srcs = [ "config_test.go", "publisher_test.go", + "subscriber_test.go", ], embed = [":sql"], deps = [ diff --git a/extensions/queue/sql/errors.go b/extensions/queue/sql/errors.go new file mode 100644 index 00000000..57887595 --- /dev/null +++ b/extensions/queue/sql/errors.go @@ -0,0 +1,12 @@ +package sql + +import "fmt" + +// ErrAlreadyAcknowledged is returned when attempting to ack/nack a delivery that was already processed +type ErrAlreadyAcknowledged struct { + DeliveryID string +} + +func (e *ErrAlreadyAcknowledged) Error() string { + return fmt.Sprintf("delivery %s already acknowledged or nacked", e.DeliveryID) +} diff --git a/extensions/queue/sql/mock_stores.go b/extensions/queue/sql/mock_stores.go index 05203a1d..7df7d1a7 100644 --- a/extensions/queue/sql/mock_stores.go +++ b/extensions/queue/sql/mock_stores.go @@ -3,7 +3,7 @@ // // Generated by this command: // -// mockgen -source=stores.go -destination=mock_stores.go -package=sql -self_package=github.com/uber/submitqueue/extensions/queue/sql +// mockgen -source=stores.go -destination=mock_stores.go -package=sql // // Package sql is a generated GoMock package. @@ -17,32 +17,61 @@ import ( gomock "go.uber.org/mock/gomock" ) -// MockMessageStore is a mock of MessageStore interface. -type MockMessageStore struct { +// MockmessageStore is a mock of messageStore interface. +type MockmessageStore struct { ctrl *gomock.Controller - recorder *MockMessageStoreMockRecorder + recorder *MockmessageStoreMockRecorder isgomock struct{} } -// MockMessageStoreMockRecorder is the mock recorder for MockMessageStore. -type MockMessageStoreMockRecorder struct { - mock *MockMessageStore +// MockmessageStoreMockRecorder is the mock recorder for MockmessageStore. +type MockmessageStoreMockRecorder struct { + mock *MockmessageStore } -// NewMockMessageStore creates a new mock instance. -func NewMockMessageStore(ctrl *gomock.Controller) *MockMessageStore { - mock := &MockMessageStore{ctrl: ctrl} - mock.recorder = &MockMessageStoreMockRecorder{mock} +// NewMockmessageStore creates a new mock instance. +func NewMockmessageStore(ctrl *gomock.Controller) *MockmessageStore { + mock := &MockmessageStore{ctrl: ctrl} + mock.recorder = &MockmessageStoreMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockMessageStore) EXPECT() *MockMessageStoreMockRecorder { +func (m *MockmessageStore) EXPECT() *MockmessageStoreMockRecorder { return m.recorder } +// Delete mocks base method. +func (m *MockmessageStore) Delete(ctx context.Context, topic, messageID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Delete", ctx, topic, messageID) + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete. +func (mr *MockmessageStoreMockRecorder) Delete(ctx, topic, messageID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockmessageStore)(nil).Delete), ctx, topic, messageID) +} + +// FetchByOffset mocks base method. +func (m *MockmessageStore) FetchByOffset(ctx context.Context, topic, partitionKey string, currentOffset int64, limit int) ([]messageRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FetchByOffset", ctx, topic, partitionKey, currentOffset, limit) + ret0, _ := ret[0].([]messageRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FetchByOffset indicates an expected call of FetchByOffset. +func (mr *MockmessageStoreMockRecorder) FetchByOffset(ctx, topic, partitionKey, currentOffset, limit any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchByOffset", reflect.TypeOf((*MockmessageStore)(nil).FetchByOffset), ctx, topic, partitionKey, currentOffset, limit) +} + // Insert mocks base method. -func (m *MockMessageStore) Insert(ctx context.Context, topic string, messages []queue.Message) error { +func (m *MockmessageStore) Insert(ctx context.Context, topic string, messages []queue.Message) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Insert", ctx, topic, messages) ret0, _ := ret[0].(error) @@ -50,7 +79,213 @@ func (m *MockMessageStore) Insert(ctx context.Context, topic string, messages [] } // Insert indicates an expected call of Insert. -func (mr *MockMessageStoreMockRecorder) Insert(ctx, topic, messages any) *gomock.Call { +func (mr *MockmessageStoreMockRecorder) Insert(ctx, topic, messages any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockmessageStore)(nil).Insert), ctx, topic, messages) +} + +// MoveToDLQ mocks base method. +func (m *MockmessageStore) MoveToDLQ(ctx context.Context, topic, messageID string, failureCount int, lastError string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MoveToDLQ", ctx, topic, messageID, failureCount, lastError) + ret0, _ := ret[0].(error) + return ret0 +} + +// MoveToDLQ indicates an expected call of MoveToDLQ. +func (mr *MockmessageStoreMockRecorder) MoveToDLQ(ctx, topic, messageID, failureCount, lastError any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MoveToDLQ", reflect.TypeOf((*MockmessageStore)(nil).MoveToDLQ), ctx, topic, messageID, failureCount, lastError) +} + +// SetVisibilityTimeout mocks base method. +func (m *MockmessageStore) SetVisibilityTimeout(ctx context.Context, topic, messageID string, visibilityTimeoutMillis int64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetVisibilityTimeout", ctx, topic, messageID, visibilityTimeoutMillis) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetVisibilityTimeout indicates an expected call of SetVisibilityTimeout. +func (mr *MockmessageStoreMockRecorder) SetVisibilityTimeout(ctx, topic, messageID, visibilityTimeoutMillis any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetVisibilityTimeout", reflect.TypeOf((*MockmessageStore)(nil).SetVisibilityTimeout), ctx, topic, messageID, visibilityTimeoutMillis) +} + +// MockoffsetStore is a mock of offsetStore interface. +type MockoffsetStore struct { + ctrl *gomock.Controller + recorder *MockoffsetStoreMockRecorder + isgomock struct{} +} + +// MockoffsetStoreMockRecorder is the mock recorder for MockoffsetStore. +type MockoffsetStoreMockRecorder struct { + mock *MockoffsetStore +} + +// NewMockoffsetStore creates a new mock instance. +func NewMockoffsetStore(ctrl *gomock.Controller) *MockoffsetStore { + mock := &MockoffsetStore{ctrl: ctrl} + mock.recorder = &MockoffsetStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockoffsetStore) EXPECT() *MockoffsetStoreMockRecorder { + return m.recorder +} + +// AckMessage mocks base method. +func (m *MockoffsetStore) AckMessage(ctx context.Context, topic, partitionKey, messageID string, offset int64, messageStore messageStore) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AckMessage", ctx, topic, partitionKey, messageID, offset, messageStore) + ret0, _ := ret[0].(error) + return ret0 +} + +// AckMessage indicates an expected call of AckMessage. +func (mr *MockoffsetStoreMockRecorder) AckMessage(ctx, topic, partitionKey, messageID, offset, messageStore any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AckMessage", reflect.TypeOf((*MockoffsetStore)(nil).AckMessage), ctx, topic, partitionKey, messageID, offset, messageStore) +} + +// GetAckedOffset mocks base method. +func (m *MockoffsetStore) GetAckedOffset(ctx context.Context, topic, partitionKey string) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAckedOffset", ctx, topic, partitionKey) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAckedOffset indicates an expected call of GetAckedOffset. +func (mr *MockoffsetStoreMockRecorder) GetAckedOffset(ctx, topic, partitionKey any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckedOffset", reflect.TypeOf((*MockoffsetStore)(nil).GetAckedOffset), ctx, topic, partitionKey) +} + +// Initialize mocks base method. +func (m *MockoffsetStore) Initialize(ctx context.Context, topic, partitionKey string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Initialize", ctx, topic, partitionKey) + ret0, _ := ret[0].(error) + return ret0 +} + +// Initialize indicates an expected call of Initialize. +func (mr *MockoffsetStoreMockRecorder) Initialize(ctx, topic, partitionKey any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Initialize", reflect.TypeOf((*MockoffsetStore)(nil).Initialize), ctx, topic, partitionKey) +} + +// UpdateAckedOffset mocks base method. +func (m *MockoffsetStore) UpdateAckedOffset(ctx context.Context, topic, partitionKey string, offset int64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateAckedOffset", ctx, topic, partitionKey, offset) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateAckedOffset indicates an expected call of UpdateAckedOffset. +func (mr *MockoffsetStoreMockRecorder) UpdateAckedOffset(ctx, topic, partitionKey, offset any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAckedOffset", reflect.TypeOf((*MockoffsetStore)(nil).UpdateAckedOffset), ctx, topic, partitionKey, offset) +} + +// MockpartitionLeaseStore is a mock of partitionLeaseStore interface. +type MockpartitionLeaseStore struct { + ctrl *gomock.Controller + recorder *MockpartitionLeaseStoreMockRecorder + isgomock struct{} +} + +// MockpartitionLeaseStoreMockRecorder is the mock recorder for MockpartitionLeaseStore. +type MockpartitionLeaseStoreMockRecorder struct { + mock *MockpartitionLeaseStore +} + +// NewMockpartitionLeaseStore creates a new mock instance. +func NewMockpartitionLeaseStore(ctrl *gomock.Controller) *MockpartitionLeaseStore { + mock := &MockpartitionLeaseStore{ctrl: ctrl} + mock.recorder = &MockpartitionLeaseStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockpartitionLeaseStore) EXPECT() *MockpartitionLeaseStoreMockRecorder { + return m.recorder +} + +// DiscoverAndAcquirePartitions mocks base method. +func (m *MockpartitionLeaseStore) DiscoverAndAcquirePartitions(ctx context.Context, topic string) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DiscoverAndAcquirePartitions", ctx, topic) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DiscoverAndAcquirePartitions indicates an expected call of DiscoverAndAcquirePartitions. +func (mr *MockpartitionLeaseStoreMockRecorder) DiscoverAndAcquirePartitions(ctx, topic any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DiscoverAndAcquirePartitions", reflect.TypeOf((*MockpartitionLeaseStore)(nil).DiscoverAndAcquirePartitions), ctx, topic) +} + +// GetLeasedPartitions mocks base method. +func (m *MockpartitionLeaseStore) GetLeasedPartitions(ctx context.Context, topic string) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetLeasedPartitions", ctx, topic) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetLeasedPartitions indicates an expected call of GetLeasedPartitions. +func (mr *MockpartitionLeaseStoreMockRecorder) GetLeasedPartitions(ctx, topic any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeasedPartitions", reflect.TypeOf((*MockpartitionLeaseStore)(nil).GetLeasedPartitions), ctx, topic) +} + +// ReleaseLease mocks base method. +func (m *MockpartitionLeaseStore) ReleaseLease(ctx context.Context, topic, partitionKey string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReleaseLease", ctx, topic, partitionKey) + ret0, _ := ret[0].(error) + return ret0 +} + +// ReleaseLease indicates an expected call of ReleaseLease. +func (mr *MockpartitionLeaseStoreMockRecorder) ReleaseLease(ctx, topic, partitionKey any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReleaseLease", reflect.TypeOf((*MockpartitionLeaseStore)(nil).ReleaseLease), ctx, topic, partitionKey) +} + +// RenewLease mocks base method. +func (m *MockpartitionLeaseStore) RenewLease(ctx context.Context, topic, partitionKey string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RenewLease", ctx, topic, partitionKey) + ret0, _ := ret[0].(error) + return ret0 +} + +// RenewLease indicates an expected call of RenewLease. +func (mr *MockpartitionLeaseStoreMockRecorder) RenewLease(ctx, topic, partitionKey any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewLease", reflect.TypeOf((*MockpartitionLeaseStore)(nil).RenewLease), ctx, topic, partitionKey) +} + +// TryAcquireLease mocks base method. +func (m *MockpartitionLeaseStore) TryAcquireLease(ctx context.Context, topic, partitionKey string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TryAcquireLease", ctx, topic, partitionKey) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// TryAcquireLease indicates an expected call of TryAcquireLease. +func (mr *MockpartitionLeaseStoreMockRecorder) TryAcquireLease(ctx, topic, partitionKey any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockMessageStore)(nil).Insert), ctx, topic, messages) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryAcquireLease", reflect.TypeOf((*MockpartitionLeaseStore)(nil).TryAcquireLease), ctx, topic, partitionKey) } diff --git a/extensions/queue/sql/publisher.go b/extensions/queue/sql/publisher.go index 0e4d96a8..530a24a3 100644 --- a/extensions/queue/sql/publisher.go +++ b/extensions/queue/sql/publisher.go @@ -15,13 +15,13 @@ type publisher struct { config Config logger *zap.SugaredLogger metrics tally.Scope - messageStore MessageStore + messageStore messageStore mu sync.RWMutex closed bool } // NewPublisher creates a publisher with the given configuration and dependencies -func NewPublisher(config Config, logger *zap.SugaredLogger, metrics tally.Scope, messageStore MessageStore) *publisher { +func NewPublisher(config Config, logger *zap.SugaredLogger, metrics tally.Scope, messageStore messageStore) *publisher { return &publisher{ config: config, logger: logger, diff --git a/extensions/queue/sql/publisher_test.go b/extensions/queue/sql/publisher_test.go index 101f4c93..8f02b1da 100644 --- a/extensions/queue/sql/publisher_test.go +++ b/extensions/queue/sql/publisher_test.go @@ -17,7 +17,7 @@ import ( const fixedTimestamp = int64(1234567890000) // Fixed timestamp for test repeatability -func setupPublisherTest(t *testing.T, mockStore *MockMessageStore) extqueue.Publisher { +func setupPublisherTest(t *testing.T, mockStore *MockmessageStore) extqueue.Publisher { t.Helper() config := DefaultConfig("test-consumer", "test-worker") @@ -35,7 +35,7 @@ func TestPublisher_Publish(t *testing.T) { topic string messages []queue.Message wantErr bool - setupMock func(*MockMessageStore) + setupMock func(*MockmessageStore) }{ { name: "publish single message", @@ -44,7 +44,7 @@ func TestPublisher_Publish(t *testing.T) { {ID: "msg1", Payload: []byte("payload1"), PartitionKey: "part1", PublishedAt: fixedTimestamp}, }, wantErr: false, - setupMock: func(m *MockMessageStore) { + setupMock: func(m *MockmessageStore) { m.EXPECT().Insert(gomock.Any(), "test_topic", gomock.Any()).Return(nil).Times(1) }, }, @@ -57,7 +57,7 @@ func TestPublisher_Publish(t *testing.T) { {ID: "msg3", Payload: []byte("p3"), PartitionKey: "part2", PublishedAt: fixedTimestamp}, }, wantErr: false, - setupMock: func(m *MockMessageStore) { + setupMock: func(m *MockmessageStore) { m.EXPECT().Insert(gomock.Any(), "multi_topic", gomock.Any()).Return(nil).Times(3) }, }, @@ -66,7 +66,7 @@ func TestPublisher_Publish(t *testing.T) { topic: "empty_topic", messages: []queue.Message{}, wantErr: false, - setupMock: func(m *MockMessageStore) { + setupMock: func(m *MockmessageStore) { // No Insert expected }, }, @@ -83,7 +83,7 @@ func TestPublisher_Publish(t *testing.T) { }, }, wantErr: false, - setupMock: func(m *MockMessageStore) { + setupMock: func(m *MockmessageStore) { m.EXPECT().Insert(gomock.Any(), "metadata_topic", gomock.Any()).Return(nil).Times(1) }, }, @@ -94,7 +94,7 @@ func TestPublisher_Publish(t *testing.T) { {ID: "msg1", Payload: []byte("p"), PartitionKey: "part1", PublishedAt: fixedTimestamp}, }, wantErr: true, - setupMock: func(m *MockMessageStore) { + setupMock: func(m *MockmessageStore) { // No Insert expected since validation fails }, }, @@ -105,7 +105,7 @@ func TestPublisher_Publish(t *testing.T) { {ID: "msg1", Payload: []byte("p"), PartitionKey: "part1", PublishedAt: fixedTimestamp}, }, wantErr: true, - setupMock: func(m *MockMessageStore) { + setupMock: func(m *MockmessageStore) { // No Insert expected since validation fails }, }, @@ -116,7 +116,7 @@ func TestPublisher_Publish(t *testing.T) { {ID: "msg1", Payload: []byte("p"), PartitionKey: "part1", PublishedAt: fixedTimestamp}, }, wantErr: true, - setupMock: func(m *MockMessageStore) { + setupMock: func(m *MockmessageStore) { // No Insert expected since validation fails }, }, @@ -127,7 +127,7 @@ func TestPublisher_Publish(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockStore := NewMockMessageStore(ctrl) + mockStore := NewMockmessageStore(ctrl) tt.setupMock(mockStore) pub := setupPublisherTest(t, mockStore) @@ -153,7 +153,7 @@ func TestPublisher_PublishAfterClose(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockStore := NewMockMessageStore(ctrl) + mockStore := NewMockmessageStore(ctrl) pub := setupPublisherTest(t, mockStore) ctx := context.Background() @@ -173,7 +173,7 @@ func TestPublisher_Close(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockStore := NewMockMessageStore(ctrl) + mockStore := NewMockmessageStore(ctrl) pub := setupPublisherTest(t, mockStore) // Close should succeed @@ -248,7 +248,7 @@ func TestValidateTopicName(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockStore := NewMockMessageStore(ctrl) + mockStore := NewMockmessageStore(ctrl) pub := setupPublisherTest(t, mockStore) // Try to publish with this topic name @@ -274,7 +274,7 @@ func TestPublisher_PublishMetrics(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockStore := NewMockMessageStore(ctrl) + mockStore := NewMockmessageStore(ctrl) mockStore.EXPECT().Insert(gomock.Any(), "metrics_test", gomock.Any()).Return(nil).Times(2) pub := setupPublisherTest(t, mockStore) @@ -304,7 +304,7 @@ func TestPublisher_ConcurrentPublish(t *testing.T) { const numGoroutines = 10 const messagesPerGoroutine = 5 - mockStore := NewMockMessageStore(ctrl) + mockStore := NewMockmessageStore(ctrl) mockStore.EXPECT().Insert(gomock.Any(), "concurrent_topic", gomock.Any()).Return(nil).Times(numGoroutines * messagesPerGoroutine) pub := setupPublisherTest(t, mockStore) @@ -339,7 +339,7 @@ func TestPublisher_PublishContextCancellation(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockStore := NewMockMessageStore(ctrl) + mockStore := NewMockmessageStore(ctrl) mockStore.EXPECT().Insert(gomock.Any(), "test_topic", gomock.Any()).Return(context.Canceled).Times(1) pub := setupPublisherTest(t, mockStore) diff --git a/extensions/queue/sql/stores.go b/extensions/queue/sql/stores.go index 140e15a3..d34e97b2 100644 --- a/extensions/queue/sql/stores.go +++ b/extensions/queue/sql/stores.go @@ -8,8 +8,78 @@ import ( "github.com/uber/submitqueue/entities/queue" ) -// MessageStore handles message table operations -type MessageStore interface { +// messageRow represents a row from the messages table (internal use only) +type messageRow struct { + // Offset is the auto-incrementing sequence number for message ordering within a partition + Offset int64 + // ID is the unique message identifier + ID string + // Payload is the message body in bytes + Payload []byte + // Metadata contains key-value pairs for message attributes + Metadata map[string]string + // PartitionKey determines which partition this message belongs to for ordering guarantees + PartitionKey string + // RetryCount tracks how many times this message has been retried on the current topic + RetryCount int + // PublishedAt is the Unix timestamp in milliseconds when message was published + PublishedAt int64 +} + +// messageStore handles message table operations (internal use only) +type messageStore interface { // Insert inserts messages into the topic table Insert(ctx context.Context, topic string, messages []queue.Message) error + + // Delete deletes a message by ID + Delete(ctx context.Context, topic string, messageID string) error + + // FetchByOffset fetches messages with offset > currentOffset for a specific partition + // Only fetches visible messages (invisible_until <= now) + // Atomically sets invisible_until and increments retry_count + FetchByOffset(ctx context.Context, topic string, partitionKey string, currentOffset int64, limit int) ([]messageRow, error) + + // MoveToDLQ moves a message to the dead letter queue + MoveToDLQ(ctx context.Context, topic string, messageID string, failureCount int, lastError string) error + + // SetVisibilityTimeout sets the invisible_until timestamp for a message + // visibilityTimeoutMillis: milliseconds from now to hide the message + // If visibilityTimeoutMillis is 0, makes the message visible immediately + // If visibilityTimeoutMillis > 0, makes the message invisible until now + visibilityTimeoutMillis + SetVisibilityTimeout(ctx context.Context, topic string, messageID string, visibilityTimeoutMillis int64) error +} + +// offsetStore handles offset table operations for per-partition offset tracking (internal use only) +type offsetStore interface { + // Initialize creates an offset entry for a topic+partition if it doesn't exist + Initialize(ctx context.Context, topic string, partitionKey string) error + + // GetAckedOffset returns the current acked offset for a topic+partition + GetAckedOffset(ctx context.Context, topic string, partitionKey string) (int64, error) + + // UpdateAckedOffset updates the offset_acked for a topic+partition (only if new offset is greater) + UpdateAckedOffset(ctx context.Context, topic string, partitionKey string, offset int64) error + + // AckMessage atomically deletes a message and updates the acked offset + AckMessage(ctx context.Context, topic string, partitionKey string, messageID string, offset int64, messageStore messageStore) error +} + +// partitionLeaseStore handles partition lease operations (internal use only) +type partitionLeaseStore interface { + // TryAcquireLease attempts to acquire or renew a lease for a partition + // Returns true if lease is acquired/owned by this worker + TryAcquireLease(ctx context.Context, topic string, partitionKey string) (bool, error) + + // RenewLease renews the lease for a partition owned by this worker + RenewLease(ctx context.Context, topic string, partitionKey string) error + + // ReleaseLease releases the lease for a partition owned by this worker + ReleaseLease(ctx context.Context, topic string, partitionKey string) error + + // GetLeasedPartitions returns all partitions currently leased by this worker + GetLeasedPartitions(ctx context.Context, topic string) ([]string, error) + + // DiscoverAndAcquirePartitions discovers partitions from messages table and tries to acquire leases + // Returns the number of new leases acquired + DiscoverAndAcquirePartitions(ctx context.Context, topic string) (int, error) } diff --git a/extensions/queue/sql/subscriber.go b/extensions/queue/sql/subscriber.go new file mode 100644 index 00000000..f2b4991f --- /dev/null +++ b/extensions/queue/sql/subscriber.go @@ -0,0 +1,552 @@ +package sql + +import ( + "context" + "fmt" + "strconv" + "sync" + "time" + + "github.com/uber-go/tally/v4" + "go.uber.org/zap" + + "github.com/uber/submitqueue/entities/queue" + extqueue "github.com/uber/submitqueue/extensions/queue" +) + +type subscriber struct { + config Config + logger *zap.SugaredLogger + metrics tally.Scope + messageStore messageStore + offsetStore offsetStore + leaseStore partitionLeaseStore + mu sync.RWMutex + closed bool + + // Active subscriptions + subscriptions map[string]*subscription + subMu sync.Mutex +} + +type subscription struct { + topic string + deliveryCh chan extqueue.Delivery + cancelFunc context.CancelFunc + wg sync.WaitGroup +} + +// sqlDelivery implements extqueue.Delivery for SQL queue +type sqlDelivery struct { + msg queue.Message + deliveryID string + attempt int + receivedAt int64 + metadata map[string]string + + // Backend-specific fields for ack/nack + subscriber *subscriber + topic string + partitionKey string + offset int64 + messageID string + + // Track acknowledgment state + mu sync.Mutex + acknowledged bool +} + +func newSQLDelivery( + msg queue.Message, + deliveryID string, + attempt int, + metadata map[string]string, + subscriber *subscriber, + topic string, + partitionKey string, + offset int64, + messageID string, +) *sqlDelivery { + return &sqlDelivery{ + msg: msg, + deliveryID: deliveryID, + attempt: attempt, + receivedAt: time.Now().UnixMilli(), + metadata: metadata, + subscriber: subscriber, + topic: topic, + partitionKey: partitionKey, + offset: offset, + messageID: messageID, + acknowledged: false, + } +} + +// Message implements extqueue.Delivery.Message +func (d *sqlDelivery) Message() queue.Message { + return d.msg +} + +// DeliveryID implements extqueue.Delivery.DeliveryID +func (d *sqlDelivery) DeliveryID() string { + return d.deliveryID +} + +// Attempt implements extqueue.Delivery.Attempt +func (d *sqlDelivery) Attempt() int { + return d.attempt +} + +// ReceivedAt implements extqueue.Delivery.ReceivedAt +func (d *sqlDelivery) ReceivedAt() int64 { + return d.receivedAt +} + +// Metadata implements extqueue.Delivery.Metadata +func (d *sqlDelivery) Metadata() map[string]string { + return d.metadata +} + +// Ack implements extqueue.Delivery.Ack +func (d *sqlDelivery) Ack(ctx context.Context) error { + d.mu.Lock() + defer d.mu.Unlock() + + if d.acknowledged { + return &ErrAlreadyAcknowledged{DeliveryID: d.deliveryID} + } + + // Perform acknowledgment + if err := d.subscriber.offsetStore.AckMessage(ctx, d.topic, d.partitionKey, d.messageID, d.offset, d.subscriber.messageStore); err != nil { + return err + } + + // Record metrics + d.subscriber.metrics.Tagged(map[string]string{ + "topic": d.topic, + "partition_key": d.partitionKey, + }).Counter("messages_acked").Inc(1) + + d.acknowledged = true + return nil +} + +// Nack implements extqueue.Delivery.Nack +func (d *sqlDelivery) Nack(ctx context.Context, requeueAfterMillis int64) error { + d.mu.Lock() + defer d.mu.Unlock() + + if d.acknowledged { + return &ErrAlreadyAcknowledged{DeliveryID: d.deliveryID} + } + + // Set visibility timeout to make message visible after requeueAfter duration + if err := d.subscriber.messageStore.SetVisibilityTimeout(ctx, d.topic, d.messageID, requeueAfterMillis); err != nil { + d.subscriber.logger.Errorw("failed to set visibility timeout for nack", + "topic", d.topic, + "partition_key", d.partitionKey, + "message_id", d.messageID, + "error", err, + ) + return err + } + + // Record metrics + d.subscriber.metrics.Tagged(map[string]string{ + "topic": d.topic, + "partition_key": d.partitionKey, + }).Counter("messages_nacked").Inc(1) + + d.subscriber.logger.Infow("message nacked", + "topic", d.topic, + "partition_key", d.partitionKey, + "message_id", d.messageID, + "requeue_after_millis", requeueAfterMillis, + ) + + d.acknowledged = true + return nil +} + +// ExtendVisibilityTimeout implements extqueue.Delivery.ExtendVisibilityTimeout +func (d *sqlDelivery) ExtendVisibilityTimeout(ctx context.Context, durationMillis int64) error { + d.mu.Lock() + defer d.mu.Unlock() + + if d.acknowledged { + return fmt.Errorf("delivery %s already acknowledged, cannot extend visibility timeout", d.deliveryID) + } + + if err := d.subscriber.messageStore.SetVisibilityTimeout(ctx, d.topic, d.messageID, durationMillis); err != nil { + return err + } + + // Record metrics + d.subscriber.metrics.Tagged(map[string]string{ + "topic": d.topic, + "partition_key": d.partitionKey, + }).Counter("visibility_extended").Inc(1) + + return nil +} + +func NewSubscriber(config Config, logger *zap.SugaredLogger, metrics tally.Scope, messageStore messageStore, offsetStore offsetStore, leaseStore partitionLeaseStore) *subscriber { + logger.Infow("created subscriber", + "consumer_group", config.ConsumerGroup, + "worker_id", config.WorkerID, + "poll_interval", config.PollInterval, + "batch_size", config.BatchSize, + "max_retry_attempts", config.Retry.MaxAttempts, + "lease_renewal_interval", config.LeaseRenewalInterval, + ) + + return &subscriber{ + config: config, + logger: logger, + metrics: metrics, + messageStore: messageStore, + offsetStore: offsetStore, + leaseStore: leaseStore, + subscriptions: make(map[string]*subscription), + } +} + +// Subscribe starts consuming messages from the specified topic +func (s *subscriber) Subscribe(ctx context.Context, topic string) (<-chan extqueue.Delivery, error) { + s.mu.RLock() + closed := s.closed + s.mu.RUnlock() + + if closed { + s.logger.Errorw("subscribe failed: subscriber is closed", "topic", topic) + return nil, fmt.Errorf("subscriber is closed") + } + + // Validate topic name + if err := validateTopicName(topic); err != nil { + s.logger.Errorw("subscribe failed: invalid topic name", "topic", topic, "error", err) + return nil, fmt.Errorf("subscribe failure: invalid topic name. err: %w", err) + } + + s.subMu.Lock() + defer s.subMu.Unlock() + + // Check if already subscribed + if sub, exists := s.subscriptions[topic]; exists { + s.logger.Debugw("reusing existing subscription", "topic", topic) + return sub.deliveryCh, nil + } + + s.logger.Infow("creating new subscription", "topic", topic) + + // Create new subscription + // Use a cancellable context for managing the subscription lifecycle + // and close will cancel the context to signal goroutines to stop + subCtx, cancel := context.WithCancel(context.Background()) + sub := &subscription{ + topic: topic, + deliveryCh: make(chan extqueue.Delivery, s.config.BatchSize*2), + cancelFunc: cancel, + } + + s.subscriptions[topic] = sub + + // Track active subscription + s.metrics.Tagged(map[string]string{"topic": topic}).Gauge("active_subscriptions").Update(1) + + // Start partition leasing and polling goroutine + sub.wg.Add(1) + go s.managePartitions(subCtx, sub) + + s.logger.Debugw("subscription created", "topic", topic, "consumer_group", s.config.ConsumerGroup, "worker_id", s.config.WorkerID) + return sub.deliveryCh, nil +} + +// managePartitions discovers partitions, acquires leases, and polls messages +func (s *subscriber) managePartitions(ctx context.Context, sub *subscription) { + defer sub.wg.Done() + defer close(sub.deliveryCh) + + pollTicker := time.NewTicker(s.config.PollInterval) + defer pollTicker.Stop() + + leaseTicker := time.NewTicker(s.config.LeaseRenewalInterval) + defer leaseTicker.Stop() + + for { + select { + case <-ctx.Done(): + // Release all leases on shutdown + s.releaseAllLeases(ctx, sub.topic) + return + + case <-leaseTicker.C: + // Renew existing leases + s.renewLeases(ctx, sub.topic) + + case <-pollTicker.C: + // Fetch and deliver messages from leased partitions + s.pollLeasedPartitions(ctx, sub) + } + } +} + +// renewLeases renews leases for all partitions owned by this worker +func (s *subscriber) renewLeases(ctx context.Context, topic string) { + leasedPartitions, err := s.leaseStore.GetLeasedPartitions(ctx, topic) + if err != nil { + s.logger.Errorw("failed to get leased partitions for renewal", + "topic", topic, + "error", err, + ) + // Error suppressed: lease renewal is best-effort. If we can't get leases, + // they will eventually expire and be reacquired by this or another worker. + // Failing the entire renewal cycle would be worse than skipping one iteration. + s.metrics.Tagged(map[string]string{"topic": topic}).Counter("lease_renewal.get_partitions_errors").Inc(1) + return + } + + for _, partitionKey := range leasedPartitions { + if err := s.leaseStore.RenewLease(ctx, topic, partitionKey); err != nil { + s.logger.Warnw("failed to renew lease", + "topic", topic, + "partition_key", partitionKey, + "error", err, + ) + // Error suppressed: Continue trying to renew other leases even if one fails. + // The partition will eventually expire and be re-acquired by this or another worker. + // Failing fast would prevent other partitions from being renewed. + s.metrics.Tagged(map[string]string{ + "topic": topic, + "partition_key": partitionKey, + }).Counter("lease_renewal.renew_errors").Inc(1) + } + } +} + +// releaseAllLeases releases all leases for a topic +func (s *subscriber) releaseAllLeases(ctx context.Context, topic string) { + leasedPartitions, err := s.leaseStore.GetLeasedPartitions(ctx, topic) + if err != nil { + s.logger.Errorw("failed to get leased partitions for release", + "topic", topic, + "error", err, + ) + return + } + + for _, partitionKey := range leasedPartitions { + if err := s.leaseStore.ReleaseLease(ctx, topic, partitionKey); err != nil { + s.logger.Warnw("failed to release lease", + "topic", topic, + "partition_key", partitionKey, + "error", err, + ) + // Continue trying to release other leases even if one fails + } + } +} + +// pollLeasedPartitions fetches and delivers messages from all leased partitions +func (s *subscriber) pollLeasedPartitions(ctx context.Context, sub *subscription) { + // Discover and try to acquire leases for new partitions + acquiredCount, err := s.leaseStore.DiscoverAndAcquirePartitions(ctx, sub.topic) + if err == nil && acquiredCount > 0 { + s.metrics.Tagged(map[string]string{"topic": sub.topic}).Counter("leases_acquired").Inc(int64(acquiredCount)) + } + + // Get currently leased partitions + leasedPartitions, err := s.leaseStore.GetLeasedPartitions(ctx, sub.topic) + if err != nil { + s.logger.Errorw("failed to get leased partitions", "topic", sub.topic, "error", err) + return + } + + // Poll each leased partition + for _, partitionKey := range leasedPartitions { + // Check if context was cancelled before processing next partition + select { + case <-ctx.Done(): + return + default: + s.fetchAndDeliverPartition(ctx, sub, partitionKey) + } + } +} + +// fetchAndDeliverPartition fetches messages from a specific partition and delivers them +func (s *subscriber) fetchAndDeliverPartition(ctx context.Context, sub *subscription, partitionKey string) { + start := time.Now() + + // Initialize offset for this partition if needed + if err := s.offsetStore.Initialize(ctx, sub.topic, partitionKey); err != nil { + s.logger.Errorw("offset initialization failure", "topic", sub.topic, "partition_key", partitionKey, "error", err) + return + } + + // Get current offset for this partition + currentOffset, err := s.offsetStore.GetAckedOffset(ctx, sub.topic, partitionKey) + if err != nil { + s.logger.Errorw("get current offset failure", "topic", sub.topic, "partition_key", partitionKey, "error", err) + return + } + + // Fetch messages for this partition + rows, err := s.messageStore.FetchByOffset(ctx, sub.topic, partitionKey, currentOffset, s.config.BatchSize) + if err != nil { + s.logger.Errorw("fetch messages failure", "topic", sub.topic, "partition_key", partitionKey, "error", err) + return + } + + messageCount := 0 + for _, row := range rows { + // Check if message has exceeded retry limit (persistent retry_count from DB) + if row.RetryCount >= s.config.Retry.MaxAttempts { + s.logger.Warnw("message exceeded retry limit", + "topic", sub.topic, + "partition_key", partitionKey, + "message_id", row.ID, + "retry_count", row.RetryCount, + ) + + // Move to DLQ if enabled + if s.config.DLQ.Enabled { + if err := s.messageStore.MoveToDLQ(ctx, sub.topic, row.ID, row.RetryCount, "exceeded retry limit"); err != nil { + s.logger.Errorw("failed to move message to DLQ", + "topic", sub.topic, + "message_id", row.ID, + "error", err, + ) + } else { + s.logger.Infow("moved message to DLQ", + "topic", sub.topic, + "message_id", row.ID, + "retry_count", row.RetryCount, + ) + s.metrics.Tagged(map[string]string{ + "topic": sub.topic, + "partition_key": partitionKey, + }).Counter("messages_moved_to_dlq").Inc(1) + + // Update offset since message is now processed (moved to DLQ) + if err := s.offsetStore.UpdateAckedOffset(ctx, sub.topic, partitionKey, row.Offset); err != nil { + s.logger.Errorw("failed to update offset after DLQ move", + "topic", sub.topic, + "partition_key", partitionKey, + "offset", row.Offset, + "error", err, + ) + } + } + } + continue + } + + // Create message (value type) + msg := queue.NewMessage(row.ID, row.Payload) + msg.Metadata = row.Metadata + msg.PartitionKey = row.PartitionKey + msg.PublishedAt = row.PublishedAt + + // Calculate message age for metrics + messageAge := time.Duration(time.Now().UnixMilli()-row.PublishedAt) * time.Millisecond + s.metrics.Tagged(map[string]string{ + "topic": sub.topic, + "partition_key": partitionKey, + }).Timer("message_age").Record(messageAge) + + // Create delivery ID from offset + deliveryID := strconv.FormatInt(row.Offset, 10) + + // Create delivery metadata + deliveryMetadata := map[string]string{ + "topic": sub.topic, + "partition_key": partitionKey, + "offset": deliveryID, + } + + // Create SQL delivery implementation + delivery := newSQLDelivery( + msg, + deliveryID, + row.RetryCount+1, // RetryCount is 0-based, Attempt is 1-based + deliveryMetadata, + s, + sub.topic, + partitionKey, + row.Offset, + row.ID, + ) + + // Deliver message + select { + case sub.deliveryCh <- delivery: + messageCount++ + case <-ctx.Done(): + return + } + } + + // Record metrics + if messageCount > 0 { + elapsed := time.Since(start) + partitionTags := map[string]string{ + "topic": sub.topic, + "partition_key": partitionKey, + } + s.metrics.Tagged(partitionTags).Counter("messages_received").Inc(int64(messageCount)) + s.metrics.Tagged(partitionTags).Timer("poll_latency").Record(elapsed) + + s.logger.Debugw("delivered messages", + "topic", sub.topic, + "partition_key", partitionKey, + "count", messageCount, + "duration_ms", elapsed.Milliseconds(), + ) + } +} + +// Close gracefully shuts down the subscriber +func (s *subscriber) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return nil + } + + s.logger.Info("closing subscriber") + + s.subMu.Lock() + defer s.subMu.Unlock() + + // Cancel all subscriptions + for topic, sub := range s.subscriptions { + s.logger.Debugw("closing subscription", "topic", topic) + sub.cancelFunc() + + // Wait for goroutine to finish with timeout + done := make(chan struct{}) + go func() { + sub.wg.Wait() + close(done) + }() + + select { + case <-done: + // Graceful shutdown completed + case <-time.After(30 * time.Second): + s.logger.Warnw("subscription shutdown timeout", "topic", topic) + } + + // Update metrics + s.metrics.Tagged(map[string]string{"topic": topic}).Gauge("active_subscriptions").Update(0) + } + + s.subscriptions = make(map[string]*subscription) + + s.closed = true + + s.logger.Info("subscriber closed") + return nil +} diff --git a/extensions/queue/sql/subscriber_test.go b/extensions/queue/sql/subscriber_test.go new file mode 100644 index 00000000..cf97491d --- /dev/null +++ b/extensions/queue/sql/subscriber_test.go @@ -0,0 +1,146 @@ +package sql + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/uber-go/tally/v4" + "go.uber.org/mock/gomock" + "go.uber.org/zap/zaptest" + + extqueue "github.com/uber/submitqueue/extensions/queue" +) + +func setupSubscriberTest(t *testing.T, mockMessageStore *MockmessageStore, mockOffsetStore *MockoffsetStore, mockLeaseStore *MockpartitionLeaseStore) extqueue.Subscriber { + t.Helper() + + config := DefaultConfig("test-consumer", "test-worker") + + return NewSubscriber(config, zaptest.NewLogger(t).Sugar().Named("subscriber"), tally.NoopScope.SubScope("subscriber"), mockMessageStore, mockOffsetStore, mockLeaseStore) +} + +func TestSubscriber_Subscribe(t *testing.T) { + tests := []struct { + name string + topics []string + expectSame bool + expectedChans int + }{ + { + name: "single topic subscription", + topics: []string{"test_topic"}, + expectedChans: 1, + }, + { + name: "multiple different topics", + topics: []string{"topic1", "topic2"}, + expectedChans: 2, + }, + { + name: "same topic returns same channel", + topics: []string{"test_topic", "test_topic"}, + expectSame: true, + expectedChans: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockMessageStore := NewMockmessageStore(ctrl) + mockOffsetStore := NewMockoffsetStore(ctrl) + mockLeaseStore := NewMockpartitionLeaseStore(ctrl) + + sub := setupSubscriberTest(t, mockMessageStore, mockOffsetStore, mockLeaseStore) + ctx := context.Background() + + var channels []<-chan extqueue.Delivery + for _, topic := range tt.topics { + ch, err := sub.Subscribe(ctx, topic) + require.NoError(t, err) + assert.NotNil(t, ch) + channels = append(channels, ch) + } + + if tt.expectSame && len(channels) == 2 { + assert.Equal(t, channels[0], channels[1], "should return same channel for same topic") + } + }) + } +} + +func TestSubscriber_Close(t *testing.T) { + tests := []struct { + name string + setupSub func(ctx context.Context, sub extqueue.Subscriber) error + closeCount int + subscribeAfter bool + expectSubError bool + }{ + { + name: "close with active subscription", + setupSub: func(ctx context.Context, sub extqueue.Subscriber) error { + _, err := sub.Subscribe(ctx, "test_topic") + return err + }, + closeCount: 1, + }, + { + name: "close is idempotent", + setupSub: func(ctx context.Context, sub extqueue.Subscriber) error { return nil }, + closeCount: 2, + }, + { + name: "subscribe after close fails", + setupSub: func(ctx context.Context, sub extqueue.Subscriber) error { return nil }, + closeCount: 1, + subscribeAfter: true, + expectSubError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockMessageStore := NewMockmessageStore(ctrl) + mockOffsetStore := NewMockoffsetStore(ctrl) + mockLeaseStore := NewMockpartitionLeaseStore(ctrl) + + // Expect lease operations during cleanup + mockLeaseStore.EXPECT().GetLeasedPartitions(gomock.Any(), gomock.Any()).Return([]string{}, nil).AnyTimes() + + sub := setupSubscriberTest(t, mockMessageStore, mockOffsetStore, mockLeaseStore) + ctx := context.Background() + + // Setup subscription if needed + if tt.setupSub != nil { + err := tt.setupSub(ctx, sub) + require.NoError(t, err) + } + + // Close multiple times if needed + for i := 0; i < tt.closeCount; i++ { + err := sub.Close() + require.NoError(t, err) + } + + // Try to subscribe after close if needed + if tt.subscribeAfter { + ch, err := sub.Subscribe(ctx, "test_topic") + if tt.expectSubError { + require.Error(t, err) + assert.Nil(t, ch) + } else { + require.NoError(t, err) + assert.NotNil(t, ch) + } + } + }) + } +}