From f4c08526e6bea18addca2d0b0bde9c9509bcb650 Mon Sep 17 00:00:00 2001 From: Yichao Yang Date: Mon, 21 Mar 2022 15:36:28 -0700 Subject: [PATCH] Persistence Context Part 1: Execution Manager (#2622) --- common/archiver/historyIterator.go | 3 +- common/archiver/historyIterator_test.go | 10 +- common/persistence/dataInterfaces.go | 57 +- common/persistence/dataInterfaces_mock.go | 225 +++---- common/persistence/execution_manager.go | 29 +- common/persistence/history_manager.go | 13 +- common/persistence/history_manager_util.go | 25 +- .../persistence-tests/executionManagerTest.go | 563 +++++++++--------- .../executionManagerTestForEventsV2.go | 44 +- .../historyV2PersistenceTest.go | 54 +- .../persistence-tests/persistenceTestBase.go | 201 ++++--- .../persistence/persistenceMetricClients.go | 180 ++++-- .../persistenceRateLimitedClients.go | 182 ++++-- .../tests/execution_mutable_state.go | 111 ++-- .../tests/execution_mutable_state_task.go | 15 +- common/persistence/tests/history_store.go | 27 +- host/archival_test.go | 4 +- host/ndc/replication_integration_test.go | 3 +- service/frontend/adminHandler.go | 4 +- service/frontend/adminHandler_test.go | 2 +- service/frontend/workflowHandler.go | 17 +- service/frontend/workflowHandler_test.go | 9 +- service/history/events/cache.go | 10 +- service/history/events/cache_test.go | 19 +- service/history/events/events_cache_mock.go | 9 +- service/history/handler.go | 4 +- service/history/historyEngine.go | 63 +- service/history/historyEngine2_test.go | 111 ++-- .../history/historyEngine3_eventsv2_test.go | 16 +- service/history/historyEngine_test.go | 342 +++++------ service/history/nDCActivityReplicator.go | 3 +- service/history/nDCActivityReplicator_test.go | 12 +- service/history/nDCBranchMgr.go | 3 +- service/history/nDCBranchMgr_test.go | 10 +- service/history/nDCHistoryReplicator.go | 2 +- service/history/nDCStateRebuilder.go | 6 +- service/history/nDCStateRebuilder_test.go | 12 +- service/history/nDCTaskUtil.go | 22 +- service/history/nDCTransactionMgr.go | 10 +- .../nDCTransactionMgrForExistingWorkflow.go | 13 +- ...CTransactionMgrForExistingWorkflow_test.go | 9 + .../nDCTransactionMgrForNewWorkflow.go | 17 +- .../nDCTransactionMgrForNewWorkflow_test.go | 9 + service/history/nDCTransactionMgr_test.go | 40 +- service/history/nDCWorkflowResetter.go | 2 +- service/history/nDCWorkflowResetter_test.go | 2 +- service/history/replicationDLQHandler.go | 6 +- service/history/replicationDLQHandler_mock.go | 8 +- service/history/replicationDLQHandler_test.go | 9 +- service/history/replicationTaskProcessor.go | 3 +- .../history/replicationTaskProcessor_test.go | 5 +- service/history/replicatorQueueProcessor.go | 9 +- .../history/replicatorQueueProcessor_test.go | 2 +- service/history/shard/context.go | 15 +- service/history/shard/context_impl.go | 30 +- service/history/shard/context_mock.go | 57 +- service/history/shard/context_test.go | 5 +- service/history/timerQueueAckMgr.go | 3 +- service/history/timerQueueAckMgr_test.go | 26 +- .../history/timerQueueActiveTaskExecutor.go | 29 +- .../timerQueueActiveTaskExecutor_test.go | 56 +- service/history/timerQueueProcessor.go | 2 +- .../history/timerQueueStandbyTaskExecutor.go | 4 +- .../timerQueueStandbyTaskExecutor_test.go | 30 +- service/history/timerQueueTaskExecutorBase.go | 5 +- .../timerQueueTaskExecutorBase_test.go | 6 +- .../transferQueueActiveTaskExecutor.go | 58 +- .../transferQueueActiveTaskExecutor_test.go | 58 +- service/history/transferQueueProcessor.go | 2 +- service/history/transferQueueProcessorBase.go | 4 +- .../transferQueueStandbyTaskExecutor.go | 4 +- .../transferQueueStandbyTaskExecutor_test.go | 28 +- .../history/transferQueueTaskExecutorBase.go | 3 +- service/history/visibilityQueueProcessor.go | 4 +- .../history/visibilityQueueTaskExecutor.go | 8 +- .../visibilityQueueTaskExecutor_test.go | 6 +- service/history/workflow/cache.go | 8 +- service/history/workflow/context.go | 50 +- service/history/workflow/context_mock.go | 88 +-- service/history/workflow/delete_manager.go | 22 +- .../history/workflow/delete_manager_mock.go | 25 +- .../history/workflow/delete_manager_test.go | 18 +- service/history/workflow/mutable_state.go | 11 +- .../history/workflow/mutable_state_impl.go | 22 +- .../workflow/mutable_state_impl_test.go | 9 +- .../history/workflow/mutable_state_mock.go | 41 +- service/history/workflow/retry.go | 2 + service/history/workflow/task_refresher.go | 24 +- .../history/workflow/task_refresher_mock.go | 9 +- service/history/workflow/transaction.go | 6 + service/history/workflow/transaction_impl.go | 40 +- service/history/workflow/transaction_mock.go | 33 +- service/history/workflow/transaction_test.go | 10 +- service/history/workflowExecutionUtil.go | 10 +- service/history/workflowRebuilder.go | 9 +- service/history/workflowResetter.go | 17 +- service/history/workflowResetter_test.go | 29 +- service/history/workflowTaskHandler.go | 59 +- .../history/workflowTaskHandlerCallbacks.go | 15 +- service/worker/archiver/activities.go | 2 +- service/worker/archiver/activities_test.go | 2 +- .../executions/history_event_id_validator.go | 7 +- .../worker/scanner/executions/interfaces.go | 4 +- .../executions/mutable_state_id_validator.go | 2 + service/worker/scanner/executions/task.go | 7 +- service/worker/scanner/history/scavenger.go | 10 +- .../worker/scanner/history/scavenger_test.go | 32 +- .../scanner/taskqueue/scavenger_test.go | 2 +- tools/cli/adminCommands.go | 8 +- tools/cli/adminDBScanCommand.go | 15 +- 110 files changed, 2109 insertions(+), 1548 deletions(-) diff --git a/common/archiver/historyIterator.go b/common/archiver/historyIterator.go index 6910c4d8e6f..eee30d6231b 100644 --- a/common/archiver/historyIterator.go +++ b/common/archiver/historyIterator.go @@ -28,6 +28,7 @@ package archiver import ( "bytes" + "context" "encoding/json" "errors" @@ -220,7 +221,7 @@ func (i *historyIterator) readHistory(firstEventID int64) ([]*historypb.History, PageSize: i.historyPageSize, ShardID: i.request.ShardID, } - historyBatches, _, _, err := persistence.ReadFullPageEventsByBatch(i.executionManager, req) + historyBatches, _, _, err := persistence.ReadFullPageEventsByBatch(context.TODO(), i.executionManager, req) return historyBatches, err } diff --git a/common/archiver/historyIterator_test.go b/common/archiver/historyIterator_test.go index 5d6937d6568..3fad9687062 100644 --- a/common/archiver/historyIterator_test.go +++ b/common/archiver/historyIterator_test.go @@ -106,7 +106,7 @@ func (s *HistoryIteratorSuite) TearDownTest() { } func (s *HistoryIteratorSuite) TestReadHistory_Failed_EventsV2() { - s.mockExecutionMgr.EXPECT().ReadHistoryBranchByBatch(gomock.Any()).Return(nil, errors.New("got error reading history branch")) + s.mockExecutionMgr.EXPECT().ReadHistoryBranchByBatch(gomock.Any(), gomock.Any()).Return(nil, errors.New("got error reading history branch")) itr := s.constructTestHistoryIterator(s.mockExecutionMgr, testDefaultTargetHistoryBlobSize, nil) history, err := itr.readHistory(common.FirstEventID) s.Error(err) @@ -118,7 +118,7 @@ func (s *HistoryIteratorSuite) TestReadHistory_Success_EventsV2() { History: []*historypb.History{}, NextPageToken: []byte{}, } - s.mockExecutionMgr.EXPECT().ReadHistoryBranchByBatch(gomock.Any()).Return(&resp, nil) + s.mockExecutionMgr.EXPECT().ReadHistoryBranchByBatch(gomock.Any(), gomock.Any()).Return(&resp, nil) itr := s.constructTestHistoryIterator(s.mockExecutionMgr, testDefaultTargetHistoryBlobSize, nil) history, err := itr.readHistory(common.FirstEventID) s.NoError(err) @@ -628,14 +628,14 @@ func (s *HistoryIteratorSuite) initMockExecutionManager(batchInfo []int, returnE ShardID: testShardId, } if returnErrorOnPage == i { - s.mockExecutionMgr.EXPECT().ReadHistoryBranchByBatch(req).Return(nil, errors.New("got error getting workflow execution history")) + s.mockExecutionMgr.EXPECT().ReadHistoryBranchByBatch(gomock.Any(), req).Return(nil, errors.New("got error getting workflow execution history")) return } resp := &persistence.ReadHistoryBranchByBatchResponse{ History: s.constructHistoryBatches(batchInfo, p, firstEventIDs[p.firstbatchIdx]), } - s.mockExecutionMgr.EXPECT().ReadHistoryBranchByBatch(req).Return(resp, nil).MaxTimes(2) + s.mockExecutionMgr.EXPECT().ReadHistoryBranchByBatch(gomock.Any(), req).Return(resp, nil).MaxTimes(2) } if addNotExistCall { @@ -646,7 +646,7 @@ func (s *HistoryIteratorSuite) initMockExecutionManager(batchInfo []int, returnE PageSize: testDefaultPersistencePageSize, ShardID: testShardId, } - s.mockExecutionMgr.EXPECT().ReadHistoryBranchByBatch(req).Return(nil, serviceerror.NewNotFound("Reach the end")) + s.mockExecutionMgr.EXPECT().ReadHistoryBranchByBatch(gomock.Any(), req).Return(nil, serviceerror.NewNotFound("Reach the end")) } } diff --git a/common/persistence/dataInterfaces.go b/common/persistence/dataInterfaces.go index f618e1b9830..4ccaec3f909 100644 --- a/common/persistence/dataInterfaces.go +++ b/common/persistence/dataInterfaces.go @@ -985,62 +985,61 @@ type ( Closeable GetName() string - CreateWorkflowExecution(request *CreateWorkflowExecutionRequest) (*CreateWorkflowExecutionResponse, error) - UpdateWorkflowExecution(request *UpdateWorkflowExecutionRequest) (*UpdateWorkflowExecutionResponse, error) - ConflictResolveWorkflowExecution(request *ConflictResolveWorkflowExecutionRequest) (*ConflictResolveWorkflowExecutionResponse, error) - DeleteWorkflowExecution(request *DeleteWorkflowExecutionRequest) error - DeleteCurrentWorkflowExecution(request *DeleteCurrentWorkflowExecutionRequest) error - GetCurrentExecution(request *GetCurrentExecutionRequest) (*GetCurrentExecutionResponse, error) - GetWorkflowExecution(request *GetWorkflowExecutionRequest) (*GetWorkflowExecutionResponse, error) - SetWorkflowExecution(request *SetWorkflowExecutionRequest) (*SetWorkflowExecutionResponse, error) + CreateWorkflowExecution(ctx context.Context, request *CreateWorkflowExecutionRequest) (*CreateWorkflowExecutionResponse, error) + UpdateWorkflowExecution(ctx context.Context, request *UpdateWorkflowExecutionRequest) (*UpdateWorkflowExecutionResponse, error) + ConflictResolveWorkflowExecution(ctx context.Context, request *ConflictResolveWorkflowExecutionRequest) (*ConflictResolveWorkflowExecutionResponse, error) + DeleteWorkflowExecution(ctx context.Context, request *DeleteWorkflowExecutionRequest) error + DeleteCurrentWorkflowExecution(ctx context.Context, request *DeleteCurrentWorkflowExecutionRequest) error + GetCurrentExecution(ctx context.Context, request *GetCurrentExecutionRequest) (*GetCurrentExecutionResponse, error) + GetWorkflowExecution(ctx context.Context, request *GetWorkflowExecutionRequest) (*GetWorkflowExecutionResponse, error) + SetWorkflowExecution(ctx context.Context, request *SetWorkflowExecutionRequest) (*SetWorkflowExecutionResponse, error) // Scan operations - ListConcreteExecutions(request *ListConcreteExecutionsRequest) (*ListConcreteExecutionsResponse, error) + ListConcreteExecutions(ctx context.Context, request *ListConcreteExecutionsRequest) (*ListConcreteExecutionsResponse, error) // Tasks related APIs - AddHistoryTasks(request *AddHistoryTasksRequest) error - GetHistoryTask(request *GetHistoryTaskRequest) (*GetHistoryTaskResponse, error) - GetHistoryTasks(request *GetHistoryTasksRequest) (*GetHistoryTasksResponse, error) - CompleteHistoryTask(request *CompleteHistoryTaskRequest) error - RangeCompleteHistoryTasks(request *RangeCompleteHistoryTasksRequest) error + AddHistoryTasks(ctx context.Context, request *AddHistoryTasksRequest) error + GetHistoryTask(ctx context.Context, request *GetHistoryTaskRequest) (*GetHistoryTaskResponse, error) + GetHistoryTasks(ctx context.Context, request *GetHistoryTasksRequest) (*GetHistoryTasksResponse, error) + CompleteHistoryTask(ctx context.Context, request *CompleteHistoryTaskRequest) error + RangeCompleteHistoryTasks(ctx context.Context, request *RangeCompleteHistoryTasksRequest) error - PutReplicationTaskToDLQ(request *PutReplicationTaskToDLQRequest) error - GetReplicationTasksFromDLQ(request *GetReplicationTasksFromDLQRequest) (*GetHistoryTasksResponse, error) - DeleteReplicationTaskFromDLQ(request *DeleteReplicationTaskFromDLQRequest) error - RangeDeleteReplicationTaskFromDLQ(request *RangeDeleteReplicationTaskFromDLQRequest) error + PutReplicationTaskToDLQ(ctx context.Context, request *PutReplicationTaskToDLQRequest) error + GetReplicationTasksFromDLQ(ctx context.Context, request *GetReplicationTasksFromDLQRequest) (*GetHistoryTasksResponse, error) + DeleteReplicationTaskFromDLQ(ctx context.Context, request *DeleteReplicationTaskFromDLQRequest) error + RangeDeleteReplicationTaskFromDLQ(ctx context.Context, request *RangeDeleteReplicationTaskFromDLQRequest) error // The below are history V2 APIs // V2 regards history events growing as a tree, decoupled from workflow concepts // For Temporal, treeID is new runID, except for fork(reset), treeID will be the runID that it forks from. // AppendHistoryNodes add a node to history node table - AppendHistoryNodes(request *AppendHistoryNodesRequest) (*AppendHistoryNodesResponse, error) + AppendHistoryNodes(ctx context.Context, request *AppendHistoryNodesRequest) (*AppendHistoryNodesResponse, error) // ReadHistoryBranch returns history node data for a branch - ReadHistoryBranch(request *ReadHistoryBranchRequest) (*ReadHistoryBranchResponse, error) + ReadHistoryBranch(ctx context.Context, request *ReadHistoryBranchRequest) (*ReadHistoryBranchResponse, error) // ReadHistoryBranchByBatch returns history node data for a branch ByBatch - ReadHistoryBranchByBatch(request *ReadHistoryBranchRequest) (*ReadHistoryBranchByBatchResponse, error) + ReadHistoryBranchByBatch(ctx context.Context, request *ReadHistoryBranchRequest) (*ReadHistoryBranchByBatchResponse, error) // ReadHistoryBranch returns history node data for a branch - ReadHistoryBranchReverse(request *ReadHistoryBranchReverseRequest) (*ReadHistoryBranchReverseResponse, error) + ReadHistoryBranchReverse(ctx context.Context, request *ReadHistoryBranchReverseRequest) (*ReadHistoryBranchReverseResponse, error) // ReadRawHistoryBranch returns history node raw data for a branch ByBatch // NOTE: this API should only be used by 3+DC - ReadRawHistoryBranch(request *ReadHistoryBranchRequest) (*ReadRawHistoryBranchResponse, error) + ReadRawHistoryBranch(ctx context.Context, request *ReadHistoryBranchRequest) (*ReadRawHistoryBranchResponse, error) // ForkHistoryBranch forks a new branch from a old branch - ForkHistoryBranch(request *ForkHistoryBranchRequest) (*ForkHistoryBranchResponse, error) + ForkHistoryBranch(ctx context.Context, request *ForkHistoryBranchRequest) (*ForkHistoryBranchResponse, error) // DeleteHistoryBranch removes a branch // If this is the last branch to delete, it will also remove the root node - DeleteHistoryBranch(request *DeleteHistoryBranchRequest) error + DeleteHistoryBranch(ctx context.Context, request *DeleteHistoryBranchRequest) error // TrimHistoryBranch validate & trim a history branch - TrimHistoryBranch(request *TrimHistoryBranchRequest) (*TrimHistoryBranchResponse, error) + TrimHistoryBranch(ctx context.Context, request *TrimHistoryBranchRequest) (*TrimHistoryBranchResponse, error) // GetHistoryTree returns all branch information of a tree - GetHistoryTree(request *GetHistoryTreeRequest) (*GetHistoryTreeResponse, error) + GetHistoryTree(ctx context.Context, request *GetHistoryTreeRequest) (*GetHistoryTreeResponse, error) // GetAllHistoryTreeBranches returns all branches of all trees - GetAllHistoryTreeBranches(request *GetAllHistoryTreeBranchesRequest) (*GetAllHistoryTreeBranchesResponse, error) + GetAllHistoryTreeBranches(ctx context.Context, request *GetAllHistoryTreeBranchesRequest) (*GetAllHistoryTreeBranchesResponse, error) } // TaskManager is used to manage tasks - // TODO: consider change the range for GetTasks and CompleteTasks to be [inclusive, exclusive) TaskManager interface { Closeable GetName() string diff --git a/common/persistence/dataInterfaces_mock.go b/common/persistence/dataInterfaces_mock.go index 38876b1e848..592273e572a 100644 --- a/common/persistence/dataInterfaces_mock.go +++ b/common/persistence/dataInterfaces_mock.go @@ -29,6 +29,7 @@ package persistence import ( + context "context" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -171,32 +172,32 @@ func (m *MockExecutionManager) EXPECT() *MockExecutionManagerMockRecorder { } // AddHistoryTasks mocks base method. -func (m *MockExecutionManager) AddHistoryTasks(request *AddHistoryTasksRequest) error { +func (m *MockExecutionManager) AddHistoryTasks(ctx context.Context, request *AddHistoryTasksRequest) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddHistoryTasks", request) + ret := m.ctrl.Call(m, "AddHistoryTasks", ctx, request) ret0, _ := ret[0].(error) return ret0 } // AddHistoryTasks indicates an expected call of AddHistoryTasks. -func (mr *MockExecutionManagerMockRecorder) AddHistoryTasks(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) AddHistoryTasks(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddHistoryTasks", reflect.TypeOf((*MockExecutionManager)(nil).AddHistoryTasks), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddHistoryTasks", reflect.TypeOf((*MockExecutionManager)(nil).AddHistoryTasks), ctx, request) } // AppendHistoryNodes mocks base method. -func (m *MockExecutionManager) AppendHistoryNodes(request *AppendHistoryNodesRequest) (*AppendHistoryNodesResponse, error) { +func (m *MockExecutionManager) AppendHistoryNodes(ctx context.Context, request *AppendHistoryNodesRequest) (*AppendHistoryNodesResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AppendHistoryNodes", request) + ret := m.ctrl.Call(m, "AppendHistoryNodes", ctx, request) ret0, _ := ret[0].(*AppendHistoryNodesResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // AppendHistoryNodes indicates an expected call of AppendHistoryNodes. -func (mr *MockExecutionManagerMockRecorder) AppendHistoryNodes(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) AppendHistoryNodes(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendHistoryNodes", reflect.TypeOf((*MockExecutionManager)(nil).AppendHistoryNodes), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendHistoryNodes", reflect.TypeOf((*MockExecutionManager)(nil).AppendHistoryNodes), ctx, request) } // Close mocks base method. @@ -212,193 +213,193 @@ func (mr *MockExecutionManagerMockRecorder) Close() *gomock.Call { } // CompleteHistoryTask mocks base method. -func (m *MockExecutionManager) CompleteHistoryTask(request *CompleteHistoryTaskRequest) error { +func (m *MockExecutionManager) CompleteHistoryTask(ctx context.Context, request *CompleteHistoryTaskRequest) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CompleteHistoryTask", request) + ret := m.ctrl.Call(m, "CompleteHistoryTask", ctx, request) ret0, _ := ret[0].(error) return ret0 } // CompleteHistoryTask indicates an expected call of CompleteHistoryTask. -func (mr *MockExecutionManagerMockRecorder) CompleteHistoryTask(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) CompleteHistoryTask(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CompleteHistoryTask", reflect.TypeOf((*MockExecutionManager)(nil).CompleteHistoryTask), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CompleteHistoryTask", reflect.TypeOf((*MockExecutionManager)(nil).CompleteHistoryTask), ctx, request) } // ConflictResolveWorkflowExecution mocks base method. -func (m *MockExecutionManager) ConflictResolveWorkflowExecution(request *ConflictResolveWorkflowExecutionRequest) (*ConflictResolveWorkflowExecutionResponse, error) { +func (m *MockExecutionManager) ConflictResolveWorkflowExecution(ctx context.Context, request *ConflictResolveWorkflowExecutionRequest) (*ConflictResolveWorkflowExecutionResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ConflictResolveWorkflowExecution", request) + ret := m.ctrl.Call(m, "ConflictResolveWorkflowExecution", ctx, request) ret0, _ := ret[0].(*ConflictResolveWorkflowExecutionResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // ConflictResolveWorkflowExecution indicates an expected call of ConflictResolveWorkflowExecution. -func (mr *MockExecutionManagerMockRecorder) ConflictResolveWorkflowExecution(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) ConflictResolveWorkflowExecution(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConflictResolveWorkflowExecution", reflect.TypeOf((*MockExecutionManager)(nil).ConflictResolveWorkflowExecution), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConflictResolveWorkflowExecution", reflect.TypeOf((*MockExecutionManager)(nil).ConflictResolveWorkflowExecution), ctx, request) } // CreateWorkflowExecution mocks base method. -func (m *MockExecutionManager) CreateWorkflowExecution(request *CreateWorkflowExecutionRequest) (*CreateWorkflowExecutionResponse, error) { +func (m *MockExecutionManager) CreateWorkflowExecution(ctx context.Context, request *CreateWorkflowExecutionRequest) (*CreateWorkflowExecutionResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateWorkflowExecution", request) + ret := m.ctrl.Call(m, "CreateWorkflowExecution", ctx, request) ret0, _ := ret[0].(*CreateWorkflowExecutionResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // CreateWorkflowExecution indicates an expected call of CreateWorkflowExecution. -func (mr *MockExecutionManagerMockRecorder) CreateWorkflowExecution(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) CreateWorkflowExecution(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateWorkflowExecution", reflect.TypeOf((*MockExecutionManager)(nil).CreateWorkflowExecution), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateWorkflowExecution", reflect.TypeOf((*MockExecutionManager)(nil).CreateWorkflowExecution), ctx, request) } // DeleteCurrentWorkflowExecution mocks base method. -func (m *MockExecutionManager) DeleteCurrentWorkflowExecution(request *DeleteCurrentWorkflowExecutionRequest) error { +func (m *MockExecutionManager) DeleteCurrentWorkflowExecution(ctx context.Context, request *DeleteCurrentWorkflowExecutionRequest) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteCurrentWorkflowExecution", request) + ret := m.ctrl.Call(m, "DeleteCurrentWorkflowExecution", ctx, request) ret0, _ := ret[0].(error) return ret0 } // DeleteCurrentWorkflowExecution indicates an expected call of DeleteCurrentWorkflowExecution. -func (mr *MockExecutionManagerMockRecorder) DeleteCurrentWorkflowExecution(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) DeleteCurrentWorkflowExecution(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCurrentWorkflowExecution", reflect.TypeOf((*MockExecutionManager)(nil).DeleteCurrentWorkflowExecution), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCurrentWorkflowExecution", reflect.TypeOf((*MockExecutionManager)(nil).DeleteCurrentWorkflowExecution), ctx, request) } // DeleteHistoryBranch mocks base method. -func (m *MockExecutionManager) DeleteHistoryBranch(request *DeleteHistoryBranchRequest) error { +func (m *MockExecutionManager) DeleteHistoryBranch(ctx context.Context, request *DeleteHistoryBranchRequest) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteHistoryBranch", request) + ret := m.ctrl.Call(m, "DeleteHistoryBranch", ctx, request) ret0, _ := ret[0].(error) return ret0 } // DeleteHistoryBranch indicates an expected call of DeleteHistoryBranch. -func (mr *MockExecutionManagerMockRecorder) DeleteHistoryBranch(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) DeleteHistoryBranch(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteHistoryBranch", reflect.TypeOf((*MockExecutionManager)(nil).DeleteHistoryBranch), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteHistoryBranch", reflect.TypeOf((*MockExecutionManager)(nil).DeleteHistoryBranch), ctx, request) } // DeleteReplicationTaskFromDLQ mocks base method. -func (m *MockExecutionManager) DeleteReplicationTaskFromDLQ(request *DeleteReplicationTaskFromDLQRequest) error { +func (m *MockExecutionManager) DeleteReplicationTaskFromDLQ(ctx context.Context, request *DeleteReplicationTaskFromDLQRequest) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteReplicationTaskFromDLQ", request) + ret := m.ctrl.Call(m, "DeleteReplicationTaskFromDLQ", ctx, request) ret0, _ := ret[0].(error) return ret0 } // DeleteReplicationTaskFromDLQ indicates an expected call of DeleteReplicationTaskFromDLQ. -func (mr *MockExecutionManagerMockRecorder) DeleteReplicationTaskFromDLQ(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) DeleteReplicationTaskFromDLQ(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteReplicationTaskFromDLQ", reflect.TypeOf((*MockExecutionManager)(nil).DeleteReplicationTaskFromDLQ), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteReplicationTaskFromDLQ", reflect.TypeOf((*MockExecutionManager)(nil).DeleteReplicationTaskFromDLQ), ctx, request) } // DeleteWorkflowExecution mocks base method. -func (m *MockExecutionManager) DeleteWorkflowExecution(request *DeleteWorkflowExecutionRequest) error { +func (m *MockExecutionManager) DeleteWorkflowExecution(ctx context.Context, request *DeleteWorkflowExecutionRequest) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteWorkflowExecution", request) + ret := m.ctrl.Call(m, "DeleteWorkflowExecution", ctx, request) ret0, _ := ret[0].(error) return ret0 } // DeleteWorkflowExecution indicates an expected call of DeleteWorkflowExecution. -func (mr *MockExecutionManagerMockRecorder) DeleteWorkflowExecution(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) DeleteWorkflowExecution(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkflowExecution", reflect.TypeOf((*MockExecutionManager)(nil).DeleteWorkflowExecution), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkflowExecution", reflect.TypeOf((*MockExecutionManager)(nil).DeleteWorkflowExecution), ctx, request) } // ForkHistoryBranch mocks base method. -func (m *MockExecutionManager) ForkHistoryBranch(request *ForkHistoryBranchRequest) (*ForkHistoryBranchResponse, error) { +func (m *MockExecutionManager) ForkHistoryBranch(ctx context.Context, request *ForkHistoryBranchRequest) (*ForkHistoryBranchResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ForkHistoryBranch", request) + ret := m.ctrl.Call(m, "ForkHistoryBranch", ctx, request) ret0, _ := ret[0].(*ForkHistoryBranchResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // ForkHistoryBranch indicates an expected call of ForkHistoryBranch. -func (mr *MockExecutionManagerMockRecorder) ForkHistoryBranch(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) ForkHistoryBranch(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ForkHistoryBranch", reflect.TypeOf((*MockExecutionManager)(nil).ForkHistoryBranch), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ForkHistoryBranch", reflect.TypeOf((*MockExecutionManager)(nil).ForkHistoryBranch), ctx, request) } // GetAllHistoryTreeBranches mocks base method. -func (m *MockExecutionManager) GetAllHistoryTreeBranches(request *GetAllHistoryTreeBranchesRequest) (*GetAllHistoryTreeBranchesResponse, error) { +func (m *MockExecutionManager) GetAllHistoryTreeBranches(ctx context.Context, request *GetAllHistoryTreeBranchesRequest) (*GetAllHistoryTreeBranchesResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAllHistoryTreeBranches", request) + ret := m.ctrl.Call(m, "GetAllHistoryTreeBranches", ctx, request) ret0, _ := ret[0].(*GetAllHistoryTreeBranchesResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // GetAllHistoryTreeBranches indicates an expected call of GetAllHistoryTreeBranches. -func (mr *MockExecutionManagerMockRecorder) GetAllHistoryTreeBranches(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) GetAllHistoryTreeBranches(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllHistoryTreeBranches", reflect.TypeOf((*MockExecutionManager)(nil).GetAllHistoryTreeBranches), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllHistoryTreeBranches", reflect.TypeOf((*MockExecutionManager)(nil).GetAllHistoryTreeBranches), ctx, request) } // GetCurrentExecution mocks base method. -func (m *MockExecutionManager) GetCurrentExecution(request *GetCurrentExecutionRequest) (*GetCurrentExecutionResponse, error) { +func (m *MockExecutionManager) GetCurrentExecution(ctx context.Context, request *GetCurrentExecutionRequest) (*GetCurrentExecutionResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetCurrentExecution", request) + ret := m.ctrl.Call(m, "GetCurrentExecution", ctx, request) ret0, _ := ret[0].(*GetCurrentExecutionResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // GetCurrentExecution indicates an expected call of GetCurrentExecution. -func (mr *MockExecutionManagerMockRecorder) GetCurrentExecution(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) GetCurrentExecution(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCurrentExecution", reflect.TypeOf((*MockExecutionManager)(nil).GetCurrentExecution), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCurrentExecution", reflect.TypeOf((*MockExecutionManager)(nil).GetCurrentExecution), ctx, request) } // GetHistoryTask mocks base method. -func (m *MockExecutionManager) GetHistoryTask(request *GetHistoryTaskRequest) (*GetHistoryTaskResponse, error) { +func (m *MockExecutionManager) GetHistoryTask(ctx context.Context, request *GetHistoryTaskRequest) (*GetHistoryTaskResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetHistoryTask", request) + ret := m.ctrl.Call(m, "GetHistoryTask", ctx, request) ret0, _ := ret[0].(*GetHistoryTaskResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // GetHistoryTask indicates an expected call of GetHistoryTask. -func (mr *MockExecutionManagerMockRecorder) GetHistoryTask(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) GetHistoryTask(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHistoryTask", reflect.TypeOf((*MockExecutionManager)(nil).GetHistoryTask), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHistoryTask", reflect.TypeOf((*MockExecutionManager)(nil).GetHistoryTask), ctx, request) } // GetHistoryTasks mocks base method. -func (m *MockExecutionManager) GetHistoryTasks(request *GetHistoryTasksRequest) (*GetHistoryTasksResponse, error) { +func (m *MockExecutionManager) GetHistoryTasks(ctx context.Context, request *GetHistoryTasksRequest) (*GetHistoryTasksResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetHistoryTasks", request) + ret := m.ctrl.Call(m, "GetHistoryTasks", ctx, request) ret0, _ := ret[0].(*GetHistoryTasksResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // GetHistoryTasks indicates an expected call of GetHistoryTasks. -func (mr *MockExecutionManagerMockRecorder) GetHistoryTasks(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) GetHistoryTasks(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHistoryTasks", reflect.TypeOf((*MockExecutionManager)(nil).GetHistoryTasks), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHistoryTasks", reflect.TypeOf((*MockExecutionManager)(nil).GetHistoryTasks), ctx, request) } // GetHistoryTree mocks base method. -func (m *MockExecutionManager) GetHistoryTree(request *GetHistoryTreeRequest) (*GetHistoryTreeResponse, error) { +func (m *MockExecutionManager) GetHistoryTree(ctx context.Context, request *GetHistoryTreeRequest) (*GetHistoryTreeResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetHistoryTree", request) + ret := m.ctrl.Call(m, "GetHistoryTree", ctx, request) ret0, _ := ret[0].(*GetHistoryTreeResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // GetHistoryTree indicates an expected call of GetHistoryTree. -func (mr *MockExecutionManagerMockRecorder) GetHistoryTree(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) GetHistoryTree(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHistoryTree", reflect.TypeOf((*MockExecutionManager)(nil).GetHistoryTree), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHistoryTree", reflect.TypeOf((*MockExecutionManager)(nil).GetHistoryTree), ctx, request) } // GetName mocks base method. @@ -416,195 +417,195 @@ func (mr *MockExecutionManagerMockRecorder) GetName() *gomock.Call { } // GetReplicationTasksFromDLQ mocks base method. -func (m *MockExecutionManager) GetReplicationTasksFromDLQ(request *GetReplicationTasksFromDLQRequest) (*GetHistoryTasksResponse, error) { +func (m *MockExecutionManager) GetReplicationTasksFromDLQ(ctx context.Context, request *GetReplicationTasksFromDLQRequest) (*GetHistoryTasksResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetReplicationTasksFromDLQ", request) + ret := m.ctrl.Call(m, "GetReplicationTasksFromDLQ", ctx, request) ret0, _ := ret[0].(*GetHistoryTasksResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // GetReplicationTasksFromDLQ indicates an expected call of GetReplicationTasksFromDLQ. -func (mr *MockExecutionManagerMockRecorder) GetReplicationTasksFromDLQ(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) GetReplicationTasksFromDLQ(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetReplicationTasksFromDLQ", reflect.TypeOf((*MockExecutionManager)(nil).GetReplicationTasksFromDLQ), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetReplicationTasksFromDLQ", reflect.TypeOf((*MockExecutionManager)(nil).GetReplicationTasksFromDLQ), ctx, request) } // GetWorkflowExecution mocks base method. -func (m *MockExecutionManager) GetWorkflowExecution(request *GetWorkflowExecutionRequest) (*GetWorkflowExecutionResponse, error) { +func (m *MockExecutionManager) GetWorkflowExecution(ctx context.Context, request *GetWorkflowExecutionRequest) (*GetWorkflowExecutionResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetWorkflowExecution", request) + ret := m.ctrl.Call(m, "GetWorkflowExecution", ctx, request) ret0, _ := ret[0].(*GetWorkflowExecutionResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // GetWorkflowExecution indicates an expected call of GetWorkflowExecution. -func (mr *MockExecutionManagerMockRecorder) GetWorkflowExecution(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) GetWorkflowExecution(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkflowExecution", reflect.TypeOf((*MockExecutionManager)(nil).GetWorkflowExecution), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkflowExecution", reflect.TypeOf((*MockExecutionManager)(nil).GetWorkflowExecution), ctx, request) } // ListConcreteExecutions mocks base method. -func (m *MockExecutionManager) ListConcreteExecutions(request *ListConcreteExecutionsRequest) (*ListConcreteExecutionsResponse, error) { +func (m *MockExecutionManager) ListConcreteExecutions(ctx context.Context, request *ListConcreteExecutionsRequest) (*ListConcreteExecutionsResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListConcreteExecutions", request) + ret := m.ctrl.Call(m, "ListConcreteExecutions", ctx, request) ret0, _ := ret[0].(*ListConcreteExecutionsResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // ListConcreteExecutions indicates an expected call of ListConcreteExecutions. -func (mr *MockExecutionManagerMockRecorder) ListConcreteExecutions(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) ListConcreteExecutions(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListConcreteExecutions", reflect.TypeOf((*MockExecutionManager)(nil).ListConcreteExecutions), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListConcreteExecutions", reflect.TypeOf((*MockExecutionManager)(nil).ListConcreteExecutions), ctx, request) } // PutReplicationTaskToDLQ mocks base method. -func (m *MockExecutionManager) PutReplicationTaskToDLQ(request *PutReplicationTaskToDLQRequest) error { +func (m *MockExecutionManager) PutReplicationTaskToDLQ(ctx context.Context, request *PutReplicationTaskToDLQRequest) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PutReplicationTaskToDLQ", request) + ret := m.ctrl.Call(m, "PutReplicationTaskToDLQ", ctx, request) ret0, _ := ret[0].(error) return ret0 } // PutReplicationTaskToDLQ indicates an expected call of PutReplicationTaskToDLQ. -func (mr *MockExecutionManagerMockRecorder) PutReplicationTaskToDLQ(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) PutReplicationTaskToDLQ(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutReplicationTaskToDLQ", reflect.TypeOf((*MockExecutionManager)(nil).PutReplicationTaskToDLQ), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutReplicationTaskToDLQ", reflect.TypeOf((*MockExecutionManager)(nil).PutReplicationTaskToDLQ), ctx, request) } // RangeCompleteHistoryTasks mocks base method. -func (m *MockExecutionManager) RangeCompleteHistoryTasks(request *RangeCompleteHistoryTasksRequest) error { +func (m *MockExecutionManager) RangeCompleteHistoryTasks(ctx context.Context, request *RangeCompleteHistoryTasksRequest) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RangeCompleteHistoryTasks", request) + ret := m.ctrl.Call(m, "RangeCompleteHistoryTasks", ctx, request) ret0, _ := ret[0].(error) return ret0 } // RangeCompleteHistoryTasks indicates an expected call of RangeCompleteHistoryTasks. -func (mr *MockExecutionManagerMockRecorder) RangeCompleteHistoryTasks(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) RangeCompleteHistoryTasks(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RangeCompleteHistoryTasks", reflect.TypeOf((*MockExecutionManager)(nil).RangeCompleteHistoryTasks), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RangeCompleteHistoryTasks", reflect.TypeOf((*MockExecutionManager)(nil).RangeCompleteHistoryTasks), ctx, request) } // RangeDeleteReplicationTaskFromDLQ mocks base method. -func (m *MockExecutionManager) RangeDeleteReplicationTaskFromDLQ(request *RangeDeleteReplicationTaskFromDLQRequest) error { +func (m *MockExecutionManager) RangeDeleteReplicationTaskFromDLQ(ctx context.Context, request *RangeDeleteReplicationTaskFromDLQRequest) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RangeDeleteReplicationTaskFromDLQ", request) + ret := m.ctrl.Call(m, "RangeDeleteReplicationTaskFromDLQ", ctx, request) ret0, _ := ret[0].(error) return ret0 } // RangeDeleteReplicationTaskFromDLQ indicates an expected call of RangeDeleteReplicationTaskFromDLQ. -func (mr *MockExecutionManagerMockRecorder) RangeDeleteReplicationTaskFromDLQ(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) RangeDeleteReplicationTaskFromDLQ(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RangeDeleteReplicationTaskFromDLQ", reflect.TypeOf((*MockExecutionManager)(nil).RangeDeleteReplicationTaskFromDLQ), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RangeDeleteReplicationTaskFromDLQ", reflect.TypeOf((*MockExecutionManager)(nil).RangeDeleteReplicationTaskFromDLQ), ctx, request) } // ReadHistoryBranch mocks base method. -func (m *MockExecutionManager) ReadHistoryBranch(request *ReadHistoryBranchRequest) (*ReadHistoryBranchResponse, error) { +func (m *MockExecutionManager) ReadHistoryBranch(ctx context.Context, request *ReadHistoryBranchRequest) (*ReadHistoryBranchResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReadHistoryBranch", request) + ret := m.ctrl.Call(m, "ReadHistoryBranch", ctx, request) ret0, _ := ret[0].(*ReadHistoryBranchResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // ReadHistoryBranch indicates an expected call of ReadHistoryBranch. -func (mr *MockExecutionManagerMockRecorder) ReadHistoryBranch(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) ReadHistoryBranch(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadHistoryBranch", reflect.TypeOf((*MockExecutionManager)(nil).ReadHistoryBranch), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadHistoryBranch", reflect.TypeOf((*MockExecutionManager)(nil).ReadHistoryBranch), ctx, request) } // ReadHistoryBranchByBatch mocks base method. -func (m *MockExecutionManager) ReadHistoryBranchByBatch(request *ReadHistoryBranchRequest) (*ReadHistoryBranchByBatchResponse, error) { +func (m *MockExecutionManager) ReadHistoryBranchByBatch(ctx context.Context, request *ReadHistoryBranchRequest) (*ReadHistoryBranchByBatchResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReadHistoryBranchByBatch", request) + ret := m.ctrl.Call(m, "ReadHistoryBranchByBatch", ctx, request) ret0, _ := ret[0].(*ReadHistoryBranchByBatchResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // ReadHistoryBranchByBatch indicates an expected call of ReadHistoryBranchByBatch. -func (mr *MockExecutionManagerMockRecorder) ReadHistoryBranchByBatch(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) ReadHistoryBranchByBatch(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadHistoryBranchByBatch", reflect.TypeOf((*MockExecutionManager)(nil).ReadHistoryBranchByBatch), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadHistoryBranchByBatch", reflect.TypeOf((*MockExecutionManager)(nil).ReadHistoryBranchByBatch), ctx, request) } // ReadHistoryBranchReverse mocks base method. -func (m *MockExecutionManager) ReadHistoryBranchReverse(request *ReadHistoryBranchReverseRequest) (*ReadHistoryBranchReverseResponse, error) { +func (m *MockExecutionManager) ReadHistoryBranchReverse(ctx context.Context, request *ReadHistoryBranchReverseRequest) (*ReadHistoryBranchReverseResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReadHistoryBranchReverse", request) + ret := m.ctrl.Call(m, "ReadHistoryBranchReverse", ctx, request) ret0, _ := ret[0].(*ReadHistoryBranchReverseResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // ReadHistoryBranchReverse indicates an expected call of ReadHistoryBranchReverse. -func (mr *MockExecutionManagerMockRecorder) ReadHistoryBranchReverse(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) ReadHistoryBranchReverse(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadHistoryBranchReverse", reflect.TypeOf((*MockExecutionManager)(nil).ReadHistoryBranchReverse), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadHistoryBranchReverse", reflect.TypeOf((*MockExecutionManager)(nil).ReadHistoryBranchReverse), ctx, request) } // ReadRawHistoryBranch mocks base method. -func (m *MockExecutionManager) ReadRawHistoryBranch(request *ReadHistoryBranchRequest) (*ReadRawHistoryBranchResponse, error) { +func (m *MockExecutionManager) ReadRawHistoryBranch(ctx context.Context, request *ReadHistoryBranchRequest) (*ReadRawHistoryBranchResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReadRawHistoryBranch", request) + ret := m.ctrl.Call(m, "ReadRawHistoryBranch", ctx, request) ret0, _ := ret[0].(*ReadRawHistoryBranchResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // ReadRawHistoryBranch indicates an expected call of ReadRawHistoryBranch. -func (mr *MockExecutionManagerMockRecorder) ReadRawHistoryBranch(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) ReadRawHistoryBranch(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadRawHistoryBranch", reflect.TypeOf((*MockExecutionManager)(nil).ReadRawHistoryBranch), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadRawHistoryBranch", reflect.TypeOf((*MockExecutionManager)(nil).ReadRawHistoryBranch), ctx, request) } // SetWorkflowExecution mocks base method. -func (m *MockExecutionManager) SetWorkflowExecution(request *SetWorkflowExecutionRequest) (*SetWorkflowExecutionResponse, error) { +func (m *MockExecutionManager) SetWorkflowExecution(ctx context.Context, request *SetWorkflowExecutionRequest) (*SetWorkflowExecutionResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetWorkflowExecution", request) + ret := m.ctrl.Call(m, "SetWorkflowExecution", ctx, request) ret0, _ := ret[0].(*SetWorkflowExecutionResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // SetWorkflowExecution indicates an expected call of SetWorkflowExecution. -func (mr *MockExecutionManagerMockRecorder) SetWorkflowExecution(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) SetWorkflowExecution(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWorkflowExecution", reflect.TypeOf((*MockExecutionManager)(nil).SetWorkflowExecution), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWorkflowExecution", reflect.TypeOf((*MockExecutionManager)(nil).SetWorkflowExecution), ctx, request) } // TrimHistoryBranch mocks base method. -func (m *MockExecutionManager) TrimHistoryBranch(request *TrimHistoryBranchRequest) (*TrimHistoryBranchResponse, error) { +func (m *MockExecutionManager) TrimHistoryBranch(ctx context.Context, request *TrimHistoryBranchRequest) (*TrimHistoryBranchResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TrimHistoryBranch", request) + ret := m.ctrl.Call(m, "TrimHistoryBranch", ctx, request) ret0, _ := ret[0].(*TrimHistoryBranchResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // TrimHistoryBranch indicates an expected call of TrimHistoryBranch. -func (mr *MockExecutionManagerMockRecorder) TrimHistoryBranch(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) TrimHistoryBranch(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TrimHistoryBranch", reflect.TypeOf((*MockExecutionManager)(nil).TrimHistoryBranch), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TrimHistoryBranch", reflect.TypeOf((*MockExecutionManager)(nil).TrimHistoryBranch), ctx, request) } // UpdateWorkflowExecution mocks base method. -func (m *MockExecutionManager) UpdateWorkflowExecution(request *UpdateWorkflowExecutionRequest) (*UpdateWorkflowExecutionResponse, error) { +func (m *MockExecutionManager) UpdateWorkflowExecution(ctx context.Context, request *UpdateWorkflowExecutionRequest) (*UpdateWorkflowExecutionResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateWorkflowExecution", request) + ret := m.ctrl.Call(m, "UpdateWorkflowExecution", ctx, request) ret0, _ := ret[0].(*UpdateWorkflowExecutionResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // UpdateWorkflowExecution indicates an expected call of UpdateWorkflowExecution. -func (mr *MockExecutionManagerMockRecorder) UpdateWorkflowExecution(request interface{}) *gomock.Call { +func (mr *MockExecutionManagerMockRecorder) UpdateWorkflowExecution(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkflowExecution", reflect.TypeOf((*MockExecutionManager)(nil).UpdateWorkflowExecution), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkflowExecution", reflect.TypeOf((*MockExecutionManager)(nil).UpdateWorkflowExecution), ctx, request) } // MockTaskManager is a mock of TaskManager interface. diff --git a/common/persistence/execution_manager.go b/common/persistence/execution_manager.go index fcf8d43dffc..bb279d4ba0d 100644 --- a/common/persistence/execution_manager.go +++ b/common/persistence/execution_manager.go @@ -25,6 +25,7 @@ package persistence import ( + "context" "fmt" commonpb "go.temporal.io/api/common/v1" @@ -81,6 +82,7 @@ func (m *executionManagerImpl) GetName() string { // The below three APIs are related to serialization/deserialization func (m *executionManagerImpl) CreateWorkflowExecution( + _ context.Context, request *CreateWorkflowExecutionRequest, ) (*CreateWorkflowExecutionResponse, error) { @@ -131,6 +133,7 @@ func (m *executionManagerImpl) CreateWorkflowExecution( } func (m *executionManagerImpl) UpdateWorkflowExecution( + _ context.Context, request *UpdateWorkflowExecutionRequest, ) (*UpdateWorkflowExecutionResponse, error) { @@ -218,6 +221,7 @@ func (m *executionManagerImpl) UpdateWorkflowExecution( } func (m *executionManagerImpl) ConflictResolveWorkflowExecution( + _ context.Context, request *ConflictResolveWorkflowExecutionRequest, ) (*ConflictResolveWorkflowExecutionResponse, error) { @@ -332,6 +336,7 @@ func (m *executionManagerImpl) ConflictResolveWorkflowExecution( } func (m *executionManagerImpl) GetWorkflowExecution( + _ context.Context, request *GetWorkflowExecutionRequest, ) (*GetWorkflowExecutionResponse, error) { response, err := m.persistence.GetWorkflowExecution(request) @@ -357,6 +362,7 @@ func (m *executionManagerImpl) GetWorkflowExecution( } func (m *executionManagerImpl) SetWorkflowExecution( + _ context.Context, request *SetWorkflowExecutionRequest, ) (*SetWorkflowExecutionResponse, error) { serializedWorkflowSnapshot, err := m.SerializeWorkflowSnapshot(&request.SetWorkflowSnapshot) @@ -401,7 +407,7 @@ func (m *executionManagerImpl) serializeWorkflowEventBatches( return workflowNewEvents, &historyStatistics, nil } -func (m *executionManagerImpl) DeserializeBufferedEvents( +func (m *executionManagerImpl) DeserializeBufferedEvents( // unexport blobs []*commonpb.DataBlob, ) ([]*historypb.HistoryEvent, error) { @@ -446,7 +452,7 @@ func (m *executionManagerImpl) serializeWorkflowEvents( return m.serializeAppendHistoryNodesRequest(request) } -func (m *executionManagerImpl) SerializeWorkflowMutation( +func (m *executionManagerImpl) SerializeWorkflowMutation( // unexport input *WorkflowMutation, ) (*InternalWorkflowMutation, error) { @@ -549,7 +555,7 @@ func (m *executionManagerImpl) SerializeWorkflowMutation( return result, nil } -func (m *executionManagerImpl) SerializeWorkflowSnapshot( +func (m *executionManagerImpl) SerializeWorkflowSnapshot( // unexport input *WorkflowSnapshot, ) (*InternalWorkflowSnapshot, error) { @@ -641,18 +647,21 @@ func (m *executionManagerImpl) SerializeWorkflowSnapshot( } func (m *executionManagerImpl) DeleteWorkflowExecution( + _ context.Context, request *DeleteWorkflowExecutionRequest, ) error { return m.persistence.DeleteWorkflowExecution(request) } func (m *executionManagerImpl) DeleteCurrentWorkflowExecution( + _ context.Context, request *DeleteCurrentWorkflowExecutionRequest, ) error { return m.persistence.DeleteCurrentWorkflowExecution(request) } func (m *executionManagerImpl) GetCurrentExecution( + _ context.Context, request *GetCurrentExecutionRequest, ) (*GetCurrentExecutionResponse, error) { internalResp, err := m.persistence.GetCurrentExecution(request) @@ -669,6 +678,7 @@ func (m *executionManagerImpl) GetCurrentExecution( } func (m *executionManagerImpl) ListConcreteExecutions( + _ context.Context, request *ListConcreteExecutionsRequest, ) (*ListConcreteExecutionsResponse, error) { response, err := m.persistence.ListConcreteExecutions(request) @@ -690,6 +700,7 @@ func (m *executionManagerImpl) ListConcreteExecutions( } func (m *executionManagerImpl) AddHistoryTasks( + _ context.Context, input *AddHistoryTasksRequest, ) error { tasks, err := serializeTasks(m.serializer, input.Tasks) @@ -710,6 +721,7 @@ func (m *executionManagerImpl) AddHistoryTasks( } func (m *executionManagerImpl) GetHistoryTask( + _ context.Context, request *GetHistoryTaskRequest, ) (*GetHistoryTaskResponse, error) { resp, err := m.persistence.GetHistoryTask(request) @@ -727,6 +739,7 @@ func (m *executionManagerImpl) GetHistoryTask( } func (m *executionManagerImpl) GetHistoryTasks( + _ context.Context, request *GetHistoryTasksRequest, ) (*GetHistoryTasksResponse, error) { if err := validateTaskRange( @@ -758,12 +771,14 @@ func (m *executionManagerImpl) GetHistoryTasks( } func (m *executionManagerImpl) CompleteHistoryTask( + _ context.Context, request *CompleteHistoryTaskRequest, ) error { return m.persistence.CompleteHistoryTask(request) } func (m *executionManagerImpl) RangeCompleteHistoryTasks( + _ context.Context, request *RangeCompleteHistoryTasksRequest, ) error { if err := validateTaskRange( @@ -778,12 +793,14 @@ func (m *executionManagerImpl) RangeCompleteHistoryTasks( } func (m *executionManagerImpl) PutReplicationTaskToDLQ( + _ context.Context, request *PutReplicationTaskToDLQRequest, ) error { return m.persistence.PutReplicationTaskToDLQ(request) } func (m *executionManagerImpl) GetReplicationTasksFromDLQ( + _ context.Context, request *GetReplicationTasksFromDLQRequest, ) (*GetHistoryTasksResponse, error) { resp, err := m.persistence.GetReplicationTasksFromDLQ(request) @@ -808,12 +825,14 @@ func (m *executionManagerImpl) GetReplicationTasksFromDLQ( } func (m *executionManagerImpl) DeleteReplicationTaskFromDLQ( + _ context.Context, request *DeleteReplicationTaskFromDLQRequest, ) error { return m.persistence.DeleteReplicationTaskFromDLQ(request) } func (m *executionManagerImpl) RangeDeleteReplicationTaskFromDLQ( + _ context.Context, request *RangeDeleteReplicationTaskFromDLQRequest, ) error { return m.persistence.RangeDeleteReplicationTaskFromDLQ(request) @@ -829,7 +848,7 @@ func (m *executionManagerImpl) trimHistoryNode( workflowID string, runID string, ) { - response, err := m.GetWorkflowExecution(&GetWorkflowExecutionRequest{ + response, err := m.GetWorkflowExecution(context.TODO(), &GetWorkflowExecutionRequest{ ShardID: shardID, NamespaceID: namespaceID, WorkflowID: workflowID, @@ -852,7 +871,7 @@ func (m *executionManagerImpl) trimHistoryNode( } mutableStateLastNodeID := executionInfo.LastFirstEventId mutableStateLastNodeTransactionID := executionInfo.LastFirstEventTxnId - if _, err := m.TrimHistoryBranch(&TrimHistoryBranchRequest{ + if _, err := m.TrimHistoryBranch(context.TODO(), &TrimHistoryBranchRequest{ ShardID: shardID, BranchToken: branchToken, NodeID: mutableStateLastNodeID, diff --git a/common/persistence/history_manager.go b/common/persistence/history_manager.go index 7423e0d08e0..12c4257be6e 100644 --- a/common/persistence/history_manager.go +++ b/common/persistence/history_manager.go @@ -25,6 +25,7 @@ package persistence import ( + "context" "fmt" "github.com/pborman/uuid" @@ -52,6 +53,7 @@ var _ ExecutionManager = (*executionManagerImpl)(nil) // ForkHistoryBranch forks a new branch from a old branch func (m *executionManagerImpl) ForkHistoryBranch( + _ context.Context, request *ForkHistoryBranchRequest, ) (*ForkHistoryBranchResponse, error) { @@ -137,6 +139,7 @@ func (m *executionManagerImpl) ForkHistoryBranch( // DeleteHistoryBranch removes a branch func (m *executionManagerImpl) DeleteHistoryBranch( + _ context.Context, request *DeleteHistoryBranchRequest, ) error { @@ -155,7 +158,7 @@ func (m *executionManagerImpl) DeleteHistoryBranch( }) // Get the entire history tree, so we know if any part of the target branch is referenced by other branches. - historyTreeResp, err := m.GetHistoryTree(&GetHistoryTreeRequest{ + historyTreeResp, err := m.GetHistoryTree(context.TODO(), &GetHistoryTreeRequest{ TreeID: branch.TreeId, ShardID: &request.ShardID, BranchToken: request.BranchToken, @@ -214,6 +217,7 @@ findDeleteRanges: // TrimHistoryBranch trims a branch func (m *executionManagerImpl) TrimHistoryBranch( + _ context.Context, request *TrimHistoryBranchRequest, ) (*TrimHistoryBranchResponse, error) { @@ -309,6 +313,7 @@ func (m *executionManagerImpl) TrimHistoryBranch( // GetHistoryTree returns all branch information of a tree func (m *executionManagerImpl) GetHistoryTree( + _ context.Context, request *GetHistoryTreeRequest, ) (*GetHistoryTreeResponse, error) { @@ -422,6 +427,7 @@ func (m *executionManagerImpl) serializeAppendHistoryNodesRequest( // AppendHistoryNodes add a node to history node table func (m *executionManagerImpl) AppendHistoryNodes( + _ context.Context, request *AppendHistoryNodesRequest, ) (*AppendHistoryNodesResponse, error) { @@ -441,6 +447,7 @@ func (m *executionManagerImpl) AppendHistoryNodes( // ReadHistoryBranchByBatch returns history node data for a branch by batch // Pagination is implemented here, the actual minNodeID passing to persistence layer is calculated along with token's LastNodeID func (m *executionManagerImpl) ReadHistoryBranchByBatch( + _ context.Context, request *ReadHistoryBranchRequest, ) (*ReadHistoryBranchByBatchResponse, error) { @@ -453,6 +460,7 @@ func (m *executionManagerImpl) ReadHistoryBranchByBatch( // ReadHistoryBranch returns history node data for a branch // Pagination is implemented here, the actual minNodeID passing to persistence layer is calculated along with token's LastNodeID func (m *executionManagerImpl) ReadHistoryBranch( + _ context.Context, request *ReadHistoryBranchRequest, ) (*ReadHistoryBranchResponse, error) { @@ -466,6 +474,7 @@ func (m *executionManagerImpl) ReadHistoryBranch( // Pagination is implemented here, the actual minNodeID passing to persistence layer is calculated along with token's LastNodeID // NOTE: this API should only be used by 3+DC func (m *executionManagerImpl) ReadRawHistoryBranch( + _ context.Context, request *ReadHistoryBranchRequest, ) (*ReadRawHistoryBranchResponse, error) { @@ -489,6 +498,7 @@ func (m *executionManagerImpl) ReadRawHistoryBranch( // ReadHistoryBranchReverse returns history node data for a branch // Pagination is implemented here, the actual minNodeID passing to persistence layer is calculated along with token's LastNodeID func (m *executionManagerImpl) ReadHistoryBranchReverse( + _ context.Context, request *ReadHistoryBranchReverseRequest, ) (*ReadHistoryBranchReverseResponse, error) { resp := &ReadHistoryBranchReverseResponse{} @@ -498,6 +508,7 @@ func (m *executionManagerImpl) ReadHistoryBranchReverse( } func (m *executionManagerImpl) GetAllHistoryTreeBranches( + _ context.Context, request *GetAllHistoryTreeBranchesRequest, ) (*GetAllHistoryTreeBranchesResponse, error) { resp, err := m.persistence.GetAllHistoryTreeBranches(request) diff --git a/common/persistence/history_manager_util.go b/common/persistence/history_manager_util.go index e95c87ae3cd..a04451620fc 100644 --- a/common/persistence/history_manager_util.go +++ b/common/persistence/history_manager_util.go @@ -25,6 +25,7 @@ package persistence import ( + "context" "sort" historypb "go.temporal.io/api/history/v1" @@ -35,11 +36,15 @@ import ( // ReadFullPageEvents reads a full page of history events from ExecutionManager. Due to storage format of V2 History // it is not guaranteed that pageSize amount of data is returned. Function returns the list of history events, the size // of data read, the next page token, and an error if present. -func ReadFullPageEvents(executionMgr ExecutionManager, req *ReadHistoryBranchRequest) ([]*historypb.HistoryEvent, int, []byte, error) { +func ReadFullPageEvents( + ctx context.Context, + executionMgr ExecutionManager, + req *ReadHistoryBranchRequest, +) ([]*historypb.HistoryEvent, int, []byte, error) { var historyEvents []*historypb.HistoryEvent size := 0 for { - response, err := executionMgr.ReadHistoryBranch(req) + response, err := executionMgr.ReadHistoryBranch(ctx, req) if err != nil { return nil, 0, nil, err } @@ -55,12 +60,16 @@ func ReadFullPageEvents(executionMgr ExecutionManager, req *ReadHistoryBranchReq // ReadFullPageEventsByBatch reads a full page of history events by batch from ExecutionManager. Due to storage format of V2 History // it is not guaranteed that pageSize amount of data is returned. Function returns the list of history batches, the size // of data read, the next page token, and an error if present. -func ReadFullPageEventsByBatch(executionMgr ExecutionManager, req *ReadHistoryBranchRequest) ([]*historypb.History, int, []byte, error) { +func ReadFullPageEventsByBatch( + ctx context.Context, + executionMgr ExecutionManager, + req *ReadHistoryBranchRequest, +) ([]*historypb.History, int, []byte, error) { var historyBatches []*historypb.History eventsRead := 0 size := 0 for { - response, err := executionMgr.ReadHistoryBranchByBatch(req) + response, err := executionMgr.ReadHistoryBranchByBatch(ctx, req) if err != nil { return nil, 0, nil, err } @@ -79,11 +88,15 @@ func ReadFullPageEventsByBatch(executionMgr ExecutionManager, req *ReadHistoryBr // ReadFullPageEventsReverse reads a full page of history events from ExecutionManager in reverse orcer. Due to storage // format of V2 History it is not guaranteed that pageSize amount of data is returned. Function returns the list of // history events, the size of data read, the next page token, and an error if present. -func ReadFullPageEventsReverse(executionMgr ExecutionManager, req *ReadHistoryBranchReverseRequest) ([]*historypb.HistoryEvent, int, []byte, error) { +func ReadFullPageEventsReverse( + ctx context.Context, + executionMgr ExecutionManager, + req *ReadHistoryBranchReverseRequest, +) ([]*historypb.HistoryEvent, int, []byte, error) { var historyEvents []*historypb.HistoryEvent size := 0 for { - response, err := executionMgr.ReadHistoryBranchReverse(req) + response, err := executionMgr.ReadHistoryBranchReverse(ctx, req) if err != nil { return nil, 0, nil, err } diff --git a/common/persistence/persistence-tests/executionManagerTest.go b/common/persistence/persistence-tests/executionManagerTest.go index 1e4026407e8..55d04b276a9 100644 --- a/common/persistence/persistence-tests/executionManagerTest.go +++ b/common/persistence/persistence-tests/executionManagerTest.go @@ -25,6 +25,7 @@ package persistencetests import ( + "context" "fmt" "math" "math/rand" @@ -61,6 +62,9 @@ type ( // override suite.Suite.Assertions with require.Assertions; this means that s.NotNil(nil) will stop the test, // not merely log an error *require.Assertions + + ctx context.Context + cancel context.CancelFunc } ) @@ -83,7 +87,13 @@ func (s *ExecutionManagerSuite) TearDownSuite() { func (s *ExecutionManagerSuite) SetupTest() { // Have to define our overridden assertions in the test setup. If we did it earlier, s.T() will return nil s.Assertions = require.New(s.T()) - s.ClearTasks() + s.ctx, s.cancel = context.WithTimeout(context.Background(), time.Second*30) + + s.ClearTasks(s.ctx) +} + +func (s *ExecutionManagerSuite) TearDownTest() { + s.cancel() } func (s *ExecutionManagerSuite) newRandomChecksum() *persistencespb.Checksum { @@ -149,16 +159,16 @@ func (s *ExecutionManagerSuite) TestCreateWorkflowExecutionDeDup() { Mode: p.CreateWorkflowModeBrandNew, } - _, err := s.ExecutionManager.CreateWorkflowExecution(req) + _, err := s.ExecutionManager.CreateWorkflowExecution(s.ctx, req) s.Nil(err) - info, err := s.GetWorkflowMutableState(namespaceID, workflowExecution) + info, err := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.Nil(err) s.assertChecksumsEqual(csum, info.Checksum) updatedInfo := copyWorkflowExecutionInfo(info.ExecutionInfo) updatedState := copyWorkflowExecutionState(info.ExecutionState) updatedState.State = enumsspb.WORKFLOW_EXECUTION_STATE_COMPLETED updatedState.Status = enumspb.WORKFLOW_EXECUTION_STATUS_COMPLETED - _, err = s.ExecutionManager.UpdateWorkflowExecution(&p.UpdateWorkflowExecutionRequest{ + _, err = s.ExecutionManager.UpdateWorkflowExecution(s.ctx, &p.UpdateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), UpdateWorkflowMutation: p.WorkflowMutation{ ExecutionInfo: updatedInfo, @@ -174,7 +184,7 @@ func (s *ExecutionManagerSuite) TestCreateWorkflowExecutionDeDup() { req.Mode = p.CreateWorkflowModeWorkflowIDReuse req.PreviousRunID = runID req.PreviousLastWriteVersion = common.EmptyVersion - _, err = s.ExecutionManager.CreateWorkflowExecution(req) + _, err = s.ExecutionManager.CreateWorkflowExecution(s.ctx, req) s.Error(err) s.IsType(&p.WorkflowConditionFailedError{}, err) } @@ -232,13 +242,13 @@ func (s *ExecutionManagerSuite) TestCreateWorkflowExecutionStateStatus() { req.NewWorkflowSnapshot.ExecutionState.State = enumsspb.WORKFLOW_EXECUTION_STATE_CREATED for _, invalidStatus := range invalidStatuses { req.NewWorkflowSnapshot.ExecutionState.Status = invalidStatus - _, err := s.ExecutionManager.CreateWorkflowExecution(req) + _, err := s.ExecutionManager.CreateWorkflowExecution(s.ctx, req) s.IsType(&serviceerror.Internal{}, err) } req.NewWorkflowSnapshot.ExecutionState.Status = enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING - _, err := s.ExecutionManager.CreateWorkflowExecution(req) + _, err := s.ExecutionManager.CreateWorkflowExecution(s.ctx, req) s.Nil(err) - info, err := s.GetWorkflowMutableState(namespaceID, workflowExecutionStatusCreated) + info, err := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecutionStatusCreated) s.Nil(err) s.Equal(enumsspb.WORKFLOW_EXECUTION_STATE_CREATED, info.ExecutionState.State) s.Equal(enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, info.ExecutionState.Status) @@ -253,13 +263,13 @@ func (s *ExecutionManagerSuite) TestCreateWorkflowExecutionStateStatus() { req.NewWorkflowSnapshot.ExecutionState.State = enumsspb.WORKFLOW_EXECUTION_STATE_RUNNING for _, invalidStatus := range invalidStatuses { req.NewWorkflowSnapshot.ExecutionState.Status = invalidStatus - _, err := s.ExecutionManager.CreateWorkflowExecution(req) + _, err := s.ExecutionManager.CreateWorkflowExecution(s.ctx, req) s.IsType(&serviceerror.Internal{}, err) } req.NewWorkflowSnapshot.ExecutionState.Status = enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING - _, err = s.ExecutionManager.CreateWorkflowExecution(req) + _, err = s.ExecutionManager.CreateWorkflowExecution(s.ctx, req) s.Nil(err) - info, err = s.GetWorkflowMutableState(namespaceID, workflowExecutionStatusRunning) + info, err = s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecutionStatusRunning) s.Nil(err) s.Equal(enumsspb.WORKFLOW_EXECUTION_STATE_RUNNING, info.ExecutionState.State) s.Equal(enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, info.ExecutionState.Status) @@ -277,13 +287,13 @@ func (s *ExecutionManagerSuite) TestCreateWorkflowExecutionStateStatus() { req.NewWorkflowSnapshot.ExecutionState.State = enumsspb.WORKFLOW_EXECUTION_STATE_ZOMBIE for _, invalidStatus := range invalidStatuses { req.NewWorkflowSnapshot.ExecutionState.Status = invalidStatus - _, err := s.ExecutionManager.CreateWorkflowExecution(req) + _, err := s.ExecutionManager.CreateWorkflowExecution(s.ctx, req) s.IsType(&serviceerror.Internal{}, err) } req.NewWorkflowSnapshot.ExecutionState.Status = enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING - _, err = s.ExecutionManager.CreateWorkflowExecution(req) + _, err = s.ExecutionManager.CreateWorkflowExecution(s.ctx, req) s.Nil(err) - info, err = s.GetWorkflowMutableState(namespaceID, workflowExecutionStatusZombie) + info, err = s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecutionStatusZombie) s.Nil(err) s.Equal(enumsspb.WORKFLOW_EXECUTION_STATE_ZOMBIE, info.ExecutionState.State) s.Equal(enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, info.ExecutionState.Status) @@ -333,9 +343,9 @@ func (s *ExecutionManagerSuite) TestCreateWorkflowExecutionWithZombieState() { RangeID: s.ShardInfo.GetRangeId(), Mode: p.CreateWorkflowModeZombie, } - _, err := s.ExecutionManager.CreateWorkflowExecution(req) + _, err := s.ExecutionManager.CreateWorkflowExecution(s.ctx, req) s.Nil(err) // allow creating a zombie workflow if no current running workflow - _, err = s.GetCurrentWorkflowRunID(namespaceID, workflowID) + _, err = s.GetCurrentWorkflowRunID(s.ctx, namespaceID, workflowID) s.IsType(&serviceerror.NotFound{}, err) // no current workflow workflowExecutionRunning := commonpb.WorkflowExecution{ @@ -346,9 +356,9 @@ func (s *ExecutionManagerSuite) TestCreateWorkflowExecutionWithZombieState() { req.Mode = p.CreateWorkflowModeBrandNew req.NewWorkflowSnapshot.ExecutionState.State = enumsspb.WORKFLOW_EXECUTION_STATE_RUNNING req.NewWorkflowSnapshot.ExecutionState.Status = enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING - _, err = s.ExecutionManager.CreateWorkflowExecution(req) + _, err = s.ExecutionManager.CreateWorkflowExecution(s.ctx, req) s.Nil(err) - currentRunID, err := s.GetCurrentWorkflowRunID(namespaceID, workflowID) + currentRunID, err := s.GetCurrentWorkflowRunID(s.ctx, namespaceID, workflowID) s.Nil(err) s.Equal(workflowExecutionRunning.GetRunId(), currentRunID) @@ -360,13 +370,13 @@ func (s *ExecutionManagerSuite) TestCreateWorkflowExecutionWithZombieState() { req.Mode = p.CreateWorkflowModeZombie req.NewWorkflowSnapshot.ExecutionState.State = enumsspb.WORKFLOW_EXECUTION_STATE_ZOMBIE req.NewWorkflowSnapshot.ExecutionState.Status = enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING - _, err = s.ExecutionManager.CreateWorkflowExecution(req) + _, err = s.ExecutionManager.CreateWorkflowExecution(s.ctx, req) s.Nil(err) // current run ID is still the prev running run ID - currentRunID, err = s.GetCurrentWorkflowRunID(namespaceID, workflowExecutionRunning.GetWorkflowId()) + currentRunID, err = s.GetCurrentWorkflowRunID(s.ctx, namespaceID, workflowExecutionRunning.GetWorkflowId()) s.Nil(err) s.Equal(workflowExecutionRunning.GetRunId(), currentRunID) - info, err := s.GetWorkflowMutableState(namespaceID, workflowExecutionZombie) + info, err := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecutionZombie) s.Nil(err) s.Equal(enumsspb.WORKFLOW_EXECUTION_STATE_ZOMBIE, info.ExecutionState.State) s.Equal(enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, info.ExecutionState.Status) @@ -425,9 +435,9 @@ func (s *ExecutionManagerSuite) TestUpdateWorkflowExecutionStateStatus() { req.NewWorkflowSnapshot.ExecutionState.State = enumsspb.WORKFLOW_EXECUTION_STATE_CREATED req.NewWorkflowSnapshot.ExecutionState.Status = enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING - _, err := s.ExecutionManager.CreateWorkflowExecution(req) + _, err := s.ExecutionManager.CreateWorkflowExecution(s.ctx, req) s.Nil(err) - state, err := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state, err := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.Nil(err) s.Equal(enumsspb.WORKFLOW_EXECUTION_STATE_CREATED, state.ExecutionState.State) s.Equal(enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, state.ExecutionState.Status) @@ -438,7 +448,7 @@ func (s *ExecutionManagerSuite) TestUpdateWorkflowExecutionStateStatus() { updatedState := copyWorkflowExecutionState(state.ExecutionState) updatedState.State = enumsspb.WORKFLOW_EXECUTION_STATE_RUNNING updatedState.Status = enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING - _, err = s.ExecutionManager.UpdateWorkflowExecution(&p.UpdateWorkflowExecutionRequest{ + _, err = s.ExecutionManager.UpdateWorkflowExecution(s.ctx, &p.UpdateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), UpdateWorkflowMutation: p.WorkflowMutation{ ExecutionInfo: updatedInfo, @@ -451,7 +461,7 @@ func (s *ExecutionManagerSuite) TestUpdateWorkflowExecutionStateStatus() { Mode: p.UpdateWorkflowModeUpdateCurrent, }) s.NoError(err) - state, err = s.GetWorkflowMutableState(namespaceID, workflowExecution) + state, err = s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.Nil(err) s.Equal(enumsspb.WORKFLOW_EXECUTION_STATE_RUNNING, state.ExecutionState.State) s.Equal(enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, state.ExecutionState.Status) @@ -462,7 +472,7 @@ func (s *ExecutionManagerSuite) TestUpdateWorkflowExecutionStateStatus() { updatedState.State = enumsspb.WORKFLOW_EXECUTION_STATE_RUNNING for _, status := range statuses { updatedState.Status = status - _, err = s.ExecutionManager.UpdateWorkflowExecution(&p.UpdateWorkflowExecutionRequest{ + _, err = s.ExecutionManager.UpdateWorkflowExecution(s.ctx, &p.UpdateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), UpdateWorkflowMutation: p.WorkflowMutation{ ExecutionInfo: updatedInfo, @@ -480,7 +490,7 @@ func (s *ExecutionManagerSuite) TestUpdateWorkflowExecutionStateStatus() { updatedState = copyWorkflowExecutionState(state.ExecutionState) updatedState.State = enumsspb.WORKFLOW_EXECUTION_STATE_COMPLETED updatedState.Status = enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING - _, err = s.ExecutionManager.UpdateWorkflowExecution(&p.UpdateWorkflowExecutionRequest{ + _, err = s.ExecutionManager.UpdateWorkflowExecution(s.ctx, &p.UpdateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), UpdateWorkflowMutation: p.WorkflowMutation{ ExecutionInfo: updatedInfo, @@ -495,7 +505,7 @@ func (s *ExecutionManagerSuite) TestUpdateWorkflowExecutionStateStatus() { for _, status := range statuses { updatedState.Status = status - _, err = s.ExecutionManager.UpdateWorkflowExecution(&p.UpdateWorkflowExecutionRequest{ + _, err = s.ExecutionManager.UpdateWorkflowExecution(s.ctx, &p.UpdateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), UpdateWorkflowMutation: p.WorkflowMutation{ ExecutionInfo: updatedInfo, @@ -507,7 +517,7 @@ func (s *ExecutionManagerSuite) TestUpdateWorkflowExecutionStateStatus() { Mode: p.UpdateWorkflowModeUpdateCurrent, }) s.Nil(err) - state, err = s.GetWorkflowMutableState(namespaceID, workflowExecution) + state, err = s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.Nil(err) s.Equal(enumsspb.WORKFLOW_EXECUTION_STATE_COMPLETED, state.ExecutionState.State) s.EqualValues(status, state.ExecutionState.Status) @@ -526,14 +536,14 @@ func (s *ExecutionManagerSuite) TestUpdateWorkflowExecutionStateStatus() { req.PreviousLastWriteVersion = common.EmptyVersion req.NewWorkflowSnapshot.ExecutionState.State = enumsspb.WORKFLOW_EXECUTION_STATE_RUNNING req.NewWorkflowSnapshot.ExecutionState.Status = enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING - _, err = s.ExecutionManager.CreateWorkflowExecution(req) + _, err = s.ExecutionManager.CreateWorkflowExecution(s.ctx, req) s.Nil(err) updatedInfo = copyWorkflowExecutionInfo(state.ExecutionInfo) updatedState = copyWorkflowExecutionState(state.ExecutionState) updatedState.State = enumsspb.WORKFLOW_EXECUTION_STATE_ZOMBIE updatedState.Status = enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING - _, err = s.ExecutionManager.UpdateWorkflowExecution(&p.UpdateWorkflowExecutionRequest{ + _, err = s.ExecutionManager.UpdateWorkflowExecution(s.ctx, &p.UpdateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), UpdateWorkflowMutation: p.WorkflowMutation{ ExecutionInfo: updatedInfo, @@ -545,7 +555,7 @@ func (s *ExecutionManagerSuite) TestUpdateWorkflowExecutionStateStatus() { Mode: p.UpdateWorkflowModeBypassCurrent, }) s.NoError(err) - state, err = s.GetWorkflowMutableState(namespaceID, workflowExecution) + state, err = s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.Nil(err) s.Equal(enumsspb.WORKFLOW_EXECUTION_STATE_ZOMBIE, state.ExecutionState.State) s.Equal(enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, state.ExecutionState.Status) @@ -555,7 +565,7 @@ func (s *ExecutionManagerSuite) TestUpdateWorkflowExecutionStateStatus() { updatedState.State = enumsspb.WORKFLOW_EXECUTION_STATE_ZOMBIE for _, status := range statuses { updatedState.Status = status - _, err = s.ExecutionManager.UpdateWorkflowExecution(&p.UpdateWorkflowExecutionRequest{ + _, err = s.ExecutionManager.UpdateWorkflowExecution(s.ctx, &p.UpdateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), UpdateWorkflowMutation: p.WorkflowMutation{ ExecutionInfo: updatedInfo, @@ -614,13 +624,13 @@ func (s *ExecutionManagerSuite) TestUpdateWorkflowExecutionWithZombieState() { RangeID: s.ShardInfo.GetRangeId(), Mode: p.CreateWorkflowModeBrandNew, } - _, err := s.ExecutionManager.CreateWorkflowExecution(req) + _, err := s.ExecutionManager.CreateWorkflowExecution(s.ctx, req) s.Nil(err) - currentRunID, err := s.GetCurrentWorkflowRunID(namespaceID, workflowID) + currentRunID, err := s.GetCurrentWorkflowRunID(s.ctx, namespaceID, workflowID) s.Nil(err) s.Equal(workflowExecution.GetRunId(), currentRunID) - info, err := s.GetWorkflowMutableState(namespaceID, workflowExecution) + info, err := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.Nil(err) s.assertChecksumsEqual(csum, info.Checksum) @@ -629,7 +639,7 @@ func (s *ExecutionManagerSuite) TestUpdateWorkflowExecutionWithZombieState() { updatedState := copyWorkflowExecutionState(info.ExecutionState) updatedState.State = enumsspb.WORKFLOW_EXECUTION_STATE_ZOMBIE updatedState.Status = enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING - _, err = s.ExecutionManager.UpdateWorkflowExecution(&p.UpdateWorkflowExecutionRequest{ + _, err = s.ExecutionManager.UpdateWorkflowExecution(s.ctx, &p.UpdateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), UpdateWorkflowMutation: p.WorkflowMutation{ ExecutionInfo: updatedInfo, @@ -647,7 +657,7 @@ func (s *ExecutionManagerSuite) TestUpdateWorkflowExecutionWithZombieState() { updatedState = copyWorkflowExecutionState(info.ExecutionState) updatedState.State = enumsspb.WORKFLOW_EXECUTION_STATE_COMPLETED updatedState.Status = enumspb.WORKFLOW_EXECUTION_STATUS_COMPLETED - _, err = s.ExecutionManager.UpdateWorkflowExecution(&p.UpdateWorkflowExecutionRequest{ + _, err = s.ExecutionManager.UpdateWorkflowExecution(s.ctx, &p.UpdateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), UpdateWorkflowMutation: p.WorkflowMutation{ ExecutionInfo: updatedInfo, @@ -674,21 +684,21 @@ func (s *ExecutionManagerSuite) TestUpdateWorkflowExecutionWithZombieState() { req.NewWorkflowSnapshot.ExecutionState.State = enumsspb.WORKFLOW_EXECUTION_STATE_RUNNING req.NewWorkflowSnapshot.ExecutionState.Status = enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING req.NewWorkflowSnapshot.Checksum = csum - _, err = s.ExecutionManager.CreateWorkflowExecution(req) + _, err = s.ExecutionManager.CreateWorkflowExecution(s.ctx, req) s.Nil(err) - currentRunID, err = s.GetCurrentWorkflowRunID(namespaceID, workflowID) + currentRunID, err = s.GetCurrentWorkflowRunID(s.ctx, namespaceID, workflowID) s.Nil(err) s.Equal(workflowExecutionRunning.GetRunId(), currentRunID) // get the workflow to be turned into a zombie - info, err = s.GetWorkflowMutableState(namespaceID, workflowExecution) + info, err = s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.Nil(err) s.assertChecksumsEqual(csum, info.Checksum) updatedInfo = copyWorkflowExecutionInfo(info.ExecutionInfo) updatedState = copyWorkflowExecutionState(info.ExecutionState) updatedState.State = enumsspb.WORKFLOW_EXECUTION_STATE_ZOMBIE updatedState.Status = enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING - _, err = s.ExecutionManager.UpdateWorkflowExecution(&p.UpdateWorkflowExecutionRequest{ + _, err = s.ExecutionManager.UpdateWorkflowExecution(s.ctx, &p.UpdateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), UpdateWorkflowMutation: p.WorkflowMutation{ ExecutionInfo: updatedInfo, @@ -701,13 +711,13 @@ func (s *ExecutionManagerSuite) TestUpdateWorkflowExecutionWithZombieState() { Mode: p.UpdateWorkflowModeBypassCurrent, }) s.NoError(err) - info, err = s.GetWorkflowMutableState(namespaceID, workflowExecution) + info, err = s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.Nil(err) s.Equal(enumsspb.WORKFLOW_EXECUTION_STATE_ZOMBIE, info.ExecutionState.State) s.Equal(enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, info.ExecutionState.Status) s.assertChecksumsEqual(csum, info.Checksum) // check current run ID is un touched - currentRunID, err = s.GetCurrentWorkflowRunID(namespaceID, workflowID) + currentRunID, err = s.GetCurrentWorkflowRunID(s.ctx, namespaceID, workflowID) s.Nil(err) s.Equal(workflowExecutionRunning.GetRunId(), currentRunID) } @@ -754,9 +764,9 @@ func (s *ExecutionManagerSuite) TestCreateWorkflowExecutionBrandNew() { Mode: p.CreateWorkflowModeBrandNew, } - _, err := s.ExecutionManager.CreateWorkflowExecution(req) + _, err := s.ExecutionManager.CreateWorkflowExecution(s.ctx, req) s.Nil(err) - _, err = s.ExecutionManager.CreateWorkflowExecution(req) + _, err = s.ExecutionManager.CreateWorkflowExecution(s.ctx, req) s.NotNil(err) alreadyStartedErr, ok := err.(*p.CurrentWorkflowConditionFailedError) s.True(ok, "err is not CurrentWorkflowConditionFailedError") @@ -809,13 +819,13 @@ func (s *ExecutionManagerSuite) TestUpsertWorkflowActivity() { RangeID: s.ShardInfo.GetRangeId(), Mode: p.CreateWorkflowModeBrandNew, } - _, err := s.ExecutionManager.CreateWorkflowExecution(req) + _, err := s.ExecutionManager.CreateWorkflowExecution(s.ctx, req) s.Nil(err) - currentRunID, err := s.GetCurrentWorkflowRunID(namespaceID, workflowID) + currentRunID, err := s.GetCurrentWorkflowRunID(s.ctx, namespaceID, workflowID) s.Nil(err) s.Equal(workflowExecution.GetRunId(), currentRunID) - info, err := s.GetWorkflowMutableState(namespaceID, workflowExecution) + info, err := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.Nil(err) s.assertChecksumsEqual(csum, info.Checksum) s.Equal(0, len(info.ActivityInfos)) @@ -823,7 +833,7 @@ func (s *ExecutionManagerSuite) TestUpsertWorkflowActivity() { // insert a new activity updatedInfo := copyWorkflowExecutionInfo(info.ExecutionInfo) updatedState := copyWorkflowExecutionState(info.ExecutionState) - _, err = s.ExecutionManager.UpdateWorkflowExecution(&p.UpdateWorkflowExecutionRequest{ + _, err = s.ExecutionManager.UpdateWorkflowExecution(s.ctx, &p.UpdateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), UpdateWorkflowMutation: p.WorkflowMutation{ ExecutionInfo: updatedInfo, @@ -844,13 +854,13 @@ func (s *ExecutionManagerSuite) TestUpsertWorkflowActivity() { }) s.Nil(err) - info2, err := s.GetWorkflowMutableState(namespaceID, workflowExecution) + info2, err := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.Nil(err) s.Equal(1, len(info2.ActivityInfos)) s.Equal("test-activity-tasktlist-1", info2.ActivityInfos[100].TaskQueue) // upsert the previous activity - _, err = s.ExecutionManager.UpdateWorkflowExecution(&p.UpdateWorkflowExecutionRequest{ + _, err = s.ExecutionManager.UpdateWorkflowExecution(s.ctx, &p.UpdateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), UpdateWorkflowMutation: p.WorkflowMutation{ ExecutionInfo: updatedInfo, @@ -869,7 +879,7 @@ func (s *ExecutionManagerSuite) TestUpsertWorkflowActivity() { RangeID: s.ShardInfo.RangeId, Mode: p.UpdateWorkflowModeUpdateCurrent, }) - info3, err := s.GetWorkflowMutableState(namespaceID, workflowExecution) + info3, err := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.Nil(err) s.Equal(1, len(info3.ActivityInfos)) s.Equal("test-activity-tasktlist-2", info3.ActivityInfos[100].TaskQueue) @@ -890,11 +900,11 @@ func (s *ExecutionManagerSuite) TestCreateWorkflowExecutionRunIDReuseWithoutRepl nextEventID := int64(3) workflowTaskScheduleID := int64(2) - task0, err0 := s.CreateWorkflowExecution(namespaceID, workflowExecution, taskqueue, workflowType, workflowTimeout, workflowTaskTimeout, nextEventID, lastProcessedEventID, workflowTaskScheduleID, nil) + task0, err0 := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, taskqueue, workflowType, workflowTimeout, workflowTaskTimeout, nextEventID, lastProcessedEventID, workflowTaskScheduleID, nil) s.NoError(err0) s.NotNil(task0, "Expected non empty task identifier.") - state0, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state0, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) s.assertChecksumsEqual(testWorkflowChecksum, state0.Checksum) info0 := state0.ExecutionInfo @@ -904,7 +914,7 @@ func (s *ExecutionManagerSuite) TestCreateWorkflowExecutionRunIDReuseWithoutRepl closeState.Status = enumspb.WORKFLOW_EXECUTION_STATUS_COMPLETED closeInfo.LastWorkflowTaskStartId = int64(2) - err2 := s.UpdateWorkflowExecution(closeInfo, closeState, int64(5), nil, nil, nextEventID, + err2 := s.UpdateWorkflowExecution(s.ctx, closeInfo, closeState, int64(5), nil, nil, nextEventID, nil, nil, nil, nil, nil) s.NoError(err2) @@ -914,7 +924,7 @@ func (s *ExecutionManagerSuite) TestCreateWorkflowExecutionRunIDReuseWithoutRepl } // this create should work since we are relying the business logic in history engine // to check whether the existing running workflow has finished - _, err3 := s.ExecutionManager.CreateWorkflowExecution(&p.CreateWorkflowExecutionRequest{ + _, err3 := s.ExecutionManager.CreateWorkflowExecution(s.ctx, &p.CreateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), NewWorkflowSnapshot: p.WorkflowSnapshot{ ExecutionInfo: &persistencespb.WorkflowExecutionInfo{ @@ -961,7 +971,7 @@ func (s *ExecutionManagerSuite) TestCreateWorkflowExecutionConcurrentCreate() { nextEventID := int64(3) workflowTaskScheduleID := int64(2) - task0, err0 := s.CreateWorkflowExecution(namespaceID, workflowExecution, taskqueue, workflowType, workflowTimeout, workflowTaskTimeout, nextEventID, lastProcessedEventID, workflowTaskScheduleID, nil) + task0, err0 := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, taskqueue, workflowType, workflowTimeout, workflowTaskTimeout, nextEventID, lastProcessedEventID, workflowTaskScheduleID, nil) s.Nil(err0, "No error expected.") s.NotNil(task0, "Expected non empty task identifier.") @@ -977,7 +987,7 @@ func (s *ExecutionManagerSuite) TestCreateWorkflowExecutionConcurrentCreate() { RunId: uuid.New(), } - state0, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state0, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) info0 := state0.ExecutionInfo continueAsNewInfo := copyWorkflowExecutionInfo(info0) @@ -986,7 +996,7 @@ func (s *ExecutionManagerSuite) TestCreateWorkflowExecutionConcurrentCreate() { continueAsNewNextEventID := int64(5) continueAsNewInfo.LastWorkflowTaskStartId = int64(2) - err2 := s.ContinueAsNewExecution(continueAsNewInfo, continueAsNewState, continueAsNewNextEventID, state0.NextEventId, newExecution, int64(3), int64(2), nil) + err2 := s.ContinueAsNewExecution(s.ctx, continueAsNewInfo, continueAsNewState, continueAsNewNextEventID, state0.NextEventId, newExecution, int64(3), int64(2), nil) if err2 != nil { errCount := atomic.AddInt32(&numOfErr, 1) if errCount > 1 { @@ -1015,11 +1025,11 @@ func (s *ExecutionManagerSuite) TestPersistenceStartWorkflow() { workflowExecution.WorkflowId, workflowExecution.RunId, ) - task0, err0 := s.CreateWorkflowExecution(namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + task0, err0 := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.NoError(err0) s.NotNil(task0, "Expected non empty task identifier.") - task1, err1 := s.CreateWorkflowExecution(namespaceID, workflowExecution, "queue1", "wType1", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(14), 3, 0, 2, nil) + task1, err1 := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, "queue1", "wType1", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(14), 3, 0, 2, nil) s.Error(err1, "Expected workflow creation to fail.") startedErr, ok := err1.(*p.CurrentWorkflowConditionFailedError) s.True(ok, fmt.Sprintf("Expected CurrentWorkflowConditionFailedError, but actual is %v", err1)) @@ -1029,7 +1039,7 @@ func (s *ExecutionManagerSuite) TestPersistenceStartWorkflow() { s.Equal(common.EmptyVersion, startedErr.LastWriteVersion, startedErr.Msg) s.Empty(task1, "Expected empty task identifier.") - response, err2 := s.ExecutionManager.CreateWorkflowExecution(&p.CreateWorkflowExecutionRequest{ + response, err2 := s.ExecutionManager.CreateWorkflowExecution(s.ctx, &p.CreateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), NewWorkflowSnapshot: p.WorkflowSnapshot{ ExecutionInfo: &persistencespb.WorkflowExecutionInfo{ @@ -1155,11 +1165,11 @@ func (s *ExecutionManagerSuite) TestGetWorkflow() { Mode: p.CreateWorkflowModeBrandNew, } - createResp, err := s.ExecutionManager.CreateWorkflowExecution(createReq) + createResp, err := s.ExecutionManager.CreateWorkflowExecution(s.ctx, createReq) s.NoError(err) s.NotNil(createResp, "Expected non empty task identifier.") - state, err := s.GetWorkflowMutableState(createReq.NewWorkflowSnapshot.ExecutionInfo.NamespaceId, + state, err := s.GetWorkflowMutableState(s.ctx, createReq.NewWorkflowSnapshot.ExecutionInfo.NamespaceId, commonpb.WorkflowExecution{ WorkflowId: createReq.NewWorkflowSnapshot.ExecutionInfo.WorkflowId, RunId: createReq.NewWorkflowSnapshot.ExecutionState.RunId, @@ -1219,11 +1229,11 @@ func (s *ExecutionManagerSuite) TestUpdateWorkflow() { WorkflowId: "update-workflow-test", RunId: "5ba5e531-e46b-48d9-b4b3-859919839553", } - task0, err0 := s.CreateWorkflowExecution(namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + task0, err0 := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.NoError(err0) s.NotNil(task0, "Expected non empty task identifier.") - state0, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state0, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) info0 := state0.ExecutionInfo s.NotNil(info0, "Valid Workflow info expected.") @@ -1284,10 +1294,10 @@ func (s *ExecutionManagerSuite) TestUpdateWorkflow() { updatedInfo.Memo = map[string]*commonpb.Payload{memoKey: memoVal} updatedInfo.ExecutionStats.HistorySize = math.MaxInt64 - err2 := s.UpdateWorkflowExecution(updatedInfo, updatedState, int64(5), []int64{int64(4)}, nil, int64(3), nil, nil, nil, nil, nil) + err2 := s.UpdateWorkflowExecution(s.ctx, updatedInfo, updatedState, int64(5), []int64{int64(4)}, nil, int64(3), nil, nil, nil, nil, nil) s.NoError(err2) - state1, err3 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state1, err3 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err3) info1 := state1.ExecutionInfo s.NotNil(info1, "Valid Workflow info expected.") @@ -1333,11 +1343,11 @@ func (s *ExecutionManagerSuite) TestUpdateWorkflow() { failedUpdateInfo := copyWorkflowExecutionInfo(updatedInfo) failedUpdateState := copyWorkflowExecutionState(updatedState) - err4 := s.UpdateWorkflowExecution(failedUpdateInfo, failedUpdateState, state0.NextEventId, []int64{int64(5)}, nil, int64(3), nil, nil, nil, nil, nil) + err4 := s.UpdateWorkflowExecution(s.ctx, failedUpdateInfo, failedUpdateState, state0.NextEventId, []int64{int64(5)}, nil, int64(3), nil, nil, nil, nil, nil) s.Error(err4, "expected non nil error.") s.IsType(&p.WorkflowConditionFailedError{}, err4) - state2, err4 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state2, err4 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err4) info2 := state2.ExecutionInfo s.NotNil(info2, "Valid Workflow info expected.") @@ -1378,11 +1388,11 @@ func (s *ExecutionManagerSuite) TestUpdateWorkflow() { s.Equal(memoVal, memoVal2) s.assertChecksumsEqual(testWorkflowChecksum, state2.Checksum) - err5 := s.UpdateWorkflowExecutionWithRangeID(failedUpdateInfo, failedUpdateState, state0.NextEventId, []int64{int64(5)}, nil, int64(12345), int64(5), nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + err5 := s.UpdateWorkflowExecutionWithRangeID(s.ctx, failedUpdateInfo, failedUpdateState, state0.NextEventId, []int64{int64(5)}, nil, int64(12345), int64(5), nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) s.Error(err5, "expected non nil error.") s.IsType(&p.ShardOwnershipLostError{}, err5) - state3, err6 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state3, err6 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err6) info3 := state3.ExecutionInfo s.NotNil(info3, "Valid Workflow info expected.") @@ -1424,11 +1434,11 @@ func (s *ExecutionManagerSuite) TestUpdateWorkflow() { s.assertChecksumsEqual(testWorkflowChecksum, state3.Checksum) // update with incorrect rangeID and condition(next_event_id) - err7 := s.UpdateWorkflowExecutionWithRangeID(failedUpdateInfo, failedUpdateState, state0.NextEventId, []int64{int64(5)}, nil, int64(12345), int64(3), nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + err7 := s.UpdateWorkflowExecutionWithRangeID(s.ctx, failedUpdateInfo, failedUpdateState, state0.NextEventId, []int64{int64(5)}, nil, int64(12345), int64(3), nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) s.Error(err7, "expected non nil error.") s.IsType(&p.ShardOwnershipLostError{}, err7) - state4, err8 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state4, err8 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err8) info4 := state4.ExecutionInfo s.NotNil(info4, "Valid Workflow info expected.") @@ -1476,11 +1486,11 @@ func (s *ExecutionManagerSuite) TestDeleteWorkflow() { WorkflowId: "delete-workflow-test", RunId: "4e0917f2-9361-4a14-b16f-1fafe09b287a", } - task0, err0 := s.CreateWorkflowExecution(namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + task0, err0 := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.NoError(err0) s.NotNil(task0, "Expected non empty task identifier.") - state0, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state0, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) info0 := state0.ExecutionInfo s.NotNil(info0, "Valid Workflow info expected.") @@ -1500,16 +1510,16 @@ func (s *ExecutionManagerSuite) TestDeleteWorkflow() { s.Equal(common.EmptyEventID, info0.WorkflowTaskStartedId) s.EqualValues(1, int64(info0.WorkflowTaskTimeout.Seconds())) - err4 := s.DeleteCurrentWorkflowExecution(info0, state0.ExecutionState) + err4 := s.DeleteCurrentWorkflowExecution(s.ctx, info0, state0.ExecutionState) s.NoError(err4) - err4 = s.DeleteWorkflowExecution(info0, state0.ExecutionState) + err4 = s.DeleteWorkflowExecution(s.ctx, info0, state0.ExecutionState) s.NoError(err4) - _, err3 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + _, err3 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.Error(err3, "expected non nil error.") s.IsType(&serviceerror.NotFound{}, err3) - err5 := s.DeleteWorkflowExecution(info0, state0.ExecutionState) + err5 := s.DeleteWorkflowExecution(s.ctx, info0, state0.ExecutionState) s.NoError(err5) } @@ -1525,24 +1535,24 @@ func (s *ExecutionManagerSuite) TestDeleteCurrentWorkflow() { RunId: "6cae4054-6ba7-46d3-8755-e3c2db6f74ea", } - task0, err0 := s.CreateWorkflowExecution(namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + task0, err0 := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.NoError(err0) s.NotNil(task0, "Expected non empty task identifier.") - runID0, err1 := s.GetCurrentWorkflowRunID(namespaceID, workflowExecution.GetWorkflowId()) + runID0, err1 := s.GetCurrentWorkflowRunID(s.ctx, namespaceID, workflowExecution.GetWorkflowId()) s.NoError(err1) s.Equal(workflowExecution.GetRunId(), runID0) - info0, err2 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + info0, err2 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err2) updatedInfo1 := copyWorkflowExecutionInfo(info0.ExecutionInfo) updatedState1 := copyWorkflowExecutionState(info0.ExecutionState) updatedInfo1.LastWorkflowTaskStartId = int64(2) - err3 := s.UpdateWorkflowExecutionAndFinish(updatedInfo1, updatedState1, int64(6), int64(3)) + err3 := s.UpdateWorkflowExecutionAndFinish(s.ctx, updatedInfo1, updatedState1, int64(6), int64(3)) s.NoError(err3) - runID4, err4 := s.GetCurrentWorkflowRunID(namespaceID, workflowExecution.GetWorkflowId()) + runID4, err4 := s.GetCurrentWorkflowRunID(s.ctx, namespaceID, workflowExecution.GetWorkflowId()) s.NoError(err4) s.Equal(workflowExecution.GetRunId(), runID4) @@ -1552,23 +1562,23 @@ func (s *ExecutionManagerSuite) TestDeleteCurrentWorkflow() { } // test wrong run id with conditional delete - s.NoError(s.DeleteCurrentWorkflowExecution(fakeInfo, &persistencespb.WorkflowExecutionState{RunId: uuid.New()})) + s.NoError(s.DeleteCurrentWorkflowExecution(s.ctx, fakeInfo, &persistencespb.WorkflowExecutionState{RunId: uuid.New()})) - runID5, err5 := s.GetCurrentWorkflowRunID(namespaceID, workflowExecution.GetWorkflowId()) + runID5, err5 := s.GetCurrentWorkflowRunID(s.ctx, namespaceID, workflowExecution.GetWorkflowId()) s.NoError(err5) s.Equal(workflowExecution.GetRunId(), runID5) // simulate a timer_task deleting execution after retention - s.NoError(s.DeleteCurrentWorkflowExecution(info0.ExecutionInfo, info0.ExecutionState)) + s.NoError(s.DeleteCurrentWorkflowExecution(s.ctx, info0.ExecutionInfo, info0.ExecutionState)) - runID0, err1 = s.GetCurrentWorkflowRunID(namespaceID, workflowExecution.GetWorkflowId()) + runID0, err1 = s.GetCurrentWorkflowRunID(s.ctx, namespaceID, workflowExecution.GetWorkflowId()) s.Error(err1) s.Empty(runID0) _, ok := err1.(*serviceerror.NotFound) s.True(ok) // execution record should still be there - _, err2 = s.GetWorkflowMutableState(namespaceID, workflowExecution) + _, err2 = s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err2) } @@ -1581,41 +1591,41 @@ func (s *ExecutionManagerSuite) TestUpdateDeleteWorkflow() { RunId: "6cae4054-6ba7-46d3-8755-e3c2db6f74ea", } - task0, err0 := s.CreateWorkflowExecution(namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + task0, err0 := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.NoError(err0) s.NotNil(task0, "Expected non empty task identifier.") - runID0, err1 := s.GetCurrentWorkflowRunID(namespaceID, workflowExecution.GetWorkflowId()) + runID0, err1 := s.GetCurrentWorkflowRunID(s.ctx, namespaceID, workflowExecution.GetWorkflowId()) s.NoError(err1) s.Equal(workflowExecution.GetRunId(), runID0) - info0, err2 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + info0, err2 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err2) updatedInfo1 := copyWorkflowExecutionInfo(info0.ExecutionInfo) updatedState1 := copyWorkflowExecutionState(info0.ExecutionState) updatedInfo1.LastWorkflowTaskStartId = int64(2) - err3 := s.UpdateWorkflowExecutionAndFinish(updatedInfo1, updatedState1, int64(6), int64(3)) + err3 := s.UpdateWorkflowExecutionAndFinish(s.ctx, updatedInfo1, updatedState1, int64(6), int64(3)) s.NoError(err3) - runID4, err4 := s.GetCurrentWorkflowRunID(namespaceID, workflowExecution.GetWorkflowId()) + runID4, err4 := s.GetCurrentWorkflowRunID(s.ctx, namespaceID, workflowExecution.GetWorkflowId()) s.NoError(err4) s.Equal(workflowExecution.GetRunId(), runID4) // simulate a timer_task deleting execution after retention - err5 := s.DeleteCurrentWorkflowExecution(info0.ExecutionInfo, info0.ExecutionState) + err5 := s.DeleteCurrentWorkflowExecution(s.ctx, info0.ExecutionInfo, info0.ExecutionState) s.NoError(err5) - err6 := s.DeleteWorkflowExecution(info0.ExecutionInfo, info0.ExecutionState) + err6 := s.DeleteWorkflowExecution(s.ctx, info0.ExecutionInfo, info0.ExecutionState) s.NoError(err6) time.Sleep(time.Duration(finishedCurrentExecutionRetentionTTL*2) * time.Second) - runID0, err1 = s.GetCurrentWorkflowRunID(namespaceID, workflowExecution.GetWorkflowId()) + runID0, err1 = s.GetCurrentWorkflowRunID(s.ctx, namespaceID, workflowExecution.GetWorkflowId()) s.Error(err1) s.Empty(runID0) _, ok := err1.(*serviceerror.NotFound) // execution record should still be there - _, err2 = s.GetWorkflowMutableState(namespaceID, workflowExecution) + _, err2 = s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.Error(err2) _, ok = err2.(*serviceerror.NotFound) s.True(ok) @@ -1629,34 +1639,34 @@ func (s *ExecutionManagerSuite) TestCleanupCorruptedWorkflow() { RunId: "6cae4054-6ba7-46d3-8755-e3c2db6f74ea", } - task0, err0 := s.CreateWorkflowExecution(namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + task0, err0 := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.NoError(err0) s.NotNil(task0, "Expected non empty task identifier.") - runID0, err1 := s.GetCurrentWorkflowRunID(namespaceID, workflowExecution.GetWorkflowId()) + runID0, err1 := s.GetCurrentWorkflowRunID(s.ctx, namespaceID, workflowExecution.GetWorkflowId()) s.NoError(err1) s.Equal(workflowExecution.GetRunId(), runID0) - info0, err2 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + info0, err2 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err2) // deleting current record and verify - err3 := s.DeleteCurrentWorkflowExecution(info0.ExecutionInfo, info0.ExecutionState) + err3 := s.DeleteCurrentWorkflowExecution(s.ctx, info0.ExecutionInfo, info0.ExecutionState) s.NoError(err3) - runID0, err4 := s.GetCurrentWorkflowRunID(namespaceID, workflowExecution.GetWorkflowId()) + runID0, err4 := s.GetCurrentWorkflowRunID(s.ctx, namespaceID, workflowExecution.GetWorkflowId()) s.Error(err4) s.Empty(runID0) _, ok := err4.(*serviceerror.NotFound) s.True(ok) // we should still be able to load with runID - info1, err5 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + info1, err5 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err5) s.Equal(info0, info1) // mark it as corrupted info0.ExecutionState.State = enumsspb.WORKFLOW_EXECUTION_STATE_CORRUPTED - _, err6 := s.ExecutionManager.UpdateWorkflowExecution(&p.UpdateWorkflowExecutionRequest{ + _, err6 := s.ExecutionManager.UpdateWorkflowExecution(s.ctx, &p.UpdateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), UpdateWorkflowMutation: p.WorkflowMutation{ ExecutionInfo: info0.ExecutionInfo, @@ -1671,7 +1681,7 @@ func (s *ExecutionManagerSuite) TestCleanupCorruptedWorkflow() { s.NoError(err6) // we should still be able to load with runID - info2, err7 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + info2, err7 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err7) s.Equal(enumsspb.WORKFLOW_EXECUTION_STATE_CORRUPTED, info2.ExecutionState.State) info2.ExecutionState.State = info1.ExecutionState.State @@ -1679,13 +1689,13 @@ func (s *ExecutionManagerSuite) TestCleanupCorruptedWorkflow() { s.Equal(info2, info1) // delete the run - err8 := s.DeleteCurrentWorkflowExecution(info0.ExecutionInfo, info0.ExecutionState) + err8 := s.DeleteCurrentWorkflowExecution(s.ctx, info0.ExecutionInfo, info0.ExecutionState) s.NoError(err8) - err8 = s.DeleteWorkflowExecution(info0.ExecutionInfo, info0.ExecutionState) + err8 = s.DeleteWorkflowExecution(s.ctx, info0.ExecutionInfo, info0.ExecutionState) s.NoError(err8) // execution record should be gone - _, err9 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + _, err9 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.Error(err9) _, ok = err9.(*serviceerror.NotFound) s.True(ok) @@ -1699,11 +1709,11 @@ func (s *ExecutionManagerSuite) TestGetCurrentWorkflow() { RunId: "6cae4054-6ba7-46d3-8755-e3c2db6f74ea", } - task0, err0 := s.CreateWorkflowExecution(namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + task0, err0 := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.NoError(err0) s.NotNil(task0, "Expected non empty task identifier.") - response, err := s.ExecutionManager.GetCurrentExecution(&p.GetCurrentExecutionRequest{ + response, err := s.ExecutionManager.GetCurrentExecution(s.ctx, &p.GetCurrentExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), NamespaceID: namespaceID, WorkflowID: workflowExecution.GetWorkflowId(), @@ -1711,16 +1721,16 @@ func (s *ExecutionManagerSuite) TestGetCurrentWorkflow() { s.NoError(err) s.Equal(workflowExecution.GetRunId(), response.RunID) - info0, err2 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + info0, err2 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err2) updatedInfo1 := copyWorkflowExecutionInfo(info0.ExecutionInfo) updatedState1 := copyWorkflowExecutionState(info0.ExecutionState) updatedInfo1.LastWorkflowTaskStartId = int64(2) - err3 := s.UpdateWorkflowExecutionAndFinish(updatedInfo1, updatedState1, int64(6), int64(3)) + err3 := s.UpdateWorkflowExecutionAndFinish(s.ctx, updatedInfo1, updatedState1, int64(6), int64(3)) s.NoError(err3) - runID4, err4 := s.GetCurrentWorkflowRunID(namespaceID, workflowExecution.GetWorkflowId()) + runID4, err4 := s.GetCurrentWorkflowRunID(s.ctx, namespaceID, workflowExecution.GetWorkflowId()) s.NoError(err4) s.Equal(workflowExecution.GetRunId(), runID4) @@ -1729,7 +1739,7 @@ func (s *ExecutionManagerSuite) TestGetCurrentWorkflow() { RunId: "c3ff4bc6-de18-4643-83b2-037a33f45322", } - task1, err5 := s.CreateWorkflowExecution(namespaceID, workflowExecution2, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + task1, err5 := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution2, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.Error(err5, "Error expected.") s.Empty(task1, "Expected empty task identifier.") } @@ -1747,11 +1757,11 @@ func (s *ExecutionManagerSuite) TestTransferTasksThroughUpdate() { workflowExecution.RunId, ) - task0, err0 := s.CreateWorkflowExecution(namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + task0, err0 := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.NoError(err0) s.NotNil(task0, "Expected non empty task identifier.") - tasks1, err1 := s.GetTransferTasks(1, false) + tasks1, err1 := s.GetTransferTasks(s.ctx, 1, false) s.NoError(err1) s.NotNil(tasks1, "expected valid list of tasks.") s.Equal(1, len(tasks1), "Expected 1 workflow task.") @@ -1765,19 +1775,19 @@ func (s *ExecutionManagerSuite) TestTransferTasksThroughUpdate() { Version: 0, }, task1) - err3 := s.CompleteTransferTask(task1.GetTaskID()) + err3 := s.CompleteTransferTask(s.ctx, task1.GetTaskID()) s.NoError(err3) - state0, err11 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state0, err11 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err11) info0 := state0.ExecutionInfo updatedInfo := copyWorkflowExecutionInfo(info0) updatedState0 := copyWorkflowExecutionState(state0.ExecutionState) updatedInfo.LastWorkflowTaskStartId = int64(2) - err2 := s.UpdateWorkflowExecution(updatedInfo, updatedState0, int64(5), nil, []int64{int64(4)}, int64(3), nil, nil, nil, nil, nil) + err2 := s.UpdateWorkflowExecution(s.ctx, updatedInfo, updatedState0, int64(5), nil, []int64{int64(4)}, int64(3), nil, nil, nil, nil, nil) s.NoError(err2) - tasks2, err1 := s.GetTransferTasks(1, false) + tasks2, err1 := s.GetTransferTasks(s.ctx, 1, false) s.NoError(err1) s.NotNil(tasks2, "expected valid list of tasks.") s.Equal(1, len(tasks2), "Expected 1 workflow task.") @@ -1792,26 +1802,26 @@ func (s *ExecutionManagerSuite) TestTransferTasksThroughUpdate() { Version: 0, }, task2) - err4 := s.CompleteTransferTask(task2.GetTaskID()) + err4 := s.CompleteTransferTask(s.ctx, task2.GetTaskID()) s.NoError(err4) - state1, _ := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state1, _ := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) info1 := state1.ExecutionInfo updatedInfo1 := copyWorkflowExecutionInfo(info1) updatedState1 := copyWorkflowExecutionState(state1.ExecutionState) updatedInfo1.LastWorkflowTaskStartId = int64(2) - err5 := s.UpdateWorkflowExecutionAndFinish(updatedInfo1, updatedState1, int64(6), int64(5)) + err5 := s.UpdateWorkflowExecutionAndFinish(s.ctx, updatedInfo1, updatedState1, int64(6), int64(5)) s.NoError(err5) newExecution := commonpb.WorkflowExecution{ WorkflowId: workflowExecution.GetWorkflowId(), RunId: "2a038c8f-b575-4151-8d2c-d443e999ab5a", } - runID6, err6 := s.GetCurrentWorkflowRunID(namespaceID, newExecution.GetWorkflowId()) + runID6, err6 := s.GetCurrentWorkflowRunID(s.ctx, namespaceID, newExecution.GetWorkflowId()) s.NoError(err6) s.Equal(workflowExecution.GetRunId(), runID6) - tasks3, err7 := s.GetTransferTasks(1, false) + tasks3, err7 := s.GetTransferTasks(s.ctx, 1, false) s.NoError(err7) s.NotNil(tasks3, "expected valid list of tasks.") s.Equal(1, len(tasks3), "Expected 1 workflow task.") @@ -1823,15 +1833,15 @@ func (s *ExecutionManagerSuite) TestTransferTasksThroughUpdate() { Version: 0, }, task3) - err8 := s.CompleteTransferTask(task3.GetTaskID()) + err8 := s.CompleteTransferTask(s.ctx, task3.GetTaskID()) s.NoError(err8) - _, err9 := s.CreateWorkflowExecution(namespaceID, newExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + _, err9 := s.CreateWorkflowExecution(s.ctx, namespaceID, newExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.Error(err9, "createWFExecution (brand_new) must fail when there is a previous instance of workflow state already in DB") - err10 := s.DeleteCurrentWorkflowExecution(info1, state1.ExecutionState) + err10 := s.DeleteCurrentWorkflowExecution(s.ctx, info1, state1.ExecutionState) s.NoError(err10) - err10 = s.DeleteWorkflowExecution(info1, state1.ExecutionState) + err10 = s.DeleteWorkflowExecution(s.ctx, info1, state1.ExecutionState) s.NoError(err10) } @@ -1848,19 +1858,19 @@ func (s *ExecutionManagerSuite) TestCancelTransferTaskTasks() { workflowExecution.RunId, ) - task0, err := s.CreateWorkflowExecution(namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + task0, err := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.NoError(err) s.NotNil(task0, "Expected non empty task identifier.") - taskD, err := s.GetTransferTasks(1, false) + taskD, err := s.GetTransferTasks(s.ctx, 1, false) s.Equal(1, len(taskD), "Expected 1 workflow task.") - err = s.CompleteTransferTask(taskD[0].GetTaskID()) + err = s.CompleteTransferTask(s.ctx, taskD[0].GetTaskID()) s.NoError(err) // Lookup is time-sensitive, hence retry var deleteCheck []tasks.Task for i := 0; i < 3; i++ { - deleteCheck, err = s.GetTransferTasks(1, false) + deleteCheck, err = s.GetTransferTasks(s.ctx, 1, false) if len(deleteCheck) == 0 { break } @@ -1869,7 +1879,7 @@ func (s *ExecutionManagerSuite) TestCancelTransferTaskTasks() { s.NoError(err) s.Equal(0, len(deleteCheck), "Expected no workflow task.") - state1, err := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state1, err := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err) s.NotNil(state1.ExecutionInfo, "Valid Workflow info expected.") updatedInfo1 := copyWorkflowExecutionInfo(state1.ExecutionInfo) @@ -1888,10 +1898,10 @@ func (s *ExecutionManagerSuite) TestCancelTransferTaskTasks() { TargetChildWorkflowOnly: targetChildWorkflowOnly, InitiatedID: 1, }} - err = s.UpdateWorkflowExecutionWithTransferTasks(updatedInfo1, updatedState1, state1.NextEventId, int64(3), transferTasks, nil) + err = s.UpdateWorkflowExecutionWithTransferTasks(s.ctx, updatedInfo1, updatedState1, state1.NextEventId, int64(3), transferTasks, nil) s.NoError(err) - tasks1, err := s.GetTransferTasks(1, false) + tasks1, err := s.GetTransferTasks(s.ctx, 1, false) s.NoError(err) s.NotNil(tasks1, "expected valid list of tasks.") s.Equal(1, len(tasks1), "Expected 1 cancel task.") @@ -1908,7 +1918,7 @@ func (s *ExecutionManagerSuite) TestCancelTransferTaskTasks() { Version: 0, }, task1) - err = s.CompleteTransferTask(task1.GetTaskID()) + err = s.CompleteTransferTask(s.ctx, task1.GetTaskID()) s.NoError(err) targetNamespaceID = "f2bfaab6-7e8b-4fac-9a62-17da8d37becb" @@ -1925,17 +1935,17 @@ func (s *ExecutionManagerSuite) TestCancelTransferTaskTasks() { InitiatedID: 3, }} - state2, err := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state2, err := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err) info2 := state2.ExecutionInfo s.NotNil(info2, "Valid Workflow info expected.") updatedInfo2 := copyWorkflowExecutionInfo(info2) updatedState2 := copyWorkflowExecutionState(state2.ExecutionState) - err = s.UpdateWorkflowExecutionWithTransferTasks(updatedInfo2, updatedState2, state2.NextEventId, int64(3), transferTasks, nil) + err = s.UpdateWorkflowExecutionWithTransferTasks(s.ctx, updatedInfo2, updatedState2, state2.NextEventId, int64(3), transferTasks, nil) s.NoError(err) - tasks2, err := s.GetTransferTasks(1, false) + tasks2, err := s.GetTransferTasks(s.ctx, 1, false) s.NoError(err) s.NotNil(tasks2, "expected valid list of tasks.") s.Equal(1, len(tasks2), "Expected 1 cancel task.") @@ -1952,7 +1962,7 @@ func (s *ExecutionManagerSuite) TestCancelTransferTaskTasks() { Version: 0, }, task2) - err = s.CompleteTransferTask(task2.GetTaskID()) + err = s.CompleteTransferTask(s.ctx, task2.GetTaskID()) s.NoError(err) } @@ -1982,16 +1992,16 @@ func (s *ExecutionManagerSuite) TestSignalTransferTaskTasks() { workflowExecution.RunId, ) - task0, err := s.CreateWorkflowExecution(namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + task0, err := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.NoError(err) s.NotNil(task0, "Expected non empty task identifier.") - taskD, err := s.GetTransferTasks(1, false) + taskD, err := s.GetTransferTasks(s.ctx, 1, false) s.Equal(1, len(taskD), "Expected 1 workflow task.") - err = s.CompleteTransferTask(taskD[0].GetTaskID()) + err = s.CompleteTransferTask(s.ctx, taskD[0].GetTaskID()) s.NoError(err) - state1, err := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state1, err := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err) info1 := state1.ExecutionInfo s.NotNil(info1, "Valid Workflow info expected.") @@ -2011,10 +2021,10 @@ func (s *ExecutionManagerSuite) TestSignalTransferTaskTasks() { TargetChildWorkflowOnly: targetChildWorkflowOnly, InitiatedID: 1, }} - err = s.UpdateWorkflowExecutionWithTransferTasks(updatedInfo1, updatedState1, state1.NextEventId, int64(3), transferTasks, nil) + err = s.UpdateWorkflowExecutionWithTransferTasks(s.ctx, updatedInfo1, updatedState1, state1.NextEventId, int64(3), transferTasks, nil) s.NoError(err) - tasks1, err := s.GetTransferTasks(1, false) + tasks1, err := s.GetTransferTasks(s.ctx, 1, false) s.NoError(err) s.NotNil(tasks1, "expected valid list of tasks.") s.Equal(1, len(tasks1), "Expected 1 cancel task.") @@ -2031,7 +2041,7 @@ func (s *ExecutionManagerSuite) TestSignalTransferTaskTasks() { Version: 0, }, task1) - err = s.CompleteTransferTask(task1.GetTaskID()) + err = s.CompleteTransferTask(s.ctx, task1.GetTaskID()) s.NoError(err) targetNamespaceID = "f2bfaab6-7e8b-4fac-9a62-17da8d37becb" @@ -2048,17 +2058,17 @@ func (s *ExecutionManagerSuite) TestSignalTransferTaskTasks() { InitiatedID: 3, }} - state2, err := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state2, err := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err) info2 := state2.ExecutionInfo s.NotNil(info2, "Valid Workflow info expected.") updatedInfo2 := copyWorkflowExecutionInfo(info2) updatedState2 := copyWorkflowExecutionState(state2.ExecutionState) - err = s.UpdateWorkflowExecutionWithTransferTasks(updatedInfo2, updatedState2, state2.NextEventId, int64(3), transferTasks, nil) + err = s.UpdateWorkflowExecutionWithTransferTasks(s.ctx, updatedInfo2, updatedState2, state2.NextEventId, int64(3), transferTasks, nil) s.NoError(err) - tasks2, err := s.GetTransferTasks(1, false) + tasks2, err := s.GetTransferTasks(s.ctx, 1, false) s.NoError(err) s.NotNil(tasks2, "expected valid list of tasks.") s.Equal(1, len(tasks2), "Expected 1 cancel task.") @@ -2075,7 +2085,7 @@ func (s *ExecutionManagerSuite) TestSignalTransferTaskTasks() { Version: 0, }, task2) - err = s.CompleteTransferTask(task2.GetTaskID()) + err = s.CompleteTransferTask(s.ctx, task2.GetTaskID()) s.NoError(err) } @@ -2092,15 +2102,15 @@ func (s *ExecutionManagerSuite) TestReplicationTasks() { workflowExecution.RunId, ) - task0, err := s.CreateWorkflowExecution(namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + task0, err := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.NoError(err) s.NotNil(task0, "Expected non empty task identifier.") - taskD, err := s.GetTransferTasks(1, false) + taskD, err := s.GetTransferTasks(s.ctx, 1, false) s.Equal(1, len(taskD), "Expected 1 workflow task.") - err = s.CompleteTransferTask(taskD[0].GetTaskID()) + err = s.CompleteTransferTask(s.ctx, taskD[0].GetTaskID()) s.NoError(err) - state1, err := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state1, err := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err) info1 := state1.ExecutionInfo s.NotNil(info1, "Valid Workflow info expected.") @@ -2129,19 +2139,19 @@ func (s *ExecutionManagerSuite) TestReplicationTasks() { ScheduledID: 99, }, } - err = s.UpdateWorklowStateAndReplication(updatedInfo1, updatedState1, state1.NextEventId, int64(3), replicationTasks) + err = s.UpdateWorklowStateAndReplication(s.ctx, updatedInfo1, updatedState1, state1.NextEventId, int64(3), replicationTasks) s.NoError(err) // test only tasks within requested range will be returned for _, replicationTask := range replicationTasks { taskID := replicationTask.GetTaskID() - tasks, err := s.GetReplicationTasksInRange(taskID, taskID+1, 100) + tasks, err := s.GetReplicationTasksInRange(s.ctx, taskID, taskID+1, 100) s.NoError(err) s.Equal(1, len(tasks)) } // test pagination - respTasks, err := s.GetReplicationTasks(1, true) + respTasks, err := s.GetReplicationTasks(s.ctx, 1, true) s.NoError(err) s.Equal(len(replicationTasks), len(respTasks)) @@ -2158,7 +2168,7 @@ func (s *ExecutionManagerSuite) TestReplicationTasks() { case *tasks.SyncActivityTask: s.Equal(expected.ScheduledID, respTasks[index].(*tasks.SyncActivityTask).ScheduledID) } - err = s.CompleteReplicationTask(respTasks[index].GetTaskID()) + err = s.CompleteReplicationTask(s.ctx, respTasks[index].GetTaskID()) s.NoError(err) } } @@ -2177,11 +2187,11 @@ func (s *ExecutionManagerSuite) TestTransferTasksComplete() { ) taskqueue := "some random taskqueue" - task0, err0 := s.CreateWorkflowExecution(namespaceID, workflowExecution, taskqueue, "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + task0, err0 := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, taskqueue, "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.NoError(err0) s.NotNil(task0, "Expected non empty task identifier.") - tasks1, err1 := s.GetTransferTasks(1, false) + tasks1, err1 := s.GetTransferTasks(s.ctx, 1, false) s.NoError(err1) s.NotNil(tasks1, "expected valid list of tasks.") s.Equal(1, len(tasks1), "Expected 1 workflow task.") @@ -2195,10 +2205,10 @@ func (s *ExecutionManagerSuite) TestTransferTasksComplete() { ScheduleID: scheduleId, Version: 0, }, task1) - err3 := s.CompleteTransferTask(task1.GetTaskID()) + err3 := s.CompleteTransferTask(s.ctx, task1.GetTaskID()) s.NoError(err3) - state0, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state0, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) info0 := state0.ExecutionInfo s.NotNil(info0, "Valid Workflow info expected.") @@ -2221,10 +2231,10 @@ func (s *ExecutionManagerSuite) TestTransferTasksComplete() { &tasks.SignalExecutionTask{workflowKey, now, currentTransferID + 10005, targetNamespaceID, targetWorkflowID, targetRunID, true, scheduleID, 555}, &tasks.StartChildExecutionTask{workflowKey, now, currentTransferID + 10006, targetNamespaceID, targetWorkflowID, scheduleID, 666}, } - err2 := s.UpdateWorklowStateAndReplication(updatedInfo, updatedState, int64(6), int64(3), taskSlice) + err2 := s.UpdateWorklowStateAndReplication(s.ctx, updatedInfo, updatedState, int64(6), int64(3), taskSlice) s.NoError(err2) - txTasks, err1 := s.GetTransferTasks(1, true) // use page size one to force pagination + txTasks, err1 := s.GetTransferTasks(s.ctx, 1, true) // use page size one to force pagination s.NoError(err1) s.NotNil(txTasks, "expected valid list of tasks.") s.Equal(len(taskSlice), len(txTasks)) @@ -2245,25 +2255,25 @@ func (s *ExecutionManagerSuite) TestTransferTasksComplete() { s.Equal(int64(555), txTasks[4].GetVersion()) s.Equal(int64(666), txTasks[5].GetVersion()) - err2 = s.CompleteTransferTask(txTasks[0].GetTaskID()) + err2 = s.CompleteTransferTask(s.ctx, txTasks[0].GetTaskID()) s.NoError(err2) - err2 = s.CompleteTransferTask(txTasks[1].GetTaskID()) + err2 = s.CompleteTransferTask(s.ctx, txTasks[1].GetTaskID()) s.NoError(err2) - err2 = s.CompleteTransferTask(txTasks[2].GetTaskID()) + err2 = s.CompleteTransferTask(s.ctx, txTasks[2].GetTaskID()) s.NoError(err2) - err2 = s.CompleteTransferTask(txTasks[3].GetTaskID()) + err2 = s.CompleteTransferTask(s.ctx, txTasks[3].GetTaskID()) s.NoError(err2) - err2 = s.CompleteTransferTask(txTasks[4].GetTaskID()) + err2 = s.CompleteTransferTask(s.ctx, txTasks[4].GetTaskID()) s.NoError(err2) - err2 = s.CompleteTransferTask(txTasks[5].GetTaskID()) + err2 = s.CompleteTransferTask(s.ctx, txTasks[5].GetTaskID()) s.NoError(err2) - txTasks, err2 = s.GetTransferTasks(100, false) + txTasks, err2 = s.GetTransferTasks(s.ctx, 100, false) s.NoError(err2) s.Empty(txTasks, "expected empty task queue.") } @@ -2282,11 +2292,11 @@ func (s *ExecutionManagerSuite) TestTransferTasksRangeComplete() { ) taskqueue := "some random taskqueue" - task0, err0 := s.CreateWorkflowExecution(namespaceID, workflowExecution, taskqueue, "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + task0, err0 := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, taskqueue, "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.NoError(err0) s.NotNil(task0, "Expected non empty task identifier.") - tasks1, err1 := s.GetTransferTasks(1, false) + tasks1, err1 := s.GetTransferTasks(s.ctx, 1, false) s.NoError(err1) s.NotNil(tasks1, "expected valid list of tasks.") s.Equal(1, len(tasks1), "Expected 1 workflow task.") @@ -2300,10 +2310,10 @@ func (s *ExecutionManagerSuite) TestTransferTasksRangeComplete() { Version: 0, }, task1) - err3 := s.CompleteTransferTask(task1.GetTaskID()) + err3 := s.CompleteTransferTask(s.ctx, task1.GetTaskID()) s.NoError(err3) - state0, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state0, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) info0 := state0.ExecutionInfo s.NotNil(info0, "Valid Workflow info expected.") @@ -2325,10 +2335,10 @@ func (s *ExecutionManagerSuite) TestTransferTasksRangeComplete() { &tasks.SignalExecutionTask{workflowKey, now, currentTransferID + 10005, targetNamespaceID, targetWorkflowID, targetRunID, true, scheduleID, 555}, &tasks.StartChildExecutionTask{workflowKey, now, currentTransferID + 10006, targetNamespaceID, targetWorkflowID, scheduleID, 666}, } - err2 := s.UpdateWorklowStateAndReplication(updatedInfo, updatedState, int64(6), int64(3), taskSlice) + err2 := s.UpdateWorklowStateAndReplication(s.ctx, updatedInfo, updatedState, int64(6), int64(3), taskSlice) s.NoError(err2) - txTasks, err1 := s.GetTransferTasks(2, true) // use page size one to force pagination + txTasks, err1 := s.GetTransferTasks(s.ctx, 2, true) // use page size one to force pagination s.NoError(err1) s.NotNil(txTasks, "expected valid list of tasks.") s.Equal(len(taskSlice), len(txTasks)) @@ -2355,10 +2365,10 @@ func (s *ExecutionManagerSuite) TestTransferTasksRangeComplete() { s.Equal(currentTransferID+10005, txTasks[4].GetTaskID()) s.Equal(currentTransferID+10006, txTasks[5].GetTaskID()) - err2 = s.RangeCompleteTransferTask(txTasks[0].GetTaskID(), txTasks[5].GetTaskID()+1) + err2 = s.RangeCompleteTransferTask(s.ctx, txTasks[0].GetTaskID(), txTasks[5].GetTaskID()+1) s.NoError(err2) - txTasks, err2 = s.GetTransferTasks(100, false) + txTasks, err2 = s.GetTransferTasks(s.ctx, 100, false) s.NoError(err2) s.Empty(txTasks, "expected empty task queue.") } @@ -2379,11 +2389,11 @@ func (s *ExecutionManagerSuite) TestTimerTasksComplete() { now := time.Now().UTC() initialTasks := []tasks.Task{&tasks.WorkflowTaskTimeoutTask{workflowKey, now.Add(1 * time.Second), 1, 2, 3, enumspb.TIMEOUT_TYPE_START_TO_CLOSE, 11}} - task0, err0 := s.CreateWorkflowExecution(namespaceID, workflowExecution, "taskQueue", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, initialTasks) + task0, err0 := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, "taskQueue", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, initialTasks) s.NoError(err0) s.NotNil(task0, "Expected non empty task identifier.") - state0, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state0, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) info0 := state0.ExecutionInfo s.NotNil(info0, "Valid Workflow info expected.") @@ -2397,13 +2407,13 @@ func (s *ExecutionManagerSuite) TestTimerTasksComplete() { &tasks.ActivityTimeoutTask{WorkflowKey: workflowKey, VisibilityTimestamp: now.Add(3 * time.Second), TaskID: 4, TimeoutType: enumspb.TIMEOUT_TYPE_START_TO_CLOSE, EventID: 7, Version: 14}, &tasks.UserTimerTask{WorkflowKey: workflowKey, VisibilityTimestamp: now.Add(3 * time.Second), TaskID: 5, EventID: 7, Version: 15}, } - err2 := s.UpdateWorkflowExecution(updatedInfo, updatedState, int64(5), []int64{int64(4)}, nil, int64(3), taskSlice, nil, nil, nil, nil) + err2 := s.UpdateWorkflowExecution(s.ctx, updatedInfo, updatedState, int64(5), []int64{int64(4)}, nil, int64(3), taskSlice, nil, nil, nil, nil) s.NoError(err2) // try the following a couple of times to give cassandra time to catch up var timerTasks []tasks.Task for i := 0; i < 3; i++ { - timerTasks, err1 = s.GetTimerTasks(1, true) // use page size one to force pagination + timerTasks, err1 = s.GetTimerTasks(s.ctx, 1, true) // use page size one to force pagination s.NoError(err1) s.NotNil(timerTasks, "expected valid list of tasks.") if len(taskSlice)+len(initialTasks) == len(timerTasks) { @@ -2425,10 +2435,10 @@ func (s *ExecutionManagerSuite) TestTimerTasksComplete() { visTimer0 := timerTasks[0].GetVisibilityTime() visTimer4 := timerTasks[4].GetVisibilityTime().Add(1 * time.Second) - err2 = s.RangeCompleteTimerTask(visTimer0, visTimer4) + err2 = s.RangeCompleteTimerTask(s.ctx, visTimer0, visTimer4) s.NoError(err2) - timerTasks2, err2 := s.GetTimerTasks(100, false) + timerTasks2, err2 := s.GetTimerTasks(s.ctx, 100, false) s.NoError(err2) s.Empty(timerTasks2, "expected empty task queue.") } @@ -2446,11 +2456,11 @@ func (s *ExecutionManagerSuite) TestTimerTasksRangeComplete() { workflowExecution.RunId, ) - task0, err0 := s.CreateWorkflowExecution(namespaceID, workflowExecution, "taskQueue", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + task0, err0 := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, "taskQueue", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.NoError(err0) s.NotNil(task0, "Expected non empty task identifier.") - state0, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state0, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) info0 := state0.ExecutionInfo s.NotNil(info0, "Valid Workflow info expected.") @@ -2465,13 +2475,13 @@ func (s *ExecutionManagerSuite) TestTimerTasksRangeComplete() { &tasks.ActivityTimeoutTask{WorkflowKey: workflowKey, VisibilityTimestamp: time.Now().UTC(), TaskID: 4, TimeoutType: enumspb.TIMEOUT_TYPE_START_TO_CLOSE, EventID: 7, Version: 14}, &tasks.UserTimerTask{WorkflowKey: workflowKey, VisibilityTimestamp: time.Now().UTC(), TaskID: 5, EventID: 7, Version: 15}, } - err2 := s.UpdateWorkflowExecution(updatedInfo, updatedState, int64(5), []int64{int64(4)}, nil, int64(3), taskSlice, nil, nil, nil, nil) + err2 := s.UpdateWorkflowExecution(s.ctx, updatedInfo, updatedState, int64(5), []int64{int64(4)}, nil, int64(3), taskSlice, nil, nil, nil, nil) s.NoError(err2) var timerTasks []tasks.Task // Try a couple of times to avoid flakiness for i := 0; i < 3; i++ { - timerTasks, err1 = s.GetTimerTasks(1, true) // use page size one to force pagination + timerTasks, err1 = s.GetTimerTasks(s.ctx, 1, true) // use page size one to force pagination if len(taskSlice) == len(timerTasks) { break } @@ -2492,25 +2502,25 @@ func (s *ExecutionManagerSuite) TestTimerTasksRangeComplete() { s.Equal(int64(14), timerTasks[3].GetVersion()) s.Equal(int64(15), timerTasks[4].GetVersion()) - err2 = s.UpdateWorkflowExecution(updatedInfo, updatedState, int64(5), nil, nil, int64(5), nil, nil, nil, nil, nil) + err2 = s.UpdateWorkflowExecution(s.ctx, updatedInfo, updatedState, int64(5), nil, nil, int64(5), nil, nil, nil, nil, nil) s.NoError(err2) - err2 = s.CompleteTimerTask(timerTasks[0].GetVisibilityTime(), timerTasks[0].GetTaskID()) + err2 = s.CompleteTimerTask(s.ctx, timerTasks[0].GetVisibilityTime(), timerTasks[0].GetTaskID()) s.NoError(err2) - err2 = s.CompleteTimerTask(timerTasks[1].GetVisibilityTime(), timerTasks[1].GetTaskID()) + err2 = s.CompleteTimerTask(s.ctx, timerTasks[1].GetVisibilityTime(), timerTasks[1].GetTaskID()) s.NoError(err2) - err2 = s.CompleteTimerTask(timerTasks[2].GetVisibilityTime(), timerTasks[2].GetTaskID()) + err2 = s.CompleteTimerTask(s.ctx, timerTasks[2].GetVisibilityTime(), timerTasks[2].GetTaskID()) s.NoError(err2) - err2 = s.CompleteTimerTask(timerTasks[3].GetVisibilityTime(), timerTasks[3].GetTaskID()) + err2 = s.CompleteTimerTask(s.ctx, timerTasks[3].GetVisibilityTime(), timerTasks[3].GetTaskID()) s.NoError(err2) - err2 = s.CompleteTimerTask(timerTasks[4].GetVisibilityTime(), timerTasks[4].GetTaskID()) + err2 = s.CompleteTimerTask(s.ctx, timerTasks[4].GetVisibilityTime(), timerTasks[4].GetTaskID()) s.NoError(err2) - timerTasks2, err2 := s.GetTimerTasks(100, false) + timerTasks2, err2 := s.GetTimerTasks(s.ctx, 100, false) s.NoError(err2) s.Empty(timerTasks2, "expected empty task queue.") } @@ -2523,11 +2533,11 @@ func (s *ExecutionManagerSuite) TestWorkflowMutableStateActivities() { RunId: "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", } - task0, err0 := s.CreateWorkflowExecution(namespaceID, workflowExecution, "taskQueue", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + task0, err0 := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, "taskQueue", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.NoError(err0) s.NotNil(task0, "Expected non empty task identifier.") - state0, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state0, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) info0 := state0.ExecutionInfo s.NotNil(info0, "Valid Workflow info expected.") @@ -2569,10 +2579,10 @@ func (s *ExecutionManagerSuite) TestWorkflowMutableStateActivities() { RetryLastWorkerIdentity: uuid.New(), RetryLastFailure: failure.NewServerFailure("some random error", false), }} - err2 := s.UpdateWorkflowExecution(updatedInfo, updatedState, int64(5), []int64{int64(4)}, nil, int64(3), nil, activityInfos, nil, nil, nil) + err2 := s.UpdateWorkflowExecution(s.ctx, updatedInfo, updatedState, int64(5), []int64{int64(4)}, nil, int64(3), nil, activityInfos, nil, nil, nil) s.NoError(err2) - state, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) s.NotNil(state, "expected valid state.") s.Equal(1, len(state.ActivityInfos)) @@ -2611,10 +2621,10 @@ func (s *ExecutionManagerSuite) TestWorkflowMutableStateActivities() { s.Equal(activityInfos[0].RetryLastFailure, ai.RetryLastFailure) s.Equal(activityInfos[0].RetryLastWorkerIdentity, ai.RetryLastWorkerIdentity) - err2 = s.UpdateWorkflowExecution(updatedInfo, updatedState, int64(5), nil, nil, int64(5), nil, nil, []int64{1}, nil, nil) + err2 = s.UpdateWorkflowExecution(s.ctx, updatedInfo, updatedState, int64(5), nil, nil, int64(5), nil, nil, []int64{1}, nil, nil) s.NoError(err2) - state, err2 = s.GetWorkflowMutableState(namespaceID, workflowExecution) + state, err2 = s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err2) s.NotNil(state, "expected valid state.") s.Equal(0, len(state.ActivityInfos)) @@ -2628,11 +2638,11 @@ func (s *ExecutionManagerSuite) TestWorkflowMutableStateTimers() { RunId: "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", } - task0, err0 := s.CreateWorkflowExecution(namespaceID, workflowExecution, "taskQueue", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + task0, err0 := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, "taskQueue", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.NoError(err0) s.NotNil(task0, "Expected non empty task identifier.") - state0, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state0, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) info0 := state0.ExecutionInfo s.NotNil(info0, "Valid Workflow info expected.") @@ -2649,10 +2659,10 @@ func (s *ExecutionManagerSuite) TestWorkflowMutableStateTimers() { TaskStatus: 2, StartedId: 5, }} - err2 := s.UpdateWorkflowExecution(updatedInfo, updatedState, int64(5), []int64{int64(4)}, nil, int64(3), nil, nil, nil, timerInfos, nil) + err2 := s.UpdateWorkflowExecution(s.ctx, updatedInfo, updatedState, int64(5), []int64{int64(4)}, nil, int64(3), nil, nil, nil, timerInfos, nil) s.NoError(err2) - state, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) s.NotNil(state, "expected valid state.") s.Equal(1, len(state.TimerInfos)) @@ -2662,10 +2672,10 @@ func (s *ExecutionManagerSuite) TestWorkflowMutableStateTimers() { s.Equal(int64(2), state.TimerInfos[timerID].TaskStatus) s.Equal(int64(5), state.TimerInfos[timerID].GetStartedId()) - err2 = s.UpdateWorkflowExecution(updatedInfo, updatedState, int64(5), nil, nil, int64(5), nil, nil, nil, nil, []string{timerID}) + err2 = s.UpdateWorkflowExecution(s.ctx, updatedInfo, updatedState, int64(5), nil, nil, int64(5), nil, nil, nil, nil, []string{timerID}) s.NoError(err2) - state, err2 = s.GetWorkflowMutableState(namespaceID, workflowExecution) + state, err2 = s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err2) s.NotNil(state, "expected valid state.") s.Equal(0, len(state.TimerInfos)) @@ -2685,11 +2695,11 @@ func (s *ExecutionManagerSuite) TestWorkflowMutableStateChildExecutions() { RunId: "73e89362-25ec-4305-bcb8-d9448b90856c", } - task0, err0 := s.CreateChildWorkflowExecution(namespaceID, workflowExecution, parentNamespaceID, parentExecution, 1, "taskQueue", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + task0, err0 := s.CreateChildWorkflowExecution(s.ctx, namespaceID, workflowExecution, parentNamespaceID, parentExecution, 1, "taskQueue", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.NoError(err0) s.NotNil(task0, "Expected non empty task identifier.") - state0, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state0, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) info0 := state0.ExecutionInfo s.NotNil(info0, "Valid Workflow info expected.") @@ -2709,10 +2719,10 @@ func (s *ExecutionManagerSuite) TestWorkflowMutableStateChildExecutions() { CreateRequestId: createRequestID, ParentClosePolicy: enumspb.PARENT_CLOSE_POLICY_TERMINATE, }} - err2 := s.UpsertChildExecutionsState(updatedInfo, updatedState, int64(5), int64(3), childExecutionInfos) + err2 := s.UpsertChildExecutionsState(s.ctx, updatedInfo, updatedState, int64(5), int64(3), childExecutionInfos) s.NoError(err2) - state, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) s.NotNil(state, "expected valid state.") s.Equal(1, len(state.ChildExecutionInfos)) @@ -2725,10 +2735,10 @@ func (s *ExecutionManagerSuite) TestWorkflowMutableStateChildExecutions() { s.Equal(int64(2), ci.StartedId) s.Equal(createRequestID, ci.CreateRequestId) - err2 = s.DeleteChildExecutionsState(updatedInfo, updatedState, int64(5), int64(5), int64(1)) + err2 = s.DeleteChildExecutionsState(s.ctx, updatedInfo, updatedState, int64(5), int64(5), int64(1)) s.NoError(err2) - state, err2 = s.GetWorkflowMutableState(namespaceID, workflowExecution) + state, err2 = s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err2) s.NotNil(state, "expected valid state.") s.Equal(0, len(state.ChildExecutionInfos)) @@ -2742,11 +2752,11 @@ func (s *ExecutionManagerSuite) TestWorkflowMutableStateRequestCancel() { RunId: "87f96253-b925-426e-90db-aa4ee89b5aca", } - task0, err0 := s.CreateWorkflowExecution(namespaceID, workflowExecution, "taskQueue", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + task0, err0 := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, "taskQueue", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.NoError(err0) s.NotNil(task0, "Expected non empty task identifier.") - state0, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state0, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) info0 := state0.ExecutionInfo s.NotNil(info0, "Valid Workflow info expected.") @@ -2760,10 +2770,10 @@ func (s *ExecutionManagerSuite) TestWorkflowMutableStateRequestCancel() { InitiatedEventBatchId: 1, CancelRequestId: uuid.New(), } - err2 := s.UpsertRequestCancelState(updatedInfo, updatedState, int64(5), int64(3), []*persistencespb.RequestCancelInfo{requestCancelInfo}) + err2 := s.UpsertRequestCancelState(s.ctx, updatedInfo, updatedState, int64(5), int64(3), []*persistencespb.RequestCancelInfo{requestCancelInfo}) s.NoError(err2) - state, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) s.NotNil(state, "expected valid state.") s.Equal(1, len(state.RequestCancelInfos)) @@ -2772,10 +2782,10 @@ func (s *ExecutionManagerSuite) TestWorkflowMutableStateRequestCancel() { s.NotNil(ri) s.Equal(requestCancelInfo, ri) - err2 = s.DeleteCancelState(updatedInfo, updatedState, int64(5), int64(5), requestCancelInfo.GetInitiatedId()) + err2 = s.DeleteCancelState(s.ctx, updatedInfo, updatedState, int64(5), int64(5), requestCancelInfo.GetInitiatedId()) s.NoError(err2) - state, err2 = s.GetWorkflowMutableState(namespaceID, workflowExecution) + state, err2 = s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err2) s.NotNil(state, "expected valid state.") s.Equal(0, len(state.RequestCancelInfos)) @@ -2790,11 +2800,11 @@ func (s *ExecutionManagerSuite) TestWorkflowMutableStateSignalInfo() { RunId: runID, } - task0, err0 := s.CreateWorkflowExecution(namespaceID, workflowExecution, "taskQueue", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + task0, err0 := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, "taskQueue", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.NoError(err0) s.NotNil(task0, "Expected non empty task identifier.") - state0, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state0, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) info0 := state0.ExecutionInfo s.NotNil(info0, "Valid Workflow info expected.") @@ -2808,10 +2818,10 @@ func (s *ExecutionManagerSuite) TestWorkflowMutableStateSignalInfo() { InitiatedEventBatchId: 1, RequestId: uuid.New(), } - err2 := s.UpsertSignalInfoState(updatedInfo, updatedState, int64(5), int64(3), []*persistencespb.SignalInfo{signalInfo}) + err2 := s.UpsertSignalInfoState(s.ctx, updatedInfo, updatedState, int64(5), int64(3), []*persistencespb.SignalInfo{signalInfo}) s.NoError(err2) - state, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) s.NotNil(state, "expected valid state.") s.Equal(1, len(state.SignalInfos)) @@ -2820,10 +2830,10 @@ func (s *ExecutionManagerSuite) TestWorkflowMutableStateSignalInfo() { s.NotNil(si) s.Equal(signalInfo, si) - err2 = s.DeleteSignalState(updatedInfo, updatedState, int64(5), int64(5), signalInfo.GetInitiatedId()) + err2 = s.DeleteSignalState(s.ctx, updatedInfo, updatedState, int64(5), int64(5), signalInfo.GetInitiatedId()) s.NoError(err2) - state, err2 = s.GetWorkflowMutableState(namespaceID, workflowExecution) + state, err2 = s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err2) s.NotNil(state, "expected valid state.") s.Equal(0, len(state.SignalInfos)) @@ -2838,11 +2848,11 @@ func (s *ExecutionManagerSuite) TestWorkflowMutableStateSignalRequested() { RunId: runID, } - task0, err0 := s.CreateWorkflowExecution(namespaceID, workflowExecution, "taskQueue", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + task0, err0 := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, "taskQueue", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.NoError(err0) s.NotNil(task0, "Expected non empty task identifier.") - state0, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state0, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) info0 := state0.ExecutionInfo s.NotNil(info0, "Valid Workflow info expected.") @@ -2852,19 +2862,19 @@ func (s *ExecutionManagerSuite) TestWorkflowMutableStateSignalRequested() { updatedInfo.LastWorkflowTaskStartId = int64(2) signalRequestedID := uuid.New() signalsRequested := []string{signalRequestedID} - err2 := s.UpsertSignalsRequestedState(updatedInfo, updatedState, int64(5), int64(3), signalsRequested) + err2 := s.UpsertSignalsRequestedState(s.ctx, updatedInfo, updatedState, int64(5), int64(3), signalsRequested) s.NoError(err2) - state, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) s.NotNil(state, "expected valid state.") s.Equal(1, len(state.SignalRequestedIds)) s.Equal(signalRequestedID, state.SignalRequestedIds[0]) - err2 = s.DeleteSignalsRequestedState(updatedInfo, updatedState, int64(5), int64(5), []string{signalRequestedID, uuid.New()}) + err2 = s.DeleteSignalsRequestedState(s.ctx, updatedInfo, updatedState, int64(5), int64(5), []string{signalRequestedID, uuid.New()}) s.NoError(err2) - state, err2 = s.GetWorkflowMutableState(namespaceID, workflowExecution) + state, err2 = s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err2) s.NotNil(state, "expected valid state.") s.Equal(0, len(state.SignalRequestedIds)) @@ -2878,11 +2888,11 @@ func (s *ExecutionManagerSuite) TestWorkflowMutableStateInfo() { RunId: "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", } - task0, err0 := s.CreateWorkflowExecution(namespaceID, workflowExecution, "taskQueue", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + task0, err0 := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, "taskQueue", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.NoError(err0) s.NotNil(task0, "Expected non empty task identifier.") - state0, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state0, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) info0 := state0.ExecutionInfo s.NotNil(info0, "Valid Workflow info expected.") @@ -2892,10 +2902,10 @@ func (s *ExecutionManagerSuite) TestWorkflowMutableStateInfo() { updatedNextEventID := int64(5) updatedInfo.LastWorkflowTaskStartId = int64(2) - err2 := s.UpdateWorkflowExecution(updatedInfo, updatedState, updatedNextEventID, []int64{int64(4)}, nil, int64(3), nil, nil, nil, nil, nil) + err2 := s.UpdateWorkflowExecution(s.ctx, updatedInfo, updatedState, updatedNextEventID, []int64{int64(4)}, nil, int64(3), nil, nil, nil, nil, nil) s.NoError(err2) - state, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) s.NotNil(state, "expected valid state.") s.NotNil(state.ExecutionInfo, "expected valid MS Info state.") @@ -2911,10 +2921,10 @@ func (s *ExecutionManagerSuite) TestContinueAsNew() { RunId: "551c88d2-d9e6-404f-8131-9eec14f36643", } - _, err0 := s.CreateWorkflowExecution(namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + _, err0 := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.NoError(err0) - state0, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state0, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) info0 := state0.ExecutionInfo continueAsNewInfo := copyWorkflowExecutionInfo(info0) @@ -2942,10 +2952,10 @@ func (s *ExecutionManagerSuite) TestContinueAsNew() { }, } - err2 := s.ContinueAsNewExecution(continueAsNewInfo, continueAsNewState, continueAsNewNextEventID, state0.NextEventId, newWorkflowExecution, int64(3), int64(2), &testResetPoints) + err2 := s.ContinueAsNewExecution(s.ctx, continueAsNewInfo, continueAsNewState, continueAsNewNextEventID, state0.NextEventId, newWorkflowExecution, int64(3), int64(2), &testResetPoints) s.NoError(err2) - prevExecutionState, err3 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + prevExecutionState, err3 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err3) prevExecutionInfo := prevExecutionState.ExecutionInfo s.Equal(enumsspb.WORKFLOW_EXECUTION_STATE_COMPLETED, prevExecutionState.ExecutionState.State) @@ -2954,7 +2964,7 @@ func (s *ExecutionManagerSuite) TestContinueAsNew() { s.Equal(int64(2), prevExecutionInfo.LastWorkflowTaskStartId) s.True(reflect.DeepEqual(prevExecutionInfo.AutoResetPoints, &workflowpb.ResetPoints{})) - newExecutionState, err4 := s.GetWorkflowMutableState(namespaceID, newWorkflowExecution) + newExecutionState, err4 := s.GetWorkflowMutableState(s.ctx, namespaceID, newWorkflowExecution) s.NoError(err4) newExecutionInfo := newExecutionState.ExecutionInfo s.Equal(enumsspb.WORKFLOW_EXECUTION_STATE_CREATED, newExecutionState.ExecutionState.State) @@ -2964,14 +2974,14 @@ func (s *ExecutionManagerSuite) TestContinueAsNew() { s.Equal(int64(2), newExecutionInfo.WorkflowTaskScheduleId) s.Equal(testResetPoints.String(), newExecutionInfo.AutoResetPoints.String()) - newRunID, err5 := s.GetCurrentWorkflowRunID(namespaceID, workflowExecution.GetWorkflowId()) + newRunID, err5 := s.GetCurrentWorkflowRunID(s.ctx, namespaceID, workflowExecution.GetWorkflowId()) s.NoError(err5) s.Equal(newWorkflowExecution.GetRunId(), newRunID) } // TestReplicationTransferTaskTasks test func (s *ExecutionManagerSuite) TestReplicationTransferTaskTasks() { - s.ClearReplicationQueue() + s.ClearReplicationQueue(s.ctx) namespaceID := "2466d7de-6602-4ad8-b939-fb8f8c36c711" workflowExecution := commonpb.WorkflowExecution{ WorkflowId: "replication-transfer-task-test", @@ -2983,16 +2993,16 @@ func (s *ExecutionManagerSuite) TestReplicationTransferTaskTasks() { workflowExecution.RunId, ) - task0, err := s.CreateWorkflowExecution(namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + task0, err := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.NoError(err) s.NotNil(task0, "Expected non empty task identifier.") - taskD, err := s.GetTransferTasks(1, false) + taskD, err := s.GetTransferTasks(s.ctx, 1, false) s.Equal(1, len(taskD), "Expected 1 workflow task.") - err = s.CompleteTransferTask(taskD[0].GetTaskID()) + err = s.CompleteTransferTask(s.ctx, taskD[0].GetTaskID()) s.NoError(err) - state1, err := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state1, err := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err) info1 := state1.ExecutionInfo s.NotNil(info1, "Valid Workflow info expected.") @@ -3006,10 +3016,10 @@ func (s *ExecutionManagerSuite) TestReplicationTransferTaskTasks() { NextEventID: int64(3), Version: int64(9), }} - err = s.UpdateWorklowStateAndReplication(updatedInfo1, updatedState1, state1.NextEventId, int64(3), replicationTasks) + err = s.UpdateWorklowStateAndReplication(s.ctx, updatedInfo1, updatedState1, state1.NextEventId, int64(3), replicationTasks) s.NoError(err) - tasks1, err := s.GetReplicationTasks(1, false) + tasks1, err := s.GetReplicationTasks(s.ctx, 1, false) s.NoError(err) s.NotNil(tasks1, "expected valid list of tasks.") s.Equal(1, len(tasks1), "Expected 1 replication task.") @@ -3025,20 +3035,20 @@ func (s *ExecutionManagerSuite) TestReplicationTransferTaskTasks() { NewRunBranchToken: nil, }, task1) - err = s.CompleteReplicationTask(task1.GetTaskID()) + err = s.CompleteReplicationTask(s.ctx, task1.GetTaskID()) s.NoError(err) // NOTE: GetReplicationTasks will return empty result even if // there's no CompleteReplicationTask above because the minTaskID // already advanced beyond task1's ID inside GetReplicationTasks - tasks2, err := s.GetReplicationTasks(1, false) + tasks2, err := s.GetReplicationTasks(s.ctx, 1, false) s.NoError(err) s.Equal(0, len(tasks2)) } // TestReplicationTransferTaskRangeComplete test func (s *ExecutionManagerSuite) TestReplicationTransferTaskRangeComplete() { - s.ClearReplicationQueue() + s.ClearReplicationQueue(s.ctx) namespaceID := uuid.New() workflowExecution := commonpb.WorkflowExecution{ WorkflowId: "replication-transfer-task--range-complete-test", @@ -3050,16 +3060,16 @@ func (s *ExecutionManagerSuite) TestReplicationTransferTaskRangeComplete() { workflowExecution.RunId, ) - task0, err := s.CreateWorkflowExecution(namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + task0, err := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.NoError(err) s.NotNil(task0, "Expected non empty task identifier.") - taskD, err := s.GetTransferTasks(1, false) + taskD, err := s.GetTransferTasks(s.ctx, 1, false) s.Equal(1, len(taskD), "Expected 1 workflow task.") - err = s.CompleteTransferTask(taskD[0].GetTaskID()) + err = s.CompleteTransferTask(s.ctx, taskD[0].GetTaskID()) s.NoError(err) - state1, err := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state1, err := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err) info1 := state1.ExecutionInfo s.NotNil(info1, "Valid Workflow info expected.") @@ -3083,6 +3093,7 @@ func (s *ExecutionManagerSuite) TestReplicationTransferTaskRangeComplete() { }, } err = s.UpdateWorklowStateAndReplication( + s.ctx, updatedInfo1, updatedState1, state1.NextEventId, @@ -3091,7 +3102,7 @@ func (s *ExecutionManagerSuite) TestReplicationTransferTaskRangeComplete() { ) s.NoError(err) - tasks1, err := s.GetReplicationTasks(1, false) + tasks1, err := s.GetReplicationTasks(s.ctx, 1, false) s.NoError(err) s.NotNil(tasks1, "expected valid list of tasks.") s.Equal(1, len(tasks1), "Expected 1 replication task.") @@ -3107,9 +3118,9 @@ func (s *ExecutionManagerSuite) TestReplicationTransferTaskRangeComplete() { NewRunBranchToken: nil, }, task1) - err = s.RangeCompleteReplicationTask(task1.GetTaskID() + 1) + err = s.RangeCompleteReplicationTask(s.ctx, task1.GetTaskID()+1) s.NoError(err) - tasks2, err := s.GetReplicationTasks(1, false) + tasks2, err := s.GetReplicationTasks(s.ctx, 1, false) s.NoError(err) s.NotNil(tasks2, "expected valid list of tasks.") task2 := tasks2[0] @@ -3123,9 +3134,9 @@ func (s *ExecutionManagerSuite) TestReplicationTransferTaskRangeComplete() { BranchToken: nil, NewRunBranchToken: nil, }, task2) - err = s.CompleteReplicationTask(task2.GetTaskID()) + err = s.CompleteReplicationTask(s.ctx, task2.GetTaskID()) s.NoError(err) - tasks3, err := s.GetReplicationTasks(1, false) + tasks3, err := s.GetReplicationTasks(s.ctx, 1, false) s.NoError(err) s.Equal(0, len(tasks3)) } @@ -3313,14 +3324,14 @@ func (s *ExecutionManagerSuite) TestReplicationDLQ() { TaskId: 0, TaskType: 1, } - err := s.PutReplicationTaskToDLQ(sourceCluster, taskInfo) + err := s.PutReplicationTaskToDLQ(s.ctx, sourceCluster, taskInfo) s.NoError(err) - resp, err := s.GetReplicationTasksFromDLQ(sourceCluster, 0, 1, 1, nil) + resp, err := s.GetReplicationTasksFromDLQ(s.ctx, sourceCluster, 0, 1, 1, nil) s.NoError(err) s.Len(resp.Tasks, 1) - err = s.DeleteReplicationTaskFromDLQ(sourceCluster, 0) + err = s.DeleteReplicationTaskFromDLQ(s.ctx, sourceCluster, 0) s.NoError(err) - resp, err = s.GetReplicationTasksFromDLQ(sourceCluster, 0, 1, 1, nil) + resp, err = s.GetReplicationTasksFromDLQ(s.ctx, sourceCluster, 0, 1, 1, nil) s.NoError(err) s.Len(resp.Tasks, 0) @@ -3338,19 +3349,19 @@ func (s *ExecutionManagerSuite) TestReplicationDLQ() { TaskId: 10, TaskType: 1, } - err = s.PutReplicationTaskToDLQ(sourceCluster, taskInfo1) + err = s.PutReplicationTaskToDLQ(s.ctx, sourceCluster, taskInfo1) s.NoError(err) - err = s.PutReplicationTaskToDLQ(sourceCluster, taskInfo2) + err = s.PutReplicationTaskToDLQ(s.ctx, sourceCluster, taskInfo2) s.NoError(err) - resp, err = s.GetReplicationTasksFromDLQ(sourceCluster, 1, 5, 100, nil) + resp, err = s.GetReplicationTasksFromDLQ(s.ctx, sourceCluster, 1, 5, 100, nil) s.NoError(err) s.Len(resp.Tasks, 1) - resp, err = s.GetReplicationTasksFromDLQ(sourceCluster, 1, 11, 2, nil) + resp, err = s.GetReplicationTasksFromDLQ(s.ctx, sourceCluster, 1, 11, 2, nil) s.NoError(err) s.Len(resp.Tasks, 2) - err = s.RangeDeleteReplicationTaskFromDLQ(sourceCluster, 1, 11) + err = s.RangeDeleteReplicationTaskFromDLQ(s.ctx, sourceCluster, 1, 11) s.NoError(err) - resp, err = s.GetReplicationTasksFromDLQ(sourceCluster, 1, 11, 2, nil) + resp, err = s.GetReplicationTasksFromDLQ(s.ctx, sourceCluster, 1, 11, 2, nil) s.NoError(err) s.Len(resp.Tasks, 0) } diff --git a/common/persistence/persistence-tests/executionManagerTestForEventsV2.go b/common/persistence/persistence-tests/executionManagerTestForEventsV2.go index 395e315feaf..9bb425bbb13 100644 --- a/common/persistence/persistence-tests/executionManagerTestForEventsV2.go +++ b/common/persistence/persistence-tests/executionManagerTestForEventsV2.go @@ -25,6 +25,7 @@ package persistencetests import ( + "context" "runtime/debug" "testing" "time" @@ -52,6 +53,9 @@ type ( // override suite.Suite.Assertions with require.Assertions; this means that s.NotNil(nil) will stop the test, // not merely log an error *require.Assertions + + ctx context.Context + cancel context.CancelFunc } ) @@ -79,7 +83,13 @@ func (s *ExecutionManagerSuiteForEventsV2) SetupTest() { defer failOnPanic(s.T()) // Have to define our overridden assertions in the test setup. If we did it earlier, s.T() will return nil s.Assertions = require.New(s.T()) - s.ClearTasks() + s.ctx, s.cancel = context.WithTimeout(context.Background(), time.Second*30) + + s.ClearTasks(s.ctx) +} + +func (s *ExecutionManagerSuiteForEventsV2) TearDownTest() { + s.cancel() } func (s *ExecutionManagerSuiteForEventsV2) newRandomChecksum() *persistencespb.Checksum { @@ -115,7 +125,7 @@ func (s *ExecutionManagerSuiteForEventsV2) TestWorkflowCreation() { csum := s.newRandomChecksum() - _, err0 := s.ExecutionManager.CreateWorkflowExecution(&p.CreateWorkflowExecutionRequest{ + _, err0 := s.ExecutionManager.CreateWorkflowExecution(s.ctx, &p.CreateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), NewWorkflowSnapshot: p.WorkflowSnapshot{ ExecutionInfo: &persistencespb.WorkflowExecutionInfo{ @@ -156,7 +166,7 @@ func (s *ExecutionManagerSuiteForEventsV2) TestWorkflowCreation() { s.NoError(err0) - state0, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state0, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) info0 := state0.ExecutionInfo s.NotNil(info0, "Valid Workflow info expected.") @@ -175,10 +185,10 @@ func (s *ExecutionManagerSuiteForEventsV2) TestWorkflowCreation() { StartedId: 5, }} - err2 := s.UpdateWorkflowExecution(updatedInfo, updatedState, int64(5), []int64{int64(4)}, nil, int64(3), nil, nil, nil, timerInfos, nil) + err2 := s.UpdateWorkflowExecution(s.ctx, updatedInfo, updatedState, int64(5), []int64{int64(4)}, nil, int64(3), nil, nil, nil, timerInfos, nil) s.NoError(err2) - state, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) s.NotNil(state, "expected valid state.") s.Equal(1, len(state.TimerInfos)) @@ -189,10 +199,10 @@ func (s *ExecutionManagerSuiteForEventsV2) TestWorkflowCreation() { s.Equal(int64(5), state.TimerInfos[timerID].GetStartedId()) s.assertChecksumsEqual(testWorkflowChecksum, state.Checksum) - err2 = s.UpdateWorkflowExecution(updatedInfo, updatedState, int64(5), nil, nil, int64(5), nil, nil, nil, nil, []string{timerID}) + err2 = s.UpdateWorkflowExecution(s.ctx, updatedInfo, updatedState, int64(5), nil, nil, int64(5), nil, nil, nil, nil, []string{timerID}) s.NoError(err2) - state, err2 = s.GetWorkflowMutableState(namespaceID, workflowExecution) + state, err2 = s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err2) s.NotNil(state, "expected valid state.") s.Equal(0, len(state.TimerInfos)) @@ -215,7 +225,7 @@ func (s *ExecutionManagerSuiteForEventsV2) TestWorkflowCreationWithVersionHistor csum := s.newRandomChecksum() - _, err0 := s.ExecutionManager.CreateWorkflowExecution(&p.CreateWorkflowExecutionRequest{ + _, err0 := s.ExecutionManager.CreateWorkflowExecution(s.ctx, &p.CreateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), RangeID: s.ShardInfo.GetRangeId(), NewWorkflowSnapshot: p.WorkflowSnapshot{ @@ -261,7 +271,7 @@ func (s *ExecutionManagerSuiteForEventsV2) TestWorkflowCreationWithVersionHistor s.NoError(err0) - state0, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state0, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) info0 := state0.ExecutionInfo s.NotNil(info0, "Valid Workflow info expected.") @@ -286,10 +296,10 @@ func (s *ExecutionManagerSuiteForEventsV2) TestWorkflowCreationWithVersionHistor s.NoError(err) updatedInfo.VersionHistories = versionHistories - err2 := s.UpdateWorkflowExecution(updatedInfo, updatedState, state0.NextEventId, []int64{int64(4)}, nil, common.EmptyEventID, nil, nil, nil, timerInfos, nil) + err2 := s.UpdateWorkflowExecution(s.ctx, updatedInfo, updatedState, state0.NextEventId, []int64{int64(4)}, nil, common.EmptyEventID, nil, nil, nil, timerInfos, nil) s.NoError(err2) - state, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) s.NotNil(state, "expected valid state.") s.Equal(1, len(state.TimerInfos)) @@ -315,10 +325,10 @@ func (s *ExecutionManagerSuiteForEventsV2) TestContinueAsNew() { workflowExecution.RunId, ) - _, err0 := s.CreateWorkflowExecution(namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) + _, err0 := s.CreateWorkflowExecution(s.ctx, namespaceID, workflowExecution, "queue1", "wType", timestamp.DurationFromSeconds(20), timestamp.DurationFromSeconds(13), 3, 0, 2, nil) s.NoError(err0) - state0, err1 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + state0, err1 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err1) info0 := state0.ExecutionInfo updatedInfo := copyWorkflowExecutionInfo(info0) @@ -339,7 +349,7 @@ func (s *ExecutionManagerSuiteForEventsV2) TestContinueAsNew() { ScheduleID: int64(2), } - _, err2 := s.ExecutionManager.UpdateWorkflowExecution(&p.UpdateWorkflowExecutionRequest{ + _, err2 := s.ExecutionManager.UpdateWorkflowExecution(s.ctx, &p.UpdateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), UpdateWorkflowMutation: p.WorkflowMutation{ ExecutionInfo: updatedInfo, @@ -382,14 +392,14 @@ func (s *ExecutionManagerSuiteForEventsV2) TestContinueAsNew() { s.NoError(err2) - prevExecutionState, err3 := s.GetWorkflowMutableState(namespaceID, workflowExecution) + prevExecutionState, err3 := s.GetWorkflowMutableState(s.ctx, namespaceID, workflowExecution) s.NoError(err3) prevExecutionInfo := prevExecutionState.ExecutionInfo s.Equal(enumsspb.WORKFLOW_EXECUTION_STATE_COMPLETED, prevExecutionState.ExecutionState.State) s.Equal(int64(5), prevExecutionState.NextEventId) s.Equal(int64(2), prevExecutionInfo.LastWorkflowTaskStartId) - newExecutionState, err4 := s.GetWorkflowMutableState(namespaceID, newWorkflowExecution) + newExecutionState, err4 := s.GetWorkflowMutableState(s.ctx, namespaceID, newWorkflowExecution) s.NoError(err4) newExecutionInfo := newExecutionState.ExecutionInfo s.Equal(enumsspb.WORKFLOW_EXECUTION_STATE_RUNNING, newExecutionState.ExecutionState.State) @@ -398,7 +408,7 @@ func (s *ExecutionManagerSuiteForEventsV2) TestContinueAsNew() { s.Equal(common.EmptyEventID, newExecutionInfo.LastWorkflowTaskStartId) s.Equal(int64(2), newExecutionInfo.WorkflowTaskScheduleId) - newRunID, err5 := s.GetCurrentWorkflowRunID(namespaceID, workflowExecution.WorkflowId) + newRunID, err5 := s.GetCurrentWorkflowRunID(s.ctx, namespaceID, workflowExecution.WorkflowId) s.NoError(err5) s.Equal(newWorkflowExecution.RunId, newRunID) } diff --git a/common/persistence/persistence-tests/historyV2PersistenceTest.go b/common/persistence/persistence-tests/historyV2PersistenceTest.go index baed3495318..bc36f179c40 100644 --- a/common/persistence/persistence-tests/historyV2PersistenceTest.go +++ b/common/persistence/persistence-tests/historyV2PersistenceTest.go @@ -25,6 +25,7 @@ package persistencetests import ( + "context" "math/rand" "reflect" "sync" @@ -33,7 +34,6 @@ import ( "github.com/pborman/uuid" "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" historypb "go.temporal.io/api/history/v1" "go.temporal.io/api/serviceerror" @@ -46,11 +46,14 @@ import ( type ( // HistoryV2PersistenceSuite contains history persistence tests HistoryV2PersistenceSuite struct { - suite.Suite + // suite.Suite TestBase // override suite.Suite.Assertions with require.Assertions; this means that s.NotNil(nil) will stop the test, // not merely log an error *require.Assertions + + ctx context.Context + cancel context.CancelFunc } ) @@ -79,15 +82,22 @@ func isConditionFail(err error) bool { func (s *HistoryV2PersistenceSuite) SetupSuite() { } +// TearDownSuite implementation +func (s *HistoryV2PersistenceSuite) TearDownSuite() { + s.TearDownWorkflowStore() +} + // SetupTest implementation func (s *HistoryV2PersistenceSuite) SetupTest() { // Have to define our overridden assertions in the test setup. If we did it earlier, s.T() will return nil s.Assertions = require.New(s.T()) + + s.ctx, s.cancel = context.WithTimeout(context.Background(), time.Second*30) } -// TearDownSuite implementation -func (s *HistoryV2PersistenceSuite) TearDownSuite() { - s.TearDownWorkflowStore() +// TearDownTest implementation +func (s *HistoryV2PersistenceSuite) TearDownTest() { + s.cancel() } // TestGenUUIDs testing uuid.New() can generate unique UUID @@ -119,7 +129,7 @@ func (s *HistoryV2PersistenceSuite) TestScanAllTrees() { return } - resp, err := s.ExecutionManager.GetAllHistoryTreeBranches(&p.GetAllHistoryTreeBranchesRequest{ + resp, err := s.ExecutionManager.GetAllHistoryTreeBranches(s.ctx, &p.GetAllHistoryTreeBranchesRequest{ PageSize: 1, }) s.Nil(err) @@ -142,7 +152,7 @@ func (s *HistoryV2PersistenceSuite) TestScanAllTrees() { var pgToken []byte for { - resp, err := s.ExecutionManager.GetAllHistoryTreeBranches(&p.GetAllHistoryTreeBranchesRequest{ + resp, err := s.ExecutionManager.GetAllHistoryTreeBranches(s.ctx, &p.GetAllHistoryTreeBranchesRequest{ PageSize: pgSize, NextPageToken: pgToken, }) @@ -219,7 +229,7 @@ func (s *HistoryV2PersistenceSuite) TestReadBranchByPagination() { ShardID: s.ShardInfo.GetShardId(), } // first page - resp, err := s.ExecutionManager.ReadHistoryBranch(req) + resp, err := s.ExecutionManager.ReadHistoryBranch(s.ctx, req) s.Nil(err) s.Equal(4, len(resp.HistoryEvents)) s.Equal(int64(6), resp.HistoryEvents[0].GetEventId()) @@ -282,7 +292,7 @@ func (s *HistoryV2PersistenceSuite) TestReadBranchByPagination() { } // first page - resp, err = s.ExecutionManager.ReadHistoryBranch(req) + resp, err = s.ExecutionManager.ReadHistoryBranch(s.ctx, req) s.Nil(err) s.Equal(8, len(resp.HistoryEvents)) @@ -292,13 +302,13 @@ func (s *HistoryV2PersistenceSuite) TestReadBranchByPagination() { // this page is all stale batches // doe to difference in Cassandra / MySQL pagination // the stale event batch may get returned - resp, err = s.ExecutionManager.ReadHistoryBranch(req) + resp, err = s.ExecutionManager.ReadHistoryBranch(s.ctx, req) s.Nil(err) historyR.Events = append(historyR.Events, resp.HistoryEvents...) req.NextPageToken = resp.NextPageToken if len(resp.HistoryEvents) == 0 { // second page - resp, err = s.ExecutionManager.ReadHistoryBranch(req) + resp, err = s.ExecutionManager.ReadHistoryBranch(s.ctx, req) s.Nil(err) s.Equal(3, len(resp.HistoryEvents)) historyR.Events = append(historyR.Events, resp.HistoryEvents...) @@ -310,14 +320,14 @@ func (s *HistoryV2PersistenceSuite) TestReadBranchByPagination() { } // 3rd page, since we fork from nodeID=13, we can only see one batch of 12 here - resp, err = s.ExecutionManager.ReadHistoryBranch(req) + resp, err = s.ExecutionManager.ReadHistoryBranch(s.ctx, req) s.Nil(err) s.Equal(1, len(resp.HistoryEvents)) historyR.Events = append(historyR.Events, resp.HistoryEvents...) req.NextPageToken = resp.NextPageToken // 4th page, 13~17 - resp, err = s.ExecutionManager.ReadHistoryBranch(req) + resp, err = s.ExecutionManager.ReadHistoryBranch(s.ctx, req) s.Nil(err) s.Equal(5, len(resp.HistoryEvents)) historyR.Events = append(historyR.Events, resp.HistoryEvents...) @@ -329,13 +339,13 @@ func (s *HistoryV2PersistenceSuite) TestReadBranchByPagination() { // If it does return a token, we need to ensure that if the token returned is used // to get history again, no error and history events should be returned. req.PageSize = 1 - resp, err = s.ExecutionManager.ReadHistoryBranch(req) + resp, err = s.ExecutionManager.ReadHistoryBranch(s.ctx, req) s.Nil(err) s.Equal(3, len(resp.HistoryEvents)) historyR.Events = append(historyR.Events, resp.HistoryEvents...) req.NextPageToken = resp.NextPageToken if len(resp.NextPageToken) != 0 { - resp, err = s.ExecutionManager.ReadHistoryBranch(req) + resp, err = s.ExecutionManager.ReadHistoryBranch(s.ctx, req) s.Nil(err) s.Equal(0, len(resp.HistoryEvents)) } @@ -347,7 +357,7 @@ func (s *HistoryV2PersistenceSuite) TestReadBranchByPagination() { // is empty), the call should return an error. req.MinEventID = 19 req.NextPageToken = nil - _, err = s.ExecutionManager.ReadHistoryBranch(req) + _, err = s.ExecutionManager.ReadHistoryBranch(s.ctx, req) s.IsType(&serviceerror.NotFound{}, err) err = s.deleteHistoryBranch(bi2) @@ -709,7 +719,7 @@ func (s *HistoryV2PersistenceSuite) deleteHistoryBranch(branch []byte) error { op := func() error { var err error - err = s.ExecutionManager.DeleteHistoryBranch(&p.DeleteHistoryBranchRequest{ + err = s.ExecutionManager.DeleteHistoryBranch(s.ctx, &p.DeleteHistoryBranchRequest{ BranchToken: branch, ShardID: s.ShardInfo.GetShardId(), }) @@ -721,7 +731,7 @@ func (s *HistoryV2PersistenceSuite) deleteHistoryBranch(branch []byte) error { // persistence helper func (s *HistoryV2PersistenceSuite) descTreeByToken(br []byte) []*persistencespb.HistoryBranch { - resp, err := s.ExecutionManager.GetHistoryTree(&p.GetHistoryTreeRequest{ + resp, err := s.ExecutionManager.GetHistoryTree(s.ctx, &p.GetHistoryTreeRequest{ BranchToken: br, ShardID: convert.Int32Ptr(s.ShardInfo.GetShardId()), }) @@ -730,7 +740,7 @@ func (s *HistoryV2PersistenceSuite) descTreeByToken(br []byte) []*persistencespb } func (s *HistoryV2PersistenceSuite) descTree(treeID string) []*persistencespb.HistoryBranch { - resp, err := s.ExecutionManager.GetHistoryTree(&p.GetHistoryTreeRequest{ + resp, err := s.ExecutionManager.GetHistoryTree(s.ctx, &p.GetHistoryTreeRequest{ TreeID: treeID, ShardID: convert.Int32Ptr(s.ShardInfo.GetShardId()), }) @@ -752,7 +762,7 @@ func (s *HistoryV2PersistenceSuite) readWithError(branch []byte, minID, maxID in res := make([]*historypb.HistoryEvent, 0) token := []byte{} for { - resp, err := s.ExecutionManager.ReadHistoryBranch(&p.ReadHistoryBranchRequest{ + resp, err := s.ExecutionManager.ReadHistoryBranch(s.ctx, &p.ReadHistoryBranchRequest{ BranchToken: branch, MinEventID: minID, MaxEventID: maxID, @@ -801,7 +811,7 @@ func (s *HistoryV2PersistenceSuite) append(branch []byte, events []*historypb.Hi op := func() error { var err error - resp, err = s.ExecutionManager.AppendHistoryNodes(&p.AppendHistoryNodesRequest{ + resp, err = s.ExecutionManager.AppendHistoryNodes(s.ctx, &p.AppendHistoryNodesRequest{ IsNewBranch: isNewBranch, Info: branchInfo, BranchToken: branch, @@ -828,7 +838,7 @@ func (s *HistoryV2PersistenceSuite) fork(forkBranch []byte, forkNodeID int64) ([ op := func() error { var err error - resp, err := s.ExecutionManager.ForkHistoryBranch(&p.ForkHistoryBranchRequest{ + resp, err := s.ExecutionManager.ForkHistoryBranch(s.ctx, &p.ForkHistoryBranchRequest{ ForkBranchToken: forkBranch, ForkNodeID: forkNodeID, Info: testForkRunID, diff --git a/common/persistence/persistence-tests/persistenceTestBase.go b/common/persistence/persistence-tests/persistenceTestBase.go index 868c7ecd5c5..2491e19ab5b 100644 --- a/common/persistence/persistence-tests/persistenceTestBase.go +++ b/common/persistence/persistence-tests/persistenceTestBase.go @@ -25,6 +25,7 @@ package persistencetests import ( + "context" "fmt" "math" "math/rand" @@ -301,10 +302,10 @@ func (s *TestBase) UpdateShard(updatedInfo *persistencespb.ShardInfo, previousRa } // CreateWorkflowExecutionWithBranchToken test util function -func (s *TestBase) CreateWorkflowExecutionWithBranchToken(namespaceID string, workflowExecution commonpb.WorkflowExecution, taskQueue, +func (s *TestBase) CreateWorkflowExecutionWithBranchToken(ctx context.Context, namespaceID string, workflowExecution commonpb.WorkflowExecution, taskQueue, wType string, wTimeout *time.Duration, workflowTaskTimeout *time.Duration, nextEventID int64, lastProcessedEventID int64, workflowTaskScheduleID int64, branchToken []byte, timerTasks []tasks.Task) (*persistence.CreateWorkflowExecutionResponse, error) { - response, err := s.ExecutionManager.CreateWorkflowExecution(&persistence.CreateWorkflowExecutionRequest{ + response, err := s.ExecutionManager.CreateWorkflowExecution(ctx, &persistence.CreateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), NewWorkflowSnapshot: persistence.WorkflowSnapshot{ ExecutionInfo: &persistencespb.WorkflowExecutionInfo{ @@ -355,13 +356,13 @@ func (s *TestBase) CreateWorkflowExecutionWithBranchToken(namespaceID string, wo } // CreateWorkflowExecution is a utility method to create workflow executions -func (s *TestBase) CreateWorkflowExecution(namespaceID string, workflowExecution commonpb.WorkflowExecution, taskQueue, wType string, wTimeout *time.Duration, workflowTaskTimeout *time.Duration, nextEventID, lastProcessedEventID, workflowTaskScheduleID int64, timerTasks []tasks.Task) (*persistence.CreateWorkflowExecutionResponse, error) { - return s.CreateWorkflowExecutionWithBranchToken(namespaceID, workflowExecution, taskQueue, wType, wTimeout, workflowTaskTimeout, +func (s *TestBase) CreateWorkflowExecution(ctx context.Context, namespaceID string, workflowExecution commonpb.WorkflowExecution, taskQueue, wType string, wTimeout *time.Duration, workflowTaskTimeout *time.Duration, nextEventID, lastProcessedEventID, workflowTaskScheduleID int64, timerTasks []tasks.Task) (*persistence.CreateWorkflowExecutionResponse, error) { + return s.CreateWorkflowExecutionWithBranchToken(ctx, namespaceID, workflowExecution, taskQueue, wType, wTimeout, workflowTaskTimeout, nextEventID, lastProcessedEventID, workflowTaskScheduleID, nil, timerTasks) } // CreateWorkflowExecutionManyTasks is a utility method to create workflow executions -func (s *TestBase) CreateWorkflowExecutionManyTasks(namespaceID string, workflowExecution commonpb.WorkflowExecution, +func (s *TestBase) CreateWorkflowExecutionManyTasks(ctx context.Context, namespaceID string, workflowExecution commonpb.WorkflowExecution, taskQueue string, nextEventID int64, lastProcessedEventID int64, workflowTaskScheduleIDs []int64, activityScheduleIDs []int64) (*persistence.CreateWorkflowExecutionResponse, error) { @@ -392,7 +393,7 @@ func (s *TestBase) CreateWorkflowExecutionManyTasks(namespaceID string, workflow }) } - response, err := s.ExecutionManager.CreateWorkflowExecution(&persistence.CreateWorkflowExecutionRequest{ + response, err := s.ExecutionManager.CreateWorkflowExecution(ctx, &persistence.CreateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), NewWorkflowSnapshot: persistence.WorkflowSnapshot{ ExecutionInfo: &persistencespb.WorkflowExecutionInfo{ @@ -426,11 +427,11 @@ func (s *TestBase) CreateWorkflowExecutionManyTasks(namespaceID string, workflow } // CreateChildWorkflowExecution is a utility method to create child workflow executions -func (s *TestBase) CreateChildWorkflowExecution(namespaceID string, workflowExecution commonpb.WorkflowExecution, +func (s *TestBase) CreateChildWorkflowExecution(ctx context.Context, namespaceID string, workflowExecution commonpb.WorkflowExecution, parentNamespaceID string, parentExecution commonpb.WorkflowExecution, initiatedID int64, taskQueue, wType string, wTimeout *time.Duration, workflowTaskTimeout *time.Duration, nextEventID int64, lastProcessedEventID int64, workflowTaskScheduleID int64, timerTasks []tasks.Task) (*persistence.CreateWorkflowExecutionResponse, error) { - response, err := s.ExecutionManager.CreateWorkflowExecution(&persistence.CreateWorkflowExecutionRequest{ + response, err := s.ExecutionManager.CreateWorkflowExecution(ctx, &persistence.CreateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), NewWorkflowSnapshot: persistence.WorkflowSnapshot{ ExecutionInfo: &persistencespb.WorkflowExecutionInfo{ @@ -482,9 +483,9 @@ func (s *TestBase) CreateChildWorkflowExecution(namespaceID string, workflowExec } // GetWorkflowExecutionInfo is a utility method to retrieve execution info -func (s *TestBase) GetWorkflowMutableState(namespaceID string, workflowExecution commonpb.WorkflowExecution) ( +func (s *TestBase) GetWorkflowMutableState(ctx context.Context, namespaceID string, workflowExecution commonpb.WorkflowExecution) ( *persistencespb.WorkflowMutableState, error) { - response, err := s.ExecutionManager.GetWorkflowExecution(&persistence.GetWorkflowExecutionRequest{ + response, err := s.ExecutionManager.GetWorkflowExecution(ctx, &persistence.GetWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), NamespaceID: namespaceID, WorkflowID: workflowExecution.GetWorkflowId(), @@ -497,8 +498,8 @@ func (s *TestBase) GetWorkflowMutableState(namespaceID string, workflowExecution } // GetCurrentWorkflowRunID returns the workflow run ID for the given params -func (s *TestBase) GetCurrentWorkflowRunID(namespaceID, workflowID string) (string, error) { - response, err := s.ExecutionManager.GetCurrentExecution(&persistence.GetCurrentExecutionRequest{ +func (s *TestBase) GetCurrentWorkflowRunID(ctx context.Context, namespaceID, workflowID string) (string, error) { + response, err := s.ExecutionManager.GetCurrentExecution(ctx, &persistence.GetCurrentExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), NamespaceID: namespaceID, WorkflowID: workflowID, @@ -512,7 +513,7 @@ func (s *TestBase) GetCurrentWorkflowRunID(namespaceID, workflowID string) (stri } // ContinueAsNewExecution is a utility method to create workflow executions -func (s *TestBase) ContinueAsNewExecution(updatedInfo *persistencespb.WorkflowExecutionInfo, updatedState *persistencespb.WorkflowExecutionState, updatedNextEventID int64, condition int64, +func (s *TestBase) ContinueAsNewExecution(ctx context.Context, updatedInfo *persistencespb.WorkflowExecutionInfo, updatedState *persistencespb.WorkflowExecutionState, updatedNextEventID int64, condition int64, newExecution commonpb.WorkflowExecution, nextEventID, workflowTaskScheduleID int64, prevResetPoints *workflowpb.ResetPoints) error { newworkflowTask := &tasks.WorkflowTask{ @@ -571,24 +572,24 @@ func (s *TestBase) ContinueAsNewExecution(updatedInfo *persistencespb.WorkflowEx } req.UpdateWorkflowMutation.ExecutionState.State = enumsspb.WORKFLOW_EXECUTION_STATE_COMPLETED req.UpdateWorkflowMutation.ExecutionState.Status = enumspb.WORKFLOW_EXECUTION_STATUS_CONTINUED_AS_NEW - _, err := s.ExecutionManager.UpdateWorkflowExecution(req) + _, err := s.ExecutionManager.UpdateWorkflowExecution(ctx, req) return err } // UpdateWorkflowExecution is a utility method to update workflow execution -func (s *TestBase) UpdateWorkflowExecution(updatedInfo *persistencespb.WorkflowExecutionInfo, +func (s *TestBase) UpdateWorkflowExecution(ctx context.Context, updatedInfo *persistencespb.WorkflowExecutionInfo, updatedState *persistencespb.WorkflowExecutionState, nextEventID int64, workflowTaskScheduleIDs []int64, activityScheduleIDs []int64, condition int64, timerTasks []tasks.Task, upsertActivityInfos []*persistencespb.ActivityInfo, deleteActivityInfos []int64, upsertTimerInfos []*persistencespb.TimerInfo, deleteTimerInfos []string) error { - return s.UpdateWorkflowExecutionWithRangeID(updatedInfo, updatedState, nextEventID, workflowTaskScheduleIDs, activityScheduleIDs, + return s.UpdateWorkflowExecutionWithRangeID(ctx, updatedInfo, updatedState, nextEventID, workflowTaskScheduleIDs, activityScheduleIDs, s.ShardInfo.GetRangeId(), condition, timerTasks, upsertActivityInfos, deleteActivityInfos, upsertTimerInfos, deleteTimerInfos, nil, nil, nil, nil, nil, nil, nil, nil) } // UpdateWorkflowExecutionAndFinish is a utility method to update workflow execution -func (s *TestBase) UpdateWorkflowExecutionAndFinish(updatedInfo *persistencespb.WorkflowExecutionInfo, +func (s *TestBase) UpdateWorkflowExecutionAndFinish(ctx context.Context, updatedInfo *persistencespb.WorkflowExecutionInfo, updatedState *persistencespb.WorkflowExecutionState, nextEventID int64, condition int64) error { var transferTasks []tasks.Task transferTasks = append(transferTasks, &tasks.CloseExecutionTask{ @@ -599,7 +600,7 @@ func (s *TestBase) UpdateWorkflowExecutionAndFinish(updatedInfo *persistencespb. ), TaskID: s.GetNextSequenceNumber(), }) - _, err := s.ExecutionManager.UpdateWorkflowExecution(&persistence.UpdateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.UpdateWorkflowExecution(ctx, &persistence.UpdateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), RangeID: s.ShardInfo.GetRangeId(), UpdateWorkflowMutation: persistence.WorkflowMutation{ @@ -620,95 +621,95 @@ func (s *TestBase) UpdateWorkflowExecutionAndFinish(updatedInfo *persistencespb. } // UpsertChildExecutionsState is a utility method to update mutable state of workflow execution -func (s *TestBase) UpsertChildExecutionsState(updatedInfo *persistencespb.WorkflowExecutionInfo, +func (s *TestBase) UpsertChildExecutionsState(ctx context.Context, updatedInfo *persistencespb.WorkflowExecutionInfo, updatedState *persistencespb.WorkflowExecutionState, nextEventID int64, condition int64, upsertChildInfos []*persistencespb.ChildExecutionInfo) error { - return s.UpdateWorkflowExecutionWithRangeID(updatedInfo, updatedState, nextEventID, nil, nil, + return s.UpdateWorkflowExecutionWithRangeID(ctx, updatedInfo, updatedState, nextEventID, nil, nil, s.ShardInfo.GetRangeId(), condition, nil, nil, nil, nil, nil, upsertChildInfos, nil, nil, nil, nil, nil, nil, nil) } // UpsertRequestCancelState is a utility method to update mutable state of workflow execution -func (s *TestBase) UpsertRequestCancelState(updatedInfo *persistencespb.WorkflowExecutionInfo, +func (s *TestBase) UpsertRequestCancelState(ctx context.Context, updatedInfo *persistencespb.WorkflowExecutionInfo, updatedState *persistencespb.WorkflowExecutionState, nextEventID int64, condition int64, upsertCancelInfos []*persistencespb.RequestCancelInfo) error { - return s.UpdateWorkflowExecutionWithRangeID(updatedInfo, updatedState, nextEventID, nil, nil, + return s.UpdateWorkflowExecutionWithRangeID(ctx, updatedInfo, updatedState, nextEventID, nil, nil, s.ShardInfo.GetRangeId(), condition, nil, nil, nil, nil, nil, nil, nil, upsertCancelInfos, nil, nil, nil, nil, nil) } // UpsertSignalInfoState is a utility method to update mutable state of workflow execution -func (s *TestBase) UpsertSignalInfoState(updatedInfo *persistencespb.WorkflowExecutionInfo, +func (s *TestBase) UpsertSignalInfoState(ctx context.Context, updatedInfo *persistencespb.WorkflowExecutionInfo, updatedState *persistencespb.WorkflowExecutionState, nextEventID int64, condition int64, upsertSignalInfos []*persistencespb.SignalInfo) error { - return s.UpdateWorkflowExecutionWithRangeID(updatedInfo, updatedState, nextEventID, nil, nil, + return s.UpdateWorkflowExecutionWithRangeID(ctx, updatedInfo, updatedState, nextEventID, nil, nil, s.ShardInfo.GetRangeId(), condition, nil, nil, nil, nil, nil, nil, nil, nil, nil, upsertSignalInfos, nil, nil, nil) } // UpsertSignalsRequestedState is a utility method to update mutable state of workflow execution -func (s *TestBase) UpsertSignalsRequestedState(updatedInfo *persistencespb.WorkflowExecutionInfo, +func (s *TestBase) UpsertSignalsRequestedState(ctx context.Context, updatedInfo *persistencespb.WorkflowExecutionInfo, updatedState *persistencespb.WorkflowExecutionState, nextEventID int64, condition int64, upsertSignalsRequested []string) error { - return s.UpdateWorkflowExecutionWithRangeID(updatedInfo, updatedState, nextEventID, nil, nil, + return s.UpdateWorkflowExecutionWithRangeID(ctx, updatedInfo, updatedState, nextEventID, nil, nil, s.ShardInfo.GetRangeId(), condition, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, upsertSignalsRequested, nil) } // DeleteChildExecutionsState is a utility method to delete child execution from mutable state -func (s *TestBase) DeleteChildExecutionsState(updatedInfo *persistencespb.WorkflowExecutionInfo, +func (s *TestBase) DeleteChildExecutionsState(ctx context.Context, updatedInfo *persistencespb.WorkflowExecutionInfo, updatedState *persistencespb.WorkflowExecutionState, nextEventID int64, condition int64, deleteChildInfo int64) error { - return s.UpdateWorkflowExecutionWithRangeID(updatedInfo, updatedState, nextEventID, nil, nil, + return s.UpdateWorkflowExecutionWithRangeID(ctx, updatedInfo, updatedState, nextEventID, nil, nil, s.ShardInfo.GetRangeId(), condition, nil, nil, nil, nil, nil, nil, []int64{deleteChildInfo}, nil, nil, nil, nil, nil, nil) } // DeleteCancelState is a utility method to delete request cancel state from mutable state -func (s *TestBase) DeleteCancelState(updatedInfo *persistencespb.WorkflowExecutionInfo, +func (s *TestBase) DeleteCancelState(ctx context.Context, updatedInfo *persistencespb.WorkflowExecutionInfo, updatedState *persistencespb.WorkflowExecutionState, nextEventID int64, condition int64, deleteCancelInfo int64) error { - return s.UpdateWorkflowExecutionWithRangeID(updatedInfo, updatedState, nextEventID, nil, nil, + return s.UpdateWorkflowExecutionWithRangeID(ctx, updatedInfo, updatedState, nextEventID, nil, nil, s.ShardInfo.GetRangeId(), condition, nil, nil, nil, nil, nil, nil, nil, nil, []int64{deleteCancelInfo}, nil, nil, nil, nil) } // DeleteSignalState is a utility method to delete request cancel state from mutable state -func (s *TestBase) DeleteSignalState(updatedInfo *persistencespb.WorkflowExecutionInfo, +func (s *TestBase) DeleteSignalState(ctx context.Context, updatedInfo *persistencespb.WorkflowExecutionInfo, updatedState *persistencespb.WorkflowExecutionState, nextEventID int64, condition int64, deleteSignalInfo int64) error { - return s.UpdateWorkflowExecutionWithRangeID(updatedInfo, updatedState, nextEventID, nil, nil, + return s.UpdateWorkflowExecutionWithRangeID(ctx, updatedInfo, updatedState, nextEventID, nil, nil, s.ShardInfo.GetRangeId(), condition, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, []int64{deleteSignalInfo}, nil, nil) } // DeleteSignalsRequestedState is a utility method to delete mutable state of workflow execution -func (s *TestBase) DeleteSignalsRequestedState(updatedInfo *persistencespb.WorkflowExecutionInfo, +func (s *TestBase) DeleteSignalsRequestedState(ctx context.Context, updatedInfo *persistencespb.WorkflowExecutionInfo, updatedState *persistencespb.WorkflowExecutionState, nextEventID int64, condition int64, deleteSignalsRequestedIDs []string) error { - return s.UpdateWorkflowExecutionWithRangeID(updatedInfo, updatedState, nextEventID, nil, nil, + return s.UpdateWorkflowExecutionWithRangeID(ctx, updatedInfo, updatedState, nextEventID, nil, nil, s.ShardInfo.GetRangeId(), condition, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, deleteSignalsRequestedIDs) } // UpdateWorklowStateAndReplication is a utility method to update workflow execution -func (s *TestBase) UpdateWorklowStateAndReplication(updatedInfo *persistencespb.WorkflowExecutionInfo, +func (s *TestBase) UpdateWorklowStateAndReplication(ctx context.Context, updatedInfo *persistencespb.WorkflowExecutionInfo, updatedState *persistencespb.WorkflowExecutionState, nextEventID int64, condition int64, txTasks []tasks.Task) error { - return s.UpdateWorkflowExecutionWithReplication(updatedInfo, updatedState, nextEventID, nil, nil, + return s.UpdateWorkflowExecutionWithReplication(ctx, updatedInfo, updatedState, nextEventID, nil, nil, s.ShardInfo.GetRangeId(), condition, nil, txTasks, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) } // UpdateWorkflowExecutionWithRangeID is a utility method to update workflow execution -func (s *TestBase) UpdateWorkflowExecutionWithRangeID(updatedInfo *persistencespb.WorkflowExecutionInfo, +func (s *TestBase) UpdateWorkflowExecutionWithRangeID(ctx context.Context, updatedInfo *persistencespb.WorkflowExecutionInfo, updatedState *persistencespb.WorkflowExecutionState, nextEventID int64, workflowTaskScheduleIDs []int64, activityScheduleIDs []int64, rangeID, condition int64, timerTasks []tasks.Task, upsertActivityInfos []*persistencespb.ActivityInfo, deleteActivityInfos []int64, upsertTimerInfos []*persistencespb.TimerInfo, @@ -716,14 +717,14 @@ func (s *TestBase) UpdateWorkflowExecutionWithRangeID(updatedInfo *persistencesp upsertCancelInfos []*persistencespb.RequestCancelInfo, deleteCancelInfos []int64, upsertSignalInfos []*persistencespb.SignalInfo, deleteSignalInfos []int64, upsertSignalRequestedIDs []string, deleteSignalRequestedIDs []string) error { - return s.UpdateWorkflowExecutionWithReplication(updatedInfo, updatedState, nextEventID, workflowTaskScheduleIDs, activityScheduleIDs, rangeID, + return s.UpdateWorkflowExecutionWithReplication(ctx, updatedInfo, updatedState, nextEventID, workflowTaskScheduleIDs, activityScheduleIDs, rangeID, condition, timerTasks, []tasks.Task{}, upsertActivityInfos, deleteActivityInfos, upsertTimerInfos, deleteTimerInfos, upsertChildInfos, deleteChildInfos, upsertCancelInfos, deleteCancelInfos, upsertSignalInfos, deleteSignalInfos, upsertSignalRequestedIDs, deleteSignalRequestedIDs) } // UpdateWorkflowExecutionWithReplication is a utility method to update workflow execution -func (s *TestBase) UpdateWorkflowExecutionWithReplication(updatedInfo *persistencespb.WorkflowExecutionInfo, +func (s *TestBase) UpdateWorkflowExecutionWithReplication(ctx context.Context, updatedInfo *persistencespb.WorkflowExecutionInfo, updatedState *persistencespb.WorkflowExecutionState, nextEventID int64, workflowTaskScheduleIDs []int64, activityScheduleIDs []int64, rangeID, condition int64, timerTasks []tasks.Task, txTasks []tasks.Task, upsertActivityInfos []*persistencespb.ActivityInfo, @@ -764,7 +765,7 @@ func (s *TestBase) UpdateWorkflowExecutionWithReplication(updatedInfo *persisten TaskQueue: updatedInfo.TaskQueue, ScheduleID: activityScheduleID}) } - _, err := s.ExecutionManager.UpdateWorkflowExecution(&persistence.UpdateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.UpdateWorkflowExecution(ctx, &persistence.UpdateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), RangeID: rangeID, UpdateWorkflowMutation: persistence.WorkflowMutation{ @@ -800,8 +801,14 @@ func (s *TestBase) UpdateWorkflowExecutionWithReplication(updatedInfo *persisten // UpdateWorkflowExecutionWithTransferTasks is a utility method to update workflow execution func (s *TestBase) UpdateWorkflowExecutionWithTransferTasks( - updatedInfo *persistencespb.WorkflowExecutionInfo, updatedState *persistencespb.WorkflowExecutionState, updatedNextEventID int64, condition int64, transferTasks []tasks.Task, upsertActivityInfo []*persistencespb.ActivityInfo) error { - _, err := s.ExecutionManager.UpdateWorkflowExecution(&persistence.UpdateWorkflowExecutionRequest{ + ctx context.Context, + updatedInfo *persistencespb.WorkflowExecutionInfo, + updatedState *persistencespb.WorkflowExecutionState, + updatedNextEventID int64, condition int64, + transferTasks []tasks.Task, + upsertActivityInfo []*persistencespb.ActivityInfo, +) error { + _, err := s.ExecutionManager.UpdateWorkflowExecution(ctx, &persistence.UpdateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), UpdateWorkflowMutation: persistence.WorkflowMutation{ ExecutionInfo: updatedInfo, @@ -820,8 +827,8 @@ func (s *TestBase) UpdateWorkflowExecutionWithTransferTasks( // UpdateWorkflowExecutionForChildExecutionsInitiated is a utility method to update workflow execution func (s *TestBase) UpdateWorkflowExecutionForChildExecutionsInitiated( - updatedInfo *persistencespb.WorkflowExecutionInfo, updatedNextEventID int64, condition int64, transferTasks []tasks.Task, childInfos []*persistencespb.ChildExecutionInfo) error { - _, err := s.ExecutionManager.UpdateWorkflowExecution(&persistence.UpdateWorkflowExecutionRequest{ + ctx context.Context, updatedInfo *persistencespb.WorkflowExecutionInfo, updatedNextEventID int64, condition int64, transferTasks []tasks.Task, childInfos []*persistencespb.ChildExecutionInfo) error { + _, err := s.ExecutionManager.UpdateWorkflowExecution(ctx, &persistence.UpdateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), UpdateWorkflowMutation: persistence.WorkflowMutation{ ExecutionInfo: updatedInfo, @@ -839,9 +846,9 @@ func (s *TestBase) UpdateWorkflowExecutionForChildExecutionsInitiated( // UpdateWorkflowExecutionForRequestCancel is a utility method to update workflow execution func (s *TestBase) UpdateWorkflowExecutionForRequestCancel( - updatedInfo *persistencespb.WorkflowExecutionInfo, updatedNextEventID int64, condition int64, transferTasks []tasks.Task, + ctx context.Context, updatedInfo *persistencespb.WorkflowExecutionInfo, updatedNextEventID int64, condition int64, transferTasks []tasks.Task, upsertRequestCancelInfo []*persistencespb.RequestCancelInfo) error { - _, err := s.ExecutionManager.UpdateWorkflowExecution(&persistence.UpdateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.UpdateWorkflowExecution(ctx, &persistence.UpdateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), UpdateWorkflowMutation: persistence.WorkflowMutation{ ExecutionInfo: updatedInfo, @@ -859,9 +866,13 @@ func (s *TestBase) UpdateWorkflowExecutionForRequestCancel( // UpdateWorkflowExecutionForSignal is a utility method to update workflow execution func (s *TestBase) UpdateWorkflowExecutionForSignal( - updatedInfo *persistencespb.WorkflowExecutionInfo, updatedNextEventID int64, condition int64, transferTasks []tasks.Task, - upsertSignalInfos []*persistencespb.SignalInfo) error { - _, err := s.ExecutionManager.UpdateWorkflowExecution(&persistence.UpdateWorkflowExecutionRequest{ + ctx context.Context, + updatedInfo *persistencespb.WorkflowExecutionInfo, + updatedNextEventID int64, condition int64, + transferTasks []tasks.Task, + upsertSignalInfos []*persistencespb.SignalInfo, +) error { + _, err := s.ExecutionManager.UpdateWorkflowExecution(ctx, &persistence.UpdateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), UpdateWorkflowMutation: persistence.WorkflowMutation{ ExecutionInfo: updatedInfo, @@ -879,9 +890,10 @@ func (s *TestBase) UpdateWorkflowExecutionForSignal( // UpdateWorkflowExecutionForBufferEvents is a utility method to update workflow execution func (s *TestBase) UpdateWorkflowExecutionForBufferEvents( + ctx context.Context, updatedInfo *persistencespb.WorkflowExecutionInfo, updatedNextEventID, condition int64, bufferEvents []*historypb.HistoryEvent, clearBufferedEvents bool) error { - _, err := s.ExecutionManager.UpdateWorkflowExecution(&persistence.UpdateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.UpdateWorkflowExecution(ctx, &persistence.UpdateWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), UpdateWorkflowMutation: persistence.WorkflowMutation{ ExecutionInfo: updatedInfo, @@ -896,7 +908,11 @@ func (s *TestBase) UpdateWorkflowExecutionForBufferEvents( } // UpdateAllMutableState is a utility method to update workflow execution -func (s *TestBase) UpdateAllMutableState(updatedMutableState *persistencespb.WorkflowMutableState, condition int64) error { +func (s *TestBase) UpdateAllMutableState( + ctx context.Context, + updatedMutableState *persistencespb.WorkflowMutableState, + condition int64, +) error { var aInfos []*persistencespb.ActivityInfo for _, ai := range updatedMutableState.ActivityInfos { aInfos = append(aInfos, ai) @@ -926,7 +942,7 @@ func (s *TestBase) UpdateAllMutableState(updatedMutableState *persistencespb.Wor for _, id := range updatedMutableState.SignalRequestedIds { srIDs = append(srIDs, id) } - _, err := s.ExecutionManager.UpdateWorkflowExecution(&persistence.UpdateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.UpdateWorkflowExecution(ctx, &persistence.UpdateWorkflowExecutionRequest{ RangeID: s.ShardInfo.GetRangeId(), UpdateWorkflowMutation: persistence.WorkflowMutation{ ExecutionInfo: updatedMutableState.ExecutionInfo, @@ -944,8 +960,8 @@ func (s *TestBase) UpdateAllMutableState(updatedMutableState *persistencespb.Wor } // DeleteWorkflowExecution is a utility method to delete a workflow execution -func (s *TestBase) DeleteWorkflowExecution(info *persistencespb.WorkflowExecutionInfo, state *persistencespb.WorkflowExecutionState) error { - return s.ExecutionManager.DeleteWorkflowExecution(&persistence.DeleteWorkflowExecutionRequest{ +func (s *TestBase) DeleteWorkflowExecution(ctx context.Context, info *persistencespb.WorkflowExecutionInfo, state *persistencespb.WorkflowExecutionState) error { + return s.ExecutionManager.DeleteWorkflowExecution(ctx, &persistence.DeleteWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), NamespaceID: info.NamespaceId, WorkflowID: info.WorkflowId, @@ -954,8 +970,8 @@ func (s *TestBase) DeleteWorkflowExecution(info *persistencespb.WorkflowExecutio } // DeleteCurrentWorkflowExecution is a utility method to delete the workflow current execution -func (s *TestBase) DeleteCurrentWorkflowExecution(info *persistencespb.WorkflowExecutionInfo, state *persistencespb.WorkflowExecutionState) error { - return s.ExecutionManager.DeleteCurrentWorkflowExecution(&persistence.DeleteCurrentWorkflowExecutionRequest{ +func (s *TestBase) DeleteCurrentWorkflowExecution(ctx context.Context, info *persistencespb.WorkflowExecutionInfo, state *persistencespb.WorkflowExecutionState) error { + return s.ExecutionManager.DeleteCurrentWorkflowExecution(ctx, &persistence.DeleteCurrentWorkflowExecutionRequest{ ShardID: s.ShardInfo.GetShardId(), NamespaceID: info.NamespaceId, WorkflowID: info.WorkflowId, @@ -964,13 +980,13 @@ func (s *TestBase) DeleteCurrentWorkflowExecution(info *persistencespb.WorkflowE } // GetTransferTasks is a utility method to get tasks from transfer task queue -func (s *TestBase) GetTransferTasks(batchSize int, getAll bool) ([]tasks.Task, error) { +func (s *TestBase) GetTransferTasks(ctx context.Context, batchSize int, getAll bool) ([]tasks.Task, error) { result := []tasks.Task{} var token []byte Loop: for { - response, err := s.ExecutionManager.GetHistoryTasks(&persistence.GetHistoryTasksRequest{ + response, err := s.ExecutionManager.GetHistoryTasks(ctx, &persistence.GetHistoryTasksRequest{ ShardID: s.ShardInfo.GetShardId(), TaskCategory: tasks.CategoryTransfer, InclusiveMinTaskKey: tasks.Key{ @@ -1001,13 +1017,13 @@ Loop: } // GetReplicationTasks is a utility method to get tasks from replication task queue -func (s *TestBase) GetReplicationTasks(batchSize int, getAll bool) ([]tasks.Task, error) { +func (s *TestBase) GetReplicationTasks(ctx context.Context, batchSize int, getAll bool) ([]tasks.Task, error) { result := []tasks.Task{} var token []byte Loop: for { - response, err := s.ExecutionManager.GetHistoryTasks(&persistence.GetHistoryTasksRequest{ + response, err := s.ExecutionManager.GetHistoryTasks(ctx, &persistence.GetHistoryTasksRequest{ ShardID: s.ShardInfo.GetShardId(), TaskCategory: tasks.CategoryReplication, InclusiveMinTaskKey: tasks.Key{ @@ -1038,11 +1054,12 @@ Loop: } func (s *TestBase) GetReplicationTasksInRange( + ctx context.Context, inclusiveMinTaskID int64, exclusiveMaxTaskID int64, batchSize int, ) ([]tasks.Task, error) { - response, err := s.ExecutionManager.GetHistoryTasks(&persistence.GetHistoryTasksRequest{ + response, err := s.ExecutionManager.GetHistoryTasks(ctx, &persistence.GetHistoryTasksRequest{ ShardID: s.ShardInfo.GetShardId(), TaskCategory: tasks.CategoryReplication, InclusiveMinTaskKey: tasks.Key{TaskID: inclusiveMinTaskID}, @@ -1057,8 +1074,8 @@ func (s *TestBase) GetReplicationTasksInRange( } // RangeCompleteReplicationTask is a utility method to complete a range of replication tasks -func (s *TestBase) RangeCompleteReplicationTask(exclusiveEndTaskID int64) error { - return s.ExecutionManager.RangeCompleteHistoryTasks(&persistence.RangeCompleteHistoryTasksRequest{ +func (s *TestBase) RangeCompleteReplicationTask(ctx context.Context, exclusiveEndTaskID int64) error { + return s.ExecutionManager.RangeCompleteHistoryTasks(ctx, &persistence.RangeCompleteHistoryTasksRequest{ ShardID: s.ShardInfo.GetShardId(), TaskCategory: tasks.CategoryReplication, ExclusiveMaxTaskKey: tasks.Key{ @@ -1069,11 +1086,12 @@ func (s *TestBase) RangeCompleteReplicationTask(exclusiveEndTaskID int64) error // PutReplicationTaskToDLQ is a utility method to insert a replication task info func (s *TestBase) PutReplicationTaskToDLQ( + ctx context.Context, sourceCluster string, taskInfo *persistencespb.ReplicationTaskInfo, ) error { - return s.ExecutionManager.PutReplicationTaskToDLQ(&persistence.PutReplicationTaskToDLQRequest{ + return s.ExecutionManager.PutReplicationTaskToDLQ(ctx, &persistence.PutReplicationTaskToDLQRequest{ ShardID: s.ShardInfo.GetShardId(), SourceClusterName: sourceCluster, TaskInfo: taskInfo, @@ -1082,6 +1100,7 @@ func (s *TestBase) PutReplicationTaskToDLQ( // GetReplicationTasksFromDLQ is a utility method to read replication task info func (s *TestBase) GetReplicationTasksFromDLQ( + ctx context.Context, sourceCluster string, inclusiveMinLevel int64, exclusiveMaxLevel int64, @@ -1089,7 +1108,7 @@ func (s *TestBase) GetReplicationTasksFromDLQ( pageToken []byte, ) (*persistence.GetHistoryTasksResponse, error) { - return s.ExecutionManager.GetReplicationTasksFromDLQ(&persistence.GetReplicationTasksFromDLQRequest{ + return s.ExecutionManager.GetReplicationTasksFromDLQ(ctx, &persistence.GetReplicationTasksFromDLQRequest{ GetHistoryTasksRequest: persistence.GetHistoryTasksRequest{ ShardID: s.ShardInfo.GetShardId(), TaskCategory: tasks.CategoryReplication, @@ -1104,11 +1123,12 @@ func (s *TestBase) GetReplicationTasksFromDLQ( // DeleteReplicationTaskFromDLQ is a utility method to delete a replication task info func (s *TestBase) DeleteReplicationTaskFromDLQ( + ctx context.Context, sourceCluster string, taskID int64, ) error { - return s.ExecutionManager.DeleteReplicationTaskFromDLQ(&persistence.DeleteReplicationTaskFromDLQRequest{ + return s.ExecutionManager.DeleteReplicationTaskFromDLQ(ctx, &persistence.DeleteReplicationTaskFromDLQRequest{ CompleteHistoryTaskRequest: persistence.CompleteHistoryTaskRequest{ ShardID: s.ShardInfo.GetShardId(), TaskCategory: tasks.CategoryReplication, @@ -1120,12 +1140,13 @@ func (s *TestBase) DeleteReplicationTaskFromDLQ( // RangeDeleteReplicationTaskFromDLQ is a utility method to delete replication task info func (s *TestBase) RangeDeleteReplicationTaskFromDLQ( + ctx context.Context, sourceCluster string, inclusiveMinTaskID int64, exclusiveMaxTaskID int64, ) error { - return s.ExecutionManager.RangeDeleteReplicationTaskFromDLQ(&persistence.RangeDeleteReplicationTaskFromDLQRequest{ + return s.ExecutionManager.RangeDeleteReplicationTaskFromDLQ(ctx, &persistence.RangeDeleteReplicationTaskFromDLQRequest{ RangeCompleteHistoryTasksRequest: persistence.RangeCompleteHistoryTasksRequest{ ShardID: s.ShardInfo.GetShardId(), TaskCategory: tasks.CategoryReplication, @@ -1137,9 +1158,9 @@ func (s *TestBase) RangeDeleteReplicationTaskFromDLQ( } // CompleteTransferTask is a utility method to complete a transfer task -func (s *TestBase) CompleteTransferTask(taskID int64) error { +func (s *TestBase) CompleteTransferTask(ctx context.Context, taskID int64) error { - return s.ExecutionManager.CompleteHistoryTask(&persistence.CompleteHistoryTaskRequest{ + return s.ExecutionManager.CompleteHistoryTask(ctx, &persistence.CompleteHistoryTaskRequest{ ShardID: s.ShardInfo.GetShardId(), TaskCategory: tasks.CategoryTransfer, TaskKey: tasks.Key{ @@ -1149,8 +1170,8 @@ func (s *TestBase) CompleteTransferTask(taskID int64) error { } // RangeCompleteTransferTask is a utility method to complete a range of transfer tasks -func (s *TestBase) RangeCompleteTransferTask(inclusiveMinTaskID int64, exclusiveMaxTaskID int64) error { - return s.ExecutionManager.RangeCompleteHistoryTasks(&persistence.RangeCompleteHistoryTasksRequest{ +func (s *TestBase) RangeCompleteTransferTask(ctx context.Context, inclusiveMinTaskID int64, exclusiveMaxTaskID int64) error { + return s.ExecutionManager.RangeCompleteHistoryTasks(ctx, &persistence.RangeCompleteHistoryTasksRequest{ ShardID: s.ShardInfo.GetShardId(), TaskCategory: tasks.CategoryTransfer, InclusiveMinTaskKey: tasks.Key{ @@ -1163,9 +1184,9 @@ func (s *TestBase) RangeCompleteTransferTask(inclusiveMinTaskID int64, exclusive } // CompleteReplicationTask is a utility method to complete a replication task -func (s *TestBase) CompleteReplicationTask(taskID int64) error { +func (s *TestBase) CompleteReplicationTask(ctx context.Context, taskID int64) error { - return s.ExecutionManager.CompleteHistoryTask(&persistence.CompleteHistoryTaskRequest{ + return s.ExecutionManager.CompleteHistoryTask(ctx, &persistence.CompleteHistoryTaskRequest{ ShardID: s.ShardInfo.GetShardId(), TaskCategory: tasks.CategoryReplication, TaskKey: tasks.Key{ @@ -1175,13 +1196,13 @@ func (s *TestBase) CompleteReplicationTask(taskID int64) error { } // GetTimerTasks is a utility method to get tasks from transfer task queue -func (s *TestBase) GetTimerTasks(batchSize int, getAll bool) ([]tasks.Task, error) { +func (s *TestBase) GetTimerTasks(ctx context.Context, batchSize int, getAll bool) ([]tasks.Task, error) { result := []tasks.Task{} var token []byte Loop: for { - response, err := s.ExecutionManager.GetHistoryTasks(&persistence.GetHistoryTasksRequest{ + response, err := s.ExecutionManager.GetHistoryTasks(ctx, &persistence.GetHistoryTasksRequest{ ShardID: s.ShardInfo.GetShardId(), TaskCategory: tasks.CategoryTimer, InclusiveMinTaskKey: tasks.Key{ @@ -1208,8 +1229,8 @@ Loop: } // CompleteTimerTask is a utility method to complete a timer task -func (s *TestBase) CompleteTimerTask(ts time.Time, taskID int64) error { - return s.ExecutionManager.CompleteHistoryTask(&persistence.CompleteHistoryTaskRequest{ +func (s *TestBase) CompleteTimerTask(ctx context.Context, ts time.Time, taskID int64) error { + return s.ExecutionManager.CompleteHistoryTask(ctx, &persistence.CompleteHistoryTaskRequest{ ShardID: s.ShardInfo.GetShardId(), TaskCategory: tasks.CategoryTimer, TaskKey: tasks.Key{ @@ -1219,17 +1240,17 @@ func (s *TestBase) CompleteTimerTask(ts time.Time, taskID int64) error { }) } -func (s *TestBase) CompleteTimerTaskProto(ts *types.Timestamp, taskID int64) error { +func (s *TestBase) CompleteTimerTaskProto(ctx context.Context, ts *types.Timestamp, taskID int64) error { t, err := types.TimestampFromProto(ts) if err != nil { return err } - return s.CompleteTimerTask(t, taskID) + return s.CompleteTimerTask(ctx, t, taskID) } // RangeCompleteTimerTask is a utility method to complete a range of timer tasks -func (s *TestBase) RangeCompleteTimerTask(inclusiveBeginTimestamp time.Time, exclusiveEndTimestamp time.Time) error { - return s.ExecutionManager.RangeCompleteHistoryTasks(&persistence.RangeCompleteHistoryTasksRequest{ +func (s *TestBase) RangeCompleteTimerTask(ctx context.Context, inclusiveBeginTimestamp time.Time, exclusiveEndTimestamp time.Time) error { + return s.ExecutionManager.RangeCompleteHistoryTasks(ctx, &persistence.RangeCompleteHistoryTasksRequest{ ShardID: s.ShardInfo.GetShardId(), TaskCategory: tasks.CategoryTimer, InclusiveMinTaskKey: tasks.Key{ @@ -1271,15 +1292,15 @@ func (s *TestBase) GetReplicationReadLevel() int64 { } // ClearTasks completes all transfer tasks and replication tasks -func (s *TestBase) ClearTasks() { - s.ClearTransferQueue() - s.ClearReplicationQueue() +func (s *TestBase) ClearTasks(ctx context.Context) { + s.ClearTransferQueue(ctx) + s.ClearReplicationQueue(ctx) } // ClearTransferQueue completes all tasks in transfer queue -func (s *TestBase) ClearTransferQueue() { +func (s *TestBase) ClearTransferQueue(ctx context.Context) { s.Logger.Info("Clearing transfer tasks", tag.ShardRangeID(s.ShardInfo.GetRangeId()), tag.ReadLevel(s.GetTransferReadLevel())) - tasks, err := s.GetTransferTasks(100, true) + tasks, err := s.GetTransferTasks(ctx, 100, true) if err != nil { s.Logger.Fatal("Error during cleanup", tag.Error(err)) } @@ -1287,7 +1308,7 @@ func (s *TestBase) ClearTransferQueue() { counter := 0 for _, t := range tasks { s.Logger.Info("Deleting transfer task with ID", tag.TaskID(t.GetTaskID())) - s.NoError(s.CompleteTransferTask(t.GetTaskID())) + s.NoError(s.CompleteTransferTask(ctx, t.GetTaskID())) counter++ } @@ -1296,9 +1317,9 @@ func (s *TestBase) ClearTransferQueue() { } // ClearReplicationQueue completes all tasks in replication queue -func (s *TestBase) ClearReplicationQueue() { +func (s *TestBase) ClearReplicationQueue(ctx context.Context) { s.Logger.Info("Clearing replication tasks", tag.ShardRangeID(s.ShardInfo.GetRangeId()), tag.ReadLevel(s.GetReplicationReadLevel())) - tasks, err := s.GetReplicationTasks(100, true) + tasks, err := s.GetReplicationTasks(ctx, 100, true) if err != nil { s.Logger.Fatal("Error during cleanup", tag.Error(err)) } @@ -1306,7 +1327,7 @@ func (s *TestBase) ClearReplicationQueue() { counter := 0 for _, t := range tasks { s.Logger.Info("Deleting replication task with ID", tag.TaskID(t.GetTaskID())) - s.NoError(s.CompleteReplicationTask(t.GetTaskID())) + s.NoError(s.CompleteReplicationTask(ctx, t.GetTaskID())) counter++ } diff --git a/common/persistence/persistenceMetricClients.go b/common/persistence/persistenceMetricClients.go index f1e67e6b858..f47e76dc1d0 100644 --- a/common/persistence/persistenceMetricClients.go +++ b/common/persistence/persistenceMetricClients.go @@ -25,6 +25,7 @@ package persistence import ( + "context" "fmt" commonpb "go.temporal.io/api/common/v1" @@ -187,11 +188,14 @@ func (p *executionPersistenceClient) GetName() string { return p.persistence.GetName() } -func (p *executionPersistenceClient) CreateWorkflowExecution(request *CreateWorkflowExecutionRequest) (*CreateWorkflowExecutionResponse, error) { +func (p *executionPersistenceClient) CreateWorkflowExecution( + ctx context.Context, + request *CreateWorkflowExecutionRequest, +) (*CreateWorkflowExecutionResponse, error) { p.metricClient.IncCounter(metrics.PersistenceCreateWorkflowExecutionScope, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(metrics.PersistenceCreateWorkflowExecutionScope, metrics.PersistenceLatency) - response, err := p.persistence.CreateWorkflowExecution(request) + response, err := p.persistence.CreateWorkflowExecution(ctx, request) sw.Stop() if err != nil { @@ -201,11 +205,14 @@ func (p *executionPersistenceClient) CreateWorkflowExecution(request *CreateWork return response, err } -func (p *executionPersistenceClient) GetWorkflowExecution(request *GetWorkflowExecutionRequest) (*GetWorkflowExecutionResponse, error) { +func (p *executionPersistenceClient) GetWorkflowExecution( + ctx context.Context, + request *GetWorkflowExecutionRequest, +) (*GetWorkflowExecutionResponse, error) { p.metricClient.IncCounter(metrics.PersistenceGetWorkflowExecutionScope, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(metrics.PersistenceGetWorkflowExecutionScope, metrics.PersistenceLatency) - response, err := p.persistence.GetWorkflowExecution(request) + response, err := p.persistence.GetWorkflowExecution(ctx, request) sw.Stop() if err != nil { @@ -216,12 +223,13 @@ func (p *executionPersistenceClient) GetWorkflowExecution(request *GetWorkflowEx } func (p *executionPersistenceClient) SetWorkflowExecution( + ctx context.Context, request *SetWorkflowExecutionRequest, ) (*SetWorkflowExecutionResponse, error) { p.metricClient.IncCounter(metrics.PersistenceSetWorkflowExecutionScope, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(metrics.PersistenceSetWorkflowExecutionScope, metrics.PersistenceLatency) - response, err := p.persistence.SetWorkflowExecution(request) + response, err := p.persistence.SetWorkflowExecution(ctx, request) sw.Stop() if err != nil { @@ -231,11 +239,14 @@ func (p *executionPersistenceClient) SetWorkflowExecution( return response, err } -func (p *executionPersistenceClient) UpdateWorkflowExecution(request *UpdateWorkflowExecutionRequest) (*UpdateWorkflowExecutionResponse, error) { +func (p *executionPersistenceClient) UpdateWorkflowExecution( + ctx context.Context, + request *UpdateWorkflowExecutionRequest, +) (*UpdateWorkflowExecutionResponse, error) { p.metricClient.IncCounter(metrics.PersistenceUpdateWorkflowExecutionScope, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(metrics.PersistenceUpdateWorkflowExecutionScope, metrics.PersistenceLatency) - resp, err := p.persistence.UpdateWorkflowExecution(request) + resp, err := p.persistence.UpdateWorkflowExecution(ctx, request) sw.Stop() if err != nil { @@ -245,11 +256,14 @@ func (p *executionPersistenceClient) UpdateWorkflowExecution(request *UpdateWork return resp, err } -func (p *executionPersistenceClient) ConflictResolveWorkflowExecution(request *ConflictResolveWorkflowExecutionRequest) (*ConflictResolveWorkflowExecutionResponse, error) { +func (p *executionPersistenceClient) ConflictResolveWorkflowExecution( + ctx context.Context, + request *ConflictResolveWorkflowExecutionRequest, +) (*ConflictResolveWorkflowExecutionResponse, error) { p.metricClient.IncCounter(metrics.PersistenceConflictResolveWorkflowExecutionScope, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(metrics.PersistenceConflictResolveWorkflowExecutionScope, metrics.PersistenceLatency) - response, err := p.persistence.ConflictResolveWorkflowExecution(request) + response, err := p.persistence.ConflictResolveWorkflowExecution(ctx, request) sw.Stop() if err != nil { @@ -259,11 +273,14 @@ func (p *executionPersistenceClient) ConflictResolveWorkflowExecution(request *C return response, err } -func (p *executionPersistenceClient) DeleteWorkflowExecution(request *DeleteWorkflowExecutionRequest) error { +func (p *executionPersistenceClient) DeleteWorkflowExecution( + ctx context.Context, + request *DeleteWorkflowExecutionRequest, +) error { p.metricClient.IncCounter(metrics.PersistenceDeleteWorkflowExecutionScope, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(metrics.PersistenceDeleteWorkflowExecutionScope, metrics.PersistenceLatency) - err := p.persistence.DeleteWorkflowExecution(request) + err := p.persistence.DeleteWorkflowExecution(ctx, request) sw.Stop() if err != nil { @@ -273,11 +290,14 @@ func (p *executionPersistenceClient) DeleteWorkflowExecution(request *DeleteWork return err } -func (p *executionPersistenceClient) DeleteCurrentWorkflowExecution(request *DeleteCurrentWorkflowExecutionRequest) error { +func (p *executionPersistenceClient) DeleteCurrentWorkflowExecution( + ctx context.Context, + request *DeleteCurrentWorkflowExecutionRequest, +) error { p.metricClient.IncCounter(metrics.PersistenceDeleteCurrentWorkflowExecutionScope, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(metrics.PersistenceDeleteCurrentWorkflowExecutionScope, metrics.PersistenceLatency) - err := p.persistence.DeleteCurrentWorkflowExecution(request) + err := p.persistence.DeleteCurrentWorkflowExecution(ctx, request) sw.Stop() if err != nil { @@ -287,11 +307,14 @@ func (p *executionPersistenceClient) DeleteCurrentWorkflowExecution(request *Del return err } -func (p *executionPersistenceClient) GetCurrentExecution(request *GetCurrentExecutionRequest) (*GetCurrentExecutionResponse, error) { +func (p *executionPersistenceClient) GetCurrentExecution( + ctx context.Context, + request *GetCurrentExecutionRequest, +) (*GetCurrentExecutionResponse, error) { p.metricClient.IncCounter(metrics.PersistenceGetCurrentExecutionScope, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(metrics.PersistenceGetCurrentExecutionScope, metrics.PersistenceLatency) - response, err := p.persistence.GetCurrentExecution(request) + response, err := p.persistence.GetCurrentExecution(ctx, request) sw.Stop() if err != nil { @@ -301,11 +324,14 @@ func (p *executionPersistenceClient) GetCurrentExecution(request *GetCurrentExec return response, err } -func (p *executionPersistenceClient) ListConcreteExecutions(request *ListConcreteExecutionsRequest) (*ListConcreteExecutionsResponse, error) { +func (p *executionPersistenceClient) ListConcreteExecutions( + ctx context.Context, + request *ListConcreteExecutionsRequest, +) (*ListConcreteExecutionsResponse, error) { p.metricClient.IncCounter(metrics.PersistenceListConcreteExecutionsScope, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(metrics.PersistenceListConcreteExecutionsScope, metrics.PersistenceLatency) - response, err := p.persistence.ListConcreteExecutions(request) + response, err := p.persistence.ListConcreteExecutions(ctx, request) sw.Stop() if err != nil { @@ -315,11 +341,14 @@ func (p *executionPersistenceClient) ListConcreteExecutions(request *ListConcret return response, err } -func (p *executionPersistenceClient) AddHistoryTasks(request *AddHistoryTasksRequest) error { +func (p *executionPersistenceClient) AddHistoryTasks( + ctx context.Context, + request *AddHistoryTasksRequest, +) error { p.metricClient.IncCounter(metrics.PersistenceAddTasksScope, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(metrics.PersistenceAddTasksScope, metrics.PersistenceLatency) - err := p.persistence.AddHistoryTasks(request) + err := p.persistence.AddHistoryTasks(ctx, request) sw.Stop() if err != nil { @@ -329,7 +358,10 @@ func (p *executionPersistenceClient) AddHistoryTasks(request *AddHistoryTasksReq return err } -func (p *executionPersistenceClient) GetHistoryTask(request *GetHistoryTaskRequest) (*GetHistoryTaskResponse, error) { +func (p *executionPersistenceClient) GetHistoryTask( + ctx context.Context, + request *GetHistoryTaskRequest, +) (*GetHistoryTaskResponse, error) { var scopeIdx int switch request.TaskCategory.ID() { case tasks.CategoryIDTransfer: @@ -347,7 +379,7 @@ func (p *executionPersistenceClient) GetHistoryTask(request *GetHistoryTaskReque p.metricClient.IncCounter(scopeIdx, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(scopeIdx, metrics.PersistenceLatency) - response, err := p.persistence.GetHistoryTask(request) + response, err := p.persistence.GetHistoryTask(ctx, request) sw.Stop() if err != nil { @@ -357,7 +389,10 @@ func (p *executionPersistenceClient) GetHistoryTask(request *GetHistoryTaskReque return response, err } -func (p *executionPersistenceClient) GetHistoryTasks(request *GetHistoryTasksRequest) (*GetHistoryTasksResponse, error) { +func (p *executionPersistenceClient) GetHistoryTasks( + ctx context.Context, + request *GetHistoryTasksRequest, +) (*GetHistoryTasksResponse, error) { var scopeIdx int switch request.TaskCategory.ID() { case tasks.CategoryIDTransfer: @@ -375,7 +410,7 @@ func (p *executionPersistenceClient) GetHistoryTasks(request *GetHistoryTasksReq p.metricClient.IncCounter(scopeIdx, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(scopeIdx, metrics.PersistenceLatency) - response, err := p.persistence.GetHistoryTasks(request) + response, err := p.persistence.GetHistoryTasks(ctx, request) sw.Stop() if err != nil { @@ -385,7 +420,10 @@ func (p *executionPersistenceClient) GetHistoryTasks(request *GetHistoryTasksReq return response, err } -func (p *executionPersistenceClient) CompleteHistoryTask(request *CompleteHistoryTaskRequest) error { +func (p *executionPersistenceClient) CompleteHistoryTask( + ctx context.Context, + request *CompleteHistoryTaskRequest, +) error { var scopeIdx int switch request.TaskCategory.ID() { case tasks.CategoryIDTransfer: @@ -403,7 +441,7 @@ func (p *executionPersistenceClient) CompleteHistoryTask(request *CompleteHistor p.metricClient.IncCounter(scopeIdx, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(scopeIdx, metrics.PersistenceLatency) - err := p.persistence.CompleteHistoryTask(request) + err := p.persistence.CompleteHistoryTask(ctx, request) sw.Stop() if err != nil { @@ -413,7 +451,10 @@ func (p *executionPersistenceClient) CompleteHistoryTask(request *CompleteHistor return err } -func (p *executionPersistenceClient) RangeCompleteHistoryTasks(request *RangeCompleteHistoryTasksRequest) error { +func (p *executionPersistenceClient) RangeCompleteHistoryTasks( + ctx context.Context, + request *RangeCompleteHistoryTasksRequest, +) error { var scopeIdx int switch request.TaskCategory.ID() { case tasks.CategoryIDTransfer: @@ -431,7 +472,7 @@ func (p *executionPersistenceClient) RangeCompleteHistoryTasks(request *RangeCom p.metricClient.IncCounter(scopeIdx, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(scopeIdx, metrics.PersistenceLatency) - err := p.persistence.RangeCompleteHistoryTasks(request) + err := p.persistence.RangeCompleteHistoryTasks(ctx, request) sw.Stop() if err != nil { @@ -442,12 +483,13 @@ func (p *executionPersistenceClient) RangeCompleteHistoryTasks(request *RangeCom } func (p *executionPersistenceClient) PutReplicationTaskToDLQ( + ctx context.Context, request *PutReplicationTaskToDLQRequest, ) error { p.metricClient.IncCounter(metrics.PersistencePutReplicationTaskToDLQScope, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(metrics.PersistencePutReplicationTaskToDLQScope, metrics.PersistenceLatency) - err := p.persistence.PutReplicationTaskToDLQ(request) + err := p.persistence.PutReplicationTaskToDLQ(ctx, request) sw.Stop() if err != nil { @@ -458,12 +500,13 @@ func (p *executionPersistenceClient) PutReplicationTaskToDLQ( } func (p *executionPersistenceClient) GetReplicationTasksFromDLQ( + ctx context.Context, request *GetReplicationTasksFromDLQRequest, ) (*GetHistoryTasksResponse, error) { p.metricClient.IncCounter(metrics.PersistenceGetReplicationTasksFromDLQScope, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(metrics.PersistenceGetReplicationTasksFromDLQScope, metrics.PersistenceLatency) - response, err := p.persistence.GetReplicationTasksFromDLQ(request) + response, err := p.persistence.GetReplicationTasksFromDLQ(ctx, request) sw.Stop() if err != nil { @@ -474,12 +517,13 @@ func (p *executionPersistenceClient) GetReplicationTasksFromDLQ( } func (p *executionPersistenceClient) DeleteReplicationTaskFromDLQ( + ctx context.Context, request *DeleteReplicationTaskFromDLQRequest, ) error { p.metricClient.IncCounter(metrics.PersistenceDeleteReplicationTaskFromDLQScope, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(metrics.PersistenceDeleteReplicationTaskFromDLQScope, metrics.PersistenceLatency) - err := p.persistence.DeleteReplicationTaskFromDLQ(request) + err := p.persistence.DeleteReplicationTaskFromDLQ(ctx, request) sw.Stop() if err != nil { @@ -490,12 +534,13 @@ func (p *executionPersistenceClient) DeleteReplicationTaskFromDLQ( } func (p *executionPersistenceClient) RangeDeleteReplicationTaskFromDLQ( + ctx context.Context, request *RangeDeleteReplicationTaskFromDLQRequest, ) error { p.metricClient.IncCounter(metrics.PersistenceRangeDeleteReplicationTaskFromDLQScope, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(metrics.PersistenceRangeDeleteReplicationTaskFromDLQScope, metrics.PersistenceLatency) - err := p.persistence.RangeDeleteReplicationTaskFromDLQ(request) + err := p.persistence.RangeDeleteReplicationTaskFromDLQ(ctx, request) sw.Stop() if err != nil { @@ -752,10 +797,13 @@ func (p *metadataPersistenceClient) Close() { } // AppendHistoryNodes add a node to history node table -func (p *executionPersistenceClient) AppendHistoryNodes(request *AppendHistoryNodesRequest) (*AppendHistoryNodesResponse, error) { +func (p *executionPersistenceClient) AppendHistoryNodes( + ctx context.Context, + request *AppendHistoryNodesRequest, +) (*AppendHistoryNodesResponse, error) { p.metricClient.IncCounter(metrics.PersistenceAppendHistoryNodesScope, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(metrics.PersistenceAppendHistoryNodesScope, metrics.PersistenceLatency) - resp, err := p.persistence.AppendHistoryNodes(request) + resp, err := p.persistence.AppendHistoryNodes(ctx, request) sw.Stop() if err != nil { p.updateErrorMetric(metrics.PersistenceAppendHistoryNodesScope, err) @@ -764,10 +812,13 @@ func (p *executionPersistenceClient) AppendHistoryNodes(request *AppendHistoryNo } // ReadHistoryBranch returns history node data for a branch -func (p *executionPersistenceClient) ReadHistoryBranch(request *ReadHistoryBranchRequest) (*ReadHistoryBranchResponse, error) { +func (p *executionPersistenceClient) ReadHistoryBranch( + ctx context.Context, + request *ReadHistoryBranchRequest, +) (*ReadHistoryBranchResponse, error) { p.metricClient.IncCounter(metrics.PersistenceReadHistoryBranchScope, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(metrics.PersistenceReadHistoryBranchScope, metrics.PersistenceLatency) - response, err := p.persistence.ReadHistoryBranch(request) + response, err := p.persistence.ReadHistoryBranch(ctx, request) sw.Stop() if err != nil { p.updateErrorMetric(metrics.PersistenceReadHistoryBranchScope, err) @@ -775,13 +826,13 @@ func (p *executionPersistenceClient) ReadHistoryBranch(request *ReadHistoryBranc return response, err } -func (p *executionPersistenceClient) ReadHistoryBranchReverse(request *ReadHistoryBranchReverseRequest) ( - *ReadHistoryBranchReverseResponse, - error, -) { +func (p *executionPersistenceClient) ReadHistoryBranchReverse( + ctx context.Context, + request *ReadHistoryBranchReverseRequest, +) (*ReadHistoryBranchReverseResponse, error) { p.metricClient.IncCounter(metrics.PersistenceReadHistoryBranchReverseScope, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(metrics.PersistenceReadHistoryBranchReverseScope, metrics.PersistenceLatency) - response, err := p.persistence.ReadHistoryBranchReverse(request) + response, err := p.persistence.ReadHistoryBranchReverse(ctx, request) sw.Stop() if err != nil { p.updateErrorMetric(metrics.PersistenceReadHistoryBranchReverseScope, err) @@ -790,10 +841,13 @@ func (p *executionPersistenceClient) ReadHistoryBranchReverse(request *ReadHisto } // ReadHistoryBranchByBatch returns history node data for a branch ByBatch -func (p *executionPersistenceClient) ReadHistoryBranchByBatch(request *ReadHistoryBranchRequest) (*ReadHistoryBranchByBatchResponse, error) { +func (p *executionPersistenceClient) ReadHistoryBranchByBatch( + ctx context.Context, + request *ReadHistoryBranchRequest, +) (*ReadHistoryBranchByBatchResponse, error) { p.metricClient.IncCounter(metrics.PersistenceReadHistoryBranchScope, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(metrics.PersistenceReadHistoryBranchScope, metrics.PersistenceLatency) - response, err := p.persistence.ReadHistoryBranchByBatch(request) + response, err := p.persistence.ReadHistoryBranchByBatch(ctx, request) sw.Stop() if err != nil { p.updateErrorMetric(metrics.PersistenceReadHistoryBranchScope, err) @@ -802,10 +856,13 @@ func (p *executionPersistenceClient) ReadHistoryBranchByBatch(request *ReadHisto } // ReadRawHistoryBranch returns history node raw data for a branch ByBatch -func (p *executionPersistenceClient) ReadRawHistoryBranch(request *ReadHistoryBranchRequest) (*ReadRawHistoryBranchResponse, error) { +func (p *executionPersistenceClient) ReadRawHistoryBranch( + ctx context.Context, + request *ReadHistoryBranchRequest, +) (*ReadRawHistoryBranchResponse, error) { p.metricClient.IncCounter(metrics.PersistenceReadHistoryBranchScope, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(metrics.PersistenceReadHistoryBranchScope, metrics.PersistenceLatency) - response, err := p.persistence.ReadRawHistoryBranch(request) + response, err := p.persistence.ReadRawHistoryBranch(ctx, request) sw.Stop() if err != nil { p.updateErrorMetric(metrics.PersistenceReadHistoryBranchScope, err) @@ -814,10 +871,13 @@ func (p *executionPersistenceClient) ReadRawHistoryBranch(request *ReadHistoryBr } // ForkHistoryBranch forks a new branch from a old branch -func (p *executionPersistenceClient) ForkHistoryBranch(request *ForkHistoryBranchRequest) (*ForkHistoryBranchResponse, error) { +func (p *executionPersistenceClient) ForkHistoryBranch( + ctx context.Context, + request *ForkHistoryBranchRequest, +) (*ForkHistoryBranchResponse, error) { p.metricClient.IncCounter(metrics.PersistenceForkHistoryBranchScope, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(metrics.PersistenceForkHistoryBranchScope, metrics.PersistenceLatency) - response, err := p.persistence.ForkHistoryBranch(request) + response, err := p.persistence.ForkHistoryBranch(ctx, request) sw.Stop() if err != nil { p.updateErrorMetric(metrics.PersistenceForkHistoryBranchScope, err) @@ -826,10 +886,13 @@ func (p *executionPersistenceClient) ForkHistoryBranch(request *ForkHistoryBranc } // DeleteHistoryBranch removes a branch -func (p *executionPersistenceClient) DeleteHistoryBranch(request *DeleteHistoryBranchRequest) error { +func (p *executionPersistenceClient) DeleteHistoryBranch( + ctx context.Context, + request *DeleteHistoryBranchRequest, +) error { p.metricClient.IncCounter(metrics.PersistenceDeleteHistoryBranchScope, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(metrics.PersistenceDeleteHistoryBranchScope, metrics.PersistenceLatency) - err := p.persistence.DeleteHistoryBranch(request) + err := p.persistence.DeleteHistoryBranch(ctx, request) sw.Stop() if err != nil { p.updateErrorMetric(metrics.PersistenceDeleteHistoryBranchScope, err) @@ -838,10 +901,13 @@ func (p *executionPersistenceClient) DeleteHistoryBranch(request *DeleteHistoryB } // TrimHistoryBranch trims a branch -func (p *executionPersistenceClient) TrimHistoryBranch(request *TrimHistoryBranchRequest) (*TrimHistoryBranchResponse, error) { +func (p *executionPersistenceClient) TrimHistoryBranch( + ctx context.Context, + request *TrimHistoryBranchRequest, +) (*TrimHistoryBranchResponse, error) { p.metricClient.IncCounter(metrics.PersistenceTrimHistoryBranchScope, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(metrics.PersistenceTrimHistoryBranchScope, metrics.PersistenceLatency) - resp, err := p.persistence.TrimHistoryBranch(request) + resp, err := p.persistence.TrimHistoryBranch(ctx, request) sw.Stop() if err != nil { p.updateErrorMetric(metrics.PersistenceTrimHistoryBranchScope, err) @@ -849,10 +915,13 @@ func (p *executionPersistenceClient) TrimHistoryBranch(request *TrimHistoryBranc return resp, err } -func (p *executionPersistenceClient) GetAllHistoryTreeBranches(request *GetAllHistoryTreeBranchesRequest) (*GetAllHistoryTreeBranchesResponse, error) { +func (p *executionPersistenceClient) GetAllHistoryTreeBranches( + ctx context.Context, + request *GetAllHistoryTreeBranchesRequest, +) (*GetAllHistoryTreeBranchesResponse, error) { p.metricClient.IncCounter(metrics.PersistenceGetAllHistoryTreeBranchesScope, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(metrics.PersistenceGetAllHistoryTreeBranchesScope, metrics.PersistenceLatency) - response, err := p.persistence.GetAllHistoryTreeBranches(request) + response, err := p.persistence.GetAllHistoryTreeBranches(ctx, request) sw.Stop() if err != nil { p.updateErrorMetric(metrics.PersistenceGetAllHistoryTreeBranchesScope, err) @@ -861,10 +930,13 @@ func (p *executionPersistenceClient) GetAllHistoryTreeBranches(request *GetAllHi } // GetHistoryTree returns all branch information of a tree -func (p *executionPersistenceClient) GetHistoryTree(request *GetHistoryTreeRequest) (*GetHistoryTreeResponse, error) { +func (p *executionPersistenceClient) GetHistoryTree( + ctx context.Context, + request *GetHistoryTreeRequest, +) (*GetHistoryTreeResponse, error) { p.metricClient.IncCounter(metrics.PersistenceGetHistoryTreeScope, metrics.PersistenceRequests) sw := p.metricClient.StartTimer(metrics.PersistenceGetHistoryTreeScope, metrics.PersistenceLatency) - response, err := p.persistence.GetHistoryTree(request) + response, err := p.persistence.GetHistoryTree(ctx, request) sw.Stop() if err != nil { p.updateErrorMetric(metrics.PersistenceGetHistoryTreeScope, err) diff --git a/common/persistence/persistenceRateLimitedClients.go b/common/persistence/persistenceRateLimitedClients.go index daf70a8d97e..6046f98b7af 100644 --- a/common/persistence/persistenceRateLimitedClients.go +++ b/common/persistence/persistenceRateLimitedClients.go @@ -25,6 +25,8 @@ package persistence import ( + "context" + commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" "go.temporal.io/api/serviceerror" @@ -167,170 +169,216 @@ func (p *executionRateLimitedPersistenceClient) GetName() string { return p.persistence.GetName() } -func (p *executionRateLimitedPersistenceClient) CreateWorkflowExecution(request *CreateWorkflowExecutionRequest) (*CreateWorkflowExecutionResponse, error) { +func (p *executionRateLimitedPersistenceClient) CreateWorkflowExecution( + ctx context.Context, + request *CreateWorkflowExecutionRequest, +) (*CreateWorkflowExecutionResponse, error) { if ok := p.rateLimiter.Allow(); !ok { return nil, ErrPersistenceLimitExceeded } - response, err := p.persistence.CreateWorkflowExecution(request) + response, err := p.persistence.CreateWorkflowExecution(ctx, request) return response, err } -func (p *executionRateLimitedPersistenceClient) GetWorkflowExecution(request *GetWorkflowExecutionRequest) (*GetWorkflowExecutionResponse, error) { +func (p *executionRateLimitedPersistenceClient) GetWorkflowExecution( + ctx context.Context, + request *GetWorkflowExecutionRequest, +) (*GetWorkflowExecutionResponse, error) { if ok := p.rateLimiter.Allow(); !ok { return nil, ErrPersistenceLimitExceeded } - response, err := p.persistence.GetWorkflowExecution(request) + response, err := p.persistence.GetWorkflowExecution(ctx, request) return response, err } -func (p *executionRateLimitedPersistenceClient) SetWorkflowExecution(request *SetWorkflowExecutionRequest) (*SetWorkflowExecutionResponse, error) { +func (p *executionRateLimitedPersistenceClient) SetWorkflowExecution( + ctx context.Context, + request *SetWorkflowExecutionRequest, +) (*SetWorkflowExecutionResponse, error) { if ok := p.rateLimiter.Allow(); !ok { return nil, ErrPersistenceLimitExceeded } - response, err := p.persistence.SetWorkflowExecution(request) + response, err := p.persistence.SetWorkflowExecution(ctx, request) return response, err } -func (p *executionRateLimitedPersistenceClient) UpdateWorkflowExecution(request *UpdateWorkflowExecutionRequest) (*UpdateWorkflowExecutionResponse, error) { +func (p *executionRateLimitedPersistenceClient) UpdateWorkflowExecution( + ctx context.Context, + request *UpdateWorkflowExecutionRequest, +) (*UpdateWorkflowExecutionResponse, error) { if ok := p.rateLimiter.Allow(); !ok { return nil, ErrPersistenceLimitExceeded } - resp, err := p.persistence.UpdateWorkflowExecution(request) + resp, err := p.persistence.UpdateWorkflowExecution(ctx, request) return resp, err } -func (p *executionRateLimitedPersistenceClient) ConflictResolveWorkflowExecution(request *ConflictResolveWorkflowExecutionRequest) (*ConflictResolveWorkflowExecutionResponse, error) { +func (p *executionRateLimitedPersistenceClient) ConflictResolveWorkflowExecution( + ctx context.Context, + request *ConflictResolveWorkflowExecutionRequest, +) (*ConflictResolveWorkflowExecutionResponse, error) { if ok := p.rateLimiter.Allow(); !ok { return nil, ErrPersistenceLimitExceeded } - response, err := p.persistence.ConflictResolveWorkflowExecution(request) + response, err := p.persistence.ConflictResolveWorkflowExecution(ctx, request) return response, err } -func (p *executionRateLimitedPersistenceClient) DeleteWorkflowExecution(request *DeleteWorkflowExecutionRequest) error { +func (p *executionRateLimitedPersistenceClient) DeleteWorkflowExecution( + ctx context.Context, + request *DeleteWorkflowExecutionRequest, +) error { if ok := p.rateLimiter.Allow(); !ok { return ErrPersistenceLimitExceeded } - err := p.persistence.DeleteWorkflowExecution(request) + err := p.persistence.DeleteWorkflowExecution(ctx, request) return err } -func (p *executionRateLimitedPersistenceClient) DeleteCurrentWorkflowExecution(request *DeleteCurrentWorkflowExecutionRequest) error { +func (p *executionRateLimitedPersistenceClient) DeleteCurrentWorkflowExecution( + ctx context.Context, + request *DeleteCurrentWorkflowExecutionRequest, +) error { if ok := p.rateLimiter.Allow(); !ok { return ErrPersistenceLimitExceeded } - err := p.persistence.DeleteCurrentWorkflowExecution(request) + err := p.persistence.DeleteCurrentWorkflowExecution(ctx, request) return err } -func (p *executionRateLimitedPersistenceClient) GetCurrentExecution(request *GetCurrentExecutionRequest) (*GetCurrentExecutionResponse, error) { +func (p *executionRateLimitedPersistenceClient) GetCurrentExecution( + ctx context.Context, + request *GetCurrentExecutionRequest, +) (*GetCurrentExecutionResponse, error) { if ok := p.rateLimiter.Allow(); !ok { return nil, ErrPersistenceLimitExceeded } - response, err := p.persistence.GetCurrentExecution(request) + response, err := p.persistence.GetCurrentExecution(ctx, request) return response, err } -func (p *executionRateLimitedPersistenceClient) ListConcreteExecutions(request *ListConcreteExecutionsRequest) (*ListConcreteExecutionsResponse, error) { +func (p *executionRateLimitedPersistenceClient) ListConcreteExecutions( + ctx context.Context, + request *ListConcreteExecutionsRequest, +) (*ListConcreteExecutionsResponse, error) { if ok := p.rateLimiter.Allow(); !ok { return nil, ErrPersistenceLimitExceeded } - response, err := p.persistence.ListConcreteExecutions(request) + response, err := p.persistence.ListConcreteExecutions(ctx, request) return response, err } -func (p *executionRateLimitedPersistenceClient) AddHistoryTasks(request *AddHistoryTasksRequest) error { +func (p *executionRateLimitedPersistenceClient) AddHistoryTasks( + ctx context.Context, + request *AddHistoryTasksRequest, +) error { if ok := p.rateLimiter.Allow(); !ok { return ErrPersistenceLimitExceeded } - err := p.persistence.AddHistoryTasks(request) + err := p.persistence.AddHistoryTasks(ctx, request) return err } -func (p *executionRateLimitedPersistenceClient) GetHistoryTask(request *GetHistoryTaskRequest) (*GetHistoryTaskResponse, error) { +func (p *executionRateLimitedPersistenceClient) GetHistoryTask( + ctx context.Context, + request *GetHistoryTaskRequest, +) (*GetHistoryTaskResponse, error) { if ok := p.rateLimiter.Allow(); !ok { return nil, ErrPersistenceLimitExceeded } - response, err := p.persistence.GetHistoryTask(request) + response, err := p.persistence.GetHistoryTask(ctx, request) return response, err } -func (p *executionRateLimitedPersistenceClient) GetHistoryTasks(request *GetHistoryTasksRequest) (*GetHistoryTasksResponse, error) { +func (p *executionRateLimitedPersistenceClient) GetHistoryTasks( + ctx context.Context, + request *GetHistoryTasksRequest, +) (*GetHistoryTasksResponse, error) { if ok := p.rateLimiter.Allow(); !ok { return nil, ErrPersistenceLimitExceeded } - response, err := p.persistence.GetHistoryTasks(request) + response, err := p.persistence.GetHistoryTasks(ctx, request) return response, err } -func (p *executionRateLimitedPersistenceClient) CompleteHistoryTask(request *CompleteHistoryTaskRequest) error { +func (p *executionRateLimitedPersistenceClient) CompleteHistoryTask( + ctx context.Context, + request *CompleteHistoryTaskRequest, +) error { if ok := p.rateLimiter.Allow(); !ok { return ErrPersistenceLimitExceeded } - err := p.persistence.CompleteHistoryTask(request) + err := p.persistence.CompleteHistoryTask(ctx, request) return err } -func (p *executionRateLimitedPersistenceClient) RangeCompleteHistoryTasks(request *RangeCompleteHistoryTasksRequest) error { +func (p *executionRateLimitedPersistenceClient) RangeCompleteHistoryTasks( + ctx context.Context, + request *RangeCompleteHistoryTasksRequest, +) error { if ok := p.rateLimiter.Allow(); !ok { return ErrPersistenceLimitExceeded } - err := p.persistence.RangeCompleteHistoryTasks(request) + err := p.persistence.RangeCompleteHistoryTasks(ctx, request) return err } func (p *executionRateLimitedPersistenceClient) PutReplicationTaskToDLQ( + ctx context.Context, request *PutReplicationTaskToDLQRequest, ) error { if ok := p.rateLimiter.Allow(); !ok { return ErrPersistenceLimitExceeded } - return p.persistence.PutReplicationTaskToDLQ(request) + return p.persistence.PutReplicationTaskToDLQ(ctx, request) } func (p *executionRateLimitedPersistenceClient) GetReplicationTasksFromDLQ( + ctx context.Context, request *GetReplicationTasksFromDLQRequest, ) (*GetHistoryTasksResponse, error) { if ok := p.rateLimiter.Allow(); !ok { return nil, ErrPersistenceLimitExceeded } - return p.persistence.GetReplicationTasksFromDLQ(request) + return p.persistence.GetReplicationTasksFromDLQ(ctx, request) } func (p *executionRateLimitedPersistenceClient) DeleteReplicationTaskFromDLQ( + ctx context.Context, request *DeleteReplicationTaskFromDLQRequest, ) error { if ok := p.rateLimiter.Allow(); !ok { return ErrPersistenceLimitExceeded } - return p.persistence.DeleteReplicationTaskFromDLQ(request) + return p.persistence.DeleteReplicationTaskFromDLQ(ctx, request) } func (p *executionRateLimitedPersistenceClient) RangeDeleteReplicationTaskFromDLQ( + ctx context.Context, request *RangeDeleteReplicationTaskFromDLQRequest, ) error { if ok := p.rateLimiter.Allow(); !ok { return ErrPersistenceLimitExceeded } - return p.persistence.RangeDeleteReplicationTaskFromDLQ(request) + return p.persistence.RangeDeleteReplicationTaskFromDLQ(ctx, request) } func (p *executionRateLimitedPersistenceClient) Close() { @@ -495,90 +543,120 @@ func (p *metadataRateLimitedPersistenceClient) Close() { } // AppendHistoryNodes add a node to history node table -func (p *executionRateLimitedPersistenceClient) AppendHistoryNodes(request *AppendHistoryNodesRequest) (*AppendHistoryNodesResponse, error) { +func (p *executionRateLimitedPersistenceClient) AppendHistoryNodes( + ctx context.Context, + request *AppendHistoryNodesRequest, +) (*AppendHistoryNodesResponse, error) { if ok := p.rateLimiter.Allow(); !ok { return nil, ErrPersistenceLimitExceeded } - return p.persistence.AppendHistoryNodes(request) + return p.persistence.AppendHistoryNodes(ctx, request) } // ReadHistoryBranch returns history node data for a branch -func (p *executionRateLimitedPersistenceClient) ReadHistoryBranch(request *ReadHistoryBranchRequest) (*ReadHistoryBranchResponse, error) { +func (p *executionRateLimitedPersistenceClient) ReadHistoryBranch( + ctx context.Context, + request *ReadHistoryBranchRequest, +) (*ReadHistoryBranchResponse, error) { if ok := p.rateLimiter.Allow(); !ok { return nil, ErrPersistenceLimitExceeded } - response, err := p.persistence.ReadHistoryBranch(request) + response, err := p.persistence.ReadHistoryBranch(ctx, request) return response, err } // ReadHistoryBranch returns history node data for a branch -func (p *executionRateLimitedPersistenceClient) ReadHistoryBranchReverse(request *ReadHistoryBranchReverseRequest) (*ReadHistoryBranchReverseResponse, error) { +func (p *executionRateLimitedPersistenceClient) ReadHistoryBranchReverse( + ctx context.Context, + request *ReadHistoryBranchReverseRequest, +) (*ReadHistoryBranchReverseResponse, error) { if ok := p.rateLimiter.Allow(); !ok { return nil, ErrPersistenceLimitExceeded } - response, err := p.persistence.ReadHistoryBranchReverse(request) + response, err := p.persistence.ReadHistoryBranchReverse(ctx, request) return response, err } // ReadHistoryBranchByBatch returns history node data for a branch -func (p *executionRateLimitedPersistenceClient) ReadHistoryBranchByBatch(request *ReadHistoryBranchRequest) (*ReadHistoryBranchByBatchResponse, error) { +func (p *executionRateLimitedPersistenceClient) ReadHistoryBranchByBatch( + ctx context.Context, + request *ReadHistoryBranchRequest, +) (*ReadHistoryBranchByBatchResponse, error) { if ok := p.rateLimiter.Allow(); !ok { return nil, ErrPersistenceLimitExceeded } - response, err := p.persistence.ReadHistoryBranchByBatch(request) + response, err := p.persistence.ReadHistoryBranchByBatch(ctx, request) return response, err } // ReadHistoryBranchByBatch returns history node data for a branch -func (p *executionRateLimitedPersistenceClient) ReadRawHistoryBranch(request *ReadHistoryBranchRequest) (*ReadRawHistoryBranchResponse, error) { +func (p *executionRateLimitedPersistenceClient) ReadRawHistoryBranch( + ctx context.Context, + request *ReadHistoryBranchRequest, +) (*ReadRawHistoryBranchResponse, error) { if ok := p.rateLimiter.Allow(); !ok { return nil, ErrPersistenceLimitExceeded } - response, err := p.persistence.ReadRawHistoryBranch(request) + response, err := p.persistence.ReadRawHistoryBranch(ctx, request) return response, err } // ForkHistoryBranch forks a new branch from a old branch -func (p *executionRateLimitedPersistenceClient) ForkHistoryBranch(request *ForkHistoryBranchRequest) (*ForkHistoryBranchResponse, error) { +func (p *executionRateLimitedPersistenceClient) ForkHistoryBranch( + ctx context.Context, + request *ForkHistoryBranchRequest, +) (*ForkHistoryBranchResponse, error) { if ok := p.rateLimiter.Allow(); !ok { return nil, ErrPersistenceLimitExceeded } - response, err := p.persistence.ForkHistoryBranch(request) + response, err := p.persistence.ForkHistoryBranch(ctx, request) return response, err } // DeleteHistoryBranch removes a branch -func (p *executionRateLimitedPersistenceClient) DeleteHistoryBranch(request *DeleteHistoryBranchRequest) error { +func (p *executionRateLimitedPersistenceClient) DeleteHistoryBranch( + ctx context.Context, + request *DeleteHistoryBranchRequest, +) error { if ok := p.rateLimiter.Allow(); !ok { return ErrPersistenceLimitExceeded } - err := p.persistence.DeleteHistoryBranch(request) + err := p.persistence.DeleteHistoryBranch(ctx, request) return err } // TrimHistoryBranch trims a branch -func (p *executionRateLimitedPersistenceClient) TrimHistoryBranch(request *TrimHistoryBranchRequest) (*TrimHistoryBranchResponse, error) { +func (p *executionRateLimitedPersistenceClient) TrimHistoryBranch( + ctx context.Context, + request *TrimHistoryBranchRequest, +) (*TrimHistoryBranchResponse, error) { if ok := p.rateLimiter.Allow(); !ok { return nil, ErrPersistenceLimitExceeded } - resp, err := p.persistence.TrimHistoryBranch(request) + resp, err := p.persistence.TrimHistoryBranch(ctx, request) return resp, err } // GetHistoryTree returns all branch information of a tree -func (p *executionRateLimitedPersistenceClient) GetHistoryTree(request *GetHistoryTreeRequest) (*GetHistoryTreeResponse, error) { +func (p *executionRateLimitedPersistenceClient) GetHistoryTree( + ctx context.Context, + request *GetHistoryTreeRequest, +) (*GetHistoryTreeResponse, error) { if ok := p.rateLimiter.Allow(); !ok { return nil, ErrPersistenceLimitExceeded } - response, err := p.persistence.GetHistoryTree(request) + response, err := p.persistence.GetHistoryTree(ctx, request) return response, err } -func (p *executionRateLimitedPersistenceClient) GetAllHistoryTreeBranches(request *GetAllHistoryTreeBranchesRequest) (*GetAllHistoryTreeBranchesResponse, error) { +func (p *executionRateLimitedPersistenceClient) GetAllHistoryTreeBranches( + ctx context.Context, + request *GetAllHistoryTreeBranchesRequest, +) (*GetAllHistoryTreeBranchesResponse, error) { if ok := p.rateLimiter.Allow(); !ok { return nil, ErrPersistenceLimitExceeded } - response, err := p.persistence.GetAllHistoryTreeBranches(request) + response, err := p.persistence.GetAllHistoryTreeBranches(ctx, request) return response, err } diff --git a/common/persistence/tests/execution_mutable_state.go b/common/persistence/tests/execution_mutable_state.go index 794f21d93ea..ddf39eaf811 100644 --- a/common/persistence/tests/execution_mutable_state.go +++ b/common/persistence/tests/execution_mutable_state.go @@ -25,6 +25,7 @@ package tests import ( + "context" "math/rand" "testing" "time" @@ -60,6 +61,9 @@ type ( ShardManager p.ShardManager ExecutionManager p.ExecutionManager Logger log.Logger + + Ctx context.Context + Cancel context.CancelFunc } ) @@ -96,6 +100,7 @@ func (s *ExecutionMutableStateSuite) TearDownSuite() { func (s *ExecutionMutableStateSuite) SetupTest() { s.Assertions = require.New(s.T()) + s.Ctx, s.Cancel = context.WithTimeout(context.Background(), time.Second*30) s.ShardID = 1 + rand.Int31n(16) resp, err := s.ShardManager.GetOrCreateShard(&p.GetOrCreateShardRequest{ @@ -121,7 +126,7 @@ func (s *ExecutionMutableStateSuite) SetupTest() { } func (s *ExecutionMutableStateSuite) TearDownTest() { - + s.Cancel() } func (s *ExecutionMutableStateSuite) TestCreate_BrandNew() { @@ -144,7 +149,7 @@ func (s *ExecutionMutableStateSuite) TestCreate_BrandNew_CurrentConflict() { rand.Int63(), ) - _, err := s.ExecutionManager.CreateWorkflowExecution(&p.CreateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.CreateWorkflowExecution(s.Ctx, &p.CreateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.CreateWorkflowModeBrandNew, @@ -189,7 +194,7 @@ func (s *ExecutionMutableStateSuite) TestCreate_Reuse() { rand.Int63(), ) - _, err := s.ExecutionManager.CreateWorkflowExecution(&p.CreateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.CreateWorkflowExecution(s.Ctx, &p.CreateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.CreateWorkflowModeWorkflowIDReuse, @@ -214,7 +219,7 @@ func (s *ExecutionMutableStateSuite) TestCreate_Reuse_CurrentConflict() { rand.Int63(), ) - _, err := s.ExecutionManager.CreateWorkflowExecution(&p.CreateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.CreateWorkflowExecution(s.Ctx, &p.CreateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.CreateWorkflowModeWorkflowIDReuse, @@ -259,7 +264,7 @@ func (s *ExecutionMutableStateSuite) TestCreate_Zombie() { rand.Int63(), ) - _, err := s.ExecutionManager.CreateWorkflowExecution(&p.CreateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.CreateWorkflowExecution(s.Ctx, &p.CreateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.CreateWorkflowModeZombie, @@ -284,7 +289,7 @@ func (s *ExecutionMutableStateSuite) TestCreate_Conflict() { rand.Int63(), ) - _, err := s.ExecutionManager.CreateWorkflowExecution(&p.CreateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.CreateWorkflowExecution(s.Ctx, &p.CreateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.CreateWorkflowModeWorkflowIDReuse, @@ -315,7 +320,7 @@ func (s *ExecutionMutableStateSuite) TestUpdate_NotZombie() { enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, currentSnapshot.DBRecordVersion+1, ) - _, err := s.ExecutionManager.UpdateWorkflowExecution(&p.UpdateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.UpdateWorkflowExecution(s.Ctx, &p.UpdateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.UpdateWorkflowModeUpdateCurrent, @@ -348,7 +353,7 @@ func (s *ExecutionMutableStateSuite) TestUpdate_NotZombie_CurrentConflict() { enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, rand.Int63(), ) - _, err := s.ExecutionManager.UpdateWorkflowExecution(&p.UpdateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.UpdateWorkflowExecution(s.Ctx, &p.UpdateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.UpdateWorkflowModeUpdateCurrent, @@ -385,7 +390,7 @@ func (s *ExecutionMutableStateSuite) TestUpdate_NotZombie_Conflict() { enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, rand.Int63(), ) - _, err := s.ExecutionManager.UpdateWorkflowExecution(&p.UpdateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.UpdateWorkflowExecution(s.Ctx, &p.UpdateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.UpdateWorkflowModeUpdateCurrent, @@ -427,7 +432,7 @@ func (s *ExecutionMutableStateSuite) TestUpdate_NotZombie_WithNew() { enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, rand.Int63(), ) - _, err := s.ExecutionManager.UpdateWorkflowExecution(&p.UpdateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.UpdateWorkflowExecution(s.Ctx, &p.UpdateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.UpdateWorkflowModeUpdateCurrent, @@ -461,7 +466,7 @@ func (s *ExecutionMutableStateSuite) TestUpdate_Zombie() { enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, rand.Int63(), ) - _, err := s.ExecutionManager.CreateWorkflowExecution(&p.CreateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.CreateWorkflowExecution(s.Ctx, &p.CreateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.CreateWorkflowModeZombie, @@ -483,7 +488,7 @@ func (s *ExecutionMutableStateSuite) TestUpdate_Zombie() { enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, zombieSnapshot.DBRecordVersion+1, ) - _, err = s.ExecutionManager.UpdateWorkflowExecution(&p.UpdateWorkflowExecutionRequest{ + _, err = s.ExecutionManager.UpdateWorkflowExecution(s.Ctx, &p.UpdateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.UpdateWorkflowModeBypassCurrent, @@ -516,7 +521,7 @@ func (s *ExecutionMutableStateSuite) TestUpdate_Zombie_CurrentConflict() { enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, currentSnapshot.DBRecordVersion+1, ) - _, err := s.ExecutionManager.UpdateWorkflowExecution(&p.UpdateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.UpdateWorkflowExecution(s.Ctx, &p.UpdateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.UpdateWorkflowModeBypassCurrent, @@ -549,7 +554,7 @@ func (s *ExecutionMutableStateSuite) TestUpdate_Zombie_Conflict() { enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, rand.Int63(), ) - _, err := s.ExecutionManager.CreateWorkflowExecution(&p.CreateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.CreateWorkflowExecution(s.Ctx, &p.CreateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.CreateWorkflowModeZombie, @@ -571,7 +576,7 @@ func (s *ExecutionMutableStateSuite) TestUpdate_Zombie_Conflict() { enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, rand.Int63(), ) - _, err = s.ExecutionManager.UpdateWorkflowExecution(&p.UpdateWorkflowExecutionRequest{ + _, err = s.ExecutionManager.UpdateWorkflowExecution(s.Ctx, &p.UpdateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.UpdateWorkflowModeBypassCurrent, @@ -604,7 +609,7 @@ func (s *ExecutionMutableStateSuite) TestUpdate_Zombie_WithNew() { enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, rand.Int63(), ) - _, err := s.ExecutionManager.CreateWorkflowExecution(&p.CreateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.CreateWorkflowExecution(s.Ctx, &p.CreateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.CreateWorkflowModeZombie, @@ -635,7 +640,7 @@ func (s *ExecutionMutableStateSuite) TestUpdate_Zombie_WithNew() { enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, rand.Int63(), ) - _, err = s.ExecutionManager.UpdateWorkflowExecution(&p.UpdateWorkflowExecutionRequest{ + _, err = s.ExecutionManager.UpdateWorkflowExecution(s.Ctx, &p.UpdateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.UpdateWorkflowModeBypassCurrent, @@ -670,7 +675,7 @@ func (s *ExecutionMutableStateSuite) TestConflictResolve_SuppressCurrent() { enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, rand.Int63(), ) - _, err := s.ExecutionManager.CreateWorkflowExecution(&p.CreateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.CreateWorkflowExecution(s.Ctx, &p.CreateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.CreateWorkflowModeZombie, @@ -701,7 +706,7 @@ func (s *ExecutionMutableStateSuite) TestConflictResolve_SuppressCurrent() { enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, currentSnapshot.DBRecordVersion+1, ) - _, err = s.ExecutionManager.ConflictResolveWorkflowExecution(&p.ConflictResolveWorkflowExecutionRequest{ + _, err = s.ExecutionManager.ConflictResolveWorkflowExecution(s.Ctx, &p.ConflictResolveWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.ConflictResolveWorkflowModeUpdateCurrent, @@ -739,7 +744,7 @@ func (s *ExecutionMutableStateSuite) TestConflictResolve_SuppressCurrent_Current enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, rand.Int63(), ) - _, err := s.ExecutionManager.CreateWorkflowExecution(&p.CreateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.CreateWorkflowExecution(s.Ctx, &p.CreateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.CreateWorkflowModeZombie, @@ -770,7 +775,7 @@ func (s *ExecutionMutableStateSuite) TestConflictResolve_SuppressCurrent_Current enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, rand.Int63(), ) - _, err = s.ExecutionManager.ConflictResolveWorkflowExecution(&p.ConflictResolveWorkflowExecutionRequest{ + _, err = s.ExecutionManager.ConflictResolveWorkflowExecution(s.Ctx, &p.ConflictResolveWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.ConflictResolveWorkflowModeUpdateCurrent, @@ -808,7 +813,7 @@ func (s *ExecutionMutableStateSuite) TestConflictResolve_SuppressCurrent_Conflic enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, rand.Int63(), ) - _, err := s.ExecutionManager.CreateWorkflowExecution(&p.CreateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.CreateWorkflowExecution(s.Ctx, &p.CreateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.CreateWorkflowModeZombie, @@ -839,7 +844,7 @@ func (s *ExecutionMutableStateSuite) TestConflictResolve_SuppressCurrent_Conflic enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, rand.Int63(), ) - _, err = s.ExecutionManager.ConflictResolveWorkflowExecution(&p.ConflictResolveWorkflowExecutionRequest{ + _, err = s.ExecutionManager.ConflictResolveWorkflowExecution(s.Ctx, &p.ConflictResolveWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.ConflictResolveWorkflowModeUpdateCurrent, @@ -877,7 +882,7 @@ func (s *ExecutionMutableStateSuite) TestConflictResolve_SuppressCurrent_Conflic enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, rand.Int63(), ) - _, err := s.ExecutionManager.CreateWorkflowExecution(&p.CreateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.CreateWorkflowExecution(s.Ctx, &p.CreateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.CreateWorkflowModeZombie, @@ -908,7 +913,7 @@ func (s *ExecutionMutableStateSuite) TestConflictResolve_SuppressCurrent_Conflic enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, currentSnapshot.DBRecordVersion+1, ) - _, err = s.ExecutionManager.ConflictResolveWorkflowExecution(&p.ConflictResolveWorkflowExecutionRequest{ + _, err = s.ExecutionManager.ConflictResolveWorkflowExecution(s.Ctx, &p.ConflictResolveWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.ConflictResolveWorkflowModeUpdateCurrent, @@ -946,7 +951,7 @@ func (s *ExecutionMutableStateSuite) TestConflictResolve_SuppressCurrent_WithNew enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, rand.Int63(), ) - _, err := s.ExecutionManager.CreateWorkflowExecution(&p.CreateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.CreateWorkflowExecution(s.Ctx, &p.CreateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.CreateWorkflowModeZombie, @@ -986,7 +991,7 @@ func (s *ExecutionMutableStateSuite) TestConflictResolve_SuppressCurrent_WithNew enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, currentSnapshot.DBRecordVersion+1, ) - _, err = s.ExecutionManager.ConflictResolveWorkflowExecution(&p.ConflictResolveWorkflowExecutionRequest{ + _, err = s.ExecutionManager.ConflictResolveWorkflowExecution(s.Ctx, &p.ConflictResolveWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.ConflictResolveWorkflowModeUpdateCurrent, @@ -1024,7 +1029,7 @@ func (s *ExecutionMutableStateSuite) TestConflictResolve_ResetCurrent() { enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, baseSnapshot.DBRecordVersion+1, ) - _, err := s.ExecutionManager.ConflictResolveWorkflowExecution(&p.ConflictResolveWorkflowExecutionRequest{ + _, err := s.ExecutionManager.ConflictResolveWorkflowExecution(s.Ctx, &p.ConflictResolveWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.ConflictResolveWorkflowModeUpdateCurrent, @@ -1060,7 +1065,7 @@ func (s *ExecutionMutableStateSuite) TestConflictResolve_ResetCurrent_CurrentCon enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, rand.Int63(), ) - _, err := s.ExecutionManager.CreateWorkflowExecution(&p.CreateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.CreateWorkflowExecution(s.Ctx, &p.CreateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.CreateWorkflowModeZombie, @@ -1082,7 +1087,7 @@ func (s *ExecutionMutableStateSuite) TestConflictResolve_ResetCurrent_CurrentCon enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, baseSnapshot.DBRecordVersion+1, ) - _, err = s.ExecutionManager.ConflictResolveWorkflowExecution(&p.ConflictResolveWorkflowExecutionRequest{ + _, err = s.ExecutionManager.ConflictResolveWorkflowExecution(s.Ctx, &p.ConflictResolveWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.ConflictResolveWorkflowModeUpdateCurrent, @@ -1118,7 +1123,7 @@ func (s *ExecutionMutableStateSuite) TestConflictResolve_ResetCurrent_Conflict() enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, rand.Int63(), ) - _, err := s.ExecutionManager.ConflictResolveWorkflowExecution(&p.ConflictResolveWorkflowExecutionRequest{ + _, err := s.ExecutionManager.ConflictResolveWorkflowExecution(s.Ctx, &p.ConflictResolveWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.ConflictResolveWorkflowModeUpdateCurrent, @@ -1163,7 +1168,7 @@ func (s *ExecutionMutableStateSuite) TestConflictResolve_ResetCurrent_WithNew() enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, rand.Int63(), ) - _, err := s.ExecutionManager.ConflictResolveWorkflowExecution(&p.ConflictResolveWorkflowExecutionRequest{ + _, err := s.ExecutionManager.ConflictResolveWorkflowExecution(s.Ctx, &p.ConflictResolveWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.ConflictResolveWorkflowModeUpdateCurrent, @@ -1200,7 +1205,7 @@ func (s *ExecutionMutableStateSuite) TestConflictResolve_Zombie() { enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, rand.Int63(), ) - _, err := s.ExecutionManager.CreateWorkflowExecution(&p.CreateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.CreateWorkflowExecution(s.Ctx, &p.CreateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.CreateWorkflowModeZombie, @@ -1222,7 +1227,7 @@ func (s *ExecutionMutableStateSuite) TestConflictResolve_Zombie() { enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, baseSnapshot.DBRecordVersion+1, ) - _, err = s.ExecutionManager.ConflictResolveWorkflowExecution(&p.ConflictResolveWorkflowExecutionRequest{ + _, err = s.ExecutionManager.ConflictResolveWorkflowExecution(s.Ctx, &p.ConflictResolveWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.ConflictResolveWorkflowModeBypassCurrent, @@ -1258,7 +1263,7 @@ func (s *ExecutionMutableStateSuite) TestConflictResolve_Zombie_CurrentConflict( enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, baseSnapshot.DBRecordVersion+1, ) - _, err := s.ExecutionManager.ConflictResolveWorkflowExecution(&p.ConflictResolveWorkflowExecutionRequest{ + _, err := s.ExecutionManager.ConflictResolveWorkflowExecution(s.Ctx, &p.ConflictResolveWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.ConflictResolveWorkflowModeBypassCurrent, @@ -1294,7 +1299,7 @@ func (s *ExecutionMutableStateSuite) TestConflictResolve_Zombie_Conflict() { enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, rand.Int63(), ) - _, err := s.ExecutionManager.CreateWorkflowExecution(&p.CreateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.CreateWorkflowExecution(s.Ctx, &p.CreateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.CreateWorkflowModeZombie, @@ -1316,7 +1321,7 @@ func (s *ExecutionMutableStateSuite) TestConflictResolve_Zombie_Conflict() { enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, rand.Int63(), ) - _, err = s.ExecutionManager.ConflictResolveWorkflowExecution(&p.ConflictResolveWorkflowExecutionRequest{ + _, err = s.ExecutionManager.ConflictResolveWorkflowExecution(s.Ctx, &p.ConflictResolveWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.ConflictResolveWorkflowModeBypassCurrent, @@ -1352,7 +1357,7 @@ func (s *ExecutionMutableStateSuite) TestConflictResolve_Zombie_WithNew() { enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, rand.Int63(), ) - _, err := s.ExecutionManager.CreateWorkflowExecution(&p.CreateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.CreateWorkflowExecution(s.Ctx, &p.CreateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.CreateWorkflowModeZombie, @@ -1383,7 +1388,7 @@ func (s *ExecutionMutableStateSuite) TestConflictResolve_Zombie_WithNew() { enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, rand.Int63(), ) - _, err = s.ExecutionManager.ConflictResolveWorkflowExecution(&p.ConflictResolveWorkflowExecutionRequest{ + _, err = s.ExecutionManager.ConflictResolveWorkflowExecution(s.Ctx, &p.ConflictResolveWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.ConflictResolveWorkflowModeBypassCurrent, @@ -1413,7 +1418,7 @@ func (s *ExecutionMutableStateSuite) TestSet_NotExists() { enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, rand.Int63(), ) - _, err := s.ExecutionManager.SetWorkflowExecution(&p.SetWorkflowExecutionRequest{ + _, err := s.ExecutionManager.SetWorkflowExecution(s.Ctx, &p.SetWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, @@ -1441,7 +1446,7 @@ func (s *ExecutionMutableStateSuite) TestSet_Conflict() { enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, rand.Int63(), ) - _, err := s.ExecutionManager.SetWorkflowExecution(&p.SetWorkflowExecutionRequest{ + _, err := s.ExecutionManager.SetWorkflowExecution(s.Ctx, &p.SetWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, @@ -1469,7 +1474,7 @@ func (s *ExecutionMutableStateSuite) TestSet() { enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, snapshot.DBRecordVersion+1, ) - _, err := s.ExecutionManager.SetWorkflowExecution(&p.SetWorkflowExecutionRequest{ + _, err := s.ExecutionManager.SetWorkflowExecution(s.Ctx, &p.SetWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, @@ -1488,7 +1493,7 @@ func (s *ExecutionMutableStateSuite) TestDeleteCurrent_IsCurrent() { rand.Int63(), ) - err := s.ExecutionManager.DeleteCurrentWorkflowExecution(&p.DeleteCurrentWorkflowExecutionRequest{ + err := s.ExecutionManager.DeleteCurrentWorkflowExecution(s.Ctx, &p.DeleteCurrentWorkflowExecutionRequest{ ShardID: s.ShardID, NamespaceID: s.NamespaceID, WorkflowID: s.WorkflowID, @@ -1496,7 +1501,7 @@ func (s *ExecutionMutableStateSuite) TestDeleteCurrent_IsCurrent() { }) s.NoError(err) - _, err = s.ExecutionManager.GetCurrentExecution(&p.GetCurrentExecutionRequest{ + _, err = s.ExecutionManager.GetCurrentExecution(s.Ctx, &p.GetCurrentExecutionRequest{ ShardID: s.ShardID, NamespaceID: s.NamespaceID, WorkflowID: s.WorkflowID, @@ -1517,7 +1522,7 @@ func (s *ExecutionMutableStateSuite) TestDeleteCurrent_NotCurrent() { rand.Int63(), ) - _, err := s.ExecutionManager.CreateWorkflowExecution(&p.CreateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.CreateWorkflowExecution(s.Ctx, &p.CreateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.CreateWorkflowModeZombie, @@ -1530,7 +1535,7 @@ func (s *ExecutionMutableStateSuite) TestDeleteCurrent_NotCurrent() { }) s.NoError(err) - err = s.ExecutionManager.DeleteCurrentWorkflowExecution(&p.DeleteCurrentWorkflowExecutionRequest{ + err = s.ExecutionManager.DeleteCurrentWorkflowExecution(s.Ctx, &p.DeleteCurrentWorkflowExecutionRequest{ ShardID: s.ShardID, NamespaceID: s.NamespaceID, WorkflowID: s.WorkflowID, @@ -1538,7 +1543,7 @@ func (s *ExecutionMutableStateSuite) TestDeleteCurrent_NotCurrent() { }) s.NoError(err) - _, err = s.ExecutionManager.GetCurrentExecution(&p.GetCurrentExecutionRequest{ + _, err = s.ExecutionManager.GetCurrentExecution(s.Ctx, &p.GetCurrentExecutionRequest{ ShardID: s.ShardID, NamespaceID: s.NamespaceID, WorkflowID: s.WorkflowID, @@ -1559,7 +1564,7 @@ func (s *ExecutionMutableStateSuite) TestDelete_Exists() { rand.Int63(), ) - _, err := s.ExecutionManager.CreateWorkflowExecution(&p.CreateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.CreateWorkflowExecution(s.Ctx, &p.CreateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.CreateWorkflowModeZombie, @@ -1572,7 +1577,7 @@ func (s *ExecutionMutableStateSuite) TestDelete_Exists() { }) s.NoError(err) - err = s.ExecutionManager.DeleteWorkflowExecution(&p.DeleteWorkflowExecutionRequest{ + err = s.ExecutionManager.DeleteWorkflowExecution(s.Ctx, &p.DeleteWorkflowExecutionRequest{ ShardID: s.ShardID, NamespaceID: s.NamespaceID, WorkflowID: s.WorkflowID, @@ -1584,7 +1589,7 @@ func (s *ExecutionMutableStateSuite) TestDelete_Exists() { } func (s *ExecutionMutableStateSuite) TestDelete_NotExists() { - err := s.ExecutionManager.DeleteWorkflowExecution(&p.DeleteWorkflowExecutionRequest{ + err := s.ExecutionManager.DeleteWorkflowExecution(s.Ctx, &p.DeleteWorkflowExecutionRequest{ ShardID: s.ShardID, NamespaceID: s.NamespaceID, WorkflowID: s.WorkflowID, @@ -1610,7 +1615,7 @@ func (s *ExecutionMutableStateSuite) CreateWorkflow( status, dbRecordVersion, ) - _, err := s.ExecutionManager.CreateWorkflowExecution(&p.CreateWorkflowExecutionRequest{ + _, err := s.ExecutionManager.CreateWorkflowExecution(s.Ctx, &p.CreateWorkflowExecutionRequest{ ShardID: s.ShardID, RangeID: s.RangeID, Mode: p.CreateWorkflowModeBrandNew, @@ -1630,7 +1635,7 @@ func (s *ExecutionMutableStateSuite) AssertMissingFromDB( workflowID string, runID string, ) { - _, err := s.ExecutionManager.GetWorkflowExecution(&p.GetWorkflowExecutionRequest{ + _, err := s.ExecutionManager.GetWorkflowExecution(s.Ctx, &p.GetWorkflowExecutionRequest{ ShardID: s.ShardID, NamespaceID: namespaceID, WorkflowID: workflowID, @@ -1643,7 +1648,7 @@ func (s *ExecutionMutableStateSuite) AssertEqualWithDB( snapshot *p.WorkflowSnapshot, mutations ...*p.WorkflowMutation, ) { - resp, err := s.ExecutionManager.GetWorkflowExecution(&p.GetWorkflowExecutionRequest{ + resp, err := s.ExecutionManager.GetWorkflowExecution(s.Ctx, &p.GetWorkflowExecutionRequest{ ShardID: s.ShardID, NamespaceID: snapshot.ExecutionInfo.NamespaceId, WorkflowID: snapshot.ExecutionInfo.WorkflowId, diff --git a/common/persistence/tests/execution_mutable_state_task.go b/common/persistence/tests/execution_mutable_state_task.go index 0ba2eac1d1d..b3daf70d57a 100644 --- a/common/persistence/tests/execution_mutable_state_task.go +++ b/common/persistence/tests/execution_mutable_state_task.go @@ -25,6 +25,7 @@ package tests import ( + "context" "math" "math/rand" "testing" @@ -55,6 +56,9 @@ type ( ShardManager p.ShardManager ExecutionManager p.ExecutionManager Logger log.Logger + + Ctx context.Context + Cancel context.CancelFunc } ) @@ -83,6 +87,7 @@ func NewExecutionMutableStateTaskSuite( func (s *ExecutionMutableStateTaskSuite) SetupTest() { s.Assertions = require.New(s.T()) + s.Ctx, s.Cancel = context.WithTimeout(context.Background(), time.Second*30) s.ShardID = 1 + rand.Int31n(16) resp, err := s.ShardManager.GetOrCreateShard(&p.GetOrCreateShardRequest{ @@ -112,7 +117,7 @@ func (s *ExecutionMutableStateTaskSuite) SetupTest() { func (s *ExecutionMutableStateTaskSuite) TearDownTest() { for _, category := range []tasks.Category{tasks.CategoryTransfer, tasks.CategoryReplication, tasks.CategoryVisibility} { - err := s.ExecutionManager.RangeCompleteHistoryTasks(&p.RangeCompleteHistoryTasksRequest{ + err := s.ExecutionManager.RangeCompleteHistoryTasks(s.Ctx, &p.RangeCompleteHistoryTasksRequest{ ShardID: s.ShardID, TaskCategory: category, InclusiveMinTaskKey: tasks.Key{TaskID: 0}, @@ -120,13 +125,15 @@ func (s *ExecutionMutableStateTaskSuite) TearDownTest() { }) s.NoError(err) } - err := s.ExecutionManager.RangeCompleteHistoryTasks(&p.RangeCompleteHistoryTasksRequest{ + err := s.ExecutionManager.RangeCompleteHistoryTasks(s.Ctx, &p.RangeCompleteHistoryTasksRequest{ ShardID: s.ShardID, TaskCategory: tasks.CategoryTimer, InclusiveMinTaskKey: tasks.Key{FireTime: time.Unix(0, 0)}, ExclusiveMaxTaskKey: tasks.Key{FireTime: time.Unix(0, math.MaxInt64)}, }) s.NoError(err) + + s.Cancel() } func (s *ExecutionMutableStateTaskSuite) TestAddGetTransferTasks_Multiple() { @@ -239,7 +246,7 @@ func (s *ExecutionMutableStateTaskSuite) AddRandomTasks( now = now.Add(time.Duration(rand.Int63n(1000_000_000)) + time.Millisecond) } - err := s.ExecutionManager.AddHistoryTasks(&p.AddHistoryTasksRequest{ + err := s.ExecutionManager.AddHistoryTasks(s.Ctx, &p.AddHistoryTasksRequest{ ShardID: s.ShardID, RangeID: s.RangeID, NamespaceID: s.WorkflowKey.NamespaceID, @@ -269,7 +276,7 @@ func (s *ExecutionMutableStateTaskSuite) PaginateTasks( } var loadedTasks []tasks.Task for { - response, err := s.ExecutionManager.GetHistoryTasks(request) + response, err := s.ExecutionManager.GetHistoryTasks(s.Ctx, request) s.NoError(err) s.True(len(response.Tasks) <= batchSize) loadedTasks = append(loadedTasks, response.Tasks...) diff --git a/common/persistence/tests/history_store.go b/common/persistence/tests/history_store.go index 37f43d7cda9..55d470dd04e 100644 --- a/common/persistence/tests/history_store.go +++ b/common/persistence/tests/history_store.go @@ -25,6 +25,7 @@ package tests import ( + "context" "math/rand" "testing" "time" @@ -61,6 +62,9 @@ type ( store p.ExecutionManager logger log.Logger + + Ctx context.Context + Cancel context.CancelFunc } ) @@ -91,10 +95,11 @@ func (s *HistoryEventsSuite) TearDownSuite() { func (s *HistoryEventsSuite) SetupTest() { s.Assertions = require.New(s.T()) + s.Ctx, s.Cancel = context.WithTimeout(context.Background(), time.Second*30) } func (s *HistoryEventsSuite) TearDownTest() { - + s.Cancel() } func (s *HistoryEventsSuite) TestAppendSelect_First() { @@ -515,7 +520,7 @@ func (s *HistoryEventsSuite) TestForkDeleteBranch_DeleteBaseBranchFirst() { s.deleteHistoryBranch(shardID, br2Token) // at this point, both branch1 and branch2 are deleted. - _, err = s.store.ReadHistoryBranch(&p.ReadHistoryBranchRequest{ + _, err = s.store.ReadHistoryBranch(s.Ctx, &p.ReadHistoryBranchRequest{ ShardID: shardID, BranchToken: br1Token, MinEventID: common.FirstEventID, @@ -524,7 +529,7 @@ func (s *HistoryEventsSuite) TestForkDeleteBranch_DeleteBaseBranchFirst() { }) s.Error(err, "Workflow execution history not found.") - _, err = s.store.ReadHistoryBranch(&p.ReadHistoryBranchRequest{ + _, err = s.store.ReadHistoryBranch(s.Ctx, &p.ReadHistoryBranchRequest{ ShardID: shardID, BranchToken: br2Token, MinEventID: common.FirstEventID, @@ -568,7 +573,7 @@ func (s *HistoryEventsSuite) TestForkDeleteBranch_DeleteForkedBranchFirst() { s.Equal(append(eventsPacket0.events, eventsPacket1.events...), s.listAllHistoryEvents(shardID, br1Token)) // branch2:[4,5] should be deleted - _, err = s.store.ReadHistoryBranch(&p.ReadHistoryBranchRequest{ + _, err = s.store.ReadHistoryBranch(s.Ctx, &p.ReadHistoryBranchRequest{ ShardID: shardID, BranchToken: br2Token, MinEventID: 4, @@ -581,7 +586,7 @@ func (s *HistoryEventsSuite) TestForkDeleteBranch_DeleteForkedBranchFirst() { s.deleteHistoryBranch(shardID, br1Token) // branch1 should be deleted - _, err = s.store.ReadHistoryBranch(&p.ReadHistoryBranchRequest{ + _, err = s.store.ReadHistoryBranch(s.Ctx, &p.ReadHistoryBranchRequest{ ShardID: shardID, BranchToken: br1Token, MinEventID: common.FirstEventID, @@ -596,7 +601,7 @@ func (s *HistoryEventsSuite) appendHistoryEvents( branchToken []byte, packet HistoryEventsPacket, ) { - _, err := s.store.AppendHistoryNodes(&p.AppendHistoryNodesRequest{ + _, err := s.store.AppendHistoryNodes(s.Ctx, &p.AppendHistoryNodesRequest{ ShardID: shardID, BranchToken: branchToken, Events: packet.events, @@ -613,7 +618,7 @@ func (s *HistoryEventsSuite) forkHistoryBranch( branchToken []byte, newNodeID int64, ) []byte { - resp, err := s.store.ForkHistoryBranch(&p.ForkHistoryBranchRequest{ + resp, err := s.store.ForkHistoryBranch(s.Ctx, &p.ForkHistoryBranchRequest{ ShardID: shardID, ForkBranchToken: branchToken, ForkNodeID: newNodeID, @@ -627,7 +632,7 @@ func (s *HistoryEventsSuite) deleteHistoryBranch( shardID int32, branchToken []byte, ) { - err := s.store.DeleteHistoryBranch(&p.DeleteHistoryBranchRequest{ + err := s.store.DeleteHistoryBranch(s.Ctx, &p.DeleteHistoryBranchRequest{ ShardID: shardID, BranchToken: branchToken, }) @@ -640,7 +645,7 @@ func (s *HistoryEventsSuite) trimHistoryBranch( nodeID int64, transactionID int64, ) { - _, err := s.store.TrimHistoryBranch(&p.TrimHistoryBranchRequest{ + _, err := s.store.TrimHistoryBranch(s.Ctx, &p.TrimHistoryBranchRequest{ ShardID: shardID, BranchToken: branchToken, NodeID: nodeID, @@ -658,7 +663,7 @@ func (s *HistoryEventsSuite) listHistoryEvents( var token []byte var events []*historypb.HistoryEvent for doContinue := true; doContinue; doContinue = len(token) > 0 { - resp, err := s.store.ReadHistoryBranch(&p.ReadHistoryBranchRequest{ + resp, err := s.store.ReadHistoryBranch(s.Ctx, &p.ReadHistoryBranchRequest{ ShardID: shardID, BranchToken: branchToken, MinEventID: startEventID, @@ -680,7 +685,7 @@ func (s *HistoryEventsSuite) listAllHistoryEvents( var token []byte var events []*historypb.HistoryEvent for doContinue := true; doContinue; doContinue = len(token) > 0 { - resp, err := s.store.ReadHistoryBranch(&p.ReadHistoryBranchRequest{ + resp, err := s.store.ReadHistoryBranch(s.Ctx, &p.ReadHistoryBranchRequest{ ShardID: shardID, BranchToken: branchToken, MinEventID: common.FirstEventID, diff --git a/host/archival_test.go b/host/archival_test.go index 1fa7f83a489..b32b423ba2f 100644 --- a/host/archival_test.go +++ b/host/archival_test.go @@ -193,7 +193,7 @@ func (s *integrationSuite) isHistoryDeleted(execution *commonpb.WorkflowExecutio ShardID: convert.Int32Ptr(shardID), } for i := 0; i < retryLimit; i++ { - resp, err := s.testCluster.testBase.ExecutionManager.GetHistoryTree(request) + resp, err := s.testCluster.testBase.ExecutionManager.GetHistoryTree(NewContext(), request) s.NoError(err) if len(resp.Branches) == 0 { return true @@ -214,7 +214,7 @@ func (s *integrationSuite) isMutableStateDeleted(namespaceID string, execution * } for i := 0; i < retryLimit; i++ { - _, err := s.testCluster.testBase.ExecutionManager.GetWorkflowExecution(request) + _, err := s.testCluster.testBase.ExecutionManager.GetWorkflowExecution(NewContext(), request) if _, ok := err.(*serviceerror.NotFound); ok { return true } diff --git a/host/ndc/replication_integration_test.go b/host/ndc/replication_integration_test.go index 73d6d67bfcc..f1166419a26 100644 --- a/host/ndc/replication_integration_test.go +++ b/host/ndc/replication_integration_test.go @@ -34,6 +34,7 @@ import ( "go.temporal.io/server/common/persistence" test "go.temporal.io/server/common/testing" + "go.temporal.io/server/host" "go.temporal.io/server/service/history/tasks" ) @@ -139,7 +140,7 @@ Loop: var token []byte for doPaging := true; doPaging; doPaging = len(token) > 0 { request.NextPageToken = token - response, err := executionManager.GetReplicationTasksFromDLQ(request) + response, err := executionManager.GetReplicationTasksFromDLQ(host.NewContext(), request) if err != nil { continue Loop } diff --git a/service/frontend/adminHandler.go b/service/frontend/adminHandler.go index 5eff565eca3..d3ad2692105 100644 --- a/service/frontend/adminHandler.go +++ b/service/frontend/adminHandler.go @@ -571,7 +571,7 @@ func (adh *AdminHandler) ListHistoryTasks( maxTaskKey.FireTime = timestamp.TimeValue(taskRange.ExclusiveMaxTaskKey.FireTime) maxTaskKey.TaskID = taskRange.ExclusiveMaxTaskKey.TaskId } - resp, err := adh.persistenceExecutionManager.GetHistoryTasks(&persistence.GetHistoryTasksRequest{ + resp, err := adh.persistenceExecutionManager.GetHistoryTasks(ctx, &persistence.GetHistoryTasksRequest{ ShardID: request.ShardId, TaskCategory: taskCategory, InclusiveMinTaskKey: minTaskKey, @@ -741,7 +741,7 @@ func (adh *AdminHandler) GetWorkflowExecutionRawHistoryV2(ctx context.Context, r execution.GetWorkflowId(), adh.numberOfHistoryShards, ) - rawHistoryResponse, err := adh.persistenceExecutionManager.ReadRawHistoryBranch(&persistence.ReadHistoryBranchRequest{ + rawHistoryResponse, err := adh.persistenceExecutionManager.ReadRawHistoryBranch(ctx, &persistence.ReadHistoryBranchRequest{ BranchToken: targetVersionHistory.GetBranchToken(), // GetWorkflowExecutionRawHistoryV2 is exclusive exclusive. // ReadRawHistoryBranch is inclusive exclusive. diff --git a/service/frontend/adminHandler_test.go b/service/frontend/adminHandler_test.go index 5d8b77876f1..3588dae9f96 100644 --- a/service/frontend/adminHandler_test.go +++ b/service/frontend/adminHandler_test.go @@ -247,7 +247,7 @@ func (s *adminHandlerSuite) Test_GetWorkflowExecutionRawHistoryV2() { } s.mockHistoryClient.EXPECT().GetMutableState(gomock.Any(), gomock.Any()).Return(mState, nil).AnyTimes() - s.mockExecutionMgr.EXPECT().ReadRawHistoryBranch(gomock.Any()).Return(&persistence.ReadRawHistoryBranchResponse{ + s.mockExecutionMgr.EXPECT().ReadRawHistoryBranch(gomock.Any(), gomock.Any()).Return(&persistence.ReadRawHistoryBranchResponse{ HistoryEventBlobs: []*commonpb.DataBlob{}, NextPageToken: []byte{}, Size: 0, diff --git a/service/frontend/workflowHandler.go b/service/frontend/workflowHandler.go index a9e0b14cb60..10834626062 100644 --- a/service/frontend/workflowHandler.go +++ b/service/frontend/workflowHandler.go @@ -568,6 +568,7 @@ func (wh *WorkflowHandler) GetWorkflowExecutionHistory(ctx context.Context, requ if !isWorkflowRunning { if rawHistoryQueryEnabled { historyBlob, _, err = wh.getRawHistory( + ctx, wh.metricsScope(ctx), namespaceID, *execution, @@ -586,6 +587,7 @@ func (wh *WorkflowHandler) GetWorkflowExecutionHistory(ctx context.Context, requ historyBlob = historyBlob[len(historyBlob)-1:] } else { history, _, err = wh.getHistory( + ctx, wh.metricsScope(ctx), namespaceID, namespace.Name(request.GetNamespace()), @@ -621,6 +623,7 @@ func (wh *WorkflowHandler) GetWorkflowExecutionHistory(ctx context.Context, requ } else { if rawHistoryQueryEnabled { historyBlob, continuationToken.PersistenceToken, err = wh.getRawHistory( + ctx, wh.metricsScope(ctx), namespaceID, *execution, @@ -633,6 +636,7 @@ func (wh *WorkflowHandler) GetWorkflowExecutionHistory(ctx context.Context, requ ) } else { history, continuationToken.PersistenceToken, err = wh.getHistory( + ctx, wh.metricsScope(ctx), namespaceID, namespace.Name(request.GetNamespace()), @@ -804,6 +808,7 @@ func (wh *WorkflowHandler) GetWorkflowExecutionHistoryReverse(ctx context.Contex history.Events = []*historypb.HistoryEvent{} // return all events history, continuationToken.PersistenceToken, continuationToken.NextEventId, err = wh.getHistoryReverse( + ctx, wh.metricsScope(ctx), namespaceID, namespace.Name(request.GetNamespace()), @@ -2998,6 +3003,7 @@ func (wh *WorkflowHandler) ListTaskQueuePartitions(ctx context.Context, request } func (wh *WorkflowHandler) getRawHistory( + ctx context.Context, scope metrics.Scope, namespaceID namespace.ID, execution commonpb.WorkflowExecution, @@ -3011,7 +3017,7 @@ func (wh *WorkflowHandler) getRawHistory( var rawHistory []*commonpb.DataBlob shardID := common.WorkflowIDToHistoryShard(namespaceID.String(), execution.GetWorkflowId(), wh.config.NumHistoryShards) - resp, err := wh.persistenceExecutionManager.ReadRawHistoryBranch(&persistence.ReadHistoryBranchRequest{ + resp, err := wh.persistenceExecutionManager.ReadRawHistoryBranch(ctx, &persistence.ReadHistoryBranchRequest{ BranchToken: branchToken, MinEventID: firstEventID, MaxEventID: nextEventID, @@ -3063,6 +3069,7 @@ func (wh *WorkflowHandler) getRawHistory( } func (wh *WorkflowHandler) getHistory( + ctx context.Context, scope metrics.Scope, namespaceID namespace.ID, namespace namespace.Name, @@ -3080,7 +3087,7 @@ func (wh *WorkflowHandler) getHistory( shardID := common.WorkflowIDToHistoryShard(namespaceID.String(), execution.GetWorkflowId(), wh.config.NumHistoryShards) var err error var historyEvents []*historypb.HistoryEvent - historyEvents, size, nextPageToken, err = persistence.ReadFullPageEvents(wh.persistenceExecutionManager, &persistence.ReadHistoryBranchRequest{ + historyEvents, size, nextPageToken, err = persistence.ReadFullPageEvents(ctx, wh.persistenceExecutionManager, &persistence.ReadHistoryBranchRequest{ BranchToken: branchToken, MinEventID: firstEventID, MaxEventID: nextEventID, @@ -3142,6 +3149,7 @@ func (wh *WorkflowHandler) getHistory( } func (wh *WorkflowHandler) getHistoryReverse( + ctx context.Context, scope metrics.Scope, namespaceID namespace.ID, namespace namespace.Name, @@ -3157,7 +3165,7 @@ func (wh *WorkflowHandler) getHistoryReverse( var err error var historyEvents []*historypb.HistoryEvent - historyEvents, size, nextPageToken, err = persistence.ReadFullPageEventsReverse(wh.persistenceExecutionManager, &persistence.ReadHistoryBranchReverseRequest{ + historyEvents, size, nextPageToken, err = persistence.ReadFullPageEventsReverse(ctx, wh.persistenceExecutionManager, &persistence.ReadHistoryBranchReverseRequest{ BranchToken: branchToken, MaxEventID: nextEventID, LastFirstTransactionID: lastFirstTxnID, @@ -3320,6 +3328,7 @@ func (wh *WorkflowHandler) createPollWorkflowTaskQueueResponse( } }() history, persistenceToken, err = wh.getHistory( + ctx, wh.metricsScope(ctx), namespaceID, namespaceEntry.Name(), @@ -3697,7 +3706,7 @@ func (wh *WorkflowHandler) trimHistoryNode( return // abort } - _, err = wh.persistenceExecutionManager.TrimHistoryBranch(&persistence.TrimHistoryBranchRequest{ + _, err = wh.persistenceExecutionManager.TrimHistoryBranch(ctx, &persistence.TrimHistoryBranchRequest{ ShardID: common.WorkflowIDToHistoryShard(namespaceID, workflowID, wh.config.NumHistoryShards), BranchToken: response.CurrentBranchToken, NodeID: response.GetLastFirstEventId(), diff --git a/service/frontend/workflowHandler_test.go b/service/frontend/workflowHandler_test.go index ba277d2ace4..59e48fbe659 100644 --- a/service/frontend/workflowHandler_test.go +++ b/service/frontend/workflowHandler_test.go @@ -1282,7 +1282,7 @@ func (s *workflowHandlerSuite) TestGetHistory() { NextPageToken: []byte{}, ShardID: shardID, } - s.mockExecutionManager.EXPECT().ReadHistoryBranch(req).Return(&persistence.ReadHistoryBranchResponse{ + s.mockExecutionManager.EXPECT().ReadHistoryBranch(gomock.Any(), req).Return(&persistence.ReadHistoryBranchResponse{ HistoryEvents: []*historypb.HistoryEvent{ { EventId: int64(100), @@ -1313,6 +1313,7 @@ func (s *workflowHandlerSuite) TestGetHistory() { wh := s.getWorkflowHandler(s.newConfig()) history, token, err := wh.getHistory( + context.Background(), metrics.NoopScope, namespaceID, namespace, @@ -1371,7 +1372,7 @@ func (s *workflowHandlerSuite) TestGetWorkflowExecutionHistory() { }, nil).Times(2) // GetWorkflowExecutionHistory will request the last event - s.mockExecutionManager.EXPECT().ReadHistoryBranch(&persistence.ReadHistoryBranchRequest{ + s.mockExecutionManager.EXPECT().ReadHistoryBranch(gomock.Any(), &persistence.ReadHistoryBranchRequest{ BranchToken: branchToken, MinEventID: 5, MaxEventID: 6, @@ -1397,7 +1398,7 @@ func (s *workflowHandlerSuite) TestGetWorkflowExecutionHistory() { Size: 1, }, nil).Times(2) - s.mockExecutionManager.EXPECT().TrimHistoryBranch(gomock.Any()).Return(nil, nil).AnyTimes() + s.mockExecutionManager.EXPECT().TrimHistoryBranch(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() s.mockSearchAttributesProvider.EXPECT().GetSearchAttributes(gomock.Any(), false).Return(searchattribute.TestNameTypeMap, nil).AnyTimes() wh := s.getWorkflowHandler(s.newConfig()) @@ -1492,7 +1493,7 @@ func (s *workflowHandlerSuite) TestGetWorkflowExecutionHistory_RawHistoryWithTra enumspb.ENCODING_TYPE_PROTO3, ) s.NoError(err) - s.mockExecutionManager.EXPECT().ReadRawHistoryBranch(&persistence.ReadHistoryBranchRequest{ + s.mockExecutionManager.EXPECT().ReadRawHistoryBranch(gomock.Any(), &persistence.ReadHistoryBranchRequest{ BranchToken: branchToken, MinEventID: 1, MaxEventID: 5, diff --git a/service/history/events/cache.go b/service/history/events/cache.go index d4b6711458e..af85d3345d5 100644 --- a/service/history/events/cache.go +++ b/service/history/events/cache.go @@ -27,6 +27,7 @@ package events import ( + "context" "time" historypb "go.temporal.io/api/history/v1" @@ -51,7 +52,7 @@ type ( } Cache interface { - GetEvent(key EventKey, firstEventID int64, branchToken []byte) (*historypb.HistoryEvent, error) + GetEvent(ctx context.Context, key EventKey, firstEventID int64, branchToken []byte) (*historypb.HistoryEvent, error) PutEvent(key EventKey, event *historypb.HistoryEvent) DeleteEvent(key EventKey) } @@ -109,7 +110,7 @@ func (e *CacheImpl) validateKey(key EventKey) bool { return true } -func (e *CacheImpl) GetEvent(key EventKey, firstEventID int64, branchToken []byte) (*historypb.HistoryEvent, error) { +func (e *CacheImpl) GetEvent(ctx context.Context, key EventKey, firstEventID int64, branchToken []byte) (*historypb.HistoryEvent, error) { e.metricsClient.IncCounter(metrics.EventsCacheGetEventScope, metrics.CacheRequests) sw := e.metricsClient.StartTimer(metrics.EventsCacheGetEventScope, metrics.CacheLatency) defer sw.Stop() @@ -125,7 +126,7 @@ func (e *CacheImpl) GetEvent(key EventKey, firstEventID int64, branchToken []byt } e.metricsClient.IncCounter(metrics.EventsCacheGetEventScope, metrics.CacheMissCounter) - event, err := e.getHistoryEventFromStore(key, firstEventID, branchToken) + event, err := e.getHistoryEventFromStore(ctx, key, firstEventID, branchToken) if err != nil { e.metricsClient.IncCounter(metrics.EventsCacheGetEventScope, metrics.CacheFailures) e.logger.Error("Cache unable to retrieve event from store", @@ -165,6 +166,7 @@ func (e *CacheImpl) DeleteEvent(key EventKey) { } func (e *CacheImpl) getHistoryEventFromStore( + ctx context.Context, key EventKey, firstEventID int64, branchToken []byte, @@ -174,7 +176,7 @@ func (e *CacheImpl) getHistoryEventFromStore( sw := e.metricsClient.StartTimer(metrics.EventsCacheGetFromStoreScope, metrics.CacheLatency) defer sw.Stop() - response, err := e.eventsMgr.ReadHistoryBranch(&persistence.ReadHistoryBranchRequest{ + response, err := e.eventsMgr.ReadHistoryBranch(ctx, &persistence.ReadHistoryBranchRequest{ BranchToken: branchToken, MinEventID: firstEventID, MaxEventID: key.EventID + 1, diff --git a/service/history/events/cache_test.go b/service/history/events/cache_test.go index 2d0abac688d..49c2af5de22 100644 --- a/service/history/events/cache_test.go +++ b/service/history/events/cache_test.go @@ -25,6 +25,7 @@ package events import ( + "context" "errors" "testing" "time" @@ -113,6 +114,7 @@ func (s *eventsCacheSuite) TestEventsCacheHitSuccess() { EventKey{namespaceID, workflowID, runID, eventID, common.EmptyVersion}, event) actualEvent, err := s.cache.GetEvent( + context.Background(), EventKey{namespaceID, workflowID, runID, eventID, common.EmptyVersion}, eventID, nil) s.Nil(err) @@ -155,7 +157,7 @@ func (s *eventsCacheSuite) TestEventsCacheMissMultiEventsBatchV2Success() { } shardID := int32(10) - s.mockExecutionManager.EXPECT().ReadHistoryBranch(&persistence.ReadHistoryBranchRequest{ + s.mockExecutionManager.EXPECT().ReadHistoryBranch(gomock.Any(), &persistence.ReadHistoryBranchRequest{ BranchToken: []byte("store_token"), MinEventID: event1.GetEventId(), MaxEventID: event6.GetEventId() + 1, @@ -171,6 +173,7 @@ func (s *eventsCacheSuite) TestEventsCacheMissMultiEventsBatchV2Success() { EventKey{namespaceID, workflowID, runID, event2.GetEventId(), common.EmptyVersion}, event2) actualEvent, err := s.cache.GetEvent( + context.Background(), EventKey{namespaceID, workflowID, runID, event6.GetEventId(), common.EmptyVersion}, event1.GetEventId(), []byte("store_token")) s.Nil(err) @@ -184,7 +187,7 @@ func (s *eventsCacheSuite) TestEventsCacheMissV2Failure() { shardID := int32(10) expectedErr := errors.New("persistence call failed") - s.mockExecutionManager.EXPECT().ReadHistoryBranch(&persistence.ReadHistoryBranchRequest{ + s.mockExecutionManager.EXPECT().ReadHistoryBranch(gomock.Any(), &persistence.ReadHistoryBranchRequest{ BranchToken: []byte("store_token"), MinEventID: int64(11), MaxEventID: int64(15), @@ -194,6 +197,7 @@ func (s *eventsCacheSuite) TestEventsCacheMissV2Failure() { }).Return(nil, expectedErr) actualEvent, err := s.cache.GetEvent( + context.Background(), EventKey{namespaceID, workflowID, runID, int64(14), common.EmptyVersion}, int64(11), []byte("store_token")) s.Nil(actualEvent) @@ -216,7 +220,7 @@ func (s *eventsCacheSuite) TestEventsCacheDisableSuccess() { } shardID := int32(10) - s.mockExecutionManager.EXPECT().ReadHistoryBranch(&persistence.ReadHistoryBranchRequest{ + s.mockExecutionManager.EXPECT().ReadHistoryBranch(gomock.Any(), &persistence.ReadHistoryBranchRequest{ BranchToken: []byte("store_token"), MinEventID: event2.GetEventId(), MaxEventID: event2.GetEventId() + 1, @@ -236,6 +240,7 @@ func (s *eventsCacheSuite) TestEventsCacheDisableSuccess() { event2) s.cache.disabled = true actualEvent, err := s.cache.GetEvent( + context.Background(), EventKey{namespaceID, workflowID, runID, event2.GetEventId(), common.EmptyVersion}, event2.GetEventId(), []byte("store_token")) s.Nil(err) @@ -253,7 +258,7 @@ func (s *eventsCacheSuite) TestEventsCacheGetCachesResult() { EventId: 14, EventType: enumspb.EVENT_TYPE_ACTIVITY_TASK_STARTED, } - s.mockExecutionManager.EXPECT().ReadHistoryBranch(&persistence.ReadHistoryBranchRequest{ + s.mockExecutionManager.EXPECT().ReadHistoryBranch(gomock.Any(), &persistence.ReadHistoryBranchRequest{ BranchToken: branchToken, MinEventID: int64(11), MaxEventID: int64(15), @@ -266,10 +271,12 @@ func (s *eventsCacheSuite) TestEventsCacheGetCachesResult() { }, nil).Times(1) // will only be called once with two calls to GetEvent gotEvent1, _ := s.cache.GetEvent( + context.Background(), EventKey{namespaceID, workflowID, runID, int64(14), common.EmptyVersion}, int64(11), branchToken) s.Equal(gotEvent1, event1) gotEvent2, _ := s.cache.GetEvent( + context.Background(), EventKey{namespaceID, workflowID, runID, int64(14), common.EmptyVersion}, int64(11), branchToken) s.Equal(gotEvent2, event1) @@ -286,7 +293,7 @@ func (s *eventsCacheSuite) TestEventsCacheInvalidKey() { EventId: 14, EventType: enumspb.EVENT_TYPE_ACTIVITY_TASK_STARTED, } - s.mockExecutionManager.EXPECT().ReadHistoryBranch(&persistence.ReadHistoryBranchRequest{ + s.mockExecutionManager.EXPECT().ReadHistoryBranch(gomock.Any(), &persistence.ReadHistoryBranchRequest{ BranchToken: branchToken, MinEventID: int64(11), MaxEventID: int64(15), @@ -303,10 +310,12 @@ func (s *eventsCacheSuite) TestEventsCacheInvalidKey() { event1) gotEvent1, _ := s.cache.GetEvent( + context.Background(), EventKey{namespaceID, workflowID, runID, int64(14), common.EmptyVersion}, int64(11), branchToken) s.Equal(gotEvent1, event1) gotEvent2, _ := s.cache.GetEvent( + context.Background(), EventKey{namespaceID, workflowID, runID, int64(14), common.EmptyVersion}, int64(11), branchToken) s.Equal(gotEvent2, event1) diff --git a/service/history/events/events_cache_mock.go b/service/history/events/events_cache_mock.go index d57bc1ee29c..87ea198bc82 100644 --- a/service/history/events/events_cache_mock.go +++ b/service/history/events/events_cache_mock.go @@ -29,6 +29,7 @@ package events import ( + context "context" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -71,18 +72,18 @@ func (mr *MockCacheMockRecorder) DeleteEvent(key interface{}) *gomock.Call { } // GetEvent mocks base method. -func (m *MockCache) GetEvent(key EventKey, firstEventID int64, branchToken []byte) (*v1.HistoryEvent, error) { +func (m *MockCache) GetEvent(ctx context.Context, key EventKey, firstEventID int64, branchToken []byte) (*v1.HistoryEvent, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetEvent", key, firstEventID, branchToken) + ret := m.ctrl.Call(m, "GetEvent", ctx, key, firstEventID, branchToken) ret0, _ := ret[0].(*v1.HistoryEvent) ret1, _ := ret[1].(error) return ret0, ret1 } // GetEvent indicates an expected call of GetEvent. -func (mr *MockCacheMockRecorder) GetEvent(key, firstEventID, branchToken interface{}) *gomock.Call { +func (mr *MockCacheMockRecorder) GetEvent(ctx, key, firstEventID, branchToken interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEvent", reflect.TypeOf((*MockCache)(nil).GetEvent), key, firstEventID, branchToken) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEvent", reflect.TypeOf((*MockCache)(nil).GetEvent), ctx, key, firstEventID, branchToken) } // PutEvent mocks base method. diff --git a/service/history/handler.go b/service/history/handler.go index 290330ff895..0a1ff66997f 100644 --- a/service/history/handler.go +++ b/service/history/handler.go @@ -563,7 +563,7 @@ func (h *Handler) DescribeHistoryHost(_ context.Context, _ *historyservice.Descr } // RemoveTask returns information about the internal states of a history host -func (h *Handler) RemoveTask(_ context.Context, request *historyservice.RemoveTaskRequest) (_ *historyservice.RemoveTaskResponse, retError error) { +func (h *Handler) RemoveTask(ctx context.Context, request *historyservice.RemoveTaskRequest) (_ *historyservice.RemoveTaskResponse, retError error) { var err error var category tasks.Category switch categoryID := request.GetCategory(); categoryID { @@ -583,7 +583,7 @@ func (h *Handler) RemoveTask(_ context.Context, request *historyservice.RemoveTa } } - err = h.persistenceExecutionManager.CompleteHistoryTask(&persistence.CompleteHistoryTaskRequest{ + err = h.persistenceExecutionManager.CompleteHistoryTask(ctx, &persistence.CompleteHistoryTaskRequest{ ShardID: request.GetShardId(), TaskCategory: category, TaskKey: tasks.Key{ diff --git a/service/history/historyEngine.go b/service/history/historyEngine.go index 19b18090c0e..83b0fe57a67 100644 --- a/service/history/historyEngine.go +++ b/service/history/historyEngine.go @@ -554,6 +554,7 @@ func (e *historyEngineImpl) StartWorkflowExecution( prevRunID := "" prevLastWriteVersion := int64(0) err = weContext.CreateWorkflowExecution( + ctx, now, createMode, prevRunID, @@ -633,6 +634,7 @@ func (e *historyEngineImpl) StartWorkflowExecution( } if err = weContext.CreateWorkflowExecution( + ctx, now, persistence.CreateWorkflowModeWorkflowIDReuse, prevRunID, @@ -873,7 +875,7 @@ func (e *historyEngineImpl) QueryWorkflow( return nil, err } defer func() { release(retErr) }() - mutableState, err := context.LoadWorkflowExecution() + mutableState, err := context.LoadWorkflowExecution(ctx) if err != nil { return nil, err } @@ -1082,7 +1084,7 @@ func (e *historyEngineImpl) getMutableState( } defer func() { release(retError) }() - mutableState, err := context.LoadWorkflowExecution() + mutableState, err := context.LoadWorkflowExecution(ctx) if err != nil { return nil, err } @@ -1158,7 +1160,7 @@ func (e *historyEngineImpl) DescribeMutableState( // clear mutable state to force reload from persistence. This API returns both cached and persisted version. context.Clear() - mutableState, err := context.LoadWorkflowExecution() + mutableState, err := context.LoadWorkflowExecution(ctx) if err != nil { return nil, err } @@ -1230,7 +1232,7 @@ func (e *historyEngineImpl) DescribeWorkflowExecution( } defer func() { release(retError) }() - mutableState, err1 := context.LoadWorkflowExecution() + mutableState, err1 := context.LoadWorkflowExecution(ctx) if err1 != nil { return nil, err1 } @@ -1274,7 +1276,7 @@ func (e *historyEngineImpl) DescribeWorkflowExecution( if executionState.State == enumsspb.WORKFLOW_EXECUTION_STATE_COMPLETED { // for closed workflow result.WorkflowExecutionInfo.Status = executionState.Status - completionEvent, err := mutableState.GetCompletionEvent() + completionEvent, err := mutableState.GetCompletionEvent(ctx) if err != nil { return nil, err } @@ -1298,7 +1300,7 @@ func (e *historyEngineImpl) DescribeWorkflowExecution( p.HeartbeatDetails = ai.LastHeartbeatDetails } // TODO: move to mutable state instead of loading it from event - scheduledEvent, err := mutableState.GetActivityScheduledEvent(ai.ScheduleId) + scheduledEvent, err := mutableState.GetActivityScheduledEvent(ctx, ai.ScheduleId) if err != nil { return nil, err } @@ -1407,7 +1409,7 @@ func (e *historyEngineImpl) RecordActivityTaskStarted( return nil, consts.ErrActivityTaskNotFound } - scheduledEvent, err := mutableState.GetActivityScheduledEvent(scheduleID) + scheduledEvent, err := mutableState.GetActivityScheduledEvent(ctx, scheduleID) if err != nil { return nil, err } @@ -2035,7 +2037,7 @@ func (e *historyEngineImpl) SignalWithStartWorkflowExecution( Just_Signal_Loop: for ; attempt <= conditionalRetryCount; attempt++ { // workflow not exist, will create workflow then signal - mutableState, err1 := context.LoadWorkflowExecution() + mutableState, err1 := context.LoadWorkflowExecution(ctx) if err1 != nil { if _, ok := err1.(*serviceerror.NotFound); ok { break @@ -2086,7 +2088,7 @@ func (e *historyEngineImpl) SignalWithStartWorkflowExecution( // We apply the update to execution using optimistic concurrency. If it fails due to a conflict then reload // the history and try the operation again. - if err := context.UpdateWorkflowExecutionAsActive(e.shard.GetTimeSource().Now()); err != nil { + if err := context.UpdateWorkflowExecutionAsActive(ctx, e.shard.GetTimeSource().Now()); err != nil { if err == consts.ErrConflict { continue Just_Signal_Loop } @@ -2152,6 +2154,7 @@ func (e *historyEngineImpl) SignalWithStartWorkflowExecution( if prevExecutionUpdateAction != nil { err := e.updateWorkflowWithNewHelper( + ctx, newWorkflowContext(prevContext, release, prevMutableState), prevExecutionUpdateAction, func() (workflow.Context, workflow.MutableState, error) { @@ -2205,6 +2208,7 @@ func (e *historyEngineImpl) SignalWithStartWorkflowExecution( } } err = context.CreateWorkflowExecution( + ctx, now, createMode, prevRunID, @@ -2339,6 +2343,7 @@ func (e *historyEngineImpl) DeleteWorkflowExecution( defer func() { wfCtx.getReleaseFn()(retError) }() return e.workflowDeleteManager.AddDeleteWorkflowExecutionTask( + ctx, nsID, commonpb.WorkflowExecution{ WorkflowId: request.GetWorkflowExecution().GetWorkflowId(), @@ -2476,7 +2481,7 @@ func (e *historyEngineImpl) ResetWorkflowExecution( } defer func() { baseReleaseFn(retError) }() - baseMutableState, err := baseContext.LoadWorkflowExecution() + baseMutableState, err := baseContext.LoadWorkflowExecution(ctx) if err != nil { return nil, err } @@ -2497,7 +2502,7 @@ func (e *historyEngineImpl) ResetWorkflowExecution( } // also load the current run of the workflow, it can be different from the base runID - resp, err := e.executionManager.GetCurrentExecution(&persistence.GetCurrentExecutionRequest{ + resp, err := e.executionManager.GetCurrentExecution(ctx, &persistence.GetCurrentExecutionRequest{ ShardID: e.shard.GetShardID(), NamespaceID: namespaceID.String(), WorkflowID: request.WorkflowExecution.GetWorkflowId(), @@ -2532,7 +2537,7 @@ func (e *historyEngineImpl) ResetWorkflowExecution( } defer func() { currentReleaseFn(retError) }() - currentMutableState, err = currentContext.LoadWorkflowExecution() + currentMutableState, err = currentContext.LoadWorkflowExecution(ctx) if err != nil { return nil, err } @@ -2606,7 +2611,7 @@ func (e *historyEngineImpl) updateWorkflow( } defer func() { workflowContext.getReleaseFn()(retError) }() - return e.updateWorkflowWithNewHelper(workflowContext, action, nil) + return e.updateWorkflowWithNewHelper(ctx, workflowContext, action, nil) } func (e *historyEngineImpl) updateWorkflowExecution( @@ -2622,7 +2627,7 @@ func (e *historyEngineImpl) updateWorkflowExecution( } defer func() { workflowContext.getReleaseFn()(retError) }() - return e.updateWorkflowWithNewHelper(workflowContext, action, nil) + return e.updateWorkflowWithNewHelper(ctx, workflowContext, action, nil) } func (e *historyEngineImpl) updateWorkflowExecutionWithNew( @@ -2644,10 +2649,11 @@ func (e *historyEngineImpl) updateWorkflowExecutionWithNew( } defer func() { workflowContext.getReleaseFn()(retError) }() - return e.updateWorkflowWithNewHelper(workflowContext, action, newWorkflowFn) + return e.updateWorkflowWithNewHelper(ctx, workflowContext, action, newWorkflowFn) } func (e *historyEngineImpl) updateWorkflowWithNewHelper( + ctx context.Context, workflowContext workflowContext, action updateWorkflowActionFunc, newWorkflowFn func() (workflow.Context, workflow.MutableState, error), @@ -2666,7 +2672,7 @@ UpdateHistoryLoop: // Reload workflow execution history workflowContext.getContext().Clear() if attempt != conditionalRetryCount { - _, err = workflowContext.reloadMutableState() + _, err = workflowContext.reloadMutableState(ctx) if err != nil { return err } @@ -2704,19 +2710,21 @@ UpdateHistoryLoop: } err = workflowContext.getContext().UpdateWorkflowExecutionWithNewAsActive( + ctx, e.shard.GetTimeSource().Now(), newContext, newMutableState, ) } else { err = workflowContext.getContext().UpdateWorkflowExecutionAsActive( + ctx, e.shard.GetTimeSource().Now(), ) } if err == consts.ErrConflict { if attempt != conditionalRetryCount { - _, err = workflowContext.reloadMutableState() + _, err = workflowContext.reloadMutableState(ctx) if err != nil { return err } @@ -2753,7 +2761,8 @@ func (e *historyEngineImpl) newWorkflowVersionCheck( } func (e *historyEngineImpl) failWorkflowTask( - context workflow.Context, + ctx context.Context, + wfContext workflow.Context, scheduleID int64, startedID int64, wtFailedCause *workflowTaskFailedCause, @@ -2761,10 +2770,10 @@ func (e *historyEngineImpl) failWorkflowTask( ) (workflow.MutableState, error) { // clear any updates we have accumulated so far - context.Clear() + wfContext.Clear() // Reload workflow execution so we can apply the workflow task failure event - mutableState, err := context.LoadWorkflowExecution() + mutableState, err := wfContext.LoadWorkflowExecution(ctx) if err != nil { return nil, err } @@ -3341,6 +3350,7 @@ func (e *historyEngineImpl) PurgeDLQMessages( } return e.replicationDLQHandler.purgeMessages( + ctx, request.GetSourceCluster(), request.GetInclusiveEndMessageId(), ) @@ -3409,7 +3419,7 @@ func (e *historyEngineImpl) RefreshWorkflowTasks( } defer func() { release(retError) }() - mutableState, err := context.LoadWorkflowExecution() + mutableState, err := context.LoadWorkflowExecution(ctx) if err != nil { return err } @@ -3428,12 +3438,12 @@ func (e *historyEngineImpl) RefreshWorkflowTasks( now := e.shard.GetTimeSource().Now() - err = mutableStateTaskRefresher.RefreshTasks(now, mutableState) + err = mutableStateTaskRefresher.RefreshTasks(ctx, now, mutableState) if err != nil { return err } - err = context.UpdateWorkflowExecutionAsActive(now) + err = context.UpdateWorkflowExecutionAsActive(ctx, now) if err != nil { return err } @@ -3461,7 +3471,7 @@ func (e *historyEngineImpl) GenerateLastHistoryReplicationTasks( } defer func() { release(retError) }() - mutableState, err := context.LoadWorkflowExecution() + mutableState, err := context.LoadWorkflowExecution(ctx) if err != nil { return nil, err } @@ -3472,7 +3482,7 @@ func (e *historyEngineImpl) GenerateLastHistoryReplicationTasks( return nil, err } - err = e.shard.AddTasks(&persistence.AddHistoryTasksRequest{ + err = e.shard.AddTasks(ctx, &persistence.AddHistoryTasksRequest{ ShardID: e.shard.GetShardID(), // RangeID is set by shard NamespaceID: string(namespaceID), @@ -3530,7 +3540,7 @@ func (e *historyEngineImpl) loadWorkflowOnce( return nil, err } - mutableState, err := context.LoadWorkflowExecution() + mutableState, err := context.LoadWorkflowExecution(ctx) if err != nil { release(err) return nil, err @@ -3563,6 +3573,7 @@ func (e *historyEngineImpl) loadWorkflow( // workflow not running, need to check current record resp, err := e.shard.GetExecutionManager().GetCurrentExecution( + ctx, &persistence.GetCurrentExecutionRequest{ ShardID: e.shard.GetShardID(), NamespaceID: namespaceID.String(), diff --git a/service/history/historyEngine2_test.go b/service/history/historyEngine2_test.go index 7b42dd18790..47eaeb76cac 100644 --- a/service/history/historyEngine2_test.go +++ b/service/history/historyEngine2_test.go @@ -210,8 +210,8 @@ func (s *engine2Suite) TestRecordWorkflowTaskStartedSuccessStickyEnabled() { gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) request := historyservice.RecordWorkflowTaskStartedRequest{ NamespaceId: namespaceID.String(), @@ -266,7 +266,7 @@ func (s *engine2Suite) TestRecordWorkflowTaskStartedIfNoExecution() { identity := "testIdentity" tl := "testTaskQueue" - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(nil, serviceerror.NewNotFound("")) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, serviceerror.NewNotFound("")) response, err := s.historyEngine.RecordWorkflowTaskStarted(metrics.AddMetricsContext(context.Background()), &historyservice.RecordWorkflowTaskStartedRequest{ NamespaceId: namespaceID.String(), @@ -296,7 +296,7 @@ func (s *engine2Suite) TestRecordWorkflowTaskStartedIfGetExecutionFailed() { identity := "testIdentity" tl := "testTaskQueue" - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(nil, errors.New("FAILED")) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, errors.New("FAILED")) response, err := s.historyEngine.RecordWorkflowTaskStarted(metrics.AddMetricsContext(context.Background()), &historyservice.RecordWorkflowTaskStartedRequest{ NamespaceId: namespaceID.String(), @@ -329,7 +329,7 @@ func (s *engine2Suite) TestRecordWorkflowTaskStartedIfTaskAlreadyStarted() { msBuilder := s.createExecutionStartedState(workflowExecution, tl, identity, true) ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) response, err := s.historyEngine.RecordWorkflowTaskStarted(metrics.AddMetricsContext(context.Background()), &historyservice.RecordWorkflowTaskStartedRequest{ NamespaceId: namespaceID.String(), @@ -366,7 +366,7 @@ func (s *engine2Suite) TestRecordWorkflowTaskStartedIfTaskAlreadyCompleted() { ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) response, err := s.historyEngine.RecordWorkflowTaskStarted(metrics.AddMetricsContext(context.Background()), &historyservice.RecordWorkflowTaskStartedRequest{ NamespaceId: namespaceID.String(), @@ -402,14 +402,14 @@ func (s *engine2Suite) TestRecordWorkflowTaskStartedConflictOnUpdate() { ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(nil, &persistence.ConditionFailedError{}) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, &persistence.ConditionFailedError{}) ms2 := workflow.TestCloneToProto(msBuilder) gwmsResponse2 := &persistence.GetWorkflowExecutionResponse{State: ms2} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse2, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse2, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) response, err := s.historyEngine.RecordWorkflowTaskStarted(metrics.AddMetricsContext(context.Background()), &historyservice.RecordWorkflowTaskStartedRequest{ NamespaceId: namespaceID.String(), @@ -446,13 +446,13 @@ func (s *engine2Suite) TestRecordWorkflowTaskRetrySameRequest() { ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(nil, &persistence.ConditionFailedError{}) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, &persistence.ConditionFailedError{}) startedEventID := addWorkflowTaskStartedEventWithRequestID(msBuilder, int64(2), requestID, tl, identity) ms2 := workflow.TestCloneToProto(msBuilder) gwmsResponse2 := &persistence.GetWorkflowExecutionResponse{State: ms2} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse2, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse2, nil) response, err := s.historyEngine.RecordWorkflowTaskStarted(metrics.AddMetricsContext(context.Background()), &historyservice.RecordWorkflowTaskStartedRequest{ NamespaceId: namespaceID.String(), @@ -489,14 +489,14 @@ func (s *engine2Suite) TestRecordWorkflowTaskRetryDifferentRequest() { msBuilder := s.createExecutionStartedState(workflowExecution, tl, identity, false) ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(nil, &persistence.ConditionFailedError{}) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, &persistence.ConditionFailedError{}) // Add event. addWorkflowTaskStartedEventWithRequestID(msBuilder, int64(2), "some_other_req", tl, identity) ms2 := workflow.TestCloneToProto(msBuilder) gwmsResponse2 := &persistence.GetWorkflowExecutionResponse{State: ms2} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse2, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse2, nil) response, err := s.historyEngine.RecordWorkflowTaskStarted(metrics.AddMetricsContext(context.Background()), &historyservice.RecordWorkflowTaskStartedRequest{ NamespaceId: namespaceID.String(), @@ -533,10 +533,10 @@ func (s *engine2Suite) TestRecordWorkflowTaskStartedMaxAttemptsExceeded() { ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) } - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(nil, + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, &persistence.ConditionFailedError{}).Times(conditionalRetryCount) response, err := s.historyEngine.RecordWorkflowTaskStarted(metrics.AddMetricsContext(context.Background()), &historyservice.RecordWorkflowTaskStartedRequest{ @@ -571,8 +571,8 @@ func (s *engine2Suite) TestRecordWorkflowTaskSuccess() { msBuilder := s.createExecutionStartedState(workflowExecution, tl, identity, false) ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) // load mutable state such that it already exists in memory when respond workflow task is called // this enables us to set query registry on it @@ -583,7 +583,7 @@ func (s *engine2Suite) TestRecordWorkflowTaskSuccess() { workflow.CallerTypeAPI, ) s.NoError(err) - loadedMS, err := ctx.LoadWorkflowExecution() + loadedMS, err := ctx.LoadWorkflowExecution(context.Background()) s.NoError(err) qr := workflow.NewQueryRegistry() id1, _ := qr.BufferQuery(&querypb.WorkflowQuery{}) @@ -629,7 +629,7 @@ func (s *engine2Suite) TestRecordActivityTaskStartedIfNoExecution() { identity := "testIdentity" tl := "testTaskQueue" - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(nil, serviceerror.NewNotFound("")) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, serviceerror.NewNotFound("")) response, err := s.historyEngine.RecordActivityTaskStarted( metrics.AddMetricsContext(context.Background()), @@ -676,10 +676,11 @@ func (s *engine2Suite) TestRecordActivityTaskStartedSuccess() { ms1 := workflow.TestCloneToProto(msBuilder) gwmsResponse1 := &persistence.GetWorkflowExecutionResponse{State: ms1} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse1, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse1, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) s.mockEventsCache.EXPECT().GetEvent( + gomock.Any(), events.EventKey{ NamespaceID: namespaceID, WorkflowID: workflowExecution.GetWorkflowId(), @@ -722,8 +723,8 @@ func (s *engine2Suite) TestRequestCancelWorkflowExecution_Running() { ms1 := workflow.TestCloneToProto(msBuilder) gwmsResponse1 := &persistence.GetWorkflowExecutionResponse{State: ms1} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse1, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse1, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) err := s.historyEngine.RequestCancelWorkflowExecution(metrics.AddMetricsContext(context.Background()), &historyservice.RequestCancelWorkflowExecutionRequest{ NamespaceId: namespaceID.String(), @@ -756,7 +757,7 @@ func (s *engine2Suite) TestRequestCancelWorkflowExecution_Finished() { ms1 := workflow.TestCloneToProto(msBuilder) gwmsResponse1 := &persistence.GetWorkflowExecutionResponse{State: ms1} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse1, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse1, nil) err := s.historyEngine.RequestCancelWorkflowExecution(metrics.AddMetricsContext(context.Background()), &historyservice.RequestCancelWorkflowExecutionRequest{ NamespaceId: namespaceID.String(), @@ -778,7 +779,7 @@ func (s *engine2Suite) TestRequestCancelWorkflowExecution_NotFound() { RunId: tests.RunID, } - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(nil, &serviceerror.NotFound{}) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, &serviceerror.NotFound{}) err := s.historyEngine.RequestCancelWorkflowExecution(metrics.AddMetricsContext(context.Background()), &historyservice.RequestCancelWorkflowExecutionRequest{ NamespaceId: namespaceID.String(), @@ -817,7 +818,7 @@ func (s *engine2Suite) TestRequestCancelWorkflowExecution_ParentMismatch() { ms1 := workflow.TestCloneToProto(msBuilder) gwmsResponse1 := &persistence.GetWorkflowExecutionResponse{State: ms1} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse1, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse1, nil) err := s.historyEngine.RequestCancelWorkflowExecution(metrics.AddMetricsContext(context.Background()), &historyservice.RequestCancelWorkflowExecutionRequest{ NamespaceId: namespaceID.String(), @@ -863,8 +864,8 @@ func (s *engine2Suite) TestTerminateWorkflowExecution_ParentMismatch() { } gwmsResponse1 := &persistence.GetWorkflowExecutionResponse{State: ms1} - s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any()).Return(currentExecutionResp, nil) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse1, nil) + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), gomock.Any()).Return(currentExecutionResp, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse1, nil) err := s.historyEngine.TerminateWorkflowExecution(metrics.AddMetricsContext(context.Background()), &historyservice.TerminateWorkflowExecutionRequest{ NamespaceId: namespaceID.String(), @@ -954,8 +955,8 @@ func (s *engine2Suite) TestRespondWorkflowTaskCompletedRecordMarkerCommand() { ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) _, err := s.historyEngine.RespondWorkflowTaskCompleted(metrics.AddMetricsContext(context.Background()), &historyservice.RespondWorkflowTaskCompletedRequest{ NamespaceId: namespaceID.String(), @@ -1010,10 +1011,10 @@ func (s *engine2Suite) TestRespondWorkflowTaskCompleted_StartChildWithSearchAttr ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) s.mockNamespaceCache.EXPECT().GetNamespace(tests.Namespace).Return(tests.LocalNamespaceEntry, nil).AnyTimes() - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).DoAndReturn(func(request *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, request *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { eventsToSave := request.UpdateWorkflowEvents[0].Events s.Len(eventsToSave, 2) s.Equal(enumspb.EVENT_TYPE_WORKFLOW_TASK_COMPLETED, eventsToSave[0].GetEventType()) @@ -1048,7 +1049,7 @@ func (s *engine2Suite) TestStartWorkflowExecution_BrandNew() { taskQueue := "testTaskQueue" identity := "testIdentity" - s.mockExecutionMgr.EXPECT().CreateWorkflowExecution(gomock.Any()).Return(tests.CreateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().CreateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.CreateWorkflowExecutionResponse, nil) requestID := uuid.New() resp, err := s.historyEngine.StartWorkflowExecution(metrics.AddMetricsContext(context.Background()), &historyservice.StartWorkflowExecutionRequest{ @@ -1077,7 +1078,7 @@ func (s *engine2Suite) TestStartWorkflowExecution_BrandNew_SearchAttributes() { taskQueue := "testTaskQueue" identity := "testIdentity" - s.mockExecutionMgr.EXPECT().CreateWorkflowExecution(gomock.Any()).DoAndReturn(func(request *persistence.CreateWorkflowExecutionRequest) (*persistence.CreateWorkflowExecutionResponse, error) { + s.mockExecutionMgr.EXPECT().CreateWorkflowExecution(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, request *persistence.CreateWorkflowExecutionRequest) (*persistence.CreateWorkflowExecutionResponse, error) { eventsToSave := request.NewWorkflowEvents[0].Events s.Len(eventsToSave, 2) s.Equal(enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED, eventsToSave[0].GetEventType()) @@ -1121,7 +1122,7 @@ func (s *engine2Suite) TestStartWorkflowExecution_StillRunning_Dedup() { requestID := "requestID" lastWriteVersion := common.EmptyVersion - s.mockExecutionMgr.EXPECT().CreateWorkflowExecution(gomock.Any()).Return(nil, &persistence.CurrentWorkflowConditionFailedError{ + s.mockExecutionMgr.EXPECT().CreateWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, &persistence.CurrentWorkflowConditionFailedError{ Msg: "random message", RequestID: requestID, RunID: runID, @@ -1157,7 +1158,7 @@ func (s *engine2Suite) TestStartWorkflowExecution_StillRunning_NonDeDup() { identity := "testIdentity" lastWriteVersion := common.EmptyVersion - s.mockExecutionMgr.EXPECT().CreateWorkflowExecution(gomock.Any()).Return(nil, &persistence.CurrentWorkflowConditionFailedError{ + s.mockExecutionMgr.EXPECT().CreateWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, &persistence.CurrentWorkflowConditionFailedError{ Msg: "random message", RequestID: "oldRequestID", RunID: runID, @@ -1204,6 +1205,7 @@ func (s *engine2Suite) TestStartWorkflowExecution_NotRunning_PrevSuccess() { expecedErrs := []bool{true, false, true} s.mockExecutionMgr.EXPECT().CreateWorkflowExecution( + gomock.Any(), newCreateWorkflowExecutionRequestMatcher(func(request *persistence.CreateWorkflowExecutionRequest) bool { return request.Mode == persistence.CreateWorkflowModeBrandNew }), @@ -1219,6 +1221,7 @@ func (s *engine2Suite) TestStartWorkflowExecution_NotRunning_PrevSuccess() { for index, option := range options { if !expecedErrs[index] { s.mockExecutionMgr.EXPECT().CreateWorkflowExecution( + gomock.Any(), newCreateWorkflowExecutionRequestMatcher(func(request *persistence.CreateWorkflowExecutionRequest) bool { return request.Mode == persistence.CreateWorkflowModeWorkflowIDReuse && request.PreviousRunID == runID && @@ -1282,6 +1285,7 @@ func (s *engine2Suite) TestStartWorkflowExecution_NotRunning_PrevFail() { for i, status := range statuses { s.mockExecutionMgr.EXPECT().CreateWorkflowExecution( + gomock.Any(), newCreateWorkflowExecutionRequestMatcher(func(request *persistence.CreateWorkflowExecutionRequest) bool { return request.Mode == persistence.CreateWorkflowModeBrandNew }), @@ -1298,6 +1302,7 @@ func (s *engine2Suite) TestStartWorkflowExecution_NotRunning_PrevFail() { if !expecedErrs[j] { s.mockExecutionMgr.EXPECT().CreateWorkflowExecution( + gomock.Any(), newCreateWorkflowExecutionRequestMatcher(func(request *persistence.CreateWorkflowExecutionRequest) bool { return request.Mode == persistence.CreateWorkflowModeWorkflowIDReuse && request.PreviousRunID == runIDs[i] && @@ -1368,9 +1373,9 @@ func (s *engine2Suite) TestSignalWithStartWorkflowExecution_JustSignal() { gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} gceResponse := &persistence.GetCurrentExecutionResponse{RunID: runID} - s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any()).Return(gceResponse, nil) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), gomock.Any()).Return(gceResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) resp, err := s.historyEngine.SignalWithStartWorkflowExecution(metrics.AddMetricsContext(context.Background()), sRequest) s.Nil(err) @@ -1409,8 +1414,8 @@ func (s *engine2Suite) TestSignalWithStartWorkflowExecution_WorkflowNotExist() { notExistErr := serviceerror.NewNotFound("Workflow not exist") - s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any()).Return(nil, notExistErr) - s.mockExecutionMgr.EXPECT().CreateWorkflowExecution(gomock.Any()).Return(tests.CreateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), gomock.Any()).Return(nil, notExistErr) + s.mockExecutionMgr.EXPECT().CreateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.CreateWorkflowExecutionResponse, nil) resp, err := s.historyEngine.SignalWithStartWorkflowExecution(metrics.AddMetricsContext(context.Background()), sRequest) s.Nil(err) @@ -1469,9 +1474,9 @@ func (s *engine2Suite) TestSignalWithStartWorkflowExecution_WorkflowNotRunning() gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} gceResponse := &persistence.GetCurrentExecutionResponse{RunID: runID} - s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any()).Return(gceResponse, nil) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().CreateWorkflowExecution(gomock.Any()).Return(tests.CreateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), gomock.Any()).Return(gceResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().CreateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.CreateWorkflowExecutionResponse, nil) resp, err := s.historyEngine.SignalWithStartWorkflowExecution(metrics.AddMetricsContext(context.Background()), sRequest) s.Nil(err) @@ -1534,9 +1539,9 @@ func (s *engine2Suite) TestSignalWithStartWorkflowExecution_Start_DuplicateReque LastWriteVersion: common.EmptyVersion, } - s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any()).Return(gceResponse, nil) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().CreateWorkflowExecution(gomock.Any()).Return(nil, workflowAlreadyStartedErr) + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), gomock.Any()).Return(gceResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().CreateWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, workflowAlreadyStartedErr) ctx := metrics.AddMetricsContext(context.Background()) resp, err := s.historyEngine.SignalWithStartWorkflowExecution(ctx, sRequest) @@ -1609,9 +1614,9 @@ func (s *engine2Suite) TestSignalWithStartWorkflowExecution_Start_WorkflowAlread LastWriteVersion: common.EmptyVersion, } - s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any()).Return(gceResponse, nil) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().CreateWorkflowExecution(gomock.Any()).Return(nil, workflowAlreadyStartedErr) + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), gomock.Any()).Return(gceResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().CreateWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, workflowAlreadyStartedErr) resp, err := s.historyEngine.SignalWithStartWorkflowExecution(metrics.AddMetricsContext(context.Background()), sRequest) s.Nil(resp) diff --git a/service/history/historyEngine3_eventsv2_test.go b/service/history/historyEngine3_eventsv2_test.go index 9d2ee4cf80e..47ddc606256 100644 --- a/service/history/historyEngine3_eventsv2_test.go +++ b/service/history/historyEngine3_eventsv2_test.go @@ -192,8 +192,8 @@ func (s *engine3Suite) TestRecordWorkflowTaskStartedSuccessStickyEnabled() { gwmsResponse := &p.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) request := historyservice.RecordWorkflowTaskStartedRequest{ NamespaceId: namespaceID.String(), @@ -249,7 +249,7 @@ func (s *engine3Suite) TestStartWorkflowExecution_BrandNew() { taskQueue := "testTaskQueue" identity := "testIdentity" - s.mockExecutionMgr.EXPECT().CreateWorkflowExecution(gomock.Any()).Return(tests.CreateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().CreateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.CreateWorkflowExecutionResponse, nil) requestID := uuid.New() resp, err := s.historyEngine.StartWorkflowExecution(context.Background(), &historyservice.StartWorkflowExecutionRequest{ @@ -309,9 +309,9 @@ func (s *engine3Suite) TestSignalWithStartWorkflowExecution_JustSignal() { gwmsResponse := &p.GetWorkflowExecutionResponse{State: ms} gceResponse := &p.GetCurrentExecutionResponse{RunID: runID} - s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any()).Return(gceResponse, nil) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), gomock.Any()).Return(gceResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) resp, err := s.historyEngine.SignalWithStartWorkflowExecution(context.Background(), sRequest) s.Nil(err) @@ -355,8 +355,8 @@ func (s *engine3Suite) TestSignalWithStartWorkflowExecution_WorkflowNotExist() { notExistErr := serviceerror.NewNotFound("Workflow not exist") - s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any()).Return(nil, notExistErr) - s.mockExecutionMgr.EXPECT().CreateWorkflowExecution(gomock.Any()).Return(tests.CreateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), gomock.Any()).Return(nil, notExistErr) + s.mockExecutionMgr.EXPECT().CreateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.CreateWorkflowExecutionResponse, nil) resp, err := s.historyEngine.SignalWithStartWorkflowExecution(context.Background(), sRequest) s.Nil(err) diff --git a/service/history/historyEngine_test.go b/service/history/historyEngine_test.go index e746089b4e6..a26def14619 100644 --- a/service/history/historyEngine_test.go +++ b/service/history/historyEngine_test.go @@ -229,7 +229,7 @@ func (s *engineSuite) TestGetMutableStateSync() { ms := workflow.TestCloneToProto(msBuilder) gweResponse := &persistence.GetWorkflowExecutionResponse{State: ms} // right now the next event ID is 4 - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gweResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gweResponse, nil) // test get the next event ID instantly response, err := s.mockHistoryEngine.GetMutableState(ctx, &historyservice.GetMutableStateRequest{ @@ -262,7 +262,7 @@ func (s *engineSuite) TestGetMutableState_EmptyRunID() { WorkflowId: "test-get-workflow-execution-event-id", } - s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any()).Return(nil, serviceerror.NewNotFound("")) + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), gomock.Any()).Return(nil, serviceerror.NewNotFound("")) _, err := s.mockHistoryEngine.GetMutableState(ctx, &historyservice.GetMutableStateRequest{ NamespaceId: tests.NamespaceID.String(), @@ -289,7 +289,7 @@ func (s *engineSuite) TestGetMutableStateLongPoll() { ms := workflow.TestCloneToProto(msBuilder) gweResponse := &persistence.GetWorkflowExecutionResponse{State: ms} // right now the next event ID is 4 - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gweResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gweResponse, nil) // test long poll on next event ID change waitGroup := &sync.WaitGroup{} @@ -302,7 +302,7 @@ func (s *engineSuite) TestGetMutableStateLongPoll() { ScheduleId: 2, } taskToken, _ := tt.Marshal() - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) timer := time.NewTimer(delay) @@ -364,7 +364,7 @@ func (s *engineSuite) TestGetMutableStateLongPoll_CurrentBranchChanged() { ms := workflow.TestCloneToProto(msBuilder) gweResponse := &persistence.GetWorkflowExecutionResponse{State: ms} // right now the next event ID is 4 - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gweResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gweResponse, nil) // test long poll on next event ID change asyncBranchTokenUpdate := func(delay time.Duration) { @@ -426,7 +426,7 @@ func (s *engineSuite) TestGetMutableStateLongPollTimeout() { ms := workflow.TestCloneToProto(msBuilder) gweResponse := &persistence.GetWorkflowExecutionResponse{State: ms} // right now the next event ID is 4 - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gweResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gweResponse, nil) // long poll, no event happen after long poll timeout response, err := s.mockHistoryEngine.GetMutableState(ctx, &historyservice.GetMutableStateRequest{ @@ -455,7 +455,7 @@ func (s *engineSuite) TestQueryWorkflow_RejectBasedOnCompleted() { addCompleteWorkflowEvent(msBuilder, event.GetEventId(), nil) ms := workflow.TestCloneToProto(msBuilder) gweResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gweResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gweResponse, nil) request := &historyservice.QueryWorkflowRequest{ NamespaceId: tests.NamespaceID.String(), @@ -489,7 +489,7 @@ func (s *engineSuite) TestQueryWorkflow_RejectBasedOnFailed() { addFailWorkflowEvent(msBuilder, event.GetEventId(), failure.NewServerFailure("failure reason", true), enumspb.RETRY_STATE_NON_RETRYABLE_FAILURE) ms := workflow.TestCloneToProto(msBuilder) gweResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gweResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gweResponse, nil) request := &historyservice.QueryWorkflowRequest{ NamespaceId: tests.NamespaceID.String(), @@ -536,7 +536,7 @@ func (s *engineSuite) TestQueryWorkflow_DirectlyThroughMatching() { ms := workflow.TestCloneToProto(msBuilder) gweResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gweResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gweResponse, nil) s.mockMatchingClient.EXPECT().QueryWorkflow(gomock.Any(), gomock.Any()).Return(&matchingservice.QueryWorkflowResponse{QueryResult: payloads.EncodeBytes([]byte{1, 2, 3})}, nil) s.mockHistoryEngine.matchingClient = s.mockMatchingClient request := &historyservice.QueryWorkflowRequest{ @@ -576,7 +576,7 @@ func (s *engineSuite) TestQueryWorkflow_WorkflowTaskDispatch_Timeout() { ms := workflow.TestCloneToProto(msBuilder) gweResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gweResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gweResponse, nil) request := &historyservice.QueryWorkflowRequest{ NamespaceId: tests.NamespaceID.String(), Request: &workflowservice.QueryWorkflowRequest{ @@ -630,7 +630,7 @@ func (s *engineSuite) TestQueryWorkflow_ConsistentQueryBufferFull() { ms := workflow.TestCloneToProto(msBuilder) gweResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gweResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gweResponse, nil) // buffer query so that when history.QueryWorkflow is called buffer is already full ctx, release, err := s.mockHistoryEngine.historyCache.GetOrCreateWorkflowExecution( @@ -640,7 +640,7 @@ func (s *engineSuite) TestQueryWorkflow_ConsistentQueryBufferFull() { workflow.CallerTypeAPI, ) s.NoError(err) - loadedMS, err := ctx.LoadWorkflowExecution() + loadedMS, err := ctx.LoadWorkflowExecution(context.Background()) s.NoError(err) qr := workflow.NewQueryRegistry() qr.BufferQuery(&querypb.WorkflowQuery{}) @@ -676,7 +676,7 @@ func (s *engineSuite) TestQueryWorkflow_WorkflowTaskDispatch_Complete() { ms := workflow.TestCloneToProto(msBuilder) gweResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gweResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gweResponse, nil) waitGroup := &sync.WaitGroup{} waitGroup.Add(1) @@ -747,7 +747,7 @@ func (s *engineSuite) TestQueryWorkflow_WorkflowTaskDispatch_Unblocked() { ms := workflow.TestCloneToProto(msBuilder) gweResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gweResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gweResponse, nil) s.mockMatchingClient.EXPECT().QueryWorkflow(gomock.Any(), gomock.Any()).Return(&matchingservice.QueryWorkflowResponse{QueryResult: payloads.EncodeBytes([]byte{1, 2, 3})}, nil) s.mockHistoryEngine.matchingClient = s.mockMatchingClient waitGroup := &sync.WaitGroup{} @@ -823,7 +823,7 @@ func (s *engineSuite) TestRespondWorkflowTaskCompletedIfNoExecution() { taskToken, _ := tt.Marshal() identity := "testIdentity" - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(nil, serviceerror.NewNotFound("")) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, serviceerror.NewNotFound("")) _, err := s.mockHistoryEngine.RespondWorkflowTaskCompleted(context.Background(), &historyservice.RespondWorkflowTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -847,7 +847,7 @@ func (s *engineSuite) TestRespondWorkflowTaskCompletedIfGetExecutionFailed() { taskToken, _ := tt.Marshal() identity := "testIdentity" - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(nil, errors.New("FAILED")) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, errors.New("FAILED")) _, err := s.mockHistoryEngine.RespondWorkflowTaskCompleted(context.Background(), &historyservice.RespondWorkflowTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -885,8 +885,8 @@ func (s *engineSuite) TestRespondWorkflowTaskCompletedUpdateExecutionFailed() { ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, errors.New("FAILED")) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, errors.New("FAILED")) s.mockShardManager.EXPECT().UpdateShard(gomock.Any()).Return(nil).AnyTimes() // might be called in background goroutine _, err := s.mockHistoryEngine.RespondWorkflowTaskCompleted(context.Background(), &historyservice.RespondWorkflowTaskCompletedRequest{ @@ -926,7 +926,7 @@ func (s *engineSuite) TestRespondWorkflowTaskCompletedIfTaskCompleted() { ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) _, err := s.mockHistoryEngine.RespondWorkflowTaskCompleted(context.Background(), &historyservice.RespondWorkflowTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -963,7 +963,7 @@ func (s *engineSuite) TestRespondWorkflowTaskCompletedIfTaskNotStarted() { ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) _, err := s.mockHistoryEngine.RespondWorkflowTaskCompleted(context.Background(), &historyservice.RespondWorkflowTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -1038,8 +1038,8 @@ func (s *engineSuite) TestRespondWorkflowTaskCompletedConflictOnUpdate() { addActivityTaskCompletedEvent(msBuilder, activity2ScheduledEvent.EventId, activity2StartedEvent.EventId, activity2Result, identity) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, &persistence.ConditionFailedError{}) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, &persistence.ConditionFailedError{}) _, err := s.mockHistoryEngine.RespondWorkflowTaskCompleted(context.Background(), &historyservice.RespondWorkflowTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -1122,7 +1122,7 @@ func (s *engineSuite) TestRespondWorkflowTaskCompletedMaxAttemptsExceeded() { ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) } _, err := s.mockHistoryEngine.RespondWorkflowTaskCompleted(context.Background(), &historyservice.RespondWorkflowTaskCompletedRequest{ @@ -1189,14 +1189,14 @@ func (s *engineSuite) TestRespondWorkflowTaskCompletedCompleteWorkflowFailed() { ms1 := workflow.TestCloneToProto(msBuilder) gwmsResponse1 := &persistence.GetWorkflowExecutionResponse{State: ms1} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse1, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse1, nil) ms2 := proto.Clone(ms1).(*persistencespb.WorkflowMutableState) gwmsResponse2 := &persistence.GetWorkflowExecutionResponse{State: ms2} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse2, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse2, nil) var updatedWorkflowMutation persistence.WorkflowMutation - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).DoAndReturn(func(request *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, request *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { updatedWorkflowMutation = request.UpdateWorkflowMutation return tests.UpdateWorkflowExecutionResponse, nil }) @@ -1273,14 +1273,14 @@ func (s *engineSuite) TestRespondWorkflowTaskCompletedFailWorkflowFailed() { ms1 := workflow.TestCloneToProto(msBuilder) gwmsResponse1 := &persistence.GetWorkflowExecutionResponse{State: ms1} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse1, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse1, nil) ms2 := proto.Clone(ms1).(*persistencespb.WorkflowMutableState) gwmsResponse2 := &persistence.GetWorkflowExecutionResponse{State: ms2} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse2, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse2, nil) var updatedWorkflowMutation persistence.WorkflowMutation - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).DoAndReturn(func(request *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, request *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { updatedWorkflowMutation = request.UpdateWorkflowMutation return tests.UpdateWorkflowExecutionResponse, nil }) @@ -1346,9 +1346,9 @@ func (s *engineSuite) TestRespondWorkflowTaskCompletedBadCommandAttributes() { gwmsResponse1 := &persistence.GetWorkflowExecutionResponse{State: workflow.TestCloneToProto(msBuilder)} gwmsResponse2 := &persistence.GetWorkflowExecutionResponse{State: workflow.TestCloneToProto(msBuilder)} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse1, nil) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse2, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse1, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse2, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) _, err := s.mockHistoryEngine.RespondWorkflowTaskCompleted(context.Background(), &historyservice.RespondWorkflowTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -1459,15 +1459,15 @@ func (s *engineSuite) TestRespondWorkflowTaskCompletedSingleActivityScheduledAtt }} gwmsResponse1 := &persistence.GetWorkflowExecutionResponse{State: workflow.TestCloneToProto(msBuilder)} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse1, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse1, nil) ms2 := workflow.TestCloneToProto(msBuilder) if iVar.expectWorkflowTaskFail { gwmsResponse2 := &persistence.GetWorkflowExecutionResponse{State: ms2} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse2, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse2, nil) } var updatedWorkflowMutation persistence.WorkflowMutation - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).DoAndReturn(func(request *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, request *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { updatedWorkflowMutation = request.UpdateWorkflowMutation return tests.UpdateWorkflowExecutionResponse, nil }) @@ -1541,10 +1541,10 @@ func (s *engineSuite) TestRespondWorkflowTaskCompletedBadBinary() { ms2 := workflow.TestCloneToProto(msBuilder) gwmsResponse2 := &persistence.GetWorkflowExecutionResponse{State: ms2} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse1, nil) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse2, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse1, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse2, nil) var updatedWorkflowMutation persistence.WorkflowMutation - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).DoAndReturn(func(request *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, request *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { updatedWorkflowMutation = request.UpdateWorkflowMutation return tests.UpdateWorkflowExecutionResponse, nil }) @@ -1609,8 +1609,8 @@ func (s *engineSuite) TestRespondWorkflowTaskCompletedSingleActivityScheduledWor ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) _, err := s.mockHistoryEngine.RespondWorkflowTaskCompleted(context.Background(), &historyservice.RespondWorkflowTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -1667,8 +1667,8 @@ func (s *engineSuite) TestRespondWorkflowTaskCompleted_WorkflowTaskHeartbeatTime ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) _, err := s.mockHistoryEngine.RespondWorkflowTaskCompleted(context.Background(), &historyservice.RespondWorkflowTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -1710,8 +1710,8 @@ func (s *engineSuite) TestRespondWorkflowTaskCompleted_WorkflowTaskHeartbeatNotT ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) _, err := s.mockHistoryEngine.RespondWorkflowTaskCompleted(context.Background(), &historyservice.RespondWorkflowTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -1753,8 +1753,8 @@ func (s *engineSuite) TestRespondWorkflowTaskCompleted_WorkflowTaskHeartbeatNotT ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) _, err := s.mockHistoryEngine.RespondWorkflowTaskCompleted(context.Background(), &historyservice.RespondWorkflowTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -1801,8 +1801,8 @@ func (s *engineSuite) TestRespondWorkflowTaskCompletedCompleteWorkflowSuccess() ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) _, err := s.mockHistoryEngine.RespondWorkflowTaskCompleted(context.Background(), &historyservice.RespondWorkflowTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -1853,8 +1853,8 @@ func (s *engineSuite) TestRespondWorkflowTaskCompletedFailWorkflowSuccess() { ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) _, err := s.mockHistoryEngine.RespondWorkflowTaskCompleted(context.Background(), &historyservice.RespondWorkflowTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -1910,8 +1910,8 @@ func (s *engineSuite) TestRespondWorkflowTaskCompletedSignalExternalWorkflowSucc ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) _, err := s.mockHistoryEngine.RespondWorkflowTaskCompleted(context.Background(), &historyservice.RespondWorkflowTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -1965,8 +1965,8 @@ func (s *engineSuite) TestRespondWorkflowTaskCompletedStartChildWorkflowWithAban ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) _, err := s.mockHistoryEngine.RespondWorkflowTaskCompleted(context.Background(), &historyservice.RespondWorkflowTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -2028,8 +2028,8 @@ func (s *engineSuite) TestRespondWorkflowTaskCompletedStartChildWorkflowWithTerm ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) _, err := s.mockHistoryEngine.RespondWorkflowTaskCompleted(context.Background(), &historyservice.RespondWorkflowTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -2142,7 +2142,7 @@ func (s *engineSuite) TestRespondWorkflowTaskCompletedSignalExternalWorkflowFail ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) s.mockNamespaceCache.EXPECT().GetNamespace(foreignNamespace).Return( nil, errors.New("get foreign namespace error"), ) @@ -2188,7 +2188,7 @@ func (s *engineSuite) TestRespondActivityTaskCompletedIfNoExecution() { taskToken, _ := tt.Marshal() identity := "testIdentity" - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(nil, serviceerror.NewNotFound("")) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, serviceerror.NewNotFound("")) err := s.mockHistoryEngine.RespondActivityTaskCompleted(context.Background(), &historyservice.RespondActivityTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -2211,7 +2211,7 @@ func (s *engineSuite) TestRespondActivityTaskCompletedIfNoRunID() { taskToken, _ := tt.Marshal() identity := "testIdentity" - s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any()).Return(nil, serviceerror.NewNotFound("")) + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), gomock.Any()).Return(nil, serviceerror.NewNotFound("")) err := s.mockHistoryEngine.RespondActivityTaskCompleted(context.Background(), &historyservice.RespondActivityTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -2235,7 +2235,7 @@ func (s *engineSuite) TestRespondActivityTaskCompletedIfGetExecutionFailed() { taskToken, _ := tt.Marshal() identity := "testIdentity" - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(nil, errors.New("FAILED")) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, errors.New("FAILED")) err := s.mockHistoryEngine.RespondActivityTaskCompleted(context.Background(), &historyservice.RespondActivityTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -2270,8 +2270,8 @@ func (s *engineSuite) TestRespondActivityTaskCompletedIfNoAIdProvided() { gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} gceResponse := &persistence.GetCurrentExecutionResponse{RunID: tests.RunID} - s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any()).Return(gceResponse, nil) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), gomock.Any()).Return(gceResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) err := s.mockHistoryEngine.RespondActivityTaskCompleted(context.Background(), &historyservice.RespondActivityTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -2307,8 +2307,8 @@ func (s *engineSuite) TestRespondActivityTaskCompletedIfNotFound() { gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} gceResponse := &persistence.GetCurrentExecutionResponse{RunID: tests.RunID} - s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any()).Return(gceResponse, nil) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), gomock.Any()).Return(gceResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) err := s.mockHistoryEngine.RespondActivityTaskCompleted(context.Background(), &historyservice.RespondActivityTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -2352,8 +2352,8 @@ func (s *engineSuite) TestRespondActivityTaskCompletedUpdateExecutionFailed() { ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, errors.New("FAILED")) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, errors.New("FAILED")) s.mockShardManager.EXPECT().UpdateShard(gomock.Any()).Return(nil).AnyTimes() // might be called in background goroutine err := s.mockHistoryEngine.RespondActivityTaskCompleted(context.Background(), &historyservice.RespondActivityTaskCompletedRequest{ @@ -2402,7 +2402,7 @@ func (s *engineSuite) TestRespondActivityTaskCompletedIfTaskCompleted() { ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) err := s.mockHistoryEngine.RespondActivityTaskCompleted(context.Background(), &historyservice.RespondActivityTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -2447,7 +2447,7 @@ func (s *engineSuite) TestRespondActivityTaskCompletedIfTaskNotStarted() { ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) err := s.mockHistoryEngine.RespondActivityTaskCompleted(context.Background(), &historyservice.RespondActivityTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -2501,11 +2501,11 @@ func (s *engineSuite) TestRespondActivityTaskCompletedConflictOnUpdate() { ms2 := workflow.TestCloneToProto(msBuilder) gwmsResponse2 := &persistence.GetWorkflowExecutionResponse{State: ms2} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse1, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, &persistence.ConditionFailedError{}) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse1, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, &persistence.ConditionFailedError{}) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse2, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse2, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) err := s.mockHistoryEngine.RespondActivityTaskCompleted(context.Background(), &historyservice.RespondActivityTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -2562,8 +2562,8 @@ func (s *engineSuite) TestRespondActivityTaskCompletedMaxAttemptsExceeded() { ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, &persistence.ConditionFailedError{}) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, &persistence.ConditionFailedError{}) } err := s.mockHistoryEngine.RespondActivityTaskCompleted(context.Background(), &historyservice.RespondActivityTaskCompletedRequest{ @@ -2609,8 +2609,8 @@ func (s *engineSuite) TestRespondActivityTaskCompletedSuccess() { ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) err := s.mockHistoryEngine.RespondActivityTaskCompleted(context.Background(), &historyservice.RespondActivityTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -2668,9 +2668,9 @@ func (s *engineSuite) TestRespondActivityTaskCompletedByIdSuccess() { gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} gceResponse := &persistence.GetCurrentExecutionResponse{RunID: we.RunId} - s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any()).Return(gceResponse, nil) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), gomock.Any()).Return(gceResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) err := s.mockHistoryEngine.RespondActivityTaskCompleted(context.Background(), &historyservice.RespondActivityTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -2722,7 +2722,7 @@ func (s *engineSuite) TestRespondActivityTaskFailedIfNoExecution() { taskToken, _ := tt.Marshal() identity := "testIdentity" - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(nil, + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, serviceerror.NewNotFound("")) err := s.mockHistoryEngine.RespondActivityTaskFailed(context.Background(), &historyservice.RespondActivityTaskFailedRequest{ @@ -2746,7 +2746,7 @@ func (s *engineSuite) TestRespondActivityTaskFailedIfNoRunID() { taskToken, _ := tt.Marshal() identity := "testIdentity" - s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any()).Return(nil, + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), gomock.Any()).Return(nil, serviceerror.NewNotFound("")) err := s.mockHistoryEngine.RespondActivityTaskFailed(context.Background(), &historyservice.RespondActivityTaskFailedRequest{ @@ -2771,7 +2771,7 @@ func (s *engineSuite) TestRespondActivityTaskFailedIfGetExecutionFailed() { taskToken, _ := tt.Marshal() identity := "testIdentity" - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(nil, + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, errors.New("FAILED")) err := s.mockHistoryEngine.RespondActivityTaskFailed(context.Background(), &historyservice.RespondActivityTaskFailedRequest{ @@ -2807,8 +2807,8 @@ func (s *engineSuite) TestRespondActivityTaskFailededIfNoAIdProvided() { gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} gceResponse := &persistence.GetCurrentExecutionResponse{RunID: tests.RunID} - s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any()).Return(gceResponse, nil) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), gomock.Any()).Return(gceResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) err := s.mockHistoryEngine.RespondActivityTaskFailed(context.Background(), &historyservice.RespondActivityTaskFailedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -2844,8 +2844,8 @@ func (s *engineSuite) TestRespondActivityTaskFailededIfNotFound() { gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} gceResponse := &persistence.GetCurrentExecutionResponse{RunID: tests.RunID} - s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any()).Return(gceResponse, nil) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), gomock.Any()).Return(gceResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) err := s.mockHistoryEngine.RespondActivityTaskFailed(context.Background(), &historyservice.RespondActivityTaskFailedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -2888,8 +2888,8 @@ func (s *engineSuite) TestRespondActivityTaskFailedUpdateExecutionFailed() { ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, errors.New("FAILED")) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, errors.New("FAILED")) s.mockShardManager.EXPECT().UpdateShard(gomock.Any()).Return(nil).AnyTimes() // might be called in background goroutine err := s.mockHistoryEngine.RespondActivityTaskFailed(context.Background(), &historyservice.RespondActivityTaskFailedRequest{ @@ -2936,7 +2936,7 @@ func (s *engineSuite) TestRespondActivityTaskFailedIfTaskCompleted() { ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) err := s.mockHistoryEngine.RespondActivityTaskFailed(context.Background(), &historyservice.RespondActivityTaskFailedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -2980,7 +2980,7 @@ func (s *engineSuite) TestRespondActivityTaskFailedIfTaskNotStarted() { ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) err := s.mockHistoryEngine.RespondActivityTaskFailed(context.Background(), &historyservice.RespondActivityTaskFailedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -3038,11 +3038,11 @@ func (s *engineSuite) TestRespondActivityTaskFailedConflictOnUpdate() { ms2 := workflow.TestCloneToProto(msBuilder) gwmsResponse2 := &persistence.GetWorkflowExecutionResponse{State: ms2} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse1, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, &persistence.ConditionFailedError{}) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse1, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, &persistence.ConditionFailedError{}) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse2, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse2, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) err := s.mockHistoryEngine.RespondActivityTaskFailed(context.Background(), &historyservice.RespondActivityTaskFailedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -3098,8 +3098,8 @@ func (s *engineSuite) TestRespondActivityTaskFailedMaxAttemptsExceeded() { ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, &persistence.ConditionFailedError{}) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, &persistence.ConditionFailedError{}) } err := s.mockHistoryEngine.RespondActivityTaskFailed(context.Background(), &historyservice.RespondActivityTaskFailedRequest{ @@ -3144,8 +3144,8 @@ func (s *engineSuite) TestRespondActivityTaskFailedSuccess() { ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) err := s.mockHistoryEngine.RespondActivityTaskFailed(context.Background(), &historyservice.RespondActivityTaskFailedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -3202,8 +3202,8 @@ func (s *engineSuite) TestRespondActivityTaskFailedWithHeartbeatSuccess() { ms.ActivityInfos[activityInfo.ScheduleId] = activityInfo gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) details := payloads.EncodeString("details") @@ -3268,9 +3268,9 @@ func (s *engineSuite) TestRespondActivityTaskFailedByIdSuccess() { gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} gceResponse := &persistence.GetCurrentExecutionResponse{RunID: we.RunId} - s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any()).Return(gceResponse, nil) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), gomock.Any()).Return(gceResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) err := s.mockHistoryEngine.RespondActivityTaskFailed(context.Background(), &historyservice.RespondActivityTaskFailedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -3325,8 +3325,8 @@ func (s *engineSuite) TestRecordActivityTaskHeartBeatSuccess_NoTimer() { // No HeartBeat timer running. ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) detais := payloads.EncodeString("details") @@ -3373,8 +3373,8 @@ func (s *engineSuite) TestRecordActivityTaskHeartBeatSuccess_TimerRunning() { gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} // HeartBeat timer running. - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) detais := payloads.EncodeString("details") @@ -3426,8 +3426,8 @@ func (s *engineSuite) TestRecordActivityTaskHeartBeatByIDSuccess() { // No HeartBeat timer running. ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) detais := payloads.EncodeString("details") @@ -3472,7 +3472,7 @@ func (s *engineSuite) TestRespondActivityTaskCanceled_Scheduled() { ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) err := s.mockHistoryEngine.RespondActivityTaskCanceled(context.Background(), &historyservice.RespondActivityTaskCanceledRequest{ NamespaceId: tests.NamespaceID.String(), @@ -3519,8 +3519,8 @@ func (s *engineSuite) TestRespondActivityTaskCanceled_Started() { ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) err = s.mockHistoryEngine.RespondActivityTaskCanceled(context.Background(), &historyservice.RespondActivityTaskCanceledRequest{ NamespaceId: tests.NamespaceID.String(), @@ -3578,9 +3578,9 @@ func (s *engineSuite) TestRespondActivityTaskCanceledById_Started() { gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} gceResponse := &persistence.GetCurrentExecutionResponse{RunID: we.RunId} - s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any()).Return(gceResponse, nil) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), gomock.Any()).Return(gceResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) err = s.mockHistoryEngine.RespondActivityTaskCanceled(context.Background(), &historyservice.RespondActivityTaskCanceledRequest{ NamespaceId: tests.NamespaceID.String(), @@ -3614,7 +3614,7 @@ func (s *engineSuite) TestRespondActivityTaskCanceledIfNoRunID() { taskToken, _ := tt.Marshal() identity := "testIdentity" - s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any()).Return(nil, serviceerror.NewNotFound("")) + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), gomock.Any()).Return(nil, serviceerror.NewNotFound("")) err := s.mockHistoryEngine.RespondActivityTaskCanceled(context.Background(), &historyservice.RespondActivityTaskCanceledRequest{ NamespaceId: tests.NamespaceID.String(), @@ -3651,8 +3651,8 @@ func (s *engineSuite) TestRespondActivityTaskCanceledIfNoAIdProvided() { gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} gceResponse := &persistence.GetCurrentExecutionResponse{RunID: tests.RunID} - s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any()).Return(gceResponse, nil) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), gomock.Any()).Return(gceResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) err := s.mockHistoryEngine.RespondActivityTaskCanceled(context.Background(), &historyservice.RespondActivityTaskCanceledRequest{ NamespaceId: tests.NamespaceID.String(), @@ -3689,8 +3689,8 @@ func (s *engineSuite) TestRespondActivityTaskCanceledIfNotFound() { gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} gceResponse := &persistence.GetCurrentExecutionResponse{RunID: tests.RunID} - s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any()).Return(gceResponse, nil) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), gomock.Any()).Return(gceResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) err := s.mockHistoryEngine.RespondActivityTaskCanceled(context.Background(), &historyservice.RespondActivityTaskCanceledRequest{ NamespaceId: tests.NamespaceID.String(), @@ -3737,10 +3737,10 @@ func (s *engineSuite) TestRequestCancel_RespondWorkflowTaskCompleted_NotSchedule ms2 := workflow.TestCloneToProto(msBuilder) gwmsResponse2 := &persistence.GetWorkflowExecutionResponse{State: ms2} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse1, nil) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse2, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse1, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse2, nil) var updatedWorkflowMutation persistence.WorkflowMutation - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).DoAndReturn(func(request *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, request *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { updatedWorkflowMutation = request.UpdateWorkflowMutation return tests.UpdateWorkflowExecutionResponse, nil }) @@ -3802,8 +3802,8 @@ func (s *engineSuite) TestRequestCancel_RespondWorkflowTaskCompleted_Scheduled() }}, }} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) _, err := s.mockHistoryEngine.RespondWorkflowTaskCompleted(context.Background(), &historyservice.RespondWorkflowTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -3866,8 +3866,8 @@ func (s *engineSuite) TestRequestCancel_RespondWorkflowTaskCompleted_Started() { }}, }} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) _, err := s.mockHistoryEngine.RespondWorkflowTaskCompleted(context.Background(), &historyservice.RespondWorkflowTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -3933,8 +3933,8 @@ func (s *engineSuite) TestRequestCancel_RespondWorkflowTaskCompleted_Completed() ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) _, err := s.mockHistoryEngine.RespondWorkflowTaskCompleted(context.Background(), &historyservice.RespondWorkflowTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -3993,8 +3993,8 @@ func (s *engineSuite) TestRequestCancel_RespondWorkflowTaskCompleted_NoHeartBeat }}, }} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) _, err := s.mockHistoryEngine.RespondWorkflowTaskCompleted(context.Background(), &historyservice.RespondWorkflowTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -4013,7 +4013,7 @@ func (s *engineSuite) TestRequestCancel_RespondWorkflowTaskCompleted_NoHeartBeat s.False(executionBuilder.HasPendingWorkflowTask()) // Try recording activity heartbeat - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) att := &tokenspb.Task{ ScheduleAttempt: 1, @@ -4036,7 +4036,7 @@ func (s *engineSuite) TestRequestCancel_RespondWorkflowTaskCompleted_NoHeartBeat s.True(hbResponse.CancelRequested) // Try cancelling the request. - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) err = s.mockHistoryEngine.RespondActivityTaskCanceled(context.Background(), &historyservice.RespondActivityTaskCanceledRequest{ NamespaceId: tests.NamespaceID.String(), @@ -4095,8 +4095,8 @@ func (s *engineSuite) TestRequestCancel_RespondWorkflowTaskCompleted_Success() { }}, }} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) _, err := s.mockHistoryEngine.RespondWorkflowTaskCompleted(context.Background(), &historyservice.RespondWorkflowTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -4115,7 +4115,7 @@ func (s *engineSuite) TestRequestCancel_RespondWorkflowTaskCompleted_Success() { s.False(executionBuilder.HasPendingWorkflowTask()) // Try recording activity heartbeat - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) att := &tokenspb.Task{ ScheduleAttempt: 1, @@ -4138,7 +4138,7 @@ func (s *engineSuite) TestRequestCancel_RespondWorkflowTaskCompleted_Success() { s.True(hbResponse.CancelRequested) // Try cancelling the request. - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) err = s.mockHistoryEngine.RespondActivityTaskCanceled(context.Background(), &historyservice.RespondActivityTaskCanceledRequest{ NamespaceId: tests.NamespaceID.String(), @@ -4196,8 +4196,8 @@ func (s *engineSuite) TestRequestCancel_RespondWorkflowTaskCompleted_SuccessWith }}, }} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) // load mutable state such that it already exists in memory when respond workflow task is called // this enables us to set query registry on it @@ -4208,7 +4208,7 @@ func (s *engineSuite) TestRequestCancel_RespondWorkflowTaskCompleted_SuccessWith workflow.CallerTypeAPI, ) s.NoError(err) - loadedMS, err := ctx.LoadWorkflowExecution() + loadedMS, err := ctx.LoadWorkflowExecution(context.Background()) s.NoError(err) qr := workflow.NewQueryRegistry() id1, _ := qr.BufferQuery(&querypb.WorkflowQuery{}) @@ -4262,7 +4262,7 @@ func (s *engineSuite) TestRequestCancel_RespondWorkflowTaskCompleted_SuccessWith s.Equal(workflow.QueryTerminationTypeUnblocked, unblocked1.QueryTerminationType) // Try recording activity heartbeat - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) att := &tokenspb.Task{ ScheduleAttempt: 1, @@ -4285,7 +4285,7 @@ func (s *engineSuite) TestRequestCancel_RespondWorkflowTaskCompleted_SuccessWith s.True(hbResponse.CancelRequested) // Try cancelling the request. - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) err = s.mockHistoryEngine.RespondActivityTaskCanceled(context.Background(), &historyservice.RespondActivityTaskCanceledRequest{ NamespaceId: tests.NamespaceID.String(), @@ -4339,8 +4339,8 @@ func (s *engineSuite) TestStarTimer_DuplicateTimerID() { }}, }} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) _, err := s.mockHistoryEngine.RespondWorkflowTaskCompleted(context.Background(), &historyservice.RespondWorkflowTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -4370,9 +4370,9 @@ func (s *engineSuite) TestStarTimer_DuplicateTimerID() { gwmsResponse2 := &persistence.GetWorkflowExecutionResponse{State: ms2} workflowTaskFailedEvent := false - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse2, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse2, nil) var updatedWorkflowMutation persistence.WorkflowMutation - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).DoAndReturn(func(request *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, request *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { for _, newEvents := range request.UpdateWorkflowEvents { decTaskIndex := len(newEvents.Events) - 1 if decTaskIndex >= 0 && newEvents.Events[decTaskIndex].EventType == enumspb.EVENT_TYPE_WORKFLOW_TASK_FAILED { @@ -4442,8 +4442,8 @@ func (s *engineSuite) TestUserTimer_RespondWorkflowTaskCompleted() { }}, }} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) _, err := s.mockHistoryEngine.RespondWorkflowTaskCompleted(context.Background(), &historyservice.RespondWorkflowTaskCompletedRequest{ NamespaceId: tests.NamespaceID.String(), @@ -4499,10 +4499,10 @@ func (s *engineSuite) TestCancelTimer_RespondWorkflowTaskCompleted_NoStartTimer( }}, }} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse2, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse2, nil) var updatedWorkflowMutation persistence.WorkflowMutation - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).DoAndReturn(func(request *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, request *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { updatedWorkflowMutation = request.UpdateWorkflowMutation return tests.UpdateWorkflowExecutionResponse, nil }) @@ -4568,9 +4568,9 @@ func (s *engineSuite) TestCancelTimer_RespondWorkflowTaskCompleted_TimerFired() }}, }} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).DoAndReturn( - func(request *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, request *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { s.True(request.UpdateWorkflowMutation.ClearBufferedEvents) return tests.UpdateWorkflowExecutionResponse, nil }) @@ -4625,8 +4625,8 @@ func (s *engineSuite) TestSignalWorkflowExecution() { ms.ExecutionInfo.NamespaceId = tests.NamespaceID.String() gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) err = s.mockHistoryEngine.SignalWorkflowExecution(context.Background(), signalRequest) s.Nil(err) @@ -4669,7 +4669,7 @@ func (s *engineSuite) TestSignalWorkflowExecution_DuplicateRequest() { ms.ExecutionInfo.NamespaceId = tests.NamespaceID.String() gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) err = s.mockHistoryEngine.SignalWorkflowExecution(context.Background(), signalRequest) s.Nil(err) @@ -4713,7 +4713,7 @@ func (s *engineSuite) TestSignalWorkflowExecution_DuplicateRequest_Completed() { ms.ExecutionState.State = enumsspb.WORKFLOW_EXECUTION_STATE_COMPLETED gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) err = s.mockHistoryEngine.SignalWorkflowExecution(context.Background(), signalRequest) s.Nil(err) @@ -4751,7 +4751,7 @@ func (s *engineSuite) TestSignalWorkflowExecution_Failed() { ms.ExecutionState.State = enumsspb.WORKFLOW_EXECUTION_STATE_COMPLETED gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) err = s.mockHistoryEngine.SignalWorkflowExecution(context.Background(), signalRequest) s.EqualError(err, "workflow execution already completed") @@ -4783,8 +4783,8 @@ func (s *engineSuite) TestRemoveSignalMutableState() { ms.ExecutionInfo.NamespaceId = tests.NamespaceID.String() gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) err = s.mockHistoryEngine.RemoveSignalMutableState(context.Background(), removeRequest) s.Nil(err) @@ -4817,8 +4817,8 @@ func (s *engineSuite) TestReapplyEvents_ReturnSuccess() { ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} gceResponse := &persistence.GetCurrentExecutionResponse{RunID: tests.RunID} - s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any()).Return(gceResponse, nil) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), gomock.Any()).Return(gceResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) s.mockEventsReapplier.EXPECT().reapplyEvents(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) err := s.mockHistoryEngine.ReapplyEvents( @@ -4860,8 +4860,8 @@ func (s *engineSuite) TestReapplyEvents_IgnoreSameVersionEvents() { ms := workflow.TestCloneToProto(msBuilder) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} gceResponse := &persistence.GetCurrentExecutionResponse{RunID: tests.RunID} - s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any()).Return(gceResponse, nil) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), gomock.Any()).Return(gceResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) s.mockEventsReapplier.EXPECT().reapplyEvents(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) err := s.mockHistoryEngine.ReapplyEvents( @@ -4908,8 +4908,8 @@ func (s *engineSuite) TestReapplyEvents_ResetWorkflow() { ms.ExecutionInfo.VersionHistories = versionhistory.NewVersionHistories(versionHistory) gwmsResponse := &persistence.GetWorkflowExecutionResponse{State: ms} gceResponse := &persistence.GetCurrentExecutionResponse{RunID: tests.RunID} - s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any()).Return(gceResponse, nil) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(gwmsResponse, nil) + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), gomock.Any()).Return(gceResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(gwmsResponse, nil) s.mockEventsReapplier.EXPECT().reapplyEvents(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) s.mockWorkflowResetter.EXPECT().resetWorkflow( gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), @@ -4941,9 +4941,11 @@ func (s *engineSuite) getBuilder(testNamespaceID namespace.ID, we commonpb.Workf return context.(*workflow.ContextImpl).MutableState } -func (s *engineSuite) getActivityScheduledEvent(msBuilder workflow.MutableState, - scheduleID int64) *historypb.HistoryEvent { - event, _ := msBuilder.GetActivityScheduledEvent(scheduleID) +func (s *engineSuite) getActivityScheduledEvent( + msBuilder workflow.MutableState, + scheduleID int64, +) *historypb.HistoryEvent { + event, _ := msBuilder.GetActivityScheduledEvent(context.Background(), scheduleID) return event } diff --git a/service/history/nDCActivityReplicator.go b/service/history/nDCActivityReplicator.go index c5c5575b304..5d5209da682 100644 --- a/service/history/nDCActivityReplicator.go +++ b/service/history/nDCActivityReplicator.go @@ -113,7 +113,7 @@ func (r *nDCActivityReplicatorImpl) SyncActivity( } defer func() { release(retError) }() - mutableState, err := executionContext.LoadWorkflowExecution() + mutableState, err := executionContext.LoadWorkflowExecution(ctx) if err != nil { if _, ok := err.(*serviceerror.NotFound); !ok { return err @@ -190,6 +190,7 @@ func (r *nDCActivityReplicatorImpl) SyncActivity( } return executionContext.UpdateWorkflowExecutionWithNew( + ctx, now, updateMode, nil, // no new workflow diff --git a/service/history/nDCActivityReplicator_test.go b/service/history/nDCActivityReplicator_test.go index 30056ae3c96..38babdfa3ba 100644 --- a/service/history/nDCActivityReplicator_test.go +++ b/service/history/nDCActivityReplicator_test.go @@ -610,7 +610,7 @@ func (s *activityReplicatorSuite) TestSyncActivity_WorkflowNotFound() { WorkflowId: workflowID, RunId: runID, } - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(&persistence.GetWorkflowExecutionRequest{ + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), &persistence.GetWorkflowExecutionRequest{ ShardID: s.mockShard.GetShardID(), NamespaceID: namespaceID.String(), WorkflowID: workflowID, @@ -668,7 +668,7 @@ func (s *activityReplicatorSuite) TestSyncActivity_WorkflowClosed() { key := definition.NewWorkflowKey(namespaceID.String(), workflowID, runID) weContext := workflow.NewMockContext(s.controller) - weContext.EXPECT().LoadWorkflowExecution().Return(s.mockMutableState, nil) + weContext.EXPECT().LoadWorkflowExecution(gomock.Any()).Return(s.mockMutableState, nil) weContext.EXPECT().Lock(gomock.Any(), workflow.CallerTypeAPI).Return(nil) weContext.EXPECT().Unlock(workflow.CallerTypeAPI) _, err := s.historyCache.PutIfNotExist(key, weContext) @@ -742,7 +742,7 @@ func (s *activityReplicatorSuite) TestSyncActivity_ActivityNotFound() { key := definition.NewWorkflowKey(namespaceID.String(), workflowID, runID) weContext := workflow.NewMockContext(s.controller) - weContext.EXPECT().LoadWorkflowExecution().Return(s.mockMutableState, nil) + weContext.EXPECT().LoadWorkflowExecution(gomock.Any()).Return(s.mockMutableState, nil) weContext.EXPECT().Lock(gomock.Any(), workflow.CallerTypeAPI).Return(nil) weContext.EXPECT().Unlock(workflow.CallerTypeAPI) _, err := s.historyCache.PutIfNotExist(key, weContext) @@ -817,7 +817,7 @@ func (s *activityReplicatorSuite) TestSyncActivity_ActivityFound_Zombie() { key := definition.NewWorkflowKey(namespaceID.String(), workflowID, runID) weContext := workflow.NewMockContext(s.controller) - weContext.EXPECT().LoadWorkflowExecution().Return(s.mockMutableState, nil) + weContext.EXPECT().LoadWorkflowExecution(gomock.Any()).Return(s.mockMutableState, nil) weContext.EXPECT().Lock(gomock.Any(), workflow.CallerTypeAPI).Return(nil) weContext.EXPECT().Unlock(workflow.CallerTypeAPI) _, err := s.historyCache.PutIfNotExist(key, weContext) @@ -847,6 +847,7 @@ func (s *activityReplicatorSuite) TestSyncActivity_ActivityFound_Zombie() { s.mockClusterMetadata.EXPECT().IsVersionFromSameCluster(version, version).Return(true) weContext.EXPECT().UpdateWorkflowExecutionWithNew( + gomock.Any(), gomock.Any(), persistence.UpdateWorkflowModeBypassCurrent, workflow.Context(nil), @@ -907,7 +908,7 @@ func (s *activityReplicatorSuite) TestSyncActivity_ActivityFound_NonZombie() { key := definition.NewWorkflowKey(namespaceID.String(), workflowID, runID) weContext := workflow.NewMockContext(s.controller) - weContext.EXPECT().LoadWorkflowExecution().Return(s.mockMutableState, nil) + weContext.EXPECT().LoadWorkflowExecution(gomock.Any()).Return(s.mockMutableState, nil) weContext.EXPECT().Lock(gomock.Any(), workflow.CallerTypeAPI).Return(nil) weContext.EXPECT().Unlock(workflow.CallerTypeAPI) _, err := s.historyCache.PutIfNotExist(key, weContext) @@ -937,6 +938,7 @@ func (s *activityReplicatorSuite) TestSyncActivity_ActivityFound_NonZombie() { s.mockClusterMetadata.EXPECT().IsVersionFromSameCluster(version, version).Return(true) weContext.EXPECT().UpdateWorkflowExecutionWithNew( + gomock.Any(), gomock.Any(), persistence.UpdateWorkflowModeUpdateCurrent, workflow.Context(nil), diff --git a/service/history/nDCBranchMgr.go b/service/history/nDCBranchMgr.go index 595f4335eb0..35928b7659d 100644 --- a/service/history/nDCBranchMgr.go +++ b/service/history/nDCBranchMgr.go @@ -185,6 +185,7 @@ func (r *nDCBranchMgrImpl) flushBufferedEvents( } // the workflow must be updated as active, to send out replication tasks if err := targetWorkflow.context.UpdateWorkflowExecutionAsActive( + ctx, r.shard.GetTimeSource().Now(), ); err != nil { return nil, 0, err @@ -243,7 +244,7 @@ func (r *nDCBranchMgrImpl) createNewBranch( namespaceID := executionInfo.NamespaceId workflowID := executionInfo.WorkflowId - resp, err := r.executionMgr.ForkHistoryBranch(&persistence.ForkHistoryBranchRequest{ + resp, err := r.executionMgr.ForkHistoryBranch(ctx, &persistence.ForkHistoryBranchRequest{ ForkBranchToken: baseBranchToken, ForkNodeID: baseBranchLastEventID + 1, Info: persistence.BuildHistoryGarbageCleanupInfo(namespaceID, workflowID, uuid.New()), diff --git a/service/history/nDCBranchMgr_test.go b/service/history/nDCBranchMgr_test.go index e32921488ee..b12f9700ff1 100644 --- a/service/history/nDCBranchMgr_test.go +++ b/service/history/nDCBranchMgr_test.go @@ -144,8 +144,8 @@ func (s *nDCBranchMgrSuite) TestCreateNewBranch() { }).AnyTimes() shardID := s.mockShard.GetShardID() - s.mockExecutionManager.EXPECT().ForkHistoryBranch(gomock.Any()).DoAndReturn( - func(input *persistence.ForkHistoryBranchRequest) (*persistence.ForkHistoryBranchResponse, error) { + s.mockExecutionManager.EXPECT().ForkHistoryBranch(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, input *persistence.ForkHistoryBranchRequest) (*persistence.ForkHistoryBranchResponse, error) { input.Info = "" s.Equal(&persistence.ForkHistoryBranchRequest{ ForkBranchToken: baseBranchToken, @@ -218,7 +218,7 @@ func (s *nDCBranchMgrSuite) TestFlushBufferedEvents() { s.mockClusterMetadata.EXPECT().ClusterNameForFailoverVersion(true, lastWriteVersion).Return(cluster.TestCurrentClusterName).AnyTimes() s.mockClusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName).AnyTimes() - s.mockContext.EXPECT().UpdateWorkflowExecutionAsActive(gomock.Any()).Return(nil) + s.mockContext.EXPECT().UpdateWorkflowExecutionAsActive(gomock.Any(), gomock.Any()).Return(nil) ctx := context.Background() @@ -324,8 +324,8 @@ func (s *nDCBranchMgrSuite) TestPrepareVersionHistory_BranchNotAppendable_NoMiss }).AnyTimes() shardID := s.mockShard.GetShardID() - s.mockExecutionManager.EXPECT().ForkHistoryBranch(gomock.Any()).DoAndReturn( - func(input *persistence.ForkHistoryBranchRequest) (*persistence.ForkHistoryBranchResponse, error) { + s.mockExecutionManager.EXPECT().ForkHistoryBranch(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, input *persistence.ForkHistoryBranchRequest) (*persistence.ForkHistoryBranchResponse, error) { input.Info = "" s.Equal(&persistence.ForkHistoryBranchRequest{ ForkBranchToken: baseBranchToken, diff --git a/service/history/nDCHistoryReplicator.go b/service/history/nDCHistoryReplicator.go index eef3b3a73c0..3845966b3c5 100644 --- a/service/history/nDCHistoryReplicator.go +++ b/service/history/nDCHistoryReplicator.go @@ -247,7 +247,7 @@ func (r *nDCHistoryReplicatorImpl) applyEvents( default: // apply events, other than simple start workflow execution // the continue as new + start workflow execution combination will also be processed here - mutableState, err := context.LoadWorkflowExecution() + mutableState, err := context.LoadWorkflowExecution(ctx) switch err.(type) { case nil: // Sanity check to make only 3DC mutable state here diff --git a/service/history/nDCStateRebuilder.go b/service/history/nDCStateRebuilder.go index cc3eaaf6394..660d50d05af 100644 --- a/service/history/nDCStateRebuilder.go +++ b/service/history/nDCStateRebuilder.go @@ -119,6 +119,7 @@ func (r *nDCStateRebuilderImpl) rebuild( requestID string, ) (workflow.MutableState, int64, error) { iter := collection.NewPagingIterator(r.getPaginationFn( + ctx, common.FirstEventID, baseLastEventID+1, baseBranchToken, @@ -194,7 +195,7 @@ func (r *nDCStateRebuilderImpl) rebuild( rebuiltMutableState.GetExecutionInfo().LastFirstEventTxnId = lastTxnId // refresh tasks to be generated - if err := r.taskRefresher.RefreshTasks(now, rebuiltMutableState); err != nil { + if err := r.taskRefresher.RefreshTasks(ctx, now, rebuiltMutableState); err != nil { return nil, 0, err } @@ -245,12 +246,13 @@ func (r *nDCStateRebuilderImpl) applyEvents( } func (r *nDCStateRebuilderImpl) getPaginationFn( + ctx context.Context, firstEventID int64, nextEventID int64, branchToken []byte, ) collection.PaginationFn { return func(paginationToken []byte) ([]interface{}, []byte, error) { - resp, err := r.executionMgr.ReadHistoryBranchByBatch(&persistence.ReadHistoryBranchRequest{ + resp, err := r.executionMgr.ReadHistoryBranchByBatch(ctx, &persistence.ReadHistoryBranchRequest{ BranchToken: branchToken, MinEventID: firstEventID, MaxEventID: nextEventID, diff --git a/service/history/nDCStateRebuilder_test.go b/service/history/nDCStateRebuilder_test.go index bd9226b612f..b2a74650d20 100644 --- a/service/history/nDCStateRebuilder_test.go +++ b/service/history/nDCStateRebuilder_test.go @@ -206,7 +206,7 @@ func (s *nDCStateRebuilderSuite) TestPagination() { pageToken := []byte("some random token") shardID := s.mockShard.GetShardID() - s.mockExecutionManager.EXPECT().ReadHistoryBranchByBatch(&persistence.ReadHistoryBranchRequest{ + s.mockExecutionManager.EXPECT().ReadHistoryBranchByBatch(gomock.Any(), &persistence.ReadHistoryBranchRequest{ BranchToken: branchToken, MinEventID: firstEventID, MaxEventID: nextEventID, @@ -219,7 +219,7 @@ func (s *nDCStateRebuilderSuite) TestPagination() { NextPageToken: pageToken, Size: 12345, }, nil) - s.mockExecutionManager.EXPECT().ReadHistoryBranchByBatch(&persistence.ReadHistoryBranchRequest{ + s.mockExecutionManager.EXPECT().ReadHistoryBranchByBatch(gomock.Any(), &persistence.ReadHistoryBranchRequest{ BranchToken: branchToken, MinEventID: firstEventID, MaxEventID: nextEventID, @@ -233,7 +233,7 @@ func (s *nDCStateRebuilderSuite) TestPagination() { Size: 67890, }, nil) - paginationFn := s.nDCStateRebuilder.getPaginationFn(firstEventID, nextEventID, branchToken) + paginationFn := s.nDCStateRebuilder.getPaginationFn(context.Background(), firstEventID, nextEventID, branchToken) iter := collection.NewPagingIterator(paginationFn) var result []*HistoryBlobsPaginationItem @@ -298,7 +298,7 @@ func (s *nDCStateRebuilderSuite) TestRebuild() { historySize1 := 12345 historySize2 := 67890 shardID := s.mockShard.GetShardID() - s.mockExecutionManager.EXPECT().ReadHistoryBranchByBatch(&persistence.ReadHistoryBranchRequest{ + s.mockExecutionManager.EXPECT().ReadHistoryBranchByBatch(gomock.Any(), &persistence.ReadHistoryBranchRequest{ BranchToken: branchToken, MinEventID: firstEventID, MaxEventID: nextEventID, @@ -312,7 +312,7 @@ func (s *nDCStateRebuilderSuite) TestRebuild() { Size: historySize1, }, nil) expectedLastFirstTransactionID := int64(20) - s.mockExecutionManager.EXPECT().ReadHistoryBranchByBatch(&persistence.ReadHistoryBranchRequest{ + s.mockExecutionManager.EXPECT().ReadHistoryBranchByBatch(gomock.Any(), &persistence.ReadHistoryBranchRequest{ BranchToken: branchToken, MinEventID: firstEventID, MaxEventID: nextEventID, @@ -338,7 +338,7 @@ func (s *nDCStateRebuilderSuite) TestRebuild() { }, 1234, ), nil).AnyTimes() - s.mockTaskRefresher.EXPECT().RefreshTasks(s.now, gomock.Any()).Return(nil) + s.mockTaskRefresher.EXPECT().RefreshTasks(gomock.Any(), s.now, gomock.Any()).Return(nil) rebuildMutableState, rebuiltHistorySize, err := s.nDCStateRebuilder.rebuild( context.Background(), diff --git a/service/history/nDCTaskUtil.go b/service/history/nDCTaskUtil.go index 8c44eeea1a7..38013cfcc16 100644 --- a/service/history/nDCTaskUtil.go +++ b/service/history/nDCTaskUtil.go @@ -25,6 +25,7 @@ package history import ( + "context" "time" persistencespb "go.temporal.io/server/api/persistence/v1" @@ -71,13 +72,15 @@ func VerifyTaskVersion( // load mutable state, if mutable state's next event ID <= task ID, will attempt to refresh // if still mutable state's next event ID <= task ID, will return nil, nil func loadMutableStateForTransferTask( - context workflow.Context, + ctx context.Context, + wfContext workflow.Context, transferTask tasks.Task, metricsClient metrics.Client, logger log.Logger, ) (workflow.MutableState, error) { return LoadMutableStateForTask( - context, + ctx, + wfContext, transferTask, getTransferTaskEventIDAndRetryable, metricsClient.Scope(metrics.TransferQueueProcessorScope), @@ -88,13 +91,15 @@ func loadMutableStateForTransferTask( // load mutable state, if mutable state's next event ID <= task ID, will attempt to refresh // if still mutable state's next event ID <= task ID, will return nil, nil func loadMutableStateForTimerTask( - context workflow.Context, + ctx context.Context, + wfContext workflow.Context, timerTask tasks.Task, metricsClient metrics.Client, logger log.Logger, ) (workflow.MutableState, error) { return LoadMutableStateForTask( - context, + ctx, + wfContext, timerTask, getTimerTaskEventIDAndRetryable, metricsClient.Scope(metrics.TimerQueueProcessorScope), @@ -103,14 +108,15 @@ func loadMutableStateForTimerTask( } func LoadMutableStateForTask( - context workflow.Context, + ctx context.Context, + wfContext workflow.Context, task tasks.Task, taskEventIDAndRetryable func(task tasks.Task, executionInfo *persistencespb.WorkflowExecutionInfo) (int64, bool), scope metrics.Scope, logger log.Logger, ) (workflow.MutableState, error) { - mutableState, err := context.LoadWorkflowExecution() + mutableState, err := wfContext.LoadWorkflowExecution(ctx) if err != nil { return nil, err } @@ -124,9 +130,9 @@ func LoadMutableStateForTask( } scope.IncCounter(metrics.StaleMutableStateCounter) - context.Clear() + wfContext.Clear() - mutableState, err = context.LoadWorkflowExecution() + mutableState, err = wfContext.LoadWorkflowExecution(ctx) if err != nil { return nil, err } diff --git a/service/history/nDCTransactionMgr.go b/service/history/nDCTransactionMgr.go index 7bf3afd683d..d398844d7d1 100644 --- a/service/history/nDCTransactionMgr.go +++ b/service/history/nDCTransactionMgr.go @@ -250,6 +250,7 @@ func (r *nDCTransactionMgrImpl) backfillWorkflow( }() if _, err := targetWorkflow.getContext().PersistWorkflowEvents( + ctx, targetWorkflowEvents, ); err != nil { return err @@ -265,6 +266,7 @@ func (r *nDCTransactionMgrImpl) backfillWorkflow( } return targetWorkflow.getContext().UpdateWorkflowExecutionWithNew( + ctx, now, updateMode, nil, @@ -382,13 +384,14 @@ func (r *nDCTransactionMgrImpl) backfillWorkflowEventsReapply( } func (r *nDCTransactionMgrImpl) checkWorkflowExists( - _ context.Context, + ctx context.Context, namespaceID namespace.ID, workflowID string, runID string, ) (bool, error) { _, err := r.shard.GetExecutionManager().GetWorkflowExecution( + ctx, &persistence.GetWorkflowExecutionRequest{ ShardID: r.shard.GetShardID(), NamespaceID: namespaceID.String(), @@ -408,12 +411,13 @@ func (r *nDCTransactionMgrImpl) checkWorkflowExists( } func (r *nDCTransactionMgrImpl) getCurrentWorkflowRunID( - _ context.Context, + ctx context.Context, namespaceID namespace.ID, workflowID string, ) (string, error) { resp, err := r.shard.GetExecutionManager().GetCurrentExecution( + ctx, &persistence.GetCurrentExecutionRequest{ ShardID: r.shard.GetShardID(), NamespaceID: namespaceID.String(), @@ -452,7 +456,7 @@ func (r *nDCTransactionMgrImpl) loadNDCWorkflow( return nil, err } - msBuilder, err := weContext.LoadWorkflowExecution() + msBuilder, err := weContext.LoadWorkflowExecution(ctx) if err != nil { // no matter what error happen, we need to retry release(err) diff --git a/service/history/nDCTransactionMgrForExistingWorkflow.go b/service/history/nDCTransactionMgrForExistingWorkflow.go index 6eefb2ec803..536a0ec8af1 100644 --- a/service/history/nDCTransactionMgrForExistingWorkflow.go +++ b/service/history/nDCTransactionMgrForExistingWorkflow.go @@ -227,17 +227,18 @@ func (r *nDCTransactionMgrForExistingWorkflowImpl) dispatchWorkflowUpdateAsZombi } func (r *nDCTransactionMgrForExistingWorkflowImpl) updateAsCurrent( - _ context.Context, + ctx context.Context, now time.Time, targetWorkflow nDCWorkflow, newWorkflow nDCWorkflow, ) error { if newWorkflow == nil { - return targetWorkflow.getContext().UpdateWorkflowExecutionAsPassive(now) + return targetWorkflow.getContext().UpdateWorkflowExecutionAsPassive(ctx, now) } return targetWorkflow.getContext().UpdateWorkflowExecutionWithNewAsPassive( + ctx, now, newWorkflow.getContext(), newWorkflow.getMutableState(), @@ -308,6 +309,7 @@ func (r *nDCTransactionMgrForExistingWorkflowImpl) updateAsZombie( currentWorkflow = nil return targetWorkflow.getContext().UpdateWorkflowExecutionWithNew( + ctx, now, persistence.UpdateWorkflowModeBypassCurrent, newContext, @@ -318,7 +320,7 @@ func (r *nDCTransactionMgrForExistingWorkflowImpl) updateAsZombie( } func (r *nDCTransactionMgrForExistingWorkflowImpl) suppressCurrentAndUpdateAsCurrent( - _ context.Context, + ctx context.Context, now time.Time, currentWorkflow nDCWorkflow, targetWorkflow nDCWorkflow, @@ -351,6 +353,7 @@ func (r *nDCTransactionMgrForExistingWorkflowImpl) suppressCurrentAndUpdateAsCur } return targetWorkflow.getContext().ConflictResolveWorkflowExecution( + ctx, now, persistence.ConflictResolveWorkflowModeUpdateCurrent, targetWorkflow.getMutableState(), @@ -363,7 +366,7 @@ func (r *nDCTransactionMgrForExistingWorkflowImpl) suppressCurrentAndUpdateAsCur } func (r *nDCTransactionMgrForExistingWorkflowImpl) conflictResolveAsCurrent( - _ context.Context, + ctx context.Context, now time.Time, targetWorkflow nDCWorkflow, newWorkflow nDCWorkflow, @@ -377,6 +380,7 @@ func (r *nDCTransactionMgrForExistingWorkflowImpl) conflictResolveAsCurrent( } return targetWorkflow.getContext().ConflictResolveWorkflowExecution( + ctx, now, persistence.ConflictResolveWorkflowModeUpdateCurrent, targetWorkflow.getMutableState(), @@ -449,6 +453,7 @@ func (r *nDCTransactionMgrForExistingWorkflowImpl) conflictResolveAsZombie( currentWorkflow = nil return targetWorkflow.getContext().ConflictResolveWorkflowExecution( + ctx, now, persistence.ConflictResolveWorkflowModeBypassCurrent, targetWorkflow.getMutableState(), diff --git a/service/history/nDCTransactionMgrForExistingWorkflow_test.go b/service/history/nDCTransactionMgrForExistingWorkflow_test.go index 58cf952bfb2..03fcb6264f9 100644 --- a/service/history/nDCTransactionMgrForExistingWorkflow_test.go +++ b/service/history/nDCTransactionMgrForExistingWorkflow_test.go @@ -97,6 +97,7 @@ func (s *nDCTransactionMgrForExistingWorkflowSuite) TestDispatchForExistingWorkf targetMutableState.EXPECT().IsCurrentWorkflowGuaranteed().Return(true).AnyTimes() targetContext.EXPECT().UpdateWorkflowExecutionWithNewAsPassive( + gomock.Any(), now, newContext, newMutableState, @@ -196,6 +197,7 @@ func (s *nDCTransactionMgrForExistingWorkflowSuite) TestDispatchForExistingWorkf targetWorkflow.EXPECT().revive().Return(nil) targetContext.EXPECT().ConflictResolveWorkflowExecution( + gomock.Any(), now, persistence.ConflictResolveWorkflowModeUpdateCurrent, targetMutableState, @@ -271,6 +273,7 @@ func (s *nDCTransactionMgrForExistingWorkflowSuite) TestDispatchForExistingWorkf targetWorkflow.EXPECT().revive().Return(nil) targetContext.EXPECT().ConflictResolveWorkflowExecution( + gomock.Any(), now, persistence.ConflictResolveWorkflowModeUpdateCurrent, targetMutableState, @@ -348,6 +351,7 @@ func (s *nDCTransactionMgrForExistingWorkflowSuite) TestDispatchForExistingWorkf newWorkflow.EXPECT().suppressBy(currentWorkflow).Return(workflow.TransactionPolicyPassive, nil) targetContext.EXPECT().UpdateWorkflowExecutionWithNew( + gomock.Any(), now, persistence.UpdateWorkflowModeBypassCurrent, newContext, @@ -423,6 +427,7 @@ func (s *nDCTransactionMgrForExistingWorkflowSuite) TestDispatchForExistingWorkf newWorkflow.EXPECT().suppressBy(currentWorkflow).Return(workflow.TransactionPolicyPassive, nil) targetContext.EXPECT().UpdateWorkflowExecutionWithNew( + gomock.Any(), now, persistence.UpdateWorkflowModeBypassCurrent, (workflow.Context)(nil), @@ -477,6 +482,7 @@ func (s *nDCTransactionMgrForExistingWorkflowSuite) TestDispatchForExistingWorkf s.mockTransactionMgr.EXPECT().getCurrentWorkflowRunID(ctx, namespaceID, workflowID).Return(targetRunID, nil) targetContext.EXPECT().ConflictResolveWorkflowExecution( + gomock.Any(), now, persistence.ConflictResolveWorkflowModeUpdateCurrent, targetMutableState, @@ -550,6 +556,7 @@ func (s *nDCTransactionMgrForExistingWorkflowSuite) TestDispatchForExistingWorkf targetWorkflow.EXPECT().revive().Return(nil) targetContext.EXPECT().ConflictResolveWorkflowExecution( + gomock.Any(), now, persistence.ConflictResolveWorkflowModeUpdateCurrent, targetMutableState, @@ -626,6 +633,7 @@ func (s *nDCTransactionMgrForExistingWorkflowSuite) TestDispatchForExistingWorkf newWorkflow.EXPECT().suppressBy(currentWorkflow).Return(workflow.TransactionPolicyPassive, nil) targetContext.EXPECT().ConflictResolveWorkflowExecution( + gomock.Any(), now, persistence.ConflictResolveWorkflowModeBypassCurrent, targetMutableState, @@ -702,6 +710,7 @@ func (s *nDCTransactionMgrForExistingWorkflowSuite) TestDispatchForExistingWorkf newWorkflow.EXPECT().suppressBy(currentWorkflow).Return(workflow.TransactionPolicyPassive, nil) targetContext.EXPECT().ConflictResolveWorkflowExecution( + gomock.Any(), now, persistence.ConflictResolveWorkflowModeBypassCurrent, targetMutableState, diff --git a/service/history/nDCTransactionMgrForNewWorkflow.go b/service/history/nDCTransactionMgrForNewWorkflow.go index e64f123501a..f85e24d385b 100644 --- a/service/history/nDCTransactionMgrForNewWorkflow.go +++ b/service/history/nDCTransactionMgrForNewWorkflow.go @@ -150,7 +150,7 @@ func (r *nDCTransactionMgrForNewWorkflowImpl) dispatchForNewWorkflow( } func (r *nDCTransactionMgrForNewWorkflowImpl) createAsCurrent( - _ context.Context, + ctx context.Context, now time.Time, currentWorkflow nDCWorkflow, targetWorkflow nDCWorkflow, @@ -177,6 +177,7 @@ func (r *nDCTransactionMgrForNewWorkflowImpl) createAsCurrent( return err } return targetWorkflow.getContext().CreateWorkflowExecution( + ctx, now, createMode, prevRunID, @@ -192,6 +193,7 @@ func (r *nDCTransactionMgrForNewWorkflowImpl) createAsCurrent( prevRunID := "" prevLastWriteVersion := int64(0) return targetWorkflow.getContext().CreateWorkflowExecution( + ctx, now, createMode, prevRunID, @@ -203,7 +205,7 @@ func (r *nDCTransactionMgrForNewWorkflowImpl) createAsCurrent( } func (r *nDCTransactionMgrForNewWorkflowImpl) createAsZombie( - _ context.Context, + ctx context.Context, now time.Time, currentWorkflow nDCWorkflow, targetWorkflow nDCWorkflow, @@ -245,6 +247,7 @@ func (r *nDCTransactionMgrForNewWorkflowImpl) createAsZombie( prevRunID := "" prevLastWriteVersion := int64(0) err = targetWorkflow.getContext().CreateWorkflowExecution( + ctx, now, createMode, prevRunID, @@ -265,7 +268,7 @@ func (r *nDCTransactionMgrForNewWorkflowImpl) createAsZombie( } func (r *nDCTransactionMgrForNewWorkflowImpl) suppressCurrentAndCreateAsCurrent( - _ context.Context, + ctx context.Context, now time.Time, currentWorkflow nDCWorkflow, targetWorkflow nDCWorkflow, @@ -282,6 +285,7 @@ func (r *nDCTransactionMgrForNewWorkflowImpl) suppressCurrentAndCreateAsCurrent( } return currentWorkflow.getContext().UpdateWorkflowExecutionWithNew( + ctx, now, persistence.UpdateWorkflowModeUpdateCurrent, targetWorkflow.getContext(), @@ -351,10 +355,3 @@ func (r *nDCTransactionMgrForNewWorkflowImpl) cleanupTransaction( targetWorkflow.getReleaseFn()(err) } } - -func (r *nDCTransactionMgrForNewWorkflowImpl) persistNewNDCWorkflowEvents( - targetNewWorkflow nDCWorkflow, - targetNewWorkflowEvents *persistence.WorkflowEvents, -) (int64, error) { - return targetNewWorkflow.getContext().PersistWorkflowEvents(targetNewWorkflowEvents) -} diff --git a/service/history/nDCTransactionMgrForNewWorkflow_test.go b/service/history/nDCTransactionMgrForNewWorkflow_test.go index 268a2b98b5b..43d508fc4dc 100644 --- a/service/history/nDCTransactionMgrForNewWorkflow_test.go +++ b/service/history/nDCTransactionMgrForNewWorkflow_test.go @@ -137,6 +137,7 @@ func (s *nDCTransactionMgrForNewWorkflowSuite) TestDispatchForNewWorkflow_BrandN ).Return("", nil) weContext.EXPECT().CreateWorkflowExecution( + gomock.Any(), now, persistence.CreateWorkflowModeBrandNew, "", @@ -191,6 +192,7 @@ func (s *nDCTransactionMgrForNewWorkflowSuite) TestDispatchForNewWorkflow_BrandN ).Return("", nil) weContext.EXPECT().CreateWorkflowExecution( + gomock.Any(), now, persistence.CreateWorkflowModeBrandNew, "", @@ -264,6 +266,7 @@ func (s *nDCTransactionMgrForNewWorkflowSuite) TestDispatchForNewWorkflow_Create currentWorkflow.EXPECT().getVectorClock().Return(currentLastWriteVersion, int64(0), nil) targetContext.EXPECT().CreateWorkflowExecution( + gomock.Any(), now, persistence.CreateWorkflowModeWorkflowIDReuse, currentRunID, @@ -338,6 +341,7 @@ func (s *nDCTransactionMgrForNewWorkflowSuite) TestDispatchForNewWorkflow_Create currentWorkflow.EXPECT().getVectorClock().Return(currentLastWriteVersion, int64(0), nil) targetContext.EXPECT().CreateWorkflowExecution( + gomock.Any(), now, persistence.CreateWorkflowModeWorkflowIDReuse, currentRunID, @@ -406,6 +410,7 @@ func (s *nDCTransactionMgrForNewWorkflowSuite) TestDispatchForNewWorkflow_Create targetWorkflow.EXPECT().suppressBy(currentWorkflow).Return(workflow.TransactionPolicyPassive, nil) targetContext.EXPECT().CreateWorkflowExecution( + gomock.Any(), now, persistence.CreateWorkflowModeZombie, "", @@ -475,6 +480,7 @@ func (s *nDCTransactionMgrForNewWorkflowSuite) TestDispatchForNewWorkflow_Create targetWorkflow.EXPECT().suppressBy(currentWorkflow).Return(workflow.TransactionPolicyPassive, nil) targetContext.EXPECT().CreateWorkflowExecution( + gomock.Any(), now, persistence.CreateWorkflowModeZombie, "", @@ -544,6 +550,7 @@ func (s *nDCTransactionMgrForNewWorkflowSuite) TestDispatchForNewWorkflow_Create targetWorkflow.EXPECT().suppressBy(currentWorkflow).Return(workflow.TransactionPolicyPassive, nil) targetContext.EXPECT().CreateWorkflowExecution( + gomock.Any(), now, persistence.CreateWorkflowModeZombie, "", @@ -613,6 +620,7 @@ func (s *nDCTransactionMgrForNewWorkflowSuite) TestDispatchForNewWorkflow_Create targetWorkflow.EXPECT().suppressBy(currentWorkflow).Return(workflow.TransactionPolicyPassive, nil) targetContext.EXPECT().CreateWorkflowExecution( + gomock.Any(), now, persistence.CreateWorkflowModeZombie, "", @@ -675,6 +683,7 @@ func (s *nDCTransactionMgrForNewWorkflowSuite) TestDispatchForNewWorkflow_Suppre targetWorkflow.EXPECT().revive().Return(nil) currentContext.EXPECT().UpdateWorkflowExecutionWithNew( + gomock.Any(), now, persistence.UpdateWorkflowModeUpdateCurrent, targetContext, diff --git a/service/history/nDCTransactionMgr_test.go b/service/history/nDCTransactionMgr_test.go index 7872379ae4a..d6533556756 100644 --- a/service/history/nDCTransactionMgr_test.go +++ b/service/history/nDCTransactionMgr_test.go @@ -169,9 +169,9 @@ func (s *nDCTransactionMgrSuite) TestBackfillWorkflow_CurrentWorkflow_Active_Ope mutableState.EXPECT().IsWorkflowExecutionRunning().Return(true).AnyTimes() mutableState.EXPECT().GetNamespaceEntry().Return(s.namespaceEntry).AnyTimes() mutableState.EXPECT().GetExecutionState().Return(&persistencespb.WorkflowExecutionState{RunId: runID}) - weContext.EXPECT().PersistWorkflowEvents(workflowEvents).Return(int64(0), nil) + weContext.EXPECT().PersistWorkflowEvents(gomock.Any(), workflowEvents).Return(int64(0), nil) weContext.EXPECT().UpdateWorkflowExecutionWithNew( - now, persistence.UpdateWorkflowModeUpdateCurrent, nil, nil, workflow.TransactionPolicyActive, (*workflow.TransactionPolicy)(nil), + gomock.Any(), now, persistence.UpdateWorkflowModeUpdateCurrent, nil, nil, workflow.TransactionPolicyActive, (*workflow.TransactionPolicy)(nil), ).Return(nil) err := s.transactionMgr.backfillWorkflow(ctx, now, targetWorkflow, workflowEvents) s.NoError(err) @@ -240,15 +240,15 @@ func (s *nDCTransactionMgrSuite) TestBackfillWorkflow_CurrentWorkflow_Active_Clo enumspb.RESET_REAPPLY_TYPE_SIGNAL, ).Return(nil) - s.mockExecutionMgr.EXPECT().GetCurrentExecution(&persistence.GetCurrentExecutionRequest{ + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), &persistence.GetCurrentExecutionRequest{ ShardID: s.mockShard.GetShardID(), NamespaceID: namespaceID.String(), WorkflowID: workflowID, }).Return(&persistence.GetCurrentExecutionResponse{RunID: runID}, nil) - weContext.EXPECT().PersistWorkflowEvents(workflowEvents).Return(int64(0), nil) + weContext.EXPECT().PersistWorkflowEvents(gomock.Any(), workflowEvents).Return(int64(0), nil) weContext.EXPECT().UpdateWorkflowExecutionWithNew( - now, persistence.UpdateWorkflowModeBypassCurrent, nil, nil, workflow.TransactionPolicyPassive, (*workflow.TransactionPolicy)(nil), + gomock.Any(), now, persistence.UpdateWorkflowModeBypassCurrent, nil, nil, workflow.TransactionPolicyPassive, (*workflow.TransactionPolicy)(nil), ).Return(nil) err := s.transactionMgr.backfillWorkflow(ctx, now, targetWorkflow, workflowEvents) @@ -281,9 +281,9 @@ func (s *nDCTransactionMgrSuite) TestBackfillWorkflow_CurrentWorkflow_Passive_Op mutableState.EXPECT().IsWorkflowExecutionRunning().Return(true).AnyTimes() mutableState.EXPECT().GetNamespaceEntry().Return(s.namespaceEntry).AnyTimes() weContext.EXPECT().ReapplyEvents([]*persistence.WorkflowEvents{workflowEvents}) - weContext.EXPECT().PersistWorkflowEvents(workflowEvents).Return(int64(0), nil) + weContext.EXPECT().PersistWorkflowEvents(gomock.Any(), workflowEvents).Return(int64(0), nil) weContext.EXPECT().UpdateWorkflowExecutionWithNew( - now, persistence.UpdateWorkflowModeUpdateCurrent, nil, nil, workflow.TransactionPolicyPassive, (*workflow.TransactionPolicy)(nil), + gomock.Any(), now, persistence.UpdateWorkflowModeUpdateCurrent, nil, nil, workflow.TransactionPolicyPassive, (*workflow.TransactionPolicy)(nil), ).Return(nil) err := s.transactionMgr.backfillWorkflow(ctx, now, targetWorkflow, workflowEvents) s.NoError(err) @@ -325,15 +325,15 @@ func (s *nDCTransactionMgrSuite) TestBackfillWorkflow_CurrentWorkflow_Passive_Cl RunId: runID, }).AnyTimes() - s.mockExecutionMgr.EXPECT().GetCurrentExecution(&persistence.GetCurrentExecutionRequest{ + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), &persistence.GetCurrentExecutionRequest{ ShardID: s.mockShard.GetShardID(), NamespaceID: namespaceID.String(), WorkflowID: workflowID, }).Return(&persistence.GetCurrentExecutionResponse{RunID: runID}, nil) weContext.EXPECT().ReapplyEvents([]*persistence.WorkflowEvents{workflowEvents}) - weContext.EXPECT().PersistWorkflowEvents(workflowEvents).Return(int64(0), nil) + weContext.EXPECT().PersistWorkflowEvents(gomock.Any(), workflowEvents).Return(int64(0), nil) weContext.EXPECT().UpdateWorkflowExecutionWithNew( - now, persistence.UpdateWorkflowModeUpdateCurrent, nil, nil, workflow.TransactionPolicyPassive, (*workflow.TransactionPolicy)(nil), + gomock.Any(), now, persistence.UpdateWorkflowModeUpdateCurrent, nil, nil, workflow.TransactionPolicyPassive, (*workflow.TransactionPolicy)(nil), ).Return(nil) err := s.transactionMgr.backfillWorkflow(ctx, now, targetWorkflow, workflowEvents) @@ -383,15 +383,15 @@ func (s *nDCTransactionMgrSuite) TestBackfillWorkflow_NotCurrentWorkflow_Active( RunId: runID, }).AnyTimes() - s.mockExecutionMgr.EXPECT().GetCurrentExecution(&persistence.GetCurrentExecutionRequest{ + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), &persistence.GetCurrentExecutionRequest{ ShardID: s.mockShard.GetShardID(), NamespaceID: namespaceID.String(), WorkflowID: workflowID, }).Return(&persistence.GetCurrentExecutionResponse{RunID: currentRunID}, nil) weContext.EXPECT().ReapplyEvents([]*persistence.WorkflowEvents{workflowEvents}) - weContext.EXPECT().PersistWorkflowEvents(workflowEvents).Return(int64(0), nil) + weContext.EXPECT().PersistWorkflowEvents(gomock.Any(), workflowEvents).Return(int64(0), nil) weContext.EXPECT().UpdateWorkflowExecutionWithNew( - now, persistence.UpdateWorkflowModeBypassCurrent, nil, nil, workflow.TransactionPolicyPassive, (*workflow.TransactionPolicy)(nil), + gomock.Any(), now, persistence.UpdateWorkflowModeBypassCurrent, nil, nil, workflow.TransactionPolicyPassive, (*workflow.TransactionPolicy)(nil), ).Return(nil) err := s.transactionMgr.backfillWorkflow(ctx, now, targetWorkflow, workflowEvents) s.NoError(err) @@ -440,15 +440,15 @@ func (s *nDCTransactionMgrSuite) TestBackfillWorkflow_NotCurrentWorkflow_Passive RunId: runID, }).AnyTimes() - s.mockExecutionMgr.EXPECT().GetCurrentExecution(&persistence.GetCurrentExecutionRequest{ + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), &persistence.GetCurrentExecutionRequest{ ShardID: s.mockShard.GetShardID(), NamespaceID: namespaceID.String(), WorkflowID: workflowID, }).Return(&persistence.GetCurrentExecutionResponse{RunID: currentRunID}, nil) weContext.EXPECT().ReapplyEvents([]*persistence.WorkflowEvents{workflowEvents}) - weContext.EXPECT().PersistWorkflowEvents(workflowEvents).Return(int64(0), nil) + weContext.EXPECT().PersistWorkflowEvents(gomock.Any(), workflowEvents).Return(int64(0), nil) weContext.EXPECT().UpdateWorkflowExecutionWithNew( - now, persistence.UpdateWorkflowModeBypassCurrent, nil, nil, workflow.TransactionPolicyPassive, (*workflow.TransactionPolicy)(nil), + gomock.Any(), now, persistence.UpdateWorkflowModeBypassCurrent, nil, nil, workflow.TransactionPolicyPassive, (*workflow.TransactionPolicy)(nil), ).Return(nil) err := s.transactionMgr.backfillWorkflow(ctx, now, targetWorkflow, workflowEvents) s.NoError(err) @@ -461,7 +461,7 @@ func (s *nDCTransactionMgrSuite) TestCheckWorkflowExists_DoesNotExists() { workflowID := "some random workflow ID" runID := "some random run ID" - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(&persistence.GetWorkflowExecutionRequest{ + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), &persistence.GetWorkflowExecutionRequest{ ShardID: s.mockShard.GetShardID(), NamespaceID: namespaceID.String(), WorkflowID: workflowID, @@ -479,7 +479,7 @@ func (s *nDCTransactionMgrSuite) TestCheckWorkflowExists_DoesExists() { workflowID := "some random workflow ID" runID := "some random run ID" - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(&persistence.GetWorkflowExecutionRequest{ + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), &persistence.GetWorkflowExecutionRequest{ ShardID: s.mockShard.GetShardID(), NamespaceID: namespaceID.String(), WorkflowID: workflowID, @@ -496,7 +496,7 @@ func (s *nDCTransactionMgrSuite) TestGetWorkflowCurrentRunID_Missing() { namespaceID := namespace.ID("some random namespace ID") workflowID := "some random workflow ID" - s.mockExecutionMgr.EXPECT().GetCurrentExecution(&persistence.GetCurrentExecutionRequest{ + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), &persistence.GetCurrentExecutionRequest{ ShardID: s.mockShard.GetShardID(), NamespaceID: namespaceID.String(), WorkflowID: workflowID, @@ -513,7 +513,7 @@ func (s *nDCTransactionMgrSuite) TestGetWorkflowCurrentRunID_Exists() { workflowID := "some random workflow ID" runID := "some random run ID" - s.mockExecutionMgr.EXPECT().GetCurrentExecution(&persistence.GetCurrentExecutionRequest{ + s.mockExecutionMgr.EXPECT().GetCurrentExecution(gomock.Any(), &persistence.GetCurrentExecutionRequest{ ShardID: s.mockShard.GetShardID(), NamespaceID: namespaceID.String(), WorkflowID: workflowID, diff --git a/service/history/nDCWorkflowResetter.go b/service/history/nDCWorkflowResetter.go index eead8540e35..6051875bbea 100644 --- a/service/history/nDCWorkflowResetter.go +++ b/service/history/nDCWorkflowResetter.go @@ -213,7 +213,7 @@ func (r *nDCWorkflowResetterImpl) getResetBranchToken( // fork a new history branch shardID := r.shard.GetShardID() - resp, err := r.executionMgr.ForkHistoryBranch(&persistence.ForkHistoryBranchRequest{ + resp, err := r.executionMgr.ForkHistoryBranch(ctx, &persistence.ForkHistoryBranchRequest{ ForkBranchToken: baseBranchToken, ForkNodeID: baseLastEventID + 1, Info: persistence.BuildHistoryGarbageCleanupInfo(r.namespaceID.String(), r.workflowID, r.newRunID), diff --git a/service/history/nDCWorkflowResetter_test.go b/service/history/nDCWorkflowResetter_test.go index 9ae303c936b..732a3cbe149 100644 --- a/service/history/nDCWorkflowResetter_test.go +++ b/service/history/nDCWorkflowResetter_test.go @@ -189,7 +189,7 @@ func (s *nDCWorkflowResetterSuite) TestResetWorkflow_NoError() { ).Return(s.mockRebuiltMutableState, rebuiltHistorySize, nil) shardID := s.mockShard.GetShardID() - s.mockExecManager.EXPECT().ForkHistoryBranch(&persistence.ForkHistoryBranchRequest{ + s.mockExecManager.EXPECT().ForkHistoryBranch(gomock.Any(), &persistence.ForkHistoryBranchRequest{ ForkBranchToken: branchToken, ForkNodeID: baseEventID + 1, Info: persistence.BuildHistoryGarbageCleanupInfo(s.namespaceID.String(), s.workflowID, s.newRunID), diff --git a/service/history/replicationDLQHandler.go b/service/history/replicationDLQHandler.go index d9d3774344c..1eca423cc47 100644 --- a/service/history/replicationDLQHandler.go +++ b/service/history/replicationDLQHandler.go @@ -64,6 +64,7 @@ type ( pageToken []byte, ) ([]*replicationspb.ReplicationTask, []byte, error) purgeMessages( + ctx context.Context, sourceCluster string, lastMessageID int64, ) error @@ -130,7 +131,7 @@ func (r *replicationDLQHandlerImpl) readMessagesWithAckLevel( ) ([]*replicationspb.ReplicationTask, int64, []byte, error) { ackLevel := r.shard.GetReplicatorDLQAckLevel(sourceCluster) - resp, err := r.shard.GetExecutionManager().GetReplicationTasksFromDLQ(&persistence.GetReplicationTasksFromDLQRequest{ + resp, err := r.shard.GetExecutionManager().GetReplicationTasksFromDLQ(ctx, &persistence.GetReplicationTasksFromDLQRequest{ GetHistoryTasksRequest: persistence.GetHistoryTasksRequest{ ShardID: r.shard.GetShardID(), TaskCategory: tasks.CategoryReplication, @@ -197,12 +198,14 @@ func (r *replicationDLQHandlerImpl) readMessagesWithAckLevel( } func (r *replicationDLQHandlerImpl) purgeMessages( + ctx context.Context, sourceCluster string, lastMessageID int64, ) error { ackLevel := r.shard.GetReplicatorDLQAckLevel(sourceCluster) err := r.shard.GetExecutionManager().RangeDeleteReplicationTaskFromDLQ( + ctx, &persistence.RangeDeleteReplicationTaskFromDLQRequest{ RangeCompleteHistoryTasksRequest: persistence.RangeCompleteHistoryTasksRequest{ ShardID: r.shard.GetShardID(), @@ -258,6 +261,7 @@ func (r *replicationDLQHandlerImpl) mergeMessages( } err = r.shard.GetExecutionManager().RangeDeleteReplicationTaskFromDLQ( + ctx, &persistence.RangeDeleteReplicationTaskFromDLQRequest{ RangeCompleteHistoryTasksRequest: persistence.RangeCompleteHistoryTasksRequest{ ShardID: r.shard.GetShardID(), diff --git a/service/history/replicationDLQHandler_mock.go b/service/history/replicationDLQHandler_mock.go index 39e216fd5d8..440f3b12b94 100644 --- a/service/history/replicationDLQHandler_mock.go +++ b/service/history/replicationDLQHandler_mock.go @@ -91,15 +91,15 @@ func (mr *MockreplicationDLQHandlerMockRecorder) mergeMessages(ctx, sourceCluste } // purgeMessages mocks base method. -func (m *MockreplicationDLQHandler) purgeMessages(sourceCluster string, lastMessageID int64) error { +func (m *MockreplicationDLQHandler) purgeMessages(ctx context.Context, sourceCluster string, lastMessageID int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "purgeMessages", sourceCluster, lastMessageID) + ret := m.ctrl.Call(m, "purgeMessages", ctx, sourceCluster, lastMessageID) ret0, _ := ret[0].(error) return ret0 } // purgeMessages indicates an expected call of purgeMessages. -func (mr *MockreplicationDLQHandlerMockRecorder) purgeMessages(sourceCluster, lastMessageID interface{}) *gomock.Call { +func (mr *MockreplicationDLQHandlerMockRecorder) purgeMessages(ctx, sourceCluster, lastMessageID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "purgeMessages", reflect.TypeOf((*MockreplicationDLQHandler)(nil).purgeMessages), sourceCluster, lastMessageID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "purgeMessages", reflect.TypeOf((*MockreplicationDLQHandler)(nil).purgeMessages), ctx, sourceCluster, lastMessageID) } diff --git a/service/history/replicationDLQHandler_test.go b/service/history/replicationDLQHandler_test.go index 83e3ba2185a..1c501b16659 100644 --- a/service/history/replicationDLQHandler_test.go +++ b/service/history/replicationDLQHandler_test.go @@ -172,7 +172,7 @@ func (s *replicationDLQHandlerSuite) TestReadMessages_OK() { }, } - s.executionManager.EXPECT().GetReplicationTasksFromDLQ(&persistence.GetReplicationTasksFromDLQRequest{ + s.executionManager.EXPECT().GetReplicationTasksFromDLQ(gomock.Any(), &persistence.GetReplicationTasksFromDLQRequest{ GetHistoryTasksRequest: persistence.GetHistoryTasksRequest{ ShardID: s.mockShard.GetShardID(), TaskCategory: tasks.CategoryReplication, @@ -199,6 +199,7 @@ func (s *replicationDLQHandlerSuite) TestPurgeMessages() { lastMessageID := int64(1) s.executionManager.EXPECT().RangeDeleteReplicationTaskFromDLQ( + gomock.Any(), &persistence.RangeDeleteReplicationTaskFromDLQRequest{ RangeCompleteHistoryTasksRequest: persistence.RangeCompleteHistoryTasksRequest{ ShardID: s.mockShard.GetShardID(), @@ -210,7 +211,7 @@ func (s *replicationDLQHandlerSuite) TestPurgeMessages() { }).Return(nil) s.shardManager.EXPECT().UpdateShard(gomock.Any()).Return(nil) - err := s.replicationMessageHandler.purgeMessages(s.sourceCluster, lastMessageID) + err := s.replicationMessageHandler.purgeMessages(context.Background(), s.sourceCluster, lastMessageID) s.NoError(err) } func (s *replicationDLQHandlerSuite) TestMergeMessages() { @@ -260,7 +261,7 @@ func (s *replicationDLQHandlerSuite) TestMergeMessages() { }, } - s.executionManager.EXPECT().GetReplicationTasksFromDLQ(&persistence.GetReplicationTasksFromDLQRequest{ + s.executionManager.EXPECT().GetReplicationTasksFromDLQ(gomock.Any(), &persistence.GetReplicationTasksFromDLQRequest{ GetHistoryTasksRequest: persistence.GetHistoryTasksRequest{ ShardID: s.mockShard.GetShardID(), TaskCategory: tasks.CategoryReplication, @@ -278,7 +279,7 @@ func (s *replicationDLQHandlerSuite) TestMergeMessages() { ReplicationTasks: []*replicationspb.ReplicationTask{remoteTask}, }, nil) s.taskExecutor.EXPECT().execute(remoteTask, true).Return(0, nil) - s.executionManager.EXPECT().RangeDeleteReplicationTaskFromDLQ(&persistence.RangeDeleteReplicationTaskFromDLQRequest{ + s.executionManager.EXPECT().RangeDeleteReplicationTaskFromDLQ(gomock.Any(), &persistence.RangeDeleteReplicationTaskFromDLQRequest{ RangeCompleteHistoryTasksRequest: persistence.RangeCompleteHistoryTasksRequest{ ShardID: s.mockShard.GetShardID(), TaskCategory: tasks.CategoryReplication, diff --git a/service/history/replicationTaskProcessor.go b/service/history/replicationTaskProcessor.go index 7d020ffc55b..3bc599449b3 100644 --- a/service/history/replicationTaskProcessor.go +++ b/service/history/replicationTaskProcessor.go @@ -365,7 +365,7 @@ func (p *ReplicationTaskProcessorImpl) handleReplicationDLQTask( ) // The following is guaranteed to success or retry forever until processor is shutdown. return backoff.Retry(func() error { - err := p.shard.GetExecutionManager().PutReplicationTaskToDLQ(request) + err := p.shard.GetExecutionManager().PutReplicationTaskToDLQ(context.TODO(), request) if err != nil { p.logger.Error("failed to enqueue replication task to DLQ", tag.Error(err)) p.metricsClient.IncCounter(metrics.ReplicationTaskFetcherScope, metrics.ReplicationDLQFailed) @@ -503,6 +503,7 @@ func (p *ReplicationTaskProcessorImpl) cleanupReplicationTasks() error { int(p.shard.GetQueueMaxReadLevel(tasks.CategoryReplication, currentCluster).TaskID-*minAckedTaskID), ) err := p.shard.GetExecutionManager().RangeCompleteHistoryTasks( + context.TODO(), &persistence.RangeCompleteHistoryTasksRequest{ ShardID: p.shard.GetShardID(), TaskCategory: tasks.CategoryReplication, diff --git a/service/history/replicationTaskProcessor_test.go b/service/history/replicationTaskProcessor_test.go index d5c7dadd0e3..9807e20997e 100644 --- a/service/history/replicationTaskProcessor_test.go +++ b/service/history/replicationTaskProcessor_test.go @@ -257,7 +257,7 @@ func (s *replicationTaskProcessorSuite) TestHandleReplicationDLQTask_SyncActivit }, } - s.mockExecutionManager.EXPECT().PutReplicationTaskToDLQ(request).Return(nil) + s.mockExecutionManager.EXPECT().PutReplicationTaskToDLQ(gomock.Any(), request).Return(nil) err := s.replicationTaskProcessor.handleReplicationDLQTask(request) s.NoError(err) } @@ -281,7 +281,7 @@ func (s *replicationTaskProcessorSuite) TestHandleReplicationDLQTask_History() { }, } - s.mockExecutionManager.EXPECT().PutReplicationTaskToDLQ(request).Return(nil) + s.mockExecutionManager.EXPECT().PutReplicationTaskToDLQ(gomock.Any(), request).Return(nil) err := s.replicationTaskProcessor.handleReplicationDLQTask(request) s.NoError(err) } @@ -384,6 +384,7 @@ func (s *replicationTaskProcessorSuite) TestCleanupReplicationTask_Cleanup() { s.replicationTaskProcessor.minTxAckedTaskID = ackedTaskID - 1 s.mockExecutionManager.EXPECT().RangeCompleteHistoryTasks( + gomock.Any(), &persistence.RangeCompleteHistoryTasksRequest{ ShardID: s.shardID, TaskCategory: tasks.CategoryReplication, diff --git a/service/history/replicatorQueueProcessor.go b/service/history/replicatorQueueProcessor.go index 1d77be5bb9d..06f542d61a0 100644 --- a/service/history/replicatorQueueProcessor.go +++ b/service/history/replicatorQueueProcessor.go @@ -206,7 +206,7 @@ func (p *replicatorQueueProcessorImpl) getTasks( var token []byte replicationTasks := make([]*replicationspb.ReplicationTask, 0, batchSize) for { - response, err := p.executionMgr.GetHistoryTasks(&persistence.GetHistoryTasksRequest{ + response, err := p.executionMgr.GetHistoryTasks(ctx, &persistence.GetHistoryTasksRequest{ ShardID: p.shard.GetShardID(), TaskCategory: tasks.CategoryReplication, InclusiveMinTaskKey: tasks.Key{ @@ -443,6 +443,7 @@ func (p *replicatorQueueProcessorImpl) generateHistoryReplicationTask( } eventsBlob, err := p.getEventsBlob( + ctx, taskInfo.BranchToken, taskInfo.FirstEventID, taskInfo.NextEventID, @@ -455,6 +456,7 @@ func (p *replicatorQueueProcessorImpl) generateHistoryReplicationTask( if len(taskInfo.NewRunBranchToken) != 0 { // only get the first batch newRunEventsBlob, err = p.getEventsBlob( + ctx, taskInfo.NewRunBranchToken, common.FirstEventID, common.FirstEventID+1, @@ -485,6 +487,7 @@ func (p *replicatorQueueProcessorImpl) generateHistoryReplicationTask( } func (p *replicatorQueueProcessorImpl) getEventsBlob( + ctx context.Context, branchToken []byte, firstEventID int64, nextEventID int64, @@ -502,7 +505,7 @@ func (p *replicatorQueueProcessorImpl) getEventsBlob( } for { - resp, err := p.executionMgr.ReadRawHistoryBranch(req) + resp, err := p.executionMgr.ReadRawHistoryBranch(ctx, req) if err != nil { return nil, err } @@ -572,7 +575,7 @@ func (p *replicatorQueueProcessorImpl) processReplication( } defer func() { release(retError) }() - msBuilder, err := context.LoadWorkflowExecution() + msBuilder, err := context.LoadWorkflowExecution(ctx) switch err.(type) { case nil: if !processTaskIfClosed && !msBuilder.IsWorkflowExecutionRunning() { diff --git a/service/history/replicatorQueueProcessor_test.go b/service/history/replicatorQueueProcessor_test.go index 18bae880b4e..f1af0a181be 100644 --- a/service/history/replicatorQueueProcessor_test.go +++ b/service/history/replicatorQueueProcessor_test.go @@ -227,7 +227,7 @@ func (s *replicatorQueueProcessorSuite) TestSyncActivity_WorkflowMissing() { Version: version, ScheduledID: scheduleID, } - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(&persistence.GetWorkflowExecutionRequest{ + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), &persistence.GetWorkflowExecutionRequest{ ShardID: s.mockShard.GetShardID(), NamespaceID: namespaceID.String(), WorkflowID: workflowID, diff --git a/service/history/shard/context.go b/service/history/shard/context.go index a9fcfff6a90..adbc878f7a5 100644 --- a/service/history/shard/context.go +++ b/service/history/shard/context.go @@ -25,6 +25,7 @@ package shard import ( + "context" "time" commonpb "go.temporal.io/api/common/v1" @@ -94,16 +95,16 @@ type ( UpdateNamespaceNotificationVersion(namespaceNotificationVersion int64) error UpdateHandoverNamespaces(newNamespaces []*namespace.Namespace, maxRepTaskID int64) - AppendHistoryEvents(request *persistence.AppendHistoryNodesRequest, namespaceID namespace.ID, execution commonpb.WorkflowExecution) (int, error) + AppendHistoryEvents(ctx context.Context, request *persistence.AppendHistoryNodesRequest, namespaceID namespace.ID, execution commonpb.WorkflowExecution) (int, error) - AddTasks(request *persistence.AddHistoryTasksRequest) error - CreateWorkflowExecution(request *persistence.CreateWorkflowExecutionRequest) (*persistence.CreateWorkflowExecutionResponse, error) - UpdateWorkflowExecution(request *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) - ConflictResolveWorkflowExecution(request *persistence.ConflictResolveWorkflowExecutionRequest) (*persistence.ConflictResolveWorkflowExecutionResponse, error) - SetWorkflowExecution(request *persistence.SetWorkflowExecutionRequest) (*persistence.SetWorkflowExecutionResponse, error) + AddTasks(ctx context.Context, request *persistence.AddHistoryTasksRequest) error + CreateWorkflowExecution(ctx context.Context, request *persistence.CreateWorkflowExecutionRequest) (*persistence.CreateWorkflowExecutionResponse, error) + UpdateWorkflowExecution(ctx context.Context, request *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) + ConflictResolveWorkflowExecution(ctx context.Context, request *persistence.ConflictResolveWorkflowExecutionRequest) (*persistence.ConflictResolveWorkflowExecutionResponse, error) + SetWorkflowExecution(ctx context.Context, request *persistence.SetWorkflowExecutionRequest) (*persistence.SetWorkflowExecutionResponse, error) // DeleteWorkflowExecution deletes workflow execution, current workflow execution, and add task to delete visibility. // If branchToken != nil, then delete history also, otherwise leave history. - DeleteWorkflowExecution(workflowKey definition.WorkflowKey, branchToken []byte, version int64, closeTime *time.Time) error + DeleteWorkflowExecution(ctx context.Context, workflowKey definition.WorkflowKey, branchToken []byte, version int64, closeTime *time.Time) error GetRemoteAdminClient(cluster string) adminservice.AdminServiceClient GetHistoryClient() historyservice.HistoryServiceClient diff --git a/service/history/shard/context_impl.go b/service/history/shard/context_impl.go index 75af1fef029..495298276c8 100644 --- a/service/history/shard/context_impl.go +++ b/service/history/shard/context_impl.go @@ -547,6 +547,7 @@ func (s *ContextImpl) UpdateHandoverNamespaces(namespaces []*namespace.Namespace } func (s *ContextImpl) AddTasks( + ctx context.Context, request *persistence.AddHistoryTasksRequest, ) error { // do not try to get namespace cache within shard lock @@ -563,10 +564,11 @@ func (s *ContextImpl) AddTasks( return err } - return s.addTasksLocked(request, namespaceEntry) + return s.addTasksLocked(ctx, request, namespaceEntry) } func (s *ContextImpl) CreateWorkflowExecution( + ctx context.Context, request *persistence.CreateWorkflowExecutionRequest, ) (*persistence.CreateWorkflowExecutionResponse, error) { // do not try to get namespace cache within shard lock @@ -596,7 +598,7 @@ func (s *ContextImpl) CreateWorkflowExecution( currentRangeID := s.getRangeIDLocked() request.RangeID = currentRangeID - resp, err := s.executionManager.CreateWorkflowExecution(request) + resp, err := s.executionManager.CreateWorkflowExecution(ctx, request) if err = s.handleErrorAndUpdateMaxReadLevelLocked(err, transferMaxReadLevel); err != nil { return nil, err } @@ -604,6 +606,7 @@ func (s *ContextImpl) CreateWorkflowExecution( } func (s *ContextImpl) UpdateWorkflowExecution( + ctx context.Context, request *persistence.UpdateWorkflowExecutionRequest, ) (*persistence.UpdateWorkflowExecutionResponse, error) { // do not try to get namespace cache within shard lock @@ -643,7 +646,7 @@ func (s *ContextImpl) UpdateWorkflowExecution( currentRangeID := s.getRangeIDLocked() request.RangeID = currentRangeID - resp, err := s.executionManager.UpdateWorkflowExecution(request) + resp, err := s.executionManager.UpdateWorkflowExecution(ctx, request) if err = s.handleErrorAndUpdateMaxReadLevelLocked(err, transferMaxReadLevel); err != nil { return nil, err } @@ -651,6 +654,7 @@ func (s *ContextImpl) UpdateWorkflowExecution( } func (s *ContextImpl) ConflictResolveWorkflowExecution( + ctx context.Context, request *persistence.ConflictResolveWorkflowExecutionRequest, ) (*persistence.ConflictResolveWorkflowExecutionResponse, error) { // do not try to get namespace cache within shard lock @@ -700,7 +704,7 @@ func (s *ContextImpl) ConflictResolveWorkflowExecution( currentRangeID := s.getRangeIDLocked() request.RangeID = currentRangeID - resp, err := s.executionManager.ConflictResolveWorkflowExecution(request) + resp, err := s.executionManager.ConflictResolveWorkflowExecution(ctx, request) if err = s.handleErrorAndUpdateMaxReadLevelLocked(err, transferMaxReadLevel); err != nil { return nil, err } @@ -708,6 +712,7 @@ func (s *ContextImpl) ConflictResolveWorkflowExecution( } func (s *ContextImpl) SetWorkflowExecution( + ctx context.Context, request *persistence.SetWorkflowExecutionRequest, ) (*persistence.SetWorkflowExecutionResponse, error) { // do not try to get namespace cache within shard lock @@ -737,7 +742,7 @@ func (s *ContextImpl) SetWorkflowExecution( currentRangeID := s.getRangeIDLocked() request.RangeID = currentRangeID - resp, err := s.executionManager.SetWorkflowExecution(request) + resp, err := s.executionManager.SetWorkflowExecution(ctx, request) if err = s.handleErrorAndUpdateMaxReadLevelLocked(err, transferMaxReadLevel); err != nil { return nil, err } @@ -745,6 +750,7 @@ func (s *ContextImpl) SetWorkflowExecution( } func (s *ContextImpl) addTasksLocked( + ctx context.Context, request *persistence.AddHistoryTasksRequest, namespaceEntry *namespace.Namespace, ) error { @@ -759,7 +765,7 @@ func (s *ContextImpl) addTasksLocked( } request.RangeID = s.getRangeIDLocked() - err := s.executionManager.AddHistoryTasks(request) + err := s.executionManager.AddHistoryTasks(ctx, request) if err = s.handleErrorAndUpdateMaxReadLevelLocked(err, transferMaxReadLevel); err != nil { return err } @@ -768,6 +774,7 @@ func (s *ContextImpl) addTasksLocked( } func (s *ContextImpl) AppendHistoryEvents( + ctx context.Context, request *persistence.AppendHistoryNodesRequest, namespaceID namespace.ID, execution commonpb.WorkflowExecution, @@ -800,7 +807,7 @@ func (s *ContextImpl) AppendHistoryEvents( tag.WorkflowHistorySizeBytes(size)) } }() - resp, err0 := s.GetExecutionManager().AppendHistoryNodes(request) + resp, err0 := s.GetExecutionManager().AppendHistoryNodes(ctx, request) if resp != nil { size = resp.Size } @@ -808,6 +815,7 @@ func (s *ContextImpl) AppendHistoryEvents( } func (s *ContextImpl) DeleteWorkflowExecution( + ctx context.Context, key definition.WorkflowKey, branchToken []byte, newTaskVersion int64, @@ -868,7 +876,7 @@ func (s *ContextImpl) DeleteWorkflowExecution( }, }, } - err = s.addTasksLocked(addTasksRequest, namespaceEntry) + err = s.addTasksLocked(ctx, addTasksRequest, namespaceEntry) if err != nil { return err } @@ -882,7 +890,7 @@ func (s *ContextImpl) DeleteWorkflowExecution( RunID: key.RunID, } op := func() error { - return s.GetExecutionManager().DeleteCurrentWorkflowExecution(delCurRequest) + return s.GetExecutionManager().DeleteCurrentWorkflowExecution(ctx, delCurRequest) } err = backoff.Retry(op, persistenceOperationRetryPolicy, common.IsPersistenceTransientError) if err != nil { @@ -897,7 +905,7 @@ func (s *ContextImpl) DeleteWorkflowExecution( RunID: key.RunID, } op = func() error { - return s.GetExecutionManager().DeleteWorkflowExecution(delRequest) + return s.GetExecutionManager().DeleteWorkflowExecution(ctx, delRequest) } err = backoff.Retry(op, persistenceOperationRetryPolicy, common.IsPersistenceTransientError) if err != nil { @@ -911,7 +919,7 @@ func (s *ContextImpl) DeleteWorkflowExecution( ShardID: s.shardID, } op := func() error { - return s.GetExecutionManager().DeleteHistoryBranch(delHistoryRequest) + return s.GetExecutionManager().DeleteHistoryBranch(ctx, delHistoryRequest) } err = backoff.Retry(op, persistenceOperationRetryPolicy, common.IsPersistenceTransientError) if err != nil { diff --git a/service/history/shard/context_mock.go b/service/history/shard/context_mock.go index fba996d4b26..41f7628c236 100644 --- a/service/history/shard/context_mock.go +++ b/service/history/shard/context_mock.go @@ -29,6 +29,7 @@ package shard import ( + context "context" reflect "reflect" time "time" @@ -75,62 +76,62 @@ func (m *MockContext) EXPECT() *MockContextMockRecorder { } // AddTasks mocks base method. -func (m *MockContext) AddTasks(request *persistence.AddHistoryTasksRequest) error { +func (m *MockContext) AddTasks(ctx context.Context, request *persistence.AddHistoryTasksRequest) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddTasks", request) + ret := m.ctrl.Call(m, "AddTasks", ctx, request) ret0, _ := ret[0].(error) return ret0 } // AddTasks indicates an expected call of AddTasks. -func (mr *MockContextMockRecorder) AddTasks(request interface{}) *gomock.Call { +func (mr *MockContextMockRecorder) AddTasks(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddTasks", reflect.TypeOf((*MockContext)(nil).AddTasks), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddTasks", reflect.TypeOf((*MockContext)(nil).AddTasks), ctx, request) } // AppendHistoryEvents mocks base method. -func (m *MockContext) AppendHistoryEvents(request *persistence.AppendHistoryNodesRequest, namespaceID namespace.ID, execution v1.WorkflowExecution) (int, error) { +func (m *MockContext) AppendHistoryEvents(ctx context.Context, request *persistence.AppendHistoryNodesRequest, namespaceID namespace.ID, execution v1.WorkflowExecution) (int, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AppendHistoryEvents", request, namespaceID, execution) + ret := m.ctrl.Call(m, "AppendHistoryEvents", ctx, request, namespaceID, execution) ret0, _ := ret[0].(int) ret1, _ := ret[1].(error) return ret0, ret1 } // AppendHistoryEvents indicates an expected call of AppendHistoryEvents. -func (mr *MockContextMockRecorder) AppendHistoryEvents(request, namespaceID, execution interface{}) *gomock.Call { +func (mr *MockContextMockRecorder) AppendHistoryEvents(ctx, request, namespaceID, execution interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendHistoryEvents", reflect.TypeOf((*MockContext)(nil).AppendHistoryEvents), request, namespaceID, execution) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendHistoryEvents", reflect.TypeOf((*MockContext)(nil).AppendHistoryEvents), ctx, request, namespaceID, execution) } // ConflictResolveWorkflowExecution mocks base method. -func (m *MockContext) ConflictResolveWorkflowExecution(request *persistence.ConflictResolveWorkflowExecutionRequest) (*persistence.ConflictResolveWorkflowExecutionResponse, error) { +func (m *MockContext) ConflictResolveWorkflowExecution(ctx context.Context, request *persistence.ConflictResolveWorkflowExecutionRequest) (*persistence.ConflictResolveWorkflowExecutionResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ConflictResolveWorkflowExecution", request) + ret := m.ctrl.Call(m, "ConflictResolveWorkflowExecution", ctx, request) ret0, _ := ret[0].(*persistence.ConflictResolveWorkflowExecutionResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // ConflictResolveWorkflowExecution indicates an expected call of ConflictResolveWorkflowExecution. -func (mr *MockContextMockRecorder) ConflictResolveWorkflowExecution(request interface{}) *gomock.Call { +func (mr *MockContextMockRecorder) ConflictResolveWorkflowExecution(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConflictResolveWorkflowExecution", reflect.TypeOf((*MockContext)(nil).ConflictResolveWorkflowExecution), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConflictResolveWorkflowExecution", reflect.TypeOf((*MockContext)(nil).ConflictResolveWorkflowExecution), ctx, request) } // CreateWorkflowExecution mocks base method. -func (m *MockContext) CreateWorkflowExecution(request *persistence.CreateWorkflowExecutionRequest) (*persistence.CreateWorkflowExecutionResponse, error) { +func (m *MockContext) CreateWorkflowExecution(ctx context.Context, request *persistence.CreateWorkflowExecutionRequest) (*persistence.CreateWorkflowExecutionResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateWorkflowExecution", request) + ret := m.ctrl.Call(m, "CreateWorkflowExecution", ctx, request) ret0, _ := ret[0].(*persistence.CreateWorkflowExecutionResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // CreateWorkflowExecution indicates an expected call of CreateWorkflowExecution. -func (mr *MockContextMockRecorder) CreateWorkflowExecution(request interface{}) *gomock.Call { +func (mr *MockContextMockRecorder) CreateWorkflowExecution(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateWorkflowExecution", reflect.TypeOf((*MockContext)(nil).CreateWorkflowExecution), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateWorkflowExecution", reflect.TypeOf((*MockContext)(nil).CreateWorkflowExecution), ctx, request) } // DeleteFailoverLevel mocks base method. @@ -148,17 +149,17 @@ func (mr *MockContextMockRecorder) DeleteFailoverLevel(category, failoverID inte } // DeleteWorkflowExecution mocks base method. -func (m *MockContext) DeleteWorkflowExecution(workflowKey definition.WorkflowKey, branchToken []byte, version int64, closeTime *time.Time) error { +func (m *MockContext) DeleteWorkflowExecution(ctx context.Context, workflowKey definition.WorkflowKey, branchToken []byte, version int64, closeTime *time.Time) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteWorkflowExecution", workflowKey, branchToken, version, closeTime) + ret := m.ctrl.Call(m, "DeleteWorkflowExecution", ctx, workflowKey, branchToken, version, closeTime) ret0, _ := ret[0].(error) return ret0 } // DeleteWorkflowExecution indicates an expected call of DeleteWorkflowExecution. -func (mr *MockContextMockRecorder) DeleteWorkflowExecution(workflowKey, branchToken, version, closeTime interface{}) *gomock.Call { +func (mr *MockContextMockRecorder) DeleteWorkflowExecution(ctx, workflowKey, branchToken, version, closeTime interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkflowExecution", reflect.TypeOf((*MockContext)(nil).DeleteWorkflowExecution), workflowKey, branchToken, version, closeTime) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkflowExecution", reflect.TypeOf((*MockContext)(nil).DeleteWorkflowExecution), ctx, workflowKey, branchToken, version, closeTime) } // GenerateTaskID mocks base method. @@ -585,18 +586,18 @@ func (mr *MockContextMockRecorder) SetCurrentTime(cluster, currentTime interface } // SetWorkflowExecution mocks base method. -func (m *MockContext) SetWorkflowExecution(request *persistence.SetWorkflowExecutionRequest) (*persistence.SetWorkflowExecutionResponse, error) { +func (m *MockContext) SetWorkflowExecution(ctx context.Context, request *persistence.SetWorkflowExecutionRequest) (*persistence.SetWorkflowExecutionResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetWorkflowExecution", request) + ret := m.ctrl.Call(m, "SetWorkflowExecution", ctx, request) ret0, _ := ret[0].(*persistence.SetWorkflowExecutionResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // SetWorkflowExecution indicates an expected call of SetWorkflowExecution. -func (mr *MockContextMockRecorder) SetWorkflowExecution(request interface{}) *gomock.Call { +func (mr *MockContextMockRecorder) SetWorkflowExecution(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWorkflowExecution", reflect.TypeOf((*MockContext)(nil).SetWorkflowExecution), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWorkflowExecution", reflect.TypeOf((*MockContext)(nil).SetWorkflowExecution), ctx, request) } // Unload mocks base method. @@ -706,16 +707,16 @@ func (mr *MockContextMockRecorder) UpdateReplicatorDLQAckLevel(sourCluster, ackL } // UpdateWorkflowExecution mocks base method. -func (m *MockContext) UpdateWorkflowExecution(request *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { +func (m *MockContext) UpdateWorkflowExecution(ctx context.Context, request *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateWorkflowExecution", request) + ret := m.ctrl.Call(m, "UpdateWorkflowExecution", ctx, request) ret0, _ := ret[0].(*persistence.UpdateWorkflowExecutionResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // UpdateWorkflowExecution indicates an expected call of UpdateWorkflowExecution. -func (mr *MockContextMockRecorder) UpdateWorkflowExecution(request interface{}) *gomock.Call { +func (mr *MockContextMockRecorder) UpdateWorkflowExecution(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkflowExecution", reflect.TypeOf((*MockContext)(nil).UpdateWorkflowExecution), request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkflowExecution", reflect.TypeOf((*MockContext)(nil).UpdateWorkflowExecution), ctx, request) } diff --git a/service/history/shard/context_test.go b/service/history/shard/context_test.go index 79727faca56..f62318091d5 100644 --- a/service/history/shard/context_test.go +++ b/service/history/shard/context_test.go @@ -25,6 +25,7 @@ package shard import ( + "context" "testing" "time" @@ -141,10 +142,10 @@ func (s *contextSuite) TestAddTasks_Success() { s.mockNamespaceCache.EXPECT().GetNamespaceByID(s.namespaceID).Return(s.namespaceEntry, nil) s.mockClusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName) - s.mockExecutionManager.EXPECT().AddHistoryTasks(addTasksRequest).Return(nil) + s.mockExecutionManager.EXPECT().AddHistoryTasks(gomock.Any(), addTasksRequest).Return(nil) s.mockHistoryEngine.EXPECT().NotifyNewTasks(gomock.Any(), tasks) - err := s.shardContext.AddTasks(addTasksRequest) + err := s.shardContext.AddTasks(context.Background(), addTasksRequest) s.NoError(err) } diff --git a/service/history/timerQueueAckMgr.go b/service/history/timerQueueAckMgr.go index a2426468ba5..8165a01fd50 100644 --- a/service/history/timerQueueAckMgr.go +++ b/service/history/timerQueueAckMgr.go @@ -25,6 +25,7 @@ package history import ( + "context" "math" "sort" "sync" @@ -375,7 +376,7 @@ func (t *timerQueueAckMgrImpl) getTimerTasks(minTimestamp time.Time, maxTimestam BatchSize: batchSize, NextPageToken: pageToken, } - response, err := t.executionMgr.GetHistoryTasks(request) + response, err := t.executionMgr.GetHistoryTasks(context.TODO(), request) if err != nil { return nil, nil, err } diff --git a/service/history/timerQueueAckMgr_test.go b/service/history/timerQueueAckMgr_test.go index c205e96efb6..0db0e96fbb9 100644 --- a/service/history/timerQueueAckMgr_test.go +++ b/service/history/timerQueueAckMgr_test.go @@ -212,7 +212,7 @@ func (s *timerQueueAckMgrSuite) TestGetTimerTasks_More() { NextPageToken: []byte("some random output next page token"), } - s.mockExecutionMgr.EXPECT().GetHistoryTasks(request).Return(response, nil) + s.mockExecutionMgr.EXPECT().GetHistoryTasks(gomock.Any(), request).Return(response, nil) timers, token, err := s.timerQueueAckMgr.getTimerTasks(minTimestamp, maxTimestamp, batchSize, request.NextPageToken) s.Nil(err) @@ -256,7 +256,7 @@ func (s *timerQueueAckMgrSuite) TestGetTimerTasks_NoMore() { NextPageToken: nil, } - s.mockExecutionMgr.EXPECT().GetHistoryTasks(request).Return(response, nil) + s.mockExecutionMgr.EXPECT().GetHistoryTasks(gomock.Any(), request).Return(response, nil) timers, token, err := s.timerQueueAckMgr.getTimerTasks(minTimestamp, maxTimestamp, batchSize, request.NextPageToken) s.Nil(err) @@ -294,8 +294,8 @@ func (s *timerQueueAckMgrSuite) TestReadTimerTasks_NoLookAhead_NoNextPage() { NextPageToken: nil, } s.mockClusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName).AnyTimes() - s.mockExecutionMgr.EXPECT().GetHistoryTasks(gomock.Any()).Return(response, nil) - s.mockExecutionMgr.EXPECT().GetHistoryTasks(gomock.Any()).Return(&persistence.GetHistoryTasksResponse{}, nil) + s.mockExecutionMgr.EXPECT().GetHistoryTasks(gomock.Any(), gomock.Any()).Return(response, nil) + s.mockExecutionMgr.EXPECT().GetHistoryTasks(gomock.Any(), gomock.Any()).Return(&persistence.GetHistoryTasksResponse{}, nil) filteredTasks, nextFireTime, moreTasks, err := s.timerQueueAckMgr.readTimerTasks() s.Nil(err) s.Equal([]tasks.Task{timer}, filteredTasks) @@ -343,7 +343,7 @@ func (s *timerQueueAckMgrSuite) TestReadTimerTasks_NoLookAhead_HasNextPage() { NextPageToken: []byte("some random next page token"), } s.mockClusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName).AnyTimes() - s.mockExecutionMgr.EXPECT().GetHistoryTasks(gomock.Any()).Return(response, nil) + s.mockExecutionMgr.EXPECT().GetHistoryTasks(gomock.Any(), gomock.Any()).Return(response, nil) readTimestamp := time.Now().UTC() // the approximate time of calling readTimerTasks filteredTasks, nextFireTime, moreTasks, err := s.timerQueueAckMgr.readTimerTasks() s.Nil(err) @@ -392,7 +392,7 @@ func (s *timerQueueAckMgrSuite) TestReadTimerTasks_HasLookAhead_NoNextPage() { NextPageToken: nil, } s.mockClusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName).AnyTimes() - s.mockExecutionMgr.EXPECT().GetHistoryTasks(gomock.Any()).Return(response, nil) + s.mockExecutionMgr.EXPECT().GetHistoryTasks(gomock.Any(), gomock.Any()).Return(response, nil) filteredTasks, nextFireTime, moreTasks, err := s.timerQueueAckMgr.readTimerTasks() s.Nil(err) s.Equal([]tasks.Task{}, filteredTasks) @@ -437,7 +437,7 @@ func (s *timerQueueAckMgrSuite) TestReadTimerTasks_HasLookAhead_HasNextPage() { NextPageToken: []byte("some random next page token"), } s.mockClusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName).AnyTimes() - s.mockExecutionMgr.EXPECT().GetHistoryTasks(gomock.Any()).Return(response, nil) + s.mockExecutionMgr.EXPECT().GetHistoryTasks(gomock.Any(), gomock.Any()).Return(response, nil) filteredTasks, nextFireTime, moreTasks, err := s.timerQueueAckMgr.readTimerTasks() s.Nil(err) s.Equal([]tasks.Task{}, filteredTasks) @@ -496,8 +496,8 @@ func (s *timerQueueAckMgrSuite) TestReadCompleteUpdateTimerTasks() { } s.mockClusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName).AnyTimes() s.mockClusterMetadata.EXPECT().GetAllClusterInfo().Return(cluster.TestAllClusterInfo).AnyTimes() - s.mockExecutionMgr.EXPECT().GetHistoryTasks(gomock.Any()).Return(response, nil) - s.mockExecutionMgr.EXPECT().GetHistoryTasks(gomock.Any()).Return(&persistence.GetHistoryTasksResponse{}, nil) + s.mockExecutionMgr.EXPECT().GetHistoryTasks(gomock.Any(), gomock.Any()).Return(response, nil) + s.mockExecutionMgr.EXPECT().GetHistoryTasks(gomock.Any(), gomock.Any()).Return(&persistence.GetHistoryTasksResponse{}, nil) filteredTasks, nextFireTime, moreTasks, err := s.timerQueueAckMgr.readTimerTasks() s.Nil(err) s.Equal([]tasks.Task{timer1, timer2, timer3}, filteredTasks) @@ -563,7 +563,7 @@ func (s *timerQueueAckMgrSuite) TestReadLookAheadTask() { Tasks: []tasks.Task{timer}, NextPageToken: []byte("some random next page token"), } - s.mockExecutionMgr.EXPECT().GetHistoryTasks(gomock.Any()).Return(response, nil) + s.mockExecutionMgr.EXPECT().GetHistoryTasks(gomock.Any(), gomock.Any()).Return(response, nil) nextFireTime, err := s.timerQueueAckMgr.readLookAheadTask() s.Nil(err) s.Equal(timer.GetVisibilityTime(), *nextFireTime) @@ -696,7 +696,7 @@ func (s *timerQueueFailoverAckMgrSuite) TestReadTimerTasks_HasNextPage() { NextPageToken: []byte("some random next page token"), } - s.mockExecutionMgr.EXPECT().GetHistoryTasks(gomock.Any()).Return(response, nil) + s.mockExecutionMgr.EXPECT().GetHistoryTasks(gomock.Any(), gomock.Any()).Return(response, nil) readTimestamp := time.Now().UTC() // the approximate time of calling readTimerTasks timers, lookAheadTimer, more, err := s.timerQueueFailoverAckMgr.readTimerTasks() s.Nil(err) @@ -727,7 +727,7 @@ func (s *timerQueueFailoverAckMgrSuite) TestReadTimerTasks_NoNextPage() { NextPageToken: nil, } s.mockClusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName).AnyTimes() - s.mockExecutionMgr.EXPECT().GetHistoryTasks(gomock.Any()).Return(response, nil) + s.mockExecutionMgr.EXPECT().GetHistoryTasks(gomock.Any(), gomock.Any()).Return(response, nil) readTimestamp := time.Now().UTC() // the approximate time of calling readTimerTasks timers, lookAheadTimer, more, err := s.timerQueueFailoverAckMgr.readTimerTasks() @@ -814,7 +814,7 @@ func (s *timerQueueFailoverAckMgrSuite) TestReadCompleteUpdateTimerTasks() { } s.mockClusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName).AnyTimes() s.mockClusterMetadata.EXPECT().GetAllClusterInfo().Return(cluster.TestAllClusterInfo).AnyTimes() - s.mockExecutionMgr.EXPECT().GetHistoryTasks(gomock.Any()).Return(response, nil) + s.mockExecutionMgr.EXPECT().GetHistoryTasks(gomock.Any(), gomock.Any()).Return(response, nil) filteredTasks, nextFireTime, moreTasks, err := s.timerQueueFailoverAckMgr.readTimerTasks() s.Nil(err) s.Equal([]tasks.Task{timer1, timer2, timer3}, filteredTasks) diff --git a/service/history/timerQueueActiveTaskExecutor.go b/service/history/timerQueueActiveTaskExecutor.go index 264f6bbc026..ed2bd06beec 100644 --- a/service/history/timerQueueActiveTaskExecutor.go +++ b/service/history/timerQueueActiveTaskExecutor.go @@ -132,7 +132,7 @@ func (t *timerQueueActiveTaskExecutor) executeUserTimerTimeoutTask( } defer func() { release(retError) }() - mutableState, err := loadMutableStateForTimerTask(weContext, task, t.metricsClient, t.logger) + mutableState, err := loadMutableStateForTimerTask(ctx, weContext, task, t.metricsClient, t.logger) if err != nil { return err } @@ -169,7 +169,7 @@ Loop: return nil } - return t.updateWorkflowExecution(weContext, mutableState, timerFired) + return t.updateWorkflowExecution(ctx, weContext, mutableState, timerFired) } func (t *timerQueueActiveTaskExecutor) executeActivityTimeoutTask( @@ -192,7 +192,7 @@ func (t *timerQueueActiveTaskExecutor) executeActivityTimeoutTask( } defer func() { release(retError) }() - mutableState, err := loadMutableStateForTimerTask(weContext, task, t.metricsClient, t.logger) + mutableState, err := loadMutableStateForTimerTask(ctx, weContext, task, t.metricsClient, t.logger) if err != nil { return err } @@ -279,7 +279,7 @@ Loop: if !updateMutableState { return nil } - return t.updateWorkflowExecution(weContext, mutableState, scheduleWorkflowTask) + return t.updateWorkflowExecution(ctx, weContext, mutableState, scheduleWorkflowTask) } func (t *timerQueueActiveTaskExecutor) executeWorkflowTaskTimeoutTask( @@ -302,7 +302,7 @@ func (t *timerQueueActiveTaskExecutor) executeWorkflowTaskTimeoutTask( } defer func() { release(retError) }() - mutableState, err := loadMutableStateForTimerTask(weContext, task, t.metricsClient, t.logger) + mutableState, err := loadMutableStateForTimerTask(ctx, weContext, task, t.metricsClient, t.logger) if err != nil { return err } @@ -357,7 +357,7 @@ func (t *timerQueueActiveTaskExecutor) executeWorkflowTaskTimeoutTask( scheduleWorkflowTask = true } - return t.updateWorkflowExecution(weContext, mutableState, scheduleWorkflowTask) + return t.updateWorkflowExecution(ctx, weContext, mutableState, scheduleWorkflowTask) } func (t *timerQueueActiveTaskExecutor) executeWorkflowBackoffTimerTask( @@ -380,7 +380,7 @@ func (t *timerQueueActiveTaskExecutor) executeWorkflowBackoffTimerTask( } defer func() { release(retError) }() - mutableState, err := loadMutableStateForTimerTask(weContext, task, t.metricsClient, t.logger) + mutableState, err := loadMutableStateForTimerTask(ctx, weContext, task, t.metricsClient, t.logger) if err != nil { return err } @@ -400,7 +400,7 @@ func (t *timerQueueActiveTaskExecutor) executeWorkflowBackoffTimerTask( } // schedule first workflow task - return t.updateWorkflowExecution(weContext, mutableState, true) + return t.updateWorkflowExecution(ctx, weContext, mutableState, true) } func (t *timerQueueActiveTaskExecutor) executeActivityRetryTimerTask( @@ -423,7 +423,7 @@ func (t *timerQueueActiveTaskExecutor) executeActivityRetryTimerTask( } defer func() { release(retError) }() - mutableState, err := loadMutableStateForTimerTask(weContext, task, t.metricsClient, t.logger) + mutableState, err := loadMutableStateForTimerTask(ctx, weContext, task, t.metricsClient, t.logger) if err != nil { return err } @@ -497,7 +497,7 @@ func (t *timerQueueActiveTaskExecutor) executeWorkflowTimeoutTask( } defer func() { release(retError) }() - mutableState, err := loadMutableStateForTimerTask(weContext, task, t.metricsClient, t.logger) + mutableState, err := loadMutableStateForTimerTask(ctx, weContext, task, t.metricsClient, t.logger) if err != nil { return err } @@ -551,10 +551,10 @@ func (t *timerQueueActiveTaskExecutor) executeWorkflowTimeoutTask( if initiator == enumspb.CONTINUE_AS_NEW_INITIATOR_UNSPECIFIED { // We apply the update to execution using optimistic concurrency. If it fails due to a conflict than reload // the history and try the operation again. - return t.updateWorkflowExecution(weContext, mutableState, false) + return t.updateWorkflowExecution(ctx, weContext, mutableState, false) } - startEvent, err := mutableState.GetStartEvent() + startEvent, err := mutableState.GetStartEvent(ctx) if err != nil { return err } @@ -569,6 +569,7 @@ func (t *timerQueueActiveTaskExecutor) executeWorkflowTimeoutTask( return err } err = workflow.SetupNewWorkflowForRetryOrCron( + ctx, mutableState, newMutableState, newRunID, @@ -585,6 +586,7 @@ func (t *timerQueueActiveTaskExecutor) executeWorkflowTimeoutTask( newExecutionInfo := newMutableState.GetExecutionInfo() newExecutionState := newMutableState.GetExecutionState() return weContext.UpdateWorkflowExecutionWithNewAsActive( + ctx, t.shard.GetTimeSource().Now(), workflow.NewContext( t.shard, @@ -608,6 +610,7 @@ func (t *timerQueueActiveTaskExecutor) getTimerSequence( } func (t *timerQueueActiveTaskExecutor) updateWorkflowExecution( + ctx context.Context, context workflow.Context, mutableState workflow.MutableState, scheduleNewWorkflowTask bool, @@ -623,7 +626,7 @@ func (t *timerQueueActiveTaskExecutor) updateWorkflowExecution( } now := t.shard.GetTimeSource().Now() - err = context.UpdateWorkflowExecutionAsActive(now) + err = context.UpdateWorkflowExecutionAsActive(ctx, now) if err != nil { if shard.IsShardOwnershipLostError(err) { // Shard is stolen. Stop timer processing to reduce duplicates diff --git a/service/history/timerQueueActiveTaskExecutor_test.go b/service/history/timerQueueActiveTaskExecutor_test.go index 8155ab0c6cc..139f5c99187 100644 --- a/service/history/timerQueueActiveTaskExecutor_test.go +++ b/service/history/timerQueueActiveTaskExecutor_test.go @@ -257,8 +257,8 @@ func (s *timerQueueActiveTaskExecutorSuite) TestProcessUserTimerTimeout_Fire() { } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) s.timeSource.Update(s.now.Add(2 * timerTimeout)) err = s.timerQueueActiveTaskExecutor.execute(context.Background(), timerTask, true) @@ -332,7 +332,7 @@ func (s *timerQueueActiveTaskExecutorSuite) TestProcessUserTimerTimeout_Noop() { mutableState.FlushBufferedEvents() persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.timeSource.Update(s.now.Add(2 * timerTimeout)) err = s.timerQueueActiveTaskExecutor.execute(context.Background(), timerTask, true) @@ -408,8 +408,8 @@ func (s *timerQueueActiveTaskExecutorSuite) TestProcessActivityTimeout_NoRetryPo } persistenceMutableState := s.createPersistenceMutableState(mutableState, scheduledEvent.GetEventId(), scheduledEvent.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) s.timeSource.Update(s.now.Add(2 * timerTimeout)) err = s.timerQueueActiveTaskExecutor.execute(context.Background(), timerTask, true) @@ -494,7 +494,7 @@ func (s *timerQueueActiveTaskExecutorSuite) TestProcessActivityTimeout_NoRetryPo mutableState.FlushBufferedEvents() persistenceMutableState := s.createPersistenceMutableState(mutableState, completeEvent.GetEventId(), completeEvent.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.timeSource.Update(s.now.Add(2 * timerTimeout)) err = s.timerQueueActiveTaskExecutor.execute(context.Background(), timerTask, true) @@ -580,8 +580,8 @@ func (s *timerQueueActiveTaskExecutorSuite) TestProcessActivityTimeout_RetryPoli } persistenceMutableState := s.createPersistenceMutableState(mutableState, scheduledEvent.GetEventId(), scheduledEvent.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) s.timeSource.Update(s.now.Add(2 * timerTimeout)) err = s.timerQueueActiveTaskExecutor.execute(context.Background(), timerTask, true) @@ -671,8 +671,8 @@ func (s *timerQueueActiveTaskExecutorSuite) TestProcessActivityTimeout_RetryPoli } persistenceMutableState := s.createPersistenceMutableState(mutableState, scheduledEvent.GetEventId(), scheduledEvent.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) s.timeSource.Update(s.now.Add(2 * timerTimeout)) err = s.timerQueueActiveTaskExecutor.execute(context.Background(), timerTask, true) @@ -764,7 +764,7 @@ func (s *timerQueueActiveTaskExecutorSuite) TestProcessActivityTimeout_RetryPoli mutableState.FlushBufferedEvents() persistenceMutableState := s.createPersistenceMutableState(mutableState, completeEvent.GetEventId(), completeEvent.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.timeSource.Update(s.now.Add(2 * timerTimeout)) err = s.timerQueueActiveTaskExecutor.execute(context.Background(), timerTask, true) @@ -851,7 +851,7 @@ func (s *timerQueueActiveTaskExecutorSuite) TestProcessActivityTimeout_Heartbeat } persistenceMutableState := s.createPersistenceMutableState(mutableState, scheduledEvent.GetEventId(), scheduledEvent.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) err = s.timerQueueActiveTaskExecutor.execute(context.Background(), timerTask, true) s.NoError(err) @@ -900,8 +900,8 @@ func (s *timerQueueActiveTaskExecutorSuite) TestWorkflowTaskTimeout_Fire() { } persistenceMutableState := s.createPersistenceMutableState(mutableState, startedEvent.GetEventId(), startedEvent.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) err = s.timerQueueActiveTaskExecutor.execute(context.Background(), timerTask, true) s.NoError(err) @@ -956,7 +956,7 @@ func (s *timerQueueActiveTaskExecutorSuite) TestWorkflowTaskTimeout_Noop() { } persistenceMutableState := s.createPersistenceMutableState(mutableState, startedEvent.GetEventId(), startedEvent.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) err = s.timerQueueActiveTaskExecutor.execute(context.Background(), timerTask, true) s.NoError(err) @@ -1000,8 +1000,8 @@ func (s *timerQueueActiveTaskExecutorSuite) TestWorkflowBackoffTimer_Fire() { } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) err = s.timerQueueActiveTaskExecutor.execute(context.Background(), timerTask, true) s.NoError(err) @@ -1056,7 +1056,7 @@ func (s *timerQueueActiveTaskExecutorSuite) TestWorkflowBackoffTimer_Noop() { } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) err = s.timerQueueActiveTaskExecutor.execute(context.Background(), timerTask, true) s.NoError(err) @@ -1134,7 +1134,7 @@ func (s *timerQueueActiveTaskExecutorSuite) TestActivityRetryTimer_Fire() { } persistenceMutableState := s.createPersistenceMutableState(mutableState, scheduledEvent.GetEventId(), scheduledEvent.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockMatchingClient.EXPECT().AddActivityTask( gomock.Any(), &matchingservice.AddActivityTaskRequest{ @@ -1226,7 +1226,7 @@ func (s *timerQueueActiveTaskExecutorSuite) TestActivityRetryTimer_Noop() { } persistenceMutableState := s.createPersistenceMutableState(mutableState, scheduledEvent.GetEventId(), scheduledEvent.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) err = s.timerQueueActiveTaskExecutor.execute(context.Background(), timerTask, true) s.NoError(err) @@ -1275,8 +1275,8 @@ func (s *timerQueueActiveTaskExecutorSuite) TestWorkflowTimeout_Fire() { } persistenceMutableState := s.createPersistenceMutableState(mutableState, completionEvent.GetEventId(), completionEvent.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) err = s.timerQueueActiveTaskExecutor.execute(context.Background(), timerTask, true) s.NoError(err) @@ -1336,9 +1336,9 @@ func (s *timerQueueActiveTaskExecutorSuite) TestWorkflowTimeout_Retry() { } persistenceMutableState := s.createPersistenceMutableState(mutableState, completionEvent.GetEventId(), completionEvent.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) // one for current workflow, one for new - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) err = s.timerQueueActiveTaskExecutor.execute(context.Background(), timerTask, true) s.NoError(err) @@ -1394,9 +1394,9 @@ func (s *timerQueueActiveTaskExecutorSuite) TestWorkflowTimeout_Cron() { } persistenceMutableState := s.createPersistenceMutableState(mutableState, completionEvent.GetEventId(), completionEvent.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) // one for current workflow, one for new - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) err = s.timerQueueActiveTaskExecutor.execute(context.Background(), timerTask, true) s.NoError(err) @@ -1451,8 +1451,8 @@ func (s *timerQueueActiveTaskExecutorSuite) TestWorkflowTimeout_WorkflowExpired( } persistenceMutableState := s.createPersistenceMutableState(mutableState, completionEvent.GetEventId(), completionEvent.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) err = s.timerQueueActiveTaskExecutor.execute(context.Background(), timerTask, true) s.NoError(err) diff --git a/service/history/timerQueueProcessor.go b/service/history/timerQueueProcessor.go index f790d0c8339..adf31ec7bf8 100644 --- a/service/history/timerQueueProcessor.go +++ b/service/history/timerQueueProcessor.go @@ -316,7 +316,7 @@ func (t *timerQueueProcessorImpl) completeTimers() error { t.metricsClient.IncCounter(metrics.TimerQueueProcessorScope, metrics.TaskBatchCompleteCounter) if lowerAckLevel.FireTime.Before(upperAckLevel.FireTime) { - err := t.shard.GetExecutionManager().RangeCompleteHistoryTasks(&persistence.RangeCompleteHistoryTasksRequest{ + err := t.shard.GetExecutionManager().RangeCompleteHistoryTasks(context.TODO(), &persistence.RangeCompleteHistoryTasksRequest{ ShardID: t.shard.GetShardID(), TaskCategory: tasks.CategoryTimer, InclusiveMinTaskKey: tasks.Key{ diff --git a/service/history/timerQueueStandbyTaskExecutor.go b/service/history/timerQueueStandbyTaskExecutor.go index 57f5efcd972..695cfd531f5 100644 --- a/service/history/timerQueueStandbyTaskExecutor.go +++ b/service/history/timerQueueStandbyTaskExecutor.go @@ -253,7 +253,7 @@ func (t *timerQueueStandbyTaskExecutor) executeActivityTimeoutTask( return nil, err } - err = context.UpdateWorkflowExecutionAsPassive(now) + err = context.UpdateWorkflowExecutionAsPassive(ctx, now) return nil, err } @@ -436,7 +436,7 @@ func (t *timerQueueStandbyTaskExecutor) processTimer( } }() - mutableState, err := loadMutableStateForTimerTask(executionContext, timerTask, t.metricsClient, t.logger) + mutableState, err := loadMutableStateForTimerTask(ctx, executionContext, timerTask, t.metricsClient, t.logger) if err != nil { return err } diff --git a/service/history/timerQueueStandbyTaskExecutor_test.go b/service/history/timerQueueStandbyTaskExecutor_test.go index ee1a8316d48..2c9a95f5255 100644 --- a/service/history/timerQueueStandbyTaskExecutor_test.go +++ b/service/history/timerQueueStandbyTaskExecutor_test.go @@ -256,7 +256,7 @@ func (s *timerQueueStandbyTaskExecutorSuite) TestProcessUserTimerTimeout_Pending } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockShard.SetCurrentTime(s.clusterName, s.now) err = s.timerQueueStandbyTaskExecutor.execute(context.Background(), timerTask, true) @@ -349,7 +349,7 @@ func (s *timerQueueStandbyTaskExecutorSuite) TestProcessUserTimerTimeout_Success mutableState.FlushBufferedEvents() persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockShard.SetCurrentTime(s.clusterName, s.now) err = s.timerQueueStandbyTaskExecutor.execute(context.Background(), timerTask, true) @@ -418,7 +418,7 @@ func (s *timerQueueStandbyTaskExecutorSuite) TestProcessUserTimerTimeout_Multipl mutableState.FlushBufferedEvents() persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockShard.SetCurrentTime(s.clusterName, s.now) err = s.timerQueueStandbyTaskExecutor.execute(context.Background(), timerTask, true) @@ -485,7 +485,7 @@ func (s *timerQueueStandbyTaskExecutorSuite) TestProcessActivityTimeout_Pending( } persistenceMutableState := s.createPersistenceMutableState(mutableState, scheduledEvent.GetEventId(), scheduledEvent.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockShard.SetCurrentTime(s.clusterName, s.now) err = s.timerQueueStandbyTaskExecutor.execute(context.Background(), timerTask, true) @@ -581,7 +581,7 @@ func (s *timerQueueStandbyTaskExecutorSuite) TestProcessActivityTimeout_Success( mutableState.FlushBufferedEvents() persistenceMutableState := s.createPersistenceMutableState(mutableState, completeEvent.GetEventId(), completeEvent.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockShard.SetCurrentTime(s.clusterName, s.now) err = s.timerQueueStandbyTaskExecutor.execute(context.Background(), timerTask, true) @@ -652,7 +652,7 @@ func (s *timerQueueStandbyTaskExecutorSuite) TestProcessActivityTimeout_Heartbea } persistenceMutableState := s.createPersistenceMutableState(mutableState, startedEvent.GetEventId(), startedEvent.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockShard.SetCurrentTime(s.clusterName, s.now) err = s.timerQueueStandbyTaskExecutor.execute(context.Background(), timerTask, true) @@ -733,9 +733,9 @@ func (s *timerQueueStandbyTaskExecutorSuite) TestProcessActivityTimeout_Multiple mutableState.FlushBufferedEvents() persistenceMutableState := s.createPersistenceMutableState(mutableState, completeEvent1.GetEventId(), completeEvent1.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).DoAndReturn( - func(input *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, input *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { s.Equal(1, len(input.UpdateWorkflowMutation.Tasks[tasks.CategoryTimer])) s.Equal(1, len(input.UpdateWorkflowMutation.UpsertActivityInfos)) mutableState.GetExecutionInfo().LastUpdateTime = input.UpdateWorkflowMutation.ExecutionInfo.LastUpdateTime @@ -829,7 +829,7 @@ func (s *timerQueueStandbyTaskExecutorSuite) TestProcessWorkflowTaskTimeout_Pend } persistenceMutableState := s.createPersistenceMutableState(mutableState, startedEvent.GetEventId(), startedEvent.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockShard.SetCurrentTime(s.clusterName, s.now) err = s.timerQueueStandbyTaskExecutor.execute(context.Background(), timerTask, true) @@ -935,7 +935,7 @@ func (s *timerQueueStandbyTaskExecutorSuite) TestProcessWorkflowTaskTimeout_Succ } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockShard.SetCurrentTime(s.clusterName, s.now) err = s.timerQueueStandbyTaskExecutor.execute(context.Background(), timerTask, true) @@ -984,7 +984,7 @@ func (s *timerQueueStandbyTaskExecutorSuite) TestProcessWorkflowBackoffTimer_Pen } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockShard.SetCurrentTime(s.clusterName, s.now) err = s.timerQueueStandbyTaskExecutor.execute(context.Background(), timerTask, true) @@ -1057,7 +1057,7 @@ func (s *timerQueueStandbyTaskExecutorSuite) TestProcessWorkflowBackoffTimer_Suc } persistenceMutableState := s.createPersistenceMutableState(mutableState, di.ScheduleID, di.Version) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockShard.SetCurrentTime(s.clusterName, s.now) err = s.timerQueueStandbyTaskExecutor.execute(context.Background(), timerTask, true) @@ -1109,7 +1109,7 @@ func (s *timerQueueStandbyTaskExecutorSuite) TestProcessWorkflowTimeout_Pending( } persistenceMutableState := s.createPersistenceMutableState(mutableState, completionEvent.GetEventId(), completionEvent.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockShard.SetCurrentTime(s.clusterName, s.now) err = s.timerQueueStandbyTaskExecutor.execute(context.Background(), timerTask, true) @@ -1185,7 +1185,7 @@ func (s *timerQueueStandbyTaskExecutorSuite) TestProcessWorkflowTimeout_Success( } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockShard.SetCurrentTime(s.clusterName, s.now) err = s.timerQueueStandbyTaskExecutor.execute(context.Background(), timerTask, true) diff --git a/service/history/timerQueueTaskExecutorBase.go b/service/history/timerQueueTaskExecutorBase.go index 2dfc82982f0..b044cb0b62a 100644 --- a/service/history/timerQueueTaskExecutorBase.go +++ b/service/history/timerQueueTaskExecutorBase.go @@ -93,7 +93,7 @@ func (t *timerQueueTaskExecutorBase) executeDeleteHistoryEventTask( } defer func() { release(retError) }() - mutableState, err := loadMutableStateForTimerTask(weContext, task, t.metricsClient, t.logger) + mutableState, err := loadMutableStateForTimerTask(ctx, weContext, task, t.metricsClient, t.logger) switch err.(type) { case nil: if mutableState == nil { @@ -126,6 +126,7 @@ func (t *timerQueueTaskExecutorBase) executeDeleteHistoryEventTask( } return t.deleteManager.DeleteWorkflowExecutionByRetention( + ctx, namespace.ID(task.GetNamespaceID()), workflowExecution, weContext, @@ -146,7 +147,7 @@ func (t *timerQueueTaskExecutorBase) getNamespaceIDAndWorkflowExecution( func (t *timerQueueTaskExecutorBase) deleteHistoryBranch(branchToken []byte) error { if len(branchToken) > 0 { - return t.shard.GetExecutionManager().DeleteHistoryBranch(&persistence.DeleteHistoryBranchRequest{ + return t.shard.GetExecutionManager().DeleteHistoryBranch(context.TODO(), &persistence.DeleteHistoryBranchRequest{ ShardID: t.shard.GetShardID(), BranchToken: branchToken, }) diff --git a/service/history/timerQueueTaskExecutorBase_test.go b/service/history/timerQueueTaskExecutorBase_test.go index 0d1c3460376..e7cdfe13779 100644 --- a/service/history/timerQueueTaskExecutorBase_test.go +++ b/service/history/timerQueueTaskExecutorBase_test.go @@ -124,7 +124,7 @@ func (s *timerQueueTaskExecutorBaseSuite) Test_executeDeleteHistoryEventTask_NoE s.mockCache.EXPECT().GetOrCreateWorkflowExecution(gomock.Any(), tests.NamespaceID, we, workflow.CallerTypeTask).Return(mockWeCtx, workflow.NoopReleaseFn, nil) - mockWeCtx.EXPECT().LoadWorkflowExecution().Return(mockMutableState, nil) + mockWeCtx.EXPECT().LoadWorkflowExecution(gomock.Any()).Return(mockMutableState, nil) mockMutableState.EXPECT().GetLastWriteVersion().Return(int64(1), nil) mockMutableState.EXPECT().GetExecutionInfo().Return(&persistencespb.WorkflowExecutionInfo{}) mockMutableState.EXPECT().GetNextEventID().Return(int64(2)) @@ -132,6 +132,7 @@ func (s *timerQueueTaskExecutorBaseSuite) Test_executeDeleteHistoryEventTask_NoE s.testShardContext.Resource.ClusterMetadata.EXPECT().IsGlobalNamespaceEnabled().Return(false) s.mockDeleteManager.EXPECT().DeleteWorkflowExecutionByRetention( + gomock.Any(), tests.NamespaceID, we, mockWeCtx, @@ -166,7 +167,7 @@ func (s *timerQueueTaskExecutorBaseSuite) TestArchiveHistory_DeleteFailed() { s.mockCache.EXPECT().GetOrCreateWorkflowExecution(gomock.Any(), tests.NamespaceID, we, workflow.CallerTypeTask).Return(mockWeCtx, workflow.NoopReleaseFn, nil) - mockWeCtx.EXPECT().LoadWorkflowExecution().Return(mockMutableState, nil) + mockWeCtx.EXPECT().LoadWorkflowExecution(gomock.Any()).Return(mockMutableState, nil) mockMutableState.EXPECT().GetLastWriteVersion().Return(int64(1), nil) mockMutableState.EXPECT().GetExecutionInfo().Return(&persistencespb.WorkflowExecutionInfo{}) mockMutableState.EXPECT().GetNextEventID().Return(int64(2)) @@ -174,6 +175,7 @@ func (s *timerQueueTaskExecutorBaseSuite) TestArchiveHistory_DeleteFailed() { s.testShardContext.Resource.ClusterMetadata.EXPECT().IsGlobalNamespaceEnabled().Return(false) s.mockDeleteManager.EXPECT().DeleteWorkflowExecutionByRetention( + gomock.Any(), tests.NamespaceID, we, mockWeCtx, diff --git a/service/history/transferQueueActiveTaskExecutor.go b/service/history/transferQueueActiveTaskExecutor.go index d13ac4f9942..38d635254b6 100644 --- a/service/history/transferQueueActiveTaskExecutor.go +++ b/service/history/transferQueueActiveTaskExecutor.go @@ -157,7 +157,7 @@ func (t *transferQueueActiveTaskExecutor) processActivityTask( } defer func() { release(retError) }() - mutableState, err := loadMutableStateForTransferTask(context, task, t.metricsClient, t.logger) + mutableState, err := loadMutableStateForTransferTask(ctx, context, task, t.metricsClient, t.logger) if err != nil { return err } @@ -202,7 +202,7 @@ func (t *transferQueueActiveTaskExecutor) processWorkflowTask( } defer func() { release(retError) }() - mutableState, err := loadMutableStateForTransferTask(context, task, t.metricsClient, t.logger) + mutableState, err := loadMutableStateForTransferTask(ctx, context, task, t.metricsClient, t.logger) if err != nil { return err } @@ -271,7 +271,7 @@ func (t *transferQueueActiveTaskExecutor) processCloseExecution( } defer func() { release(retError) }() - mutableState, err := loadMutableStateForTransferTask(weContext, task, t.metricsClient, t.logger) + mutableState, err := loadMutableStateForTransferTask(ctx, weContext, task, t.metricsClient, t.logger) if err != nil { return err } @@ -291,7 +291,7 @@ func (t *transferQueueActiveTaskExecutor) processCloseExecution( executionInfo := mutableState.GetExecutionInfo() executionState := mutableState.GetExecutionState() replyToParentWorkflow := mutableState.HasParentExecution() && executionInfo.NewExecutionRunId == "" - completionEvent, err := mutableState.GetCompletionEvent() + completionEvent, err := mutableState.GetCompletionEvent(ctx) if err != nil { return err } @@ -389,7 +389,7 @@ func (t *transferQueueActiveTaskExecutor) processCancelExecution( } defer func() { release(retError) }() - mutableState, err := loadMutableStateForTransferTask(executionContext, task, t.metricsClient, t.logger) + mutableState, err := loadMutableStateForTransferTask(ctx, executionContext, task, t.metricsClient, t.logger) if err != nil { return err } @@ -415,7 +415,7 @@ func (t *transferQueueActiveTaskExecutor) processCancelExecution( // handle workflow cancel itself if task.NamespaceID == task.TargetNamespaceID && task.WorkflowID == task.TargetWorkflowID { // it does not matter if the run ID is a mismatch - err = t.requestCancelExternalExecutionFailed(task, executionContext, targetNamespace, task.TargetWorkflowID, task.TargetRunID) + err = t.requestCancelExternalExecutionFailed(ctx, task, executionContext, targetNamespace, task.TargetWorkflowID, task.TargetRunID) if _, ok := err.(*serviceerror.NotFound); ok { // this could happen if this is a duplicate processing of the task, and the execution has already completed. return nil @@ -437,6 +437,7 @@ func (t *transferQueueActiveTaskExecutor) processCancelExecution( return err } return t.requestCancelExternalExecutionFailed( + ctx, task, executionContext, targetNamespace, @@ -447,6 +448,7 @@ func (t *transferQueueActiveTaskExecutor) processCancelExecution( // Record ExternalWorkflowExecutionCancelRequested in source execution return t.requestCancelExternalExecutionCompleted( + ctx, task, executionContext, targetNamespace, @@ -474,7 +476,7 @@ func (t *transferQueueActiveTaskExecutor) processSignalExecution( } defer func() { release(retError) }() - mutableState, err := loadMutableStateForTransferTask(weContext, task, t.metricsClient, t.logger) + mutableState, err := loadMutableStateForTransferTask(ctx, weContext, task, t.metricsClient, t.logger) if err != nil { return err } @@ -494,7 +496,7 @@ func (t *transferQueueActiveTaskExecutor) processSignalExecution( return nil } - initiatedEvent, err := mutableState.GetSignalExternalInitiatedEvent(task.InitiatedID) + initiatedEvent, err := mutableState.GetSignalExternalInitiatedEvent(ctx, task.InitiatedID) if err != nil { return err } @@ -510,6 +512,7 @@ func (t *transferQueueActiveTaskExecutor) processSignalExecution( if task.NamespaceID == task.TargetNamespaceID && task.WorkflowID == task.TargetWorkflowID { // it does not matter if the run ID is a mismatch return t.signalExternalExecutionFailed( + ctx, task, weContext, targetNamespace, @@ -534,6 +537,7 @@ func (t *transferQueueActiveTaskExecutor) processSignalExecution( return err } return t.signalExternalExecutionFailed( + ctx, task, weContext, targetNamespace, @@ -544,6 +548,7 @@ func (t *transferQueueActiveTaskExecutor) processSignalExecution( } err = t.signalExternalExecutionCompleted( + ctx, task, weContext, targetNamespace, @@ -593,7 +598,7 @@ func (t *transferQueueActiveTaskExecutor) processStartChildExecution( } defer func() { release(retError) }() - mutableState, err := loadMutableStateForTransferTask(context, task, t.metricsClient, t.logger) + mutableState, err := loadMutableStateForTransferTask(ctx, context, task, t.metricsClient, t.logger) if err != nil { return err } @@ -634,7 +639,7 @@ func (t *transferQueueActiveTaskExecutor) processStartChildExecution( return nil } - initiatedEvent, err := mutableState.GetChildExecutionInitiatedEvent(task.InitiatedID) + initiatedEvent, err := mutableState.GetChildExecutionInitiatedEvent(ctx, task.InitiatedID) if err != nil { return err } @@ -666,7 +671,7 @@ func (t *transferQueueActiveTaskExecutor) processStartChildExecution( // Check to see if the error is non-transient, in which case add StartChildWorkflowExecutionFailed // event and complete transfer task by setting the err = nil if _, ok := err.(*serviceerror.WorkflowExecutionAlreadyStarted); ok { - err = t.recordStartChildExecutionFailed(task, context, attributes) + err = t.recordStartChildExecutionFailed(ctx, task, context, attributes) } return err @@ -676,7 +681,7 @@ func (t *transferQueueActiveTaskExecutor) processStartChildExecution( tag.WorkflowID(attributes.WorkflowId), tag.WorkflowRunID(childRunID)) // Child execution is successfully started, record ChildExecutionStartedEvent in parent execution - err = t.recordChildExecutionStarted(task, context, attributes, childRunID) + err = t.recordChildExecutionStarted(ctx, task, context, attributes, childRunID) if err != nil { return err } @@ -710,7 +715,7 @@ func (t *transferQueueActiveTaskExecutor) processResetWorkflow( } defer func() { currentRelease(retError) }() - currentMutableState, err := loadMutableStateForTransferTask(currentContext, task, t.metricsClient, t.logger) + currentMutableState, err := loadMutableStateForTransferTask(ctx, currentContext, task, t.metricsClient, t.logger) if err != nil { return err } @@ -728,7 +733,7 @@ func (t *transferQueueActiveTaskExecutor) processResetWorkflow( if !currentMutableState.IsWorkflowExecutionRunning() { // it means this this might not be current anymore, we need to check var resp *persistence.GetCurrentExecutionResponse - resp, err = t.shard.GetExecutionManager().GetCurrentExecution(&persistence.GetCurrentExecutionRequest{ + resp, err = t.shard.GetExecutionManager().GetCurrentExecution(ctx, &persistence.GetCurrentExecutionRequest{ ShardID: t.shard.GetShardID(), NamespaceID: task.NamespaceID, WorkflowID: task.WorkflowID, @@ -800,7 +805,7 @@ func (t *transferQueueActiveTaskExecutor) processResetWorkflow( return err } defer func() { baseRelease(retError) }() - baseMutableState, err = loadMutableStateForTransferTask(baseContext, task, t.metricsClient, t.logger) + baseMutableState, err = loadMutableStateForTransferTask(ctx, baseContext, task, t.metricsClient, t.logger) if err != nil { return err } @@ -824,13 +829,14 @@ func (t *transferQueueActiveTaskExecutor) processResetWorkflow( } func (t *transferQueueActiveTaskExecutor) recordChildExecutionStarted( + ctx context.Context, task *tasks.StartChildExecutionTask, context workflow.Context, initiatedAttributes *historypb.StartChildWorkflowExecutionInitiatedEventAttributes, runID string, ) error { - return t.updateWorkflowExecution(context, true, + return t.updateWorkflowExecution(ctx, context, true, func(mutableState workflow.MutableState) error { if !mutableState.IsWorkflowExecutionRunning() { return serviceerror.NewNotFound("Workflow execution already completed.") @@ -858,12 +864,13 @@ func (t *transferQueueActiveTaskExecutor) recordChildExecutionStarted( } func (t *transferQueueActiveTaskExecutor) recordStartChildExecutionFailed( + ctx context.Context, task *tasks.StartChildExecutionTask, context workflow.Context, initiatedAttributes *historypb.StartChildWorkflowExecutionInitiatedEventAttributes, ) error { - return t.updateWorkflowExecution(context, true, + return t.updateWorkflowExecution(ctx, context, true, func(mutableState workflow.MutableState) error { if !mutableState.IsWorkflowExecutionRunning() { return serviceerror.NewNotFound("Workflow execution already completed.") @@ -910,6 +917,7 @@ func (t *transferQueueActiveTaskExecutor) createFirstWorkflowTask( } func (t *transferQueueActiveTaskExecutor) requestCancelExternalExecutionCompleted( + ctx context.Context, task *tasks.CancelExecutionTask, context workflow.Context, targetNamespace namespace.Name, @@ -917,7 +925,7 @@ func (t *transferQueueActiveTaskExecutor) requestCancelExternalExecutionComplete targetRunID string, ) error { - err := t.updateWorkflowExecution(context, true, + err := t.updateWorkflowExecution(ctx, context, true, func(mutableState workflow.MutableState) error { if !mutableState.IsWorkflowExecutionRunning() { return &serviceerror.NotFound{Message: "Workflow execution already completed."} @@ -946,6 +954,7 @@ func (t *transferQueueActiveTaskExecutor) requestCancelExternalExecutionComplete } func (t *transferQueueActiveTaskExecutor) signalExternalExecutionCompleted( + ctx context.Context, task *tasks.SignalExecutionTask, context workflow.Context, targetNamespace namespace.Name, @@ -954,7 +963,7 @@ func (t *transferQueueActiveTaskExecutor) signalExternalExecutionCompleted( control string, ) error { - err := t.updateWorkflowExecution(context, true, + err := t.updateWorkflowExecution(ctx, context, true, func(mutableState workflow.MutableState) error { if !mutableState.IsWorkflowExecutionRunning() { return &serviceerror.NotFound{Message: "Workflow execution already completed."} @@ -984,6 +993,7 @@ func (t *transferQueueActiveTaskExecutor) signalExternalExecutionCompleted( } func (t *transferQueueActiveTaskExecutor) requestCancelExternalExecutionFailed( + ctx context.Context, task *tasks.CancelExecutionTask, context workflow.Context, targetNamespace namespace.Name, @@ -991,7 +1001,7 @@ func (t *transferQueueActiveTaskExecutor) requestCancelExternalExecutionFailed( targetRunID string, ) error { - err := t.updateWorkflowExecution(context, true, + err := t.updateWorkflowExecution(ctx, context, true, func(mutableState workflow.MutableState) error { if !mutableState.IsWorkflowExecutionRunning() { return &serviceerror.NotFound{Message: "Workflow execution already completed."} @@ -1021,6 +1031,7 @@ func (t *transferQueueActiveTaskExecutor) requestCancelExternalExecutionFailed( } func (t *transferQueueActiveTaskExecutor) signalExternalExecutionFailed( + ctx context.Context, task *tasks.SignalExecutionTask, context workflow.Context, targetNamespace namespace.Name, @@ -1029,7 +1040,7 @@ func (t *transferQueueActiveTaskExecutor) signalExternalExecutionFailed( control string, ) error { - err := t.updateWorkflowExecution(context, true, + err := t.updateWorkflowExecution(ctx, context, true, func(mutableState workflow.MutableState) error { if !mutableState.IsWorkflowExecutionRunning() { return &serviceerror.NotFound{Message: "Workflow is not running."} @@ -1060,12 +1071,13 @@ func (t *transferQueueActiveTaskExecutor) signalExternalExecutionFailed( } func (t *transferQueueActiveTaskExecutor) updateWorkflowExecution( + ctx context.Context, context workflow.Context, createWorkflowTask bool, action func(builder workflow.MutableState) error, ) error { - mutableState, err := context.LoadWorkflowExecution() + mutableState, err := context.LoadWorkflowExecution(ctx) if err != nil { return err } @@ -1082,7 +1094,7 @@ func (t *transferQueueActiveTaskExecutor) updateWorkflowExecution( } } - return context.UpdateWorkflowExecutionAsActive(t.shard.GetTimeSource().Now()) + return context.UpdateWorkflowExecutionAsActive(ctx, t.shard.GetTimeSource().Now()) } func (t *transferQueueActiveTaskExecutor) requestCancelExternalExecutionWithRetry( diff --git a/service/history/transferQueueActiveTaskExecutor_test.go b/service/history/transferQueueActiveTaskExecutor_test.go index d4ca16abd3d..c6f644759bf 100644 --- a/service/history/transferQueueActiveTaskExecutor_test.go +++ b/service/history/transferQueueActiveTaskExecutor_test.go @@ -292,7 +292,7 @@ func (s *transferQueueActiveTaskExecutorSuite) TestProcessActivityTask_Success() } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockMatchingClient.EXPECT().AddActivityTask(gomock.Any(), s.createAddActivityTaskRequest(transferTask, ai), gomock.Any()).Return(&matchingservice.AddActivityTaskResponse{}, nil) err = s.transferQueueActiveTaskExecutor.execute(context.Background(), transferTask, true) @@ -355,7 +355,7 @@ func (s *transferQueueActiveTaskExecutorSuite) TestProcessActivityTask_Duplicati mutableState.FlushBufferedEvents() persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) err = s.transferQueueActiveTaskExecutor.execute(context.Background(), transferTask, true) s.Nil(err) @@ -406,7 +406,7 @@ func (s *transferQueueActiveTaskExecutorSuite) TestProcessWorkflowTask_FirstWork } persistenceMutableState := s.createPersistenceMutableState(mutableState, di.ScheduleID, di.Version) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockMatchingClient.EXPECT().AddWorkflowTask(gomock.Any(), s.createAddWorkflowTaskRequest(transferTask, mutableState), gomock.Any()).Return(&matchingservice.AddWorkflowTaskResponse{}, nil) err = s.transferQueueActiveTaskExecutor.execute(context.Background(), transferTask, true) @@ -465,7 +465,7 @@ func (s *transferQueueActiveTaskExecutorSuite) TestProcessWorkflowTask_NonFirstW } persistenceMutableState := s.createPersistenceMutableState(mutableState, di.ScheduleID, di.Version) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockMatchingClient.EXPECT().AddWorkflowTask(gomock.Any(), s.createAddWorkflowTaskRequest(transferTask, mutableState), gomock.Any()).Return(&matchingservice.AddWorkflowTaskResponse{}, nil) err = s.transferQueueActiveTaskExecutor.execute(context.Background(), transferTask, true) @@ -527,7 +527,7 @@ func (s *transferQueueActiveTaskExecutorSuite) TestProcessWorkflowTask_Sticky_No } persistenceMutableState := s.createPersistenceMutableState(mutableState, di.ScheduleID, di.Version) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockMatchingClient.EXPECT().AddWorkflowTask(gomock.Any(), s.createAddWorkflowTaskRequest(transferTask, mutableState), gomock.Any()).Return(&matchingservice.AddWorkflowTaskResponse{}, nil) err = s.transferQueueActiveTaskExecutor.execute(context.Background(), transferTask, true) @@ -592,7 +592,7 @@ func (s *transferQueueActiveTaskExecutorSuite) TestProcessWorkflowTask_WorkflowT } persistenceMutableState := s.createPersistenceMutableState(mutableState, di.ScheduleID, di.Version) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockMatchingClient.EXPECT().AddWorkflowTask(gomock.Any(), s.createAddWorkflowTaskRequest(transferTask, mutableState), gomock.Any()).Return(&matchingservice.AddWorkflowTaskResponse{}, nil) err = s.transferQueueActiveTaskExecutor.execute(context.Background(), transferTask, true) @@ -644,7 +644,7 @@ func (s *transferQueueActiveTaskExecutorSuite) TestProcessWorkflowTask_Duplicati } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) err = s.transferQueueActiveTaskExecutor.execute(context.Background(), transferTask, true) s.Nil(err) @@ -709,7 +709,7 @@ func (s *transferQueueActiveTaskExecutorSuite) TestProcessCloseExecution_HasPare } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockHistoryClient.EXPECT().RecordChildExecutionCompleted(gomock.Any(), &historyservice.RecordChildExecutionCompletedRequest{ NamespaceId: parentNamespaceID, WorkflowExecution: parentExecution, @@ -768,7 +768,7 @@ func (s *transferQueueActiveTaskExecutorSuite) TestProcessCloseExecution_NoParen } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockArchivalMetadata.EXPECT().GetVisibilityConfig().Return(archiver.NewArchivalConfig("enabled", dc.GetStringPropertyFn("enabled"), dc.GetBoolPropertyFn(true), "disabled", "random URI")) s.mockArchivalClient.EXPECT().Archive(gomock.Any(), gomock.Any()).Return(nil, nil) s.mockSearchAttributesProvider.EXPECT().GetSearchAttributes(gomock.Any(), false) @@ -910,7 +910,7 @@ func (s *transferQueueActiveTaskExecutorSuite) TestProcessCloseExecution_NoParen } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockArchivalMetadata.EXPECT().GetVisibilityConfig().Return(archiver.NewDisabledArchvialConfig()) s.mockHistoryClient.EXPECT().RequestCancelWorkflowExecution(gomock.Any(), gomock.Any()).DoAndReturn( func(_ context.Context, request *historyservice.RequestCancelWorkflowExecutionRequest, _ ...grpc.CallOption) (*historyservice.RequestCancelWorkflowExecutionResponse, error) { @@ -1019,7 +1019,7 @@ func (s *transferQueueActiveTaskExecutorSuite) TestProcessCloseExecution_NoParen } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockArchivalMetadata.EXPECT().GetVisibilityConfig().Return(archiver.NewDisabledArchvialConfig()) s.mockParentClosePolicyClient.EXPECT().SendParentClosePolicyRequest(gomock.Any()).DoAndReturn( func(request parentclosepolicy.Request) error { @@ -1116,7 +1116,7 @@ func (s *transferQueueActiveTaskExecutorSuite) TestProcessCloseExecution_NoParen } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockArchivalMetadata.EXPECT().GetVisibilityConfig().Return(archiver.NewDisabledArchvialConfig()) err = s.transferQueueActiveTaskExecutor.execute(context.Background(), transferTask, true) @@ -1177,9 +1177,9 @@ func (s *transferQueueActiveTaskExecutorSuite) TestProcessCancelExecution_Succes } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockHistoryClient.EXPECT().RequestCancelWorkflowExecution(gomock.Any(), s.createRequestCancelWorkflowExecutionRequest(s.targetNamespace, transferTask, rci)).Return(nil, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) s.mockClusterMetadata.EXPECT().ClusterNameForFailoverVersion(s.namespaceEntry.IsGlobalNamespace(), s.version).Return(cluster.TestCurrentClusterName).AnyTimes() err = s.transferQueueActiveTaskExecutor.execute(context.Background(), transferTask, true) @@ -1240,9 +1240,9 @@ func (s *transferQueueActiveTaskExecutorSuite) TestProcessCancelExecution_Failur } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockHistoryClient.EXPECT().RequestCancelWorkflowExecution(gomock.Any(), s.createRequestCancelWorkflowExecutionRequest(s.targetNamespace, transferTask, rci)).Return(nil, serviceerror.NewNotFound("")) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) s.mockClusterMetadata.EXPECT().ClusterNameForFailoverVersion(gomock.Any(), s.version).Return(cluster.TestCurrentClusterName).AnyTimes() err = s.transferQueueActiveTaskExecutor.execute(context.Background(), transferTask, true) @@ -1307,7 +1307,7 @@ func (s *transferQueueActiveTaskExecutorSuite) TestProcessCancelExecution_Duplic mutableState.FlushBufferedEvents() persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) err = s.transferQueueActiveTaskExecutor.execute(context.Background(), transferTask, true) s.Nil(err) @@ -1376,9 +1376,9 @@ func (s *transferQueueActiveTaskExecutorSuite) TestProcessSignalExecution_Succes } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockHistoryClient.EXPECT().SignalWorkflowExecution(gomock.Any(), s.createSignalWorkflowExecutionRequest(s.targetNamespace, transferTask, si, attributes)).Return(nil, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) s.mockClusterMetadata.EXPECT().ClusterNameForFailoverVersion(s.namespaceEntry.IsGlobalNamespace(), s.version).Return(cluster.TestCurrentClusterName).AnyTimes() s.mockHistoryClient.EXPECT().RemoveSignalMutableState(gomock.Any(), &historyservice.RemoveSignalMutableStateRequest{ @@ -1457,9 +1457,9 @@ func (s *transferQueueActiveTaskExecutorSuite) TestProcessSignalExecution_Failur } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockHistoryClient.EXPECT().SignalWorkflowExecution(gomock.Any(), s.createSignalWorkflowExecutionRequest(s.targetNamespace, transferTask, si, attributes)).Return(nil, serviceerror.NewNotFound("")) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) s.mockClusterMetadata.EXPECT().ClusterNameForFailoverVersion(s.namespaceEntry.IsGlobalNamespace(), s.version).Return(cluster.TestCurrentClusterName).AnyTimes() err = s.transferQueueActiveTaskExecutor.execute(context.Background(), transferTask, true) @@ -1532,7 +1532,7 @@ func (s *transferQueueActiveTaskExecutorSuite) TestProcessSignalExecution_Duplic mutableState.FlushBufferedEvents() persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) err = s.transferQueueActiveTaskExecutor.execute(context.Background(), transferTask, true) s.Nil(err) @@ -1593,7 +1593,7 @@ func (s *transferQueueActiveTaskExecutorSuite) TestProcessStartChildExecution_Su } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockHistoryClient.EXPECT().StartWorkflowExecution(gomock.Any(), s.createChildWorkflowExecutionRequest( s.namespace, s.childNamespace, @@ -1601,7 +1601,7 @@ func (s *transferQueueActiveTaskExecutorSuite) TestProcessStartChildExecution_Su mutableState, ci, )).Return(&historyservice.StartWorkflowExecutionResponse{RunId: childRunID}, nil) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) s.mockClusterMetadata.EXPECT().ClusterNameForFailoverVersion(s.namespaceEntry.IsGlobalNamespace(), s.version).Return(cluster.TestCurrentClusterName).AnyTimes() s.mockHistoryClient.EXPECT().ScheduleWorkflowTask(gomock.Any(), &historyservice.ScheduleWorkflowTaskRequest{ NamespaceId: tests.ChildNamespaceID.String(), @@ -1682,7 +1682,7 @@ func (s *transferQueueActiveTaskExecutorSuite) TestProcessStartChildExecution_Fa } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockHistoryClient.EXPECT().StartWorkflowExecution(gomock.Any(), s.createChildWorkflowExecutionRequest( s.namespace, s.childNamespace, @@ -1690,7 +1690,7 @@ func (s *transferQueueActiveTaskExecutorSuite) TestProcessStartChildExecution_Fa mutableState, ci, )).Return(nil, serviceerror.NewWorkflowExecutionAlreadyStarted("msg", "", "")) - s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) + s.mockExecutionMgr.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(tests.UpdateWorkflowExecutionResponse, nil) s.mockClusterMetadata.EXPECT().ClusterNameForFailoverVersion(s.namespaceEntry.IsGlobalNamespace(), s.version).Return(cluster.TestCurrentClusterName).AnyTimes() err = s.transferQueueActiveTaskExecutor.execute(context.Background(), transferTask, true) @@ -1768,7 +1768,7 @@ func (s *transferQueueActiveTaskExecutorSuite) TestProcessStartChildExecution_Su ci.StartedId = event.GetEventId() persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockHistoryClient.EXPECT().ScheduleWorkflowTask(gomock.Any(), &historyservice.ScheduleWorkflowTaskRequest{ NamespaceId: tests.ChildNamespaceID.String(), WorkflowExecution: &commonpb.WorkflowExecution{ @@ -1859,7 +1859,7 @@ func (s *transferQueueActiveTaskExecutorSuite) TestProcessStartChildExecution_Du mutableState.FlushBufferedEvents() persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) err = s.transferQueueActiveTaskExecutor.execute(context.Background(), transferTask, true) s.Nil(err) @@ -2028,7 +2028,7 @@ func (s *transferQueueActiveTaskExecutorSuite) createChildWorkflowExecutionReque ci *persistencespb.ChildExecutionInfo, ) *historyservice.StartWorkflowExecutionRequest { - event, err := mutableState.GetChildExecutionInitiatedEvent(task.InitiatedID) + event, err := mutableState.GetChildExecutionInitiatedEvent(context.Background(), task.InitiatedID) s.NoError(err) attributes := event.GetStartChildWorkflowExecutionInitiatedEventAttributes() execution := commonpb.WorkflowExecution{ diff --git a/service/history/transferQueueProcessor.go b/service/history/transferQueueProcessor.go index b9b13e311fc..6e6308542ce 100644 --- a/service/history/transferQueueProcessor.go +++ b/service/history/transferQueueProcessor.go @@ -312,7 +312,7 @@ func (t *transferQueueProcessorImpl) completeTransfer() error { t.metricsClient.IncCounter(metrics.TransferQueueProcessorScope, metrics.TaskBatchCompleteCounter) if lowerAckLevel < upperAckLevel { - err := t.shard.GetExecutionManager().RangeCompleteHistoryTasks(&persistence.RangeCompleteHistoryTasksRequest{ + err := t.shard.GetExecutionManager().RangeCompleteHistoryTasks(context.TODO(), &persistence.RangeCompleteHistoryTasksRequest{ ShardID: t.shard.GetShardID(), TaskCategory: tasks.CategoryTransfer, InclusiveMinTaskKey: tasks.Key{ diff --git a/service/history/transferQueueProcessorBase.go b/service/history/transferQueueProcessorBase.go index 56e21144ac4..472aa1b22f1 100644 --- a/service/history/transferQueueProcessorBase.go +++ b/service/history/transferQueueProcessorBase.go @@ -25,6 +25,8 @@ package history import ( + "context" + "go.temporal.io/server/common/log" "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/persistence" @@ -73,7 +75,7 @@ func (t *transferQueueProcessorBase) readTasks( readLevel int64, ) ([]tasks.Task, bool, error) { - response, err := t.executionManager.GetHistoryTasks(&persistence.GetHistoryTasksRequest{ + response, err := t.executionManager.GetHistoryTasks(context.TODO(), &persistence.GetHistoryTasksRequest{ ShardID: t.shard.GetShardID(), TaskCategory: tasks.CategoryTransfer, InclusiveMinTaskKey: tasks.Key{ diff --git a/service/history/transferQueueStandbyTaskExecutor.go b/service/history/transferQueueStandbyTaskExecutor.go index e5765fdb5db..e598d38d653 100644 --- a/service/history/transferQueueStandbyTaskExecutor.go +++ b/service/history/transferQueueStandbyTaskExecutor.go @@ -244,7 +244,7 @@ func (t *transferQueueStandbyTaskExecutor) processCloseExecution( return nil, nil } - completionEvent, err := mutableState.GetCompletionEvent() + completionEvent, err := mutableState.GetCompletionEvent(ctx) if err != nil { return nil, err } @@ -439,7 +439,7 @@ func (t *transferQueueStandbyTaskExecutor) processTransfer( } }() - mutableState, err := loadMutableStateForTransferTask(context, taskInfo, t.metricsClient, t.logger) + mutableState, err := loadMutableStateForTransferTask(ctx, context, taskInfo, t.metricsClient, t.logger) if err != nil || mutableState == nil { return err } diff --git a/service/history/transferQueueStandbyTaskExecutor_test.go b/service/history/transferQueueStandbyTaskExecutor_test.go index e209ee1a393..736c79ab8f2 100644 --- a/service/history/transferQueueStandbyTaskExecutor_test.go +++ b/service/history/transferQueueStandbyTaskExecutor_test.go @@ -245,7 +245,7 @@ func (s *transferQueueStandbyTaskExecutorSuite) TestProcessActivityTask_Pending( } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockShard.SetCurrentTime(s.clusterName, now) err = s.transferQueueStandbyTaskExecutor.execute(context.Background(), transferTask, true) @@ -303,7 +303,7 @@ func (s *transferQueueStandbyTaskExecutorSuite) TestProcessActivityTask_Pending_ } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockMatchingClient.EXPECT().AddActivityTask(gomock.Any(), gomock.Any(), gomock.Any()).Return(&matchingservice.AddActivityTaskResponse{}, nil) s.mockShard.SetCurrentTime(s.clusterName, now) @@ -368,7 +368,7 @@ func (s *transferQueueStandbyTaskExecutorSuite) TestProcessActivityTask_Success( mutableState.FlushBufferedEvents() persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockShard.SetCurrentTime(s.clusterName, now) err = s.transferQueueStandbyTaskExecutor.execute(context.Background(), transferTask, true) @@ -418,7 +418,7 @@ func (s *transferQueueStandbyTaskExecutorSuite) TestProcessWorkflowTask_Pending( } persistenceMutableState := s.createPersistenceMutableState(mutableState, di.ScheduleID, di.Version) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockShard.SetCurrentTime(s.clusterName, now) err = s.transferQueueStandbyTaskExecutor.execute(context.Background(), transferTask, true) @@ -469,7 +469,7 @@ func (s *transferQueueStandbyTaskExecutorSuite) TestProcessWorkflowTask_Pending_ } persistenceMutableState := s.createPersistenceMutableState(mutableState, di.ScheduleID, di.Version) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockMatchingClient.EXPECT().AddWorkflowTask(gomock.Any(), gomock.Any(), gomock.Any()).Return(&matchingservice.AddWorkflowTaskResponse{}, nil) s.mockShard.SetCurrentTime(s.clusterName, now) @@ -523,7 +523,7 @@ func (s *transferQueueStandbyTaskExecutorSuite) TestProcessWorkflowTask_Success_ di.StartedID = event.GetEventId() persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockShard.SetCurrentTime(s.clusterName, now) err = s.transferQueueStandbyTaskExecutor.execute(context.Background(), transferTask, true) @@ -581,7 +581,7 @@ func (s *transferQueueStandbyTaskExecutorSuite) TestProcessWorkflowTask_Success_ di.StartedID = event.GetEventId() persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockShard.SetCurrentTime(s.clusterName, now) err = s.transferQueueStandbyTaskExecutor.execute(context.Background(), transferTask, true) @@ -634,7 +634,7 @@ func (s *transferQueueStandbyTaskExecutorSuite) TestProcessCloseExecution() { } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockArchivalMetadata.EXPECT().GetVisibilityConfig().Return(archiver.NewDisabledArchvialConfig()) s.mockShard.SetCurrentTime(s.clusterName, now) @@ -699,7 +699,7 @@ func (s *transferQueueStandbyTaskExecutorSuite) TestProcessCancelExecution_Pendi } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockShard.SetCurrentTime(s.clusterName, now) err = s.transferQueueStandbyTaskExecutor.execute(context.Background(), transferTask, true) @@ -789,7 +789,7 @@ func (s *transferQueueStandbyTaskExecutorSuite) TestProcessCancelExecution_Succe mutableState.FlushBufferedEvents() persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockShard.SetCurrentTime(s.clusterName, now) err = s.transferQueueStandbyTaskExecutor.execute(context.Background(), transferTask, true) @@ -854,7 +854,7 @@ func (s *transferQueueStandbyTaskExecutorSuite) TestProcessSignalExecution_Pendi } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockShard.SetCurrentTime(s.clusterName, now) err = s.transferQueueStandbyTaskExecutor.execute(context.Background(), transferTask, true) @@ -946,7 +946,7 @@ func (s *transferQueueStandbyTaskExecutorSuite) TestProcessSignalExecution_Succe mutableState.FlushBufferedEvents() persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockShard.SetCurrentTime(s.clusterName, now) err = s.transferQueueStandbyTaskExecutor.execute(context.Background(), transferTask, true) @@ -1008,7 +1008,7 @@ func (s *transferQueueStandbyTaskExecutorSuite) TestProcessStartChildExecution_P } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockShard.SetCurrentTime(s.clusterName, now) err = s.transferQueueStandbyTaskExecutor.execute(context.Background(), transferTask, true) @@ -1099,7 +1099,7 @@ func (s *transferQueueStandbyTaskExecutorSuite) TestProcessStartChildExecution_S childInfo.StartedId = event.GetEventId() persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockShard.SetCurrentTime(s.clusterName, now) err = s.transferQueueStandbyTaskExecutor.execute(context.Background(), transferTask, true) diff --git a/service/history/transferQueueTaskExecutorBase.go b/service/history/transferQueueTaskExecutorBase.go index 944e071e3c1..5f40aca3847 100644 --- a/service/history/transferQueueTaskExecutorBase.go +++ b/service/history/transferQueueTaskExecutorBase.go @@ -237,7 +237,7 @@ func (t *transferQueueTaskExecutorBase) processDeleteExecutionTask( } defer func() { release(retError) }() - mutableState, err := loadMutableStateForTransferTask(weCtx, task, t.metricsClient, t.logger) + mutableState, err := loadMutableStateForTransferTask(ctx, weCtx, task, t.metricsClient, t.logger) if err != nil { return err } @@ -252,6 +252,7 @@ func (t *transferQueueTaskExecutorBase) processDeleteExecutionTask( } return t.workflowDeleteManager.DeleteWorkflowExecution( + ctx, namespace.ID(task.GetNamespaceID()), workflowExecution, weCtx, diff --git a/service/history/visibilityQueueProcessor.go b/service/history/visibilityQueueProcessor.go index 1ae4f27e538..9a94250dabd 100644 --- a/service/history/visibilityQueueProcessor.go +++ b/service/history/visibilityQueueProcessor.go @@ -255,7 +255,7 @@ func (t *visibilityQueueProcessorImpl) completeTask() error { t.metricsClient.IncCounter(metrics.VisibilityQueueProcessorScope, metrics.TaskBatchCompleteCounter) if lowerAckLevel < upperAckLevel { - err := t.shard.GetExecutionManager().RangeCompleteHistoryTasks(&persistence.RangeCompleteHistoryTasksRequest{ + err := t.shard.GetExecutionManager().RangeCompleteHistoryTasks(context.TODO(), &persistence.RangeCompleteHistoryTasksRequest{ ShardID: t.shard.GetShardID(), TaskCategory: tasks.CategoryVisibility, InclusiveMinTaskKey: tasks.Key{ @@ -306,7 +306,7 @@ func (t *visibilityQueueProcessorImpl) readTasks( readLevel int64, ) ([]tasks.Task, bool, error) { - response, err := t.executionManager.GetHistoryTasks(&persistence.GetHistoryTasksRequest{ + response, err := t.executionManager.GetHistoryTasks(context.TODO(), &persistence.GetHistoryTasksRequest{ ShardID: t.shard.GetShardID(), TaskCategory: tasks.CategoryVisibility, InclusiveMinTaskKey: tasks.Key{ diff --git a/service/history/visibilityQueueTaskExecutor.go b/service/history/visibilityQueueTaskExecutor.go index c1de7c861f2..88b938fdbd3 100644 --- a/service/history/visibilityQueueTaskExecutor.go +++ b/service/history/visibilityQueueTaskExecutor.go @@ -115,7 +115,7 @@ func (t *visibilityQueueTaskExecutor) processStartExecution( } defer func() { release(retError) }() - mutableState, err := weContext.LoadWorkflowExecution() + mutableState, err := weContext.LoadWorkflowExecution(ctx) if err != nil { return err } @@ -189,7 +189,7 @@ func (t *visibilityQueueTaskExecutor) processUpsertExecution( } defer func() { release(retError) }() - mutableState, err := weContext.LoadWorkflowExecution() + mutableState, err := weContext.LoadWorkflowExecution(ctx) if err != nil { return err } @@ -337,7 +337,7 @@ func (t *visibilityQueueTaskExecutor) processCloseExecution( } defer func() { release(retError) }() - mutableState, err := weContext.LoadWorkflowExecution() + mutableState, err := weContext.LoadWorkflowExecution(ctx) if err != nil { return err } @@ -356,7 +356,7 @@ func (t *visibilityQueueTaskExecutor) processCloseExecution( executionInfo := mutableState.GetExecutionInfo() executionState := mutableState.GetExecutionState() - completionEvent, err := mutableState.GetCompletionEvent() + completionEvent, err := mutableState.GetCompletionEvent(ctx) if err != nil { return err } diff --git a/service/history/visibilityQueueTaskExecutor_test.go b/service/history/visibilityQueueTaskExecutor_test.go index 5375587ed88..2122af12182 100644 --- a/service/history/visibilityQueueTaskExecutor_test.go +++ b/service/history/visibilityQueueTaskExecutor_test.go @@ -235,7 +235,7 @@ func (s *visibilityQueueTaskExecutorSuite) TestProcessCloseExecution() { } persistenceMutableState := s.createPersistenceMutableState(mutableState, event.GetEventId(), event.GetVersion()) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockVisibilityMgr.EXPECT().RecordWorkflowExecutionClosed(gomock.Any()).Return(nil) err = s.visibilityQueueTaskExecutor.execute(context.Background(), visibilityTask, true) @@ -287,7 +287,7 @@ func (s *visibilityQueueTaskExecutorSuite) TestProcessRecordWorkflowStartedTask( } persistenceMutableState := s.createPersistenceMutableState(mutableState, di.ScheduleID, di.Version) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockVisibilityMgr.EXPECT().RecordWorkflowExecutionStarted(s.createRecordWorkflowExecutionStartedRequest(s.namespace, event, visibilityTask, mutableState, backoff, taskQueueName)).Return(nil) err = s.visibilityQueueTaskExecutor.execute(context.Background(), visibilityTask, true) @@ -334,7 +334,7 @@ func (s *visibilityQueueTaskExecutorSuite) TestProcessUpsertWorkflowSearchAttrib } persistenceMutableState := s.createPersistenceMutableState(mutableState, di.ScheduleID, di.Version) - s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) + s.mockExecutionMgr.EXPECT().GetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.GetWorkflowExecutionResponse{State: persistenceMutableState}, nil) s.mockVisibilityMgr.EXPECT().UpsertWorkflowExecution(s.createUpsertWorkflowSearchAttributesRequest(s.namespace, visibilityTask, mutableState, taskQueueName)).Return(nil) err = s.visibilityQueueTaskExecutor.execute(context.Background(), visibilityTask, true) diff --git a/service/history/workflow/cache.go b/service/history/workflow/cache.go index 967bf0e2a6f..983bedfa92b 100644 --- a/service/history/workflow/cache.go +++ b/service/history/workflow/cache.go @@ -138,7 +138,7 @@ func (c *CacheImpl) GetOrCreateWorkflowExecution( caller CallerType, ) (Context, ReleaseCacheFunc, error) { - if err := c.validateWorkflowExecutionInfo(namespaceID, &execution); err != nil { + if err := c.validateWorkflowExecutionInfo(ctx, namespaceID, &execution); err != nil { return nil, nil, err } @@ -227,6 +227,7 @@ func (c *CacheImpl) makeReleaseFunc( } func (c *CacheImpl) validateWorkflowExecutionInfo( + ctx context.Context, namespaceID namespace.ID, execution *commonpb.WorkflowExecution, ) error { @@ -242,7 +243,7 @@ func (c *CacheImpl) validateWorkflowExecutionInfo( // RunID is not provided, lets try to retrieve the RunID for current active execution if execution.GetRunId() == "" { - response, err := c.getCurrentExecutionWithRetry(&persistence.GetCurrentExecutionRequest{ + response, err := c.getCurrentExecutionWithRetry(ctx, &persistence.GetCurrentExecutionRequest{ ShardID: c.shard.GetShardID(), NamespaceID: namespaceID.String(), WorkflowID: execution.GetWorkflowId(), @@ -260,13 +261,14 @@ func (c *CacheImpl) validateWorkflowExecutionInfo( } func (c *CacheImpl) getCurrentExecutionWithRetry( + ctx context.Context, request *persistence.GetCurrentExecutionRequest, ) (*persistence.GetCurrentExecutionResponse, error) { var response *persistence.GetCurrentExecutionResponse op := func() error { var err error - response, err = c.executionManager.GetCurrentExecution(request) + response, err = c.executionManager.GetCurrentExecution(ctx, request) return err } diff --git a/service/history/workflow/context.go b/service/history/workflow/context.go index 2063587418c..94173e56539 100644 --- a/service/history/workflow/context.go +++ b/service/history/workflow/context.go @@ -72,8 +72,8 @@ type ( GetWorkflowID() string GetRunID() string - LoadWorkflowExecution() (MutableState, error) - LoadExecutionStats() (*persistencespb.ExecutionStats, error) + LoadWorkflowExecution(ctx context.Context) (MutableState, error) + LoadExecutionStats(ctx context.Context) (*persistencespb.ExecutionStats, error) Clear() Lock(ctx context.Context, caller CallerType) error @@ -87,10 +87,12 @@ type ( ) error PersistWorkflowEvents( + ctx context.Context, workflowEvents *persistence.WorkflowEvents, ) (int64, error) CreateWorkflowExecution( + ctx context.Context, now time.Time, createMode persistence.CreateWorkflowMode, prevRunID string, @@ -100,6 +102,7 @@ type ( newWorkflowEvents []*persistence.WorkflowEvents, ) error ConflictResolveWorkflowExecution( + ctx context.Context, now time.Time, conflictResolveMode persistence.ConflictResolveWorkflowMode, resetMutableState MutableState, @@ -110,22 +113,27 @@ type ( currentTransactionPolicy *TransactionPolicy, ) error UpdateWorkflowExecutionAsActive( + ctx context.Context, now time.Time, ) error UpdateWorkflowExecutionWithNewAsActive( + ctx context.Context, now time.Time, newContext Context, newMutableState MutableState, ) error UpdateWorkflowExecutionAsPassive( + ctx context.Context, now time.Time, ) error UpdateWorkflowExecutionWithNewAsPassive( + ctx context.Context, now time.Time, newContext Context, newMutableState MutableState, ) error UpdateWorkflowExecutionWithNew( + ctx context.Context, now time.Time, updateMode persistence.UpdateWorkflowMode, newContext Context, @@ -134,6 +142,7 @@ type ( newWorkflowTransactionPolicy *TransactionPolicy, ) error SetWorkflowExecution( + ctx context.Context, now time.Time, ) error } @@ -247,15 +256,15 @@ func (c *ContextImpl) SetHistorySize(size int64) { c.stats.HistorySize = size } -func (c *ContextImpl) LoadExecutionStats() (*persistencespb.ExecutionStats, error) { - _, err := c.LoadWorkflowExecution() +func (c *ContextImpl) LoadExecutionStats(ctx context.Context) (*persistencespb.ExecutionStats, error) { + _, err := c.LoadWorkflowExecution(ctx) if err != nil { return nil, err } return c.stats, nil } -func (c *ContextImpl) LoadWorkflowExecution() (MutableState, error) { +func (c *ContextImpl) LoadWorkflowExecution(ctx context.Context) (MutableState, error) { namespaceEntry, err := c.shard.GetNamespaceRegistry().GetNamespaceByID(c.GetNamespaceID()) if err != nil { return nil, err @@ -273,6 +282,7 @@ func (c *ContextImpl) LoadWorkflowExecution() (MutableState, error) { } c.MutableState, err = newMutableStateBuilderFromDB( + ctx, c.shard, c.shard.GetEventsCache(), c.logger, @@ -296,6 +306,7 @@ func (c *ContextImpl) LoadWorkflowExecution() (MutableState, error) { } if err = c.UpdateWorkflowExecutionAsActive( + ctx, c.shard.GetTimeSource().Now(), ); err != nil { return nil, err @@ -313,12 +324,14 @@ func (c *ContextImpl) LoadWorkflowExecution() (MutableState, error) { } func (c *ContextImpl) PersistWorkflowEvents( + ctx context.Context, workflowEvents *persistence.WorkflowEvents, ) (int64, error) { - return PersistWorkflowEvents(c.shard, workflowEvents) + return PersistWorkflowEvents(ctx, c.shard, workflowEvents) } func (c *ContextImpl) CreateWorkflowExecution( + ctx context.Context, _ time.Time, createMode persistence.CreateWorkflowMode, prevRunID string, @@ -346,6 +359,7 @@ func (c *ContextImpl) CreateWorkflowExecution( } resp, err := createWorkflowExecutionWithRetry( + ctx, c.shard, createRequest, ) @@ -365,6 +379,7 @@ func (c *ContextImpl) CreateWorkflowExecution( } func (c *ContextImpl) ConflictResolveWorkflowExecution( + ctx context.Context, now time.Time, conflictResolveMode persistence.ConflictResolveWorkflowMode, resetMutableState MutableState, @@ -446,6 +461,7 @@ func (c *ContextImpl) ConflictResolveWorkflowExecution( } if resetWorkflowSizeDiff, newWorkflowSizeDiff, currentWorkflowSizeDiff, err := c.transaction.ConflictResolveWorkflowExecution( + ctx, conflictResolveMode, resetWorkflow, resetWorkflowEventsSeq, @@ -474,16 +490,18 @@ func (c *ContextImpl) ConflictResolveWorkflowExecution( } func (c *ContextImpl) UpdateWorkflowExecutionAsActive( + ctx context.Context, now time.Time, ) error { // We only perform this check on active cluster for the namespace - forceTerminate, err := c.enforceSizeCheck() + forceTerminate, err := c.enforceSizeCheck(ctx) if err != nil { return err } if err := c.UpdateWorkflowExecutionWithNew( + ctx, now, persistence.UpdateWorkflowModeUpdateCurrent, nil, @@ -505,12 +523,14 @@ func (c *ContextImpl) UpdateWorkflowExecutionAsActive( } func (c *ContextImpl) UpdateWorkflowExecutionWithNewAsActive( + ctx context.Context, now time.Time, newContext Context, newMutableState MutableState, ) error { return c.UpdateWorkflowExecutionWithNew( + ctx, now, persistence.UpdateWorkflowModeUpdateCurrent, newContext, @@ -521,10 +541,12 @@ func (c *ContextImpl) UpdateWorkflowExecutionWithNewAsActive( } func (c *ContextImpl) UpdateWorkflowExecutionAsPassive( + ctx context.Context, now time.Time, ) error { return c.UpdateWorkflowExecutionWithNew( + ctx, now, persistence.UpdateWorkflowModeUpdateCurrent, nil, @@ -535,12 +557,14 @@ func (c *ContextImpl) UpdateWorkflowExecutionAsPassive( } func (c *ContextImpl) UpdateWorkflowExecutionWithNewAsPassive( + ctx context.Context, now time.Time, newContext Context, newMutableState MutableState, ) error { return c.UpdateWorkflowExecutionWithNew( + ctx, now, persistence.UpdateWorkflowModeUpdateCurrent, newContext, @@ -551,6 +575,7 @@ func (c *ContextImpl) UpdateWorkflowExecutionWithNewAsPassive( } func (c *ContextImpl) UpdateWorkflowExecutionWithNew( + ctx context.Context, now time.Time, updateMode persistence.UpdateWorkflowMode, newContext Context, @@ -614,6 +639,7 @@ func (c *ContextImpl) UpdateWorkflowExecutionWithNew( } if currentWorkflowSizeDiff, newWorkflowSizeDiff, err := c.transaction.UpdateWorkflowExecution( + ctx, updateMode, currentWorkflow, currentWorkflowEventsSeq, @@ -644,7 +670,7 @@ func (c *ContextImpl) UpdateWorkflowExecutionWithNew( return nil } -func (c *ContextImpl) SetWorkflowExecution(now time.Time) (retError error) { +func (c *ContextImpl) SetWorkflowExecution(ctx context.Context, now time.Time) (retError error) { defer func() { if retError != nil { c.Clear() @@ -667,6 +693,7 @@ func (c *ContextImpl) SetWorkflowExecution(now time.Time) (retError error) { } return c.transaction.SetWorkflowExecution( + ctx, resetWorkflowSnapshot, c.MutableState.GetNamespaceEntry().ActiveClusterName(), ) @@ -794,6 +821,7 @@ func (c *ContextImpl) ReapplyEvents( return err } + // TODO: should we pass in a context instead of using the default one? ctx, cancel := context.WithTimeout(context.Background(), defaultRemoteCallTimeout) defer cancel() @@ -840,7 +868,9 @@ func (c *ContextImpl) ReapplyEvents( } // Returns true if execution is forced terminated -func (c *ContextImpl) enforceSizeCheck() (bool, error) { +func (c *ContextImpl) enforceSizeCheck( + ctx context.Context, +) (bool, error) { namespaceName := c.GetNamespace().String() historySizeLimitWarn := c.config.HistorySizeLimitWarn(namespaceName) historySizeLimitError := c.config.HistorySizeLimitError(namespaceName) @@ -864,7 +894,7 @@ func (c *ContextImpl) enforceSizeCheck() (bool, error) { c.Clear() // Reload mutable state - mutableState, err := c.LoadWorkflowExecution() + mutableState, err := c.LoadWorkflowExecution(ctx) if err != nil { return false, err } diff --git a/service/history/workflow/context_mock.go b/service/history/workflow/context_mock.go index ee7fe7b9b47..d5a141bcf69 100644 --- a/service/history/workflow/context_mock.go +++ b/service/history/workflow/context_mock.go @@ -75,31 +75,31 @@ func (mr *MockContextMockRecorder) Clear() *gomock.Call { } // ConflictResolveWorkflowExecution mocks base method. -func (m *MockContext) ConflictResolveWorkflowExecution(now time.Time, conflictResolveMode persistence.ConflictResolveWorkflowMode, resetMutableState MutableState, newContext Context, newMutableState MutableState, currentContext Context, currentMutableState MutableState, currentTransactionPolicy *TransactionPolicy) error { +func (m *MockContext) ConflictResolveWorkflowExecution(ctx context.Context, now time.Time, conflictResolveMode persistence.ConflictResolveWorkflowMode, resetMutableState MutableState, newContext Context, newMutableState MutableState, currentContext Context, currentMutableState MutableState, currentTransactionPolicy *TransactionPolicy) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ConflictResolveWorkflowExecution", now, conflictResolveMode, resetMutableState, newContext, newMutableState, currentContext, currentMutableState, currentTransactionPolicy) + ret := m.ctrl.Call(m, "ConflictResolveWorkflowExecution", ctx, now, conflictResolveMode, resetMutableState, newContext, newMutableState, currentContext, currentMutableState, currentTransactionPolicy) ret0, _ := ret[0].(error) return ret0 } // ConflictResolveWorkflowExecution indicates an expected call of ConflictResolveWorkflowExecution. -func (mr *MockContextMockRecorder) ConflictResolveWorkflowExecution(now, conflictResolveMode, resetMutableState, newContext, newMutableState, currentContext, currentMutableState, currentTransactionPolicy interface{}) *gomock.Call { +func (mr *MockContextMockRecorder) ConflictResolveWorkflowExecution(ctx, now, conflictResolveMode, resetMutableState, newContext, newMutableState, currentContext, currentMutableState, currentTransactionPolicy interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConflictResolveWorkflowExecution", reflect.TypeOf((*MockContext)(nil).ConflictResolveWorkflowExecution), now, conflictResolveMode, resetMutableState, newContext, newMutableState, currentContext, currentMutableState, currentTransactionPolicy) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConflictResolveWorkflowExecution", reflect.TypeOf((*MockContext)(nil).ConflictResolveWorkflowExecution), ctx, now, conflictResolveMode, resetMutableState, newContext, newMutableState, currentContext, currentMutableState, currentTransactionPolicy) } // CreateWorkflowExecution mocks base method. -func (m *MockContext) CreateWorkflowExecution(now time.Time, createMode persistence.CreateWorkflowMode, prevRunID string, prevLastWriteVersion int64, newMutableState MutableState, newWorkflow *persistence.WorkflowSnapshot, newWorkflowEvents []*persistence.WorkflowEvents) error { +func (m *MockContext) CreateWorkflowExecution(ctx context.Context, now time.Time, createMode persistence.CreateWorkflowMode, prevRunID string, prevLastWriteVersion int64, newMutableState MutableState, newWorkflow *persistence.WorkflowSnapshot, newWorkflowEvents []*persistence.WorkflowEvents) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateWorkflowExecution", now, createMode, prevRunID, prevLastWriteVersion, newMutableState, newWorkflow, newWorkflowEvents) + ret := m.ctrl.Call(m, "CreateWorkflowExecution", ctx, now, createMode, prevRunID, prevLastWriteVersion, newMutableState, newWorkflow, newWorkflowEvents) ret0, _ := ret[0].(error) return ret0 } // CreateWorkflowExecution indicates an expected call of CreateWorkflowExecution. -func (mr *MockContextMockRecorder) CreateWorkflowExecution(now, createMode, prevRunID, prevLastWriteVersion, newMutableState, newWorkflow, newWorkflowEvents interface{}) *gomock.Call { +func (mr *MockContextMockRecorder) CreateWorkflowExecution(ctx, now, createMode, prevRunID, prevLastWriteVersion, newMutableState, newWorkflow, newWorkflowEvents interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateWorkflowExecution", reflect.TypeOf((*MockContext)(nil).CreateWorkflowExecution), now, createMode, prevRunID, prevLastWriteVersion, newMutableState, newWorkflow, newWorkflowEvents) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateWorkflowExecution", reflect.TypeOf((*MockContext)(nil).CreateWorkflowExecution), ctx, now, createMode, prevRunID, prevLastWriteVersion, newMutableState, newWorkflow, newWorkflowEvents) } // GetHistorySize mocks base method. @@ -173,33 +173,33 @@ func (mr *MockContextMockRecorder) GetWorkflowID() *gomock.Call { } // LoadExecutionStats mocks base method. -func (m *MockContext) LoadExecutionStats() (*v1.ExecutionStats, error) { +func (m *MockContext) LoadExecutionStats(ctx context.Context) (*v1.ExecutionStats, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LoadExecutionStats") + ret := m.ctrl.Call(m, "LoadExecutionStats", ctx) ret0, _ := ret[0].(*v1.ExecutionStats) ret1, _ := ret[1].(error) return ret0, ret1 } // LoadExecutionStats indicates an expected call of LoadExecutionStats. -func (mr *MockContextMockRecorder) LoadExecutionStats() *gomock.Call { +func (mr *MockContextMockRecorder) LoadExecutionStats(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadExecutionStats", reflect.TypeOf((*MockContext)(nil).LoadExecutionStats)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadExecutionStats", reflect.TypeOf((*MockContext)(nil).LoadExecutionStats), ctx) } // LoadWorkflowExecution mocks base method. -func (m *MockContext) LoadWorkflowExecution() (MutableState, error) { +func (m *MockContext) LoadWorkflowExecution(ctx context.Context) (MutableState, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LoadWorkflowExecution") + ret := m.ctrl.Call(m, "LoadWorkflowExecution", ctx) ret0, _ := ret[0].(MutableState) ret1, _ := ret[1].(error) return ret0, ret1 } // LoadWorkflowExecution indicates an expected call of LoadWorkflowExecution. -func (mr *MockContextMockRecorder) LoadWorkflowExecution() *gomock.Call { +func (mr *MockContextMockRecorder) LoadWorkflowExecution(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadWorkflowExecution", reflect.TypeOf((*MockContext)(nil).LoadWorkflowExecution)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadWorkflowExecution", reflect.TypeOf((*MockContext)(nil).LoadWorkflowExecution), ctx) } // Lock mocks base method. @@ -217,18 +217,18 @@ func (mr *MockContextMockRecorder) Lock(ctx, caller interface{}) *gomock.Call { } // PersistWorkflowEvents mocks base method. -func (m *MockContext) PersistWorkflowEvents(workflowEvents *persistence.WorkflowEvents) (int64, error) { +func (m *MockContext) PersistWorkflowEvents(ctx context.Context, workflowEvents *persistence.WorkflowEvents) (int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PersistWorkflowEvents", workflowEvents) + ret := m.ctrl.Call(m, "PersistWorkflowEvents", ctx, workflowEvents) ret0, _ := ret[0].(int64) ret1, _ := ret[1].(error) return ret0, ret1 } // PersistWorkflowEvents indicates an expected call of PersistWorkflowEvents. -func (mr *MockContextMockRecorder) PersistWorkflowEvents(workflowEvents interface{}) *gomock.Call { +func (mr *MockContextMockRecorder) PersistWorkflowEvents(ctx, workflowEvents interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PersistWorkflowEvents", reflect.TypeOf((*MockContext)(nil).PersistWorkflowEvents), workflowEvents) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PersistWorkflowEvents", reflect.TypeOf((*MockContext)(nil).PersistWorkflowEvents), ctx, workflowEvents) } // ReapplyEvents mocks base method. @@ -258,17 +258,17 @@ func (mr *MockContextMockRecorder) SetHistorySize(size interface{}) *gomock.Call } // SetWorkflowExecution mocks base method. -func (m *MockContext) SetWorkflowExecution(now time.Time) error { +func (m *MockContext) SetWorkflowExecution(ctx context.Context, now time.Time) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetWorkflowExecution", now) + ret := m.ctrl.Call(m, "SetWorkflowExecution", ctx, now) ret0, _ := ret[0].(error) return ret0 } // SetWorkflowExecution indicates an expected call of SetWorkflowExecution. -func (mr *MockContextMockRecorder) SetWorkflowExecution(now interface{}) *gomock.Call { +func (mr *MockContextMockRecorder) SetWorkflowExecution(ctx, now interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWorkflowExecution", reflect.TypeOf((*MockContext)(nil).SetWorkflowExecution), now) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWorkflowExecution", reflect.TypeOf((*MockContext)(nil).SetWorkflowExecution), ctx, now) } // Unlock mocks base method. @@ -284,71 +284,71 @@ func (mr *MockContextMockRecorder) Unlock(caller interface{}) *gomock.Call { } // UpdateWorkflowExecutionAsActive mocks base method. -func (m *MockContext) UpdateWorkflowExecutionAsActive(now time.Time) error { +func (m *MockContext) UpdateWorkflowExecutionAsActive(ctx context.Context, now time.Time) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateWorkflowExecutionAsActive", now) + ret := m.ctrl.Call(m, "UpdateWorkflowExecutionAsActive", ctx, now) ret0, _ := ret[0].(error) return ret0 } // UpdateWorkflowExecutionAsActive indicates an expected call of UpdateWorkflowExecutionAsActive. -func (mr *MockContextMockRecorder) UpdateWorkflowExecutionAsActive(now interface{}) *gomock.Call { +func (mr *MockContextMockRecorder) UpdateWorkflowExecutionAsActive(ctx, now interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkflowExecutionAsActive", reflect.TypeOf((*MockContext)(nil).UpdateWorkflowExecutionAsActive), now) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkflowExecutionAsActive", reflect.TypeOf((*MockContext)(nil).UpdateWorkflowExecutionAsActive), ctx, now) } // UpdateWorkflowExecutionAsPassive mocks base method. -func (m *MockContext) UpdateWorkflowExecutionAsPassive(now time.Time) error { +func (m *MockContext) UpdateWorkflowExecutionAsPassive(ctx context.Context, now time.Time) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateWorkflowExecutionAsPassive", now) + ret := m.ctrl.Call(m, "UpdateWorkflowExecutionAsPassive", ctx, now) ret0, _ := ret[0].(error) return ret0 } // UpdateWorkflowExecutionAsPassive indicates an expected call of UpdateWorkflowExecutionAsPassive. -func (mr *MockContextMockRecorder) UpdateWorkflowExecutionAsPassive(now interface{}) *gomock.Call { +func (mr *MockContextMockRecorder) UpdateWorkflowExecutionAsPassive(ctx, now interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkflowExecutionAsPassive", reflect.TypeOf((*MockContext)(nil).UpdateWorkflowExecutionAsPassive), now) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkflowExecutionAsPassive", reflect.TypeOf((*MockContext)(nil).UpdateWorkflowExecutionAsPassive), ctx, now) } // UpdateWorkflowExecutionWithNew mocks base method. -func (m *MockContext) UpdateWorkflowExecutionWithNew(now time.Time, updateMode persistence.UpdateWorkflowMode, newContext Context, newMutableState MutableState, currentWorkflowTransactionPolicy TransactionPolicy, newWorkflowTransactionPolicy *TransactionPolicy) error { +func (m *MockContext) UpdateWorkflowExecutionWithNew(ctx context.Context, now time.Time, updateMode persistence.UpdateWorkflowMode, newContext Context, newMutableState MutableState, currentWorkflowTransactionPolicy TransactionPolicy, newWorkflowTransactionPolicy *TransactionPolicy) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateWorkflowExecutionWithNew", now, updateMode, newContext, newMutableState, currentWorkflowTransactionPolicy, newWorkflowTransactionPolicy) + ret := m.ctrl.Call(m, "UpdateWorkflowExecutionWithNew", ctx, now, updateMode, newContext, newMutableState, currentWorkflowTransactionPolicy, newWorkflowTransactionPolicy) ret0, _ := ret[0].(error) return ret0 } // UpdateWorkflowExecutionWithNew indicates an expected call of UpdateWorkflowExecutionWithNew. -func (mr *MockContextMockRecorder) UpdateWorkflowExecutionWithNew(now, updateMode, newContext, newMutableState, currentWorkflowTransactionPolicy, newWorkflowTransactionPolicy interface{}) *gomock.Call { +func (mr *MockContextMockRecorder) UpdateWorkflowExecutionWithNew(ctx, now, updateMode, newContext, newMutableState, currentWorkflowTransactionPolicy, newWorkflowTransactionPolicy interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkflowExecutionWithNew", reflect.TypeOf((*MockContext)(nil).UpdateWorkflowExecutionWithNew), now, updateMode, newContext, newMutableState, currentWorkflowTransactionPolicy, newWorkflowTransactionPolicy) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkflowExecutionWithNew", reflect.TypeOf((*MockContext)(nil).UpdateWorkflowExecutionWithNew), ctx, now, updateMode, newContext, newMutableState, currentWorkflowTransactionPolicy, newWorkflowTransactionPolicy) } // UpdateWorkflowExecutionWithNewAsActive mocks base method. -func (m *MockContext) UpdateWorkflowExecutionWithNewAsActive(now time.Time, newContext Context, newMutableState MutableState) error { +func (m *MockContext) UpdateWorkflowExecutionWithNewAsActive(ctx context.Context, now time.Time, newContext Context, newMutableState MutableState) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateWorkflowExecutionWithNewAsActive", now, newContext, newMutableState) + ret := m.ctrl.Call(m, "UpdateWorkflowExecutionWithNewAsActive", ctx, now, newContext, newMutableState) ret0, _ := ret[0].(error) return ret0 } // UpdateWorkflowExecutionWithNewAsActive indicates an expected call of UpdateWorkflowExecutionWithNewAsActive. -func (mr *MockContextMockRecorder) UpdateWorkflowExecutionWithNewAsActive(now, newContext, newMutableState interface{}) *gomock.Call { +func (mr *MockContextMockRecorder) UpdateWorkflowExecutionWithNewAsActive(ctx, now, newContext, newMutableState interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkflowExecutionWithNewAsActive", reflect.TypeOf((*MockContext)(nil).UpdateWorkflowExecutionWithNewAsActive), now, newContext, newMutableState) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkflowExecutionWithNewAsActive", reflect.TypeOf((*MockContext)(nil).UpdateWorkflowExecutionWithNewAsActive), ctx, now, newContext, newMutableState) } // UpdateWorkflowExecutionWithNewAsPassive mocks base method. -func (m *MockContext) UpdateWorkflowExecutionWithNewAsPassive(now time.Time, newContext Context, newMutableState MutableState) error { +func (m *MockContext) UpdateWorkflowExecutionWithNewAsPassive(ctx context.Context, now time.Time, newContext Context, newMutableState MutableState) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateWorkflowExecutionWithNewAsPassive", now, newContext, newMutableState) + ret := m.ctrl.Call(m, "UpdateWorkflowExecutionWithNewAsPassive", ctx, now, newContext, newMutableState) ret0, _ := ret[0].(error) return ret0 } // UpdateWorkflowExecutionWithNewAsPassive indicates an expected call of UpdateWorkflowExecutionWithNewAsPassive. -func (mr *MockContextMockRecorder) UpdateWorkflowExecutionWithNewAsPassive(now, newContext, newMutableState interface{}) *gomock.Call { +func (mr *MockContextMockRecorder) UpdateWorkflowExecutionWithNewAsPassive(ctx, now, newContext, newMutableState interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkflowExecutionWithNewAsPassive", reflect.TypeOf((*MockContext)(nil).UpdateWorkflowExecutionWithNewAsPassive), now, newContext, newMutableState) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkflowExecutionWithNewAsPassive", reflect.TypeOf((*MockContext)(nil).UpdateWorkflowExecutionWithNewAsPassive), ctx, now, newContext, newMutableState) } diff --git a/service/history/workflow/delete_manager.go b/service/history/workflow/delete_manager.go index c7cd0af6697..398f6a93425 100644 --- a/service/history/workflow/delete_manager.go +++ b/service/history/workflow/delete_manager.go @@ -48,9 +48,9 @@ import ( type ( DeleteManager interface { - AddDeleteWorkflowExecutionTask(nsID namespace.ID, we commonpb.WorkflowExecution, ms MutableState) error - DeleteWorkflowExecution(nsID namespace.ID, we commonpb.WorkflowExecution, weCtx Context, ms MutableState, sourceTaskVersion int64) error - DeleteWorkflowExecutionByRetention(nsID namespace.ID, we commonpb.WorkflowExecution, weCtx Context, ms MutableState, sourceTaskVersion int64) error + AddDeleteWorkflowExecutionTask(ctx context.Context, nsID namespace.ID, we commonpb.WorkflowExecution, ms MutableState) error + DeleteWorkflowExecution(ctx context.Context, nsID namespace.ID, we commonpb.WorkflowExecution, weCtx Context, ms MutableState, sourceTaskVersion int64) error + DeleteWorkflowExecutionByRetention(ctx context.Context, nsID namespace.ID, we commonpb.WorkflowExecution, weCtx Context, ms MutableState, sourceTaskVersion int64) error } DeleteManagerImpl struct { @@ -84,6 +84,7 @@ func NewDeleteManager( return deleteManager } func (m *DeleteManagerImpl) AddDeleteWorkflowExecutionTask( + ctx context.Context, nsID namespace.ID, we commonpb.WorkflowExecution, ms MutableState, @@ -101,7 +102,7 @@ func (m *DeleteManagerImpl) AddDeleteWorkflowExecutionTask( return err } - err = m.shard.AddTasks(&persistence.AddHistoryTasksRequest{ + err = m.shard.AddTasks(ctx, &persistence.AddHistoryTasksRequest{ ShardID: m.shard.GetShardID(), // RangeID is set by shard NamespaceID: nsID.String(), @@ -119,6 +120,7 @@ func (m *DeleteManagerImpl) AddDeleteWorkflowExecutionTask( } func (m *DeleteManagerImpl) DeleteWorkflowExecution( + ctx context.Context, nsID namespace.ID, we commonpb.WorkflowExecution, weCtx Context, @@ -135,6 +137,7 @@ func (m *DeleteManagerImpl) DeleteWorkflowExecution( } err := m.deleteWorkflowExecutionInternal( + ctx, nsID, we, weCtx, @@ -148,6 +151,7 @@ func (m *DeleteManagerImpl) DeleteWorkflowExecution( } func (m *DeleteManagerImpl) DeleteWorkflowExecutionByRetention( + ctx context.Context, nsID namespace.ID, we commonpb.WorkflowExecution, weCtx Context, @@ -163,6 +167,7 @@ func (m *DeleteManagerImpl) DeleteWorkflowExecutionByRetention( } err := m.deleteWorkflowExecutionInternal( + ctx, nsID, we, weCtx, @@ -176,6 +181,7 @@ func (m *DeleteManagerImpl) DeleteWorkflowExecutionByRetention( } func (m *DeleteManagerImpl) deleteWorkflowExecutionInternal( + ctx context.Context, namespaceID namespace.ID, we commonpb.WorkflowExecution, weCtx Context, @@ -192,7 +198,7 @@ func (m *DeleteManagerImpl) deleteWorkflowExecutionInternal( shouldDeleteHistory := true if archiveIfEnabled { - shouldDeleteHistory, err = m.archiveWorkflowIfEnabled(namespaceID, we, currentBranchToken, weCtx, ms, scope) + shouldDeleteHistory, err = m.archiveWorkflowIfEnabled(ctx, namespaceID, we, currentBranchToken, weCtx, ms, scope) if err != nil { return err } @@ -203,12 +209,13 @@ func (m *DeleteManagerImpl) deleteWorkflowExecutionInternal( currentBranchToken = nil } - completionEvent, err := ms.GetCompletionEvent() + completionEvent, err := ms.GetCompletionEvent(ctx) if err != nil { return err } if err := m.shard.DeleteWorkflowExecution( + ctx, definition.WorkflowKey{ NamespaceID: namespaceID.String(), WorkflowID: we.GetWorkflowId(), @@ -229,6 +236,7 @@ func (m *DeleteManagerImpl) deleteWorkflowExecutionInternal( } func (m *DeleteManagerImpl) archiveWorkflowIfEnabled( + ctx context.Context, namespaceID namespace.ID, workflowExecution commonpb.WorkflowExecution, currentBranchToken []byte, @@ -271,7 +279,7 @@ func (m *DeleteManagerImpl) archiveWorkflowIfEnabled( CallerService: common.HistoryServiceName, AttemptArchiveInline: false, // archive in workflow by default } - executionStats, err := weCtx.LoadExecutionStats() + executionStats, err := weCtx.LoadExecutionStats(ctx) if err == nil && executionStats.HistorySize < int64(m.config.TimerProcessorHistoryArchivalSizeLimit()) { req.AttemptArchiveInline = true } diff --git a/service/history/workflow/delete_manager_mock.go b/service/history/workflow/delete_manager_mock.go index a638c4d6ec5..6a2eab1720b 100644 --- a/service/history/workflow/delete_manager_mock.go +++ b/service/history/workflow/delete_manager_mock.go @@ -29,6 +29,7 @@ package workflow import ( + context "context" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -60,43 +61,43 @@ func (m *MockDeleteManager) EXPECT() *MockDeleteManagerMockRecorder { } // AddDeleteWorkflowExecutionTask mocks base method. -func (m *MockDeleteManager) AddDeleteWorkflowExecutionTask(nsID namespace.ID, we v1.WorkflowExecution, ms MutableState) error { +func (m *MockDeleteManager) AddDeleteWorkflowExecutionTask(ctx context.Context, nsID namespace.ID, we v1.WorkflowExecution, ms MutableState) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddDeleteWorkflowExecutionTask", nsID, we, ms) + ret := m.ctrl.Call(m, "AddDeleteWorkflowExecutionTask", ctx, nsID, we, ms) ret0, _ := ret[0].(error) return ret0 } // AddDeleteWorkflowExecutionTask indicates an expected call of AddDeleteWorkflowExecutionTask. -func (mr *MockDeleteManagerMockRecorder) AddDeleteWorkflowExecutionTask(nsID, we, ms interface{}) *gomock.Call { +func (mr *MockDeleteManagerMockRecorder) AddDeleteWorkflowExecutionTask(ctx, nsID, we, ms interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddDeleteWorkflowExecutionTask", reflect.TypeOf((*MockDeleteManager)(nil).AddDeleteWorkflowExecutionTask), nsID, we, ms) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddDeleteWorkflowExecutionTask", reflect.TypeOf((*MockDeleteManager)(nil).AddDeleteWorkflowExecutionTask), ctx, nsID, we, ms) } // DeleteWorkflowExecution mocks base method. -func (m *MockDeleteManager) DeleteWorkflowExecution(nsID namespace.ID, we v1.WorkflowExecution, weCtx Context, ms MutableState, sourceTaskVersion int64) error { +func (m *MockDeleteManager) DeleteWorkflowExecution(ctx context.Context, nsID namespace.ID, we v1.WorkflowExecution, weCtx Context, ms MutableState, sourceTaskVersion int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteWorkflowExecution", nsID, we, weCtx, ms, sourceTaskVersion) + ret := m.ctrl.Call(m, "DeleteWorkflowExecution", ctx, nsID, we, weCtx, ms, sourceTaskVersion) ret0, _ := ret[0].(error) return ret0 } // DeleteWorkflowExecution indicates an expected call of DeleteWorkflowExecution. -func (mr *MockDeleteManagerMockRecorder) DeleteWorkflowExecution(nsID, we, weCtx, ms, sourceTaskVersion interface{}) *gomock.Call { +func (mr *MockDeleteManagerMockRecorder) DeleteWorkflowExecution(ctx, nsID, we, weCtx, ms, sourceTaskVersion interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkflowExecution", reflect.TypeOf((*MockDeleteManager)(nil).DeleteWorkflowExecution), nsID, we, weCtx, ms, sourceTaskVersion) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkflowExecution", reflect.TypeOf((*MockDeleteManager)(nil).DeleteWorkflowExecution), ctx, nsID, we, weCtx, ms, sourceTaskVersion) } // DeleteWorkflowExecutionByRetention mocks base method. -func (m *MockDeleteManager) DeleteWorkflowExecutionByRetention(nsID namespace.ID, we v1.WorkflowExecution, weCtx Context, ms MutableState, sourceTaskVersion int64) error { +func (m *MockDeleteManager) DeleteWorkflowExecutionByRetention(ctx context.Context, nsID namespace.ID, we v1.WorkflowExecution, weCtx Context, ms MutableState, sourceTaskVersion int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteWorkflowExecutionByRetention", nsID, we, weCtx, ms, sourceTaskVersion) + ret := m.ctrl.Call(m, "DeleteWorkflowExecutionByRetention", ctx, nsID, we, weCtx, ms, sourceTaskVersion) ret0, _ := ret[0].(error) return ret0 } // DeleteWorkflowExecutionByRetention indicates an expected call of DeleteWorkflowExecutionByRetention. -func (mr *MockDeleteManagerMockRecorder) DeleteWorkflowExecutionByRetention(nsID, we, weCtx, ms, sourceTaskVersion interface{}) *gomock.Call { +func (mr *MockDeleteManagerMockRecorder) DeleteWorkflowExecutionByRetention(ctx, nsID, we, weCtx, ms, sourceTaskVersion interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkflowExecutionByRetention", reflect.TypeOf((*MockDeleteManager)(nil).DeleteWorkflowExecutionByRetention), nsID, we, weCtx, ms, sourceTaskVersion) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkflowExecutionByRetention", reflect.TypeOf((*MockDeleteManager)(nil).DeleteWorkflowExecutionByRetention), ctx, nsID, we, weCtx, ms, sourceTaskVersion) } diff --git a/service/history/workflow/delete_manager_test.go b/service/history/workflow/delete_manager_test.go index 959ab2358ec..72221a990be 100644 --- a/service/history/workflow/delete_manager_test.go +++ b/service/history/workflow/delete_manager_test.go @@ -25,6 +25,7 @@ package workflow import ( + "context" "errors" "testing" "time" @@ -120,9 +121,10 @@ func (s *deleteManagerWorkflowSuite) TestDeleteDeletedWorkflowExecution() { EventType: enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED, EventTime: &closeTime, } - mockMutableState.EXPECT().GetCompletionEvent().Return(completionEvent, nil) + mockMutableState.EXPECT().GetCompletionEvent(gomock.Any()).Return(completionEvent, nil) s.mockShardContext.EXPECT().DeleteWorkflowExecution( + gomock.Any(), definition.WorkflowKey{ NamespaceID: tests.NamespaceID.String(), WorkflowID: tests.WorkflowID, @@ -135,6 +137,7 @@ func (s *deleteManagerWorkflowSuite) TestDeleteDeletedWorkflowExecution() { mockWeCtx.EXPECT().Clear() err := s.deleteManager.DeleteWorkflowExecution( + context.Background(), tests.NamespaceID, we, mockWeCtx, @@ -160,9 +163,10 @@ func (s *deleteManagerWorkflowSuite) TestDeleteDeletedWorkflowExecution_Error() EventType: enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED, EventTime: &closeTime, } - mockMutableState.EXPECT().GetCompletionEvent().Return(completionEvent, nil) + mockMutableState.EXPECT().GetCompletionEvent(gomock.Any()).Return(completionEvent, nil) s.mockShardContext.EXPECT().DeleteWorkflowExecution( + gomock.Any(), definition.WorkflowKey{ NamespaceID: tests.NamespaceID.String(), WorkflowID: tests.WorkflowID, @@ -174,6 +178,7 @@ func (s *deleteManagerWorkflowSuite) TestDeleteDeletedWorkflowExecution_Error() ).Return(serviceerror.NewInternal("test error")) err := s.deleteManager.DeleteWorkflowExecution( + context.Background(), tests.NamespaceID, we, mockWeCtx, @@ -200,7 +205,7 @@ func (s *deleteManagerWorkflowSuite) TestDeleteWorkflowExecutionRetention_Archiv EventType: enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED, EventTime: &closeTime, } - mockMutableState.EXPECT().GetCompletionEvent().Return(completionEvent, nil) + mockMutableState.EXPECT().GetCompletionEvent(gomock.Any()).Return(completionEvent, nil) // ====================== Archival mocks ======================================= mockNamespaceRegistry := namespace.NewMockRegistry(s.controller) @@ -223,7 +228,7 @@ func (s *deleteManagerWorkflowSuite) TestDeleteWorkflowExecutionRetention_Archiv mockMutableState.EXPECT().GetLastWriteVersion().Return(int64(1), nil) s.mockShardContext.EXPECT().GetShardID().Return(int32(1)) mockMutableState.EXPECT().GetNextEventID().Return(int64(1)) - mockWeCtx.EXPECT().LoadExecutionStats().Return(&persistencespb.ExecutionStats{ + mockWeCtx.EXPECT().LoadExecutionStats(gomock.Any()).Return(&persistencespb.ExecutionStats{ HistorySize: 22, }, nil) mockSearchAttributesProvider := searchattribute.NewMockProvider(s.controller) @@ -235,6 +240,7 @@ func (s *deleteManagerWorkflowSuite) TestDeleteWorkflowExecutionRetention_Archiv // ============================================================= s.mockShardContext.EXPECT().DeleteWorkflowExecution( + gomock.Any(), definition.WorkflowKey{ NamespaceID: tests.NamespaceID.String(), WorkflowID: tests.WorkflowID, @@ -247,6 +253,7 @@ func (s *deleteManagerWorkflowSuite) TestDeleteWorkflowExecutionRetention_Archiv mockWeCtx.EXPECT().Clear() err := s.deleteManager.DeleteWorkflowExecutionByRetention( + context.Background(), tests.NamespaceID, we, mockWeCtx, @@ -289,7 +296,7 @@ func (s *deleteManagerWorkflowSuite) TestDeleteWorkflowExecutionRetention_Archiv mockMutableState.EXPECT().GetLastWriteVersion().Return(int64(1), nil) s.mockShardContext.EXPECT().GetShardID().Return(int32(1)) mockMutableState.EXPECT().GetNextEventID().Return(int64(1)) - mockWeCtx.EXPECT().LoadExecutionStats().Return(&persistencespb.ExecutionStats{ + mockWeCtx.EXPECT().LoadExecutionStats(gomock.Any()).Return(&persistencespb.ExecutionStats{ HistorySize: 22 * 1024 * 1024 * 1024, }, nil) mockSearchAttributesProvider := searchattribute.NewMockProvider(s.controller) @@ -299,6 +306,7 @@ func (s *deleteManagerWorkflowSuite) TestDeleteWorkflowExecutionRetention_Archiv // ============================================================= err := s.deleteManager.DeleteWorkflowExecutionByRetention( + context.Background(), tests.NamespaceID, we, mockWeCtx, diff --git a/service/history/workflow/mutable_state.go b/service/history/workflow/mutable_state.go index fbf7d491732..2ae4fb048f0 100644 --- a/service/history/workflow/mutable_state.go +++ b/service/history/workflow/mutable_state.go @@ -27,6 +27,7 @@ package workflow import ( + "context" "time" commandpb "go.temporal.io/api/command/v1" @@ -140,14 +141,14 @@ type ( GetActivityByActivityID(string) (*persistencespb.ActivityInfo, bool) GetActivityInfo(int64) (*persistencespb.ActivityInfo, bool) GetActivityInfoWithTimerHeartbeat(scheduleEventID int64) (*persistencespb.ActivityInfo, time.Time, bool) - GetActivityScheduledEvent(int64) (*historypb.HistoryEvent, error) + GetActivityScheduledEvent(context.Context, int64) (*historypb.HistoryEvent, error) GetChildExecutionInfo(int64) (*persistencespb.ChildExecutionInfo, bool) - GetChildExecutionInitiatedEvent(int64) (*historypb.HistoryEvent, error) - GetCompletionEvent() (*historypb.HistoryEvent, error) + GetChildExecutionInitiatedEvent(context.Context, int64) (*historypb.HistoryEvent, error) + GetCompletionEvent(context.Context) (*historypb.HistoryEvent, error) GetWorkflowTaskInfo(int64) (*WorkflowTaskInfo, bool) GetNamespaceEntry() *namespace.Namespace - GetStartEvent() (*historypb.HistoryEvent, error) - GetSignalExternalInitiatedEvent(int64) (*historypb.HistoryEvent, error) + GetStartEvent(context.Context) (*historypb.HistoryEvent, error) + GetSignalExternalInitiatedEvent(context.Context, int64) (*historypb.HistoryEvent, error) GetFirstRunID() (string, error) GetCurrentBranchToken() ([]byte, error) GetCurrentVersion() int64 diff --git a/service/history/workflow/mutable_state_impl.go b/service/history/workflow/mutable_state_impl.go index a641362b969..2f9d20515a0 100644 --- a/service/history/workflow/mutable_state_impl.go +++ b/service/history/workflow/mutable_state_impl.go @@ -25,6 +25,7 @@ package workflow import ( + "context" "fmt" "math/rand" "time" @@ -267,6 +268,7 @@ func NewMutableState( } func newMutableStateBuilderFromDB( + ctx context.Context, shard shard.Context, eventsCache events.Cache, logger log.Logger, @@ -317,7 +319,7 @@ func newMutableStateBuilderFromDB( // Workflows created before 1.11 doesn't have ExecutionTime and it must be computed for backwards compatibility. // Remove this "if" block when it is ok to rely on executionInfo.ExecutionTime only (added 6/9/21). if mutableState.executionInfo.ExecutionTime == nil { - startEvent, err := mutableState.GetStartEvent() + startEvent, err := mutableState.GetStartEvent(ctx) if err != nil { return nil, err } @@ -582,6 +584,7 @@ func (e *MutableStateImpl) GetQueryRegistry() QueryRegistry { } func (e *MutableStateImpl) GetActivityScheduledEvent( + ctx context.Context, scheduleEventID int64, ) (*historypb.HistoryEvent, error) { @@ -595,6 +598,7 @@ func (e *MutableStateImpl) GetActivityScheduledEvent( return nil, err } event, err := e.eventsCache.GetEvent( + ctx, events.EventKey{ NamespaceID: namespace.ID(e.executionInfo.NamespaceId), WorkflowID: e.executionInfo.WorkflowId, @@ -657,6 +661,7 @@ func (e *MutableStateImpl) GetChildExecutionInfo( // GetChildExecutionInitiatedEvent reads out the ChildExecutionInitiatedEvent from mutable state for in-progress child // executions func (e *MutableStateImpl) GetChildExecutionInitiatedEvent( + ctx context.Context, initiatedEventID int64, ) (*historypb.HistoryEvent, error) { @@ -670,6 +675,7 @@ func (e *MutableStateImpl) GetChildExecutionInitiatedEvent( return nil, err } event, err := e.eventsCache.GetEvent( + ctx, events.EventKey{ NamespaceID: namespace.ID(e.executionInfo.NamespaceId), WorkflowID: e.executionInfo.WorkflowId, @@ -739,6 +745,7 @@ func (e *MutableStateImpl) GetSignalInfo( // GetSignalExternalInitiatedEvent get the details about signal external workflow func (e *MutableStateImpl) GetSignalExternalInitiatedEvent( + ctx context.Context, initiatedEventID int64, ) (*historypb.HistoryEvent, error) { si, ok := e.pendingSignalInfoIDs[initiatedEventID] @@ -751,6 +758,7 @@ func (e *MutableStateImpl) GetSignalExternalInitiatedEvent( return nil, err } event, err := e.eventsCache.GetEvent( + ctx, events.EventKey{ NamespaceID: namespace.ID(e.executionInfo.NamespaceId), WorkflowID: e.executionInfo.WorkflowId, @@ -771,7 +779,9 @@ func (e *MutableStateImpl) GetSignalExternalInitiatedEvent( } // GetCompletionEvent retrieves the workflow completion event from mutable state -func (e *MutableStateImpl) GetCompletionEvent() (*historypb.HistoryEvent, error) { +func (e *MutableStateImpl) GetCompletionEvent( + ctx context.Context, +) (*historypb.HistoryEvent, error) { if e.executionState.State != enumsspb.WORKFLOW_EXECUTION_STATE_COMPLETED { return nil, ErrMissingWorkflowCompletionEvent } @@ -786,6 +796,7 @@ func (e *MutableStateImpl) GetCompletionEvent() (*historypb.HistoryEvent, error) } event, err := e.eventsCache.GetEvent( + ctx, events.EventKey{ NamespaceID: namespace.ID(e.executionInfo.NamespaceId), WorkflowID: e.executionInfo.WorkflowId, @@ -806,7 +817,9 @@ func (e *MutableStateImpl) GetCompletionEvent() (*historypb.HistoryEvent, error) } // GetStartEvent retrieves the workflow start event from mutable state -func (e *MutableStateImpl) GetStartEvent() (*historypb.HistoryEvent, error) { +func (e *MutableStateImpl) GetStartEvent( + ctx context.Context, +) (*historypb.HistoryEvent, error) { currentBranchToken, err := e.GetCurrentBranchToken() if err != nil { @@ -818,6 +831,7 @@ func (e *MutableStateImpl) GetStartEvent() (*historypb.HistoryEvent, error) { } event, err := e.eventsCache.GetEvent( + ctx, events.EventKey{ NamespaceID: namespace.ID(e.executionInfo.NamespaceId), WorkflowID: e.executionInfo.WorkflowId, @@ -845,7 +859,7 @@ func (e *MutableStateImpl) GetFirstRunID() (string, error) { if len(firstRunID) != 0 { return firstRunID, nil } - currentStartEvent, err := e.GetStartEvent() + currentStartEvent, err := e.GetStartEvent(context.TODO()) if err != nil { return "", err } diff --git a/service/history/workflow/mutable_state_impl_test.go b/service/history/workflow/mutable_state_impl_test.go index 841496475f0..2bdd557ad4f 100644 --- a/service/history/workflow/mutable_state_impl_test.go +++ b/service/history/workflow/mutable_state_impl_test.go @@ -25,6 +25,7 @@ package workflow import ( + "context" "testing" "time" @@ -246,7 +247,7 @@ func (s *mutableStateSuite) TestChecksum() { // create mutable state and verify checksum is generated on close loadErrors = loadErrorsFunc() var err error - s.mutableState, err = newMutableStateBuilderFromDB(s.mockShard, s.mockEventsCache, s.logger, tests.LocalNamespaceEntry, dbState, 123) + s.mutableState, err = newMutableStateBuilderFromDB(context.Background(), s.mockShard, s.mockEventsCache, s.logger, tests.LocalNamespaceEntry, dbState, 123) s.NoError(err) s.Equal(loadErrors, loadErrorsFunc()) // no errors expected s.EqualValues(dbState.Checksum, s.mutableState.checksum) @@ -260,7 +261,7 @@ func (s *mutableStateSuite) TestChecksum() { // verify checksum is verified on Load dbState.Checksum = csum - s.mutableState, err = newMutableStateBuilderFromDB(s.mockShard, s.mockEventsCache, s.logger, tests.LocalNamespaceEntry, dbState, 123) + s.mutableState, err = newMutableStateBuilderFromDB(context.Background(), s.mockShard, s.mockEventsCache, s.logger, tests.LocalNamespaceEntry, dbState, 123) s.NoError(err) s.Equal(loadErrors, loadErrorsFunc()) @@ -272,7 +273,7 @@ func (s *mutableStateSuite) TestChecksum() { // modify checksum and verify Load fails dbState.Checksum.Value[0]++ - s.mutableState, err = newMutableStateBuilderFromDB(s.mockShard, s.mockEventsCache, s.logger, tests.LocalNamespaceEntry, dbState, 123) + s.mutableState, err = newMutableStateBuilderFromDB(context.Background(), s.mockShard, s.mockEventsCache, s.logger, tests.LocalNamespaceEntry, dbState, 123) s.NoError(err) s.Equal(loadErrors+1, loadErrorsFunc()) s.EqualValues(dbState.Checksum, s.mutableState.checksum) @@ -282,7 +283,7 @@ func (s *mutableStateSuite) TestChecksum() { s.mockConfig.MutableStateChecksumInvalidateBefore = func(...dynamicconfig.FilterOption) float64 { return float64((s.mutableState.executionInfo.LastUpdateTime.UnixNano() / int64(time.Second)) + 1) } - s.mutableState, err = newMutableStateBuilderFromDB(s.mockShard, s.mockEventsCache, s.logger, tests.LocalNamespaceEntry, dbState, 123) + s.mutableState, err = newMutableStateBuilderFromDB(context.Background(), s.mockShard, s.mockEventsCache, s.logger, tests.LocalNamespaceEntry, dbState, 123) s.NoError(err) s.Equal(loadErrors, loadErrorsFunc()) s.Nil(s.mutableState.checksum) diff --git a/service/history/workflow/mutable_state_mock.go b/service/history/workflow/mutable_state_mock.go index 2c7392257e1..6dc3c8b047c 100644 --- a/service/history/workflow/mutable_state_mock.go +++ b/service/history/workflow/mutable_state_mock.go @@ -29,6 +29,7 @@ package workflow import ( + context "context" reflect "reflect" time "time" @@ -953,18 +954,18 @@ func (mr *MockMutableStateMockRecorder) GetActivityInfoWithTimerHeartbeat(schedu } // GetActivityScheduledEvent mocks base method. -func (m *MockMutableState) GetActivityScheduledEvent(arg0 int64) (*v13.HistoryEvent, error) { +func (m *MockMutableState) GetActivityScheduledEvent(arg0 context.Context, arg1 int64) (*v13.HistoryEvent, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetActivityScheduledEvent", arg0) + ret := m.ctrl.Call(m, "GetActivityScheduledEvent", arg0, arg1) ret0, _ := ret[0].(*v13.HistoryEvent) ret1, _ := ret[1].(error) return ret0, ret1 } // GetActivityScheduledEvent indicates an expected call of GetActivityScheduledEvent. -func (mr *MockMutableStateMockRecorder) GetActivityScheduledEvent(arg0 interface{}) *gomock.Call { +func (mr *MockMutableStateMockRecorder) GetActivityScheduledEvent(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActivityScheduledEvent", reflect.TypeOf((*MockMutableState)(nil).GetActivityScheduledEvent), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActivityScheduledEvent", reflect.TypeOf((*MockMutableState)(nil).GetActivityScheduledEvent), arg0, arg1) } // GetChildExecutionInfo mocks base method. @@ -983,33 +984,33 @@ func (mr *MockMutableStateMockRecorder) GetChildExecutionInfo(arg0 interface{}) } // GetChildExecutionInitiatedEvent mocks base method. -func (m *MockMutableState) GetChildExecutionInitiatedEvent(arg0 int64) (*v13.HistoryEvent, error) { +func (m *MockMutableState) GetChildExecutionInitiatedEvent(arg0 context.Context, arg1 int64) (*v13.HistoryEvent, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetChildExecutionInitiatedEvent", arg0) + ret := m.ctrl.Call(m, "GetChildExecutionInitiatedEvent", arg0, arg1) ret0, _ := ret[0].(*v13.HistoryEvent) ret1, _ := ret[1].(error) return ret0, ret1 } // GetChildExecutionInitiatedEvent indicates an expected call of GetChildExecutionInitiatedEvent. -func (mr *MockMutableStateMockRecorder) GetChildExecutionInitiatedEvent(arg0 interface{}) *gomock.Call { +func (mr *MockMutableStateMockRecorder) GetChildExecutionInitiatedEvent(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChildExecutionInitiatedEvent", reflect.TypeOf((*MockMutableState)(nil).GetChildExecutionInitiatedEvent), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChildExecutionInitiatedEvent", reflect.TypeOf((*MockMutableState)(nil).GetChildExecutionInitiatedEvent), arg0, arg1) } // GetCompletionEvent mocks base method. -func (m *MockMutableState) GetCompletionEvent() (*v13.HistoryEvent, error) { +func (m *MockMutableState) GetCompletionEvent(arg0 context.Context) (*v13.HistoryEvent, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetCompletionEvent") + ret := m.ctrl.Call(m, "GetCompletionEvent", arg0) ret0, _ := ret[0].(*v13.HistoryEvent) ret1, _ := ret[1].(error) return ret0, ret1 } // GetCompletionEvent indicates an expected call of GetCompletionEvent. -func (mr *MockMutableStateMockRecorder) GetCompletionEvent() *gomock.Call { +func (mr *MockMutableStateMockRecorder) GetCompletionEvent(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCompletionEvent", reflect.TypeOf((*MockMutableState)(nil).GetCompletionEvent)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCompletionEvent", reflect.TypeOf((*MockMutableState)(nil).GetCompletionEvent), arg0) } // GetCronBackoffDuration mocks base method. @@ -1315,18 +1316,18 @@ func (mr *MockMutableStateMockRecorder) GetRetryBackoffDuration(failure interfac } // GetSignalExternalInitiatedEvent mocks base method. -func (m *MockMutableState) GetSignalExternalInitiatedEvent(arg0 int64) (*v13.HistoryEvent, error) { +func (m *MockMutableState) GetSignalExternalInitiatedEvent(arg0 context.Context, arg1 int64) (*v13.HistoryEvent, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSignalExternalInitiatedEvent", arg0) + ret := m.ctrl.Call(m, "GetSignalExternalInitiatedEvent", arg0, arg1) ret0, _ := ret[0].(*v13.HistoryEvent) ret1, _ := ret[1].(error) return ret0, ret1 } // GetSignalExternalInitiatedEvent indicates an expected call of GetSignalExternalInitiatedEvent. -func (mr *MockMutableStateMockRecorder) GetSignalExternalInitiatedEvent(arg0 interface{}) *gomock.Call { +func (mr *MockMutableStateMockRecorder) GetSignalExternalInitiatedEvent(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSignalExternalInitiatedEvent", reflect.TypeOf((*MockMutableState)(nil).GetSignalExternalInitiatedEvent), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSignalExternalInitiatedEvent", reflect.TypeOf((*MockMutableState)(nil).GetSignalExternalInitiatedEvent), arg0, arg1) } // GetSignalInfo mocks base method. @@ -1345,18 +1346,18 @@ func (mr *MockMutableStateMockRecorder) GetSignalInfo(arg0 interface{}) *gomock. } // GetStartEvent mocks base method. -func (m *MockMutableState) GetStartEvent() (*v13.HistoryEvent, error) { +func (m *MockMutableState) GetStartEvent(arg0 context.Context) (*v13.HistoryEvent, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetStartEvent") + ret := m.ctrl.Call(m, "GetStartEvent", arg0) ret0, _ := ret[0].(*v13.HistoryEvent) ret1, _ := ret[1].(error) return ret0, ret1 } // GetStartEvent indicates an expected call of GetStartEvent. -func (mr *MockMutableStateMockRecorder) GetStartEvent() *gomock.Call { +func (mr *MockMutableStateMockRecorder) GetStartEvent(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStartEvent", reflect.TypeOf((*MockMutableState)(nil).GetStartEvent)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStartEvent", reflect.TypeOf((*MockMutableState)(nil).GetStartEvent), arg0) } // GetStartVersion mocks base method. diff --git a/service/history/workflow/retry.go b/service/history/workflow/retry.go index 3355f2e160c..6ebc44472d1 100644 --- a/service/history/workflow/retry.go +++ b/service/history/workflow/retry.go @@ -25,6 +25,7 @@ package workflow import ( + "context" "math" "time" @@ -163,6 +164,7 @@ func matchNonRetryableTypes( // Helpers for creating new retry/cron workflows: func SetupNewWorkflowForRetryOrCron( + ctx context.Context, previousMutableState MutableState, newMutableState MutableState, newRunID string, diff --git a/service/history/workflow/task_refresher.go b/service/history/workflow/task_refresher.go index bd84ba2ca26..b8d62ddcfa2 100644 --- a/service/history/workflow/task_refresher.go +++ b/service/history/workflow/task_refresher.go @@ -27,6 +27,7 @@ package workflow import ( + "context" "time" enumspb "go.temporal.io/api/enums/v1" @@ -45,7 +46,7 @@ import ( type ( TaskRefresher interface { - RefreshTasks(now time.Time, mutableState MutableState) error + RefreshTasks(ctx context.Context, now time.Time, mutableState MutableState) error } TaskRefresherImpl struct { @@ -75,6 +76,7 @@ func NewTaskRefresher( } func (r *TaskRefresherImpl) RefreshTasks( + ctx context.Context, now time.Time, mutableState MutableState, ) error { @@ -85,6 +87,7 @@ func (r *TaskRefresherImpl) RefreshTasks( ) if err := r.refreshTasksForWorkflowStart( + ctx, now, mutableState, taskGenerator, @@ -101,6 +104,7 @@ func (r *TaskRefresherImpl) RefreshTasks( } if err := r.refreshTasksForRecordWorkflowStarted( + ctx, now, mutableState, taskGenerator, @@ -117,6 +121,7 @@ func (r *TaskRefresherImpl) RefreshTasks( } if err := r.refreshTasksForActivity( + ctx, now, mutableState, taskGenerator, @@ -133,6 +138,7 @@ func (r *TaskRefresherImpl) RefreshTasks( } if err := r.refreshTasksForChildWorkflow( + ctx, now, mutableState, taskGenerator, @@ -141,6 +147,7 @@ func (r *TaskRefresherImpl) RefreshTasks( } if err := r.refreshTasksForRequestCancelExternalWorkflow( + ctx, now, mutableState, taskGenerator, @@ -149,6 +156,7 @@ func (r *TaskRefresherImpl) RefreshTasks( } if err := r.refreshTasksForSignalExternalWorkflow( + ctx, now, mutableState, taskGenerator, @@ -170,12 +178,13 @@ func (r *TaskRefresherImpl) RefreshTasks( } func (r *TaskRefresherImpl) refreshTasksForWorkflowStart( + ctx context.Context, now time.Time, mutableState MutableState, taskGenerator TaskGenerator, ) error { - startEvent, err := mutableState.GetStartEvent() + startEvent, err := mutableState.GetStartEvent(ctx) if err != nil { return err } @@ -218,12 +227,13 @@ func (r *TaskRefresherImpl) refreshTasksForWorkflowClose( } func (r *TaskRefresherImpl) refreshTasksForRecordWorkflowStarted( + ctx context.Context, now time.Time, mutableState MutableState, taskGenerator TaskGenerator, ) error { - startEvent, err := mutableState.GetStartEvent() + startEvent, err := mutableState.GetStartEvent(ctx) if err != nil { return err } @@ -272,6 +282,7 @@ func (r *TaskRefresherImpl) refreshWorkflowTaskTasks( } func (r *TaskRefresherImpl) refreshTasksForActivity( + ctx context.Context, now time.Time, mutableState MutableState, taskGenerator TaskGenerator, @@ -303,6 +314,7 @@ Loop: } scheduleEvent, err := r.eventsCache.GetEvent( + ctx, events.EventKey{ NamespaceID: namespace.ID(executionInfo.NamespaceId), WorkflowID: executionInfo.WorkflowId, @@ -366,6 +378,7 @@ func (r *TaskRefresherImpl) refreshTasksForTimer( } func (r *TaskRefresherImpl) refreshTasksForChildWorkflow( + ctx context.Context, now time.Time, mutableState MutableState, taskGenerator TaskGenerator, @@ -387,6 +400,7 @@ Loop: } scheduleEvent, err := r.eventsCache.GetEvent( + ctx, events.EventKey{ NamespaceID: namespace.ID(executionInfo.NamespaceId), WorkflowID: executionInfo.WorkflowId, @@ -413,6 +427,7 @@ Loop: } func (r *TaskRefresherImpl) refreshTasksForRequestCancelExternalWorkflow( + ctx context.Context, now time.Time, mutableState MutableState, taskGenerator TaskGenerator, @@ -429,6 +444,7 @@ func (r *TaskRefresherImpl) refreshTasksForRequestCancelExternalWorkflow( for _, requestCancelInfo := range pendingRequestCancelInfos { initiateEvent, err := r.eventsCache.GetEvent( + ctx, events.EventKey{ NamespaceID: namespace.ID(executionInfo.NamespaceId), WorkflowID: executionInfo.WorkflowId, @@ -455,6 +471,7 @@ func (r *TaskRefresherImpl) refreshTasksForRequestCancelExternalWorkflow( } func (r *TaskRefresherImpl) refreshTasksForSignalExternalWorkflow( + ctx context.Context, now time.Time, mutableState MutableState, taskGenerator TaskGenerator, @@ -471,6 +488,7 @@ func (r *TaskRefresherImpl) refreshTasksForSignalExternalWorkflow( for _, signalInfo := range pendingSignalInfos { initiateEvent, err := r.eventsCache.GetEvent( + ctx, events.EventKey{ NamespaceID: namespace.ID(executionInfo.NamespaceId), WorkflowID: executionInfo.WorkflowId, diff --git a/service/history/workflow/task_refresher_mock.go b/service/history/workflow/task_refresher_mock.go index a76419075f6..b4557c06d0c 100644 --- a/service/history/workflow/task_refresher_mock.go +++ b/service/history/workflow/task_refresher_mock.go @@ -29,6 +29,7 @@ package workflow import ( + context "context" reflect "reflect" time "time" @@ -59,15 +60,15 @@ func (m *MockTaskRefresher) EXPECT() *MockTaskRefresherMockRecorder { } // RefreshTasks mocks base method. -func (m *MockTaskRefresher) RefreshTasks(now time.Time, mutableState MutableState) error { +func (m *MockTaskRefresher) RefreshTasks(ctx context.Context, now time.Time, mutableState MutableState) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RefreshTasks", now, mutableState) + ret := m.ctrl.Call(m, "RefreshTasks", ctx, now, mutableState) ret0, _ := ret[0].(error) return ret0 } // RefreshTasks indicates an expected call of RefreshTasks. -func (mr *MockTaskRefresherMockRecorder) RefreshTasks(now, mutableState interface{}) *gomock.Call { +func (mr *MockTaskRefresherMockRecorder) RefreshTasks(ctx, now, mutableState interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RefreshTasks", reflect.TypeOf((*MockTaskRefresher)(nil).RefreshTasks), now, mutableState) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RefreshTasks", reflect.TypeOf((*MockTaskRefresher)(nil).RefreshTasks), ctx, now, mutableState) } diff --git a/service/history/workflow/transaction.go b/service/history/workflow/transaction.go index 9dd4fdedd8c..e160bf6a999 100644 --- a/service/history/workflow/transaction.go +++ b/service/history/workflow/transaction.go @@ -25,6 +25,8 @@ package workflow import ( + "context" + "go.temporal.io/server/common/persistence" ) @@ -32,6 +34,7 @@ import ( type ( Transaction interface { CreateWorkflowExecution( + ctx context.Context, createMode persistence.CreateWorkflowMode, newWorkflowSnapshot *persistence.WorkflowSnapshot, newWorkflowEventsSeq []*persistence.WorkflowEvents, @@ -39,6 +42,7 @@ type ( ) (int64, error) ConflictResolveWorkflowExecution( + ctx context.Context, conflictResolveMode persistence.ConflictResolveWorkflowMode, resetWorkflowSnapshot *persistence.WorkflowSnapshot, resetWorkflowEventsSeq []*persistence.WorkflowEvents, @@ -50,6 +54,7 @@ type ( ) (int64, int64, int64, error) UpdateWorkflowExecution( + ctx context.Context, updateMode persistence.UpdateWorkflowMode, currentWorkflowMutation *persistence.WorkflowMutation, currentWorkflowEventsSeq []*persistence.WorkflowEvents, @@ -59,6 +64,7 @@ type ( ) (int64, int64, error) SetWorkflowExecution( + ctx context.Context, workflowSnapshot *persistence.WorkflowSnapshot, clusterName string, ) error diff --git a/service/history/workflow/transaction_impl.go b/service/history/workflow/transaction_impl.go index 67845b39582..83bdbe3e714 100644 --- a/service/history/workflow/transaction_impl.go +++ b/service/history/workflow/transaction_impl.go @@ -25,6 +25,8 @@ package workflow import ( + "context" + commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" "go.temporal.io/api/serviceerror" @@ -66,6 +68,7 @@ func NewTransaction( } func (t *TransactionImpl) CreateWorkflowExecution( + ctx context.Context, createMode persistence.CreateWorkflowMode, newWorkflowSnapshot *persistence.WorkflowSnapshot, newWorkflowEventsSeq []*persistence.WorkflowEvents, @@ -77,7 +80,7 @@ func (t *TransactionImpl) CreateWorkflowExecution( return 0, err } - resp, err := createWorkflowExecutionWithRetry(t.shard, &persistence.CreateWorkflowExecutionRequest{ + resp, err := createWorkflowExecutionWithRetry(ctx, t.shard, &persistence.CreateWorkflowExecutionRequest{ ShardID: t.shard.GetShardID(), // RangeID , this is set by shard context Mode: createMode, @@ -99,6 +102,7 @@ func (t *TransactionImpl) CreateWorkflowExecution( } func (t *TransactionImpl) ConflictResolveWorkflowExecution( + ctx context.Context, conflictResolveMode persistence.ConflictResolveWorkflowMode, resetWorkflowSnapshot *persistence.WorkflowSnapshot, resetWorkflowEventsSeq []*persistence.WorkflowEvents, @@ -114,7 +118,7 @@ func (t *TransactionImpl) ConflictResolveWorkflowExecution( return 0, 0, 0, err } - resp, err := conflictResolveWorkflowExecutionWithRetry(t.shard, &persistence.ConflictResolveWorkflowExecutionRequest{ + resp, err := conflictResolveWorkflowExecutionWithRetry(ctx, t.shard, &persistence.ConflictResolveWorkflowExecutionRequest{ ShardID: t.shard.GetShardID(), // RangeID , this is set by shard context Mode: conflictResolveMode, @@ -156,6 +160,7 @@ func (t *TransactionImpl) ConflictResolveWorkflowExecution( } func (t *TransactionImpl) UpdateWorkflowExecution( + ctx context.Context, updateMode persistence.UpdateWorkflowMode, currentWorkflowMutation *persistence.WorkflowMutation, currentWorkflowEventsSeq []*persistence.WorkflowEvents, @@ -168,7 +173,7 @@ func (t *TransactionImpl) UpdateWorkflowExecution( if err != nil { return 0, 0, err } - resp, err := updateWorkflowExecutionWithRetry(t.shard, &persistence.UpdateWorkflowExecutionRequest{ + resp, err := updateWorkflowExecutionWithRetry(ctx, t.shard, &persistence.UpdateWorkflowExecutionRequest{ ShardID: t.shard.GetShardID(), // RangeID , this is set by shard context Mode: updateMode, @@ -200,6 +205,7 @@ func (t *TransactionImpl) UpdateWorkflowExecution( } func (t *TransactionImpl) SetWorkflowExecution( + ctx context.Context, workflowSnapshot *persistence.WorkflowSnapshot, clusterName string, ) error { @@ -208,7 +214,7 @@ func (t *TransactionImpl) SetWorkflowExecution( if err != nil { return err } - _, err = setWorkflowExecutionWithRetry(t.shard, &persistence.SetWorkflowExecutionRequest{ + _, err = setWorkflowExecutionWithRetry(ctx, t.shard, &persistence.SetWorkflowExecutionRequest{ ShardID: t.shard.GetShardID(), // RangeID , this is set by shard context SetWorkflowSnapshot: *workflowSnapshot, @@ -224,6 +230,7 @@ func (t *TransactionImpl) SetWorkflowExecution( } func PersistWorkflowEvents( + ctx context.Context, shard shard.Context, workflowEvents *persistence.WorkflowEvents, ) (int64, error) { @@ -234,12 +241,13 @@ func PersistWorkflowEvents( firstEventID := workflowEvents.Events[0].EventId if firstEventID == common.FirstEventID { - return persistFirstWorkflowEvents(shard, workflowEvents) + return persistFirstWorkflowEvents(ctx, shard, workflowEvents) } - return persistNonFirstWorkflowEvents(shard, workflowEvents) + return persistNonFirstWorkflowEvents(ctx, shard, workflowEvents) } func persistFirstWorkflowEvents( + ctx context.Context, shard shard.Context, workflowEvents *persistence.WorkflowEvents, ) (int64, error) { @@ -257,6 +265,7 @@ func persistFirstWorkflowEvents( txnID := workflowEvents.TxnID size, err := appendHistoryV2EventsWithRetry( + ctx, shard, namespaceID, execution, @@ -273,6 +282,7 @@ func persistFirstWorkflowEvents( } func persistNonFirstWorkflowEvents( + ctx context.Context, shard shard.Context, workflowEvents *persistence.WorkflowEvents, ) (int64, error) { @@ -292,6 +302,7 @@ func persistNonFirstWorkflowEvents( txnID := workflowEvents.TxnID size, err := appendHistoryV2EventsWithRetry( + ctx, shard, namespaceID, execution, @@ -307,6 +318,7 @@ func persistNonFirstWorkflowEvents( } func appendHistoryV2EventsWithRetry( + ctx context.Context, shard shard.Context, namespaceID namespace.ID, execution commonpb.WorkflowExecution, @@ -316,7 +328,7 @@ func appendHistoryV2EventsWithRetry( resp := 0 op := func() error { var err error - resp, err = shard.AppendHistoryEvents(request, namespaceID, execution) + resp, err = shard.AppendHistoryEvents(ctx, request, namespaceID, execution) return err } @@ -329,6 +341,7 @@ func appendHistoryV2EventsWithRetry( } func createWorkflowExecutionWithRetry( + ctx context.Context, shard shard.Context, request *persistence.CreateWorkflowExecutionRequest, ) (*persistence.CreateWorkflowExecutionResponse, error) { @@ -336,7 +349,7 @@ func createWorkflowExecutionWithRetry( var resp *persistence.CreateWorkflowExecutionResponse op := func() error { var err error - resp, err = shard.CreateWorkflowExecution(request) + resp, err = shard.CreateWorkflowExecution(ctx, request) return err } @@ -377,6 +390,7 @@ func createWorkflowExecutionWithRetry( } func conflictResolveWorkflowExecutionWithRetry( + ctx context.Context, shard shard.Context, request *persistence.ConflictResolveWorkflowExecutionRequest, ) (*persistence.ConflictResolveWorkflowExecutionResponse, error) { @@ -384,7 +398,7 @@ func conflictResolveWorkflowExecutionWithRetry( var resp *persistence.ConflictResolveWorkflowExecutionResponse op := func() error { var err error - resp, err = shard.ConflictResolveWorkflowExecution(request) + resp, err = shard.ConflictResolveWorkflowExecution(ctx, request) return err } @@ -441,7 +455,7 @@ func getWorkflowExecutionWithRetry( var resp *persistence.GetWorkflowExecutionResponse op := func() error { var err error - resp, err = shard.GetExecutionManager().GetWorkflowExecution(request) + resp, err = shard.GetExecutionManager().GetWorkflowExecution(context.TODO(), request) return err } @@ -480,6 +494,7 @@ func getWorkflowExecutionWithRetry( } func updateWorkflowExecutionWithRetry( + ctx context.Context, shard shard.Context, request *persistence.UpdateWorkflowExecutionRequest, ) (*persistence.UpdateWorkflowExecutionResponse, error) { @@ -487,7 +502,7 @@ func updateWorkflowExecutionWithRetry( var resp *persistence.UpdateWorkflowExecutionResponse var err error op := func() error { - resp, err = shard.UpdateWorkflowExecution(request) + resp, err = shard.UpdateWorkflowExecution(ctx, request) return err } @@ -536,6 +551,7 @@ func updateWorkflowExecutionWithRetry( } func setWorkflowExecutionWithRetry( + ctx context.Context, shard shard.Context, request *persistence.SetWorkflowExecutionRequest, ) (*persistence.SetWorkflowExecutionResponse, error) { @@ -543,7 +559,7 @@ func setWorkflowExecutionWithRetry( var resp *persistence.SetWorkflowExecutionResponse var err error op := func() error { - resp, err = shard.SetWorkflowExecution(request) + resp, err = shard.SetWorkflowExecution(ctx, request) return err } diff --git a/service/history/workflow/transaction_mock.go b/service/history/workflow/transaction_mock.go index a600dcae76a..13221699ad7 100644 --- a/service/history/workflow/transaction_mock.go +++ b/service/history/workflow/transaction_mock.go @@ -29,6 +29,7 @@ package workflow import ( + context "context" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -59,9 +60,9 @@ func (m *MockTransaction) EXPECT() *MockTransactionMockRecorder { } // ConflictResolveWorkflowExecution mocks base method. -func (m *MockTransaction) ConflictResolveWorkflowExecution(conflictResolveMode persistence.ConflictResolveWorkflowMode, resetWorkflowSnapshot *persistence.WorkflowSnapshot, resetWorkflowEventsSeq []*persistence.WorkflowEvents, newWorkflowSnapshot *persistence.WorkflowSnapshot, newWorkflowEventsSeq []*persistence.WorkflowEvents, currentWorkflowMutation *persistence.WorkflowMutation, currentWorkflowEventsSeq []*persistence.WorkflowEvents, clusterName string) (int64, int64, int64, error) { +func (m *MockTransaction) ConflictResolveWorkflowExecution(ctx context.Context, conflictResolveMode persistence.ConflictResolveWorkflowMode, resetWorkflowSnapshot *persistence.WorkflowSnapshot, resetWorkflowEventsSeq []*persistence.WorkflowEvents, newWorkflowSnapshot *persistence.WorkflowSnapshot, newWorkflowEventsSeq []*persistence.WorkflowEvents, currentWorkflowMutation *persistence.WorkflowMutation, currentWorkflowEventsSeq []*persistence.WorkflowEvents, clusterName string) (int64, int64, int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ConflictResolveWorkflowExecution", conflictResolveMode, resetWorkflowSnapshot, resetWorkflowEventsSeq, newWorkflowSnapshot, newWorkflowEventsSeq, currentWorkflowMutation, currentWorkflowEventsSeq, clusterName) + ret := m.ctrl.Call(m, "ConflictResolveWorkflowExecution", ctx, conflictResolveMode, resetWorkflowSnapshot, resetWorkflowEventsSeq, newWorkflowSnapshot, newWorkflowEventsSeq, currentWorkflowMutation, currentWorkflowEventsSeq, clusterName) ret0, _ := ret[0].(int64) ret1, _ := ret[1].(int64) ret2, _ := ret[2].(int64) @@ -70,44 +71,44 @@ func (m *MockTransaction) ConflictResolveWorkflowExecution(conflictResolveMode p } // ConflictResolveWorkflowExecution indicates an expected call of ConflictResolveWorkflowExecution. -func (mr *MockTransactionMockRecorder) ConflictResolveWorkflowExecution(conflictResolveMode, resetWorkflowSnapshot, resetWorkflowEventsSeq, newWorkflowSnapshot, newWorkflowEventsSeq, currentWorkflowMutation, currentWorkflowEventsSeq, clusterName interface{}) *gomock.Call { +func (mr *MockTransactionMockRecorder) ConflictResolveWorkflowExecution(ctx, conflictResolveMode, resetWorkflowSnapshot, resetWorkflowEventsSeq, newWorkflowSnapshot, newWorkflowEventsSeq, currentWorkflowMutation, currentWorkflowEventsSeq, clusterName interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConflictResolveWorkflowExecution", reflect.TypeOf((*MockTransaction)(nil).ConflictResolveWorkflowExecution), conflictResolveMode, resetWorkflowSnapshot, resetWorkflowEventsSeq, newWorkflowSnapshot, newWorkflowEventsSeq, currentWorkflowMutation, currentWorkflowEventsSeq, clusterName) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConflictResolveWorkflowExecution", reflect.TypeOf((*MockTransaction)(nil).ConflictResolveWorkflowExecution), ctx, conflictResolveMode, resetWorkflowSnapshot, resetWorkflowEventsSeq, newWorkflowSnapshot, newWorkflowEventsSeq, currentWorkflowMutation, currentWorkflowEventsSeq, clusterName) } // CreateWorkflowExecution mocks base method. -func (m *MockTransaction) CreateWorkflowExecution(createMode persistence.CreateWorkflowMode, newWorkflowSnapshot *persistence.WorkflowSnapshot, newWorkflowEventsSeq []*persistence.WorkflowEvents, clusterName string) (int64, error) { +func (m *MockTransaction) CreateWorkflowExecution(ctx context.Context, createMode persistence.CreateWorkflowMode, newWorkflowSnapshot *persistence.WorkflowSnapshot, newWorkflowEventsSeq []*persistence.WorkflowEvents, clusterName string) (int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateWorkflowExecution", createMode, newWorkflowSnapshot, newWorkflowEventsSeq, clusterName) + ret := m.ctrl.Call(m, "CreateWorkflowExecution", ctx, createMode, newWorkflowSnapshot, newWorkflowEventsSeq, clusterName) ret0, _ := ret[0].(int64) ret1, _ := ret[1].(error) return ret0, ret1 } // CreateWorkflowExecution indicates an expected call of CreateWorkflowExecution. -func (mr *MockTransactionMockRecorder) CreateWorkflowExecution(createMode, newWorkflowSnapshot, newWorkflowEventsSeq, clusterName interface{}) *gomock.Call { +func (mr *MockTransactionMockRecorder) CreateWorkflowExecution(ctx, createMode, newWorkflowSnapshot, newWorkflowEventsSeq, clusterName interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateWorkflowExecution", reflect.TypeOf((*MockTransaction)(nil).CreateWorkflowExecution), createMode, newWorkflowSnapshot, newWorkflowEventsSeq, clusterName) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateWorkflowExecution", reflect.TypeOf((*MockTransaction)(nil).CreateWorkflowExecution), ctx, createMode, newWorkflowSnapshot, newWorkflowEventsSeq, clusterName) } // SetWorkflowExecution mocks base method. -func (m *MockTransaction) SetWorkflowExecution(workflowSnapshot *persistence.WorkflowSnapshot, clusterName string) error { +func (m *MockTransaction) SetWorkflowExecution(ctx context.Context, workflowSnapshot *persistence.WorkflowSnapshot, clusterName string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetWorkflowExecution", workflowSnapshot, clusterName) + ret := m.ctrl.Call(m, "SetWorkflowExecution", ctx, workflowSnapshot, clusterName) ret0, _ := ret[0].(error) return ret0 } // SetWorkflowExecution indicates an expected call of SetWorkflowExecution. -func (mr *MockTransactionMockRecorder) SetWorkflowExecution(workflowSnapshot, clusterName interface{}) *gomock.Call { +func (mr *MockTransactionMockRecorder) SetWorkflowExecution(ctx, workflowSnapshot, clusterName interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWorkflowExecution", reflect.TypeOf((*MockTransaction)(nil).SetWorkflowExecution), workflowSnapshot, clusterName) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWorkflowExecution", reflect.TypeOf((*MockTransaction)(nil).SetWorkflowExecution), ctx, workflowSnapshot, clusterName) } // UpdateWorkflowExecution mocks base method. -func (m *MockTransaction) UpdateWorkflowExecution(updateMode persistence.UpdateWorkflowMode, currentWorkflowMutation *persistence.WorkflowMutation, currentWorkflowEventsSeq []*persistence.WorkflowEvents, newWorkflowSnapshot *persistence.WorkflowSnapshot, newWorkflowEventsSeq []*persistence.WorkflowEvents, clusterName string) (int64, int64, error) { +func (m *MockTransaction) UpdateWorkflowExecution(ctx context.Context, updateMode persistence.UpdateWorkflowMode, currentWorkflowMutation *persistence.WorkflowMutation, currentWorkflowEventsSeq []*persistence.WorkflowEvents, newWorkflowSnapshot *persistence.WorkflowSnapshot, newWorkflowEventsSeq []*persistence.WorkflowEvents, clusterName string) (int64, int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateWorkflowExecution", updateMode, currentWorkflowMutation, currentWorkflowEventsSeq, newWorkflowSnapshot, newWorkflowEventsSeq, clusterName) + ret := m.ctrl.Call(m, "UpdateWorkflowExecution", ctx, updateMode, currentWorkflowMutation, currentWorkflowEventsSeq, newWorkflowSnapshot, newWorkflowEventsSeq, clusterName) ret0, _ := ret[0].(int64) ret1, _ := ret[1].(int64) ret2, _ := ret[2].(error) @@ -115,7 +116,7 @@ func (m *MockTransaction) UpdateWorkflowExecution(updateMode persistence.UpdateW } // UpdateWorkflowExecution indicates an expected call of UpdateWorkflowExecution. -func (mr *MockTransactionMockRecorder) UpdateWorkflowExecution(updateMode, currentWorkflowMutation, currentWorkflowEventsSeq, newWorkflowSnapshot, newWorkflowEventsSeq, clusterName interface{}) *gomock.Call { +func (mr *MockTransactionMockRecorder) UpdateWorkflowExecution(ctx, updateMode, currentWorkflowMutation, currentWorkflowEventsSeq, newWorkflowSnapshot, newWorkflowEventsSeq, clusterName interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkflowExecution", reflect.TypeOf((*MockTransaction)(nil).UpdateWorkflowExecution), updateMode, currentWorkflowMutation, currentWorkflowEventsSeq, newWorkflowSnapshot, newWorkflowEventsSeq, clusterName) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkflowExecution", reflect.TypeOf((*MockTransaction)(nil).UpdateWorkflowExecution), ctx, updateMode, currentWorkflowMutation, currentWorkflowEventsSeq, newWorkflowSnapshot, newWorkflowEventsSeq, clusterName) } diff --git a/service/history/workflow/transaction_test.go b/service/history/workflow/transaction_test.go index 7eac056ae70..ac281451aae 100644 --- a/service/history/workflow/transaction_test.go +++ b/service/history/workflow/transaction_test.go @@ -25,6 +25,7 @@ package workflow import ( + "context" "errors" "testing" @@ -111,10 +112,11 @@ func (s *transactionSuite) TestCreateWorkflowExecution_NotifyTaskWhenFailed() { timeoutErr := &persistence.TimeoutError{} s.True(operationPossiblySucceeded(timeoutErr)) - s.mockShard.EXPECT().CreateWorkflowExecution(gomock.Any()).Return(nil, timeoutErr) + s.mockShard.EXPECT().CreateWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, timeoutErr) s.setupMockForTaskNotification() _, err := s.transaction.CreateWorkflowExecution( + context.Background(), persistence.CreateWorkflowModeBrandNew, &persistence.WorkflowSnapshot{ ExecutionInfo: &persistencespb.WorkflowExecutionInfo{ @@ -135,11 +137,12 @@ func (s *transactionSuite) TestUpdateWorkflowExecution_NotifyTaskWhenFailed() { timeoutErr := &persistence.TimeoutError{} s.True(operationPossiblySucceeded(timeoutErr)) - s.mockShard.EXPECT().UpdateWorkflowExecution(gomock.Any()).Return(nil, timeoutErr) + s.mockShard.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, timeoutErr) s.setupMockForTaskNotification() // for current workflow mutation s.setupMockForTaskNotification() // for new workflow snapshot _, _, err := s.transaction.UpdateWorkflowExecution( + context.Background(), persistence.UpdateWorkflowModeUpdateCurrent, &persistence.WorkflowMutation{ ExecutionInfo: &persistencespb.WorkflowExecutionInfo{ @@ -162,12 +165,13 @@ func (s *transactionSuite) TestConflictResolveWorkflowExecution_NotifyTaskWhenFa timeoutErr := &persistence.TimeoutError{} s.True(operationPossiblySucceeded(timeoutErr)) - s.mockShard.EXPECT().ConflictResolveWorkflowExecution(gomock.Any()).Return(nil, timeoutErr) + s.mockShard.EXPECT().ConflictResolveWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, timeoutErr) s.setupMockForTaskNotification() // for reset workflow snapshot s.setupMockForTaskNotification() // for new workflow snapshot s.setupMockForTaskNotification() // for current workflow mutation _, _, _, err := s.transaction.ConflictResolveWorkflowExecution( + context.Background(), persistence.ConflictResolveWorkflowModeUpdateCurrent, &persistence.WorkflowSnapshot{ ExecutionInfo: &persistencespb.WorkflowExecutionInfo{ diff --git a/service/history/workflowExecutionUtil.go b/service/history/workflowExecutionUtil.go index 8469dd21e1e..0b8d4f37283 100644 --- a/service/history/workflowExecutionUtil.go +++ b/service/history/workflowExecutionUtil.go @@ -25,13 +25,15 @@ package history import ( + "context" + "go.temporal.io/server/service/history/workflow" ) type workflowContext interface { getContext() workflow.Context getMutableState() workflow.MutableState - reloadMutableState() (workflow.MutableState, error) + reloadMutableState(context.Context) (workflow.MutableState, error) getReleaseFn() workflow.ReleaseCacheFunc getWorkflowID() string getRunID() string @@ -67,8 +69,10 @@ func (w *workflowContextImpl) getMutableState() workflow.MutableState { return w.mutableState } -func (w *workflowContextImpl) reloadMutableState() (workflow.MutableState, error) { - mutableState, err := w.getContext().LoadWorkflowExecution() +func (w *workflowContextImpl) reloadMutableState( + ctx context.Context, +) (workflow.MutableState, error) { + mutableState, err := w.getContext().LoadWorkflowExecution(ctx) if err != nil { return nil, err } diff --git a/service/history/workflowRebuilder.go b/service/history/workflowRebuilder.go index d616c9aa4c6..b6266c56756 100644 --- a/service/history/workflowRebuilder.go +++ b/service/history/workflowRebuilder.go @@ -103,7 +103,7 @@ func (r *workflowRebuilderImpl) rebuild( context.Clear() }() - msRecord, dbRecordVersion, err := r.getMutableState(workflowKey) + msRecord, dbRecordVersion, err := r.getMutableState(ctx, workflowKey) if err != nil { return err } @@ -127,7 +127,7 @@ func (r *workflowRebuilderImpl) rebuild( if err != nil { return err } - return r.persistToDB(rebuildMutableState, rebuildHistorySize) + return r.persistToDB(ctx, rebuildMutableState, rebuildHistorySize) } func (r *workflowRebuilderImpl) replayResetWorkflow( @@ -162,6 +162,7 @@ func (r *workflowRebuilderImpl) replayResetWorkflow( } func (r *workflowRebuilderImpl) persistToDB( + ctx context.Context, mutableState workflow.MutableState, historySize int64, ) error { @@ -181,6 +182,7 @@ func (r *workflowRebuilderImpl) persistToDB( HistorySize: historySize, } if err := r.transaction.SetWorkflowExecution( + ctx, resetWorkflowSnapshot, mutableState.GetNamespaceEntry().ActiveClusterName(), ); err != nil { @@ -190,9 +192,10 @@ func (r *workflowRebuilderImpl) persistToDB( } func (r *workflowRebuilderImpl) getMutableState( + ctx context.Context, workflowKey definition.WorkflowKey, ) (*persistencespb.WorkflowMutableState, int64, error) { - record, err := r.executionMgr.GetWorkflowExecution(&persistence.GetWorkflowExecutionRequest{ + record, err := r.executionMgr.GetWorkflowExecution(ctx, &persistence.GetWorkflowExecutionRequest{ ShardID: r.shard.GetShardID(), NamespaceID: workflowKey.NamespaceID, WorkflowID: workflowKey.WorkflowID, diff --git a/service/history/workflowResetter.go b/service/history/workflowResetter.go index facae7e3113..ff7c71288e1 100644 --- a/service/history/workflowResetter.go +++ b/service/history/workflowResetter.go @@ -231,6 +231,7 @@ func (r *workflowResetterImpl) resetWorkflow( } return r.persistToDB( + ctx, currentWorkflow, currentWorkflowMutation, currentWorkflowEventsSeq, @@ -341,6 +342,7 @@ func (r *workflowResetterImpl) reapplyEventsToResetWorkflow( } func (r *workflowResetterImpl) persistToDB( + ctx context.Context, currentWorkflow nDCWorkflow, currentWorkflowMutation *persistence.WorkflowMutation, currentWorkflowEventsSeq []*persistence.WorkflowEvents, @@ -361,6 +363,7 @@ func (r *workflowResetterImpl) persistToDB( if currentWorkflowMutation != nil { if currentWorkflowSizeDiff, resetWorkflowSizeDiff, err := r.transaction.UpdateWorkflowExecution( + ctx, persistence.UpdateWorkflowModeUpdateCurrent, currentWorkflowMutation, currentWorkflowEventsSeq, @@ -384,6 +387,7 @@ func (r *workflowResetterImpl) persistToDB( } return resetWorkflow.getContext().CreateWorkflowExecution( + ctx, now, persistence.CreateWorkflowModeWorkflowIDReuse, currentRunID, @@ -407,6 +411,7 @@ func (r *workflowResetterImpl) replayResetWorkflow( ) (nDCWorkflow, error) { resetBranchToken, err := r.forkAndGenerateBranchToken( + ctx, namespaceID, workflowID, baseBranchToken, @@ -545,6 +550,7 @@ func (r *workflowResetterImpl) failInflightActivity( } func (r *workflowResetterImpl) forkAndGenerateBranchToken( + ctx context.Context, namespaceID namespace.ID, workflowID string, forkBranchToken []byte, @@ -553,7 +559,7 @@ func (r *workflowResetterImpl) forkAndGenerateBranchToken( ) ([]byte, error) { // fork a new history branch shardID := r.shard.GetShardID() - resp, err := r.executionMgr.ForkHistoryBranch(&persistence.ForkHistoryBranchRequest{ + resp, err := r.executionMgr.ForkHistoryBranch(ctx, &persistence.ForkHistoryBranchRequest{ ForkBranchToken: forkBranchToken, ForkNodeID: forkNodeID, Info: persistence.BuildHistoryGarbageCleanupInfo(namespaceID.String(), workflowID, resetRunID), @@ -599,6 +605,7 @@ func (r *workflowResetterImpl) reapplyContinueAsNewWorkflowEvents( // first special handling the remaining events for base workflow nextRunID, err := r.reapplyWorkflowEvents( + ctx, resetMutableState, baseRebuildNextEventID, baseNextEventID, @@ -630,7 +637,7 @@ func (r *workflowResetterImpl) reapplyContinueAsNewWorkflowEvents( } defer func() { release(retError) }() - mutableState, err := context.LoadWorkflowExecution() + mutableState, err := context.LoadWorkflowExecution(ctx) if err != nil { // no matter what error happen, we need to retry return 0, nil, err @@ -653,6 +660,7 @@ func (r *workflowResetterImpl) reapplyContinueAsNewWorkflowEvents( } nextRunID, err = r.reapplyWorkflowEvents( + ctx, resetMutableState, common.FirstEventID, nextWorkflowNextEventID, @@ -673,6 +681,7 @@ func (r *workflowResetterImpl) reapplyContinueAsNewWorkflowEvents( } func (r *workflowResetterImpl) reapplyWorkflowEvents( + ctx context.Context, mutableState workflow.MutableState, firstEventID int64, nextEventID int64, @@ -684,6 +693,7 @@ func (r *workflowResetterImpl) reapplyWorkflowEvents( // after the above change, this API do not have to return the continue as new run ID iter := collection.NewPagingIterator(r.getPaginationFn( + ctx, firstEventID, nextEventID, branchToken, @@ -737,6 +747,7 @@ func (r *workflowResetterImpl) reapplyEvents( } func (r *workflowResetterImpl) getPaginationFn( + ctx context.Context, firstEventID int64, nextEventID int64, branchToken []byte, @@ -744,7 +755,7 @@ func (r *workflowResetterImpl) getPaginationFn( return func(paginationToken []byte) ([]interface{}, []byte, error) { - resp, err := r.executionMgr.ReadHistoryBranchByBatch(&persistence.ReadHistoryBranchRequest{ + resp, err := r.executionMgr.ReadHistoryBranchByBatch(ctx, &persistence.ReadHistoryBranchRequest{ BranchToken: branchToken, MinEventID: firstEventID, MaxEventID: nextEventID, diff --git a/service/history/workflowResetter_test.go b/service/history/workflowResetter_test.go index a665b4c3aff..b767eed1fb9 100644 --- a/service/history/workflowResetter_test.go +++ b/service/history/workflowResetter_test.go @@ -199,6 +199,7 @@ func (s *workflowResetterSuite) TestPersistToDB_CurrentTerminated() { resetContext.EXPECT().SetHistorySize(resetEventsSize + resetNewEventsSize) s.mockTransaction.EXPECT().UpdateWorkflowExecution( + gomock.Any(), persistence.UpdateWorkflowModeUpdateCurrent, currentMutation, currentEventsSeq, @@ -207,7 +208,7 @@ func (s *workflowResetterSuite) TestPersistToDB_CurrentTerminated() { "active-cluster-name", ).Return(currentNewEventsSize, resetNewEventsSize, nil) - err := s.workflowResetter.persistToDB(currentWorkflow, currentMutation, currentEventsSeq, resetWorkflow) + err := s.workflowResetter.persistToDB(context.Background(), currentWorkflow, currentMutation, currentEventsSeq, resetWorkflow) s.NoError(err) // persistToDB function is not charged of releasing locks s.False(currentReleaseCalled) @@ -258,6 +259,7 @@ func (s *workflowResetterSuite) TestPersistToDB_CurrentNotTerminated() { ).Return(resetSnapshot, resetEventsSeq, nil) resetContext.EXPECT().GetHistorySize().Return(int64(123)).AnyTimes() resetContext.EXPECT().CreateWorkflowExecution( + gomock.Any(), gomock.Any(), persistence.CreateWorkflowModeWorkflowIDReuse, s.currentRunID, @@ -267,7 +269,7 @@ func (s *workflowResetterSuite) TestPersistToDB_CurrentNotTerminated() { resetEventsSeq, ).Return(nil) - err := s.workflowResetter.persistToDB(currentWorkflow, nil, nil, resetWorkflow) + err := s.workflowResetter.persistToDB(context.Background(), currentWorkflow, nil, nil, resetWorkflow) s.NoError(err) // persistToDB function is not charged of releasing locks s.False(currentReleaseCalled) @@ -287,7 +289,7 @@ func (s *workflowResetterSuite) TestReplayResetWorkflow() { resetMutableState := workflow.NewMockMutableState(s.controller) shardID := s.mockShard.GetShardID() - s.mockExecutionMgr.EXPECT().ForkHistoryBranch(&persistence.ForkHistoryBranchRequest{ + s.mockExecutionMgr.EXPECT().ForkHistoryBranch(gomock.Any(), &persistence.ForkHistoryBranchRequest{ ForkBranchToken: baseBranchToken, ForkNodeID: baseNodeID, Info: persistence.BuildHistoryGarbageCleanupInfo(s.namespaceID.String(), s.workflowID, s.resetRunID), @@ -496,7 +498,7 @@ func (s *workflowResetterSuite) TestGenerateBranchToken() { resetBranchToken := []byte("some random reset branch token") shardID := s.mockShard.GetShardID() - s.mockExecutionMgr.EXPECT().ForkHistoryBranch(&persistence.ForkHistoryBranchRequest{ + s.mockExecutionMgr.EXPECT().ForkHistoryBranch(gomock.Any(), &persistence.ForkHistoryBranchRequest{ ForkBranchToken: baseBranchToken, ForkNodeID: baseNodeID, Info: persistence.BuildHistoryGarbageCleanupInfo(s.namespaceID.String(), s.workflowID, s.resetRunID), @@ -504,7 +506,7 @@ func (s *workflowResetterSuite) TestGenerateBranchToken() { }).Return(&persistence.ForkHistoryBranchResponse{NewBranchToken: resetBranchToken}, nil) newBranchToken, err := s.workflowResetter.forkAndGenerateBranchToken( - s.namespaceID, s.workflowID, baseBranchToken, baseNodeID, s.resetRunID, + context.Background(), s.namespaceID, s.workflowID, baseBranchToken, baseNodeID, s.resetRunID, ) s.NoError(err) s.Equal(resetBranchToken, newBranchToken) @@ -575,7 +577,7 @@ func (s *workflowResetterSuite) TestReapplyContinueAsNewWorkflowEvents_WithOutCo baseEvents := []*historypb.HistoryEvent{baseEvent1, baseEvent2, baseEvent3, baseEvent4} shardID := s.mockShard.GetShardID() - s.mockExecutionMgr.EXPECT().ReadHistoryBranchByBatch(&persistence.ReadHistoryBranchRequest{ + s.mockExecutionMgr.EXPECT().ReadHistoryBranchByBatch(gomock.Any(), &persistence.ReadHistoryBranchRequest{ BranchToken: baseBranchToken, MinEventID: baseFirstEventID, MaxEventID: baseNextEventID, @@ -665,7 +667,7 @@ func (s *workflowResetterSuite) TestReapplyContinueAsNewWorkflowEvents_WithConti baseEvents := []*historypb.HistoryEvent{baseEvent1, baseEvent2, baseEvent3, baseEvent4} shardID := s.mockShard.GetShardID() - s.mockExecutionMgr.EXPECT().ReadHistoryBranchByBatch(&persistence.ReadHistoryBranchRequest{ + s.mockExecutionMgr.EXPECT().ReadHistoryBranchByBatch(gomock.Any(), &persistence.ReadHistoryBranchRequest{ BranchToken: baseBranchToken, MinEventID: baseFirstEventID, MaxEventID: baseNextEventID, @@ -678,7 +680,7 @@ func (s *workflowResetterSuite) TestReapplyContinueAsNewWorkflowEvents_WithConti }, nil) newEvents := []*historypb.HistoryEvent{newEvent1, newEvent2, newEvent3, newEvent4, newEvent5} - s.mockExecutionMgr.EXPECT().ReadHistoryBranchByBatch(&persistence.ReadHistoryBranchRequest{ + s.mockExecutionMgr.EXPECT().ReadHistoryBranchByBatch(gomock.Any(), &persistence.ReadHistoryBranchRequest{ BranchToken: newBranchToken, MinEventID: newFirstEventID, MaxEventID: newNextEventID, @@ -694,7 +696,7 @@ func (s *workflowResetterSuite) TestReapplyContinueAsNewWorkflowEvents_WithConti resetContext.EXPECT().Lock(gomock.Any(), workflow.CallerTypeAPI).Return(nil) resetContext.EXPECT().Unlock(workflow.CallerTypeAPI) resetMutableState := workflow.NewMockMutableState(s.controller) - resetContext.EXPECT().LoadWorkflowExecution().Return(resetMutableState, nil) + resetContext.EXPECT().LoadWorkflowExecution(gomock.Any()).Return(resetMutableState, nil) resetMutableState.EXPECT().GetNextEventID().Return(newNextEventID).AnyTimes() resetMutableState.EXPECT().GetCurrentBranchToken().Return(newBranchToken, nil).AnyTimes() resetContextCacheKey := definition.NewWorkflowKey(s.namespaceID.String(), s.workflowID, newRunID) @@ -751,7 +753,7 @@ func (s *workflowResetterSuite) TestReapplyWorkflowEvents() { } events := []*historypb.HistoryEvent{event1, event2, event3, event4, event5} shardID := s.mockShard.GetShardID() - s.mockExecutionMgr.EXPECT().ReadHistoryBranchByBatch(&persistence.ReadHistoryBranchRequest{ + s.mockExecutionMgr.EXPECT().ReadHistoryBranchByBatch(gomock.Any(), &persistence.ReadHistoryBranchRequest{ BranchToken: branchToken, MinEventID: firstEventID, MaxEventID: nextEventID, @@ -766,6 +768,7 @@ func (s *workflowResetterSuite) TestReapplyWorkflowEvents() { mutableState := workflow.NewMockMutableState(s.controller) nextRunID, err := s.workflowResetter.reapplyWorkflowEvents( + context.Background(), mutableState, firstEventID, nextEventID, @@ -857,7 +860,7 @@ func (s *workflowResetterSuite) TestPagination() { pageToken := []byte("some random token") shardID := s.mockShard.GetShardID() - s.mockExecutionMgr.EXPECT().ReadHistoryBranchByBatch(&persistence.ReadHistoryBranchRequest{ + s.mockExecutionMgr.EXPECT().ReadHistoryBranchByBatch(gomock.Any(), &persistence.ReadHistoryBranchRequest{ BranchToken: branchToken, MinEventID: firstEventID, MaxEventID: nextEventID, @@ -869,7 +872,7 @@ func (s *workflowResetterSuite) TestPagination() { NextPageToken: pageToken, Size: 12345, }, nil) - s.mockExecutionMgr.EXPECT().ReadHistoryBranchByBatch(&persistence.ReadHistoryBranchRequest{ + s.mockExecutionMgr.EXPECT().ReadHistoryBranchByBatch(gomock.Any(), &persistence.ReadHistoryBranchRequest{ BranchToken: branchToken, MinEventID: firstEventID, MaxEventID: nextEventID, @@ -882,7 +885,7 @@ func (s *workflowResetterSuite) TestPagination() { Size: 67890, }, nil) - paginationFn := s.workflowResetter.getPaginationFn(firstEventID, nextEventID, branchToken) + paginationFn := s.workflowResetter.getPaginationFn(context.Background(), firstEventID, nextEventID, branchToken) iter := collection.NewPagingIterator(paginationFn) var result []*historypb.History diff --git a/service/history/workflowTaskHandler.go b/service/history/workflowTaskHandler.go index 47a77c00e49..88a11b83c10 100644 --- a/service/history/workflowTaskHandler.go +++ b/service/history/workflowTaskHandler.go @@ -25,6 +25,7 @@ package history import ( + "context" "fmt" "time" @@ -126,6 +127,7 @@ func newWorkflowTaskHandler( } func (handler *workflowTaskHandlerImpl) handleCommands( + ctx context.Context, commands []*commandpb.Command, ) error { if err := handler.attrValidator.validateCommandSequence( @@ -134,7 +136,7 @@ func (handler *workflowTaskHandlerImpl) handleCommands( return err } for _, command := range commands { - err := handler.handleCommand(command) + err := handler.handleCommand(ctx, command) if err != nil || handler.stopProcessing { return err } @@ -142,46 +144,46 @@ func (handler *workflowTaskHandlerImpl) handleCommands( return nil } -func (handler *workflowTaskHandlerImpl) handleCommand(command *commandpb.Command) error { +func (handler *workflowTaskHandlerImpl) handleCommand(ctx context.Context, command *commandpb.Command) error { switch command.GetCommandType() { case enumspb.COMMAND_TYPE_SCHEDULE_ACTIVITY_TASK: - return handler.handleCommandScheduleActivity(command.GetScheduleActivityTaskCommandAttributes()) + return handler.handleCommandScheduleActivity(ctx, command.GetScheduleActivityTaskCommandAttributes()) case enumspb.COMMAND_TYPE_COMPLETE_WORKFLOW_EXECUTION: - return handler.handleCommandCompleteWorkflow(command.GetCompleteWorkflowExecutionCommandAttributes()) + return handler.handleCommandCompleteWorkflow(ctx, command.GetCompleteWorkflowExecutionCommandAttributes()) case enumspb.COMMAND_TYPE_FAIL_WORKFLOW_EXECUTION: - return handler.handleCommandFailWorkflow(command.GetFailWorkflowExecutionCommandAttributes()) + return handler.handleCommandFailWorkflow(ctx, command.GetFailWorkflowExecutionCommandAttributes()) case enumspb.COMMAND_TYPE_CANCEL_WORKFLOW_EXECUTION: - return handler.handleCommandCancelWorkflow(command.GetCancelWorkflowExecutionCommandAttributes()) + return handler.handleCommandCancelWorkflow(ctx, command.GetCancelWorkflowExecutionCommandAttributes()) case enumspb.COMMAND_TYPE_START_TIMER: - return handler.handleCommandStartTimer(command.GetStartTimerCommandAttributes()) + return handler.handleCommandStartTimer(ctx, command.GetStartTimerCommandAttributes()) case enumspb.COMMAND_TYPE_REQUEST_CANCEL_ACTIVITY_TASK: - return handler.handleCommandRequestCancelActivity(command.GetRequestCancelActivityTaskCommandAttributes()) + return handler.handleCommandRequestCancelActivity(ctx, command.GetRequestCancelActivityTaskCommandAttributes()) case enumspb.COMMAND_TYPE_CANCEL_TIMER: - return handler.handleCommandCancelTimer(command.GetCancelTimerCommandAttributes()) + return handler.handleCommandCancelTimer(ctx, command.GetCancelTimerCommandAttributes()) case enumspb.COMMAND_TYPE_RECORD_MARKER: - return handler.handleCommandRecordMarker(command.GetRecordMarkerCommandAttributes()) + return handler.handleCommandRecordMarker(ctx, command.GetRecordMarkerCommandAttributes()) case enumspb.COMMAND_TYPE_REQUEST_CANCEL_EXTERNAL_WORKFLOW_EXECUTION: - return handler.handleCommandRequestCancelExternalWorkflow(command.GetRequestCancelExternalWorkflowExecutionCommandAttributes()) + return handler.handleCommandRequestCancelExternalWorkflow(ctx, command.GetRequestCancelExternalWorkflowExecutionCommandAttributes()) case enumspb.COMMAND_TYPE_SIGNAL_EXTERNAL_WORKFLOW_EXECUTION: - return handler.handleCommandSignalExternalWorkflow(command.GetSignalExternalWorkflowExecutionCommandAttributes()) + return handler.handleCommandSignalExternalWorkflow(ctx, command.GetSignalExternalWorkflowExecutionCommandAttributes()) case enumspb.COMMAND_TYPE_CONTINUE_AS_NEW_WORKFLOW_EXECUTION: - return handler.handleCommandContinueAsNewWorkflow(command.GetContinueAsNewWorkflowExecutionCommandAttributes()) + return handler.handleCommandContinueAsNewWorkflow(ctx, command.GetContinueAsNewWorkflowExecutionCommandAttributes()) case enumspb.COMMAND_TYPE_START_CHILD_WORKFLOW_EXECUTION: - return handler.handleCommandStartChildWorkflow(command.GetStartChildWorkflowExecutionCommandAttributes()) + return handler.handleCommandStartChildWorkflow(ctx, command.GetStartChildWorkflowExecutionCommandAttributes()) case enumspb.COMMAND_TYPE_UPSERT_WORKFLOW_SEARCH_ATTRIBUTES: - return handler.handleCommandUpsertWorkflowSearchAttributes(command.GetUpsertWorkflowSearchAttributesCommandAttributes()) + return handler.handleCommandUpsertWorkflowSearchAttributes(ctx, command.GetUpsertWorkflowSearchAttributesCommandAttributes()) default: return serviceerror.NewInvalidArgument(fmt.Sprintf("Unknown command type: %v", command.GetCommandType())) @@ -189,6 +191,7 @@ func (handler *workflowTaskHandlerImpl) handleCommand(command *commandpb.Command } func (handler *workflowTaskHandlerImpl) handleCommandScheduleActivity( + _ context.Context, attr *commandpb.ScheduleActivityTaskCommandAttributes, ) error { @@ -245,6 +248,7 @@ func (handler *workflowTaskHandlerImpl) handleCommandScheduleActivity( } func (handler *workflowTaskHandlerImpl) handleCommandRequestCancelActivity( + _ context.Context, attr *commandpb.RequestCancelActivityTaskCommandAttributes, ) error { @@ -298,6 +302,7 @@ func (handler *workflowTaskHandlerImpl) handleCommandRequestCancelActivity( } func (handler *workflowTaskHandlerImpl) handleCommandStartTimer( + _ context.Context, attr *commandpb.StartTimerCommandAttributes, ) error { @@ -326,6 +331,7 @@ func (handler *workflowTaskHandlerImpl) handleCommandStartTimer( } func (handler *workflowTaskHandlerImpl) handleCommandCompleteWorkflow( + ctx context.Context, attr *commandpb.CompleteWorkflowExecutionCommandAttributes, ) error { @@ -385,13 +391,14 @@ func (handler *workflowTaskHandlerImpl) handleCommandCompleteWorkflow( // Check if this workflow has a cron schedule if cronBackoff != backoff.NoBackoff { - return handler.handleCron(cronBackoff, attr.GetResult(), nil, newExecutionRunID) + return handler.handleCron(ctx, cronBackoff, attr.GetResult(), nil, newExecutionRunID) } return nil } func (handler *workflowTaskHandlerImpl) handleCommandFailWorkflow( + ctx context.Context, attr *commandpb.FailWorkflowExecutionCommandAttributes, ) error { @@ -462,9 +469,9 @@ func (handler *workflowTaskHandlerImpl) handleCommandFailWorkflow( // Handle retry or cron if retryBackoff != backoff.NoBackoff { - return handler.handleRetry(retryBackoff, retryState, attr.GetFailure(), newExecutionRunID) + return handler.handleRetry(ctx, retryBackoff, retryState, attr.GetFailure(), newExecutionRunID) } else if cronBackoff != backoff.NoBackoff { - return handler.handleCron(cronBackoff, nil, attr.GetFailure(), newExecutionRunID) + return handler.handleCron(ctx, cronBackoff, nil, attr.GetFailure(), newExecutionRunID) } // No retry or cron @@ -472,6 +479,7 @@ func (handler *workflowTaskHandlerImpl) handleCommandFailWorkflow( } func (handler *workflowTaskHandlerImpl) handleCommandCancelTimer( + _ context.Context, attr *commandpb.CancelTimerCommandAttributes, ) error { @@ -509,6 +517,7 @@ func (handler *workflowTaskHandlerImpl) handleCommandCancelTimer( } func (handler *workflowTaskHandlerImpl) handleCommandCancelWorkflow( + _ context.Context, attr *commandpb.CancelWorkflowExecutionCommandAttributes, ) error { @@ -549,6 +558,7 @@ func (handler *workflowTaskHandlerImpl) handleCommandCancelWorkflow( } func (handler *workflowTaskHandlerImpl) handleCommandRequestCancelExternalWorkflow( + _ context.Context, attr *commandpb.RequestCancelExternalWorkflowExecutionCommandAttributes, ) error { @@ -591,6 +601,7 @@ func (handler *workflowTaskHandlerImpl) handleCommandRequestCancelExternalWorkfl } func (handler *workflowTaskHandlerImpl) handleCommandRecordMarker( + _ context.Context, attr *commandpb.RecordMarkerCommandAttributes, ) error { @@ -623,6 +634,7 @@ func (handler *workflowTaskHandlerImpl) handleCommandRecordMarker( } func (handler *workflowTaskHandlerImpl) handleCommandContinueAsNewWorkflow( + _ context.Context, attr *commandpb.ContinueAsNewWorkflowExecutionCommandAttributes, ) error { @@ -726,6 +738,7 @@ func (handler *workflowTaskHandlerImpl) handleCommandContinueAsNewWorkflow( } func (handler *workflowTaskHandlerImpl) handleCommandStartChildWorkflow( + _ context.Context, attr *commandpb.StartChildWorkflowExecutionCommandAttributes, ) error { handler.metricsClient.IncCounter( @@ -822,6 +835,7 @@ func (handler *workflowTaskHandlerImpl) handleCommandStartChildWorkflow( } func (handler *workflowTaskHandlerImpl) handleCommandSignalExternalWorkflow( + _ context.Context, attr *commandpb.SignalExternalWorkflowExecutionCommandAttributes, ) error { @@ -872,6 +886,7 @@ func (handler *workflowTaskHandlerImpl) handleCommandSignalExternalWorkflow( } func (handler *workflowTaskHandlerImpl) handleCommandUpsertWorkflowSearchAttributes( + _ context.Context, attr *commandpb.UpsertWorkflowSearchAttributesCommandAttributes, ) error { @@ -946,12 +961,13 @@ func searchAttributesSize(fields map[string]*commonpb.Payload) int { } func (handler *workflowTaskHandlerImpl) handleRetry( + ctx context.Context, backoffInterval time.Duration, retryState enumspb.RetryState, failure *failurepb.Failure, newRunID string, ) error { - startEvent, err := handler.mutableState.GetStartEvent() + startEvent, err := handler.mutableState.GetStartEvent(ctx) if err != nil { return err } @@ -966,6 +982,7 @@ func (handler *workflowTaskHandlerImpl) handleRetry( return err } err = workflow.SetupNewWorkflowForRetryOrCron( + ctx, handler.mutableState, newStateBuilder, newRunID, @@ -984,12 +1001,13 @@ func (handler *workflowTaskHandlerImpl) handleRetry( } func (handler *workflowTaskHandlerImpl) handleCron( + ctx context.Context, backoffInterval time.Duration, lastCompletionResult *commonpb.Payloads, failure *failurepb.Failure, newRunID string, ) error { - startEvent, err := handler.mutableState.GetStartEvent() + startEvent, err := handler.mutableState.GetStartEvent(ctx) if err != nil { return err } @@ -1008,6 +1026,7 @@ func (handler *workflowTaskHandlerImpl) handleCron( return err } err = workflow.SetupNewWorkflowForRetryOrCron( + ctx, handler.mutableState, newStateBuilder, newRunID, diff --git a/service/history/workflowTaskHandlerCallbacks.go b/service/history/workflowTaskHandlerCallbacks.go index dc12f106c35..a42ab45f1d0 100644 --- a/service/history/workflowTaskHandlerCallbacks.go +++ b/service/history/workflowTaskHandlerCallbacks.go @@ -139,7 +139,7 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskScheduled( }, nil } - startEvent, err := mutableState.GetStartEvent() + startEvent, err := mutableState.GetStartEvent(ctx) if err != nil { return nil, err } @@ -342,7 +342,7 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted( var currentWorkflowTask *workflow.WorkflowTaskInfo var currentWorkflowTaskRunning bool for attempt := 1; ; attempt++ { - msBuilder, err = weContext.LoadWorkflowExecution() + msBuilder, err = weContext.LoadWorkflowExecution(ctx) if err != nil { return nil, err } @@ -368,7 +368,7 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted( } executionInfo := msBuilder.GetExecutionInfo() - executionStats, err := weContext.LoadExecutionStats() + executionStats, err := weContext.LoadExecutionStats(ctx) if err != nil { return nil, err } @@ -469,6 +469,7 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted( ) if err := workflowTaskHandler.handleCommands( + ctx, request.Commands, ); err != nil { return nil, err @@ -498,7 +499,7 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted( // drop this workflow task if it keeps failing. This will cause the workflow task to timeout and get retried after timeout. return nil, serviceerror.NewInvalidArgument(wtFailedCause.Message()) } - msBuilder, err = handler.historyEngine.failWorkflowTask(weContext, scheduleID, startedID, wtFailedCause, request) + msBuilder, err = handler.historyEngine.failWorkflowTask(ctx, weContext, scheduleID, startedID, wtFailedCause, request) if err != nil { return nil, err } @@ -546,6 +547,7 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted( newWorkflowExecutionInfo := newStateBuilder.GetExecutionInfo() newWorkflowExecutionState := newStateBuilder.GetExecutionState() updateErr = weContext.UpdateWorkflowExecutionWithNewAsActive( + ctx, handler.shard.GetTimeSource().Now(), workflow.NewContext( handler.shard, @@ -559,7 +561,7 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted( newStateBuilder, ) } else { - updateErr = weContext.UpdateWorkflowExecutionAsActive(handler.shard.GetTimeSource().Now()) + updateErr = weContext.UpdateWorkflowExecutionAsActive(ctx, handler.shard.GetTimeSource().Now()) } if updateErr != nil { @@ -572,7 +574,7 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted( case *persistence.TransactionSizeLimitError: // must reload mutable state because the first call to updateWorkflowExecutionWithContext or continueAsNewWorkflowExecution // clears mutable state if error is returned - msBuilder, err = weContext.LoadWorkflowExecution() + msBuilder, err = weContext.LoadWorkflowExecution(ctx) if err != nil { return nil, err } @@ -588,6 +590,7 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted( return nil, err } if err := weContext.UpdateWorkflowExecutionAsActive( + ctx, handler.shard.GetTimeSource().Now(), ); err != nil { return nil, err diff --git a/service/worker/archiver/activities.go b/service/worker/archiver/activities.go index e56f8c47dbf..503d211f923 100644 --- a/service/worker/archiver/activities.go +++ b/service/worker/archiver/activities.go @@ -110,7 +110,7 @@ func deleteHistoryActivity(ctx context.Context, request ArchiveRequest) (err err err = temporal.NewNonRetryableApplicationError(err.Error(), "", nil) } }() - err = container.HistoryV2Manager.DeleteHistoryBranch(&persistence.DeleteHistoryBranchRequest{ + err = container.HistoryV2Manager.DeleteHistoryBranch(ctx, &persistence.DeleteHistoryBranchRequest{ BranchToken: request.BranchToken, ShardID: request.ShardID, }) diff --git a/service/worker/archiver/activities_test.go b/service/worker/archiver/activities_test.go index f364e95dea9..bec347aa944 100644 --- a/service/worker/archiver/activities_test.go +++ b/service/worker/archiver/activities_test.go @@ -247,7 +247,7 @@ func (s *activitiesSuite) TestUploadHistory_Success() { func (s *activitiesSuite) TestDeleteHistoryActivity_Fail_DeleteFromV2NonRetryableError() { s.metricsClient.EXPECT().Scope(metrics.ArchiverDeleteHistoryActivityScope, []metrics.Tag{metrics.NamespaceTag(testNamespace)}).Return(s.metricsScope) s.metricsScope.EXPECT().IncCounter(metrics.ArchiverNonRetryableErrorCount) - s.mockExecutionMgr.EXPECT().DeleteHistoryBranch(gomock.Any()).Return(errPersistenceNonRetryable) + s.mockExecutionMgr.EXPECT().DeleteHistoryBranch(gomock.Any(), gomock.Any()).Return(errPersistenceNonRetryable) container := &BootstrapContainer{ Logger: s.logger, MetricsClient: s.metricsClient, diff --git a/service/worker/scanner/executions/history_event_id_validator.go b/service/worker/scanner/executions/history_event_id_validator.go index 552786144de..d2cee7eedaf 100644 --- a/service/worker/scanner/executions/history_event_id_validator.go +++ b/service/worker/scanner/executions/history_event_id_validator.go @@ -25,6 +25,8 @@ package executions import ( + "context" + "go.temporal.io/api/serviceerror" "go.temporal.io/server/common" @@ -59,6 +61,7 @@ func NewHistoryEventIDValidator( } func (v *historyEventIDValidator) Validate( + ctx context.Context, mutableState *MutableState, ) ([]MutableStateValidationResult, error) { currentVersionHistory, err := versionhistory.GetCurrentVersionHistory( @@ -71,7 +74,7 @@ func (v *historyEventIDValidator) Validate( // TODO currently history event ID validator only verifies // the first event batch exists, before doing whole history // validation, ensure not too much capacity is consumed - _, err = v.executionManager.ReadRawHistoryBranch(&persistence.ReadHistoryBranchRequest{ + _, err = v.executionManager.ReadRawHistoryBranch(ctx, &persistence.ReadHistoryBranchRequest{ MinEventID: common.FirstEventID, MaxEventID: common.FirstEventID + 1, BranchToken: currentVersionHistory.BranchToken, @@ -85,7 +88,7 @@ func (v *historyEventIDValidator) Validate( case *serviceerror.NotFound, *serviceerror.DataLoss: // additionally validate mutable state is still present in DB - _, err = v.executionManager.GetWorkflowExecution(&persistence.GetWorkflowExecutionRequest{ + _, err = v.executionManager.GetWorkflowExecution(ctx, &persistence.GetWorkflowExecutionRequest{ ShardID: v.shardID, NamespaceID: mutableState.GetExecutionInfo().NamespaceId, WorkflowID: mutableState.GetExecutionInfo().WorkflowId, diff --git a/service/worker/scanner/executions/interfaces.go b/service/worker/scanner/executions/interfaces.go index a3e4fd756e0..3afa17b8f76 100644 --- a/service/worker/scanner/executions/interfaces.go +++ b/service/worker/scanner/executions/interfaces.go @@ -25,6 +25,8 @@ package executions import ( + "context" + persistencespb "go.temporal.io/server/api/persistence/v1" ) @@ -41,6 +43,6 @@ type ( } Validator interface { - Validate(mutableState *MutableState) ([]MutableStateValidationResult, error) + Validate(ctx context.Context, mutableState *MutableState) ([]MutableStateValidationResult, error) } ) diff --git a/service/worker/scanner/executions/mutable_state_id_validator.go b/service/worker/scanner/executions/mutable_state_id_validator.go index fcb0df33154..7c898532d85 100644 --- a/service/worker/scanner/executions/mutable_state_id_validator.go +++ b/service/worker/scanner/executions/mutable_state_id_validator.go @@ -25,6 +25,7 @@ package executions import ( + "context" "fmt" persistencespb "go.temporal.io/server/api/persistence/v1" @@ -56,6 +57,7 @@ func NewMutableStateIDValidator() *mutableStateIDValidator { // Validate does shallow correctness check of IDs in mutable state. func (v *mutableStateIDValidator) Validate( + ctx context.Context, mutableState *MutableState, ) ([]MutableStateValidationResult, error) { diff --git a/service/worker/scanner/executions/task.go b/service/worker/scanner/executions/task.go index ba9af38e665..d61442fdc4e 100644 --- a/service/worker/scanner/executions/task.go +++ b/service/worker/scanner/executions/task.go @@ -82,7 +82,7 @@ func newTask( logger: logger, scavenger: scavenger, - ctx: context.Background(), + ctx: context.Background(), // TODO: use context from ExecutionsScavengerActivity rateLimiter: rateLimiter, } } @@ -128,6 +128,7 @@ func (t *task) validate( ) if validationResults, err := NewMutableStateIDValidator().Validate( + t.ctx, mutableState, ); err != nil { t.logger.Error("unable to validate mutable state ID", @@ -144,7 +145,7 @@ func (t *task) validate( if validationResults, err := NewHistoryEventIDValidator( t.shardID, t.executionManager, - ).Validate(mutableState); err != nil { + ).Validate(t.ctx, mutableState); err != nil { t.logger.Error("unable to validate history event ID being contiguous", tag.ShardID(t.shardID), tag.WorkflowNamespaceID(mutableState.GetExecutionInfo().GetNamespaceId()), @@ -166,7 +167,7 @@ func (t *task) getPaginationFn() collection.PaginationFn { PageSize: executionsPageSize, PageToken: paginationToken, } - resp, err := t.executionManager.ListConcreteExecutions(req) + resp, err := t.executionManager.ListConcreteExecutions(t.ctx, req) if err != nil { return nil, nil, err } diff --git a/service/worker/scanner/history/scavenger.go b/service/worker/scanner/history/scavenger.go index 142ce176889..4e62be27e36 100644 --- a/service/worker/scanner/history/scavenger.go +++ b/service/worker/scanner/history/scavenger.go @@ -147,7 +147,7 @@ func (s *Scavenger) loadTasks( defer close(reqCh) - iter := collection.NewPagingIteratorWithToken(s.getPaginationFn(), s.hbd.NextPageToken) + iter := collection.NewPagingIteratorWithToken(s.getPaginationFn(ctx), s.hbd.NextPageToken) for iter.HasNext() { if err := s.rateLimiter.Wait(ctx); err != nil { // context done @@ -274,7 +274,7 @@ func (s *Scavenger) handleTask( return err } - err = s.db.DeleteHistoryBranch(&persistence.DeleteHistoryBranchRequest{ + err = s.db.DeleteHistoryBranch(ctx, &persistence.DeleteHistoryBranchRequest{ ShardID: task.shardID, BranchToken: branchToken, }) @@ -301,13 +301,15 @@ func (s *Scavenger) handleErr( s.hbd.SuccessCount++ } -func (s *Scavenger) getPaginationFn() collection.PaginationFn { +func (s *Scavenger) getPaginationFn( + ctx context.Context, +) collection.PaginationFn { return func(paginationToken []byte) ([]interface{}, []byte, error) { req := &persistence.GetAllHistoryTreeBranchesRequest{ PageSize: pageSize, NextPageToken: paginationToken, } - resp, err := s.db.GetAllHistoryTreeBranches(req) + resp, err := s.db.GetAllHistoryTreeBranches(ctx, req) if err != nil { return nil, nil, err } diff --git a/service/worker/scanner/history/scavenger_test.go b/service/worker/scanner/history/scavenger_test.go index c2d18be5329..3c0568896ca 100644 --- a/service/worker/scanner/history/scavenger_test.go +++ b/service/worker/scanner/history/scavenger_test.go @@ -94,7 +94,7 @@ func (s *ScavengerTestSuite) createTestScavenger( func (s *ScavengerTestSuite) TestAllSkipTasksTwoPages() { db, _, scvgr, controller := s.createTestScavenger(100) defer controller.Finish() - db.EXPECT().GetAllHistoryTreeBranches(&p.GetAllHistoryTreeBranchesRequest{ + db.EXPECT().GetAllHistoryTreeBranches(gomock.Any(), &p.GetAllHistoryTreeBranchesRequest{ PageSize: pageSize, }).Return(&p.GetAllHistoryTreeBranchesResponse{ NextPageToken: []byte("page1"), @@ -114,7 +114,7 @@ func (s *ScavengerTestSuite) TestAllSkipTasksTwoPages() { }, }, nil) - db.EXPECT().GetAllHistoryTreeBranches(&p.GetAllHistoryTreeBranchesRequest{ + db.EXPECT().GetAllHistoryTreeBranches(gomock.Any(), &p.GetAllHistoryTreeBranchesRequest{ PageSize: pageSize, NextPageToken: []byte("page1"), }).Return(&p.GetAllHistoryTreeBranchesResponse{ @@ -146,7 +146,7 @@ func (s *ScavengerTestSuite) TestAllSkipTasksTwoPages() { func (s *ScavengerTestSuite) TestAllErrorSplittingTasksTwoPages() { db, _, scvgr, controller := s.createTestScavenger(100) defer controller.Finish() - db.EXPECT().GetAllHistoryTreeBranches(&p.GetAllHistoryTreeBranchesRequest{ + db.EXPECT().GetAllHistoryTreeBranches(gomock.Any(), &p.GetAllHistoryTreeBranchesRequest{ PageSize: pageSize, }).Return(&p.GetAllHistoryTreeBranchesResponse{ NextPageToken: []byte("page1"), @@ -166,7 +166,7 @@ func (s *ScavengerTestSuite) TestAllErrorSplittingTasksTwoPages() { }, }, nil) - db.EXPECT().GetAllHistoryTreeBranches(&p.GetAllHistoryTreeBranchesRequest{ + db.EXPECT().GetAllHistoryTreeBranches(gomock.Any(), &p.GetAllHistoryTreeBranchesRequest{ PageSize: pageSize, NextPageToken: []byte("page1"), }).Return(&p.GetAllHistoryTreeBranchesResponse{ @@ -198,7 +198,7 @@ func (s *ScavengerTestSuite) TestAllErrorSplittingTasksTwoPages() { func (s *ScavengerTestSuite) TestNoGarbageTwoPages() { db, client, scvgr, controller := s.createTestScavenger(100) defer controller.Finish() - db.EXPECT().GetAllHistoryTreeBranches(&p.GetAllHistoryTreeBranchesRequest{ + db.EXPECT().GetAllHistoryTreeBranches(gomock.Any(), &p.GetAllHistoryTreeBranchesRequest{ PageSize: pageSize, }).Return(&p.GetAllHistoryTreeBranchesResponse{ NextPageToken: []byte("page1"), @@ -218,7 +218,7 @@ func (s *ScavengerTestSuite) TestNoGarbageTwoPages() { }, }, nil) - db.EXPECT().GetAllHistoryTreeBranches(&p.GetAllHistoryTreeBranchesRequest{ + db.EXPECT().GetAllHistoryTreeBranches(gomock.Any(), &p.GetAllHistoryTreeBranchesRequest{ PageSize: pageSize, NextPageToken: []byte("page1"), }).Return(&p.GetAllHistoryTreeBranchesResponse{ @@ -279,7 +279,7 @@ func (s *ScavengerTestSuite) TestNoGarbageTwoPages() { func (s *ScavengerTestSuite) TestDeletingBranchesTwoPages() { db, client, scvgr, controller := s.createTestScavenger(100) defer controller.Finish() - db.EXPECT().GetAllHistoryTreeBranches(&p.GetAllHistoryTreeBranchesRequest{ + db.EXPECT().GetAllHistoryTreeBranches(gomock.Any(), &p.GetAllHistoryTreeBranchesRequest{ PageSize: pageSize, }).Return(&p.GetAllHistoryTreeBranchesResponse{ NextPageToken: []byte("page1"), @@ -298,7 +298,7 @@ func (s *ScavengerTestSuite) TestDeletingBranchesTwoPages() { }, }, }, nil) - db.EXPECT().GetAllHistoryTreeBranches(&p.GetAllHistoryTreeBranchesRequest{ + db.EXPECT().GetAllHistoryTreeBranches(gomock.Any(), &p.GetAllHistoryTreeBranchesRequest{ PageSize: pageSize, NextPageToken: []byte("page1"), }).Return(&p.GetAllHistoryTreeBranchesResponse{ @@ -349,25 +349,25 @@ func (s *ScavengerTestSuite) TestDeletingBranchesTwoPages() { branchToken1, err := p.NewHistoryBranchTokenByBranchID(treeID1, branchID1) s.Nil(err) - db.EXPECT().DeleteHistoryBranch(&p.DeleteHistoryBranchRequest{ + db.EXPECT().DeleteHistoryBranch(gomock.Any(), &p.DeleteHistoryBranchRequest{ BranchToken: branchToken1, ShardID: common.WorkflowIDToHistoryShard("namespaceID1", "workflowID1", s.numShards), }).Return(nil) branchToken2, err := p.NewHistoryBranchTokenByBranchID(treeID2, branchID2) s.Nil(err) - db.EXPECT().DeleteHistoryBranch(&p.DeleteHistoryBranchRequest{ + db.EXPECT().DeleteHistoryBranch(gomock.Any(), &p.DeleteHistoryBranchRequest{ BranchToken: branchToken2, ShardID: common.WorkflowIDToHistoryShard("namespaceID2", "workflowID2", s.numShards), }).Return(nil) branchToken3, err := p.NewHistoryBranchTokenByBranchID(treeID3, branchID3) s.Nil(err) - db.EXPECT().DeleteHistoryBranch(&p.DeleteHistoryBranchRequest{ + db.EXPECT().DeleteHistoryBranch(gomock.Any(), &p.DeleteHistoryBranchRequest{ BranchToken: branchToken3, ShardID: common.WorkflowIDToHistoryShard("namespaceID3", "workflowID3", s.numShards), }).Return(nil) branchToken4, err := p.NewHistoryBranchTokenByBranchID(treeID4, branchID4) s.Nil(err) - db.EXPECT().DeleteHistoryBranch(&p.DeleteHistoryBranchRequest{ + db.EXPECT().DeleteHistoryBranch(gomock.Any(), &p.DeleteHistoryBranchRequest{ BranchToken: branchToken4, ShardID: common.WorkflowIDToHistoryShard("namespaceID4", "workflowID4", s.numShards), }).Return(nil) @@ -384,7 +384,7 @@ func (s *ScavengerTestSuite) TestDeletingBranchesTwoPages() { func (s *ScavengerTestSuite) TestMixesTwoPages() { db, client, scvgr, controller := s.createTestScavenger(100) defer controller.Finish() - db.EXPECT().GetAllHistoryTreeBranches(&p.GetAllHistoryTreeBranchesRequest{ + db.EXPECT().GetAllHistoryTreeBranches(gomock.Any(), &p.GetAllHistoryTreeBranchesRequest{ PageSize: pageSize, }).Return(&p.GetAllHistoryTreeBranchesResponse{ NextPageToken: []byte("page1"), @@ -405,7 +405,7 @@ func (s *ScavengerTestSuite) TestMixesTwoPages() { }, }, }, nil) - db.EXPECT().GetAllHistoryTreeBranches(&p.GetAllHistoryTreeBranchesRequest{ + db.EXPECT().GetAllHistoryTreeBranches(gomock.Any(), &p.GetAllHistoryTreeBranchesRequest{ PageSize: pageSize, NextPageToken: []byte("page1"), }).Return(&p.GetAllHistoryTreeBranchesResponse{ @@ -459,14 +459,14 @@ func (s *ScavengerTestSuite) TestMixesTwoPages() { branchToken3, err := p.NewHistoryBranchTokenByBranchID(treeID3, branchID3) s.Nil(err) - db.EXPECT().DeleteHistoryBranch(&p.DeleteHistoryBranchRequest{ + db.EXPECT().DeleteHistoryBranch(gomock.Any(), &p.DeleteHistoryBranchRequest{ BranchToken: branchToken3, ShardID: common.WorkflowIDToHistoryShard("namespaceID3", "workflowID3", s.numShards), }).Return(nil) branchToken4, err := p.NewHistoryBranchTokenByBranchID(treeID4, branchID4) s.Nil(err) - db.EXPECT().DeleteHistoryBranch(&p.DeleteHistoryBranchRequest{ + db.EXPECT().DeleteHistoryBranch(gomock.Any(), &p.DeleteHistoryBranchRequest{ BranchToken: branchToken4, ShardID: common.WorkflowIDToHistoryShard("namespaceID4", "workflowID4", s.numShards), }).Return(fmt.Errorf("failed to delete history")) diff --git a/service/worker/scanner/taskqueue/scavenger_test.go b/service/worker/scanner/taskqueue/scavenger_test.go index c4c1105cbc2..6bd50e3aaee 100644 --- a/service/worker/scanner/taskqueue/scavenger_test.go +++ b/service/worker/scanner/taskqueue/scavenger_test.go @@ -68,7 +68,7 @@ func (s *ScavengerTestSuite) SetupTest() { executorPollInterval = time.Millisecond * 50 } -func (s *ScavengerTestSuite) TeardownTest() { +func (s *ScavengerTestSuite) TearDownTest() { s.controller.Finish() } diff --git a/tools/cli/adminCommands.go b/tools/cli/adminCommands.go index f377fac56fa..11d9c83eb5c 100644 --- a/tools/cli/adminCommands.go +++ b/tools/cli/adminCommands.go @@ -208,6 +208,8 @@ func AdminDeleteWorkflow(c *cli.Context) { branchTokens = append(branchTokens, historyItem.GetBranchToken()) } + ctx, cancel := newContext(c) + defer cancel() for _, branchToken := range branchTokens { branchInfo, err := serialization.HistoryBranchFromBlob(branchToken, enumspb.ENCODING_TYPE_PROTO3.String()) if err != nil { @@ -222,7 +224,7 @@ func AdminDeleteWorkflow(c *cli.Context) { log.NewNoopLogger(), dynamicconfig.GetIntPropertyFn(common.DefaultTransactionSizeLimit), ) - err = execMgr.DeleteHistoryBranch(&persistence.DeleteHistoryBranchRequest{ + err = execMgr.DeleteHistoryBranch(ctx, &persistence.DeleteHistoryBranchRequest{ BranchToken: branchToken, ShardID: int32(shardIDInt), }) @@ -439,7 +441,9 @@ func AdminDescribeTask(c *cli.Context) { if err != nil { ErrorAndExit("Failed to initialize execution manager", err) } - task, err := executionManager.GetHistoryTask(&persistence.GetHistoryTaskRequest{ + ctx, cancel := newContext(c) + defer cancel() + task, err := executionManager.GetHistoryTask(ctx, &persistence.GetHistoryTaskRequest{ ShardID: int32(sid), TaskCategory: historyTaskCategory, TaskKey: taskKey, diff --git a/tools/cli/adminDBScanCommand.go b/tools/cli/adminDBScanCommand.go index 6cb0c2ba8fb..30dae8630ce 100644 --- a/tools/cli/adminDBScanCommand.go +++ b/tools/cli/adminDBScanCommand.go @@ -241,9 +241,12 @@ func AdminDBScan(c *cli.Context) { reports := make(chan *ShardScanReport) for i := int32(0); i < scanWorkerCount; i++ { go func(workerIdx int32) { + ctx, cancel := newIndefiniteContext(c) + defer cancel() for shardID := lowerShardBound; shardID < upperShardBound; shardID++ { if shardID%scanWorkerCount == workerIdx { reports <- scanShard( + ctx, session, shardID, scanOutputDirectories, @@ -271,6 +274,7 @@ func AdminDBScan(c *cli.Context) { } func scanShard( + ctx context.Context, session gocql.Session, shardID int32, scanOutputDirectories *ScanOutputDirectories, @@ -309,7 +313,7 @@ func scanShard( PageToken: token, } preconditionForDBCall(&report.TotalDBRequests, limiter) - resp, err := execMan.ListConcreteExecutions(req) + resp, err := execMan.ListConcreteExecutions(ctx, req) if err != nil { report.Failure = &ShardScanReportFailure{ Note: "failed to call ListConcreteExecutions", @@ -380,6 +384,7 @@ func scanShard( } currentExecutionVerificationResult := verifyCurrentExecution( + ctx, s.ExecutionInfo, s.ExecutionState, s.NextEventId, @@ -631,6 +636,7 @@ func verifyActivityIds( } func verifyCurrentExecution( + ctx context.Context, executionInfo *persistencespb.WorkflowExecutionInfo, executionState *persistencespb.WorkflowExecutionState, nextEventID int64, @@ -651,9 +657,9 @@ func verifyCurrentExecution( WorkflowID: executionInfo.WorkflowId, } preconditionForDBCall(totalDBRequests, limiter) - currentExecution, err := execMan.GetCurrentExecution(getCurrentExecutionRequest) + currentExecution, err := execMan.GetCurrentExecution(ctx, getCurrentExecutionRequest) - ecf, stillOpen := concreteExecutionStillOpen(executionInfo, executionState, shardID, execMan, limiter, totalDBRequests) + ecf, stillOpen := concreteExecutionStillOpen(ctx, executionInfo, executionState, shardID, execMan, limiter, totalDBRequests) if ecf != nil { checkFailureWriter.Add(ecf) return VerificationResultCheckFailure @@ -748,6 +754,7 @@ func concreteExecutionStillExists( } func concreteExecutionStillOpen( + ctx context.Context, executionInfo *persistencespb.WorkflowExecutionInfo, executionState *persistencespb.WorkflowExecutionState, shardID int32, @@ -762,7 +769,7 @@ func concreteExecutionStillOpen( RunID: executionState.GetRunId(), } preconditionForDBCall(totalDBRequests, limiter) - ce, err := execMan.GetWorkflowExecution(getConcreteExecution) + ce, err := execMan.GetWorkflowExecution(ctx, getConcreteExecution) if err != nil { return &ExecutionCheckFailure{ ShardID: shardID,