From 7608a0c8b16bfcea48625f50e68a6216544915f5 Mon Sep 17 00:00:00 2001 From: Yichao Yang Date: Wed, 7 Jun 2023 18:16:32 -0700 Subject: [PATCH 1/9] Add retryable client for persistence task manager (#4457) --- common/persistence/client/factory.go | 1 + 1 file changed, 1 insertion(+) diff --git a/common/persistence/client/factory.go b/common/persistence/client/factory.go index 4242fa11528..ee32c4e01ac 100644 --- a/common/persistence/client/factory.go +++ b/common/persistence/client/factory.go @@ -117,6 +117,7 @@ func (f *factoryImpl) NewTaskManager() (p.TaskManager, error) { if f.metricsHandler != nil && f.healthSignals != nil { result = p.NewTaskPersistenceMetricsClient(result, f.metricsHandler, f.healthSignals, f.logger) } + result = p.NewTaskPersistenceRetryableClient(result, retryPolicy, IsPersistenceTransientError) return result, nil } From f8453ceb7de4daa0c23dea74bbba58b586704830 Mon Sep 17 00:00:00 2001 From: David Reiss Date: Thu, 8 Jun 2023 08:08:14 -0700 Subject: [PATCH 2/9] Hide BuildIds search attribute for schedules (#4458) --- service/frontend/workflow_handler.go | 3 ++- tests/schedule_test.go | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/service/frontend/workflow_handler.go b/service/frontend/workflow_handler.go index 4f90ce4afaf..8efd193047f 100644 --- a/service/frontend/workflow_handler.go +++ b/service/frontend/workflow_handler.go @@ -4871,9 +4871,10 @@ func (wh *WorkflowHandler) cleanScheduleSearchAttributes(searchAttributes *commo delete(fields, searchattribute.TemporalSchedulePaused) delete(fields, "TemporalScheduleInfoJSON") // used by older version, clean this up if present - // this isn't schedule-related but isn't relevant to the user for + // these aren't schedule-related but they aren't relevant to the user for // scheduler workflows since it's the server worker delete(fields, searchattribute.BinaryChecksums) + delete(fields, searchattribute.BuildIds) if len(fields) == 0 { return nil diff --git a/tests/schedule_test.go b/tests/schedule_test.go index fdeaded15e6..fa8d4df880a 100644 --- a/tests/schedule_test.go +++ b/tests/schedule_test.go @@ -254,6 +254,8 @@ func (s *scheduleIntegrationSuite) TestBasics() { s.EqualValues(365*24*3600, describeResp.Schedule.Policies.CatchupWindow.Seconds()) // set to default value s.Equal(schSAValue.Data, describeResp.SearchAttributes.IndexedFields[csa].Data) + s.Nil(describeResp.SearchAttributes.IndexedFields[searchattribute.BinaryChecksums]) + s.Nil(describeResp.SearchAttributes.IndexedFields[searchattribute.BuildIds]) s.Equal(schMemo.Data, describeResp.Memo.Fields["schedmemo1"].Data) s.Equal(wfSAValue.Data, describeResp.Schedule.Action.GetStartWorkflow().SearchAttributes.IndexedFields[csa].Data) s.Equal(wfMemo.Data, describeResp.Schedule.Action.GetStartWorkflow().Memo.Fields["wfmemo1"].Data) From e5d58dc7ecc469262be6d02215fcecd962c32d4c Mon Sep 17 00:00:00 2001 From: David Reiss Date: Thu, 8 Jun 2023 08:15:19 -0700 Subject: [PATCH 3/9] More functional tests for worker versioning (#4446) --- service/frontend/workflow_handler.go | 4 +- service/matching/taskQueueManager.go | 17 +- tests/versioning_test.go | 325 ++++++++++++++++++++++++++- 3 files changed, 329 insertions(+), 17 deletions(-) diff --git a/service/frontend/workflow_handler.go b/service/frontend/workflow_handler.go index 8efd193047f..b20a83574ca 100644 --- a/service/frontend/workflow_handler.go +++ b/service/frontend/workflow_handler.go @@ -878,7 +878,7 @@ func (wh *WorkflowHandler) PollWorkflowTaskQueue(ctx context.Context, request *w return &workflowservice.PollWorkflowTaskQueueResponse{}, nil } - // For newer build error, return silently. + // These errors are expected based on certain client behavior. We should not log them, it'd be too noisy. var newerBuild *serviceerror.NewerBuildExists if errors.As(err, &newerBuild) { return nil, err @@ -1114,7 +1114,7 @@ func (wh *WorkflowHandler) PollActivityTaskQueue(ctx context.Context, request *w return &workflowservice.PollActivityTaskQueueResponse{}, nil } - // For newer build error, return silently. + // These errors are expected based on certain client behavior. We should not log them, it'd be too noisy. var newerBuild *serviceerror.NewerBuildExists if errors.As(err, &newerBuild) { return nil, err diff --git a/service/matching/taskQueueManager.go b/service/matching/taskQueueManager.go index f548e4d8c38..171fa2bb70c 100644 --- a/service/matching/taskQueueManager.go +++ b/service/matching/taskQueueManager.go @@ -346,13 +346,12 @@ func (c *taskQueueManagerImpl) Stop() { c.unloadFromEngine() } -// isVersioned returns true if this is a tqm for a "versioned [low-level] task queue". Note -// that this is a different concept from the overall [high-level] task queue having versioning -// data associated with it, which is the usual meaning of "versioned task queue". In this case, -// it means whether this is a tqm processing a specific version set id. Unlike non-root -// partitions which are known (at some level) by other services, [low-level] task queues with a -// version set should not be interacted with outside of the matching service. -func (c *taskQueueManagerImpl) isVersioned() bool { +// managesSpecificVersionSet returns true if this is a tqm for a specific version set in the +// build-id-based versioning feature. Note that this is a different concept from the overall +// task queue having versioning data associated with it, which is the usual meaning of +// "versioned task queue". These task queues are not interacted with directly outside outside +// of a single matching node. +func (c *taskQueueManagerImpl) managesSpecificVersionSet() bool { return c.taskQueueID.VersionSet() != "" } @@ -362,7 +361,7 @@ func (c *taskQueueManagerImpl) isVersioned() bool { func (c *taskQueueManagerImpl) shouldFetchUserData() bool { // 1. If the db stores it, then we definitely should not be fetching. // 2. Additionally, we should not fetch for "versioned" tqms. - return !c.db.DbStoresUserData() && !c.isVersioned() + return !c.db.DbStoresUserData() && !c.managesSpecificVersionSet() } func (c *taskQueueManagerImpl) WaitUntilInitialized(ctx context.Context) error { @@ -510,7 +509,7 @@ func (c *taskQueueManagerImpl) DispatchQueryTask( // GetUserData returns the user data for the task queue if any. // Note: can return nil value with no error. func (c *taskQueueManagerImpl) GetUserData(ctx context.Context) (*persistencespb.VersionedTaskQueueUserData, chan struct{}, error) { - if c.isVersioned() { + if c.managesSpecificVersionSet() { return nil, nil, errNoUserDataOnVersionedTQM } return c.db.GetUserData(ctx) diff --git a/tests/versioning_test.go b/tests/versioning_test.go index eadb6bd4188..72db5a5757c 100644 --- a/tests/versioning_test.go +++ b/tests/versioning_test.go @@ -171,7 +171,7 @@ func (s *versioningIntegSuite) TestLinkToNonexistentCompatibleVersionReturnsNotF func (s *versioningIntegSuite) TestVersioningStatePersistsAcrossUnload() { ctx := NewContext() - tq := "integration-versioning-not-destroyed" + tq := "integration-versioning-persists" s.addNewDefaultBuildId(ctx, tq, "foo") @@ -311,6 +311,50 @@ func (s *versioningIntegSuite) dispatchNewWorkflow() { s.Equal("done!", out) } +func (s *versioningIntegSuite) TestDispatchNotUsingVersioning() { + s.testWithMatchingBehavior(s.dispatchNotUsingVersioning) +} + +func (s *versioningIntegSuite) dispatchNotUsingVersioning() { + tq := s.randomizeStr(s.T().Name()) + + wf1nover := func(ctx workflow.Context) (string, error) { + return "done without versioning!", nil + } + wf1 := func(ctx workflow.Context) (string, error) { + return "done with versioning!", nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + s.addNewDefaultBuildId(ctx, tq, "v1") + s.waitForPropagation(ctx, tq, "v1") + + w1nover := worker.New(s.sdkClient, tq, worker.Options{ + BuildID: s.prefixed("v1"), + UseBuildIDForVersioning: false, + MaxConcurrentWorkflowTaskPollers: numPollers, + }) + w1 := worker.New(s.sdkClient, tq, worker.Options{ + BuildID: s.prefixed("v1"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, + }) + w1nover.RegisterWorkflowWithOptions(wf1nover, workflow.RegisterOptions{Name: "wf"}) + w1.RegisterWorkflowWithOptions(wf1, workflow.RegisterOptions{Name: "wf"}) + s.NoError(w1nover.Start()) + defer w1nover.Stop() + s.NoError(w1.Start()) + defer w1.Stop() + + run, err := s.sdkClient.ExecuteWorkflow(ctx, sdkclient.StartWorkflowOptions{TaskQueue: tq}, "wf") + s.NoError(err) + var out string + s.NoError(run.Get(ctx, &out)) + s.Equal("done with versioning!", out) +} + func (s *versioningIntegSuite) TestDispatchNewWorkflowStartWorkerFirst() { s.testWithMatchingBehavior(s.dispatchNewWorkflowStartWorkerFirst) } @@ -468,11 +512,27 @@ func (s *versioningIntegSuite) dispatchUpgrade(stopOld bool) { s.Equal("done from 1.1!", out) } +type activityFailMode int + +const ( + dontFailActivity = iota + failActivity + timeoutActivity +) + func (s *versioningIntegSuite) TestDispatchActivity() { - s.testWithMatchingBehavior(s.dispatchActivity) + s.testWithMatchingBehavior(func() { s.dispatchActivity(dontFailActivity) }) +} + +func (s *versioningIntegSuite) TestDispatchActivityFail() { + s.testWithMatchingBehavior(func() { s.dispatchActivity(failActivity) }) } -func (s *versioningIntegSuite) dispatchActivity() { +func (s *versioningIntegSuite) TestDispatchActivityTimeout() { + s.testWithMatchingBehavior(func() { s.dispatchActivity(timeoutActivity) }) +} + +func (s *versioningIntegSuite) dispatchActivity(failMode activityFailMode) { // This also implicitly tests that a workflow stays on a compatible version set if a new // incompatible set is registered, because wf2 just panics. It further tests that // stickiness on v1 is not broken by registering v2, because the channel send will panic on @@ -482,8 +542,32 @@ func (s *versioningIntegSuite) dispatchActivity() { started := make(chan struct{}, 1) - act1 := func() (string, error) { return "v1", nil } - act2 := func() (string, error) { return "v2", nil } + var act1state, act2state atomic.Int32 + + act1 := func() (string, error) { + if act1state.Add(1) == 1 { + switch failMode { + case failActivity: + return "", errors.New("try again") + case timeoutActivity: + time.Sleep(5 * time.Second) + return "ignored", nil + } + } + return "v1", nil + } + act2 := func() (string, error) { + if act2state.Add(1) == 1 { + switch failMode { + case failActivity: + return "", errors.New("try again") + case timeoutActivity: + time.Sleep(5 * time.Second) + return "ignored", nil + } + } + return "v2", nil + } wf1 := func(ctx workflow.Context) (string, error) { started <- struct{}{} // wait for signal @@ -493,11 +577,13 @@ func (s *versioningIntegSuite) dispatchActivity() { ScheduleToCloseTimeout: time.Minute, DisableEagerExecution: true, VersioningIntent: temporal.VersioningIntentCompatible, + StartToCloseTimeout: 1 * time.Second, }), "act") fut2 := workflow.ExecuteActivity(workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ ScheduleToCloseTimeout: time.Minute, DisableEagerExecution: true, VersioningIntent: temporal.VersioningIntentDefault, // this one should go to default + StartToCloseTimeout: 1 * time.Second, }), "act") var val1, val2 string s.NoError(fut1.Get(ctx, &val1)) @@ -528,7 +614,7 @@ func (s *versioningIntegSuite) dispatchActivity() { s.NoError(err) // wait for it to start on v1 s.waitForChan(ctx, started) - close(started) //force panic if replayed + close(started) // force panic if replayed // now register v2 as default s.addNewDefaultBuildId(ctx, tq, "v2") @@ -552,6 +638,78 @@ func (s *versioningIntegSuite) dispatchActivity() { s.Equal("v1v2", out) } +func (s *versioningIntegSuite) TestDispatchActivityCompatible() { + s.testWithMatchingBehavior(s.dispatchActivityCompatible) +} + +func (s *versioningIntegSuite) dispatchActivityCompatible() { + tq := s.randomizeStr(s.T().Name()) + + started := make(chan struct{}, 2) + + act1 := func() (string, error) { return "v1", nil } + act11 := func() (string, error) { return "v1.1", nil } + wf1 := func(ctx workflow.Context) (string, error) { + started <- struct{}{} + // wait for signal + workflow.GetSignalChannel(ctx, "wait").Receive(ctx, nil) + // run activity + fut11 := workflow.ExecuteActivity(workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ + ScheduleToCloseTimeout: time.Minute, + DisableEagerExecution: true, + VersioningIntent: temporal.VersioningIntentCompatible, + }), "act") + var val11 string + s.NoError(fut11.Get(ctx, &val11)) + return val11, nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + s.addNewDefaultBuildId(ctx, tq, "v1") + s.waitForPropagation(ctx, tq, "v1") + + w1 := worker.New(s.sdkClient, tq, worker.Options{ + BuildID: s.prefixed("v1"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, + }) + w1.RegisterWorkflowWithOptions(wf1, workflow.RegisterOptions{Name: "wf"}) + w1.RegisterActivityWithOptions(act1, activity.RegisterOptions{Name: "act"}) + s.NoError(w1.Start()) + defer w1.Stop() + + run, err := s.sdkClient.ExecuteWorkflow(ctx, sdkclient.StartWorkflowOptions{TaskQueue: tq}, "wf") + s.NoError(err) + // wait for it to start on v1 + s.waitForChan(ctx, started) + + // now register v1.1 as compatible + s.addCompatibleBuildId(ctx, tq, "v1.1", "v1", false) + s.waitForPropagation(ctx, tq, "v1.1") + // start worker for v1.1 + w11 := worker.New(s.sdkClient, tq, worker.Options{ + BuildID: s.prefixed("v1.1"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, + }) + w11.RegisterWorkflowWithOptions(wf1, workflow.RegisterOptions{Name: "wf"}) + w11.RegisterActivityWithOptions(act11, activity.RegisterOptions{Name: "act"}) + s.NoError(w11.Start()) + defer w11.Stop() + + // wait for w1 long polls to all time out + time.Sleep(longPollTime) + + // unblock the workflow + s.NoError(s.sdkClient.SignalWorkflow(ctx, run.GetID(), run.GetRunID(), "wait", nil)) + + var out string + s.NoError(run.Get(ctx, &out)) + s.Equal("v1.1", out) +} + func (s *versioningIntegSuite) TestDispatchChildWorkflow() { s.testWithMatchingBehavior(s.dispatchChildWorkflow) } @@ -630,6 +788,161 @@ func (s *versioningIntegSuite) dispatchChildWorkflow() { s.Equal("v1v2", out) } +func (s *versioningIntegSuite) TestDispatchChildWorkflowUpgrade() { + s.testWithMatchingBehavior(s.dispatchChildWorkflowUpgrade) +} + +func (s *versioningIntegSuite) dispatchChildWorkflowUpgrade() { + tq := s.randomizeStr(s.T().Name()) + + started := make(chan struct{}, 2) + + child1 := func(workflow.Context) (string, error) { return "v1", nil } + child11 := func(workflow.Context) (string, error) { return "v1.1", nil } + wf1 := func(ctx workflow.Context) (string, error) { + started <- struct{}{} + // wait for signal + workflow.GetSignalChannel(ctx, "wait").Receive(ctx, nil) + // run child + fut11 := workflow.ExecuteChildWorkflow(workflow.WithChildOptions(ctx, workflow.ChildWorkflowOptions{}), "child") + var val11 string + s.NoError(fut11.Get(ctx, &val11)) + return val11, nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + s.addNewDefaultBuildId(ctx, tq, "v1") + s.waitForPropagation(ctx, tq, "v1") + + w1 := worker.New(s.sdkClient, tq, worker.Options{ + BuildID: s.prefixed("v1"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, + }) + w1.RegisterWorkflowWithOptions(wf1, workflow.RegisterOptions{Name: "wf"}) + w1.RegisterWorkflowWithOptions(child1, workflow.RegisterOptions{Name: "child"}) + s.NoError(w1.Start()) + defer w1.Stop() + + run, err := s.sdkClient.ExecuteWorkflow(ctx, sdkclient.StartWorkflowOptions{TaskQueue: tq}, "wf") + s.NoError(err) + // wait for it to start on v1 + s.waitForChan(ctx, started) + + // now register v1.1 as compatible + s.addCompatibleBuildId(ctx, tq, "v1.1", "v1", false) + s.waitForPropagation(ctx, tq, "v1.1") + // start worker for v1.1 + w11 := worker.New(s.sdkClient, tq, worker.Options{ + BuildID: s.prefixed("v1.1"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, + }) + w11.RegisterWorkflowWithOptions(wf1, workflow.RegisterOptions{Name: "wf"}) + w11.RegisterWorkflowWithOptions(child11, workflow.RegisterOptions{Name: "child"}) + s.NoError(w11.Start()) + defer w11.Stop() + + // wait for w1 long polls to all time out + time.Sleep(longPollTime) + + // unblock the workflow + s.NoError(s.sdkClient.SignalWorkflow(ctx, run.GetID(), run.GetRunID(), "wait", nil)) + + var out string + s.NoError(run.Get(ctx, &out)) + s.Equal("v1.1", out) +} + +func (s *versioningIntegSuite) TestDispatchQuery() { + s.testWithMatchingBehavior(s.dispatchQuery) +} + +func (s *versioningIntegSuite) dispatchQuery() { + tq := s.randomizeStr(s.T().Name()) + + started := make(chan struct{}, 2) + + wf1 := func(ctx workflow.Context) error { + workflow.SetQueryHandler(ctx, "query", func() (string, error) { return "v1", nil }) + started <- struct{}{} + workflow.GetSignalChannel(ctx, "wait").Receive(ctx, nil) + return nil + } + wf11 := func(ctx workflow.Context) error { + workflow.SetQueryHandler(ctx, "query", func() (string, error) { return "v1.1", nil }) + started <- struct{}{} + workflow.GetSignalChannel(ctx, "wait").Receive(ctx, nil) + return nil + } + wf2 := func(ctx workflow.Context) error { + workflow.SetQueryHandler(ctx, "query", func() (string, error) { return "v2", nil }) + workflow.GetSignalChannel(ctx, "wait").Receive(ctx, nil) + return nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + s.addNewDefaultBuildId(ctx, tq, "v1") + s.waitForPropagation(ctx, tq, "v1") + + w1 := worker.New(s.sdkClient, tq, worker.Options{ + BuildID: s.prefixed("v1"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, + }) + w1.RegisterWorkflowWithOptions(wf1, workflow.RegisterOptions{Name: "wf"}) + s.NoError(w1.Start()) + defer w1.Stop() + + run, err := s.sdkClient.ExecuteWorkflow(ctx, sdkclient.StartWorkflowOptions{TaskQueue: tq}, "wf") + s.NoError(err) + // wait for it to start on v1 + s.waitForChan(ctx, started) + + // now register v1.1 as compatible + // now register v11 as newer compatible with v1 AND v2 as a new default + s.addCompatibleBuildId(ctx, tq, "v11", "v1", false) + s.addNewDefaultBuildId(ctx, tq, "v2") + s.waitForPropagation(ctx, tq, "v2") + // add another 100ms to make sure it got to sticky queues also + time.Sleep(100 * time.Millisecond) + + // start worker for v1.1 and v2 + w11 := worker.New(s.sdkClient, tq, worker.Options{ + BuildID: s.prefixed("v11"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, + }) + w11.RegisterWorkflowWithOptions(wf11, workflow.RegisterOptions{Name: "wf"}) + s.NoError(w11.Start()) + defer w11.Stop() + w2 := worker.New(s.sdkClient, tq, worker.Options{ + BuildID: s.prefixed("v2"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, + }) + w2.RegisterWorkflowWithOptions(wf2, workflow.RegisterOptions{Name: "wf"}) + s.NoError(w2.Start()) + defer w2.Stop() + + // wait for w1 long polls to all time out + time.Sleep(longPollTime) + + // query + val, err := s.sdkClient.QueryWorkflow(ctx, run.GetID(), run.GetRunID(), "query") + s.NoError(err) + var out string + s.NoError(val.Get(&out)) + s.Equal("v1.1", out) + + // let the workflow exit + s.NoError(s.sdkClient.SignalWorkflow(ctx, run.GetID(), run.GetRunID(), "wait", nil)) +} + func (s *versioningIntegSuite) TestDispatchContinueAsNew() { s.testWithMatchingBehavior(s.dispatchContinueAsNew) } From 27d0a1f48570517ebd9810144aa50377120c0f29 Mon Sep 17 00:00:00 2001 From: David Reiss Date: Thu, 8 Jun 2023 08:16:40 -0700 Subject: [PATCH 4/9] Use task queue base name for dynamic config (#4445) --- service/matching/config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/service/matching/config.go b/service/matching/config.go index 132b3c209ad..7e56a2e5e56 100644 --- a/service/matching/config.go +++ b/service/matching/config.go @@ -168,7 +168,7 @@ func NewConfig(dc *dynamicconfig.Collection) *Config { } func newTaskQueueConfig(id *taskQueueID, config *Config, namespace namespace.Name) *taskQueueConfig { - taskQueueName := id.FullName() + taskQueueName := id.BaseNameString() taskType := id.taskType return &taskQueueConfig{ From a4bc99111ff78c3404f492b311664901a9910aa9 Mon Sep 17 00:00:00 2001 From: Alex Shtin Date: Thu, 8 Jun 2023 10:56:02 -0700 Subject: [PATCH 5/9] Add more asserts on history events attributes (#4456) --- tests/update_workflow_test.go | 241 +++++++++++++++++----------------- 1 file changed, 120 insertions(+), 121 deletions(-) diff --git a/tests/update_workflow_test.go b/tests/update_workflow_test.go index 39dbd10d45e..15bb9dd49be 100644 --- a/tests/update_workflow_test.go +++ b/tests/update_workflow_test.go @@ -220,15 +220,15 @@ func (s *integrationSuite) TestUpdateWorkflow_NewSpeculativeWorkflowTask_AcceptC }}, }}, nil case 2: - // Speculative WT, with update.Request message. s.EqualHistory(` 1 WorkflowExecutionStarted 2 WorkflowTaskScheduled 3 WorkflowTaskStarted 4 WorkflowTaskCompleted 5 ActivityTaskScheduled - 6 WorkflowTaskScheduled - 7 WorkflowTaskStarted`, history) + 6 WorkflowTaskScheduled // Speculative WT events are not written to the history yet. + 7 WorkflowTaskStarted +`, history) return s.acceptCompleteUpdateCommands(tv, "1"), nil case 3: s.EqualHistory(` @@ -280,7 +280,7 @@ func (s *integrationSuite) TestUpdateWorkflow_NewSpeculativeWorkflowTask_AcceptC T: s.T(), } - // Drain exiting first workflow task. + // Drain first WT. _, err := poller.PollAndProcessWorkflowTask(true, false) s.NoError(err) @@ -314,10 +314,10 @@ func (s *integrationSuite) TestUpdateWorkflow_NewSpeculativeWorkflowTask_AcceptC 3 WorkflowTaskStarted 4 WorkflowTaskCompleted 5 ActivityTaskScheduled - 6 WorkflowTaskScheduled // Normal WT events were written for speculative WT at complete. + 6 WorkflowTaskScheduled // Was speculative WT... 7 WorkflowTaskStarted - 8 WorkflowTaskCompleted - 9 WorkflowExecutionUpdateAccepted + 8 WorkflowTaskCompleted // ...and events were written to the history when WT completes. + 9 WorkflowExecutionUpdateAccepted {"AcceptedRequestSequencingEventId": 6} // WTScheduled event which delivered update to the worker. 10 WorkflowExecutionUpdateCompleted {"AcceptedEventId": 9} 11 WorkflowTaskScheduled 12 WorkflowTaskStarted @@ -361,7 +361,8 @@ func (s *integrationSuite) TestUpdateWorkflow_FirstNormalStartedWorkflowTask_Acc s.EqualHistory(` 1 WorkflowExecutionStarted 2 WorkflowTaskScheduled - 3 WorkflowTaskStarted`, history) + 3 WorkflowTaskStarted // First normal WT. No speculative WT was created. +`, history) return s.acceptCompleteUpdateCommands(tv, "1"), nil case 2: s.EqualHistory(` @@ -443,8 +444,8 @@ func (s *integrationSuite) TestUpdateWorkflow_FirstNormalStartedWorkflowTask_Acc 2 WorkflowTaskScheduled 3 WorkflowTaskStarted 4 WorkflowTaskCompleted - 5 WorkflowExecutionUpdateAccepted - 6 WorkflowExecutionUpdateCompleted + 5 WorkflowExecutionUpdateAccepted {"AcceptedRequestSequencingEventId": 2} // WTScheduled event which delivered update to the worker. + 6 WorkflowExecutionUpdateCompleted {"AcceptedEventId": 5} 7 WorkflowTaskScheduled 8 WorkflowTaskStarted 9 WorkflowTaskCompleted @@ -501,7 +502,7 @@ func (s *integrationSuite) TestUpdateWorkflow_NormalScheduledWorkflowTask_Accept 4 WorkflowTaskCompleted 5 ActivityTaskScheduled 6 WorkflowExecutionSignaled - 7 WorkflowTaskScheduled + 7 WorkflowTaskScheduled // This WT was already created by signal and no speculative WT was created. 8 WorkflowTaskStarted`, history) return s.acceptCompleteUpdateCommands(tv, "1"), nil case 3: @@ -599,8 +600,8 @@ func (s *integrationSuite) TestUpdateWorkflow_NormalScheduledWorkflowTask_Accept 7 WorkflowTaskScheduled 8 WorkflowTaskStarted 9 WorkflowTaskCompleted - 10 WorkflowExecutionUpdateAccepted - 11 WorkflowExecutionUpdateCompleted + 10 WorkflowExecutionUpdateAccepted {"AcceptedRequestSequencingEventId": 7} // WTScheduled event which delivered update to the worker. + 11 WorkflowExecutionUpdateCompleted {"AcceptedEventId": 10} 12 WorkflowTaskScheduled 13 WorkflowTaskStarted 14 WorkflowTaskCompleted @@ -652,7 +653,7 @@ func (s *integrationSuite) TestUpdateWorkflow_ValidateWorkerMessages() { updRequest := unmarshalAny[*updatepb.Request](s, reqMsg.GetBody()) return []*protocolpb.Message{ { - Id: tv.Any(), // Random message Id. + Id: tv.Any(), ProtocolInstanceId: updRequest.GetMeta().GetUpdateId(), SequencingId: nil, Body: marshalAny(s, &updatepb.Acceptance{ @@ -945,7 +946,6 @@ func (s *integrationSuite) TestUpdateWorkflow_ValidateWorkerMessages() { // Process update in workflow. _, err := poller.PollAndProcessWorkflowTask(false, false) if tc.RespondWorkflowTaskError != "" { - // respond workflow task should return error require.Error(s.T(), err, "RespondWorkflowTaskCompleted should return an error contains `%v`", tc.RespondWorkflowTaskError) require.Contains(s.T(), err.Error(), tc.RespondWorkflowTaskError) } else { @@ -996,12 +996,11 @@ func (s *integrationSuite) TestUpdateWorkflow_NewStickySpeculativeWorkflowTask_A }}, }}, nil case 2: - // Speculative WT, with update.Request message. - // This WT contains partial history because it is sticky enabled. + // This WT contains partial history because sticky was enabled. s.EqualHistory(` 4 WorkflowTaskCompleted 5 ActivityTaskScheduled - 6 WorkflowTaskScheduled + 6 WorkflowTaskScheduled // Speculative WT. 7 WorkflowTaskStarted`, history) return s.acceptCompleteUpdateCommands(tv, "1"), nil case 3: @@ -1094,8 +1093,8 @@ func (s *integrationSuite) TestUpdateWorkflow_NewStickySpeculativeWorkflowTask_A 6 WorkflowTaskScheduled 7 WorkflowTaskStarted 8 WorkflowTaskCompleted - 9 WorkflowExecutionUpdateAccepted - 10 WorkflowExecutionUpdateCompleted + 9 WorkflowExecutionUpdateAccepted {"AcceptedRequestSequencingEventId": 6} // WTScheduled event which delivered update to the worker. + 10 WorkflowExecutionUpdateCompleted {"AcceptedEventId": 9} 11 WorkflowTaskScheduled 12 WorkflowTaskStarted 13 WorkflowTaskCompleted @@ -1125,7 +1124,6 @@ func (s *integrationSuite) TestUpdateWorkflow_NewStickySpeculativeWorkflowTask_A }}, }}, nil case 2: - // Speculative WT, with update.Request message. // Worker gets full history because update was issued after sticky worker is gone. s.EqualHistory(` 1 WorkflowExecutionStarted @@ -1133,8 +1131,9 @@ func (s *integrationSuite) TestUpdateWorkflow_NewStickySpeculativeWorkflowTask_A 3 WorkflowTaskStarted 4 WorkflowTaskCompleted 5 ActivityTaskScheduled - 6 WorkflowTaskScheduled - 7 WorkflowTaskStarted`, history) + 6 WorkflowTaskScheduled // Speculative WT. + 7 WorkflowTaskStarted +`, history) return s.acceptCompleteUpdateCommands(tv, "1"), nil case 3: s.EqualHistory(` @@ -1232,8 +1231,8 @@ func (s *integrationSuite) TestUpdateWorkflow_NewStickySpeculativeWorkflowTask_A 6 WorkflowTaskScheduled 7 WorkflowTaskStarted 8 WorkflowTaskCompleted - 9 WorkflowExecutionUpdateAccepted - 10 WorkflowExecutionUpdateCompleted + 9 WorkflowExecutionUpdateAccepted {"AcceptedRequestSequencingEventId": 6} // WTScheduled event which delivered update to the worker. + 10 WorkflowExecutionUpdateCompleted {"AcceptedEventId": 9} 11 WorkflowTaskScheduled 12 WorkflowTaskStarted 13 WorkflowTaskCompleted @@ -1327,10 +1326,10 @@ func (s *integrationSuite) TestUpdateWorkflow_FirstNormalWorkflowTask_Reject() { s.EqualHistoryEvents(` 1 WorkflowExecutionStarted - 2 WorkflowTaskScheduled + 2 WorkflowTaskScheduled // First normal WT was scheduled before update and therefore all 3 events have to be written even if update was rejected. 3 WorkflowTaskStarted 4 WorkflowTaskCompleted - 5 WorkflowTaskScheduled + 5 WorkflowTaskScheduled // Empty completed WT. No new events were created after it. 6 WorkflowTaskStarted 7 WorkflowTaskCompleted 8 WorkflowExecutionCompleted`, events) @@ -1362,12 +1361,13 @@ func (s *integrationSuite) TestUpdateWorkflow_NewSpeculativeWorkflowTask_Reject( 3 WorkflowTaskStarted 4 WorkflowTaskCompleted 5 ActivityTaskScheduled - 6 WorkflowTaskScheduled - 7 WorkflowTaskStarted`, history) + 6 WorkflowTaskScheduled // Speculative WT. + 7 WorkflowTaskStarted +`, history) return nil, nil case 3: s.EqualHistory(` - 4 WorkflowTaskCompleted + 4 WorkflowTaskCompleted // Speculative WT was dropped and history starts from 4 again. 5 ActivityTaskScheduled 6 WorkflowTaskScheduled 7 WorkflowTaskStarted`, history) @@ -1414,7 +1414,7 @@ func (s *integrationSuite) TestUpdateWorkflow_NewSpeculativeWorkflowTask_Reject( T: s.T(), } - // Drain existing first WT. + // Drain first WT. _, err := poller.PollAndProcessWorkflowTask(true, false) s.NoError(err) @@ -1447,7 +1447,7 @@ func (s *integrationSuite) TestUpdateWorkflow_NewSpeculativeWorkflowTask_Reject( 3 WorkflowTaskStarted 4 WorkflowTaskCompleted 5 ActivityTaskScheduled - 6 WorkflowTaskScheduled + 6 WorkflowTaskScheduled // Speculative WT is not present in the history. 7 WorkflowTaskStarted 8 WorkflowTaskCompleted 9 WorkflowExecutionCompleted`, events) @@ -1474,28 +1474,28 @@ func (s *integrationSuite) TestUpdateWorkflow_1stAccept_2ndAccept_2ndComplete_1s 2 WorkflowTaskScheduled 3 WorkflowTaskStarted 4 WorkflowTaskCompleted - 5 WorkflowExecutionUpdateAccepted - 6 WorkflowTaskScheduled + 5 WorkflowExecutionUpdateAccepted // 1st update is accepted. + 6 WorkflowTaskScheduled // New speculative WT is created because of the 2nd update. 7 WorkflowTaskStarted`, history) return s.acceptUpdateCommands(tv, "2"), nil case 3: s.EqualHistory(` 8 WorkflowTaskCompleted - 9 WorkflowExecutionUpdateAccepted + 9 WorkflowExecutionUpdateAccepted // 2nd update is accepted. 10 WorkflowTaskScheduled 11 WorkflowTaskStarted`, history) return s.completeUpdateCommands(tv, "2"), nil case 4: s.EqualHistory(` 12 WorkflowTaskCompleted - 13 WorkflowExecutionUpdateCompleted + 13 WorkflowExecutionUpdateCompleted // 2nd update is completed. 14 WorkflowTaskScheduled 15 WorkflowTaskStarted`, history) return s.completeUpdateCommands(tv, "1"), nil case 5: s.EqualHistory(` 16 WorkflowTaskCompleted - 17 WorkflowExecutionUpdateCompleted + 17 WorkflowExecutionUpdateCompleted // 1st update is completed. 18 WorkflowTaskScheduled 19 WorkflowTaskStarted`, history) return []*commandpb.Command{{ @@ -1604,19 +1604,19 @@ func (s *integrationSuite) TestUpdateWorkflow_1stAccept_2ndAccept_2ndComplete_1s 2 WorkflowTaskScheduled 3 WorkflowTaskStarted 4 WorkflowTaskCompleted - 5 WorkflowExecutionUpdateAccepted + 5 WorkflowExecutionUpdateAccepted {"AcceptedRequestSequencingEventId": 2} // WTScheduled event which delivered update to the worker. 6 WorkflowTaskScheduled 7 WorkflowTaskStarted 8 WorkflowTaskCompleted - 9 WorkflowExecutionUpdateAccepted + 9 WorkflowExecutionUpdateAccepted {"AcceptedRequestSequencingEventId": 6} // WTScheduled event which delivered update to the worker. 10 WorkflowTaskScheduled 11 WorkflowTaskStarted 12 WorkflowTaskCompleted - 13 WorkflowExecutionUpdateCompleted + 13 WorkflowExecutionUpdateCompleted {"AcceptedEventId": 9} 14 WorkflowTaskScheduled 15 WorkflowTaskStarted 16 WorkflowTaskCompleted - 17 WorkflowExecutionUpdateCompleted + 17 WorkflowExecutionUpdateCompleted {"AcceptedEventId": 5} 18 WorkflowTaskScheduled 19 WorkflowTaskStarted 20 WorkflowTaskCompleted @@ -1644,23 +1644,23 @@ func (s *integrationSuite) TestUpdateWorkflow_1stAccept_2ndReject_1stComplete() 2 WorkflowTaskScheduled 3 WorkflowTaskStarted 4 WorkflowTaskCompleted - 5 WorkflowExecutionUpdateAccepted - 6 WorkflowTaskScheduled - 7 WorkflowTaskStarted`, history) + 5 WorkflowExecutionUpdateAccepted // 1st update is accepted. + 6 WorkflowTaskScheduled // Speculative WT. Will disappear from the history because update is rejected. + 7 WorkflowTaskStarted +`, history) // Message handler rejects 2nd update. return nil, nil case 3: - // WT2 was speculative. s.EqualHistory(` - 4 WorkflowTaskCompleted - 5 WorkflowExecutionUpdateAccepted + 4 WorkflowTaskCompleted // WT2 (events 6 and 7) was speculative. + 5 WorkflowExecutionUpdateAccepted 6 WorkflowTaskScheduled 7 WorkflowTaskStarted`, history) return s.completeUpdateCommands(tv, "1"), nil case 4: s.EqualHistory(` 8 WorkflowTaskCompleted - 9 WorkflowExecutionUpdateCompleted + 9 WorkflowExecutionUpdateCompleted // 1st update is completed. 10 WorkflowTaskScheduled 11 WorkflowTaskStarted`, history) return []*commandpb.Command{{ @@ -1760,11 +1760,11 @@ func (s *integrationSuite) TestUpdateWorkflow_1stAccept_2ndReject_1stComplete() 2 WorkflowTaskScheduled 3 WorkflowTaskStarted 4 WorkflowTaskCompleted - 5 WorkflowExecutionUpdateAccepted + 5 WorkflowExecutionUpdateAccepted {"AcceptedRequestSequencingEventId": 2} // WTScheduled event which delivered update to the worker. 6 WorkflowTaskScheduled 7 WorkflowTaskStarted 8 WorkflowTaskCompleted - 9 WorkflowExecutionUpdateCompleted + 9 WorkflowExecutionUpdateCompleted {"AcceptedEventId": 5} 10 WorkflowTaskScheduled 11 WorkflowTaskStarted 12 WorkflowTaskCompleted @@ -1842,7 +1842,7 @@ func (s *integrationSuite) TestUpdateWorkflow_FailWorkflowTask() { return []*protocolpb.Message{ { Id: tv.MessageID("update-accepted", "1"), - ProtocolInstanceId: tv.Any(), // Random update Id. + ProtocolInstanceId: tv.Any(), SequencingId: nil, Body: marshalAny(s, &updatepb.Acceptance{ AcceptedRequestMessageId: updRequestMsg.GetId(), @@ -1906,8 +1906,7 @@ func (s *integrationSuite) TestUpdateWorkflow_FailWorkflowTask() { }, }) assert.Error(s.T(), err1) - // UpdateWorkflowExecution is timed out after 2 seconds. - assert.True(s.T(), common.IsContextDeadlineExceededErr(err1)) + assert.True(s.T(), common.IsContextDeadlineExceededErr(err1), "UpdateWorkflowExecution must timeout after 2 seconds") assert.Nil(s.T(), updateResponse) updateResultCh <- struct{}{} } @@ -1977,15 +1976,15 @@ func (s *integrationSuite) TestUpdateWorkflow_ConvertStartedSpeculativeWorkflowT }}, }}, nil case 2: - // Speculative WT, with update.Request message. s.EqualHistory(` 1 WorkflowExecutionStarted 2 WorkflowTaskScheduled 3 WorkflowTaskStarted 4 WorkflowTaskCompleted 5 ActivityTaskScheduled - 6 WorkflowTaskScheduled - 7 WorkflowTaskStarted`, history) + 6 WorkflowTaskScheduled // Speculative WT. Events 6 and 7 are written into the history when signal is received. + 7 WorkflowTaskStarted +`, history) // Send signal which will be buffered. This will persist MS and speculative WT must be converted to normal. err := s.sendSignal(s.namespace, tv.WorkflowExecution(), tv.Any(), payloads.EncodeString(tv.Any()), tv.Any()) s.NoError(err) @@ -1993,7 +1992,7 @@ func (s *integrationSuite) TestUpdateWorkflow_ConvertStartedSpeculativeWorkflowT case 3: s.EqualHistory(` 8 WorkflowTaskCompleted - 9 WorkflowExecutionSignaled + 9 WorkflowExecutionSignaled // It was buffered and got to the history after WT is completed. 10 WorkflowTaskScheduled 11 WorkflowTaskStarted`, history) return []*commandpb.Command{{ @@ -2073,7 +2072,7 @@ func (s *integrationSuite) TestUpdateWorkflow_ConvertStartedSpeculativeWorkflowT 5 ActivityTaskScheduled 6 WorkflowTaskScheduled 7 WorkflowTaskStarted - 8 WorkflowTaskCompleted + 8 WorkflowTaskCompleted // Update was rejected but events 6-8 are in the history because of buffered signal. 9 WorkflowExecutionSignaled 10 WorkflowTaskScheduled 11 WorkflowTaskStarted @@ -2102,14 +2101,13 @@ func (s *integrationSuite) TestUpdateWorkflow_ConvertScheduledSpeculativeWorkflo }}, }}, nil case 2: - // Speculative WT was already converted to normal because of the signal. s.EqualHistory(` 1 WorkflowExecutionStarted 2 WorkflowTaskScheduled 3 WorkflowTaskStarted 4 WorkflowTaskCompleted 5 ActivityTaskScheduled - 6 WorkflowTaskScheduled + 6 WorkflowTaskScheduled // It was initially speculative WT but was already converted to normal when signal was received. 7 WorkflowExecutionSignaled 8 WorkflowTaskStarted`, history) return nil, nil @@ -2203,7 +2201,7 @@ func (s *integrationSuite) TestUpdateWorkflow_ConvertScheduledSpeculativeWorkflo 6 WorkflowTaskScheduled 7 WorkflowExecutionSignaled 8 WorkflowTaskStarted - 9 WorkflowTaskCompleted + 9 WorkflowTaskCompleted // Update was rejected but events 6,8,9 are in the history because of signal. 10 WorkflowTaskScheduled 11 WorkflowTaskStarted 12 WorkflowTaskCompleted @@ -2243,15 +2241,15 @@ func (s *integrationSuite) TestUpdateWorkflow_StartToCloseTimeoutSpeculativeWork }}, }}, nil case 2: - // Speculative WT, with update.Request message. s.EqualHistory(` 1 WorkflowExecutionStarted 2 WorkflowTaskScheduled 3 WorkflowTaskStarted 4 WorkflowTaskCompleted 5 ActivityTaskScheduled - 6 WorkflowTaskScheduled - 7 WorkflowTaskStarted`, history) + 6 WorkflowTaskScheduled // Speculative WT. + 7 WorkflowTaskStarted +`, history) // Emulate slow worker: sleep more than WT timeout. time.Sleep(1*time.Second + 100*time.Millisecond) // This doesn't matter because WT times out before update is applied. @@ -2267,7 +2265,7 @@ func (s *integrationSuite) TestUpdateWorkflow_StartToCloseTimeoutSpeculativeWork 6 WorkflowTaskScheduled 7 WorkflowTaskStarted 8 WorkflowTaskTimedOut - 9 WorkflowTaskScheduled + 9 WorkflowTaskScheduled {"Attempt":2 } // Transient WT. 10 WorkflowTaskStarted`, history) commands := append(s.acceptUpdateCommands(tv, "1"), &commandpb.Command{ @@ -2314,7 +2312,7 @@ func (s *integrationSuite) TestUpdateWorkflow_StartToCloseTimeoutSpeculativeWork T: s.T(), } - // Start activity using existing workflow task. + // Drain first WT. _, err = poller.PollAndProcessWorkflowTask(true, false) s.NoError(err) @@ -2350,12 +2348,12 @@ func (s *integrationSuite) TestUpdateWorkflow_StartToCloseTimeoutSpeculativeWork 5 ActivityTaskScheduled 6 WorkflowTaskScheduled 7 WorkflowTaskStarted - 8 WorkflowTaskTimedOut - 9 WorkflowTaskScheduled + 8 WorkflowTaskTimedOut // Timeout of speculative WT writes events 6-8 + 9 WorkflowTaskScheduled {"Attempt":2 } 10 WorkflowTaskStarted 11 WorkflowTaskCompleted - 12 WorkflowExecutionUpdateAccepted - 13 WorkflowExecutionUpdateCompleted + 12 WorkflowExecutionUpdateAccepted {"AcceptedRequestSequencingEventId": 9} // WTScheduled event which delivered update to the worker. + 13 WorkflowExecutionUpdateCompleted {"AcceptedEventId": 12} 14 WorkflowExecutionCompleted`, events) } @@ -2380,16 +2378,16 @@ func (s *integrationSuite) TestUpdateWorkflow_ScheduleToStartTimeoutSpeculativeW }}, }}, nil case 2: - // Speculative WT, timed out on sticky task queue. Server sent full history with sticky timeout event. + // Speculative WT timed out on sticky task queue. Server sent full history with sticky timeout event. s.EqualHistory(` 1 WorkflowExecutionStarted 2 WorkflowTaskScheduled 3 WorkflowTaskStarted 4 WorkflowTaskCompleted 5 ActivityTaskScheduled - 6 WorkflowTaskScheduled + 6 WorkflowTaskScheduled // Speculative WT. 7 WorkflowTaskTimedOut - 8 WorkflowTaskScheduled + 8 WorkflowTaskScheduled {"Attempt":1} // Normal WT. 9 WorkflowTaskStarted`, history) return nil, nil case 3: @@ -2470,11 +2468,11 @@ func (s *integrationSuite) TestUpdateWorkflow_ScheduleToStartTimeoutSpeculativeW 3 WorkflowTaskStarted 4 WorkflowTaskCompleted 5 ActivityTaskScheduled - 6 WorkflowTaskScheduled + 6 WorkflowTaskScheduled // Speculative WT was written into the history because of timeout. 7 WorkflowTaskTimedOut - 8 WorkflowTaskScheduled + 8 WorkflowTaskScheduled {"Attempt":1} // Second attempt WT is normal WT (clear stickiness reset attempts count). 9 WorkflowTaskStarted - 10 WorkflowTaskCompleted + 10 WorkflowTaskCompleted // Normal WT is completed and events are in the history even update was rejected. 11 WorkflowTaskScheduled 12 WorkflowTaskStarted 13 WorkflowTaskCompleted @@ -2516,7 +2514,7 @@ func (s *integrationSuite) TestUpdateWorkflow_StartedSpeculativeWorkflowTask_Ter 3 WorkflowTaskStarted 4 WorkflowTaskCompleted 5 ActivityTaskScheduled - 6 WorkflowTaskScheduled + 6 WorkflowTaskScheduled // Speculative WT. 7 WorkflowTaskStarted`, history) return s.acceptCompleteUpdateCommands(tv, "1"), nil default: @@ -2596,7 +2594,7 @@ func (s *integrationSuite) TestUpdateWorkflow_StartedSpeculativeWorkflowTask_Ter 3 WorkflowTaskStarted 4 WorkflowTaskCompleted 5 ActivityTaskScheduled - 6 WorkflowTaskScheduled + 6 WorkflowTaskScheduled // Speculative WT was converted to normal WT during termination. 7 WorkflowTaskStarted 8 WorkflowTaskFailed 9 WorkflowExecutionTerminated`, events) @@ -2606,8 +2604,7 @@ func (s *integrationSuite) TestUpdateWorkflow_StartedSpeculativeWorkflowTask_Ter Execution: tv.WorkflowExecution(), }) s.NoError(err) - // completion_event_batch_id should point to WTFailed event. - s.EqualValues(8, msResp.GetDatabaseMutableState().GetExecutionInfo().GetCompletionEventBatchId()) + s.EqualValues(8, msResp.GetDatabaseMutableState().GetExecutionInfo().GetCompletionEventBatchId(), "completion_event_batch_id should point to WTFailed event") } func (s *integrationSuite) TestUpdateWorkflow_ScheduledSpeculativeWorkflowTask_TerminateWorkflow() { @@ -2707,15 +2704,15 @@ func (s *integrationSuite) TestUpdateWorkflow_ScheduledSpeculativeWorkflowTask_T 3 WorkflowTaskStarted 4 WorkflowTaskCompleted 5 ActivityTaskScheduled - 6 WorkflowExecutionTerminated`, events) + 6 WorkflowExecutionTerminated // Speculative WTScheduled event is not written to history if WF is terminated. +`, events) msResp, err := s.adminClient.DescribeMutableState(NewContext(), &adminservice.DescribeMutableStateRequest{ Namespace: s.namespace, Execution: tv.WorkflowExecution(), }) s.NoError(err) - // completion_event_batch_id should point to WFTerminated event. - s.EqualValues(6, msResp.GetDatabaseMutableState().GetExecutionInfo().GetCompletionEventBatchId()) + s.EqualValues(6, msResp.GetDatabaseMutableState().GetExecutionInfo().GetCompletionEventBatchId(), "completion_event_batch_id should point to WFTerminated event") } func (s *integrationSuite) TestUpdateWorkflow_CompleteWorkflow_TerminateUpdate() { @@ -2806,7 +2803,7 @@ func (s *integrationSuite) TestUpdateWorkflow_CompleteWorkflow_TerminateUpdate() T: s.T(), } - // Drain exiting first workflow task. + // Drain first WT. _, err := poller.PollAndProcessWorkflowTask(true, false) s.NoError(err) @@ -2870,24 +2867,23 @@ func (s *integrationSuite) TestUpdateWorkflow_SpeculativeWorkflowTask_Heartbeat( }}, }}, nil case 2: - // Last two events (6 and 7) are for speculative WT, but they won't disappear after reject because - // speculative WT is converted to normal during heartbeat. s.EqualHistory(` 1 WorkflowExecutionStarted 2 WorkflowTaskScheduled 3 WorkflowTaskStarted 4 WorkflowTaskCompleted 5 ActivityTaskScheduled - 6 WorkflowTaskScheduled - 7 WorkflowTaskStarted`, history) + 6 WorkflowTaskScheduled // Events (6 and 7) are for speculative WT, but they won't disappear after reject because speculative WT is converted to normal during heartbeat. + 7 WorkflowTaskStarted +`, history) // Heartbeat from speculative WT (no messages, no commands). return nil, nil case 3: - // New WT (after heartbeat) is normal and won't disappear from the history after reject. s.EqualHistory(` 8 WorkflowTaskCompleted - 9 WorkflowTaskScheduled - 10 WorkflowTaskStarted`, history) + 9 WorkflowTaskScheduled // New WT (after heartbeat) is normal and won't disappear from the history after reject. + 10 WorkflowTaskStarted +`, history) // Reject update. return nil, nil case 4: @@ -2936,7 +2932,7 @@ func (s *integrationSuite) TestUpdateWorkflow_SpeculativeWorkflowTask_Heartbeat( T: s.T(), } - // Drain exiting first WT. + // Drain first WT. _, err := poller.PollAndProcessWorkflowTask(true, false) s.NoError(err) @@ -2967,6 +2963,7 @@ func (s *integrationSuite) TestUpdateWorkflow_SpeculativeWorkflowTask_Heartbeat( s.Equal(4, msgHandlerCalls) events := s.getHistory(s.namespace, tv.WorkflowExecution()) + s.EqualHistoryEvents(` 1 WorkflowExecutionStarted 2 WorkflowTaskScheduled @@ -2975,10 +2972,10 @@ func (s *integrationSuite) TestUpdateWorkflow_SpeculativeWorkflowTask_Heartbeat( 5 ActivityTaskScheduled 6 WorkflowTaskScheduled 7 WorkflowTaskStarted - 8 WorkflowTaskCompleted + 8 WorkflowTaskCompleted // Heartbeat response. 9 WorkflowTaskScheduled 10 WorkflowTaskStarted - 11 WorkflowTaskCompleted + 11 WorkflowTaskCompleted // After heartbeat new normal WT was created and events are written into the history even update is rejected. 12 WorkflowTaskScheduled 13 WorkflowTaskStarted 14 WorkflowTaskCompleted @@ -3050,7 +3047,7 @@ func (s *integrationSuite) TestUpdateWorkflow_NewScheduledSpeculativeWorkflowTas T: s.T(), } - // Drain exiting first workflow task. + // Drain first WT. _, err := poller.PollAndProcessWorkflowTask(true, false) s.NoError(err) @@ -3143,17 +3140,17 @@ func (s *integrationSuite) TestUpdateWorkflow_NewStartedSpeculativeWorkflowTaskL }}, }}, nil case 2: - // Speculative WT, with update.Request message. s.EqualHistory(` 1 WorkflowExecutionStarted 2 WorkflowTaskScheduled 3 WorkflowTaskStarted 4 WorkflowTaskCompleted 5 ActivityTaskScheduled - 6 WorkflowTaskScheduled - 7 WorkflowTaskStarted`, history) + 6 WorkflowTaskScheduled // Speculative WT. Events 6 and 7 will be lost. + 7 WorkflowTaskStarted +`, history) - // Close shard, Speculative WT (6 and 7) will be lost, and NotFound error will be returned to RespondWorkflowTaskCompleted. + // Close shard. NotFound error will be returned to RespondWorkflowTaskCompleted. s.closeShard(tv.WorkflowID()) return s.acceptCompleteUpdateCommands(tv, "1"), nil @@ -3207,7 +3204,7 @@ func (s *integrationSuite) TestUpdateWorkflow_NewStartedSpeculativeWorkflowTaskL T: s.T(), } - // Drain exiting first workflow task. + // Drain first WT. _, err := poller.PollAndProcessWorkflowTask(true, false) s.NoError(err) @@ -3292,15 +3289,15 @@ func (s *integrationSuite) TestUpdateWorkflow_ScheduledSpeculativeWorkflowTask_D }}, }}, nil case 2: - // Speculative WT, with update.Request message. s.EqualHistory(` 1 WorkflowExecutionStarted 2 WorkflowTaskScheduled 3 WorkflowTaskStarted 4 WorkflowTaskCompleted 5 ActivityTaskScheduled - 6 WorkflowTaskScheduled - 7 WorkflowTaskStarted`, history) + 6 WorkflowTaskScheduled // Speculative WT. + 7 WorkflowTaskStarted +`, history) return s.acceptCompleteUpdateCommands(tv, "1"), nil case 3: s.EqualHistory(` @@ -3349,7 +3346,7 @@ func (s *integrationSuite) TestUpdateWorkflow_ScheduledSpeculativeWorkflowTask_D T: s.T(), } - // Drain exiting first workflow task. + // Drain first WT. _, err := poller.PollAndProcessWorkflowTask(true, false) s.NoError(err) @@ -3394,8 +3391,8 @@ func (s *integrationSuite) TestUpdateWorkflow_ScheduledSpeculativeWorkflowTask_D 6 WorkflowTaskScheduled 7 WorkflowTaskStarted 8 WorkflowTaskCompleted - 9 WorkflowExecutionUpdateAccepted - 10 WorkflowExecutionUpdateCompleted + 9 WorkflowExecutionUpdateAccepted {"AcceptedRequestSequencingEventId": 6} // WTScheduled event which delivered update to the worker. + 10 WorkflowExecutionUpdateCompleted {"AcceptedEventId": 9} 11 WorkflowTaskScheduled 12 WorkflowTaskStarted 13 WorkflowTaskCompleted @@ -3430,15 +3427,15 @@ func (s *integrationSuite) TestUpdateWorkflow_StartedSpeculativeWorkflowTask_Ded updateResultCh2 <- s.sendUpdateNoError(tv, "1") }() - // Speculative WT, with update.Request message. s.EqualHistory(` 1 WorkflowExecutionStarted 2 WorkflowTaskScheduled 3 WorkflowTaskStarted 4 WorkflowTaskCompleted 5 ActivityTaskScheduled - 6 WorkflowTaskScheduled - 7 WorkflowTaskStarted`, history) + 6 WorkflowTaskScheduled // Speculative WT. + 7 WorkflowTaskStarted +`, history) return s.acceptCompleteUpdateCommands(tv, "1"), nil case 3: s.EqualHistory(` @@ -3486,7 +3483,7 @@ func (s *integrationSuite) TestUpdateWorkflow_StartedSpeculativeWorkflowTask_Ded T: s.T(), } - // Drain exiting first workflow task. + // Drain first WT. _, err := poller.PollAndProcessWorkflowTask(true, false) s.NoError(err) @@ -3526,8 +3523,8 @@ func (s *integrationSuite) TestUpdateWorkflow_StartedSpeculativeWorkflowTask_Ded 6 WorkflowTaskScheduled 7 WorkflowTaskStarted 8 WorkflowTaskCompleted - 9 WorkflowExecutionUpdateAccepted - 10 WorkflowExecutionUpdateCompleted + 9 WorkflowExecutionUpdateAccepted {"AcceptedRequestSequencingEventId": 6} // WTScheduled event which delivered update to the worker. + 10 WorkflowExecutionUpdateCompleted {"AcceptedEventId": 9} 11 WorkflowTaskScheduled 12 WorkflowTaskStarted 13 WorkflowTaskCompleted @@ -3571,15 +3568,15 @@ func (s *integrationSuite) TestUpdateWorkflow_CompletedSpeculativeWorkflowTask_D }}, }}, nil case 2: - // Speculative WT, with update.Request message. s.EqualHistory(` 1 WorkflowExecutionStarted 2 WorkflowTaskScheduled 3 WorkflowTaskStarted 4 WorkflowTaskCompleted 5 ActivityTaskScheduled - 6 WorkflowTaskScheduled - 7 WorkflowTaskStarted`, history) + 6 WorkflowTaskScheduled // Speculative WT. + 7 WorkflowTaskStarted +`, history) return s.acceptCompleteUpdateCommands(tv, "1"), nil case 3: return []*commandpb.Command{{ @@ -3620,7 +3617,7 @@ func (s *integrationSuite) TestUpdateWorkflow_CompletedSpeculativeWorkflowTask_D T: s.T(), } - // Drain exiting first workflow task. + // Drain first WT. _, err := poller.PollAndProcessWorkflowTask(true, false) s.NoError(err) @@ -3655,11 +3652,13 @@ func (s *integrationSuite) TestUpdateWorkflow_CompletedSpeculativeWorkflowTask_D Identity: tv.WorkerIdentity(), }) s.NoError(err) - s.Nil(pollResponse.Messages) + s.Nil(pollResponse.Messages, "there must be no new WT") - // But results of the first update are available. updateResult2 := <-updateResultCh2 - s.EqualValues(tv.String("success-result", "1"), decodeString(s, updateResult2.GetOutcome().GetSuccess())) + s.EqualValues( + tv.String("success-result", "1"), + decodeString(s, updateResult2.GetOutcome().GetSuccess()), + "results of the first update must be available") // Send signal to schedule new WT. err = s.sendSignal(s.namespace, tv.WorkflowExecution(), tv.Any(), payloads.EncodeString(tv.Any()), tv.Any()) @@ -3684,8 +3683,8 @@ func (s *integrationSuite) TestUpdateWorkflow_CompletedSpeculativeWorkflowTask_D 6 WorkflowTaskScheduled 7 WorkflowTaskStarted 8 WorkflowTaskCompleted - 9 WorkflowExecutionUpdateAccepted - 10 WorkflowExecutionUpdateCompleted + 9 WorkflowExecutionUpdateAccepted {"AcceptedRequestSequencingEventId": 6} // WTScheduled event which delivered update to the worker. + 10 WorkflowExecutionUpdateCompleted {"AcceptedEventId": 9} 11 WorkflowExecutionSignaled 12 WorkflowTaskScheduled 13 WorkflowTaskStarted From ac1944fde4f89205dc335de769011696f7858a1f Mon Sep 17 00:00:00 2001 From: wxing1292 Date: Thu, 8 Jun 2023 11:18:04 -0700 Subject: [PATCH 6/9] Enforce uniqueness of replication stream (#4448) * Enforce uniqueness of replication stream within stream monitor for both sender side & receiver side * Consolidate replication stream implementations; move logic to replication package --- .golangci.yml | 2 +- service/history/api/replication/stream.go | 455 ------------------ .../history/api/replication/stream_mock.go | 75 --- service/history/handler.go | 34 +- .../history/replication/grpc_stream_client.go | 16 +- service/history/replication/stream.go | 58 +++ .../history/replication/stream_receiver.go | 44 +- .../replication/stream_receiver_mock.go | 110 +++++ .../replication/stream_receiver_monitor.go | 119 ++++- .../stream_receiver_monitor_test.go | 306 +++++++++--- .../replication/stream_receiver_test.go | 2 +- service/history/replication/stream_sender.go | 422 ++++++++++++++++ .../history/replication/stream_sender_mock.go | 150 ++++++ .../stream_sender_test.go} | 146 +++--- 14 files changed, 1197 insertions(+), 742 deletions(-) delete mode 100644 service/history/api/replication/stream.go delete mode 100644 service/history/api/replication/stream_mock.go create mode 100644 service/history/replication/stream.go create mode 100644 service/history/replication/stream_receiver_mock.go create mode 100644 service/history/replication/stream_sender.go create mode 100644 service/history/replication/stream_sender_mock.go rename service/history/{api/replication/stream_test.go => replication/stream_sender_test.go} (82%) diff --git a/.golangci.yml b/.golangci.yml index 83760ad14d3..94bc6ce1b9d 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -5,7 +5,7 @@ linters: - goerr113 - errcheck - goimports - - paralleltest + # - paralleltest # missing the call to method parallel, but testify does not seem to work well with parallel test: https://github.com/stretchr/testify/issues/187 - revive # revive supersedes golint, which is now archived - staticcheck - vet diff --git a/service/history/api/replication/stream.go b/service/history/api/replication/stream.go deleted file mode 100644 index d79f6bfe1ea..00000000000 --- a/service/history/api/replication/stream.go +++ /dev/null @@ -1,455 +0,0 @@ -// The MIT License -// -// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. -// -// Copyright (c) 2020 Uber Technologies, Inc. -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -//go:generate mockgen -copyright_file ../../../../LICENSE -package $GOPACKAGE -source $GOFILE -destination stream_mock.go - -package replication - -import ( - "context" - "fmt" - "math" - - "go.temporal.io/api/serviceerror" - "golang.org/x/sync/errgroup" - - enumsspb "go.temporal.io/server/api/enums/v1" - "go.temporal.io/server/api/historyservice/v1" - persistencespb "go.temporal.io/server/api/persistence/v1" - replicationspb "go.temporal.io/server/api/replication/v1" - historyclient "go.temporal.io/server/client/history" - "go.temporal.io/server/common" - "go.temporal.io/server/common/cluster" - "go.temporal.io/server/common/log/tag" - "go.temporal.io/server/common/metrics" - "go.temporal.io/server/common/namespace" - "go.temporal.io/server/common/primitives/timestamp" - "go.temporal.io/server/service/history/replication" - "go.temporal.io/server/service/history/shard" - "go.temporal.io/server/service/history/tasks" -) - -type ( - TaskConvertorImpl struct { - Ctx context.Context - Engine shard.Engine - NamespaceCache namespace.Registry - ClientClusterShardCount int32 - ClientClusterName string - ClientClusterShardID historyclient.ClusterShardID - } - TaskConvertor interface { - Convert(task tasks.Task) (*replicationspb.ReplicationTask, error) - } -) - -func StreamReplicationTasks( - server historyservice.HistoryService_StreamWorkflowReplicationMessagesServer, - shardContext shard.Context, - clientClusterShardID historyclient.ClusterShardID, - serverClusterShardID historyclient.ClusterShardID, -) error { - allClusterInfo := shardContext.GetClusterMetadata().GetAllClusterInfo() - clientClusterName, clientShardCount, err := clusterIDToClusterNameShardCount(allClusterInfo, clientClusterShardID.ClusterID) - if err != nil { - return err - } - _, serverShardCount, err := clusterIDToClusterNameShardCount(allClusterInfo, int32(shardContext.GetClusterMetadata().GetClusterID())) - if err != nil { - return err - } - err = common.VerifyShardIDMapping(clientShardCount, serverShardCount, clientClusterShardID.ShardID, serverClusterShardID.ShardID) - if err != nil { - return err - } - engine, err := shardContext.GetEngine(server.Context()) - if err != nil { - return err - } - filter := &TaskConvertorImpl{ - Ctx: server.Context(), - Engine: engine, - NamespaceCache: shardContext.GetNamespaceRegistry(), - ClientClusterShardCount: clientShardCount, - ClientClusterName: clientClusterName, - ClientClusterShardID: clientClusterShardID, - } - - errGroup, ctx := errgroup.WithContext(server.Context()) - errGroup.Go(func() error { - return recvLoop(ctx, server, shardContext, clientClusterShardID, serverClusterShardID) - }) - errGroup.Go(func() error { - return sendLoop(ctx, server, shardContext, filter, clientClusterShardID, serverClusterShardID) - }) - return errGroup.Wait() -} - -func recvLoop( - ctx context.Context, - server historyservice.HistoryService_StreamWorkflowReplicationMessagesServer, - shardContext shard.Context, - clientClusterShardID historyclient.ClusterShardID, - serverClusterShardID historyclient.ClusterShardID, -) error { - for ctx.Err() == nil { - req, err := server.Recv() - if err != nil { - return err - } - shardContext.GetLogger().Debug(fmt.Sprintf( - "cluster shard ID %v/%v <- cluster shard ID %v/%v", - serverClusterShardID.ClusterID, serverClusterShardID.ShardID, - clientClusterShardID.ClusterID, clientClusterShardID.ShardID, - )) - switch attr := req.GetAttributes().(type) { - case *historyservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState: - if err := recvSyncReplicationState( - shardContext, - attr.SyncReplicationState, - clientClusterShardID, - ); err != nil { - shardContext.GetLogger().Error( - "StreamWorkflowReplication unable to handle SyncReplicationState", - tag.Error(err), - tag.ShardID(shardContext.GetShardID()), - ) - return err - } - shardContext.GetMetricsHandler().Counter(metrics.ReplicationTasksRecv.GetMetricName()).Record( - int64(1), - metrics.FromClusterIDTag(clientClusterShardID.ClusterID), - metrics.ToClusterIDTag(serverClusterShardID.ClusterID), - metrics.OperationTag(metrics.SyncWatermarkScope), - ) - default: - return serviceerror.NewInternal(fmt.Sprintf( - "StreamReplicationMessages encountered unknown type: %T %v", attr, attr, - )) - } - } - return ctx.Err() -} - -func recvSyncReplicationState( - shardContext shard.Context, - attr *replicationspb.SyncReplicationState, - clientClusterShardID historyclient.ClusterShardID, -) error { - inclusiveLowWatermark := attr.GetInclusiveLowWatermark() - inclusiveLowWatermarkTime := attr.GetInclusiveLowWatermarkTime() - - readerID := shard.ReplicationReaderIDFromClusterShardID( - int64(clientClusterShardID.ClusterID), - clientClusterShardID.ShardID, - ) - readerState := &persistencespb.QueueReaderState{ - Scopes: []*persistencespb.QueueSliceScope{{ - Range: &persistencespb.QueueSliceRange{ - InclusiveMin: shard.ConvertToPersistenceTaskKey( - tasks.NewImmediateKey(inclusiveLowWatermark), - ), - ExclusiveMax: shard.ConvertToPersistenceTaskKey( - tasks.NewImmediateKey(math.MaxInt64), - ), - }, - Predicate: &persistencespb.Predicate{ - PredicateType: enumsspb.PREDICATE_TYPE_UNIVERSAL, - Attributes: &persistencespb.Predicate_UniversalPredicateAttributes{}, - }, - }}, - } - if err := shardContext.UpdateReplicationQueueReaderState( - readerID, - readerState, - ); err != nil { - return err - } - shardContext.UpdateRemoteClusterInfo( - string(clientClusterShardID.ClusterID), - inclusiveLowWatermark-1, - *inclusiveLowWatermarkTime, - ) - return nil -} - -func sendLoop( - ctx context.Context, - server historyservice.HistoryService_StreamWorkflowReplicationMessagesServer, - shardContext shard.Context, - taskConvertor TaskConvertor, - clientClusterShardID historyclient.ClusterShardID, - serverClusterShardID historyclient.ClusterShardID, -) error { - engine, err := shardContext.GetEngine(ctx) - if err != nil { - return err - } - newTaskNotificationChan, subscriberID := engine.SubscribeReplicationNotification() - defer engine.UnsubscribeReplicationNotification(subscriberID) - - catchupEndExclusiveWatermark, err := sendCatchUp( - ctx, - server, - shardContext, - taskConvertor, - clientClusterShardID, - serverClusterShardID, - ) - if err != nil { - shardContext.GetLogger().Error( - "StreamWorkflowReplication unable to catch up replication tasks", - tag.Error(err), - ) - return err - } - if err := sendLive( - ctx, - server, - shardContext, - taskConvertor, - clientClusterShardID, - serverClusterShardID, - newTaskNotificationChan, - catchupEndExclusiveWatermark, - ); err != nil { - shardContext.GetLogger().Error( - "StreamWorkflowReplication unable to stream replication tasks", - tag.Error(err), - ) - return err - } - shardContext.GetLogger().Info("StreamWorkflowReplication finish", tag.ShardID(shardContext.GetShardID())) - return nil -} - -func sendCatchUp( - ctx context.Context, - server historyservice.HistoryService_StreamWorkflowReplicationMessagesServer, - shardContext shard.Context, - taskConvertor TaskConvertor, - clientClusterShardID historyclient.ClusterShardID, - serverClusterShardID historyclient.ClusterShardID, -) (int64, error) { - - readerID := shard.ReplicationReaderIDFromClusterShardID( - int64(clientClusterShardID.ClusterID), - clientClusterShardID.ShardID, - ) - - var catchupBeginInclusiveWatermark int64 - queueState, ok := shardContext.GetQueueState( - tasks.CategoryReplication, - ) - if !ok { - catchupBeginInclusiveWatermark = 0 - } else { - readerState, ok := queueState.ReaderStates[readerID] - if !ok { - catchupBeginInclusiveWatermark = 0 - } else { - catchupBeginInclusiveWatermark = readerState.Scopes[0].Range.InclusiveMin.TaskId - } - } - catchupEndExclusiveWatermark := shardContext.GetImmediateQueueExclusiveHighReadWatermark().TaskID - if err := sendTasks( - ctx, - server, - shardContext, - taskConvertor, - clientClusterShardID, - serverClusterShardID, - catchupBeginInclusiveWatermark, - catchupEndExclusiveWatermark, - ); err != nil { - return 0, err - } - return catchupEndExclusiveWatermark, nil -} - -func sendLive( - ctx context.Context, - server historyservice.HistoryService_StreamWorkflowReplicationMessagesServer, - shardContext shard.Context, - taskConvertor TaskConvertor, - clientClusterShardID historyclient.ClusterShardID, - serverClusterShardID historyclient.ClusterShardID, - newTaskNotificationChan <-chan struct{}, - beginInclusiveWatermark int64, -) error { - for { - select { - case <-newTaskNotificationChan: - endExclusiveWatermark := shardContext.GetImmediateQueueExclusiveHighReadWatermark().TaskID - if err := sendTasks( - ctx, - server, - shardContext, - taskConvertor, - clientClusterShardID, - serverClusterShardID, - beginInclusiveWatermark, - endExclusiveWatermark, - ); err != nil { - return err - } - beginInclusiveWatermark = endExclusiveWatermark - case <-ctx.Done(): - return ctx.Err() - } - } -} - -func sendTasks( - ctx context.Context, - server historyservice.HistoryService_StreamWorkflowReplicationMessagesServer, - shardContext shard.Context, - taskConvertor TaskConvertor, - clientClusterShardID historyclient.ClusterShardID, - serverClusterShardID historyclient.ClusterShardID, - beginInclusiveWatermark int64, - endExclusiveWatermark int64, -) error { - if beginInclusiveWatermark > endExclusiveWatermark { - err := serviceerror.NewInternal(fmt.Sprintf("StreamWorkflowReplication encountered invalid task range [%v, %v)", - beginInclusiveWatermark, - endExclusiveWatermark, - )) - shardContext.GetLogger().Error("StreamWorkflowReplication unable to", tag.Error(err)) - return err - } - if beginInclusiveWatermark == endExclusiveWatermark { - return server.Send(&historyservice.StreamWorkflowReplicationMessagesResponse{ - Attributes: &historyservice.StreamWorkflowReplicationMessagesResponse_Messages{ - Messages: &replicationspb.WorkflowReplicationMessages{ - ReplicationTasks: nil, - ExclusiveHighWatermark: endExclusiveWatermark, - ExclusiveHighWatermarkTime: timestamp.TimeNowPtrUtc(), - }, - }, - }) - } - - engine, err := shardContext.GetEngine(ctx) - if err != nil { - return err - } - iter, err := engine.GetReplicationTasksIter( - ctx, - string(clientClusterShardID.ClusterID), - beginInclusiveWatermark, - endExclusiveWatermark, - ) - if err != nil { - return err - } -Loop: - for iter.HasNext() { - if ctx.Err() != nil { - return ctx.Err() - } - - item, err := iter.Next() - if err != nil { - return err - } - task, err := taskConvertor.Convert(item) - if err != nil { - return err - } - if task == nil { - continue Loop - } - if err := server.Send(&historyservice.StreamWorkflowReplicationMessagesResponse{ - Attributes: &historyservice.StreamWorkflowReplicationMessagesResponse_Messages{ - Messages: &replicationspb.WorkflowReplicationMessages{ - ReplicationTasks: []*replicationspb.ReplicationTask{task}, - ExclusiveHighWatermark: task.SourceTaskId + 1, - ExclusiveHighWatermarkTime: task.VisibilityTime, - }, - }, - }); err != nil { - return err - } - shardContext.GetMetricsHandler().Counter(metrics.ReplicationTasksSend.GetMetricName()).Record( - int64(1), - metrics.FromClusterIDTag(serverClusterShardID.ClusterID), - metrics.ToClusterIDTag(clientClusterShardID.ClusterID), - metrics.OperationTag(replication.TaskOperationTag(task)), - ) - } - return server.Send(&historyservice.StreamWorkflowReplicationMessagesResponse{ - Attributes: &historyservice.StreamWorkflowReplicationMessagesResponse_Messages{ - Messages: &replicationspb.WorkflowReplicationMessages{ - ReplicationTasks: nil, - ExclusiveHighWatermark: endExclusiveWatermark, - ExclusiveHighWatermarkTime: timestamp.TimeNowPtrUtc(), - }, - }, - }) -} - -func (f *TaskConvertorImpl) Convert( - task tasks.Task, -) (*replicationspb.ReplicationTask, error) { - if namespaceEntry, err := f.NamespaceCache.GetNamespaceByID( - namespace.ID(task.GetNamespaceID()), - ); err == nil { - shouldProcessTask := false - FilterLoop: - for _, targetCluster := range namespaceEntry.ClusterNames() { - if f.ClientClusterName == targetCluster { - shouldProcessTask = true - break FilterLoop - } - } - if !shouldProcessTask { - return nil, nil - } - } - // if there is error, then blindly send the task, better safe than sorry - - sourceShardID := common.WorkflowIDToHistoryShard(task.GetNamespaceID(), task.GetWorkflowID(), f.ClientClusterShardCount) - if sourceShardID != f.ClientClusterShardID.ShardID { - return nil, nil - } - - replicationTask, err := f.Engine.ConvertReplicationTask(f.Ctx, task) - if err != nil { - return nil, err - } - return replicationTask, nil -} - -func clusterIDToClusterNameShardCount( - allClusterInfo map[string]cluster.ClusterInformation, - clusterID int32, -) (string, int32, error) { - for clusterName, clusterInfo := range allClusterInfo { - if int32(clusterInfo.InitialFailoverVersion) == clusterID { - return clusterName, clusterInfo.ShardCount, nil - } - } - return "", 0, serviceerror.NewInternal(fmt.Sprintf("unknown cluster ID: %v", clusterID)) -} diff --git a/service/history/api/replication/stream_mock.go b/service/history/api/replication/stream_mock.go deleted file mode 100644 index d6086deb7a7..00000000000 --- a/service/history/api/replication/stream_mock.go +++ /dev/null @@ -1,75 +0,0 @@ -// The MIT License -// -// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. -// -// Copyright (c) 2020 Uber Technologies, Inc. -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -// Code generated by MockGen. DO NOT EDIT. -// Source: stream.go - -// Package replication is a generated GoMock package. -package replication - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - repication "go.temporal.io/server/api/replication/v1" - tasks "go.temporal.io/server/service/history/tasks" -) - -// MockTaskConvertor is a mock of TaskConvertor interface. -type MockTaskConvertor struct { - ctrl *gomock.Controller - recorder *MockTaskConvertorMockRecorder -} - -// MockTaskConvertorMockRecorder is the mock recorder for MockTaskConvertor. -type MockTaskConvertorMockRecorder struct { - mock *MockTaskConvertor -} - -// NewMockTaskConvertor creates a new mock instance. -func NewMockTaskConvertor(ctrl *gomock.Controller) *MockTaskConvertor { - mock := &MockTaskConvertor{ctrl: ctrl} - mock.recorder = &MockTaskConvertorMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockTaskConvertor) EXPECT() *MockTaskConvertorMockRecorder { - return m.recorder -} - -// Convert mocks base method. -func (m *MockTaskConvertor) Convert(task tasks.Task) (*repication.ReplicationTask, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Convert", task) - ret0, _ := ret[0].(*repication.ReplicationTask) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Convert indicates an expected call of Convert. -func (mr *MockTaskConvertorMockRecorder) Convert(task interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Convert", reflect.TypeOf((*MockTaskConvertor)(nil).Convert), task) -} diff --git a/service/history/handler.go b/service/history/handler.go index 6d6457ebd41..81d76b2e6ab 100644 --- a/service/history/handler.go +++ b/service/history/handler.go @@ -63,7 +63,6 @@ import ( "go.temporal.io/server/common/searchattribute" serviceerrors "go.temporal.io/server/common/serviceerror" "go.temporal.io/server/service/history/api" - replicationapi "go.temporal.io/server/service/history/api/replication" "go.temporal.io/server/service/history/configs" "go.temporal.io/server/service/history/events" "go.temporal.io/server/service/history/replication" @@ -1910,10 +1909,41 @@ func (h *Handler) StreamWorkflowReplicationMessages( if err != nil { return h.convertError(err) } - err = replicationapi.StreamReplicationTasks(server, shardContext, clientClusterShardID, serverClusterShardID) + engine, err := shardContext.GetEngine(server.Context()) if err != nil { return h.convertError(err) } + allClusterInfo := shardContext.GetClusterMetadata().GetAllClusterInfo() + clientClusterName, clientShardCount, err := replication.ClusterIDToClusterNameShardCount(allClusterInfo, clientClusterShardID.ClusterID) + if err != nil { + return h.convertError(err) + } + _, serverShardCount, err := replication.ClusterIDToClusterNameShardCount(allClusterInfo, int32(shardContext.GetClusterMetadata().GetClusterID())) + if err != nil { + return h.convertError(err) + } + err = common.VerifyShardIDMapping(clientShardCount, serverShardCount, clientClusterShardID.ShardID, serverClusterShardID.ShardID) + if err != nil { + return h.convertError(err) + } + streamSender := replication.NewStreamSender( + server, + shardContext, + engine, + replication.NewSourceTaskConvertor( + engine, + shardContext.GetNamespaceRegistry(), + clientShardCount, + clientClusterName, + replication.NewClusterShardKey(clientClusterShardID.ClusterID, clientClusterShardID.ShardID), + ), + replication.NewClusterShardKey(clientClusterShardID.ClusterID, clientClusterShardID.ShardID), + replication.NewClusterShardKey(serverClusterShardID.ClusterID, serverClusterShardID.ShardID), + ) + h.streamReceiverMonitor.RegisterInboundStream(streamSender) + streamSender.Start() + defer streamSender.Stop() + streamSender.Wait() return nil } diff --git a/service/history/replication/grpc_stream_client.go b/service/history/replication/grpc_stream_client.go index c40b33ae1ab..0f2e9a60b43 100644 --- a/service/history/replication/grpc_stream_client.go +++ b/service/history/replication/grpc_stream_client.go @@ -26,9 +26,7 @@ package replication import ( "context" - "fmt" - "go.temporal.io/api/serviceerror" "google.golang.org/grpc/metadata" "go.temporal.io/server/api/adminservice/v1" @@ -60,7 +58,7 @@ func (p *StreamBiDirectionStreamClientProvider) Get( serverShardKey ClusterShardKey, ) (BiDirectionStreamClient[*adminservice.StreamWorkflowReplicationMessagesRequest, *adminservice.StreamWorkflowReplicationMessagesResponse], error) { allClusterInfo := p.clusterMetadata.GetAllClusterInfo() - clusterName, err := clusterIDToClusterName(allClusterInfo, serverShardKey.ClusterID) + clusterName, _, err := ClusterIDToClusterNameShardCount(allClusterInfo, serverShardKey.ClusterID) if err != nil { return nil, err } @@ -80,15 +78,3 @@ func (p *StreamBiDirectionStreamClientProvider) Get( )) return adminClient.StreamWorkflowReplicationMessages(ctx) } - -func clusterIDToClusterName( - allClusterInfo map[string]cluster.ClusterInformation, - clusterID int32, -) (string, error) { - for clusterName, clusterInfo := range allClusterInfo { - if int32(clusterInfo.InitialFailoverVersion) == clusterID { - return clusterName, nil - } - } - return "", serviceerror.NewInternal(fmt.Sprintf("unknown cluster ID: %v", clusterID)) -} diff --git a/service/history/replication/stream.go b/service/history/replication/stream.go new file mode 100644 index 00000000000..4419d76e02e --- /dev/null +++ b/service/history/replication/stream.go @@ -0,0 +1,58 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package replication + +import ( + "fmt" + + "go.temporal.io/api/serviceerror" + + "go.temporal.io/server/api/adminservice/v1" + "go.temporal.io/server/common/cluster" +) + +type ( + Stream BiDirectionStream[*adminservice.StreamWorkflowReplicationMessagesRequest, *adminservice.StreamWorkflowReplicationMessagesResponse] + ClusterShardKey struct { + ClusterID int32 + ShardID int32 + } + ClusterShardKeyPair struct { + Client ClusterShardKey + Server ClusterShardKey + } +) + +func ClusterIDToClusterNameShardCount( + allClusterInfo map[string]cluster.ClusterInformation, + clusterID int32, +) (string, int32, error) { + for clusterName, clusterInfo := range allClusterInfo { + if int32(clusterInfo.InitialFailoverVersion) == clusterID { + return clusterName, clusterInfo.ShardCount, nil + } + } + return "", 0, serviceerror.NewInternal(fmt.Sprintf("unknown cluster ID: %v", clusterID)) +} diff --git a/service/history/replication/stream_receiver.go b/service/history/replication/stream_receiver.go index d29439e54b5..5e1a833601f 100644 --- a/service/history/replication/stream_receiver.go +++ b/service/history/replication/stream_receiver.go @@ -22,6 +22,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +//go:generate mockgen -copyright_file ../../../LICENSE -package $GOPACKAGE -source $GOFILE -destination stream_receiver_mock.go + package replication import ( @@ -41,17 +43,12 @@ import ( ) type ( - ClusterShardKey struct { - ClusterID int32 - ShardID int32 - } - ClusterShardKeyPair struct { - Client ClusterShardKey - Server ClusterShardKey + StreamReceiver interface { + common.Daemon + IsValid() bool + Key() ClusterShardKeyPair } - - Stream BiDirectionStream[*adminservice.StreamWorkflowReplicationMessagesRequest, *adminservice.StreamWorkflowReplicationMessagesResponse] - StreamReceiver struct { + StreamReceiverImpl struct { ProcessToolBox status int32 @@ -78,10 +75,10 @@ func NewStreamReceiver( processToolBox ProcessToolBox, clientShardKey ClusterShardKey, serverShardKey ClusterShardKey, -) *StreamReceiver { +) *StreamReceiverImpl { logger := log.With(processToolBox.Logger, tag.ShardID(clientShardKey.ShardID)) taskTracker := NewExecutableTaskTracker(logger) - return &StreamReceiver{ + return &StreamReceiverImpl{ ProcessToolBox: processToolBox, status: common.DaemonStatusInitialized, @@ -99,7 +96,7 @@ func NewStreamReceiver( } // Start starts the processor -func (r *StreamReceiver) Start() { +func (r *StreamReceiverImpl) Start() { if !atomic.CompareAndSwapInt32( &r.status, common.DaemonStatusInitialized, @@ -115,7 +112,7 @@ func (r *StreamReceiver) Start() { } // Stop stops the processor -func (r *StreamReceiver) Stop() { +func (r *StreamReceiverImpl) Stop() { if !atomic.CompareAndSwapInt32( &r.status, common.DaemonStatusStarted, @@ -131,11 +128,18 @@ func (r *StreamReceiver) Stop() { r.logger.Info("StreamReceiver shutting down.") } -func (r *StreamReceiver) IsValid() bool { +func (r *StreamReceiverImpl) IsValid() bool { return atomic.LoadInt32(&r.status) == common.DaemonStatusStarted } -func (r *StreamReceiver) sendEventLoop() { +func (r *StreamReceiverImpl) Key() ClusterShardKeyPair { + return ClusterShardKeyPair{ + Client: r.clientShardKey, + Server: r.serverShardKey, + } +} + +func (r *StreamReceiverImpl) sendEventLoop() { defer r.Stop() timer := time.NewTicker(r.Config.ReplicationStreamSyncStatusDuration()) defer timer.Stop() @@ -154,14 +158,14 @@ func (r *StreamReceiver) sendEventLoop() { } } -func (r *StreamReceiver) recvEventLoop() { +func (r *StreamReceiverImpl) recvEventLoop() { defer r.Stop() err := r.processMessages(r.stream) r.logger.Error("StreamReceiver exit recv loop", tag.Error(err)) } -func (r *StreamReceiver) ackMessage( +func (r *StreamReceiverImpl) ackMessage( stream Stream, ) error { watermarkInfo := r.taskTracker.LowWatermark() @@ -194,11 +198,11 @@ func (r *StreamReceiver) ackMessage( return nil } -func (r *StreamReceiver) processMessages( +func (r *StreamReceiverImpl) processMessages( stream Stream, ) error { allClusterInfo := r.ClusterMetadata.GetAllClusterInfo() - clusterName, err := clusterIDToClusterName(allClusterInfo, r.serverShardKey.ClusterID) + clusterName, _, err := ClusterIDToClusterNameShardCount(allClusterInfo, r.serverShardKey.ClusterID) if err != nil { return err } diff --git a/service/history/replication/stream_receiver_mock.go b/service/history/replication/stream_receiver_mock.go new file mode 100644 index 00000000000..f124daf0dee --- /dev/null +++ b/service/history/replication/stream_receiver_mock.go @@ -0,0 +1,110 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +// Code generated by MockGen. DO NOT EDIT. +// Source: stream_receiver.go + +// Package replication is a generated GoMock package. +package replication + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockStreamReceiver is a mock of StreamReceiver interface. +type MockStreamReceiver struct { + ctrl *gomock.Controller + recorder *MockStreamReceiverMockRecorder +} + +// MockStreamReceiverMockRecorder is the mock recorder for MockStreamReceiver. +type MockStreamReceiverMockRecorder struct { + mock *MockStreamReceiver +} + +// NewMockStreamReceiver creates a new mock instance. +func NewMockStreamReceiver(ctrl *gomock.Controller) *MockStreamReceiver { + mock := &MockStreamReceiver{ctrl: ctrl} + mock.recorder = &MockStreamReceiverMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStreamReceiver) EXPECT() *MockStreamReceiverMockRecorder { + return m.recorder +} + +// IsValid mocks base method. +func (m *MockStreamReceiver) IsValid() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsValid") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsValid indicates an expected call of IsValid. +func (mr *MockStreamReceiverMockRecorder) IsValid() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsValid", reflect.TypeOf((*MockStreamReceiver)(nil).IsValid)) +} + +// Key mocks base method. +func (m *MockStreamReceiver) Key() ClusterShardKeyPair { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Key") + ret0, _ := ret[0].(ClusterShardKeyPair) + return ret0 +} + +// Key indicates an expected call of Key. +func (mr *MockStreamReceiverMockRecorder) Key() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Key", reflect.TypeOf((*MockStreamReceiver)(nil).Key)) +} + +// Start mocks base method. +func (m *MockStreamReceiver) Start() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Start") +} + +// Start indicates an expected call of Start. +func (mr *MockStreamReceiverMockRecorder) Start() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockStreamReceiver)(nil).Start)) +} + +// Stop mocks base method. +func (m *MockStreamReceiver) Stop() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Stop") +} + +// Stop indicates an expected call of Stop. +func (mr *MockStreamReceiverMockRecorder) Stop() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockStreamReceiver)(nil).Stop)) +} diff --git a/service/history/replication/stream_receiver_monitor.go b/service/history/replication/stream_receiver_monitor.go index fe865223525..017b9d6edbb 100644 --- a/service/history/replication/stream_receiver_monitor.go +++ b/service/history/replication/stream_receiver_monitor.go @@ -42,6 +42,7 @@ const ( type ( StreamReceiverMonitor interface { common.Daemon + RegisterInboundStream(streamSender StreamSender) } StreamReceiverMonitorImpl struct { ProcessToolBox @@ -51,7 +52,8 @@ type ( shutdownOnce channel.ShutdownOnce sync.Mutex - streams map[ClusterShardKeyPair]*StreamReceiver + inboundStreams map[ClusterShardKeyPair]StreamSender + outboundStreams map[ClusterShardKeyPair]StreamReceiver } ) @@ -66,7 +68,8 @@ func NewStreamReceiverMonitor( status: streamStatusInitialized, shutdownOnce: channel.NewShutdownOnce(), - streams: make(map[ClusterShardKeyPair]*StreamReceiver), + inboundStreams: make(map[ClusterShardKeyPair]StreamSender), + outboundStreams: make(map[ClusterShardKeyPair]StreamReceiver), } } @@ -102,13 +105,28 @@ func (m *StreamReceiverMonitorImpl) Stop() { m.shutdownOnce.Shutdown() m.Lock() defer m.Unlock() - for serverKey, stream := range m.streams { + for serverKey, stream := range m.outboundStreams { stream.Stop() - delete(m.streams, serverKey) + delete(m.outboundStreams, serverKey) } m.Logger.Info("StreamReceiverMonitor stopped.") } +func (m *StreamReceiverMonitorImpl) RegisterInboundStream( + streamSender StreamSender, +) { + streamKey := streamSender.Key() + + m.Lock() + defer m.Unlock() + + if staleSender, ok := m.inboundStreams[streamKey]; ok { + staleSender.Stop() + delete(m.inboundStreams, streamKey) + } + m.inboundStreams[streamKey] = streamSender +} + func (m *StreamReceiverMonitorImpl) eventLoop() { defer m.Stop() ticker := time.NewTicker(streamReceiverMonitorInterval) @@ -122,27 +140,71 @@ func (m *StreamReceiverMonitorImpl) eventLoop() { } }) defer m.ClusterMetadata.UnRegisterMetadataChangeCallback(m) - m.reconcileStreams() + m.reconcileOutboundStreams() Loop: for !m.shutdownOnce.IsShutdown() { select { case <-clusterMetadataChangeChan: - m.reconcileStreams() + m.reconcileInboundStreams() + m.reconcileOutboundStreams() case <-ticker.C: - m.reconcileStreams() + m.reconcileInboundStreams() + m.reconcileOutboundStreams() case <-m.shutdownOnce.Channel(): break Loop } } } -func (m *StreamReceiverMonitorImpl) reconcileStreams() { - streamKeys := m.generateStreamKeys() - m.doReconcileStreams(streamKeys) +func (m *StreamReceiverMonitorImpl) reconcileInboundStreams() { + streamKeys := m.generateInboundStreamKeys() + m.doReconcileInboundStreams(streamKeys) +} + +func (m *StreamReceiverMonitorImpl) reconcileOutboundStreams() { + streamKeys := m.generateOutboundStreamKeys() + m.doReconcileOutboundStreams(streamKeys) } -func (m *StreamReceiverMonitorImpl) generateStreamKeys() map[ClusterShardKeyPair]struct{} { +func (m *StreamReceiverMonitorImpl) generateInboundStreamKeys() map[ClusterShardKeyPair]struct{} { + allClusterInfo := m.ClusterMetadata.GetAllClusterInfo() + + clientClusterIDs := make(map[int32]struct{}) + serverClusterID := int32(m.ClusterMetadata.GetClusterID()) + clusterIDToShardCount := make(map[int32]int32) + for _, clusterInfo := range allClusterInfo { + clusterIDToShardCount[int32(clusterInfo.InitialFailoverVersion)] = clusterInfo.ShardCount + + if !clusterInfo.Enabled || int32(clusterInfo.InitialFailoverVersion) == serverClusterID { + continue + } + clientClusterIDs[int32(clusterInfo.InitialFailoverVersion)] = struct{}{} + } + streamKeys := make(map[ClusterShardKeyPair]struct{}) + for _, shardID := range m.ShardController.ShardIDs() { + for clientClusterID := range clientClusterIDs { + serverShardID := shardID + for _, clientShardID := range common.MapShardID( + clusterIDToShardCount[serverClusterID], + clusterIDToShardCount[clientClusterID], + serverShardID, + ) { + m.Logger.Debug(fmt.Sprintf( + "inbound cluster shard ID %v/%v -> cluster shard ID %v/%v", + clientClusterID, clientShardID, serverClusterID, serverShardID, + )) + streamKeys[ClusterShardKeyPair{ + Client: NewClusterShardKey(clientClusterID, clientShardID), + Server: NewClusterShardKey(serverClusterID, serverShardID), + }] = struct{}{} + } + } + } + return streamKeys +} + +func (m *StreamReceiverMonitorImpl) generateOutboundStreamKeys() map[ClusterShardKeyPair]struct{} { allClusterInfo := m.ClusterMetadata.GetAllClusterInfo() clientClusterID := int32(m.ClusterMetadata.GetClusterID()) @@ -166,7 +228,7 @@ func (m *StreamReceiverMonitorImpl) generateStreamKeys() map[ClusterShardKeyPair clientShardID, ) { m.Logger.Debug(fmt.Sprintf( - "cluster shard ID %v/%v -> cluster shard ID %v/%v", + "outbound cluster shard ID %v/%v -> cluster shard ID %v/%v", clientClusterID, clientShardID, serverClusterID, serverShardID, )) streamKeys[ClusterShardKeyPair{ @@ -179,7 +241,7 @@ func (m *StreamReceiverMonitorImpl) generateStreamKeys() map[ClusterShardKeyPair return streamKeys } -func (m *StreamReceiverMonitorImpl) doReconcileStreams( +func (m *StreamReceiverMonitorImpl) doReconcileInboundStreams( streamKeys map[ClusterShardKeyPair]struct{}, ) { m.Lock() @@ -188,25 +250,44 @@ func (m *StreamReceiverMonitorImpl) doReconcileStreams( return } - for streamKey, stream := range m.streams { + for streamKey, stream := range m.inboundStreams { if !stream.IsValid() { stream.Stop() - delete(m.streams, streamKey) + delete(m.inboundStreams, streamKey) + } else if _, ok := streamKeys[streamKey]; !ok { + stream.Stop() + delete(m.inboundStreams, streamKey) } - if _, ok := streamKeys[streamKey]; !ok { + } +} + +func (m *StreamReceiverMonitorImpl) doReconcileOutboundStreams( + streamKeys map[ClusterShardKeyPair]struct{}, +) { + m.Lock() + defer m.Unlock() + if m.shutdownOnce.IsShutdown() { + return + } + + for streamKey, stream := range m.outboundStreams { + if !stream.IsValid() { + stream.Stop() + delete(m.outboundStreams, streamKey) + } else if _, ok := streamKeys[streamKey]; !ok { stream.Stop() - delete(m.streams, streamKey) + delete(m.outboundStreams, streamKey) } } for streamKey := range streamKeys { - if _, ok := m.streams[streamKey]; !ok { + if _, ok := m.outboundStreams[streamKey]; !ok { stream := NewStreamReceiver( m.ProcessToolBox, streamKey.Client, streamKey.Server, ) stream.Start() - m.streams[streamKey] = stream + m.outboundStreams[streamKey] = stream } } } diff --git a/service/history/replication/stream_receiver_monitor_test.go b/service/history/replication/stream_receiver_monitor_test.go index a60d2785f67..a3fb5562391 100644 --- a/service/history/replication/stream_receiver_monitor_test.go +++ b/service/history/replication/stream_receiver_monitor_test.go @@ -56,9 +56,6 @@ type ( clientBean *client.MockBean shardController *shard.MockController - clientClusterID int32 - serverClusterID int32 - streamReceiverMonitor *StreamReceiverMonitorImpl } ) @@ -84,9 +81,6 @@ func (s *streamReceiverMonitorSuite) SetupTest() { s.clientBean = client.NewMockBean(s.controller) s.shardController = shard.NewMockController(s.controller) - s.clientClusterID = int32(cluster.TestCurrentClusterInitialFailoverVersion) - s.serverClusterID = int32(cluster.TestAlternativeClusterInitialFailoverVersion) - s.streamReceiverMonitor = NewStreamReceiverMonitor( ProcessToolBox{ Config: configs.NewConfig( @@ -124,98 +118,176 @@ func (s *streamReceiverMonitorSuite) TearDownTest() { s.streamReceiverMonitor.Lock() defer s.streamReceiverMonitor.Unlock() - for serverKey, stream := range s.streamReceiverMonitor.streams { + for serverKey, stream := range s.streamReceiverMonitor.outboundStreams { stream.Stop() - delete(s.streamReceiverMonitor.streams, serverKey) + delete(s.streamReceiverMonitor.outboundStreams, serverKey) } } -func (s *streamReceiverMonitorSuite) TestGenerateStreamKeys_1To4() { - s.clusterMetadata.EXPECT().GetClusterID().Return(int64(s.clientClusterID)).AnyTimes() +func (s *streamReceiverMonitorSuite) TestGenerateInboundStreamKeys_1From4() { + s.clusterMetadata.EXPECT().GetClusterID().Return(cluster.TestAlternativeClusterInitialFailoverVersion).AnyTimes() + s.clusterMetadata.EXPECT().GetAllClusterInfo().Return(map[string]cluster.ClusterInformation{ + cluster.TestCurrentClusterName: { + Enabled: true, + InitialFailoverVersion: cluster.TestAlternativeClusterInitialFailoverVersion, + RPCAddress: cluster.TestCurrentClusterFrontendAddress, + ShardCount: 1, + }, + cluster.TestAlternativeClusterName: { + Enabled: true, + InitialFailoverVersion: cluster.TestCurrentClusterInitialFailoverVersion, + RPCAddress: cluster.TestAlternativeClusterFrontendAddress, + ShardCount: 4, + }, + }).AnyTimes() + s.shardController.EXPECT().ShardIDs().Return([]int32{1}).AnyTimes() + + streamKeys := s.streamReceiverMonitor.generateInboundStreamKeys() + s.Equal(map[ClusterShardKeyPair]struct{}{ + ClusterShardKeyPair{ + Client: NewClusterShardKey(int32(cluster.TestCurrentClusterInitialFailoverVersion), 1), + Server: NewClusterShardKey(int32(cluster.TestAlternativeClusterInitialFailoverVersion), 1), + }: {}, + ClusterShardKeyPair{ + Client: NewClusterShardKey(int32(cluster.TestCurrentClusterInitialFailoverVersion), 2), + Server: NewClusterShardKey(int32(cluster.TestAlternativeClusterInitialFailoverVersion), 1), + }: {}, + ClusterShardKeyPair{ + Client: NewClusterShardKey(int32(cluster.TestCurrentClusterInitialFailoverVersion), 3), + Server: NewClusterShardKey(int32(cluster.TestAlternativeClusterInitialFailoverVersion), 1), + }: {}, + ClusterShardKeyPair{ + Client: NewClusterShardKey(int32(cluster.TestCurrentClusterInitialFailoverVersion), 4), + Server: NewClusterShardKey(int32(cluster.TestAlternativeClusterInitialFailoverVersion), 1), + }: {}, + }, streamKeys) +} + +func (s *streamReceiverMonitorSuite) TestGenerateInboundStreamKeys_4From1() { + s.clusterMetadata.EXPECT().GetClusterID().Return(cluster.TestAlternativeClusterInitialFailoverVersion).AnyTimes() + s.clusterMetadata.EXPECT().GetAllClusterInfo().Return(map[string]cluster.ClusterInformation{ + cluster.TestCurrentClusterName: { + Enabled: true, + InitialFailoverVersion: cluster.TestAlternativeClusterInitialFailoverVersion, + RPCAddress: cluster.TestCurrentClusterFrontendAddress, + ShardCount: 4, + }, + cluster.TestAlternativeClusterName: { + Enabled: true, + InitialFailoverVersion: cluster.TestCurrentClusterInitialFailoverVersion, + RPCAddress: cluster.TestAlternativeClusterFrontendAddress, + ShardCount: 1, + }, + }).AnyTimes() + s.shardController.EXPECT().ShardIDs().Return([]int32{1, 2, 3, 4}).AnyTimes() + + streamKeys := s.streamReceiverMonitor.generateInboundStreamKeys() + s.Equal(map[ClusterShardKeyPair]struct{}{ + ClusterShardKeyPair{ + Client: NewClusterShardKey(int32(cluster.TestCurrentClusterInitialFailoverVersion), 1), + Server: NewClusterShardKey(int32(cluster.TestAlternativeClusterInitialFailoverVersion), 1), + }: {}, + ClusterShardKeyPair{ + Client: NewClusterShardKey(int32(cluster.TestCurrentClusterInitialFailoverVersion), 1), + Server: NewClusterShardKey(int32(cluster.TestAlternativeClusterInitialFailoverVersion), 2), + }: {}, + ClusterShardKeyPair{ + Client: NewClusterShardKey(int32(cluster.TestCurrentClusterInitialFailoverVersion), 1), + Server: NewClusterShardKey(int32(cluster.TestAlternativeClusterInitialFailoverVersion), 3), + }: {}, + ClusterShardKeyPair{ + Client: NewClusterShardKey(int32(cluster.TestCurrentClusterInitialFailoverVersion), 1), + Server: NewClusterShardKey(int32(cluster.TestAlternativeClusterInitialFailoverVersion), 4), + }: {}, + }, streamKeys) +} + +func (s *streamReceiverMonitorSuite) TestGenerateOutboundStreamKeys_1To4() { + s.clusterMetadata.EXPECT().GetClusterID().Return(cluster.TestCurrentClusterInitialFailoverVersion).AnyTimes() s.clusterMetadata.EXPECT().GetAllClusterInfo().Return(map[string]cluster.ClusterInformation{ cluster.TestCurrentClusterName: { Enabled: true, - InitialFailoverVersion: int64(s.clientClusterID), + InitialFailoverVersion: cluster.TestCurrentClusterInitialFailoverVersion, RPCAddress: cluster.TestCurrentClusterFrontendAddress, ShardCount: 1, }, cluster.TestAlternativeClusterName: { Enabled: true, - InitialFailoverVersion: int64(s.serverClusterID), + InitialFailoverVersion: cluster.TestAlternativeClusterInitialFailoverVersion, RPCAddress: cluster.TestAlternativeClusterFrontendAddress, ShardCount: 4, }, }).AnyTimes() s.shardController.EXPECT().ShardIDs().Return([]int32{1}).AnyTimes() - streamKeys := s.streamReceiverMonitor.generateStreamKeys() + streamKeys := s.streamReceiverMonitor.generateOutboundStreamKeys() s.Equal(map[ClusterShardKeyPair]struct{}{ ClusterShardKeyPair{ - Client: NewClusterShardKey(s.clientClusterID, 1), - Server: NewClusterShardKey(s.serverClusterID, 1), + Client: NewClusterShardKey(int32(cluster.TestCurrentClusterInitialFailoverVersion), 1), + Server: NewClusterShardKey(int32(cluster.TestAlternativeClusterInitialFailoverVersion), 1), }: {}, ClusterShardKeyPair{ - Client: NewClusterShardKey(s.clientClusterID, 1), - Server: NewClusterShardKey(s.serverClusterID, 2), + Client: NewClusterShardKey(int32(cluster.TestCurrentClusterInitialFailoverVersion), 1), + Server: NewClusterShardKey(int32(cluster.TestAlternativeClusterInitialFailoverVersion), 2), }: {}, ClusterShardKeyPair{ - Client: NewClusterShardKey(s.clientClusterID, 1), - Server: NewClusterShardKey(s.serverClusterID, 3), + Client: NewClusterShardKey(int32(cluster.TestCurrentClusterInitialFailoverVersion), 1), + Server: NewClusterShardKey(int32(cluster.TestAlternativeClusterInitialFailoverVersion), 3), }: {}, ClusterShardKeyPair{ - Client: NewClusterShardKey(s.clientClusterID, 1), - Server: NewClusterShardKey(s.serverClusterID, 4), + Client: NewClusterShardKey(int32(cluster.TestCurrentClusterInitialFailoverVersion), 1), + Server: NewClusterShardKey(int32(cluster.TestAlternativeClusterInitialFailoverVersion), 4), }: {}, }, streamKeys) } -func (s *streamReceiverMonitorSuite) TestGenerateStreamKeys_4To1() { - s.clusterMetadata.EXPECT().GetClusterID().Return(int64(s.clientClusterID)).AnyTimes() +func (s *streamReceiverMonitorSuite) TestGenerateOutboundStreamKeys_4To1() { + s.clusterMetadata.EXPECT().GetClusterID().Return(cluster.TestCurrentClusterInitialFailoverVersion).AnyTimes() s.clusterMetadata.EXPECT().GetAllClusterInfo().Return(map[string]cluster.ClusterInformation{ cluster.TestCurrentClusterName: { Enabled: true, - InitialFailoverVersion: int64(s.clientClusterID), + InitialFailoverVersion: cluster.TestCurrentClusterInitialFailoverVersion, RPCAddress: cluster.TestCurrentClusterFrontendAddress, ShardCount: 4, }, cluster.TestAlternativeClusterName: { Enabled: true, - InitialFailoverVersion: int64(s.serverClusterID), + InitialFailoverVersion: cluster.TestAlternativeClusterInitialFailoverVersion, RPCAddress: cluster.TestAlternativeClusterFrontendAddress, ShardCount: 1, }, }).AnyTimes() s.shardController.EXPECT().ShardIDs().Return([]int32{1, 2, 3, 4}).AnyTimes() - streamKeys := s.streamReceiverMonitor.generateStreamKeys() + streamKeys := s.streamReceiverMonitor.generateOutboundStreamKeys() s.Equal(map[ClusterShardKeyPair]struct{}{ ClusterShardKeyPair{ - Client: NewClusterShardKey(s.clientClusterID, 1), - Server: NewClusterShardKey(s.serverClusterID, 1), + Client: NewClusterShardKey(int32(cluster.TestCurrentClusterInitialFailoverVersion), 1), + Server: NewClusterShardKey(int32(cluster.TestAlternativeClusterInitialFailoverVersion), 1), }: {}, ClusterShardKeyPair{ - Client: NewClusterShardKey(s.clientClusterID, 2), - Server: NewClusterShardKey(s.serverClusterID, 1), + Client: NewClusterShardKey(int32(cluster.TestCurrentClusterInitialFailoverVersion), 2), + Server: NewClusterShardKey(int32(cluster.TestAlternativeClusterInitialFailoverVersion), 1), }: {}, ClusterShardKeyPair{ - Client: NewClusterShardKey(s.clientClusterID, 3), - Server: NewClusterShardKey(s.serverClusterID, 1), + Client: NewClusterShardKey(int32(cluster.TestCurrentClusterInitialFailoverVersion), 3), + Server: NewClusterShardKey(int32(cluster.TestAlternativeClusterInitialFailoverVersion), 1), }: {}, ClusterShardKeyPair{ - Client: NewClusterShardKey(s.clientClusterID, 4), - Server: NewClusterShardKey(s.serverClusterID, 1), + Client: NewClusterShardKey(int32(cluster.TestCurrentClusterInitialFailoverVersion), 4), + Server: NewClusterShardKey(int32(cluster.TestAlternativeClusterInitialFailoverVersion), 1), }: {}, }, streamKeys) } -func (s *streamReceiverMonitorSuite) TestDoReconcileStreams_Add() { +func (s *streamReceiverMonitorSuite) TestDoReconcileInboundStreams_Add() { s.clusterMetadata.EXPECT().GetAllClusterInfo().Return(cluster.TestAllClusterInfo).AnyTimes() - clientKey := NewClusterShardKey(s.clientClusterID, rand.Int31()) - serverKey := NewClusterShardKey(s.serverClusterID, rand.Int31()) + clientKey := NewClusterShardKey(int32(cluster.TestAlternativeClusterInitialFailoverVersion), rand.Int31()) + serverKey := NewClusterShardKey(int32(cluster.TestCurrentClusterInitialFailoverVersion), rand.Int31()) s.streamReceiverMonitor.Lock() - s.Equal(0, len(s.streamReceiverMonitor.streams)) + s.Equal(0, len(s.streamReceiverMonitor.inboundStreams)) s.streamReceiverMonitor.Unlock() streamKeys := map[ClusterShardKeyPair]struct{}{ @@ -224,63 +296,165 @@ func (s *streamReceiverMonitorSuite) TestDoReconcileStreams_Add() { Server: serverKey, }: {}, } - s.streamReceiverMonitor.doReconcileStreams(streamKeys) + streamSender := NewMockStreamSender(s.controller) + streamSender.EXPECT().Key().Return(ClusterShardKeyPair{ + Client: clientKey, + Server: serverKey, + }).AnyTimes() + streamSender.EXPECT().IsValid().Return(true) + s.streamReceiverMonitor.RegisterInboundStream(streamSender) + s.streamReceiverMonitor.doReconcileInboundStreams(streamKeys) s.streamReceiverMonitor.Lock() defer s.streamReceiverMonitor.Unlock() - s.Equal(1, len(s.streamReceiverMonitor.streams)) - stream, ok := s.streamReceiverMonitor.streams[ClusterShardKeyPair{ + s.Equal(1, len(s.streamReceiverMonitor.inboundStreams)) + stream, ok := s.streamReceiverMonitor.inboundStreams[ClusterShardKeyPair{ Client: clientKey, Server: serverKey, }] s.True(ok) - s.True(stream.IsValid()) + s.Equal(streamSender, stream) +} + +func (s *streamReceiverMonitorSuite) TestDoReconcileInboundStreams_Remove() { + s.clusterMetadata.EXPECT().GetAllClusterInfo().Return(cluster.TestAllClusterInfo).AnyTimes() + + clientKey := NewClusterShardKey(int32(cluster.TestAlternativeClusterInitialFailoverVersion), rand.Int31()) + serverKey := NewClusterShardKey(int32(cluster.TestCurrentClusterInitialFailoverVersion), rand.Int31()) + streamSender := NewMockStreamSender(s.controller) + streamSender.EXPECT().Key().Return(ClusterShardKeyPair{ + Client: clientKey, + Server: serverKey, + }).AnyTimes() + streamSender.EXPECT().IsValid().Return(false) + streamSender.EXPECT().Stop() + s.streamReceiverMonitor.RegisterInboundStream(streamSender) + + s.streamReceiverMonitor.Lock() + s.Equal(1, len(s.streamReceiverMonitor.inboundStreams)) + s.streamReceiverMonitor.Unlock() + + s.streamReceiverMonitor.doReconcileInboundStreams(map[ClusterShardKeyPair]struct{}{}) + + s.streamReceiverMonitor.Lock() + defer s.streamReceiverMonitor.Unlock() + s.Equal(0, len(s.streamReceiverMonitor.inboundStreams)) +} + +func (s *streamReceiverMonitorSuite) TestDoReconcileInboundStreams_Reactivate() { + s.clusterMetadata.EXPECT().GetAllClusterInfo().Return(cluster.TestAllClusterInfo).AnyTimes() + + clientKey := NewClusterShardKey(int32(cluster.TestAlternativeClusterInitialFailoverVersion), rand.Int31()) + serverKey := NewClusterShardKey(int32(cluster.TestCurrentClusterInitialFailoverVersion), rand.Int31()) + streamSenderStale := NewMockStreamSender(s.controller) + streamSenderStale.EXPECT().Key().Return(ClusterShardKeyPair{ + Client: clientKey, + Server: serverKey, + }).AnyTimes() + streamSenderStale.EXPECT().Stop() + s.streamReceiverMonitor.RegisterInboundStream(streamSenderStale) + + s.streamReceiverMonitor.Lock() + s.Equal(1, len(s.streamReceiverMonitor.inboundStreams)) + s.streamReceiverMonitor.Unlock() + + streamSenderValid := NewMockStreamSender(s.controller) + streamSenderValid.EXPECT().Key().Return(ClusterShardKeyPair{ + Client: clientKey, + Server: serverKey, + }).AnyTimes() + s.streamReceiverMonitor.RegisterInboundStream(streamSenderValid) + + s.streamReceiverMonitor.Lock() + defer s.streamReceiverMonitor.Unlock() + s.Equal(1, len(s.streamReceiverMonitor.inboundStreams)) + stream, ok := s.streamReceiverMonitor.inboundStreams[ClusterShardKeyPair{ + Client: clientKey, + Server: serverKey, + }] + s.True(ok) + s.Equal(streamSenderValid, stream) } -func (s *streamReceiverMonitorSuite) TestDoReconcileStreams_Remove() { +func (s *streamReceiverMonitorSuite) TestDoReconcileOutboundStreams_Add() { s.clusterMetadata.EXPECT().GetAllClusterInfo().Return(cluster.TestAllClusterInfo).AnyTimes() - clientKey := NewClusterShardKey(s.clientClusterID, rand.Int31()) - serverKey := NewClusterShardKey(s.serverClusterID, rand.Int31()) - stream := NewStreamReceiver(s.streamReceiverMonitor.ProcessToolBox, clientKey, serverKey) + clientKey := NewClusterShardKey(int32(cluster.TestCurrentClusterInitialFailoverVersion), rand.Int31()) + serverKey := NewClusterShardKey(int32(cluster.TestAlternativeClusterInitialFailoverVersion), rand.Int31()) s.streamReceiverMonitor.Lock() - s.Equal(0, len(s.streamReceiverMonitor.streams)) - stream.Start() + s.Equal(0, len(s.streamReceiverMonitor.outboundStreams)) + s.streamReceiverMonitor.Unlock() + + streamKeys := map[ClusterShardKeyPair]struct{}{ + ClusterShardKeyPair{ + Client: clientKey, + Server: serverKey, + }: {}, + } + s.streamReceiverMonitor.doReconcileOutboundStreams(streamKeys) + + s.streamReceiverMonitor.Lock() + defer s.streamReceiverMonitor.Unlock() + s.Equal(1, len(s.streamReceiverMonitor.outboundStreams)) + stream, ok := s.streamReceiverMonitor.outboundStreams[ClusterShardKeyPair{ + Client: clientKey, + Server: serverKey, + }] + s.True(ok) s.True(stream.IsValid()) - s.streamReceiverMonitor.streams[ClusterShardKeyPair{ +} + +func (s *streamReceiverMonitorSuite) TestDoReconcileOutboundStreams_Remove() { + s.clusterMetadata.EXPECT().GetAllClusterInfo().Return(cluster.TestAllClusterInfo).AnyTimes() + + clientKey := NewClusterShardKey(int32(cluster.TestCurrentClusterInitialFailoverVersion), rand.Int31()) + serverKey := NewClusterShardKey(int32(cluster.TestAlternativeClusterInitialFailoverVersion), rand.Int31()) + streamReceiver := NewMockStreamReceiver(s.controller) + streamReceiver.EXPECT().Key().Return(ClusterShardKeyPair{ Client: clientKey, Server: serverKey, - }] = stream + }).AnyTimes() + streamReceiver.EXPECT().IsValid().Return(true) + streamReceiver.EXPECT().Stop() + + s.streamReceiverMonitor.Lock() + s.Equal(0, len(s.streamReceiverMonitor.outboundStreams)) + s.streamReceiverMonitor.outboundStreams[ClusterShardKeyPair{ + Client: clientKey, + Server: serverKey, + }] = streamReceiver s.streamReceiverMonitor.Unlock() - s.streamReceiverMonitor.doReconcileStreams(map[ClusterShardKeyPair]struct{}{}) + s.streamReceiverMonitor.doReconcileOutboundStreams(map[ClusterShardKeyPair]struct{}{}) s.streamReceiverMonitor.Lock() defer s.streamReceiverMonitor.Unlock() - s.Equal(0, len(s.streamReceiverMonitor.streams)) - s.False(stream.IsValid()) + s.Equal(0, len(s.streamReceiverMonitor.outboundStreams)) } -func (s *streamReceiverMonitorSuite) TestDoReconcileStreams_Reactivate() { +func (s *streamReceiverMonitorSuite) TestDoReconcileOutboundStreams_Reactivate() { s.clusterMetadata.EXPECT().GetAllClusterInfo().Return(cluster.TestAllClusterInfo).AnyTimes() - clientKey := NewClusterShardKey(s.clientClusterID, rand.Int31()) - serverKey := NewClusterShardKey(s.serverClusterID, rand.Int31()) - stream := NewStreamReceiver(s.streamReceiverMonitor.ProcessToolBox, clientKey, serverKey) + clientKey := NewClusterShardKey(int32(cluster.TestCurrentClusterInitialFailoverVersion), rand.Int31()) + serverKey := NewClusterShardKey(int32(cluster.TestAlternativeClusterInitialFailoverVersion), rand.Int31()) + streamReceiverStale := NewMockStreamReceiver(s.controller) + streamReceiverStale.EXPECT().Key().Return(ClusterShardKeyPair{ + Client: clientKey, + Server: serverKey, + }).AnyTimes() + streamReceiverStale.EXPECT().IsValid().Return(false) + streamReceiverStale.EXPECT().Stop() s.streamReceiverMonitor.Lock() - s.Equal(0, len(s.streamReceiverMonitor.streams)) - stream.Start() - stream.Stop() - s.False(stream.IsValid()) - s.streamReceiverMonitor.streams[ClusterShardKeyPair{ + s.Equal(0, len(s.streamReceiverMonitor.outboundStreams)) + s.streamReceiverMonitor.outboundStreams[ClusterShardKeyPair{ Client: clientKey, Server: serverKey, - }] = stream + }] = streamReceiverStale s.streamReceiverMonitor.Unlock() - s.streamReceiverMonitor.doReconcileStreams(map[ClusterShardKeyPair]struct{}{ + s.streamReceiverMonitor.doReconcileOutboundStreams(map[ClusterShardKeyPair]struct{}{ ClusterShardKeyPair{ Client: clientKey, Server: serverKey, @@ -289,8 +463,8 @@ func (s *streamReceiverMonitorSuite) TestDoReconcileStreams_Reactivate() { s.streamReceiverMonitor.Lock() defer s.streamReceiverMonitor.Unlock() - s.Equal(1, len(s.streamReceiverMonitor.streams)) - stream, ok := s.streamReceiverMonitor.streams[ClusterShardKeyPair{ + s.Equal(1, len(s.streamReceiverMonitor.outboundStreams)) + stream, ok := s.streamReceiverMonitor.outboundStreams[ClusterShardKeyPair{ Client: clientKey, Server: serverKey, }] diff --git a/service/history/replication/stream_receiver_test.go b/service/history/replication/stream_receiver_test.go index 1cb5f3d65f1..de69fb91792 100644 --- a/service/history/replication/stream_receiver_test.go +++ b/service/history/replication/stream_receiver_test.go @@ -55,7 +55,7 @@ type ( stream *mockStream taskScheduler *mockScheduler - streamReceiver *StreamReceiver + streamReceiver *StreamReceiverImpl } mockStream struct { diff --git a/service/history/replication/stream_sender.go b/service/history/replication/stream_sender.go new file mode 100644 index 00000000000..4e2cceb56e7 --- /dev/null +++ b/service/history/replication/stream_sender.go @@ -0,0 +1,422 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +//go:generate mockgen -copyright_file ../../../LICENSE -package $GOPACKAGE -source $GOFILE -destination stream_sender_mock.go + +package replication + +import ( + "context" + "fmt" + "math" + "sync/atomic" + + "go.temporal.io/api/serviceerror" + + enumsspb "go.temporal.io/server/api/enums/v1" + "go.temporal.io/server/api/historyservice/v1" + persistencespb "go.temporal.io/server/api/persistence/v1" + replicationspb "go.temporal.io/server/api/replication/v1" + "go.temporal.io/server/common" + "go.temporal.io/server/common/channel" + "go.temporal.io/server/common/log" + "go.temporal.io/server/common/log/tag" + "go.temporal.io/server/common/metrics" + "go.temporal.io/server/common/namespace" + "go.temporal.io/server/common/primitives/timestamp" + "go.temporal.io/server/service/history/shard" + "go.temporal.io/server/service/history/tasks" +) + +type ( + SourceTaskConvertorImpl struct { + historyEngine shard.Engine + namespaceCache namespace.Registry + clientClusterShardCount int32 + clientClusterName string + clientShardKey ClusterShardKey + } + SourceTaskConvertor interface { + Convert(task tasks.Task) (*replicationspb.ReplicationTask, error) + } + StreamSender interface { + common.Daemon + IsValid() bool + Key() ClusterShardKeyPair + } + StreamSenderImpl struct { + server historyservice.HistoryService_StreamWorkflowReplicationMessagesServer + shardContext shard.Context + historyEngine shard.Engine + taskConvertor SourceTaskConvertor + metrics metrics.Handler + logger log.Logger + + status int32 + clientShardKey ClusterShardKey + serverShardKey ClusterShardKey + shutdownChan channel.ShutdownOnce + } +) + +func NewStreamSender( + server historyservice.HistoryService_StreamWorkflowReplicationMessagesServer, + shardContext shard.Context, + historyEngine shard.Engine, + taskConvertor SourceTaskConvertor, + clientShardKey ClusterShardKey, + serverShardKey ClusterShardKey, +) *StreamSenderImpl { + return &StreamSenderImpl{ + server: server, + shardContext: shardContext, + historyEngine: historyEngine, + taskConvertor: taskConvertor, + metrics: shardContext.GetMetricsHandler(), + logger: shardContext.GetLogger(), + + status: common.DaemonStatusInitialized, + clientShardKey: clientShardKey, + serverShardKey: serverShardKey, + shutdownChan: channel.NewShutdownOnce(), + } +} + +func (s *StreamSenderImpl) Start() { + if !atomic.CompareAndSwapInt32( + &s.status, + common.DaemonStatusInitialized, + common.DaemonStatusStarted, + ) { + return + } + + go func() { _ = s.sendEventLoop() }() + go func() { _ = s.recvEventLoop() }() + + s.logger.Info("StreamSender started.") +} + +func (s *StreamSenderImpl) Stop() { + if !atomic.CompareAndSwapInt32( + &s.status, + common.DaemonStatusStarted, + common.DaemonStatusStopped, + ) { + return + } + + s.shutdownChan.Shutdown() + s.logger.Info("StreamSender stopped.") +} + +func (s *StreamSenderImpl) IsValid() bool { + return atomic.LoadInt32(&s.status) == common.DaemonStatusStarted +} + +func (s *StreamSenderImpl) Wait() { + <-s.shutdownChan.Channel() +} + +func (s *StreamSenderImpl) Key() ClusterShardKeyPair { + return ClusterShardKeyPair{ + Client: s.clientShardKey, + Server: s.serverShardKey, + } +} + +func (s *StreamSenderImpl) recvEventLoop() error { + defer s.Stop() + + for !s.shutdownChan.IsShutdown() { + req, err := s.server.Recv() + if err != nil { + s.logger.Error("StreamSender exit recv loop", tag.Error(err)) + return err + } + switch attr := req.GetAttributes().(type) { + case *historyservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState: + if err := s.recvSyncReplicationState(attr.SyncReplicationState); err != nil { + s.logger.Error("StreamSender unable to handle SyncReplicationState", tag.Error(err)) + return err + } + s.metrics.Counter(metrics.ReplicationTasksRecv.GetMetricName()).Record( + int64(1), + metrics.FromClusterIDTag(s.clientShardKey.ClusterID), + metrics.ToClusterIDTag(s.serverShardKey.ClusterID), + metrics.OperationTag(metrics.SyncWatermarkScope), + ) + default: + err := serviceerror.NewInternal(fmt.Sprintf( + "StreamReplicationMessages encountered unknown type: %T %v", attr, attr, + )) + s.logger.Error("StreamSender unable to handle request", tag.Error(err)) + return err + } + } + return nil +} + +func (s *StreamSenderImpl) sendEventLoop() error { + defer s.Stop() + + newTaskNotificationChan, subscriberID := s.historyEngine.SubscribeReplicationNotification() + defer s.historyEngine.UnsubscribeReplicationNotification(subscriberID) + + catchupEndExclusiveWatermark, err := s.sendCatchUp() + if err != nil { + s.logger.Error("StreamSender unable to catch up replication tasks", tag.Error(err)) + return err + } + if err := s.sendLive( + newTaskNotificationChan, + catchupEndExclusiveWatermark, + ); err != nil { + s.logger.Error("StreamSender unable to stream replication tasks", tag.Error(err)) + return err + } + return nil +} + +func (s *StreamSenderImpl) recvSyncReplicationState( + attr *replicationspb.SyncReplicationState, +) error { + inclusiveLowWatermark := attr.GetInclusiveLowWatermark() + inclusiveLowWatermarkTime := attr.GetInclusiveLowWatermarkTime() + + readerID := shard.ReplicationReaderIDFromClusterShardID( + int64(s.clientShardKey.ClusterID), + s.clientShardKey.ShardID, + ) + readerState := &persistencespb.QueueReaderState{ + Scopes: []*persistencespb.QueueSliceScope{{ + Range: &persistencespb.QueueSliceRange{ + InclusiveMin: shard.ConvertToPersistenceTaskKey( + tasks.NewImmediateKey(inclusiveLowWatermark), + ), + ExclusiveMax: shard.ConvertToPersistenceTaskKey( + tasks.NewImmediateKey(math.MaxInt64), + ), + }, + Predicate: &persistencespb.Predicate{ + PredicateType: enumsspb.PREDICATE_TYPE_UNIVERSAL, + Attributes: &persistencespb.Predicate_UniversalPredicateAttributes{}, + }, + }}, + } + if err := s.shardContext.UpdateReplicationQueueReaderState( + readerID, + readerState, + ); err != nil { + return err + } + s.shardContext.UpdateRemoteClusterInfo( + string(s.clientShardKey.ClusterID), + inclusiveLowWatermark-1, + *inclusiveLowWatermarkTime, + ) + return nil +} + +func (s *StreamSenderImpl) sendCatchUp() (int64, error) { + readerID := shard.ReplicationReaderIDFromClusterShardID( + int64(s.clientShardKey.ClusterID), + s.clientShardKey.ShardID, + ) + + var catchupBeginInclusiveWatermark int64 + queueState, ok := s.shardContext.GetQueueState( + tasks.CategoryReplication, + ) + if !ok { + catchupBeginInclusiveWatermark = 0 + } else { + readerState, ok := queueState.ReaderStates[readerID] + if !ok { + catchupBeginInclusiveWatermark = 0 + } else { + catchupBeginInclusiveWatermark = readerState.Scopes[0].Range.InclusiveMin.TaskId + } + } + catchupEndExclusiveWatermark := s.shardContext.GetImmediateQueueExclusiveHighReadWatermark().TaskID + if err := s.sendTasks( + catchupBeginInclusiveWatermark, + catchupEndExclusiveWatermark, + ); err != nil { + return 0, err + } + return catchupEndExclusiveWatermark, nil +} + +func (s *StreamSenderImpl) sendLive( + newTaskNotificationChan <-chan struct{}, + beginInclusiveWatermark int64, +) error { + for { + select { + case <-newTaskNotificationChan: + endExclusiveWatermark := s.shardContext.GetImmediateQueueExclusiveHighReadWatermark().TaskID + if err := s.sendTasks( + beginInclusiveWatermark, + endExclusiveWatermark, + ); err != nil { + return err + } + beginInclusiveWatermark = endExclusiveWatermark + case <-s.shutdownChan.Channel(): + return nil + } + } +} + +func (s *StreamSenderImpl) sendTasks( + beginInclusiveWatermark int64, + endExclusiveWatermark int64, +) error { + if beginInclusiveWatermark > endExclusiveWatermark { + err := serviceerror.NewInternal(fmt.Sprintf("StreamWorkflowReplication encountered invalid task range [%v, %v)", + beginInclusiveWatermark, + endExclusiveWatermark, + )) + s.logger.Error("StreamSender unable to send tasks", tag.Error(err)) + return err + } + if beginInclusiveWatermark == endExclusiveWatermark { + return s.server.Send(&historyservice.StreamWorkflowReplicationMessagesResponse{ + Attributes: &historyservice.StreamWorkflowReplicationMessagesResponse_Messages{ + Messages: &replicationspb.WorkflowReplicationMessages{ + ReplicationTasks: nil, + ExclusiveHighWatermark: endExclusiveWatermark, + ExclusiveHighWatermarkTime: timestamp.TimeNowPtrUtc(), + }, + }, + }) + } + + ctx, cancel := context.WithTimeout(context.Background(), replicationTimeout) + defer cancel() + iter, err := s.historyEngine.GetReplicationTasksIter( + ctx, + string(s.clientShardKey.ClusterID), + beginInclusiveWatermark, + endExclusiveWatermark, + ) + if err != nil { + return err + } +Loop: + for iter.HasNext() { + if s.shutdownChan.IsShutdown() { + return nil + } + + item, err := iter.Next() + if err != nil { + return err + } + task, err := s.taskConvertor.Convert(item) + if err != nil { + return err + } + if task == nil { + continue Loop + } + if err := s.server.Send(&historyservice.StreamWorkflowReplicationMessagesResponse{ + Attributes: &historyservice.StreamWorkflowReplicationMessagesResponse_Messages{ + Messages: &replicationspb.WorkflowReplicationMessages{ + ReplicationTasks: []*replicationspb.ReplicationTask{task}, + ExclusiveHighWatermark: task.SourceTaskId + 1, + ExclusiveHighWatermarkTime: task.VisibilityTime, + }, + }, + }); err != nil { + return err + } + s.metrics.Counter(metrics.ReplicationTasksSend.GetMetricName()).Record( + int64(1), + metrics.FromClusterIDTag(s.serverShardKey.ClusterID), + metrics.ToClusterIDTag(s.clientShardKey.ClusterID), + metrics.OperationTag(TaskOperationTag(task)), + ) + } + return s.server.Send(&historyservice.StreamWorkflowReplicationMessagesResponse{ + Attributes: &historyservice.StreamWorkflowReplicationMessagesResponse_Messages{ + Messages: &replicationspb.WorkflowReplicationMessages{ + ReplicationTasks: nil, + ExclusiveHighWatermark: endExclusiveWatermark, + ExclusiveHighWatermarkTime: timestamp.TimeNowPtrUtc(), + }, + }, + }) +} + +func NewSourceTaskConvertor( + historyEngine shard.Engine, + namespaceCache namespace.Registry, + clientClusterShardCount int32, + clientClusterName string, + clientShardKey ClusterShardKey, +) *SourceTaskConvertorImpl { + return &SourceTaskConvertorImpl{ + historyEngine: historyEngine, + namespaceCache: namespaceCache, + clientClusterShardCount: clientClusterShardCount, + clientClusterName: clientClusterName, + clientShardKey: clientShardKey, + } +} + +func (c *SourceTaskConvertorImpl) Convert( + task tasks.Task, +) (*replicationspb.ReplicationTask, error) { + if namespaceEntry, err := c.namespaceCache.GetNamespaceByID( + namespace.ID(task.GetNamespaceID()), + ); err == nil { + shouldProcessTask := false + FilterLoop: + for _, targetCluster := range namespaceEntry.ClusterNames() { + if c.clientClusterName == targetCluster { + shouldProcessTask = true + break FilterLoop + } + } + if !shouldProcessTask { + return nil, nil + } + } + // if there is error, then blindly send the task, better safe than sorry + + clientShardID := common.WorkflowIDToHistoryShard(task.GetNamespaceID(), task.GetWorkflowID(), c.clientClusterShardCount) + if clientShardID != c.clientShardKey.ShardID { + return nil, nil + } + + ctx, cancel := context.WithTimeout(context.Background(), replicationTimeout) + defer cancel() + replicationTask, err := c.historyEngine.ConvertReplicationTask(ctx, task) + if err != nil { + return nil, err + } + return replicationTask, nil +} diff --git a/service/history/replication/stream_sender_mock.go b/service/history/replication/stream_sender_mock.go new file mode 100644 index 00000000000..0599dccbeab --- /dev/null +++ b/service/history/replication/stream_sender_mock.go @@ -0,0 +1,150 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +// Code generated by MockGen. DO NOT EDIT. +// Source: stream_sender.go + +// Package replication is a generated GoMock package. +package replication + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + v1 "go.temporal.io/server/api/replication/v1" + tasks "go.temporal.io/server/service/history/tasks" +) + +// MockSourceTaskConvertor is a mock of SourceTaskConvertor interface. +type MockSourceTaskConvertor struct { + ctrl *gomock.Controller + recorder *MockSourceTaskConvertorMockRecorder +} + +// MockSourceTaskConvertorMockRecorder is the mock recorder for MockSourceTaskConvertor. +type MockSourceTaskConvertorMockRecorder struct { + mock *MockSourceTaskConvertor +} + +// NewMockSourceTaskConvertor creates a new mock instance. +func NewMockSourceTaskConvertor(ctrl *gomock.Controller) *MockSourceTaskConvertor { + mock := &MockSourceTaskConvertor{ctrl: ctrl} + mock.recorder = &MockSourceTaskConvertorMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSourceTaskConvertor) EXPECT() *MockSourceTaskConvertorMockRecorder { + return m.recorder +} + +// Convert mocks base method. +func (m *MockSourceTaskConvertor) Convert(task tasks.Task) (*v1.ReplicationTask, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Convert", task) + ret0, _ := ret[0].(*v1.ReplicationTask) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Convert indicates an expected call of Convert. +func (mr *MockSourceTaskConvertorMockRecorder) Convert(task interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Convert", reflect.TypeOf((*MockSourceTaskConvertor)(nil).Convert), task) +} + +// MockStreamSender is a mock of StreamSender interface. +type MockStreamSender struct { + ctrl *gomock.Controller + recorder *MockStreamSenderMockRecorder +} + +// MockStreamSenderMockRecorder is the mock recorder for MockStreamSender. +type MockStreamSenderMockRecorder struct { + mock *MockStreamSender +} + +// NewMockStreamSender creates a new mock instance. +func NewMockStreamSender(ctrl *gomock.Controller) *MockStreamSender { + mock := &MockStreamSender{ctrl: ctrl} + mock.recorder = &MockStreamSenderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStreamSender) EXPECT() *MockStreamSenderMockRecorder { + return m.recorder +} + +// IsValid mocks base method. +func (m *MockStreamSender) IsValid() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsValid") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsValid indicates an expected call of IsValid. +func (mr *MockStreamSenderMockRecorder) IsValid() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsValid", reflect.TypeOf((*MockStreamSender)(nil).IsValid)) +} + +// Key mocks base method. +func (m *MockStreamSender) Key() ClusterShardKeyPair { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Key") + ret0, _ := ret[0].(ClusterShardKeyPair) + return ret0 +} + +// Key indicates an expected call of Key. +func (mr *MockStreamSenderMockRecorder) Key() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Key", reflect.TypeOf((*MockStreamSender)(nil).Key)) +} + +// Start mocks base method. +func (m *MockStreamSender) Start() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Start") +} + +// Start indicates an expected call of Start. +func (mr *MockStreamSenderMockRecorder) Start() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockStreamSender)(nil).Start)) +} + +// Stop mocks base method. +func (m *MockStreamSender) Stop() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Stop") +} + +// Stop indicates an expected call of Stop. +func (mr *MockStreamSenderMockRecorder) Stop() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockStreamSender)(nil).Stop)) +} diff --git a/service/history/api/replication/stream_test.go b/service/history/replication/stream_sender_test.go similarity index 82% rename from service/history/api/replication/stream_test.go rename to service/history/replication/stream_sender_test.go index c3f64a0ec0a..86c6b7529be 100644 --- a/service/history/api/replication/stream_test.go +++ b/service/history/replication/stream_sender_test.go @@ -25,7 +25,6 @@ package replication import ( - "context" "math" "math/rand" "testing" @@ -40,7 +39,6 @@ import ( "go.temporal.io/server/api/historyservicemock/v1" persistencespb "go.temporal.io/server/api/persistence/v1" replicationspb "go.temporal.io/server/api/replication/v1" - historyclient "go.temporal.io/server/client/history" "go.temporal.io/server/common/collection" "go.temporal.io/server/common/log" "go.temporal.io/server/common/metrics" @@ -52,7 +50,7 @@ import ( ) type ( - streamSuite struct { + streamSenderSuite struct { suite.Suite *require.Assertions @@ -60,59 +58,61 @@ type ( server *historyservicemock.MockHistoryService_StreamWorkflowReplicationMessagesServer shardContext *shard.MockContext historyEngine *shard.MockEngine - taskConvertor *MockTaskConvertor + taskConvertor *MockSourceTaskConvertor - ctx context.Context - cancel context.CancelFunc - clientClusterShardID historyclient.ClusterShardID - serverClusterShardID historyclient.ClusterShardID + clientShardKey ClusterShardKey + serverShardKey ClusterShardKey + + streamSender *StreamSenderImpl } ) -func TestStreamSuite(t *testing.T) { - s := new(streamSuite) +func TestStreamSenderSuite(t *testing.T) { + s := new(streamSenderSuite) suite.Run(t, s) } -func (s *streamSuite) SetupSuite() { +func (s *streamSenderSuite) SetupSuite() { rand.Seed(time.Now().UnixNano()) } -func (s *streamSuite) TearDownSuite() { +func (s *streamSenderSuite) TearDownSuite() { } -func (s *streamSuite) SetupTest() { +func (s *streamSenderSuite) SetupTest() { s.Assertions = require.New(s.T()) s.controller = gomock.NewController(s.T()) s.server = historyservicemock.NewMockHistoryService_StreamWorkflowReplicationMessagesServer(s.controller) s.shardContext = shard.NewMockContext(s.controller) s.historyEngine = shard.NewMockEngine(s.controller) - s.taskConvertor = NewMockTaskConvertor(s.controller) + s.taskConvertor = NewMockSourceTaskConvertor(s.controller) - s.ctx, s.cancel = context.WithCancel(context.Background()) - s.clientClusterShardID = historyclient.ClusterShardID{ - ClusterID: rand.Int31(), - ShardID: rand.Int31(), - } - s.serverClusterShardID = historyclient.ClusterShardID{ - ClusterID: rand.Int31(), - ShardID: rand.Int31(), - } + s.clientShardKey = NewClusterShardKey(rand.Int31(), rand.Int31()) + s.serverShardKey = NewClusterShardKey(rand.Int31(), rand.Int31()) s.shardContext.EXPECT().GetEngine(gomock.Any()).Return(s.historyEngine, nil).AnyTimes() s.shardContext.EXPECT().GetMetricsHandler().Return(metrics.NoopMetricsHandler).AnyTimes() s.shardContext.EXPECT().GetLogger().Return(log.NewNoopLogger()).AnyTimes() + + s.streamSender = NewStreamSender( + s.server, + s.shardContext, + s.historyEngine, + s.taskConvertor, + s.clientShardKey, + s.serverShardKey, + ) } -func (s *streamSuite) TearDownTest() { +func (s *streamSenderSuite) TearDownTest() { s.controller.Finish() } -func (s *streamSuite) TestRecvSyncReplicationState_Success() { +func (s *streamSenderSuite) TestRecvSyncReplicationState_Success() { readerID := shard.ReplicationReaderIDFromClusterShardID( - int64(s.clientClusterShardID.ClusterID), - s.clientClusterShardID.ShardID, + int64(s.clientShardKey.ClusterID), + s.clientShardKey.ShardID, ) replicationState := &replicationspb.SyncReplicationState{ InclusiveLowWatermark: rand.Int63(), @@ -139,19 +139,19 @@ func (s *streamSuite) TestRecvSyncReplicationState_Success() { }, ).Return(nil) s.shardContext.EXPECT().UpdateRemoteClusterInfo( - string(s.clientClusterShardID.ClusterID), + string(s.clientShardKey.ClusterID), replicationState.InclusiveLowWatermark-1, *replicationState.InclusiveLowWatermarkTime, ) - err := recvSyncReplicationState(s.shardContext, replicationState, s.clientClusterShardID) + err := s.streamSender.recvSyncReplicationState(replicationState) s.NoError(err) } -func (s *streamSuite) TestRecvSyncReplicationState_Error() { +func (s *streamSenderSuite) TestRecvSyncReplicationState_Error() { readerID := shard.ReplicationReaderIDFromClusterShardID( - int64(s.clientClusterShardID.ClusterID), - s.clientClusterShardID.ShardID, + int64(s.clientShardKey.ClusterID), + s.clientShardKey.ShardID, ) replicationState := &replicationspb.SyncReplicationState{ InclusiveLowWatermark: rand.Int63(), @@ -185,15 +185,15 @@ func (s *streamSuite) TestRecvSyncReplicationState_Error() { }, ).Return(ownershipLost) - err := recvSyncReplicationState(s.shardContext, replicationState, s.clientClusterShardID) + err := s.streamSender.recvSyncReplicationState(replicationState) s.Error(err) s.Equal(ownershipLost, err) } -func (s *streamSuite) TestSendCatchUp() { +func (s *streamSenderSuite) TestSendCatchUp() { readerID := shard.ReplicationReaderIDFromClusterShardID( - int64(s.clientClusterShardID.ClusterID), - s.clientClusterShardID.ShardID, + int64(s.clientShardKey.ClusterID), + s.clientShardKey.ShardID, ) beginInclusiveWatermark := rand.Int63() endExclusiveWatermark := beginInclusiveWatermark + 1 @@ -230,8 +230,8 @@ func (s *streamSuite) TestSendCatchUp() { }, ) s.historyEngine.EXPECT().GetReplicationTasksIter( - s.ctx, - string(s.clientClusterShardID.ClusterID), + gomock.Any(), + string(s.clientShardKey.ClusterID), beginInclusiveWatermark, endExclusiveWatermark, ).Return(iter, nil) @@ -241,19 +241,12 @@ func (s *streamSuite) TestSendCatchUp() { return nil }) - taskID, err := sendCatchUp( - s.ctx, - s.server, - s.shardContext, - s.taskConvertor, - s.clientClusterShardID, - s.serverClusterShardID, - ) + taskID, err := s.streamSender.sendCatchUp() s.NoError(err) s.Equal(endExclusiveWatermark, taskID) } -func (s *streamSuite) TestSendLive() { +func (s *streamSenderSuite) TestSendLive() { channel := make(chan struct{}) watermark0 := rand.Int63() watermark1 := watermark0 + 1 + rand.Int63n(100) @@ -274,14 +267,14 @@ func (s *streamSuite) TestSendLive() { ) gomock.InOrder( s.historyEngine.EXPECT().GetReplicationTasksIter( - s.ctx, - string(s.clientClusterShardID.ClusterID), + gomock.Any(), + string(s.clientShardKey.ClusterID), watermark0, watermark1, ).Return(iter, nil), s.historyEngine.EXPECT().GetReplicationTasksIter( - s.ctx, - string(s.clientClusterShardID.ClusterID), + gomock.Any(), + string(s.clientShardKey.ClusterID), watermark1, watermark2, ).Return(iter, nil), @@ -301,22 +294,17 @@ func (s *streamSuite) TestSendLive() { go func() { channel <- struct{}{} channel <- struct{}{} - s.cancel() + s.streamSender.shutdownChan.Shutdown() }() - err := sendLive( - s.ctx, - s.server, - s.shardContext, - s.taskConvertor, - s.clientClusterShardID, - s.serverClusterShardID, + err := s.streamSender.sendLive( channel, watermark0, ) - s.Equal(s.ctx.Err(), err) + s.Nil(err) + s.True(!s.streamSender.IsValid()) } -func (s *streamSuite) TestSendTasks_Noop() { +func (s *streamSenderSuite) TestSendTasks_Noop() { beginInclusiveWatermark := rand.Int63() endExclusiveWatermark := beginInclusiveWatermark @@ -326,20 +314,14 @@ func (s *streamSuite) TestSendTasks_Noop() { return nil }) - err := sendTasks( - s.ctx, - s.server, - s.shardContext, - s.taskConvertor, - s.clientClusterShardID, - s.serverClusterShardID, + err := s.streamSender.sendTasks( beginInclusiveWatermark, endExclusiveWatermark, ) s.NoError(err) } -func (s *streamSuite) TestSendTasks_WithoutTasks() { +func (s *streamSenderSuite) TestSendTasks_WithoutTasks() { beginInclusiveWatermark := rand.Int63() endExclusiveWatermark := beginInclusiveWatermark + 100 @@ -349,8 +331,8 @@ func (s *streamSuite) TestSendTasks_WithoutTasks() { }, ) s.historyEngine.EXPECT().GetReplicationTasksIter( - s.ctx, - string(s.clientClusterShardID.ClusterID), + gomock.Any(), + string(s.clientShardKey.ClusterID), beginInclusiveWatermark, endExclusiveWatermark, ).Return(iter, nil) @@ -360,20 +342,14 @@ func (s *streamSuite) TestSendTasks_WithoutTasks() { return nil }) - err := sendTasks( - s.ctx, - s.server, - s.shardContext, - s.taskConvertor, - s.clientClusterShardID, - s.serverClusterShardID, + err := s.streamSender.sendTasks( beginInclusiveWatermark, endExclusiveWatermark, ) s.NoError(err) } -func (s *streamSuite) TestSendTasks_WithTasks() { +func (s *streamSenderSuite) TestSendTasks_WithTasks() { beginInclusiveWatermark := rand.Int63() endExclusiveWatermark := beginInclusiveWatermark + 100 item0 := tasks.NewMockTask(s.controller) @@ -394,8 +370,8 @@ func (s *streamSuite) TestSendTasks_WithTasks() { }, ) s.historyEngine.EXPECT().GetReplicationTasksIter( - s.ctx, - string(s.clientClusterShardID.ClusterID), + gomock.Any(), + string(s.clientShardKey.ClusterID), beginInclusiveWatermark, endExclusiveWatermark, ).Return(iter, nil) @@ -428,13 +404,7 @@ func (s *streamSuite) TestSendTasks_WithTasks() { }), ) - err := sendTasks( - s.ctx, - s.server, - s.shardContext, - s.taskConvertor, - s.clientClusterShardID, - s.serverClusterShardID, + err := s.streamSender.sendTasks( beginInclusiveWatermark, endExclusiveWatermark, ) From dd4c7de946a907a5ce824ef559f15648b0ad65f5 Mon Sep 17 00:00:00 2001 From: Alex Shtin Date: Thu, 8 Jun 2023 12:12:03 -0700 Subject: [PATCH 7/9] Enable workflow update is server development configs (#4461) --- config/dynamicconfig/development-cass.yaml | 4 ++++ config/dynamicconfig/development-sql.yaml | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/config/dynamicconfig/development-cass.yaml b/config/dynamicconfig/development-cass.yaml index fb49ebc3edb..b560b529e09 100644 --- a/config/dynamicconfig/development-cass.yaml +++ b/config/dynamicconfig/development-cass.yaml @@ -32,3 +32,7 @@ # - value: true system.enableEagerWorkflowStart: - value: true +frontend.enableUpdateWorkflowExecution: + - value: true +frontend.enableUpdateWorkflowExecutionAsyncAccepted: + - value: true diff --git a/config/dynamicconfig/development-sql.yaml b/config/dynamicconfig/development-sql.yaml index 23ce85333d6..0d6d3902bed 100644 --- a/config/dynamicconfig/development-sql.yaml +++ b/config/dynamicconfig/development-sql.yaml @@ -35,3 +35,7 @@ system.enableEagerWorkflowStart: limit.maxIDLength: - value: 255 constraints: {} +frontend.enableUpdateWorkflowExecution: + - value: true +frontend.enableUpdateWorkflowExecutionAsyncAccepted: + - value: true From 137886a7bdfb2c290abfbe9dbf9d57c3a6de3c60 Mon Sep 17 00:00:00 2001 From: Spencer Judge Date: Thu, 8 Jun 2023 13:31:07 -0700 Subject: [PATCH 8/9] Implement merge sets operation for Build IDs (#4447) --- go.mod | 2 +- go.sum | 4 +- service/matching/version_sets.go | 38 ++++++++- service/matching/version_sets_test.go | 106 ++++++++++++++++++++++++++ 4 files changed, 146 insertions(+), 4 deletions(-) diff --git a/go.mod b/go.mod index 1f039ee58ee..713b813d132 100644 --- a/go.mod +++ b/go.mod @@ -44,7 +44,7 @@ require ( go.opentelemetry.io/otel/metric v1.16.0 go.opentelemetry.io/otel/sdk v1.16.0 go.opentelemetry.io/otel/sdk/metric v0.39.0 - go.temporal.io/api v1.22.0 + go.temporal.io/api v1.23.0 go.temporal.io/sdk v1.23.0 go.temporal.io/version v0.3.0 go.uber.org/atomic v1.10.0 diff --git a/go.sum b/go.sum index 89578de4ee7..1b77681dd70 100644 --- a/go.sum +++ b/go.sum @@ -1125,8 +1125,8 @@ go.opentelemetry.io/proto/otlp v0.15.0/go.mod h1:H7XAot3MsfNsj7EXtrA2q5xSNQ10UqI go.opentelemetry.io/proto/otlp v0.19.0 h1:IVN6GR+mhC4s5yfcTbmzHYODqvWAp3ZedA2SJPI1Nnw= go.opentelemetry.io/proto/otlp v0.19.0/go.mod h1:H7XAot3MsfNsj7EXtrA2q5xSNQ10UqI405h3+duxN4U= go.temporal.io/api v1.21.0/go.mod h1:xlsUEakkN2vU2/WV7e5NqMG4N93nfuNfvbXdaXUpU8w= -go.temporal.io/api v1.22.0 h1:XshAWMbKbyJws+brrNe2UCcI501iiIAJAd0uX40yw7s= -go.temporal.io/api v1.22.0/go.mod h1:AcJd1+rc1j0zte+ZBIkOHGHjntR/17LnZWFz+gMFHQ0= +go.temporal.io/api v1.23.0 h1:4y9mTQjEHsE0Du0WJ2ExJUcP/1/a+B/UefzIDm4ALTE= +go.temporal.io/api v1.23.0/go.mod h1:AcJd1+rc1j0zte+ZBIkOHGHjntR/17LnZWFz+gMFHQ0= go.temporal.io/sdk v1.23.0 h1:oa9/1f3bbcBLiNGbYf9woIx7uWFJ153q0JOkPeZqJtQ= go.temporal.io/sdk v1.23.0/go.mod h1:S7vWxU01lGcCny0sWx03bkkYw4VtVrpzeqBTn2A6y+E= go.temporal.io/version v0.3.0 h1:dMrei9l9NyHt8nG6EB8vAwDLLTwx2SvRyucCSumAiig= diff --git a/service/matching/version_sets.go b/service/matching/version_sets.go index a8436d5d340..2f3c817cb27 100644 --- a/service/matching/version_sets.go +++ b/service/matching/version_sets.go @@ -253,6 +253,37 @@ func updateImpl(timestamp hlc.Clock, existingData *persistencespb.VersioningData BuildIds: buildIDsCopy, } makeVersionInSetDefault(&modifiedData, targetSetIdx, versionInSetIdx, ×tamp) + } else if mergeSets := req.GetMergeSets(); mergeSets != nil { + if targetSetIdx == -1 { + return nil, serviceerror.NewNotFound(fmt.Sprintf("targeted primary version %v not found", targetedVersion)) + } + secondaryBuildID := mergeSets.GetSecondarySetBuildId() + secondarySetIdx, _ := findVersion(&modifiedData, secondaryBuildID) + if secondarySetIdx == -1 { + return nil, serviceerror.NewNotFound(fmt.Sprintf("targeted secondary version %v not found", secondaryBuildID)) + } + if targetSetIdx == secondarySetIdx { + // Nothing to be done + return existingData, nil + } + // Merge the sets together, preserving the primary set's default by making it have the most recent timestamp. + primarySet := modifiedData.VersionSets[targetSetIdx] + justPrimaryData := &persistencespb.VersioningData{ + VersionSets: []*persistencespb.CompatibleVersionSet{{ + SetIds: primarySet.SetIds, + BuildIds: primarySet.BuildIds, + DefaultUpdateTimestamp: ×tamp, + }}, + DefaultUpdateTimestamp: modifiedData.DefaultUpdateTimestamp, + } + secondarySet := modifiedData.VersionSets[secondarySetIdx] + modifiedData.VersionSets[secondarySetIdx] = &persistencespb.CompatibleVersionSet{ + SetIds: mergeSetIDs(primarySet.SetIds, secondarySet.SetIds), + BuildIds: secondarySet.BuildIds, + DefaultUpdateTimestamp: secondarySet.DefaultUpdateTimestamp, + } + mergedData := MergeVersioningData(justPrimaryData, &modifiedData) + modifiedData = *mergedData } return &modifiedData, nil @@ -265,13 +296,18 @@ func extractTargetedVersion(req *workflowservice.UpdateWorkerBuildIdCompatibilit return req.GetPromoteSetByBuildId() } else if req.GetPromoteBuildIdWithinSet() != "" { return req.GetPromoteBuildIdWithinSet() + } else if req.GetAddNewBuildIdInNewDefaultSet() != "" { + return req.GetAddNewBuildIdInNewDefaultSet() } - return req.GetAddNewBuildIdInNewDefaultSet() + return req.GetMergeSets().GetPrimarySetBuildId() } // Finds the version in the version sets, returning (set index, index within that set) // Returns -1, -1 if not found. func findVersion(data *persistencespb.VersioningData, buildID string) (setIndex, indexInSet int) { + if buildID == "" { + return -1, -1 + } for setIndex, set := range data.GetVersionSets() { for indexInSet, version := range set.GetBuildIds() { if version.Id == buildID { diff --git a/service/matching/version_sets_test.go b/service/matching/version_sets_test.go index 4b25adb5e51..1481d883395 100644 --- a/service/matching/version_sets_test.go +++ b/service/matching/version_sets_test.go @@ -98,8 +98,19 @@ func mkPromoteInSet(id string) *workflowservice.UpdateWorkerBuildIdCompatibility }, } } +func mkMergeSet(primaryId string, secondaryId string) *workflowservice.UpdateWorkerBuildIdCompatibilityRequest { + return &workflowservice.UpdateWorkerBuildIdCompatibilityRequest{ + Operation: &workflowservice.UpdateWorkerBuildIdCompatibilityRequest_MergeSets_{ + MergeSets: &workflowservice.UpdateWorkerBuildIdCompatibilityRequest_MergeSets{ + PrimarySetBuildId: primaryId, + SecondarySetBuildId: secondaryId, + }, + }, + } +} func TestNewDefaultUpdate(t *testing.T) { + t.Parallel() clock := hlc.Zero(1) initialData := mkInitialData(2, clock) @@ -136,6 +147,7 @@ func TestNewDefaultUpdate(t *testing.T) { } func TestNewDefaultSetUpdateOfEmptyData(t *testing.T) { + t.Parallel() clock := hlc.Zero(1) initialData := mkInitialData(0, clock) @@ -159,6 +171,7 @@ func TestNewDefaultSetUpdateOfEmptyData(t *testing.T) { } func TestNewDefaultSetUpdateCompatWithCurDefault(t *testing.T) { + t.Parallel() clock := hlc.Zero(1) initialData := mkInitialData(2, clock) @@ -190,6 +203,7 @@ func TestNewDefaultSetUpdateCompatWithCurDefault(t *testing.T) { } func TestNewDefaultSetUpdateCompatWithNonDefaultSet(t *testing.T) { + t.Parallel() clock := hlc.Zero(1) initialData := mkInitialData(2, clock) @@ -221,6 +235,7 @@ func TestNewDefaultSetUpdateCompatWithNonDefaultSet(t *testing.T) { } func TestNewCompatibleWithVerInOlderSet(t *testing.T) { + t.Parallel() clock := hlc.Zero(1) initialData := mkInitialData(2, clock) @@ -255,6 +270,7 @@ func TestNewCompatibleWithVerInOlderSet(t *testing.T) { } func TestNewCompatibleWithNonDefaultSetUpdate(t *testing.T) { + t.Parallel() clock0 := hlc.Zero(1) data := mkInitialData(2, clock0) @@ -320,6 +336,7 @@ func TestNewCompatibleWithNonDefaultSetUpdate(t *testing.T) { } func TestCompatibleTargetsNotFound(t *testing.T) { + t.Parallel() clock := hlc.Zero(1) data := mkInitialData(1, clock) @@ -331,6 +348,7 @@ func TestCompatibleTargetsNotFound(t *testing.T) { } func TestMakeExistingSetDefault(t *testing.T) { + t.Parallel() clock0 := hlc.Zero(1) data := mkInitialData(3, clock0) @@ -397,6 +415,7 @@ func TestMakeExistingSetDefault(t *testing.T) { } func TestSayVersionIsCompatWithDifferentSetThanItsAlreadyCompatWithNotAllowed(t *testing.T) { + t.Parallel() clock := hlc.Zero(1) data := mkInitialData(3, clock) @@ -411,6 +430,7 @@ func TestSayVersionIsCompatWithDifferentSetThanItsAlreadyCompatWithNotAllowed(t } func TestLimitsMaxSets(t *testing.T) { + t.Parallel() clock := hlc.Zero(1) maxSets := 10 data := mkInitialData(maxSets, clock) @@ -422,6 +442,7 @@ func TestLimitsMaxSets(t *testing.T) { } func TestLimitsMaxBuildIds(t *testing.T) { + t.Parallel() clock := hlc.Zero(1) maxBuildIds := 10 data := mkInitialData(maxBuildIds, clock) @@ -433,6 +454,7 @@ func TestLimitsMaxBuildIds(t *testing.T) { } func TestPromoteWithinVersion(t *testing.T) { + t.Parallel() clock0 := hlc.Zero(1) data := mkInitialData(2, clock0) @@ -472,6 +494,7 @@ func TestPromoteWithinVersion(t *testing.T) { } func TestAddNewDefaultAlreadyExtantVersionWithNoConflictSucceeds(t *testing.T) { + t.Parallel() clock := hlc.Zero(1) original := mkInitialData(3, clock) @@ -482,6 +505,7 @@ func TestAddNewDefaultAlreadyExtantVersionWithNoConflictSucceeds(t *testing.T) { } func TestAddToExistingSetAlreadyExtantVersionWithNoConflictSucceeds(t *testing.T) { + t.Parallel() clock := hlc.Zero(1) req := mkNewCompatReq("1.1", "1", false) original, err := UpdateVersionSets(clock, mkInitialData(3, clock), req, 0, 0) @@ -492,6 +516,7 @@ func TestAddToExistingSetAlreadyExtantVersionWithNoConflictSucceeds(t *testing.T } func TestAddToExistingSetAlreadyExtantVersionErrorsIfNotDefault(t *testing.T) { + t.Parallel() clock := hlc.Zero(1) req := mkNewCompatReq("1.1", "1", true) original, err := UpdateVersionSets(clock, mkInitialData(3, clock), req, 0, 0) @@ -503,6 +528,7 @@ func TestAddToExistingSetAlreadyExtantVersionErrorsIfNotDefault(t *testing.T) { } func TestAddToExistingSetAlreadyExtantVersionErrorsIfNotDefaultSet(t *testing.T) { + t.Parallel() clock := hlc.Zero(1) req := mkNewCompatReq("1.1", "1", false) original, err := UpdateVersionSets(clock, mkInitialData(3, clock), req, 0, 0) @@ -514,6 +540,7 @@ func TestAddToExistingSetAlreadyExtantVersionErrorsIfNotDefaultSet(t *testing.T) } func TestPromoteWithinSetAlreadyPromotedIsANoop(t *testing.T) { + t.Parallel() clock0 := hlc.Zero(1) original := mkInitialData(3, clock0) req := mkPromoteInSet("1") @@ -524,6 +551,7 @@ func TestPromoteWithinSetAlreadyPromotedIsANoop(t *testing.T) { } func TestPromoteSetAlreadyPromotedIsANoop(t *testing.T) { + t.Parallel() clock0 := hlc.Zero(1) original := mkInitialData(3, clock0) req := mkExistingDefault("2") @@ -534,6 +562,7 @@ func TestPromoteSetAlreadyPromotedIsANoop(t *testing.T) { } func TestAddAlreadyExtantVersionAsDefaultErrors(t *testing.T) { + t.Parallel() clock := hlc.Zero(1) data := mkInitialData(3, clock) @@ -544,6 +573,7 @@ func TestAddAlreadyExtantVersionAsDefaultErrors(t *testing.T) { } func TestAddAlreadyExtantVersionToAnotherSetErrors(t *testing.T) { + t.Parallel() clock := hlc.Zero(1) data := mkInitialData(3, clock) @@ -554,6 +584,7 @@ func TestAddAlreadyExtantVersionToAnotherSetErrors(t *testing.T) { } func TestMakeSetDefaultTargetingNonexistentVersionErrors(t *testing.T) { + t.Parallel() clock := hlc.Zero(1) data := mkInitialData(3, clock) @@ -564,6 +595,7 @@ func TestMakeSetDefaultTargetingNonexistentVersionErrors(t *testing.T) { } func TestPromoteWithinSetTargetingNonexistentVersionErrors(t *testing.T) { + t.Parallel() clock := hlc.Zero(1) data := mkInitialData(3, clock) @@ -574,6 +606,7 @@ func TestPromoteWithinSetTargetingNonexistentVersionErrors(t *testing.T) { } func TestToBuildIdOrderingResponseTrimsResponse(t *testing.T) { + t.Parallel() clock := hlc.Zero(1) data := mkInitialData(3, clock) actual := ToBuildIdOrderingResponse(data, 2) @@ -582,6 +615,7 @@ func TestToBuildIdOrderingResponseTrimsResponse(t *testing.T) { } func TestToBuildIdOrderingResponseOmitsDeleted(t *testing.T) { + t.Parallel() clock := hlc.Zero(1) data := &persistencespb.VersioningData{ DefaultUpdateTimestamp: &clock, @@ -602,11 +636,13 @@ func TestToBuildIdOrderingResponseOmitsDeleted(t *testing.T) { } func TestHashBuildId(t *testing.T) { + t.Parallel() // This function should never change. assert.Equal(t, "ftrPuUeORv2JD4Wp2wTU", hashBuildId("my-build-id")) } func TestGetBuildIdDeltas(t *testing.T) { + t.Parallel() clock := hlc.Zero(0) prev := &persistencespb.VersioningData{ DefaultUpdateTimestamp: &clock, @@ -644,7 +680,77 @@ func TestGetBuildIdDeltas(t *testing.T) { } func TestGetBuildIdDeltas_AcceptsNils(t *testing.T) { + t.Parallel() added, removed := GetBuildIdDeltas(nil, nil) assert.Equal(t, []string(nil), removed) assert.Equal(t, []string(nil), added) } + +func TestMergeSets(t *testing.T) { + t.Parallel() + clock := hlc.Zero(1) + initialData := mkInitialData(4, clock) + + req := mkMergeSet("1", "2") + nextClock := hlc.Next(clock, commonclock.NewRealTimeSource()) + updatedData, err := UpdateVersionSets(nextClock, initialData, req, 0, 0) + assert.NoError(t, err) + // Should only be three sets now + assert.Equal(t, 3, len(updatedData.VersionSets)) + // The overall default set should not have changed + assert.Equal(t, "3", updatedData.GetVersionSets()[2].GetBuildIds()[0].Id) + // But set 1 should now have 2, maintaining 1 as the default ID + assert.Equal(t, "1", updatedData.GetVersionSets()[1].GetBuildIds()[1].Id) + assert.Equal(t, "2", updatedData.GetVersionSets()[1].GetBuildIds()[0].Id) + // Ensure it has the set ids of both sets + bothSetIds := mergeSetIDs([]string{hashBuildId("1")}, []string{hashBuildId("2")}) + assert.Equal(t, bothSetIds, updatedData.GetVersionSets()[1].GetSetIds()) + assert.Equal(t, initialData.DefaultUpdateTimestamp, updatedData.DefaultUpdateTimestamp) + assert.Equal(t, nextClock, *updatedData.GetVersionSets()[1].DefaultUpdateTimestamp) + // Initial data should not have changed + assert.Equal(t, 4, len(initialData.VersionSets)) + for _, set := range initialData.VersionSets { + assert.Equal(t, 1, len(set.GetSetIds())) + assert.Equal(t, clock, *set.DefaultUpdateTimestamp) + } + + // Same merge request must be idempotent + nextClock2 := hlc.Next(nextClock, commonclock.NewRealTimeSource()) + updatedData2, err := UpdateVersionSets(nextClock2, updatedData, req, 0, 0) + assert.NoError(t, err) + assert.Equal(t, 3, len(updatedData2.VersionSets)) + assert.Equal(t, "3", updatedData2.GetVersionSets()[2].GetBuildIds()[0].Id) + assert.Equal(t, "1", updatedData2.GetVersionSets()[1].GetBuildIds()[1].Id) + assert.Equal(t, "2", updatedData2.GetVersionSets()[1].GetBuildIds()[0].Id) + assert.Equal(t, initialData.DefaultUpdateTimestamp, updatedData2.DefaultUpdateTimestamp) + // Clock shouldn't have changed + assert.Equal(t, nextClock, *updatedData2.GetVersionSets()[1].DefaultUpdateTimestamp) + + // Verify merging into the current default maintains that set as the default + req = mkMergeSet("3", "0") + nextClock3 := hlc.Next(nextClock2, commonclock.NewRealTimeSource()) + updatedData3, err := UpdateVersionSets(nextClock3, updatedData2, req, 0, 0) + assert.NoError(t, err) + assert.Equal(t, 2, len(updatedData3.VersionSets)) + assert.Equal(t, "3", updatedData3.GetVersionSets()[1].GetBuildIds()[1].Id) + assert.Equal(t, "0", updatedData3.GetVersionSets()[1].GetBuildIds()[0].Id) + assert.Equal(t, "1", updatedData3.GetVersionSets()[0].GetBuildIds()[1].Id) + assert.Equal(t, "2", updatedData3.GetVersionSets()[0].GetBuildIds()[0].Id) + assert.Equal(t, initialData.DefaultUpdateTimestamp, updatedData3.DefaultUpdateTimestamp) + assert.Equal(t, nextClock3, *updatedData3.GetVersionSets()[1].DefaultUpdateTimestamp) +} + +func TestMergeInvalidTargets(t *testing.T) { + t.Parallel() + clock := hlc.Zero(1) + initialData := mkInitialData(4, clock) + + nextClock := hlc.Next(clock, commonclock.NewRealTimeSource()) + req := mkMergeSet("lol", "2") + _, err := UpdateVersionSets(nextClock, initialData, req, 0, 0) + assert.Error(t, err) + + req2 := mkMergeSet("2", "nope") + _, err2 := UpdateVersionSets(nextClock, initialData, req2, 0, 0) + assert.Error(t, err2) +} From 912b8890bb1d156c321db8b6c6304d69bdb04992 Mon Sep 17 00:00:00 2001 From: Michael Snowden Date: Thu, 8 Jun 2023 21:15:20 -0700 Subject: [PATCH 9/9] Remove unused MetricType concept (#4452) --- common/metrics/defs.go | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/common/metrics/defs.go b/common/metrics/defs.go index f38c285deaf..a7c41ed34fb 100644 --- a/common/metrics/defs.go +++ b/common/metrics/defs.go @@ -30,14 +30,10 @@ type ( // MetricName is the name of the metric MetricName string - // MetricType is the type of the metric - MetricType int - MetricUnit string // metricDefinition contains the definition for a metric metricDefinition struct { - metricType MetricType // metric type metricName MetricName // metric name unit MetricUnit } @@ -54,14 +50,6 @@ const ( Bytes = "By" ) -// MetricTypes which are supported -const ( - Counter MetricType = iota - Timer - Gauge - Histogram -) - // Empty returns true if the metricName is an empty string func (mn MetricName) Empty() bool { return mn == "" @@ -72,10 +60,6 @@ func (mn MetricName) String() string { return string(mn) } -func (md metricDefinition) GetMetricType() MetricType { - return md.metricType -} - func (md metricDefinition) GetMetricName() string { return md.metricName.String() } @@ -85,21 +69,21 @@ func (md metricDefinition) GetMetricUnit() MetricUnit { } func NewTimerDef(name string) metricDefinition { - return metricDefinition{metricName: MetricName(name), metricType: Timer, unit: Milliseconds} + return metricDefinition{metricName: MetricName(name), unit: Milliseconds} } func NewBytesHistogramDef(name string) metricDefinition { - return metricDefinition{metricName: MetricName(name), metricType: Histogram, unit: Bytes} + return metricDefinition{metricName: MetricName(name), unit: Bytes} } func NewDimensionlessHistogramDef(name string) metricDefinition { - return metricDefinition{metricName: MetricName(name), metricType: Histogram, unit: Dimensionless} + return metricDefinition{metricName: MetricName(name), unit: Dimensionless} } func NewCounterDef(name string) metricDefinition { - return metricDefinition{metricName: MetricName(name), metricType: Counter} + return metricDefinition{metricName: MetricName(name)} } func NewGaugeDef(name string) metricDefinition { - return metricDefinition{metricName: MetricName(name), metricType: Gauge} + return metricDefinition{metricName: MetricName(name)} }