diff --git a/.golangci.yml b/.golangci.yml index 40bfdcaa551..d9c770d0712 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -105,7 +105,7 @@ issues: text: "(cyclomatic|cognitive)" linters: - revive - - path: _test\.go + - path: _test\.go|^common/persistence\/tests\/.+\.go # Ignore things like err = errors.New("test error") in tests linters: - goerr113 - path: ^tools\/.+\.go diff --git a/common/persistence/data_interfaces.go b/common/persistence/data_interfaces.go index d878bb2372a..92ff551355b 100644 --- a/common/persistence/data_interfaces.go +++ b/common/persistence/data_interfaces.go @@ -1218,6 +1218,7 @@ type ( ReadTasks(ctx context.Context, request *ReadTasksRequest) (*ReadTasksResponse, error) // CreateQueue must return an ErrQueueAlreadyExists if the queue already exists. CreateQueue(ctx context.Context, request *CreateQueueRequest) (*CreateQueueResponse, error) + DeleteTasks(ctx context.Context, request *DeleteTasksRequest) (*DeleteTasksResponse, error) } HistoryTaskQueueManagerImpl struct { @@ -1282,6 +1283,15 @@ type ( CreateQueueResponse struct { } + + DeleteTasksRequest struct { + QueueKey QueueKey + InclusiveMaxMessageMetadata MessageMetadata + } + + DeleteTasksResponse struct { + // empty + } ) func (e *InvalidPersistenceRequestError) Error() string { diff --git a/common/persistence/data_interfaces_mock.go b/common/persistence/data_interfaces_mock.go index a0b80e22380..f15bcacdd87 100644 --- a/common/persistence/data_interfaces_mock.go +++ b/common/persistence/data_interfaces_mock.go @@ -1328,6 +1328,21 @@ func (mr *MockHistoryTaskQueueManagerMockRecorder) CreateQueue(ctx, request inte return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateQueue", reflect.TypeOf((*MockHistoryTaskQueueManager)(nil).CreateQueue), ctx, request) } +// DeleteTasks mocks base method. +func (m *MockHistoryTaskQueueManager) DeleteTasks(ctx context.Context, request *DeleteTasksRequest) (*DeleteTasksResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteTasks", ctx, request) + ret0, _ := ret[0].(*DeleteTasksResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeleteTasks indicates an expected call of DeleteTasks. +func (mr *MockHistoryTaskQueueManagerMockRecorder) DeleteTasks(ctx, request interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTasks", reflect.TypeOf((*MockHistoryTaskQueueManager)(nil).DeleteTasks), ctx, request) +} + // EnqueueTask mocks base method. func (m *MockHistoryTaskQueueManager) EnqueueTask(ctx context.Context, request *EnqueueTaskRequest) (*EnqueueTaskResponse, error) { m.ctrl.T.Helper() diff --git a/common/persistence/history_task_queue_manager.go b/common/persistence/history_task_queue_manager.go index 46b21e01785..9aa05f9ac38 100644 --- a/common/persistence/history_task_queue_manager.go +++ b/common/persistence/history_task_queue_manager.go @@ -188,7 +188,10 @@ func (m *HistoryTaskQueueManagerImpl) ReadTasks(ctx context.Context, request *Re }, nil } -func (m *HistoryTaskQueueManagerImpl) CreateQueue(ctx context.Context, request *CreateQueueRequest) (*CreateQueueResponse, error) { +func (m *HistoryTaskQueueManagerImpl) CreateQueue( + ctx context.Context, + request *CreateQueueRequest, +) (*CreateQueueResponse, error) { _, err := m.queue.CreateQueue(ctx, &InternalCreateQueueRequest{ QueueType: request.QueueKey.QueueType, QueueName: request.QueueKey.GetQueueName(), @@ -199,6 +202,20 @@ func (m *HistoryTaskQueueManagerImpl) CreateQueue(ctx context.Context, request * return &CreateQueueResponse{}, nil } +func (m *HistoryTaskQueueManagerImpl) DeleteTasks( + ctx context.Context, + request *DeleteTasksRequest, +) (*DeleteTasksResponse, error) { + _, err := m.queue.RangeDeleteMessages(ctx, &InternalRangeDeleteMessagesRequest{ + QueueType: request.QueueKey.QueueType, + QueueName: request.QueueKey.GetQueueName(), + }) + if err != nil { + return nil, err + } + return &DeleteTasksResponse{}, nil +} + // combineUnique combines the given strings into a single string by hashing the length of each string and the string // itself. This is used to generate a unique suffix for the queue name. func combineUnique(strs ...string) string { diff --git a/common/persistence/history_task_queue_manager_test.go b/common/persistence/history_task_queue_manager_test.go index 876c36810f2..f1730ff26a2 100644 --- a/common/persistence/history_task_queue_manager_test.go +++ b/common/persistence/history_task_queue_manager_test.go @@ -167,65 +167,3 @@ func TestHistoryTaskQueueManager_ReadTasks_NonPositivePageSize(t *testing.T) { "ErrReadTasksNonPositivePageSize when the request's page size is: "+strconv.Itoa(pageSize)) } } - -// failingQueue is a QueueV2 implementation that always returns an error. -type failingQueue struct{} - -func (q failingQueue) EnqueueMessage( - context.Context, - *persistence.InternalEnqueueMessageRequest, -) (*persistence.InternalEnqueueMessageResponse, error) { - return nil, assert.AnError -} - -func (q failingQueue) ReadMessages( - context.Context, - *persistence.InternalReadMessagesRequest, -) (*persistence.InternalReadMessagesResponse, error) { - return nil, assert.AnError -} - -func (q failingQueue) CreateQueue( - context.Context, - *persistence.InternalCreateQueueRequest, -) (*persistence.InternalCreateQueueResponse, error) { - return nil, assert.AnError -} - -func (q failingQueue) RangeDeleteMessages( - context.Context, - *persistence.InternalRangeDeleteMessagesRequest, -) (*persistence.InternalRangeDeleteMessagesResponse, error) { - return nil, assert.AnError -} - -func TestHistoryTaskQueueManager_ReadTasks_ErrReadQueueMessages(t *testing.T) { - t.Parallel() - - m := persistence.NewHistoryTaskQueueManager(failingQueue{}, 1) - _, err := m.ReadTasks(context.Background(), &persistence.ReadTasksRequest{ - QueueKey: persistence.QueueKey{ - Category: tasks.CategoryTransfer, - }, - PageSize: 1, - }) - assert.ErrorIs(t, err, assert.AnError, "ReadTasks should propagate errors from ReadMessages") -} - -func TestHistoryTaskQueueManager_ReadTasks_ErrEnqueueMessage(t *testing.T) { - t.Parallel() - - m := persistence.NewHistoryTaskQueueManager(failingQueue{}, 1) - _, err := m.EnqueueTask(context.Background(), &persistence.EnqueueTaskRequest{ - Task: &tasks.WorkflowTask{}, - }) - assert.ErrorIs(t, err, assert.AnError, "EnqueueTask should propagate errors from EnqueueMessage") -} - -func TestHistoryTaskQueueManager_CreateQueue(t *testing.T) { - t.Parallel() - - m := persistence.NewHistoryTaskQueueManager(failingQueue{}, 1) - _, err := m.CreateQueue(context.Background(), &persistence.CreateQueueRequest{}) - assert.ErrorIs(t, err, assert.AnError, "CreateQueue should propagate errors from the persistence layer") -} diff --git a/common/persistence/tests/history_task_queue_manager_test_suite.go b/common/persistence/tests/history_task_queue_manager_test_suite.go index 35d94f4e2dc..c658a87854e 100644 --- a/common/persistence/tests/history_task_queue_manager_test_suite.go +++ b/common/persistence/tests/history_task_queue_manager_test_suite.go @@ -26,6 +26,7 @@ package tests import ( "context" + "errors" "testing" "time" @@ -44,18 +45,84 @@ import ( "go.temporal.io/server/service/history/tasks" ) +type ( + faultyQueue struct { + base persistence.QueueV2 + enqueueErr error + readMessagesErr error + createQueueErr error + rangeDeleteMessagesErr error + } +) + +func (q faultyQueue) EnqueueMessage( + ctx context.Context, + req *persistence.InternalEnqueueMessageRequest, +) (*persistence.InternalEnqueueMessageResponse, error) { + if q.enqueueErr != nil { + return nil, q.enqueueErr + } + return q.base.EnqueueMessage(ctx, req) +} + +func (q faultyQueue) ReadMessages( + ctx context.Context, + req *persistence.InternalReadMessagesRequest, +) (*persistence.InternalReadMessagesResponse, error) { + if q.readMessagesErr != nil { + return nil, q.readMessagesErr + } + return q.base.ReadMessages(ctx, req) +} + +func (q faultyQueue) CreateQueue( + ctx context.Context, + req *persistence.InternalCreateQueueRequest, +) (*persistence.InternalCreateQueueResponse, error) { + if q.createQueueErr != nil { + return nil, q.createQueueErr + } + return q.base.CreateQueue(ctx, req) +} + +func (q faultyQueue) RangeDeleteMessages( + ctx context.Context, + req *persistence.InternalRangeDeleteMessagesRequest, +) (*persistence.InternalRangeDeleteMessagesResponse, error) { + if q.rangeDeleteMessagesErr != nil { + return nil, q.rangeDeleteMessagesErr + } + return q.base.RangeDeleteMessages(ctx, req) +} + // RunHistoryTaskQueueManagerTestSuite runs all tests for the history task queue manager against a given queue provided by a // particular database. This test suite should be re-used to test all queue implementations. func RunHistoryTaskQueueManagerTestSuite(t *testing.T, queue persistence.QueueV2) { historyTaskQueueManager := persistence.NewHistoryTaskQueueManager(queue, 1) - t.Run("TestHistoryTaskQueueManagerHappyPath", func(t *testing.T) { + t.Run("TestHistoryTaskQueueManagerEnqueueTasks", func(t *testing.T) { + t.Parallel() + testHistoryTaskQueueManagerEnqueueTasks(t, historyTaskQueueManager) + }) + t.Run("TestHistoryTaskQueueManagerEnqueueTasksErr", func(t *testing.T) { + t.Parallel() + testHistoryTaskQueueManagerEnqueueTasksErr(t, queue) + }) + t.Run("TestHistoryTaskQueueManagerCreateQueueErr", func(t *testing.T) { t.Parallel() - testHistoryTaskQueueManagerHappyPath(t, historyTaskQueueManager) + testHistoryTaskQueueManagerCreateQueueErr(t, queue) }) t.Run("TestHistoryTaskQueueManagerErrDeserializeTask", func(t *testing.T) { t.Parallel() testHistoryTaskQueueManagerErrDeserializeHistoryTask(t, queue, historyTaskQueueManager) }) + t.Run("DeleteTasks", func(t *testing.T) { + t.Parallel() + testHistoryTaskQueueManagerDeleteTasks(t, historyTaskQueueManager) + }) + t.Run("DeleteTasksErr", func(t *testing.T) { + t.Parallel() + testHistoryTaskQueueManagerDeleteTasksErr(t, queue) + }) t.Run("GetDLQTasks", func(t *testing.T) { t.Parallel() getdlqtaskstest.TestGetDLQTasks(t, historyTaskQueueManager) @@ -70,7 +137,19 @@ func RunHistoryTaskQueueManagerTestSuite(t *testing.T, queue persistence.QueueV2 }) } -func testHistoryTaskQueueManagerHappyPath(t *testing.T, manager persistence.HistoryTaskQueueManager) { +func testHistoryTaskQueueManagerCreateQueueErr(t *testing.T, queue persistence.QueueV2) { + retErr := errors.New("test") + manager := persistence.NewHistoryTaskQueueManager(faultyQueue{ + base: queue, + createQueueErr: retErr, + }, 1) + _, err := manager.CreateQueue(context.Background(), &persistence.CreateQueueRequest{ + QueueKey: getQueueKey(t), + }) + assert.ErrorIs(t, err, retErr) +} + +func testHistoryTaskQueueManagerEnqueueTasks(t *testing.T, manager persistence.HistoryTaskQueueManager) { numHistoryShards := 5 ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) t.Cleanup(cancel) @@ -81,12 +160,7 @@ func testHistoryTaskQueueManagerHappyPath(t *testing.T, manager persistence.Hist shardID := 2 assert.Equal(t, int32(shardID), common.WorkflowIDToHistoryShard(namespaceID, workflowID, int32(numHistoryShards))) - category := tasks.CategoryTransfer - queueKey := persistence.QueueKey{ - QueueType: persistence.QueueTypeHistoryNormal, - Category: category, - SourceCluster: "test-source-cluster-" + t.Name(), - } + queueKey := getQueueKey(t) _, err := manager.CreateQueue(ctx, &persistence.CreateQueueRequest{ QueueKey: queueKey, }) @@ -97,11 +171,7 @@ func testHistoryTaskQueueManagerHappyPath(t *testing.T, manager persistence.Hist WorkflowKey: workflowKey, TaskID: int64(i + 1), } - res, err := manager.EnqueueTask(ctx, &persistence.EnqueueTaskRequest{ - QueueType: queueKey.QueueType, - SourceCluster: queueKey.SourceCluster, - Task: task, - }) + res, err := enqueueTask(ctx, manager, queueKey, task) require.NoError(t, err) assert.Equal(t, int64(persistence.FirstQueueMessageID+i), res.Metadata.ID) } @@ -127,6 +197,25 @@ func testHistoryTaskQueueManagerHappyPath(t *testing.T, manager persistence.Hist } } +func testHistoryTaskQueueManagerEnqueueTasksErr(t *testing.T, queue persistence.QueueV2) { + ctx := context.Background() + + retErr := errors.New("test") + manager := persistence.NewHistoryTaskQueueManager(faultyQueue{ + base: queue, + enqueueErr: retErr, + }, 1) + queueKey := getQueueKey(t) + _, err := manager.CreateQueue(ctx, &persistence.CreateQueueRequest{ + QueueKey: queueKey, + }) + require.NoError(t, err) + _, err = enqueueTask(ctx, manager, queueKey, &tasks.WorkflowTask{ + TaskID: 1, + }) + assert.ErrorIs(t, err, retErr) +} + func testHistoryTaskQueueManagerErrDeserializeHistoryTask( t *testing.T, queue persistence.QueueV2, @@ -148,6 +237,36 @@ func testHistoryTaskQueueManagerErrDeserializeHistoryTask( }) } +func testHistoryTaskQueueManagerDeleteTasks(t *testing.T, manager *persistence.HistoryTaskQueueManagerImpl) { + ctx := context.Background() + + queueKey := getQueueKey(t) + _, err := manager.CreateQueue(ctx, &persistence.CreateQueueRequest{ + QueueKey: queueKey, + }) + require.NoError(t, err) + for i := 0; i < 2; i++ { + _, err := enqueueTask(ctx, manager, queueKey, &tasks.WorkflowTask{ + TaskID: int64(i + 1), + }) + require.NoError(t, err) + } + _, err = manager.DeleteTasks(ctx, &persistence.DeleteTasksRequest{ + QueueKey: queueKey, + InclusiveMaxMessageMetadata: persistence.MessageMetadata{ + ID: persistence.FirstQueueMessageID, + }, + }) + require.NoError(t, err) + res, err := manager.ReadTasks(ctx, &persistence.ReadTasksRequest{ + QueueKey: queueKey, + PageSize: 10, + }) + require.NoError(t, err) + require.Len(t, res.Tasks, 1) + assert.Equal(t, int64(2), res.Tasks[0].Task.GetTaskID()) +} + func enqueueAndDeserializeBlob( ctx context.Context, t *testing.T, @@ -158,11 +277,7 @@ func enqueueAndDeserializeBlob( t.Helper() queueType := persistence.QueueTypeHistoryNormal - queueKey := persistence.QueueKey{ - QueueType: queueType, - Category: tasks.CategoryTransfer, - SourceCluster: "test-source-cluster-" + t.Name(), - } + queueKey := getQueueKey(t) _, err := queue.CreateQueue(ctx, &persistence.InternalCreateQueueRequest{ QueueType: queueType, QueueName: queueKey.GetQueueName(), @@ -190,3 +305,52 @@ func enqueueAndDeserializeBlob( }) return err } + +func testHistoryTaskQueueManagerDeleteTasksErr(t *testing.T, queue persistence.QueueV2) { + ctx := context.Background() + + retErr := errors.New("test") + manager := persistence.NewHistoryTaskQueueManager(faultyQueue{ + base: queue, + rangeDeleteMessagesErr: retErr, + }, 1) + queueKey := getQueueKey(t) + _, err := manager.CreateQueue(ctx, &persistence.CreateQueueRequest{ + QueueKey: queueKey, + }) + require.NoError(t, err) + _, err = enqueueTask(ctx, manager, queueKey, &tasks.WorkflowTask{ + TaskID: 1, + }) + require.NoError(t, err) + _, err = manager.DeleteTasks(ctx, &persistence.DeleteTasksRequest{ + QueueKey: queueKey, + InclusiveMaxMessageMetadata: persistence.MessageMetadata{ + ID: persistence.FirstQueueMessageID, + }, + }) + assert.ErrorIs(t, err, retErr) +} + +func getQueueKey(t *testing.T) persistence.QueueKey { + return persistence.QueueKey{ + QueueType: persistence.QueueTypeHistoryNormal, + Category: tasks.CategoryTransfer, + SourceCluster: "test-source-cluster-" + t.Name(), + TargetCluster: "test-target-cluster-" + t.Name(), + } +} + +func enqueueTask( + ctx context.Context, + manager persistence.HistoryTaskQueueManager, + queueKey persistence.QueueKey, + task *tasks.WorkflowTask, +) (*persistence.EnqueueTaskResponse, error) { + return manager.EnqueueTask(ctx, &persistence.EnqueueTaskRequest{ + QueueType: queueKey.QueueType, + SourceCluster: queueKey.SourceCluster, + TargetCluster: queueKey.TargetCluster, + Task: task, + }) +} diff --git a/common/persistence/tests/queue_v2_test_suite.go b/common/persistence/tests/queue_v2_test_suite.go index db5da632e4b..63c9e4d8ea3 100644 --- a/common/persistence/tests/queue_v2_test_suite.go +++ b/common/persistence/tests/queue_v2_test_suite.go @@ -28,7 +28,6 @@ import ( "context" "strconv" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -43,8 +42,6 @@ import ( // implementation-specific tests that will not be covered by this suite elsewhere. func RunQueueV2TestSuite(t *testing.T, queue persistence.QueueV2) { ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, time.Minute) - t.Cleanup(cancel) queueType := persistence.QueueTypeHistoryNormal queueName := "test-queue-" + t.Name()