From eed91445e70be00a0ebeb5d2164a4409cf5060bc Mon Sep 17 00:00:00 2001 From: wxing1292 Date: Fri, 26 Mar 2021 21:04:26 -0700 Subject: [PATCH] Perf optimize NDC RPC replication (#1392) * Rename ReadLevel to MinTaskID, MaxReadLevel to MaxTaskID * Add hint in replication task source reducing unnecessary DB read * Add periodical sanity check in replication task source in case of timeout err * Treat last replication task ID vs replication task pagination ID differently --- .../cassandra/cassandraPersistence.go | 8 +- common/persistence/dataInterfaces.go | 8 +- .../persistence-tests/persistenceTestBase.go | 8 +- common/persistence/sql/sql_execution_tasks.go | 8 +- service/history/configs/config.go | 4 - service/history/handler.go | 1 + service/history/historyEngine.go | 46 ++-- service/history/replicationDLQHandler.go | 4 +- service/history/replicationDLQHandler_test.go | 8 +- service/history/replicationTaskProcessor.go | 45 ++-- .../history/replicationTaskProcessor_test.go | 115 +++++++++- service/history/replicatorQueueProcessor.go | 217 ++++++++++++------ .../history/replicatorQueueProcessor_test.go | 84 +++++++ service/history/shard/context.go | 2 +- service/history/shard/context_impl.go | 7 +- service/history/shard/context_mock.go | 8 +- service/history/shard/context_test.go | 1 + service/history/shard/engine.go | 3 +- service/history/shard/engine_mock.go | 20 +- .../timerQueueTaskExecutorBase_test.go | 2 + service/history/workflowExecutionContext.go | 15 +- 21 files changed, 462 insertions(+), 152 deletions(-) diff --git a/common/persistence/cassandra/cassandraPersistence.go b/common/persistence/cassandra/cassandraPersistence.go index 61e4f624095..b7c5b4f916d 100644 --- a/common/persistence/cassandra/cassandraPersistence.go +++ b/common/persistence/cassandra/cassandraPersistence.go @@ -1894,8 +1894,8 @@ func (d *cassandraPersistence) GetReplicationTasks( rowTypeReplicationWorkflowID, rowTypeReplicationRunID, defaultVisibilityTimestamp, - request.ReadLevel, - request.MaxReadLevel, + request.MinTaskID, + request.MaxTaskID, ).PageSize(request.BatchSize).PageState(request.NextPageToken) return d.populateGetReplicationTasksResponse(query, "GetReplicationTasks") @@ -2556,8 +2556,8 @@ func (d *cassandraPersistence) GetReplicationTasksFromDLQ( request.SourceClusterName, rowTypeDLQRunID, defaultVisibilityTimestamp, - request.ReadLevel, - request.ReadLevel+int64(request.BatchSize), + request.MinTaskID, + request.MinTaskID+int64(request.BatchSize), ).PageSize(request.BatchSize).PageState(request.NextPageToken) return d.populateGetReplicationTasksResponse(query, "GetReplicationTasksFromDLQ") diff --git a/common/persistence/dataInterfaces.go b/common/persistence/dataInterfaces.go index 0901bc10056..35903b90983 100644 --- a/common/persistence/dataInterfaces.go +++ b/common/persistence/dataInterfaces.go @@ -662,8 +662,8 @@ type ( // GetReplicationTasksRequest is used to read tasks from the replication task queue GetReplicationTasksRequest struct { - ReadLevel int64 - MaxReadLevel int64 + MinTaskID int64 + MaxTaskID int64 BatchSize int NextPageToken []byte } @@ -2143,8 +2143,8 @@ func NewGetReplicationTasksFromDLQRequest( return &GetReplicationTasksFromDLQRequest{ SourceClusterName: sourceClusterName, GetReplicationTasksRequest: GetReplicationTasksRequest{ - ReadLevel: readLevel, - MaxReadLevel: maxReadLevel, + MinTaskID: readLevel, + MaxTaskID: maxReadLevel, BatchSize: batchSize, NextPageToken: nextPageToken, }, diff --git a/common/persistence/persistence-tests/persistenceTestBase.go b/common/persistence/persistence-tests/persistenceTestBase.go index 72a6bd3be10..2f6026a71a4 100644 --- a/common/persistence/persistence-tests/persistenceTestBase.go +++ b/common/persistence/persistence-tests/persistenceTestBase.go @@ -980,8 +980,8 @@ func (s *TestBase) GetReplicationTasks(batchSize int, getAll bool) ([]*persisten Loop: for { response, err := s.ExecutionManager.GetReplicationTasks(&persistence.GetReplicationTasksRequest{ - ReadLevel: s.GetReplicationReadLevel(), - MaxReadLevel: int64(math.MaxInt64), + MinTaskID: s.GetReplicationReadLevel(), + MaxTaskID: int64(math.MaxInt64), BatchSize: batchSize, NextPageToken: token, }) @@ -1034,8 +1034,8 @@ func (s *TestBase) GetReplicationTasksFromDLQ( return s.ExecutionManager.GetReplicationTasksFromDLQ(&persistence.GetReplicationTasksFromDLQRequest{ SourceClusterName: sourceCluster, GetReplicationTasksRequest: persistence.GetReplicationTasksRequest{ - ReadLevel: readLevel, - MaxReadLevel: maxReadLevel, + MinTaskID: readLevel, + MaxTaskID: maxReadLevel, BatchSize: pageSize, NextPageToken: pageToken, }, diff --git a/common/persistence/sql/sql_execution_tasks.go b/common/persistence/sql/sql_execution_tasks.go index db271d022af..ce0704c3bbf 100644 --- a/common/persistence/sql/sql_execution_tasks.go +++ b/common/persistence/sql/sql_execution_tasks.go @@ -319,7 +319,7 @@ func (m *sqlExecutionManager) GetReplicationTasks( switch err { case nil: - return m.populateGetReplicationTasksResponse(rows, request.MaxReadLevel) + return m.populateGetReplicationTasksResponse(rows, request.MaxTaskID) case sql.ErrNoRows: return &p.GetReplicationTasksResponse{}, nil default: @@ -330,7 +330,7 @@ func (m *sqlExecutionManager) GetReplicationTasks( func getReadLevels( request *p.GetReplicationTasksRequest, ) (readLevel int64, maxReadLevelInclusive int64, err error) { - readLevel = request.ReadLevel + readLevel = request.MinTaskID if len(request.NextPageToken) > 0 { readLevel, err = deserializePageToken(request.NextPageToken) if err != nil { @@ -338,7 +338,7 @@ func getReadLevels( } } - maxReadLevelInclusive = collection.MaxInt64(readLevel+int64(request.BatchSize), request.MaxReadLevel) + maxReadLevelInclusive = collection.MaxInt64(readLevel+int64(request.BatchSize), request.MaxTaskID) return readLevel, maxReadLevelInclusive, nil } @@ -476,7 +476,7 @@ func (m *sqlExecutionManager) GetReplicationTasksFromDLQ( switch err { case nil: - return m.populateGetReplicationDLQTasksResponse(rows, request.MaxReadLevel) + return m.populateGetReplicationDLQTasksResponse(rows, request.MaxTaskID) case sql.ErrNoRows: return &p.GetReplicationTasksResponse{}, nil default: diff --git a/service/history/configs/config.go b/service/history/configs/config.go index c06f9c025e3..485cafbdb6e 100644 --- a/service/history/configs/config.go +++ b/service/history/configs/config.go @@ -205,8 +205,6 @@ type Config struct { ReplicationTaskProcessorNoTaskRetryWait dynamicconfig.DurationPropertyFnWithShardIDFilter ReplicationTaskProcessorCleanupInterval dynamicconfig.DurationPropertyFnWithShardIDFilter ReplicationTaskProcessorCleanupJitterCoefficient dynamicconfig.FloatPropertyFnWithShardIDFilter - ReplicationTaskProcessorStartWait dynamicconfig.DurationPropertyFnWithShardIDFilter - ReplicationTaskProcessorStartWaitJitterCoefficient dynamicconfig.FloatPropertyFnWithShardIDFilter ReplicationTaskProcessorHostQPS dynamicconfig.FloatPropertyFn ReplicationTaskProcessorShardQPS dynamicconfig.FloatPropertyFn @@ -349,8 +347,6 @@ func NewConfig(dc *dynamicconfig.Collection, numberOfShards int32, isAdvancedVis ReplicatorProcessorMaxRedispatchQueueSize: dc.GetIntProperty(dynamicconfig.ReplicatorProcessorMaxRedispatchQueueSize, 10000), ReplicatorProcessorEnablePriorityTaskProcessor: dc.GetBoolProperty(dynamicconfig.ReplicatorProcessorEnablePriorityTaskProcessor, false), ReplicatorProcessorFetchTasksBatchSize: dc.GetIntProperty(dynamicconfig.ReplicatorTaskBatchSize, 25), - ReplicationTaskProcessorStartWait: dc.GetDurationPropertyFilteredByShardID(dynamicconfig.ReplicationTaskProcessorStartWait, 5*time.Second), - ReplicationTaskProcessorStartWaitJitterCoefficient: dc.GetFloat64PropertyFilteredByShardID(dynamicconfig.ReplicationTaskProcessorStartWaitJitterCoefficient, 0.9), ReplicationTaskProcessorHostQPS: dc.GetFloat64Property(dynamicconfig.ReplicationTaskProcessorHostQPS, 1500), ReplicationTaskProcessorShardQPS: dc.GetFloat64Property(dynamicconfig.ReplicationTaskProcessorShardQPS, 30), diff --git a/service/history/handler.go b/service/history/handler.go index 614cc89cbc5..8726b942256 100644 --- a/service/history/handler.go +++ b/service/history/handler.go @@ -1260,6 +1260,7 @@ func (h *Handler) GetReplicationMessages(ctx context.Context, request *historyse tasks, err := engine.GetReplicationMessages( ctx, request.GetClusterName(), + token.GetLastProcessedMessageId(), token.GetLastRetrievedMessageId(), ) if err != nil { diff --git a/service/history/historyEngine.go b/service/history/historyEngine.go index ed694a9684f..84d21b31c2f 100644 --- a/service/history/historyEngine.go +++ b/service/history/historyEngine.go @@ -2496,23 +2496,32 @@ func (e *historyEngineImpl) NotifyNewTransferTasks( } } -func (e *historyEngineImpl) NotifyNewVisibilityTasks( +func (e *historyEngineImpl) NotifyNewTimerTasks( tasks []persistence.Task, ) { - if len(tasks) > 0 && e.visibilityProcessor != nil { - e.visibilityProcessor.NotifyNewTask(tasks) + if len(tasks) > 0 { + task := tasks[0] + clusterName := e.clusterMetadata.ClusterNameForFailoverVersion(task.GetVersion()) + e.timerProcessor.NotifyNewTimers(clusterName, tasks) } } -func (e *historyEngineImpl) NotifyNewTimerTasks( +func (e *historyEngineImpl) NotifyNewReplicationTasks( tasks []persistence.Task, ) { - if len(tasks) > 0 { - task := tasks[0] - clusterName := e.clusterMetadata.ClusterNameForFailoverVersion(task.GetVersion()) - e.timerProcessor.NotifyNewTimers(clusterName, tasks) + if len(tasks) > 0 && e.replicatorProcessor != nil { + e.replicatorProcessor.NotifyNewTasks(tasks) + } +} + +func (e *historyEngineImpl) NotifyNewVisibilityTasks( + tasks []persistence.Task, +) { + + if len(tasks) > 0 && e.visibilityProcessor != nil { + e.visibilityProcessor.NotifyNewTask(tasks) } } @@ -2779,27 +2788,32 @@ func getWorkflowAlreadyStartedError(errMsg string, createRequestID string, workf func (e *historyEngineImpl) GetReplicationMessages( ctx context.Context, pollingCluster string, - lastReadMessageID int64, + ackMessageID int64, + queryMessageID int64, ) (*replicationspb.ReplicationMessages, error) { scope := metrics.HistoryGetReplicationMessagesScope sw := e.metricsClient.StartTimer(scope, metrics.GetReplicationMessagesForShardLatency) defer sw.Stop() - replicationMessages, err := e.replicatorProcessor.getTasks( + if ackMessageID != persistence.EmptyQueueMessageID { + if err := e.shard.UpdateClusterReplicationLevel( + pollingCluster, + ackMessageID, + ); err != nil { + e.logger.Error("error updating replication level for shard", tag.Error(err), tag.OperationFailed) + } + } + + replicationMessages, err := e.replicatorProcessor.paginateTasks( ctx, pollingCluster, - lastReadMessageID, + queryMessageID, ) if err != nil { e.logger.Error("Failed to retrieve replication messages.", tag.Error(err)) return nil, err } - - // Set cluster status for sync shard info - replicationMessages.SyncShardStatus = &replicationspb.SyncShardStatus{ - StatusTime: timestamp.TimePtr(e.timeSource.Now()), - } e.logger.Debug("Successfully fetched replication messages.", tag.Counter(len(replicationMessages.ReplicationTasks))) return replicationMessages, nil } diff --git a/service/history/replicationDLQHandler.go b/service/history/replicationDLQHandler.go index 35620afb708..58fe428b573 100644 --- a/service/history/replicationDLQHandler.go +++ b/service/history/replicationDLQHandler.go @@ -117,8 +117,8 @@ func (r *replicationDLQHandlerImpl) readMessagesWithAckLevel( resp, err := r.shard.GetExecutionManager().GetReplicationTasksFromDLQ(&persistence.GetReplicationTasksFromDLQRequest{ SourceClusterName: sourceCluster, GetReplicationTasksRequest: persistence.GetReplicationTasksRequest{ - ReadLevel: ackLevel, - MaxReadLevel: lastMessageID, + MinTaskID: ackLevel, + MaxTaskID: lastMessageID, BatchSize: pageSize, NextPageToken: pageToken, }, diff --git a/service/history/replicationDLQHandler_test.go b/service/history/replicationDLQHandler_test.go index 8178ade21fc..b505199379d 100644 --- a/service/history/replicationDLQHandler_test.go +++ b/service/history/replicationDLQHandler_test.go @@ -172,8 +172,8 @@ func (s *replicationDLQHandlerSuite) TestReadMessages_OK() { s.executionManager.EXPECT().GetReplicationTasksFromDLQ(&persistence.GetReplicationTasksFromDLQRequest{ SourceClusterName: s.sourceCluster, GetReplicationTasksRequest: persistence.GetReplicationTasksRequest{ - ReadLevel: persistence.EmptyQueueMessageID, - MaxReadLevel: lastMessageID, + MinTaskID: persistence.EmptyQueueMessageID, + MaxTaskID: lastMessageID, BatchSize: pageSize, NextPageToken: pageToken, }, @@ -255,8 +255,8 @@ func (s *replicationDLQHandlerSuite) TestMergeMessages() { s.executionManager.EXPECT().GetReplicationTasksFromDLQ(&persistence.GetReplicationTasksFromDLQRequest{ SourceClusterName: s.sourceCluster, GetReplicationTasksRequest: persistence.GetReplicationTasksRequest{ - ReadLevel: persistence.EmptyQueueMessageID, - MaxReadLevel: lastMessageID, + MinTaskID: persistence.EmptyQueueMessageID, + MaxTaskID: lastMessageID, BatchSize: pageSize, NextPageToken: pageToken, }, diff --git a/service/history/replicationTaskProcessor.go b/service/history/replicationTaskProcessor.go index 66360a79821..d137c13efd4 100644 --- a/service/history/replicationTaskProcessor.go +++ b/service/history/replicationTaskProcessor.go @@ -54,10 +54,8 @@ import ( ) const ( - dropSyncShardTaskTimeThreshold = 10 * time.Minute - replicationTimeout = 30 * time.Second - taskErrorRetryBackoffCoefficient = 1.2 - taskErrorRetryMaxInterval = 5 * time.Second + dropSyncShardTaskTimeThreshold = 10 * time.Minute + replicationTimeout = 30 * time.Second ) var ( @@ -88,6 +86,8 @@ type ( minTxAckedTaskID int64 // recv side maxRxProcessedTaskID int64 + maxRxReceivedTaskID int64 + rxTaskBackoff time.Duration requestChan chan<- *replicationTaskRequest syncShardChan chan *replicationspb.SyncShardStatus @@ -150,6 +150,7 @@ func NewReplicationTaskProcessor( shutdownChan: make(chan struct{}), minTxAckedTaskID: persistence.EmptyQueueMessageID, maxRxProcessedTaskID: persistence.EmptyQueueMessageID, + maxRxReceivedTaskID: persistence.EmptyQueueMessageID, } } @@ -198,6 +199,9 @@ func (p *ReplicationTaskProcessorImpl) eventLoop() { )) defer cleanupTimer.Stop() + replicationTimer := time.NewTimer(0) + defer replicationTimer.Stop() + var syncShardTask *replicationspb.SyncShardStatus for { select { @@ -226,25 +230,30 @@ func (p *ReplicationTaskProcessorImpl) eventLoop() { case <-p.shutdownChan: return - default: + case <-replicationTimer.C: if err := p.pollProcessReplicationTasks(); err != nil { p.logger.Error("unable to process replication tasks", tag.Error(err)) } + replicationTimer.Reset(p.rxTaskBackoff) } } } -func (p *ReplicationTaskProcessorImpl) pollProcessReplicationTasks() error { - taskIterator := collection.NewPagingIterator(p.paginationFn) +func (p *ReplicationTaskProcessorImpl) pollProcessReplicationTasks() (retError error) { + defer func() { + if retError != nil { + p.maxRxReceivedTaskID = p.maxRxProcessedTaskID + p.rxTaskBackoff = p.config.ReplicationTaskFetcherErrorRetryWait() + } + }() - count := 0 + taskIterator := collection.NewPagingIterator(p.paginationFn) for taskIterator.HasNext() && !p.isStopped() { task, err := taskIterator.Next() if err != nil { return err } - count++ replicationTask := task.(*replicationspb.ReplicationTask) if err = p.applyReplicationTask(replicationTask); err != nil { return err @@ -252,11 +261,11 @@ func (p *ReplicationTaskProcessorImpl) pollProcessReplicationTasks() error { p.maxRxProcessedTaskID = replicationTask.GetSourceTaskId() } - // TODO there should be better handling of remote not having replication tasks - // & make the application of replication task evenly distributed (in terms of time) - // stream / long poll API worth considering - if count == 0 { - time.Sleep(p.config.ReplicationTaskProcessorNoTaskRetryWait(p.shard.GetShardID())) + if !p.isStopped() { + // all tasks fetched successfully processed + // setting the receiver side max processed task ID to max received task ID + // since task ID is not contiguous + p.maxRxProcessedTaskID = p.maxRxReceivedTaskID } return nil @@ -415,7 +424,7 @@ func (p *ReplicationTaskProcessorImpl) paginationFn(_ []byte) ([]interface{}, [] token: &replicationspb.ReplicationToken{ ShardId: p.shard.GetShardID(), LastProcessedMessageId: p.maxRxProcessedTaskID, - LastRetrievedMessageId: p.maxRxProcessedTaskID, + LastRetrievedMessageId: p.maxRxReceivedTaskID, }, respChan: respChan, } @@ -438,6 +447,12 @@ func (p *ReplicationTaskProcessorImpl) paginationFn(_ []byte) ([]interface{}, [] for _, task := range resp.GetReplicationTasks() { tasks = append(tasks, task) } + p.maxRxReceivedTaskID = resp.GetLastRetrievedMessageId() + if resp.GetHasMore() { + p.rxTaskBackoff = time.Duration(0) + } else { + p.rxTaskBackoff = p.config.ReplicationTaskProcessorNoTaskRetryWait(p.shard.GetShardID()) + } return tasks, nil, nil case <-p.shutdownChan: diff --git a/service/history/replicationTaskProcessor_test.go b/service/history/replicationTaskProcessor_test.go index bcdfa36a34b..81a5776cdef 100644 --- a/service/history/replicationTaskProcessor_test.go +++ b/service/history/replicationTaskProcessor_test.go @@ -25,6 +25,7 @@ package history import ( + "math/rand" "testing" "time" @@ -377,7 +378,7 @@ func (s *replicationTaskProcessorSuite) TestCleanupReplicationTask_Cleanup() { s.NoError(err) } -func (s *replicationTaskProcessorSuite) TestPaginationFn_Success() { +func (s *replicationTaskProcessorSuite) TestPaginationFn_Success_More() { namespaceID := uuid.NewRandom().String() workflowID := uuid.New() runID := uuid.NewRandom().String() @@ -396,8 +397,11 @@ func (s *replicationTaskProcessorSuite) TestPaginationFn_Success() { syncShardTask := &replicationspb.SyncShardStatus{ StatusTime: timestamp.TimeNowPtrUtc(), } + taskID := int64(123) + lastRetrievedMessageID := 2 * taskID task := &replicationspb.ReplicationTask{ - TaskType: enumsspb.REPLICATION_TASK_TYPE_HISTORY_V2_TASK, + SourceTaskId: taskID, + TaskType: enumsspb.REPLICATION_TASK_TYPE_HISTORY_V2_TASK, Attributes: &replicationspb.ReplicationTask_HistoryTaskV2Attributes{ HistoryTaskV2Attributes: &replicationspb.HistoryTaskV2Attributes{ NamespaceId: namespaceID, @@ -412,18 +416,27 @@ func (s *replicationTaskProcessorSuite) TestPaginationFn_Success() { }, } + maxRxProcessedTaskID := rand.Int63() + maxRxReceivedTaskID := rand.Int63() + rxTaskBackoff := time.Duration(rand.Int63()) + s.replicationTaskProcessor.maxRxProcessedTaskID = maxRxProcessedTaskID + s.replicationTaskProcessor.maxRxReceivedTaskID = maxRxReceivedTaskID + s.replicationTaskProcessor.rxTaskBackoff = rxTaskBackoff + requestToken := &replicationspb.ReplicationToken{ ShardId: s.mockShard.GetShardID(), - LastProcessedMessageId: s.replicationTaskProcessor.maxRxProcessedTaskID, - LastRetrievedMessageId: s.replicationTaskProcessor.maxRxProcessedTaskID, + LastProcessedMessageId: maxRxProcessedTaskID, + LastRetrievedMessageId: maxRxReceivedTaskID, } go func() { request := <-s.requestChan s.Equal(requestToken, request.token) request.respChan <- &replicationspb.ReplicationMessages{ - SyncShardStatus: syncShardTask, - ReplicationTasks: []*replicationspb.ReplicationTask{task}, + SyncShardStatus: syncShardTask, + ReplicationTasks: []*replicationspb.ReplicationTask{task}, + LastRetrievedMessageId: lastRetrievedMessageID, + HasMore: true, } close(request.respChan) }() @@ -433,13 +446,95 @@ func (s *replicationTaskProcessorSuite) TestPaginationFn_Success() { s.Equal(1, len(tasks)) s.Equal(task, tasks[0].(*replicationspb.ReplicationTask)) s.Equal(syncShardTask, <-s.replicationTaskProcessor.syncShardChan) + s.Equal(lastRetrievedMessageID, s.replicationTaskProcessor.maxRxReceivedTaskID) + s.Equal(time.Duration(0), s.replicationTaskProcessor.rxTaskBackoff) +} + +func (s *replicationTaskProcessorSuite) TestPaginationFn_Success_NoMore() { + namespaceID := uuid.NewRandom().String() + workflowID := uuid.New() + runID := uuid.NewRandom().String() + events := []*historypb.HistoryEvent{{ + EventId: 1, + Version: 1, + }} + versionHistory := []*historyspb.VersionHistoryItem{{ + EventId: 1, + Version: 1, + }} + serializer := s.mockResource.GetPayloadSerializer() + data, err := serializer.SerializeEvents(events, enumspb.ENCODING_TYPE_PROTO3) + s.NoError(err) + + syncShardTask := &replicationspb.SyncShardStatus{ + StatusTime: timestamp.TimeNowPtrUtc(), + } + taskID := int64(123) + lastRetrievedMessageID := 2 * taskID + task := &replicationspb.ReplicationTask{ + SourceTaskId: taskID, + TaskType: enumsspb.REPLICATION_TASK_TYPE_HISTORY_V2_TASK, + Attributes: &replicationspb.ReplicationTask_HistoryTaskV2Attributes{ + HistoryTaskV2Attributes: &replicationspb.HistoryTaskV2Attributes{ + NamespaceId: namespaceID, + WorkflowId: workflowID, + RunId: runID, + Events: &commonpb.DataBlob{ + EncodingType: enumspb.ENCODING_TYPE_PROTO3, + Data: data.Data, + }, + VersionHistoryItems: versionHistory, + }, + }, + } + + maxRxProcessedTaskID := rand.Int63() + maxRxReceivedTaskID := rand.Int63() + rxTaskBackoff := time.Duration(rand.Int63()) + s.replicationTaskProcessor.maxRxProcessedTaskID = maxRxProcessedTaskID + s.replicationTaskProcessor.maxRxReceivedTaskID = maxRxReceivedTaskID + s.replicationTaskProcessor.rxTaskBackoff = rxTaskBackoff + + requestToken := &replicationspb.ReplicationToken{ + ShardId: s.mockShard.GetShardID(), + LastProcessedMessageId: maxRxProcessedTaskID, + LastRetrievedMessageId: maxRxReceivedTaskID, + } + + go func() { + request := <-s.requestChan + s.Equal(requestToken, request.token) + request.respChan <- &replicationspb.ReplicationMessages{ + SyncShardStatus: syncShardTask, + ReplicationTasks: []*replicationspb.ReplicationTask{task}, + LastRetrievedMessageId: lastRetrievedMessageID, + HasMore: false, + } + close(request.respChan) + }() + + tasks, _, err := s.replicationTaskProcessor.paginationFn(nil) + s.NoError(err) + s.Equal(1, len(tasks)) + s.Equal(task, tasks[0].(*replicationspb.ReplicationTask)) + s.Equal(syncShardTask, <-s.replicationTaskProcessor.syncShardChan) + s.Equal(lastRetrievedMessageID, s.replicationTaskProcessor.maxRxReceivedTaskID) + s.NotEqual(time.Duration(0), s.replicationTaskProcessor.rxTaskBackoff) } func (s *replicationTaskProcessorSuite) TestPaginationFn_Error() { + + maxRxProcessedTaskID := rand.Int63() + maxRxReceivedTaskID := rand.Int63() + rxTaskBackoff := time.Duration(rand.Int63()) + s.replicationTaskProcessor.maxRxProcessedTaskID = maxRxProcessedTaskID + s.replicationTaskProcessor.maxRxReceivedTaskID = maxRxReceivedTaskID + s.replicationTaskProcessor.rxTaskBackoff = rxTaskBackoff + requestToken := &replicationspb.ReplicationToken{ ShardId: s.mockShard.GetShardID(), - LastProcessedMessageId: s.replicationTaskProcessor.maxRxProcessedTaskID, - LastRetrievedMessageId: s.replicationTaskProcessor.maxRxProcessedTaskID, + LastProcessedMessageId: maxRxProcessedTaskID, + LastRetrievedMessageId: maxRxReceivedTaskID, } go func() { @@ -455,6 +550,8 @@ func (s *replicationTaskProcessorSuite) TestPaginationFn_Error() { case <-s.replicationTaskProcessor.syncShardChan: s.Fail("should not receive any sync shard task") default: - // noop + s.Equal(maxRxProcessedTaskID, s.replicationTaskProcessor.maxRxProcessedTaskID) + s.Equal(maxRxReceivedTaskID, s.replicationTaskProcessor.maxRxReceivedTaskID) + s.Equal(rxTaskBackoff, s.replicationTaskProcessor.rxTaskBackoff) } } diff --git a/service/history/replicatorQueueProcessor.go b/service/history/replicatorQueueProcessor.go index 4c2f6abf352..ff4892da6f4 100644 --- a/service/history/replicatorQueueProcessor.go +++ b/service/history/replicatorQueueProcessor.go @@ -27,6 +27,7 @@ package history import ( "context" "errors" + "sync" "time" commonpb "go.temporal.io/api/common/v1" @@ -38,11 +39,14 @@ import ( replicationspb "go.temporal.io/server/api/replication/v1" "go.temporal.io/server/common" "go.temporal.io/server/common/backoff" + "go.temporal.io/server/common/convert" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/persistence/versionhistory" + "go.temporal.io/server/common/primitives/timestamp" + "go.temporal.io/server/service/history/configs" "go.temporal.io/server/service/history/shard" ) @@ -50,14 +54,19 @@ type ( replicatorQueueProcessorImpl struct { currentClusterName string shard shard.Context + config *configs.Config historyCache *historyCache executionMgr persistence.ExecutionManager historyMgr persistence.HistoryManager metricsClient metrics.Client logger log.Logger retryPolicy backoff.RetryPolicy - // This is the batch size used by pull based RPC replicator. - fetchTasksBatchSize int + pageSize int + + sync.Mutex + // largest replication task ID generated + maxTaskID *int64 + sanityCheckTime time.Time } ) @@ -81,60 +90,58 @@ func newReplicatorQueueProcessor( retryPolicy.SetBackoffCoefficient(1) return &replicatorQueueProcessorImpl{ - currentClusterName: currentClusterName, - shard: shard, - historyCache: historyCache, - executionMgr: executionMgr, - historyMgr: historyMgr, - metricsClient: shard.GetMetricsClient(), - logger: log.With(logger, tag.ComponentReplicatorQueue), - retryPolicy: retryPolicy, - fetchTasksBatchSize: config.ReplicatorProcessorFetchTasksBatchSize(), + currentClusterName: currentClusterName, + shard: shard, + config: shard.GetConfig(), + historyCache: historyCache, + executionMgr: executionMgr, + historyMgr: historyMgr, + metricsClient: shard.GetMetricsClient(), + logger: log.With(logger, tag.ComponentReplicatorQueue), + retryPolicy: retryPolicy, + pageSize: config.ReplicatorProcessorFetchTasksBatchSize(), + + maxTaskID: nil, + sanityCheckTime: time.Time{}, } } -func (p *replicatorQueueProcessorImpl) getTasks( - ctx context.Context, - pollingCluster string, - lastReadTaskID int64, -) (*replicationspb.ReplicationMessages, error) { +func (p *replicatorQueueProcessorImpl) NotifyNewTasks( + tasks []persistence.Task, +) { - if lastReadTaskID == persistence.EmptyQueueMessageID { - lastReadTaskID = p.shard.GetClusterReplicationLevel(pollingCluster) - } else { - if err := p.shard.UpdateClusterReplicationLevel( - pollingCluster, - lastReadTaskID, - ); err != nil { - p.logger.Error("error updating replication level for shard", tag.Error(err), tag.OperationFailed) + if len(tasks) == 0 { + return + } + maxTaskID := tasks[0].GetTaskID() + for _, task := range tasks { + if maxTaskID < task.GetTaskID() { + maxTaskID = task.GetTaskID() } } - taskInfoList, hasMore, err := p.readTasksWithBatchSize(lastReadTaskID, p.fetchTasksBatchSize) - if err != nil { - return nil, err + p.Lock() + defer p.Unlock() + if p.maxTaskID == nil || *p.maxTaskID < maxTaskID { + p.maxTaskID = &maxTaskID } +} - var replicationTasks []*replicationspb.ReplicationTask - readLevel := lastReadTaskID - for _, taskInfo := range taskInfoList { - var replicationTask *replicationspb.ReplicationTask - op := func() error { - var err error - replicationTask, err = p.toReplicationTask(ctx, taskInfo) - return err - } +func (p *replicatorQueueProcessorImpl) paginateTasks( + ctx context.Context, + pollingCluster string, + queryMessageID int64, +) (*replicationspb.ReplicationMessages, error) { - err = backoff.Retry(op, p.retryPolicy, common.IsPersistenceTransientError) - if err != nil { - p.logger.Debug("Failed to get replication task. Return what we have so far.", tag.Error(err)) - hasMore = true - break - } - readLevel = taskInfo.GetTaskId() - if replicationTask != nil { - replicationTasks = append(replicationTasks, replicationTask) - } + minTaskID, maxTaskID := p.taskIDsRange(queryMessageID) + replicationTasks, lastTaskID, err := p.getTasks( + ctx, + minTaskID, + maxTaskID, + p.pageSize, + ) + if err != nil { + return nil, err } // Note this is a very rough indicator of how much the remote DC is behind on this shard. @@ -143,13 +150,13 @@ func (p *replicatorQueueProcessorImpl) getTasks( metrics.TargetClusterTag(pollingCluster), ).RecordDistribution( metrics.ReplicationTasksLag, - int(p.shard.GetTransferMaxReadLevel()-readLevel), + int(maxTaskID-lastTaskID), ) p.metricsClient.RecordDistribution( metrics.ReplicatorQueueProcessorScope, metrics.ReplicationTasksFetched, - len(taskInfoList), + len(replicationTasks), ) p.metricsClient.RecordDistribution( @@ -160,11 +167,68 @@ func (p *replicatorQueueProcessorImpl) getTasks( return &replicationspb.ReplicationMessages{ ReplicationTasks: replicationTasks, - HasMore: hasMore, - LastRetrievedMessageId: readLevel, + HasMore: lastTaskID < maxTaskID, + LastRetrievedMessageId: lastTaskID, + SyncShardStatus: &replicationspb.SyncShardStatus{ + StatusTime: timestamp.TimePtr(p.shard.GetTimeSource().Now()), + }, }, nil } +func (p *replicatorQueueProcessorImpl) getTasks( + ctx context.Context, + minTaskID int64, + maxTaskID int64, + batchSize int, +) ([]*replicationspb.ReplicationTask, int64, error) { + + if minTaskID == maxTaskID { + return []*replicationspb.ReplicationTask{}, maxTaskID, nil + } + + var token []byte + tasks := make([]*replicationspb.ReplicationTask, 0, batchSize) + for { + response, err := p.executionMgr.GetReplicationTasks(&persistence.GetReplicationTasksRequest{ + MinTaskID: minTaskID, + MaxTaskID: maxTaskID, + BatchSize: batchSize, + NextPageToken: token, + }) + if err != nil { + return nil, 0, err + } + + token = response.NextPageToken + for _, task := range response.Tasks { + if replicationTask, err := p.taskInfoToTask( + ctx, + &persistence.ReplicationTaskInfoWrapper{ReplicationTaskInfo: task}, + ); err != nil { + return nil, 0, err + } else if replicationTask != nil { + tasks = append(tasks, replicationTask) + } + } + + // break if seen at least one task or no more task + if len(token) == 0 || len(tasks) > 0 { + break + } + } + + // sanity check we will finish pagination or return some tasks + if len(token) != 0 && len(tasks) == 0 { + p.logger.Fatal("replication task reader should finish pagination or return some tasks") + } + + if len(tasks) == 0 { + // len(token) == 0, no more items from DB + return nil, maxTaskID, nil + } + return tasks, tasks[len(tasks)-1].GetSourceTaskId(), nil +} + func (p *replicatorQueueProcessorImpl) getTask( ctx context.Context, taskInfo *replicationspb.ReplicationTaskInfo, @@ -181,29 +245,50 @@ func (p *replicatorQueueProcessorImpl) getTask( Version: taskInfo.GetVersion(), ScheduledId: taskInfo.GetScheduledId(), } - return p.toReplicationTask(ctx, &persistence.ReplicationTaskInfoWrapper{ReplicationTaskInfo: task}) + return p.taskInfoToTask(ctx, &persistence.ReplicationTaskInfoWrapper{ReplicationTaskInfo: task}) } -func (p *replicatorQueueProcessorImpl) readTasksWithBatchSize( - readLevel int64, - batchSize int, -) ([]queueTaskInfo, bool, error) { - response, err := p.executionMgr.GetReplicationTasks(&persistence.GetReplicationTasksRequest{ - ReadLevel: readLevel, - MaxReadLevel: p.shard.GetTransferMaxReadLevel(), - BatchSize: batchSize, - }) +func (p *replicatorQueueProcessorImpl) taskInfoToTask( + ctx context.Context, + taskInfo queueTaskInfo, +) (*replicationspb.ReplicationTask, error) { + var replicationTask *replicationspb.ReplicationTask + op := func() error { + var err error + replicationTask, err = p.toReplicationTask(ctx, taskInfo) + return err + } - if err != nil { - return nil, false, err + if err := backoff.Retry(op, p.retryPolicy, common.IsPersistenceTransientError); err != nil { + return nil, err + } + return replicationTask, nil +} + +func (p *replicatorQueueProcessorImpl) taskIDsRange( + lastReadMessageID int64, +) (minTaskID int64, maxTaskID int64) { + minTaskID = lastReadMessageID + maxTaskID = p.shard.GetTransferMaxReadLevel() + + p.Lock() + defer p.Unlock() + defer func() { p.maxTaskID = convert.Int64Ptr(maxTaskID) }() + + now := p.shard.GetTimeSource().Now() + if p.sanityCheckTime.IsZero() || p.sanityCheckTime.Before(now) { + p.sanityCheckTime = now.Add(backoff.JitDuration( + p.config.ReplicatorProcessorMaxPollInterval(), + p.config.ReplicatorProcessorMaxPollIntervalJitterCoefficient(), + )) + return minTaskID, maxTaskID } - tasks := make([]queueTaskInfo, len(response.Tasks)) - for i := range response.Tasks { - tasks[i] = &persistence.ReplicationTaskInfoWrapper{ReplicationTaskInfo: response.Tasks[i]} + if p.maxTaskID != nil && *p.maxTaskID < maxTaskID { + maxTaskID = *p.maxTaskID } - return tasks, len(response.NextPageToken) != 0, nil + return minTaskID, maxTaskID } func (p *replicatorQueueProcessorImpl) toReplicationTask( diff --git a/service/history/replicatorQueueProcessor_test.go b/service/history/replicatorQueueProcessor_test.go index de23e52c74d..fbfafdffb9a 100644 --- a/service/history/replicatorQueueProcessor_test.go +++ b/service/history/replicatorQueueProcessor_test.go @@ -43,6 +43,7 @@ import ( "go.temporal.io/server/common" "go.temporal.io/server/common/cache" "go.temporal.io/server/common/cluster" + "go.temporal.io/server/common/convert" "go.temporal.io/server/common/failure" "go.temporal.io/server/common/log" "go.temporal.io/server/common/payloads" @@ -120,6 +121,89 @@ func (s *replicatorQueueProcessorSuite) TearDownTest() { s.controller.Finish() } +func (s *replicatorQueueProcessorSuite) TestNotifyNewTasks_NotInitialized() { + s.replicatorQueueProcessor.maxTaskID = nil + + s.replicatorQueueProcessor.NotifyNewTasks([]persistence.Task{ + &persistence.HistoryReplicationTask{TaskID: 456}, + &persistence.HistoryReplicationTask{TaskID: 123}, + }) + + s.Equal(*s.replicatorQueueProcessor.maxTaskID, int64(456)) +} + +func (s *replicatorQueueProcessorSuite) TestNotifyNewTasks_Initialized() { + s.replicatorQueueProcessor.maxTaskID = convert.Int64Ptr(123) + + s.replicatorQueueProcessor.NotifyNewTasks([]persistence.Task{ + &persistence.HistoryReplicationTask{TaskID: 100}, + }) + s.Equal(*s.replicatorQueueProcessor.maxTaskID, int64(123)) + + s.replicatorQueueProcessor.NotifyNewTasks([]persistence.Task{ + &persistence.HistoryReplicationTask{TaskID: 234}, + }) + s.Equal(*s.replicatorQueueProcessor.maxTaskID, int64(234)) +} + +func (s *replicatorQueueProcessorSuite) TestTaskIDRange_NotInitialized() { + s.replicatorQueueProcessor.sanityCheckTime = time.Time{} + expectMaxTaskID := s.mockShard.GetTransferMaxReadLevel() + expectMinTaskID := expectMaxTaskID - 100 + s.replicatorQueueProcessor.maxTaskID = convert.Int64Ptr(expectMinTaskID - 100) + + minTaskID, maxTaskID := s.replicatorQueueProcessor.taskIDsRange(expectMinTaskID) + s.Equal(expectMinTaskID, minTaskID) + s.Equal(expectMaxTaskID, maxTaskID) + s.NotEqual(time.Time{}, s.replicatorQueueProcessor.sanityCheckTime) + s.Equal(expectMaxTaskID, *s.replicatorQueueProcessor.maxTaskID) +} + +func (s *replicatorQueueProcessorSuite) TestTaskIDRange_Initialized_UseHighestReplicationTaskID() { + now := time.Now().UTC() + sanityCheckTime := now.Add(2 * time.Minute) + s.replicatorQueueProcessor.sanityCheckTime = sanityCheckTime + expectMinTaskID := s.mockShard.GetTransferMaxReadLevel() - 100 + expectMaxTaskID := s.mockShard.GetTransferMaxReadLevel() - 50 + s.replicatorQueueProcessor.maxTaskID = convert.Int64Ptr(expectMaxTaskID) + + minTaskID, maxTaskID := s.replicatorQueueProcessor.taskIDsRange(expectMinTaskID) + s.Equal(expectMinTaskID, minTaskID) + s.Equal(expectMaxTaskID, maxTaskID) + s.Equal(sanityCheckTime, s.replicatorQueueProcessor.sanityCheckTime) + s.Equal(expectMaxTaskID, *s.replicatorQueueProcessor.maxTaskID) +} + +func (s *replicatorQueueProcessorSuite) TestTaskIDRange_Initialized_NoHighestReplicationTaskID() { + now := time.Now().UTC() + sanityCheckTime := now.Add(2 * time.Minute) + s.replicatorQueueProcessor.sanityCheckTime = sanityCheckTime + expectMinTaskID := s.mockShard.GetTransferMaxReadLevel() - 100 + expectMaxTaskID := s.mockShard.GetTransferMaxReadLevel() + s.replicatorQueueProcessor.maxTaskID = nil + + minTaskID, maxTaskID := s.replicatorQueueProcessor.taskIDsRange(expectMinTaskID) + s.Equal(expectMinTaskID, minTaskID) + s.Equal(expectMaxTaskID, maxTaskID) + s.Equal(sanityCheckTime, s.replicatorQueueProcessor.sanityCheckTime) + s.Equal(expectMaxTaskID, *s.replicatorQueueProcessor.maxTaskID) +} + +func (s *replicatorQueueProcessorSuite) TestTaskIDRange_Initialized_UseHighestTransferTaskID() { + now := time.Now().UTC() + sanityCheckTime := now.Add(-2 * time.Minute) + s.replicatorQueueProcessor.sanityCheckTime = sanityCheckTime + expectMinTaskID := s.mockShard.GetTransferMaxReadLevel() - 100 + expectMaxTaskID := s.mockShard.GetTransferMaxReadLevel() + s.replicatorQueueProcessor.maxTaskID = convert.Int64Ptr(s.mockShard.GetTransferMaxReadLevel() - 50) + + minTaskID, maxTaskID := s.replicatorQueueProcessor.taskIDsRange(expectMinTaskID) + s.Equal(expectMinTaskID, minTaskID) + s.Equal(expectMaxTaskID, maxTaskID) + s.NotEqual(sanityCheckTime, s.replicatorQueueProcessor.sanityCheckTime) + s.Equal(expectMaxTaskID, *s.replicatorQueueProcessor.maxTaskID) +} + func (s *replicatorQueueProcessorSuite) TestSyncActivity_WorkflowMissing() { ctx := context.Background() namespace := "some random namespace name" diff --git a/service/history/shard/context.go b/service/history/shard/context.go index e4a4b76fad0..6f5418deb8f 100644 --- a/service/history/shard/context.go +++ b/service/history/shard/context.go @@ -87,7 +87,7 @@ type ( UpdateReplicatorDLQAckLevel(sourCluster string, ackLevel int64) error GetClusterReplicationLevel(cluster string) int64 - UpdateClusterReplicationLevel(cluster string, lastTaskID int64) error + UpdateClusterReplicationLevel(cluster string, ackTaskID int64) error GetTimerAckLevel() time.Time UpdateTimerAckLevel(ackLevel time.Time) error diff --git a/service/history/shard/context_impl.go b/service/history/shard/context_impl.go index 8698b19740e..f4bf4512757 100644 --- a/service/history/shard/context_impl.go +++ b/service/history/shard/context_impl.go @@ -261,11 +261,11 @@ func (s *ContextImpl) GetClusterReplicationLevel(cluster string) int64 { return persistence.EmptyQueueMessageID } -func (s *ContextImpl) UpdateClusterReplicationLevel(cluster string, lastTaskID int64) error { +func (s *ContextImpl) UpdateClusterReplicationLevel(cluster string, ackTaskID int64) error { s.Lock() defer s.Unlock() - s.shardInfo.ClusterReplicationLevel[cluster] = lastTaskID + s.shardInfo.ClusterReplicationLevel[cluster] = ackTaskID s.shardInfo.StolenSinceRenew = 0 return s.updateShardInfoLocked() } @@ -709,6 +709,7 @@ func (s *ContextImpl) AddTasks( s.engine.NotifyNewTransferTasks(request.TransferTasks) s.engine.NotifyNewTimerTasks(request.TimerTasks) s.engine.NotifyNewVisibilityTasks(request.VisibilityTasks) + s.engine.NotifyNewReplicationTasks(request.ReplicationTasks) return nil case *persistence.TimeoutError: // noop @@ -878,7 +879,7 @@ func (s *ContextImpl) renewRangeLocked(isStealing bool) error { func (s *ContextImpl) updateMaxReadLevelLocked(rl int64) { if rl > s.transferMaxReadLevel { - s.logger.Debug("Updating MaxReadLevel", tag.MaxLevel(rl)) + s.logger.Debug("Updating MaxTaskID", tag.MaxLevel(rl)) s.transferMaxReadLevel = rl } } diff --git a/service/history/shard/context_mock.go b/service/history/shard/context_mock.go index 021bb3d4c6d..7d5a9d78ebc 100644 --- a/service/history/shard/context_mock.go +++ b/service/history/shard/context_mock.go @@ -615,17 +615,17 @@ func (mr *MockContextMockRecorder) SetEngine(arg0 interface{}) *gomock.Call { } // UpdateClusterReplicationLevel mocks base method. -func (m *MockContext) UpdateClusterReplicationLevel(cluster string, lastTaskID int64) error { +func (m *MockContext) UpdateClusterReplicationLevel(cluster string, ackTaskID int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateClusterReplicationLevel", cluster, lastTaskID) + ret := m.ctrl.Call(m, "UpdateClusterReplicationLevel", cluster, ackTaskID) ret0, _ := ret[0].(error) return ret0 } // UpdateClusterReplicationLevel indicates an expected call of UpdateClusterReplicationLevel. -func (mr *MockContextMockRecorder) UpdateClusterReplicationLevel(cluster, lastTaskID interface{}) *gomock.Call { +func (mr *MockContextMockRecorder) UpdateClusterReplicationLevel(cluster, ackTaskID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateClusterReplicationLevel", reflect.TypeOf((*MockContext)(nil).UpdateClusterReplicationLevel), cluster, lastTaskID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateClusterReplicationLevel", reflect.TypeOf((*MockContext)(nil).UpdateClusterReplicationLevel), cluster, ackTaskID) } // UpdateNamespaceNotificationVersion mocks base method. diff --git a/service/history/shard/context_test.go b/service/history/shard/context_test.go index 2c08fc4b3f4..8221944ed02 100644 --- a/service/history/shard/context_test.go +++ b/service/history/shard/context_test.go @@ -133,6 +133,7 @@ func (s *contextSuite) TestAddTasks_Success() { s.mockHistoryEngine.EXPECT().NotifyNewTransferTasks(transferTasks) s.mockHistoryEngine.EXPECT().NotifyNewTimerTasks(timerTasks) s.mockHistoryEngine.EXPECT().NotifyNewVisibilityTasks(visibilityTasks) + s.mockHistoryEngine.EXPECT().NotifyNewReplicationTasks(replicationTasks) err := s.shardContext.AddTasks(addTasksRequest) s.NoError(err) diff --git a/service/history/shard/engine.go b/service/history/shard/engine.go index 0c829839330..8645a8b577f 100644 --- a/service/history/shard/engine.go +++ b/service/history/shard/engine.go @@ -69,7 +69,7 @@ type ( ReplicateEventsV2(ctx context.Context, request *historyservice.ReplicateEventsV2Request) error SyncShardStatus(ctx context.Context, request *historyservice.SyncShardStatusRequest) error SyncActivity(ctx context.Context, request *historyservice.SyncActivityRequest) error - GetReplicationMessages(ctx context.Context, pollingCluster string, lastReadMessageID int64) (*replicationspb.ReplicationMessages, error) + GetReplicationMessages(ctx context.Context, pollingCluster string, ackMessageID int64, queryMessageID int64) (*replicationspb.ReplicationMessages, error) GetDLQReplicationMessages(ctx context.Context, taskInfos []*replicationspb.ReplicationTaskInfo) ([]*replicationspb.ReplicationTask, error) QueryWorkflow(ctx context.Context, request *historyservice.QueryWorkflowRequest) (*historyservice.QueryWorkflowResponse, error) ReapplyEvents(ctx context.Context, namespaceUUID string, workflowID string, runID string, events []*historypb.HistoryEvent) error @@ -82,5 +82,6 @@ type ( NotifyNewTransferTasks(tasks []persistence.Task) NotifyNewTimerTasks(tasks []persistence.Task) NotifyNewVisibilityTasks(tasks []persistence.Task) + NotifyNewReplicationTasks(tasks []persistence.Task) } ) diff --git a/service/history/shard/engine_mock.go b/service/history/shard/engine_mock.go index 1582f2be2e7..33feb8cd7a0 100644 --- a/service/history/shard/engine_mock.go +++ b/service/history/shard/engine_mock.go @@ -140,18 +140,18 @@ func (mr *MockEngineMockRecorder) GetMutableState(ctx, request interface{}) *gom } // GetReplicationMessages mocks base method. -func (m *MockEngine) GetReplicationMessages(ctx context.Context, pollingCluster string, lastReadMessageID int64) (*repication.ReplicationMessages, error) { +func (m *MockEngine) GetReplicationMessages(ctx context.Context, pollingCluster string, ackMessageID, queryMessageID int64) (*repication.ReplicationMessages, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetReplicationMessages", ctx, pollingCluster, lastReadMessageID) + ret := m.ctrl.Call(m, "GetReplicationMessages", ctx, pollingCluster, ackMessageID, queryMessageID) ret0, _ := ret[0].(*repication.ReplicationMessages) ret1, _ := ret[1].(error) return ret0, ret1 } // GetReplicationMessages indicates an expected call of GetReplicationMessages. -func (mr *MockEngineMockRecorder) GetReplicationMessages(ctx, pollingCluster, lastReadMessageID interface{}) *gomock.Call { +func (mr *MockEngineMockRecorder) GetReplicationMessages(ctx, pollingCluster, ackMessageID, queryMessageID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetReplicationMessages", reflect.TypeOf((*MockEngine)(nil).GetReplicationMessages), ctx, pollingCluster, lastReadMessageID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetReplicationMessages", reflect.TypeOf((*MockEngine)(nil).GetReplicationMessages), ctx, pollingCluster, ackMessageID, queryMessageID) } // MergeDLQMessages mocks base method. @@ -181,6 +181,18 @@ func (mr *MockEngineMockRecorder) NotifyNewHistoryEvent(event interface{}) *gomo return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotifyNewHistoryEvent", reflect.TypeOf((*MockEngine)(nil).NotifyNewHistoryEvent), event) } +// NotifyNewReplicationTasks mocks base method. +func (m *MockEngine) NotifyNewReplicationTasks(tasks []persistence.Task) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "NotifyNewReplicationTasks", tasks) +} + +// NotifyNewReplicationTasks indicates an expected call of NotifyNewReplicationTasks. +func (mr *MockEngineMockRecorder) NotifyNewReplicationTasks(tasks interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotifyNewReplicationTasks", reflect.TypeOf((*MockEngine)(nil).NotifyNewReplicationTasks), tasks) +} + // NotifyNewTimerTasks mocks base method. func (m *MockEngine) NotifyNewTimerTasks(tasks []persistence.Task) { m.ctrl.T.Helper() diff --git a/service/history/timerQueueTaskExecutorBase_test.go b/service/history/timerQueueTaskExecutorBase_test.go index de28f5b21d9..878673e2862 100644 --- a/service/history/timerQueueTaskExecutorBase_test.go +++ b/service/history/timerQueueTaskExecutorBase_test.go @@ -144,6 +144,7 @@ func (s *timerQueueTaskExecutorBaseSuite) TestDeleteWorkflow_NoErr() { s.mockEngine.EXPECT().NotifyNewTransferTasks(gomock.Any()) s.mockEngine.EXPECT().NotifyNewTimerTasks(gomock.Any()) s.mockEngine.EXPECT().NotifyNewVisibilityTasks(gomock.Any()) + s.mockEngine.EXPECT().NotifyNewReplicationTasks(gomock.Any()) s.mockExecutionManager.EXPECT().DeleteCurrentWorkflowExecution(gomock.Any()).Return(nil) s.mockExecutionManager.EXPECT().DeleteWorkflowExecution(gomock.Any()).Return(nil) @@ -169,6 +170,7 @@ func (s *timerQueueTaskExecutorBaseSuite) TestArchiveHistory_NoErr_InlineArchiva s.mockEngine.EXPECT().NotifyNewTransferTasks(gomock.Any()) s.mockEngine.EXPECT().NotifyNewTimerTasks(gomock.Any()) s.mockEngine.EXPECT().NotifyNewVisibilityTasks(gomock.Any()) + s.mockEngine.EXPECT().NotifyNewReplicationTasks(gomock.Any()) s.mockExecutionManager.EXPECT().DeleteCurrentWorkflowExecution(gomock.Any()).Return(nil) s.mockExecutionManager.EXPECT().DeleteWorkflowExecution(gomock.Any()).Return(nil) diff --git a/service/history/workflowExecutionContext.go b/service/history/workflowExecutionContext.go index bd835589ce1..fc867e6318e 100644 --- a/service/history/workflowExecutionContext.go +++ b/service/history/workflowExecutionContext.go @@ -419,8 +419,8 @@ func (c *workflowExecutionContextImpl) createWorkflowExecution( c.notifyTasks( newWorkflow.TransferTasks, - newWorkflow.ReplicationTasks, newWorkflow.TimerTasks, + newWorkflow.ReplicationTasks, newWorkflow.VisibilityTasks, ) return nil @@ -568,23 +568,23 @@ func (c *workflowExecutionContextImpl) conflictResolveWorkflowExecution( c.notifyTasks( resetWorkflow.TransferTasks, - resetWorkflow.ReplicationTasks, resetWorkflow.TimerTasks, + resetWorkflow.ReplicationTasks, resetWorkflow.VisibilityTasks, ) if newWorkflow != nil { c.notifyTasks( newWorkflow.TransferTasks, - newWorkflow.ReplicationTasks, newWorkflow.TimerTasks, + newWorkflow.ReplicationTasks, newWorkflow.VisibilityTasks, ) } if currentWorkflow != nil { c.notifyTasks( currentWorkflow.TransferTasks, - currentWorkflow.ReplicationTasks, currentWorkflow.TimerTasks, + currentWorkflow.ReplicationTasks, currentWorkflow.VisibilityTasks, ) } @@ -785,8 +785,8 @@ func (c *workflowExecutionContextImpl) updateWorkflowExecutionWithNew( // notify current workflow tasks c.notifyTasks( currentWorkflow.TransferTasks, - currentWorkflow.ReplicationTasks, currentWorkflow.TimerTasks, + currentWorkflow.ReplicationTasks, currentWorkflow.VisibilityTasks, ) @@ -794,8 +794,8 @@ func (c *workflowExecutionContextImpl) updateWorkflowExecutionWithNew( if newWorkflow != nil { c.notifyTasks( newWorkflow.TransferTasks, - newWorkflow.ReplicationTasks, newWorkflow.TimerTasks, + newWorkflow.ReplicationTasks, newWorkflow.VisibilityTasks, ) } @@ -826,13 +826,14 @@ func (c *workflowExecutionContextImpl) updateWorkflowExecutionWithNew( func (c *workflowExecutionContextImpl) notifyTasks( transferTasks []persistence.Task, - replicationTasks []persistence.Task, timerTasks []persistence.Task, + replicationTasks []persistence.Task, visibilityTasks []persistence.Task, ) { c.engine.NotifyNewTransferTasks(transferTasks) c.engine.NotifyNewTimerTasks(timerTasks) c.engine.NotifyNewVisibilityTasks(visibilityTasks) + c.engine.NotifyNewReplicationTasks(replicationTasks) } func (c *workflowExecutionContextImpl) mergeContinueAsNewReplicationTasks(