From 4644dc96bf1c200104a3564078243782c627b44a Mon Sep 17 00:00:00 2001 From: Haifeng He Date: Thu, 20 Jul 2023 22:04:43 -0700 Subject: [PATCH] Execute VerifyReplicationTasks as an individual activity (#4656) **What changed?** Divide GenerateAndVerifyReplicationTasks activity into two activities: GenerateReplicationTasks (reuse previous one) and VerifyReplicationTasks **Why?** Based on cluster tests, GenerateReplicationTasks is expensive (10ms latency for `GenerateLastHistoryReplicationTasks` call). In previous implementation, VerificationTasks runs after GenerateReplicationTasks and we only get ~60 RPS for GenerateAndVerifyReplicationTasks. By dividing the two, we can achieve ~100 RPS VerifyReplicationTasks for a single activity (bottleneck is still GenerateReplicationTasks because of 10ms latency). Also moved the special handling of WF not_found on target to VerifyReplicationTasks, which reduced # of `DescribeMutableState` call on source cluster. In previous implementation, `DescribeMutableState` is called for every replication task. Now we only call `DescribeMutableState` if WF was not found on target (which should be rare for steady state). The downside is that we can potentially replicate Zombie WF from source to target. But it should be avoidable by eliminating Zombie during migration process (i.e., delete WF on target if migration is incomplete). **How did you test it?** Unit test & cluster tests. **Potential risks** Low, the feature is disabled by default and only affect force replication workflow. **Is hotfix candidate?** No. --- common/metrics/metric_defs.go | 11 +- service/worker/migration/activities.go | 149 +++++-------- service/worker/migration/activities_test.go | 211 +++++++++++++++--- .../migration/force_replication_workflow.go | 54 +++-- .../force_replication_workflow_test.go | 72 +++++- 5 files changed, 349 insertions(+), 148 deletions(-) diff --git a/common/metrics/metric_defs.go b/common/metrics/metric_defs.go index 6cece8e6132..c96e4c7ec3f 100644 --- a/common/metrics/metric_defs.go +++ b/common/metrics/metric_defs.go @@ -1656,10 +1656,13 @@ var ( ScheduleTerminateWorkflowErrors = NewCounterDef("schedule_terminate_workflow_errors") // Force replication - EncounterZombieWorkflowCount = NewCounterDef("encounter_zombie_workflow_count") - CreateReplicationTasksLatency = NewTimerDef("create_replication_tasks_latency") - VerifyReplicationTaskSuccess = NewCounterDef("verify_replication_task_success") - VerifyReplicationTasksLatency = NewTimerDef("verify_replication_tasks_latency") + EncounterZombieWorkflowCount = NewCounterDef("encounter_zombie_workflow_count") + GenerateReplicationTasksLatency = NewTimerDef("generate_replication_tasks_latency") + VerifyReplicationTaskSuccess = NewCounterDef("verify_replication_task_success") + VerifyReplicationTaskNotFound = NewCounterDef("verify_replication_task_not_found") + VerifyReplicationTaskFailed = NewCounterDef("verify_replication_task_failed") + VerifyReplicationTasksLatency = NewTimerDef("verify_replication_tasks_latency") + VerifyDescribeMutableStateLatency = NewTimerDef("verify_describe_mutable_state_latency") // Replication NamespaceReplicationTaskAckLevelGauge = NewGaugeDef("namespace_replication_task_ack_level") diff --git a/service/worker/migration/activities.go b/service/worker/migration/activities.go index aa50165a0e0..b640ea7f6c8 100644 --- a/service/worker/migration/activities.go +++ b/service/worker/migration/activities.go @@ -75,30 +75,22 @@ type ( // State Diagram // -// NOT_CREATED -// │ -// │ -// CREATED_TO_BE_VERIFIED +// NOT_VERIFIED // │ // ┌────────┴─────────┐ // │ │ // VERIFIED VERIFIED_SKIPPED const ( - NOT_CREATED VerifyStatus = 0 - CREATED_TO_BE_VERIFIED VerifyStatus = 1 - VERIFIED VerifyStatus = 2 - VERIFY_SKIPPED VerifyStatus = 3 + NOT_VERIFIED VerifyStatus = 0 + VERIFIED VerifyStatus = 1 + VERIFY_SKIPPED VerifyStatus = 2 reasonZombieWorkflow = "Zombie workflow" reasonWorkflowNotFound = "Workflow not found" ) -func (r VerifyResult) isNotCreated() bool { - return r.Status == NOT_CREATED -} - -func (r VerifyResult) isCreatedToBeVerified() bool { - return r.Status == CREATED_TO_BE_VERIFIED +func (r VerifyResult) isNotVerified() bool { + return r.Status == NOT_VERIFIED } func (r VerifyResult) isVerified() bool { @@ -436,6 +428,11 @@ func (a *activities) GenerateReplicationTasks(ctx context.Context, request *gene ctx = a.setCallerInfoForGenReplicationTask(ctx, namespace.ID(request.NamespaceID)) rateLimiter := quotas.NewRateLimiter(request.RPS, int(math.Ceil(request.RPS))) + start := time.Now() + defer func() { + a.forceReplicationMetricsHandler.Timer(metrics.GenerateReplicationTasksLatency.GetMetricName()).Record(time.Since(start)) + }() + startIndex := 0 if activity.HasHeartbeatDetails(ctx) { var finishedIndex int @@ -447,11 +444,12 @@ func (a *activities) GenerateReplicationTasks(ctx context.Context, request *gene for i := startIndex; i < len(request.Executions); i++ { we := request.Executions[i] if err := a.generateWorkflowReplicationTask(ctx, rateLimiter, definition.NewWorkflowKey(request.NamespaceID, we.WorkflowId, we.RunId)); err != nil { - if _, isNotFound := err.(*serviceerror.NotFound); !isNotFound { + if !isNotFoundServiceError(err) { a.logger.Error("force-replication failed to generate replication task", tag.WorkflowNamespaceID(request.NamespaceID), tag.WorkflowID(we.WorkflowId), tag.WorkflowRunID(we.RunId), tag.Error(err)) return err } } + activity.RecordHeartbeat(ctx, i) } @@ -550,63 +548,39 @@ func (a *activities) SeedReplicationQueueWithUserDataEntries(ctx context.Context } } -func (a *activities) createReplicationTasks(ctx context.Context, request *genearteAndVerifyReplicationTasksRequest, detail *replicationTasksHeartbeatDetails) error { - start := time.Now() - defer func() { - a.forceReplicationMetricsHandler.Timer(metrics.CreateReplicationTasksLatency.GetMetricName()).Record(time.Since(start)) - }() +func isNotFoundServiceError(err error) bool { + _, ok := err.(*serviceerror.NotFound) + return ok +} - rateLimiter := quotas.NewRateLimiter(request.RPS, int(math.Ceil(request.RPS))) +func (a *activities) verifyHandleNotFoundWorkflow( + ctx context.Context, + namespaceID string, + we *commonpb.WorkflowExecution, + result *VerifyResult, +) error { + tags := []tag.Tag{tag.WorkflowType(forceReplicationWorkflowName), tag.WorkflowNamespaceID(namespaceID), tag.WorkflowID(we.WorkflowId), tag.WorkflowRunID(we.RunId)} + resp, err := a.historyClient.DescribeMutableState(ctx, &historyservice.DescribeMutableStateRequest{ + NamespaceId: namespaceID, + Execution: we, + }) - for i := 0; i < len(request.Executions); i++ { - r := &detail.Results[i] - if r.isCompleted() { - continue + if err != nil { + if isNotFoundServiceError(err) { + // Workflow could be deleted due to retention. + result.Status = VERIFY_SKIPPED + result.Reason = reasonWorkflowNotFound + return nil } - we := request.Executions[i] - tags := []tag.Tag{tag.WorkflowType(forceReplicationWorkflowName), tag.WorkflowNamespaceID(request.NamespaceID), tag.WorkflowID(we.WorkflowId), tag.WorkflowRunID(we.RunId)} - - resp, err := a.historyClient.DescribeMutableState(ctx, &historyservice.DescribeMutableStateRequest{ - NamespaceId: request.NamespaceID, - Execution: &we, - }) - - switch err.(type) { - case nil: - if resp.GetDatabaseMutableState().GetExecutionState().GetState() == enumsspb.WORKFLOW_EXECUTION_STATE_ZOMBIE { - a.forceReplicationMetricsHandler.Counter(metrics.EncounterZombieWorkflowCount.GetMetricName()).Record(1) - a.logger.Info("createReplicationTasks skip Zombie workflow", tags...) - - r.Status = VERIFY_SKIPPED - r.Reason = reasonZombieWorkflow - continue - } - - // Only create replication task if it hasn't been already created - if r.isNotCreated() { - err := a.generateWorkflowReplicationTask(ctx, rateLimiter, definition.NewWorkflowKey(request.NamespaceID, we.WorkflowId, we.RunId)) - - switch err.(type) { - case nil: - r.Status = CREATED_TO_BE_VERIFIED - case *serviceerror.NotFound: - // rare case but in case if execution was deleted after above DescribeMutableState - r.Status = VERIFY_SKIPPED - r.Reason = reasonWorkflowNotFound - default: - a.logger.Error(fmt.Sprintf("createReplicationTasks failed to generate replication task. Error: %v", err), tags...) - return err - } - } - - case *serviceerror.NotFound: - r.Status = VERIFY_SKIPPED - r.Reason = reasonWorkflowNotFound + return err + } - default: - return err - } + if resp.GetDatabaseMutableState().GetExecutionState().GetState() == enumsspb.WORKFLOW_EXECUTION_STATE_ZOMBIE { + a.forceReplicationMetricsHandler.Counter(metrics.EncounterZombieWorkflowCount.GetMetricName()).Record(1) + a.logger.Info("createReplicationTasks skip Zombie workflow", tags...) + result.Status = VERIFY_SKIPPED + result.Reason = reasonZombieWorkflow } return nil @@ -614,7 +588,7 @@ func (a *activities) createReplicationTasks(ctx context.Context, request *genear func (a *activities) verifyReplicationTasks( ctx context.Context, - request *genearteAndVerifyReplicationTasksRequest, + request *verifyReplicationTasksRequest, detail *replicationTasksHeartbeatDetails, remoteClient adminservice.AdminServiceClient, ) (verified bool, progress bool, err error) { @@ -627,32 +601,41 @@ func (a *activities) verifyReplicationTasks( for i := 0; i < len(request.Executions); i++ { r := &detail.Results[i] we := request.Executions[i] - if r.isNotCreated() { - // invalid state - return false, progress, temporal.NewNonRetryableApplicationError(fmt.Sprintf("verifyReplicationTasks: replication task for %v was not created", we), "", nil) - } - if r.isCompleted() { continue } + s := time.Now() // Check if execution exists on remote cluster _, err := remoteClient.DescribeMutableState(ctx, &adminservice.DescribeMutableStateRequest{ Namespace: request.Namespace, Execution: &we, }) + a.forceReplicationMetricsHandler.Timer(metrics.VerifyDescribeMutableStateLatency.GetMetricName()).Record(time.Since(s)) switch err.(type) { case nil: - a.forceReplicationMetricsHandler.Counter(metrics.VerifyReplicationTaskSuccess.GetMetricName()).Record(1) + a.forceReplicationMetricsHandler.WithTags(metrics.NamespaceTag(request.Namespace)).Counter(metrics.VerifyReplicationTaskSuccess.GetMetricName()).Record(1) r.Status = VERIFIED progress = true case *serviceerror.NotFound: - detail.LastNotFoundWorkflowExecution = we - return false, progress, nil + a.forceReplicationMetricsHandler.WithTags(metrics.NamespaceTag(request.Namespace)).Counter(metrics.VerifyReplicationTaskNotFound.GetMetricName()).Record(1) + if err := a.verifyHandleNotFoundWorkflow(ctx, request.NamespaceID, &we, r); err != nil { + return false, progress, err + } + + if r.isNotVerified() { + detail.LastNotFoundWorkflowExecution = we + return false, progress, nil + } + + progress = true default: + a.forceReplicationMetricsHandler.WithTags(metrics.NamespaceTag(request.Namespace), metrics.ServiceErrorTypeTag(err)). + Counter(metrics.VerifyReplicationTaskFailed.GetMetricName()).Record(1) + return false, progress, errors.WithMessage(err, "remoteClient.DescribeMutableState call failed") } } @@ -665,7 +648,7 @@ const ( defaultNoProgressNotRetryableTimeout = 15 * time.Minute ) -func (a *activities) GenerateAndVerifyReplicationTasks(ctx context.Context, request *genearteAndVerifyReplicationTasksRequest) error { +func (a *activities) VerifyReplicationTasks(ctx context.Context, request *verifyReplicationTasksRequest) error { ctx = headers.SetCallerInfo(ctx, headers.NewPreemptableCallerInfo(request.Namespace)) remoteClient := a.clientFactory.NewRemoteAdminClientWithTimeout( request.TargetClusterEndpoint, @@ -684,12 +667,6 @@ func (a *activities) GenerateAndVerifyReplicationTasks(ctx context.Context, requ activity.RecordHeartbeat(ctx, details) } - if err := a.createReplicationTasks(ctx, request, &details); err != nil { - return err - } - - activity.RecordHeartbeat(ctx, details) - // Verify if replication tasks exist on target cluster. There are several cases where execution was not found on target cluster. // 1. replication lag // 2. Zombie workflow execution @@ -704,10 +681,8 @@ func (a *activities) GenerateAndVerifyReplicationTasks(ctx context.Context, requ // - more than NonRetryableTimeout, it means potentially we encountered #4. The activity returns // non-retryable error and force-replication workflow will restarted. for { - var verified, progress bool - var err error - - if verified, progress, err = a.verifyReplicationTasks(ctx, request, &details, remoteClient); err != nil { + verified, progress, err := a.verifyReplicationTasks(ctx, request, &details, remoteClient) + if err != nil { return err } diff --git a/service/worker/migration/activities_test.go b/service/worker/migration/activities_test.go index 5bb4acd995f..0ac37c4eaa5 100644 --- a/service/worker/migration/activities_test.go +++ b/service/worker/migration/activities_test.go @@ -110,10 +110,13 @@ func (s *activitiesSuite) SetupTest() { s.logger = log.NewNoopLogger() s.mockMetricsHandler = metrics.NewMockHandler(s.controller) + s.mockMetricsHandler.EXPECT().WithTags(gomock.Any()).Return(s.mockMetricsHandler).AnyTimes() s.mockMetricsHandler.EXPECT().Timer(gomock.Any()).Return(metrics.NoopTimerMetricFunc).AnyTimes() s.mockMetricsHandler.EXPECT().Counter(gomock.Any()).Return(metrics.NoopCounterMetricFunc).AnyTimes() s.mockClientFactory.EXPECT().NewRemoteAdminClientWithTimeout(remoteRpcAddress, gomock.Any(), gomock.Any()). Return(s.mockRemoteAdminClient).AnyTimes() + s.mockNamespaceRegistry.EXPECT().GetNamespaceName(gomock.Any()). + Return(namespace.Name(mockedNamespace), nil).AnyTimes() s.a = &activities{ namespaceRegistry: s.mockNamespaceRegistry, @@ -141,37 +144,16 @@ func (s *activitiesSuite) initEnv() (*testsuite.TestActivityEnvironment, *heartb return env, &iceptor } -func (s *activitiesSuite) TestGenerateAndVerifyReplicationTasks_Success() { +func (s *activitiesSuite) TestVerifyReplicationTasks_Success() { env, iceptor := s.initEnv() - request := genearteAndVerifyReplicationTasksRequest{ + request := verifyReplicationTasksRequest{ Namespace: mockedNamespace, NamespaceID: mockedNamespaceID, - RPS: 10, TargetClusterEndpoint: remoteRpcAddress, Executions: []commonpb.WorkflowExecution{execution1, execution2}, } - // Setup create replication tasks - for i := 0; i < len(request.Executions); i++ { - we := request.Executions[i] - s.mockHistoryClient.EXPECT().DescribeMutableState(gomock.Any(), &historyservice.DescribeMutableStateRequest{ - NamespaceId: mockedNamespaceID, - Execution: &we, - }).Return(&historyservice.DescribeMutableStateResponse{ - DatabaseMutableState: &persistencepb.WorkflowMutableState{ - ExecutionState: &persistencepb.WorkflowExecutionState{ - State: enumsspb.WORKFLOW_EXECUTION_STATE_COMPLETED, - }, - }, - }, nil).Times(1) - - s.mockHistoryClient.EXPECT().GenerateLastHistoryReplicationTasks(gomock.Any(), &historyservice.GenerateLastHistoryReplicationTasksRequest{ - NamespaceId: mockedNamespaceID, - Execution: &we, - }).Return(&historyservice.GenerateLastHistoryReplicationTasksResponse{}, nil).Times(1) - } - // Immediately replicated s.mockRemoteAdminClient.EXPECT().DescribeMutableState(gomock.Any(), &adminservice.DescribeMutableStateRequest{ Namespace: mockedNamespace, @@ -188,6 +170,17 @@ func (s *activitiesSuite) TestGenerateAndVerifyReplicationTasks_Success() { {&adminservice.DescribeMutableStateResponse{}, nil}, } + s.mockHistoryClient.EXPECT().DescribeMutableState(gomock.Any(), &historyservice.DescribeMutableStateRequest{ + NamespaceId: mockedNamespaceID, + Execution: &execution2, + }).Return(&historyservice.DescribeMutableStateResponse{ + DatabaseMutableState: &persistencepb.WorkflowMutableState{ + ExecutionState: &persistencepb.WorkflowExecutionState{ + State: enumsspb.WORKFLOW_EXECUTION_STATE_COMPLETED, + }, + }, + }, nil).Times(2) + for _, r := range replicationSlowReponses { s.mockRemoteAdminClient.EXPECT().DescribeMutableState(gomock.Any(), &adminservice.DescribeMutableStateRequest{ Namespace: mockedNamespace, @@ -195,7 +188,7 @@ func (s *activitiesSuite) TestGenerateAndVerifyReplicationTasks_Success() { }).Return(r.resp, r.err).Times(1) } - _, err := env.ExecuteActivity(s.a.GenerateAndVerifyReplicationTasks, &request) + _, err := env.ExecuteActivity(s.a.VerifyReplicationTasks, &request) s.NoError(err) s.Greater(len(iceptor.replicationRecordedHeartbeats), 0) @@ -206,7 +199,7 @@ func (s *activitiesSuite) TestGenerateAndVerifyReplicationTasks_Success() { } } -func (s *activitiesSuite) TestGenerateAndVerifyReplicationTasks_Skipped() { +func (s *activitiesSuite) TestVerifyReplicationTasks_NotFound() { mockErr := serviceerror.NewInternal("mock error") var testcases = []struct { resp *historyservice.DescribeMutableStateResponse @@ -236,16 +229,15 @@ func (s *activitiesSuite) TestGenerateAndVerifyReplicationTasks_Skipped() { }, { nil, mockErr, - NOT_CREATED, + NOT_VERIFIED, "", mockErr, }, } - request := genearteAndVerifyReplicationTasksRequest{ + request := verifyReplicationTasksRequest{ Namespace: mockedNamespace, NamespaceID: mockedNamespaceID, - RPS: 10, TargetClusterEndpoint: remoteRpcAddress, Executions: []commonpb.WorkflowExecution{execution1}, } @@ -254,12 +246,17 @@ func (s *activitiesSuite) TestGenerateAndVerifyReplicationTasks_Skipped() { for _, t := range testcases { env, iceptor := s.initEnv() + s.mockRemoteAdminClient.EXPECT().DescribeMutableState(gomock.Any(), &adminservice.DescribeMutableStateRequest{ + Namespace: mockedNamespace, + Execution: &execution1, + }).Return(nil, serviceerror.NewNotFound("")) + s.mockHistoryClient.EXPECT().DescribeMutableState(gomock.Any(), &historyservice.DescribeMutableStateRequest{ NamespaceId: mockedNamespaceID, Execution: &execution1, }).Return(t.resp, t.err) - _, err := env.ExecuteActivity(s.a.GenerateAndVerifyReplicationTasks, &request) + _, err := env.ExecuteActivity(s.a.VerifyReplicationTasks, &request) if t.expectedErr == nil { s.NoError(err) } else { @@ -277,3 +274,159 @@ func (s *activitiesSuite) TestGenerateAndVerifyReplicationTasks_Skipped() { s.True(lastHeartBeat.CheckPoint.After(start)) } } + +func (s *activitiesSuite) TestVerifyReplicationTasks_FailedNotFound() { + env, iceptor := s.initEnv() + request := verifyReplicationTasksRequest{ + Namespace: mockedNamespace, + NamespaceID: mockedNamespaceID, + TargetClusterEndpoint: remoteRpcAddress, + Executions: []commonpb.WorkflowExecution{execution1}, + } + + s.mockHistoryClient.EXPECT().DescribeMutableState(gomock.Any(), &historyservice.DescribeMutableStateRequest{ + NamespaceId: mockedNamespaceID, + Execution: &execution1, + }).Return(&historyservice.DescribeMutableStateResponse{ + DatabaseMutableState: &persistencepb.WorkflowMutableState{ + ExecutionState: &persistencepb.WorkflowExecutionState{ + State: enumsspb.WORKFLOW_EXECUTION_STATE_COMPLETED, + }, + }, + }, nil) + + // Workflow not found at target cluster. + s.mockRemoteAdminClient.EXPECT().DescribeMutableState(gomock.Any(), &adminservice.DescribeMutableStateRequest{ + Namespace: mockedNamespace, + Execution: &execution1, + }).Return(nil, serviceerror.NewNotFound("")).AnyTimes() + + // Set CheckPoint to an early to trigger failure. + env.SetHeartbeatDetails(&replicationTasksHeartbeatDetails{ + Results: make([]VerifyResult, len(request.Executions)), + CheckPoint: time.Now().Add(-defaultNoProgressNotRetryableTimeout), + }) + + _, err := env.ExecuteActivity(s.a.VerifyReplicationTasks, &request) + s.Error(err) + s.ErrorContains(err, "verifyReplicationTasks was not able to make progress") + + s.Greater(len(iceptor.replicationRecordedHeartbeats), 0) + lastHeartBeat := iceptor.replicationRecordedHeartbeats[len(iceptor.replicationRecordedHeartbeats)-1] + s.Equal(len(request.Executions), len(lastHeartBeat.Results)) + for _, r := range lastHeartBeat.Results { + s.True(r.isNotVerified()) + } +} + +func (s *activitiesSuite) TestVerifyReplicationTasks_AlreadyVerified() { + env, iceptor := s.initEnv() + request := verifyReplicationTasksRequest{ + Namespace: mockedNamespace, + NamespaceID: mockedNamespaceID, + TargetClusterEndpoint: remoteRpcAddress, + Executions: []commonpb.WorkflowExecution{execution1}, + } + + env.SetHeartbeatDetails(&replicationTasksHeartbeatDetails{ + Results: []VerifyResult{ + {Status: VERIFIED}, + }, + CheckPoint: time.Now(), + }) + + _, err := env.ExecuteActivity(s.a.VerifyReplicationTasks, &request) + s.NoError(err) + + s.Greater(len(iceptor.replicationRecordedHeartbeats), 0) + lastHeartBeat := iceptor.replicationRecordedHeartbeats[len(iceptor.replicationRecordedHeartbeats)-1] + s.Equal(len(request.Executions), len(lastHeartBeat.Results)) + for _, r := range lastHeartBeat.Results { + s.True(r.isVerified()) + } +} + +func (s *activitiesSuite) Test_isNotFoundServiceError() { + s.True(isNotFoundServiceError(serviceerror.NewNotFound(""))) + var err error + s.False(isNotFoundServiceError(err)) + s.False(isNotFoundServiceError(serviceerror.NewInternal(""))) +} + +func (s *activitiesSuite) TestGenerateReplicationTasks_Success() { + env, iceptor := s.initEnv() + + request := generateReplicationTasksRequest{ + NamespaceID: mockedNamespaceID, + RPS: 10, + Executions: []commonpb.WorkflowExecution{execution1, execution2}, + } + + for i := 0; i < len(request.Executions); i++ { + we := request.Executions[i] + s.mockHistoryClient.EXPECT().GenerateLastHistoryReplicationTasks(gomock.Any(), &historyservice.GenerateLastHistoryReplicationTasksRequest{ + NamespaceId: mockedNamespaceID, + Execution: &we, + }).Return(&historyservice.GenerateLastHistoryReplicationTasksResponse{}, nil).Times(1) + } + + _, err := env.ExecuteActivity(s.a.GenerateReplicationTasks, &request) + s.NoError(err) + + s.Greater(len(iceptor.generateReplicationRecordedHeartbeats), 0) + lastIdx := len(iceptor.generateReplicationRecordedHeartbeats) - 1 + lastHeartBeat := iceptor.generateReplicationRecordedHeartbeats[lastIdx] + s.Equal(lastIdx, lastHeartBeat) +} + +func (s *activitiesSuite) TestGenerateReplicationTasks_NotFound() { + env, iceptor := s.initEnv() + + request := generateReplicationTasksRequest{ + NamespaceID: mockedNamespaceID, + RPS: 10, + Executions: []commonpb.WorkflowExecution{execution1}, + } + + s.mockHistoryClient.EXPECT().GenerateLastHistoryReplicationTasks(gomock.Any(), &historyservice.GenerateLastHistoryReplicationTasksRequest{ + NamespaceId: mockedNamespaceID, + Execution: &execution1, + }).Return(nil, serviceerror.NewNotFound("")).Times(1) + + _, err := env.ExecuteActivity(s.a.GenerateReplicationTasks, &request) + s.NoError(err) + + s.Greater(len(iceptor.generateReplicationRecordedHeartbeats), 0) + lastIdx := len(iceptor.generateReplicationRecordedHeartbeats) - 1 + lastHeartBeat := iceptor.generateReplicationRecordedHeartbeats[lastIdx] + s.Equal(0, lastHeartBeat) +} + +func (s *activitiesSuite) TestGenerateReplicationTasks_Failed() { + env, iceptor := s.initEnv() + + request := generateReplicationTasksRequest{ + NamespaceID: mockedNamespaceID, + RPS: 10, + Executions: []commonpb.WorkflowExecution{execution1, execution2}, + } + + s.mockHistoryClient.EXPECT().GenerateLastHistoryReplicationTasks(gomock.Any(), &historyservice.GenerateLastHistoryReplicationTasksRequest{ + NamespaceId: mockedNamespaceID, + Execution: &execution1, + }).Return(&historyservice.GenerateLastHistoryReplicationTasksResponse{}, nil).Times(1) + + s.mockHistoryClient.EXPECT().GenerateLastHistoryReplicationTasks(gomock.Any(), &historyservice.GenerateLastHistoryReplicationTasksRequest{ + NamespaceId: mockedNamespaceID, + Execution: &execution2, + }).Return(nil, serviceerror.NewInternal("")) + + _, err := env.ExecuteActivity(s.a.GenerateReplicationTasks, &request) + s.Error(err) + + s.Greater(len(iceptor.generateReplicationRecordedHeartbeats), 0) + lastIdx := len(iceptor.generateReplicationRecordedHeartbeats) - 1 + lastHeartBeat := iceptor.generateReplicationRecordedHeartbeats[lastIdx] + // Only the generation of 1st execution suceeded. + s.Equal(0, lastHeartBeat) +} diff --git a/service/worker/migration/force_replication_workflow.go b/service/worker/migration/force_replication_workflow.go index 501049cba3a..34f5ff96b6d 100644 --- a/service/worker/migration/force_replication_workflow.go +++ b/service/worker/migration/force_replication_workflow.go @@ -102,10 +102,9 @@ type ( RPS float64 } - genearteAndVerifyReplicationTasksRequest struct { + verifyReplicationTasksRequest struct { Namespace string NamespaceID string - RPS float64 TargetClusterEndpoint string VerifyInterval time.Duration `validate:"gte=0"` Executions []commonpb.WorkflowExecution @@ -359,7 +358,8 @@ func listWorkflowsForReplication(ctx workflow.Context, workflowExecutionsCh work func enqueueReplicationTasks(ctx workflow.Context, workflowExecutionsCh workflow.Channel, namespaceID string, params ForceReplicationParams) error { selector := workflow.NewSelector(ctx) - pendingActivities := 0 + pendingGenerateTasks := 0 + pendingVerifyTasks := 0 ao := workflow.ActivityOptions{ StartToCloseTimeout: time.Hour, @@ -374,41 +374,49 @@ func enqueueReplicationTasks(ctx workflow.Context, workflowExecutionsCh workflow var lastActivityErr error for workflowExecutionsCh.Receive(ctx, &workflowExecutions) { - var replicationTaskFuture workflow.Future + generateTaskFuture := workflow.ExecuteActivity(actx, a.GenerateReplicationTasks, &generateReplicationTasksRequest{ + NamespaceID: namespaceID, + Executions: workflowExecutions, + RPS: params.OverallRps / float64(params.ConcurrentActivityCount), + }) + + pendingGenerateTasks++ + selector.AddFuture(generateTaskFuture, func(f workflow.Future) { + pendingGenerateTasks-- + + if err := f.Get(ctx, nil); err != nil { + lastActivityErr = err + } + }) + futures = append(futures, generateTaskFuture) + if params.EnableVerification { - replicationTaskFuture = workflow.ExecuteActivity(actx, a.GenerateAndVerifyReplicationTasks, &genearteAndVerifyReplicationTasksRequest{ + verifyTaskFuture := workflow.ExecuteActivity(actx, a.VerifyReplicationTasks, &verifyReplicationTasksRequest{ TargetClusterEndpoint: params.TargetClusterEndpoint, Namespace: params.Namespace, NamespaceID: namespaceID, Executions: workflowExecutions, - RPS: params.OverallRps / float64(params.ConcurrentActivityCount), VerifyInterval: time.Duration(params.VerifyIntervalInSeconds) * time.Second, }) - } else { - replicationTaskFuture = workflow.ExecuteActivity(actx, a.GenerateReplicationTasks, &generateReplicationTasksRequest{ - NamespaceID: namespaceID, - Executions: workflowExecutions, - RPS: params.OverallRps / float64(params.ConcurrentActivityCount), - }) - } - pendingActivities++ - selector.AddFuture(replicationTaskFuture, func(f workflow.Future) { - pendingActivities-- + pendingVerifyTasks++ + selector.AddFuture(verifyTaskFuture, func(f workflow.Future) { + pendingVerifyTasks-- - if err := f.Get(ctx, nil); err != nil { - lastActivityErr = err - } - }) + if err := f.Get(ctx, nil); err != nil { + lastActivityErr = err + } + }) + + futures = append(futures, verifyTaskFuture) + } - if pendingActivities >= params.ConcurrentActivityCount { + for pendingGenerateTasks >= params.ConcurrentActivityCount || pendingVerifyTasks >= params.ConcurrentActivityCount { selector.Select(ctx) // this will block until one of the in-flight activities completes if lastActivityErr != nil { return lastActivityErr } } - - futures = append(futures, replicationTaskFuture) } for _, future := range futures { diff --git a/service/worker/migration/force_replication_workflow_test.go b/service/worker/migration/force_replication_workflow_test.go index cbdfbfd5cea..92e97926968 100644 --- a/service/worker/migration/force_replication_workflow_test.go +++ b/service/worker/migration/force_replication_workflow_test.go @@ -85,6 +85,7 @@ func TestForceReplicationWorkflow(t *testing.T) { }).Times(totalPageCount) env.OnActivity(a.GenerateReplicationTasks, mock.Anything, mock.Anything).Return(nil).Times(totalPageCount) + env.OnActivity(a.VerifyReplicationTasks, mock.Anything, mock.Anything).Return(nil).Times(totalPageCount) env.RegisterWorkflow(ForceTaskQueueUserDataReplicationWorkflow) env.OnActivity(a.SeedReplicationQueueWithUserDataEntries, mock.Anything, mock.Anything).Return(nil).Times(1) @@ -96,6 +97,7 @@ func TestForceReplicationWorkflow(t *testing.T) { OverallRps: 10, ListWorkflowsPageSize: 1, PageCountPerExecution: 4, + EnableVerification: true, }) require.True(t, env.IsWorkflowCompleted()) @@ -150,6 +152,7 @@ func TestForceReplicationWorkflow_ContinueAsNew(t *testing.T) { }).Times(maxPageCountPerExecution) env.OnActivity(a.GenerateReplicationTasks, mock.Anything, mock.Anything).Return(nil).Times(maxPageCountPerExecution) + env.OnActivity(a.VerifyReplicationTasks, mock.Anything, mock.Anything).Return(nil).Times(maxPageCountPerExecution) env.RegisterWorkflow(ForceTaskQueueUserDataReplicationWorkflow) env.OnActivity(a.SeedReplicationQueueWithUserDataEntries, mock.Anything, mock.Anything).Return(nil) @@ -161,6 +164,7 @@ func TestForceReplicationWorkflow_ContinueAsNew(t *testing.T) { OverallRps: 10, ListWorkflowsPageSize: 1, PageCountPerExecution: maxPageCountPerExecution, + EnableVerification: true, }) require.True(t, env.IsWorkflowCompleted()) @@ -285,10 +289,11 @@ func TestForceReplicationWorkflow_GenerateReplicationTaskNonRetryableError(t *te }, nil }) + var errMsg = "mock generate replication tasks error" // Only expect GenerateReplicationTasks to execute once and workflow will then fail because of // non-retryable error. env.OnActivity(a.GenerateReplicationTasks, mock.Anything, mock.Anything).Return( - temporal.NewNonRetryableApplicationError("mock generate replication tasks error", "", nil), + temporal.NewNonRetryableApplicationError(errMsg, "", nil), ).Times(1) env.RegisterWorkflow(ForceTaskQueueUserDataReplicationWorkflow) @@ -301,12 +306,66 @@ func TestForceReplicationWorkflow_GenerateReplicationTaskNonRetryableError(t *te OverallRps: 10, ListWorkflowsPageSize: 1, PageCountPerExecution: 4, + EnableVerification: true, }) require.True(t, env.IsWorkflowCompleted()) err := env.GetWorkflowError() require.Error(t, err) - require.Contains(t, err.Error(), "mock generate replication tasks error") + require.Contains(t, err.Error(), errMsg) + env.AssertExpectations(t) +} + +func TestForceReplicationWorkflow_VerifyReplicationTaskNonRetryableError(t *testing.T) { + testSuite := &testsuite.WorkflowTestSuite{} + env := testSuite.NewTestWorkflowEnvironment() + + namespaceID := uuid.New() + + var a *activities + env.OnActivity(a.GetMetadata, mock.Anything, metadataRequest{Namespace: "test-ns"}).Return(&metadataResponse{ShardCount: 4, NamespaceID: namespaceID}, nil) + + totalPageCount := 4 + currentPageCount := 0 + env.OnActivity(a.ListWorkflows, mock.Anything, mock.Anything).Return(func(ctx context.Context, request *workflowservice.ListWorkflowExecutionsRequest) (*listWorkflowsResponse, error) { + assert.Equal(t, "test-ns", request.Namespace) + currentPageCount++ + if currentPageCount < totalPageCount { + return &listWorkflowsResponse{ + Executions: []commonpb.WorkflowExecution{}, + NextPageToken: []byte("fake-page-token"), + }, nil + } + // your mock function implementation + return &listWorkflowsResponse{ + Executions: []commonpb.WorkflowExecution{}, + NextPageToken: nil, // last page + }, nil + }) + + var errMsg = "mock verify replication tasks error" + env.OnActivity(a.GenerateReplicationTasks, mock.Anything, mock.Anything).Return(nil).Times(1) + env.OnActivity(a.VerifyReplicationTasks, mock.Anything, mock.Anything).Return( + temporal.NewNonRetryableApplicationError(errMsg, "", nil), + ).Times(1) + + env.RegisterWorkflow(ForceTaskQueueUserDataReplicationWorkflow) + env.OnActivity(a.SeedReplicationQueueWithUserDataEntries, mock.Anything, mock.Anything).Return(nil) + + env.ExecuteWorkflow(ForceReplicationWorkflow, ForceReplicationParams{ + Namespace: "test-ns", + Query: "", + ConcurrentActivityCount: 1, + OverallRps: 10, + ListWorkflowsPageSize: 1, + PageCountPerExecution: 4, + EnableVerification: true, + }) + + require.True(t, env.IsWorkflowCompleted()) + err := env.GetWorkflowError() + require.Error(t, err) + require.Contains(t, err.Error(), errMsg) env.AssertExpectations(t) } @@ -424,9 +483,10 @@ type heartbeatRecordingInterceptor struct { interceptor.WorkerInterceptorBase interceptor.ActivityInboundInterceptorBase interceptor.ActivityOutboundInterceptorBase - seedRecordedHeartbeats []seedReplicationQueueWithUserDataEntriesHeartbeatDetails - replicationRecordedHeartbeats []replicationTasksHeartbeatDetails - T *testing.T + seedRecordedHeartbeats []seedReplicationQueueWithUserDataEntriesHeartbeatDetails + replicationRecordedHeartbeats []replicationTasksHeartbeatDetails + generateReplicationRecordedHeartbeats []int + T *testing.T } func (i *heartbeatRecordingInterceptor) InterceptActivity(ctx context.Context, next interceptor.ActivityInboundInterceptor) interceptor.ActivityInboundInterceptor { @@ -444,6 +504,8 @@ func (i *heartbeatRecordingInterceptor) RecordHeartbeat(ctx context.Context, det i.seedRecordedHeartbeats = append(i.seedRecordedHeartbeats, d) } else if d, ok := details[0].(replicationTasksHeartbeatDetails); ok { i.replicationRecordedHeartbeats = append(i.replicationRecordedHeartbeats, d) + } else if d, ok := details[0].(int); ok { + i.generateReplicationRecordedHeartbeats = append(i.generateReplicationRecordedHeartbeats, d) } else { assert.Fail(i.T, "invalid heartbeat details") }