From 7e2de6902fa98cc98884b4f4003b5a0e8a91936e Mon Sep 17 00:00:00 2001 From: Michael Snowden Date: Thu, 13 Jul 2023 17:05:58 -0700 Subject: [PATCH 1/3] Fix some minor warnings in the matching package (#4608) **What changed?** I went through all the GoLand IDE inspection errors in the matching package, and I fixed the ones that seemed relevant. Most of them are little typos. **Why?** To clean things up a bit, make it easier to find actual errors from inspection, make sure these don't have to get included in other actual behavioral changes. **How did you test it?** I made sure not to include any behavioral changes, even small things like replacing `err ==` with `errors.Is`. **Potential risks** **Is hotfix candidate?** --- .golangci.yml | 10 ++-- service/matching/ack_manager.go | 2 +- service/matching/config.go | 2 +- service/matching/db.go | 6 +- service/matching/forwarder.go | 9 +-- service/matching/forwarder_test.go | 10 ++-- service/matching/fx.go | 6 +- service/matching/matcher.go | 15 +++-- service/matching/matching_engine.go | 40 ++++++-------- service/matching/matching_engine_test.go | 61 ++++++++++----------- service/matching/task.go | 2 +- service/matching/task_gc.go | 4 +- service/matching/task_queue_manager.go | 25 ++++----- service/matching/task_queue_manager_test.go | 4 +- service/matching/task_reader.go | 40 +++++++++----- service/matching/task_writer.go | 17 ++++-- service/matching/taskqueue.go | 2 +- service/matching/version_sets.go | 6 +- 18 files changed, 137 insertions(+), 124 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index f31761c64fb..88ca3dd0c25 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -30,6 +30,8 @@ linters-settings: # Disabled rules - name: add-constant disabled: true + - name: argument-limit + disabled: true - name: bare-return disabled: true - name: banned-characters @@ -78,9 +80,6 @@ linters-settings: disabled: true # Rule tuning - - name: argument-limit - arguments: - - 10 - name: cognitive-complexity arguments: - 25 @@ -92,8 +91,9 @@ linters-settings: - 3 - name: unhandled-error arguments: - - "fmt.Printf" - - "fmt.Println" + - "fmt.*" + - "bytes.Buffer.*" + - "strings.Builder.*" issues: # Exclude cyclomatic and cognitive complexity rules for functional tests in the `tests` root directory. exclude-rules: diff --git a/service/matching/ack_manager.go b/service/matching/ack_manager.go index a05c51cfb6f..c52aae27c39 100644 --- a/service/matching/ack_manager.go +++ b/service/matching/ack_manager.go @@ -120,7 +120,7 @@ func (m *ackManager) completeTask(taskID int64) (ackLevel int64) { m.backlogCounter.Dec() } - // TODO the ack level management shuld be done by a dedicated coroutine + // TODO the ack level management should be done by a dedicated coroutine // this is only a temporarily solution taskIDs := maps.Keys(m.outstandingTasks) diff --git a/service/matching/config.go b/service/matching/config.go index d7877391d03..31f5a8ed3ce 100644 --- a/service/matching/config.go +++ b/service/matching/config.go @@ -124,7 +124,7 @@ type ( AdminNamespaceTaskQueueToPartitionDispatchRate func() float64 // If set to false, matching does not load user data from DB for root partitions or fetch it via RPC from the - // root. When disbled, features that rely on user data (e.g. worker versioning) will essentially be disabled. + // root. When disabled, features that rely on user data (e.g. worker versioning) will essentially be disabled. // See the documentation for constants.MatchingLoadUserData for the implications on versioning. LoadUserData func() bool diff --git a/service/matching/db.go b/service/matching/db.go index b030a19706c..dfd342ea2e1 100644 --- a/service/matching/db.go +++ b/service/matching/db.go @@ -310,8 +310,8 @@ func (db *taskQueueDB) CompleteTasksLessThan( return n, err } -// Returns true if we are storing user data in the db. We need to be the root partition, -// workflow type, unversioned, and also a normal queue. +// DbStoresUserData returns true if we are storing user data in the db. We need to be the root partition, workflow type, +// unversioned, and also a normal queue. func (db *taskQueueDB) DbStoresUserData() bool { return db.taskQueue.OwnsUserData() && db.taskQueueKind == enumspb.TASK_QUEUE_KIND_NORMAL } @@ -319,7 +319,7 @@ func (db *taskQueueDB) DbStoresUserData() bool { // GetUserData returns the versioning data for this task queue. Do not mutate the returned pointer, as doing so // will cause cache inconsistency. func (db *taskQueueDB) GetUserData( - ctx context.Context, + context.Context, ) (*persistencespb.VersionedTaskQueueUserData, chan struct{}, error) { db.Lock() defer db.Unlock() diff --git a/service/matching/forwarder.go b/service/matching/forwarder.go index 5e64dbbf4b0..4a3b2c33d52 100644 --- a/service/matching/forwarder.go +++ b/service/matching/forwarder.go @@ -113,7 +113,7 @@ func newForwarder( return fwdr } -// ForwardTask forwards an activity or workflow task to the parent task queue partition if it exist +// ForwardTask forwards an activity or workflow task to the parent task queue partition if it exists func (fwdr *Forwarder) ForwardTask(ctx context.Context, task *internalTask) error { if fwdr.taskQueueKind == enumspb.TASK_QUEUE_KIND_STICKY { return errTaskQueueKind @@ -131,15 +131,12 @@ func (fwdr *Forwarder) ForwardTask(ctx context.Context, task *internalTask) erro var expirationDuration time.Duration expirationTime := timestamp.TimeValue(task.event.Data.ExpiryTime) - if expirationTime.IsZero() { - // noop - } else { + if !expirationTime.IsZero() { expirationDuration = time.Until(expirationTime) if expirationDuration <= 0 { return nil } } - switch fwdr.taskQueueID.taskType { case enumspb.TASK_QUEUE_TYPE_WORKFLOW: _, err = fwdr.client.AddWorkflowTask(ctx, &matchingservice.AddWorkflowTaskRequest{ @@ -178,7 +175,7 @@ func (fwdr *Forwarder) ForwardTask(ctx context.Context, task *internalTask) erro return fwdr.handleErr(err) } -// ForwardQueryTask forwards a query task to parent task queue partition, if it exist +// ForwardQueryTask forwards a query task to parent task queue partition, if it exists func (fwdr *Forwarder) ForwardQueryTask( ctx context.Context, task *internalTask, diff --git a/service/matching/forwarder_test.go b/service/matching/forwarder_test.go index 51a0dd249bf..79bc2f2d1d1 100644 --- a/service/matching/forwarder_test.go +++ b/service/matching/forwarder_test.go @@ -100,7 +100,7 @@ func (t *ForwarderTestSuite) TestForwardWorkflowTask() { t.NoError(t.fwdr.ForwardTask(context.Background(), task)) t.NotNil(request) t.Equal(mustParent(t.taskQueue.Name, 20).FullName(), request.TaskQueue.GetName()) - t.Equal(enumspb.TaskQueueKind(t.fwdr.taskQueueKind), request.TaskQueue.GetKind()) + t.Equal(t.fwdr.taskQueueKind, request.TaskQueue.GetKind()) t.Equal(taskInfo.Data.GetNamespaceId(), request.GetNamespaceId()) t.Equal(taskInfo.Data.GetWorkflowId(), request.GetExecution().GetWorkflowId()) t.Equal(taskInfo.Data.GetRunId(), request.GetExecution().GetRunId()) @@ -175,7 +175,7 @@ func (t *ForwarderTestSuite) TestForwardQueryTask() { gotResp, err := t.fwdr.ForwardQueryTask(context.Background(), task) t.NoError(err) t.Equal(mustParent(t.taskQueue.Name, 20).FullName(), request.TaskQueue.GetName()) - t.Equal(enumspb.TaskQueueKind(t.fwdr.taskQueueKind), request.TaskQueue.GetKind()) + t.Equal(t.fwdr.taskQueueKind, request.TaskQueue.GetKind()) t.Equal(task.query.request.QueryRequest, request.QueryRequest) t.Equal(resp, gotResp) } @@ -191,7 +191,7 @@ func (t *ForwarderTestSuite) TestForwardQueryTaskRateNotEnforced() { t.NoError(err) } _, err := t.fwdr.ForwardQueryTask(context.Background(), task) - t.NoError(err) // no rateliming should be enforced for query task + t.NoError(err) // no rate limiting should be enforced for query task } func (t *ForwarderTestSuite) TestForwardPollError() { @@ -228,7 +228,7 @@ func (t *ForwarderTestSuite) TestForwardPollWorkflowTaskQueue() { t.Equal(t.taskQueue.namespaceID, namespace.ID(request.GetNamespaceId())) t.Equal("id1", request.GetPollRequest().GetIdentity()) t.Equal(mustParent(t.taskQueue.Name, 20).FullName(), request.GetPollRequest().GetTaskQueue().GetName()) - t.Equal(enumspb.TaskQueueKind(t.fwdr.taskQueueKind), request.GetPollRequest().GetTaskQueue().GetKind()) + t.Equal(t.fwdr.taskQueueKind, request.GetPollRequest().GetTaskQueue().GetKind()) t.Equal(resp, task.pollWorkflowTaskQueueResponse()) t.Nil(task.pollActivityTaskQueueResponse()) } @@ -256,7 +256,7 @@ func (t *ForwarderTestSuite) TestForwardPollForActivity() { t.Equal(t.taskQueue.namespaceID, namespace.ID(request.GetNamespaceId())) t.Equal("id1", request.GetPollRequest().GetIdentity()) t.Equal(mustParent(t.taskQueue.Name, 20).FullName(), request.GetPollRequest().GetTaskQueue().GetName()) - t.Equal(enumspb.TaskQueueKind(t.fwdr.taskQueueKind), request.GetPollRequest().GetTaskQueue().GetKind()) + t.Equal(t.fwdr.taskQueueKind, request.GetPollRequest().GetTaskQueue().GetKind()) t.Equal(resp, task.pollActivityTaskQueueResponse()) t.Nil(task.pollWorkflowTaskQueueResponse()) } diff --git a/service/matching/fx.go b/service/matching/fx.go index b992bd1af4a..544f34cbc45 100644 --- a/service/matching/fx.go +++ b/service/matching/fx.go @@ -110,7 +110,7 @@ func RateLimitInterceptorProvider( ) } -// This function is the same between services but uses different config sources. +// PersistenceRateLimitingParamsProvider is the same between services but uses different config sources. // if-case comes from resourceImpl.New. func PersistenceRateLimitingParamsProvider( serviceConfig *Config, @@ -129,8 +129,8 @@ func ServiceResolverProvider(membershipMonitor membership.Monitor) (membership.S return membershipMonitor.GetResolver(primitives.MatchingService) } -// This type is used to ensure the replicator only gets set if global namespaces are enabled on this cluster. -// See NamespaceReplicationQueueProvider below. +// TaskQueueReplicatorNamespaceReplicationQueue is used to ensure the replicator only gets set if global namespaces are +// enabled on this cluster. See NamespaceReplicationQueueProvider below. type TaskQueueReplicatorNamespaceReplicationQueue persistence.NamespaceReplicationQueue func NamespaceReplicationQueueProvider( diff --git a/service/matching/matcher.go b/service/matching/matcher.go index 48a8eded9f9..e9eb6d5258f 100644 --- a/service/matching/matcher.go +++ b/service/matching/matcher.go @@ -45,10 +45,10 @@ type TaskMatcher struct { // synchronous task channel to match producer/consumer taskC chan *internalTask - // synchronous task channel to match query task - the reason to have - // separate channel for this is because there are cases when consumers - // are interested in queryTasks but not others. Example is when namespace is - // not active in a cluster + // synchronous task channel to match query task - the reason to have a + // separate channel for this is that there are cases where consumers + // are interested in queryTasks but not others. One example is when a + // namespace is not active in a cluster. queryTaskC chan *internalTask // dynamicRate is the dynamic rate & burst for rate limiter @@ -75,9 +75,8 @@ var ( errInterrupted = errors.New("interrupted offer") ) -// newTaskMatcher returns an task matcher instance. The returned instance can be -// used by task producers and consumers to find a match. Both sync matches and non-sync -// matches should use this implementation +// newTaskMatcher returns a task matcher instance. The returned instance can be used by task producers and consumers to +// find a match. Both sync matches and non-sync matches should use this implementation func newTaskMatcher(config *taskQueueConfig, fwdr *Forwarder, metricsHandler metrics.Handler) *TaskMatcher { dynamicRateBurst := quotas.NewMutableRateBurst( defaultTaskDispatchRPS, @@ -383,7 +382,7 @@ func (tm *TaskMatcher) poll(ctx context.Context, pollMetadata *pollMetadata, que default: } - // 3. forwarding (and all other clauses repeated again) + // 3. forwarding (and all other clauses repeated) select { case <-ctx.Done(): tm.metricsHandler.Counter(metrics.PollTimeoutPerTaskQueueCounter.GetMetricName()).Record(1) diff --git a/service/matching/matching_engine.go b/service/matching/matching_engine.go index ed00a355d45..ac5011501da 100644 --- a/service/matching/matching_engine.go +++ b/service/matching/matching_engine.go @@ -250,7 +250,7 @@ func (e *matchingEngineImpl) String() string { // // Note that stickyInfo is not used as part of the task queue identity. That means that if // getTaskQueueManager is called twice with the same taskQueue but different stickyInfo, the -// properties of the taskQueueManager will depend on which call came first. In general we can +// properties of the taskQueueManager will depend on which call came first. In general, we can // rely on kind being the same for all calls now, but normalName was a later addition to the // protocol and is not always set consistently. normalName is only required when using // versioning, and SDKs that support versioning will always set it. The current server version @@ -347,9 +347,7 @@ func (e *matchingEngineImpl) AddWorkflowTask( var expirationTime *time.Time now := timestamp.TimePtr(time.Now().UTC()) expirationDuration := timestamp.DurationValue(addRequest.GetScheduleToStartTimeout()) - if expirationDuration == 0 { - // noop - } else { + if expirationDuration != 0 { expirationTime = timestamp.TimePtr(now.Add(expirationDuration)) } taskInfo := &persistencespb.TaskInfo{ @@ -412,9 +410,7 @@ func (e *matchingEngineImpl) AddActivityTask( var expirationTime *time.Time now := timestamp.TimePtr(time.Now().UTC()) expirationDuration := timestamp.DurationValue(addRequest.GetScheduleToStartTimeout()) - if expirationDuration == 0 { - // noop - } else { + if expirationDuration != 0 { expirationTime = timestamp.TimePtr(now.Add(expirationDuration)) } taskInfo := &persistencespb.TaskInfo{ @@ -714,13 +710,13 @@ func (e *matchingEngineImpl) QueryWorkflow( taskID := uuid.New() resp, err := tqm.DispatchQueryTask(ctx, taskID, queryRequest) - // if get response or error it means that query task was handled by forwarding to another matching host + // if we get a response or error it means that query task was handled by forwarding to another matching host // this remote host's result can be returned directly if resp != nil || err != nil { return resp, err } - // if get here it means that dispatch of query task has occurred locally + // if we get here it means that dispatch of query task has occurred locally // must wait on result channel to get query result queryResultCh := make(chan *queryResult, 1) e.lockableQueryTaskMap.put(taskID, queryResultCh) @@ -747,7 +743,7 @@ func (e *matchingEngineImpl) QueryWorkflow( } func (e *matchingEngineImpl) RespondQueryTaskCompleted( - ctx context.Context, + _ context.Context, request *matchingservice.RespondQueryTaskCompletedRequest, opMetrics metrics.Handler, ) error { @@ -768,7 +764,7 @@ func (e *matchingEngineImpl) deliverQueryResult(taskID string, queryResult *quer } func (e *matchingEngineImpl) CancelOutstandingPoll( - ctx context.Context, + _ context.Context, request *matchingservice.CancelOutstandingPollRequest, ) error { e.pollMap.cancel(request.PollerId) @@ -796,7 +792,7 @@ func (e *matchingEngineImpl) DescribeTaskQueue( } func (e *matchingEngineImpl) ListTaskQueuePartitions( - ctx context.Context, + _ context.Context, request *matchingservice.ListTaskQueuePartitionsRequest, ) (*matchingservice.ListTaskQueuePartitionsResponse, error) { activityTaskQueueInfo, err := e.listTaskQueuePartitions(request, enumspb.TASK_QUEUE_TYPE_ACTIVITY) @@ -868,12 +864,12 @@ func (e *matchingEngineImpl) UpdateWorkerBuildIdCompatibility( } err = tqMgr.UpdateUserData(ctx, updateOptions, func(data *persistencespb.TaskQueueUserData) (*persistencespb.TaskQueueUserData, bool, error) { - clock := data.GetClock() - if clock == nil { + clk := data.GetClock() + if clk == nil { tmp := hlc.Zero(e.clusterMeta.GetClusterID()) - clock = &tmp + clk = &tmp } - updatedClock := hlc.Next(*clock, e.timeSource) + updatedClock := hlc.Next(*clk, e.timeSource) var versioningData *persistencespb.VersioningData switch req.GetOperation().(type) { case *matchingservice.UpdateWorkerBuildIdCompatibilityRequest_ApplyPublicRequest_: @@ -1172,12 +1168,12 @@ func (e *matchingEngineImpl) getHostInfo(partitionKey string) (string, error) { } func (e *matchingEngineImpl) getAllPartitions( - namespace namespace.Name, + ns namespace.Name, taskQueue taskqueuepb.TaskQueue, taskQueueType enumspb.TaskQueueType, ) ([]string, error) { var partitionKeys []string - namespaceID, err := e.namespaceRegistry.GetNamespaceID(namespace) + namespaceID, err := e.namespaceRegistry.GetNamespaceID(ns) if err != nil { return partitionKeys, err } @@ -1186,7 +1182,7 @@ func (e *matchingEngineImpl) getAllPartitions( return partitionKeys, err } - n := e.config.NumTaskqueueWritePartitions(namespace.String(), taskQueueID.BaseNameString(), taskQueueType) + n := e.config.NumTaskqueueWritePartitions(ns.String(), taskQueueID.BaseNameString(), taskQueueType) for i := 0; i < n; i++ { partitionKeys = append(partitionKeys, taskQueueID.WithPartition(i).FullName()) } @@ -1265,14 +1261,14 @@ func (e *matchingEngineImpl) unloadTaskQueue(unloadTQM taskQueueManager) { func (e *matchingEngineImpl) updateTaskQueueGauge(countKey taskQueueCounterKey, taskQueueCount int) { nsEntry, err := e.namespaceRegistry.GetNamespaceByID(countKey.namespaceID) - namespace := namespace.Name("unknown") + ns := namespace.Name("unknown") if err == nil { - namespace = nsEntry.Name() + ns = nsEntry.Name() } e.metricsHandler.Gauge(metrics.LoadedTaskQueueGauge.GetMetricName()).Record( float64(taskQueueCount), - metrics.NamespaceTag(namespace.String()), + metrics.NamespaceTag(ns.String()), metrics.TaskTypeTag(countKey.taskType.String()), metrics.QueueTypeTag(countKey.kind.String()), ) diff --git a/service/matching/matching_engine_test.go b/service/matching/matching_engine_test.go index f9688552e71..45aca4c12cc 100644 --- a/service/matching/matching_engine_test.go +++ b/service/matching/matching_engine_test.go @@ -95,7 +95,6 @@ type ( const ( matchingTestNamespace = "matching-test" - matchingTestTaskQueue = "matching-test-taskqueue" ) func TestMatchingEngineSuite(t *testing.T) { @@ -495,7 +494,7 @@ func (s *matchingEngineSuite) TestPollWorkflowTaskQueues_NamespaceHandover() { ScheduleToStartTimeout: timestamp.DurationFromSeconds(100), } - // add multiple workflow tasks, but matching should not keeping polling new tasks + // add multiple workflow tasks, but matching should not keep polling new tasks // upon getting namespace handover error when recording start for the first task _, err := s.matchingEngine.AddWorkflowTask(context.Background(), &addRequest) s.NoError(err) @@ -527,7 +526,7 @@ func (s *matchingEngineSuite) TestPollActivityTaskQueues_NamespaceHandover() { ScheduleToStartTimeout: timestamp.DurationFromSeconds(100), } - // add multiple activity tasks, but matching should not keeping polling new tasks + // add multiple activity tasks, but matching should not keep polling new tasks // upon getting namespace handover error when recording start for the first task _, err := s.matchingEngine.AddActivityTask(context.Background(), &addRequest) s.NoError(err) @@ -804,7 +803,7 @@ func (s *matchingEngineSuite) TestAddThenConsumeActivities() { } func (s *matchingEngineSuite) TestSyncMatchActivities() { - // Set a short long poll expiration so we don't have to wait too long for 0 throttling cases + // Set a short long poll expiration so that we don't have to wait too long for 0 throttling cases s.matchingEngine.config.LongPollExpirationInterval = dynamicconfig.GetDurationPropertyFnFilteredByTaskQueueInfo(2 * time.Second) runID := uuid.NewRandom().String() @@ -1006,7 +1005,7 @@ func (s *matchingEngineSuite) TestConcurrentPublishConsumeActivities() { func (s *matchingEngineSuite) TestConcurrentPublishConsumeActivitiesWithZeroDispatch() { s.T().Skip("Racy - times out ~50% of the time running locally with --race") - // Set a short long poll expiration so we don't have to wait too long for 0 throttling cases + // Set a short long poll expiration so that we don't have to wait too long for 0 throttling cases s.matchingEngine.config.LongPollExpirationInterval = dynamicconfig.GetDurationPropertyFnFilteredByTaskQueueInfo(20 * time.Millisecond) dispatchLimitFn := func(wc int, tc int64) float64 { if tc%50 == 0 && wc%5 == 0 { // Gets triggered atleast 20 times @@ -1724,7 +1723,7 @@ func (s *matchingEngineSuite) TestTaskQueueManagerGetTaskBatch() { s.True(ok, "taskQueueManger doesn't implement taskQueueManager interface") s.EqualValues(taskCount, s.taskManager.getTaskCount(tlID)) - // wait until all tasks are read by the task pump and enqeued into the in-memory buffer + // wait until all tasks are read by the task pump and enqueued into the in-memory buffer // at the end of this step, ackManager readLevel will also be equal to the buffer size expectedBufSize := util.Min(cap(tlMgr.taskReader.taskBuffer), taskCount) s.True(s.awaitCondition(func() bool { return len(tlMgr.taskReader.taskBuffer) == expectedBufSize }, time.Second)) @@ -1736,18 +1735,18 @@ func (s *matchingEngineSuite) TestTaskQueueManagerGetTaskBatch() { // setReadLevel should NEVER be called without updating ackManager.outstandingTasks // This is only for unit test purpose tlMgr.taskAckManager.setReadLevel(tlMgr.taskWriter.GetMaxReadLevel()) - tasks, readLevel, isReadBatchDone, err := tlMgr.taskReader.getTaskBatch(context.Background()) + batch, err := tlMgr.taskReader.getTaskBatch(context.Background()) s.Nil(err) - s.EqualValues(0, len(tasks)) - s.EqualValues(tlMgr.taskWriter.GetMaxReadLevel(), readLevel) - s.True(isReadBatchDone) + s.EqualValues(0, len(batch.tasks)) + s.EqualValues(tlMgr.taskWriter.GetMaxReadLevel(), batch.readLevel) + s.True(batch.isReadBatchDone) tlMgr.taskAckManager.setReadLevel(0) - tasks, readLevel, isReadBatchDone, err = tlMgr.taskReader.getTaskBatch(context.Background()) + batch, err = tlMgr.taskReader.getTaskBatch(context.Background()) s.Nil(err) - s.EqualValues(rangeSize, len(tasks)) - s.EqualValues(rangeSize, readLevel) - s.True(isReadBatchDone) + s.EqualValues(rangeSize, len(batch.tasks)) + s.EqualValues(rangeSize, batch.readLevel) + s.True(batch.isReadBatchDone) s.setupRecordActivityTaskStartedMock(tl) @@ -1776,10 +1775,10 @@ func (s *matchingEngineSuite) TestTaskQueueManagerGetTaskBatch() { } } s.EqualValues(taskCount-rangeSize, s.taskManager.getTaskCount(tlID)) - tasks, _, isReadBatchDone, err = tlMgr.taskReader.getTaskBatch(context.Background()) + batch, err = tlMgr.taskReader.getTaskBatch(context.Background()) s.Nil(err) - s.True(0 < len(tasks) && len(tasks) <= rangeSize) - s.True(isReadBatchDone) + s.True(0 < len(batch.tasks) && len(batch.tasks) <= rangeSize) + s.True(batch.isReadBatchDone) } func (s *matchingEngineSuite) TestTaskQueueManagerGetTaskBatch_ReadBatchDone() { @@ -1805,17 +1804,17 @@ func (s *matchingEngineSuite) TestTaskQueueManagerGetTaskBatch_ReadBatchDone() { tlMgr.taskAckManager.setReadLevel(0) atomic.StoreInt64(&tlMgr.taskWriter.maxReadLevel, maxReadLevel) - tasks, readLevel, isReadBatchDone, err := tlMgr.taskReader.getTaskBatch(context.Background()) - s.Empty(tasks) - s.Equal(int64(rangeSize*10), readLevel) - s.False(isReadBatchDone) + batch, err := tlMgr.taskReader.getTaskBatch(context.Background()) + s.Empty(batch.tasks) + s.Equal(int64(rangeSize*10), batch.readLevel) + s.False(batch.isReadBatchDone) s.NoError(err) - tlMgr.taskAckManager.setReadLevel(readLevel) - tasks, readLevel, isReadBatchDone, err = tlMgr.taskReader.getTaskBatch(context.Background()) - s.Empty(tasks) - s.Equal(maxReadLevel, readLevel) - s.True(isReadBatchDone) + tlMgr.taskAckManager.setReadLevel(batch.readLevel) + batch, err = tlMgr.taskReader.getTaskBatch(context.Background()) + s.Empty(batch.tasks) + s.Equal(maxReadLevel, batch.readLevel) + s.True(batch.isReadBatchDone) s.NoError(err) } @@ -2731,7 +2730,7 @@ func (m *testTaskManager) String() string { } // GetTaskQueueData implements persistence.TaskManager -func (m *testTaskManager) GetTaskQueueUserData(ctx context.Context, request *persistence.GetTaskQueueUserDataRequest) (*persistence.GetTaskQueueUserDataResponse, error) { +func (m *testTaskManager) GetTaskQueueUserData(_ context.Context, request *persistence.GetTaskQueueUserDataRequest) (*persistence.GetTaskQueueUserDataResponse, error) { tlm := m.getTaskQueueManager(newTestTaskQueueID(namespace.ID(request.NamespaceID), request.TaskQueue, enumspb.TASK_QUEUE_TYPE_WORKFLOW)) tlm.Lock() defer tlm.Unlock() @@ -2742,7 +2741,7 @@ func (m *testTaskManager) GetTaskQueueUserData(ctx context.Context, request *per } // UpdateTaskQueueUserData implements persistence.TaskManager -func (m *testTaskManager) UpdateTaskQueueUserData(ctx context.Context, request *persistence.UpdateTaskQueueUserDataRequest) error { +func (m *testTaskManager) UpdateTaskQueueUserData(_ context.Context, request *persistence.UpdateTaskQueueUserDataRequest) error { tlm := m.getTaskQueueManager(newTestTaskQueueID(namespace.ID(request.NamespaceID), request.TaskQueue, enumspb.TASK_QUEUE_TYPE_WORKFLOW)) tlm.Lock() defer tlm.Unlock() @@ -2753,19 +2752,19 @@ func (m *testTaskManager) UpdateTaskQueueUserData(ctx context.Context, request * } // ListTaskQueueUserDataEntries implements persistence.TaskManager -func (*testTaskManager) ListTaskQueueUserDataEntries(ctx context.Context, request *persistence.ListTaskQueueUserDataEntriesRequest) (*persistence.ListTaskQueueUserDataEntriesResponse, error) { +func (*testTaskManager) ListTaskQueueUserDataEntries(context.Context, *persistence.ListTaskQueueUserDataEntriesRequest) (*persistence.ListTaskQueueUserDataEntriesResponse, error) { // No need to implement this for unit tests panic("unimplemented") } // GetTaskQueuesByBuildId implements persistence.TaskManager -func (*testTaskManager) GetTaskQueuesByBuildId(ctx context.Context, request *persistence.GetTaskQueuesByBuildIdRequest) ([]string, error) { +func (*testTaskManager) GetTaskQueuesByBuildId(context.Context, *persistence.GetTaskQueuesByBuildIdRequest) ([]string, error) { // No need to implement this for unit tests panic("unimplemented") } // CountTaskQueuesByBuildId implements persistence.TaskManager -func (*testTaskManager) CountTaskQueuesByBuildId(ctx context.Context, request *persistence.CountTaskQueuesByBuildIdRequest) (int, error) { +func (*testTaskManager) CountTaskQueuesByBuildId(context.Context, *persistence.CountTaskQueuesByBuildIdRequest) (int, error) { // This is only used to validate that the build id to task queue mapping is enforced (at the time of writing), report 0. return 0, nil } diff --git a/service/matching/task.go b/service/matching/task.go index a04d8f4490a..9d3041bc16c 100644 --- a/service/matching/task.go +++ b/service/matching/task.go @@ -164,7 +164,7 @@ func (task *internalTask) finish(err error) { case task.responseC != nil: task.responseC <- err case task.event.completionFunc != nil: - // TODO: this probably should not be done synchronosly in PollWorkflow/ActivityTaskQueue + // TODO: this probably should not be done synchronously in PollWorkflow/ActivityTaskQueue task.event.completionFunc(task.event.AllocatedTaskInfo, err) } } diff --git a/service/matching/task_gc.go b/service/matching/task_gc.go index 5f9818147ce..1786e74f840 100644 --- a/service/matching/task_gc.go +++ b/service/matching/task_gc.go @@ -57,13 +57,13 @@ func newTaskGC(db *taskQueueDB, config *taskQueueConfig) *taskGC { return &taskGC{db: db, config: config} } -// Run deletes a batch of completed tasks, if its possible to do so +// Run deletes a batch of completed tasks, if it's possible to do so // Only attempts deletion if size or time thresholds are met func (tgc *taskGC) Run(ctx context.Context, ackLevel int64) { tgc.tryDeleteNextBatch(ctx, ackLevel, false) } -// RunNow deletes a batch of completed tasks if its possible to do so +// RunNow deletes a batch of completed tasks if it's possible to do so // This method attempts deletions without waiting for size/time threshold to be met func (tgc *taskGC) RunNow(ctx context.Context, ackLevel int64) { tgc.tryDeleteNextBatch(ctx, ackLevel, true) diff --git a/service/matching/task_queue_manager.go b/service/matching/task_queue_manager.go index 2b57b5ff4cb..e19630fb2e2 100644 --- a/service/matching/task_queue_manager.go +++ b/service/matching/task_queue_manager.go @@ -73,8 +73,8 @@ const ( ) var ( - // this retry policy is currenly only used for matching persistence operations - // that, if failed, the entire task queue needs to be reload + // this retry policy is currently only used for matching persistence operations + // that, if failed, the entire task queue needs to be reloaded persistenceOperationRetryPolicy = backoff.NewExponentialRetryPolicy(50 * time.Millisecond). WithMaximumInterval(1 * time.Second). WithExpirationInterval(30 * time.Second) @@ -132,7 +132,7 @@ type ( // DispatchQueryTask will dispatch query to local or remote poller. If forwarded then result or error is returned, // if dispatched to local poller then nil and nil is returned. DispatchQueryTask(ctx context.Context, taskID string, request *matchingservice.QueryWorkflowRequest) (*matchingservice.QueryWorkflowResponse, error) - // GetUserData returns the verioned user data for this task queue + // GetUserData returns the versioned user data for this task queue GetUserData(ctx context.Context) (*persistencespb.VersionedTaskQueueUserData, chan struct{}, error) // UpdateUserData updates user data for this task queue and replicates across clusters if necessary. // Extra care should be taken to avoid mutating the existing data in the update function. @@ -353,11 +353,10 @@ func (c *taskQueueManagerImpl) Stop() { c.unloadFromEngine() } -// managesSpecificVersionSet returns true if this is a tqm for a specific version set in the -// build-id-based versioning feature. Note that this is a different concept from the overall -// task queue having versioning data associated with it, which is the usual meaning of -// "versioned task queue". These task queues are not interacted with directly outside outside -// of a single matching node. +// managesSpecificVersionSet returns true if this is a tqm for a specific version set in the build-id-based versioning +// feature. Note that this is a different concept from the overall task queue having versioning data associated with it, +// which is the usual meaning of "versioned task queue". These task queues are not interacted with directly outside of +// a single matching node. func (c *taskQueueManagerImpl) managesSpecificVersionSet() bool { return c.taskQueueID.VersionSet() != "" } @@ -441,7 +440,7 @@ func (c *taskQueueManagerImpl) AddTask( return false, errRemoteSyncMatchFailed } - // Ensure that tasks with the "default" versioning directive get spooled in the unversioned queue as they not + // Ensure that tasks with the "default" versioning directive get spooled in the unversioned queue as they are not // associated with any version set until their execution is touched by a version specific worker. // "compatible" tasks OTOH are associated with a specific version set and should be stored along with all tasks for // that version set. @@ -651,7 +650,7 @@ func (c *taskQueueManagerImpl) completeTask(task *persistencespb.AllocatedTaskIn if err != nil { // OK, we also failed to write to persistence. // This should only happen in very extreme cases where persistence is completely down. - // We still can't lose the old task so we just unload the entire task queue + // We still can't lose the old task, so we just unload the entire task queue c.logger.Error("Persistent store operation failure", tag.StoreOperationStopTaskQueue, tag.Error(err), @@ -755,8 +754,8 @@ func (c *taskQueueManagerImpl) LongPollExpirationInterval() time.Duration { } func (c *taskQueueManagerImpl) callerInfoContext(ctx context.Context) context.Context { - namespace, _ := c.namespaceRegistry.GetNamespaceName(c.taskQueueID.namespaceID) - return headers.SetCallerInfo(ctx, headers.NewBackgroundCallerInfo(namespace.String())) + ns, _ := c.namespaceRegistry.GetNamespaceName(c.taskQueueID.namespaceID) + return headers.SetCallerInfo(ctx, headers.NewBackgroundCallerInfo(ns.String())) } func (c *taskQueueManagerImpl) newIOContext() (context.Context, context.CancelFunc) { @@ -896,7 +895,7 @@ func (c *taskQueueManagerImpl) fetchUserData(ctx context.Context) error { _ = backoff.ThrottleRetryContext(ctx, op, c.config.GetUserDataRetryPolicy, nil) elapsed := time.Since(start) - // In general we want to start a new call immediately on completion of the previous + // In general, we want to start a new call immediately on completion of the previous // one. But if the remote is broken and returns success immediately, we might end up // spinning. So enforce a minimum wait time that increases as long as we keep getting // very fast replies. diff --git a/service/matching/task_queue_manager_test.go b/service/matching/task_queue_manager_test.go index 18030014c28..a2c18019a4e 100644 --- a/service/matching/task_queue_manager_test.go +++ b/service/matching/task_queue_manager_test.go @@ -288,7 +288,7 @@ func TestSyncMatchLeasingUnavailable(t *testing.T) { func TestForeignPartitionOwnerCausesUnload(t *testing.T) { cfg := NewConfig(dynamicconfig.NewNoopCollection(), false, false) cfg.RangeSize = 1 // TaskID block size - var leaseErr error = nil + var leaseErr error tqm := mustCreateTestTaskQueueManager(t, gomock.NewController(t), makeTestBlocAlloc(func() (taskQueueState, error) { return taskQueueState{rangeID: 1}, leaseErr @@ -505,7 +505,7 @@ func TestCheckIdleTaskQueue(t *testing.T) { // Active poll-er tlm = mustCreateTestTaskQueueManagerWithConfig(t, controller, tqCfg) tlm.Start() - tlm.pollerHistory.updatePollerInfo(pollerIdentity("test-poll"), &pollMetadata{}) + tlm.pollerHistory.updatePollerInfo("test-poll", &pollMetadata{}) require.Equal(t, 1, len(tlm.GetAllPollerInfo())) time.Sleep(1 * time.Second) require.Equal(t, common.DaemonStatusStarted, atomic.LoadInt32(&tlm.status)) diff --git a/service/matching/task_reader.go b/service/matching/task_reader.go index 3fa89120e83..cfa961c72ac 100644 --- a/service/matching/task_reader.go +++ b/service/matching/task_reader.go @@ -182,30 +182,30 @@ Loop: return nil case <-tr.notifyC: - tasks, readLevel, isReadBatchDone, err := tr.getTaskBatch(ctx) + batch, err := tr.getTaskBatch(ctx) tr.tlMgr.signalIfFatal(err) if err != nil { // TODO: Should we ever stop retrying on db errors? if common.IsResourceExhausted(err) { - tr.backoff(taskReaderThrottleRetryDelay) + tr.reEnqueueAfterDelay(taskReaderThrottleRetryDelay) } else { - tr.backoff(tr.retrier.NextBackOff()) + tr.reEnqueueAfterDelay(tr.retrier.NextBackOff()) } continue Loop } tr.retrier.Reset() - if len(tasks) == 0 { - tr.tlMgr.taskAckManager.setReadLevelAfterGap(readLevel) - if !isReadBatchDone { + if len(batch.tasks) == 0 { + tr.tlMgr.taskAckManager.setReadLevelAfterGap(batch.readLevel) + if !batch.isReadBatchDone { tr.Signal() } continue Loop } - // only error here is due to context cancelation which we also + // only error here is due to context cancellation which we also // handle above - _ = tr.addTasksToBuffer(ctx, tasks) + _ = tr.addTasksToBuffer(ctx, batch.tasks) // There maybe more tasks. We yield now, but signal pump to check again later. tr.Signal() @@ -236,10 +236,16 @@ func (tr *taskReader) getTaskBatchWithRange( return response.Tasks, err } +type getTasksBatchResponse struct { + tasks []*persistencespb.AllocatedTaskInfo + readLevel int64 + isReadBatchDone bool +} + // Returns a batch of tasks from persistence starting form current read level. // Also return a number that can be used to update readLevel // Also return a bool to indicate whether read is finished -func (tr *taskReader) getTaskBatch(ctx context.Context) ([]*persistencespb.AllocatedTaskInfo, int64, bool, error) { +func (tr *taskReader) getTaskBatch(ctx context.Context) (*getTasksBatchResponse, error) { var tasks []*persistencespb.AllocatedTaskInfo readLevel := tr.tlMgr.taskAckManager.getReadLevel() maxReadLevel := tr.tlMgr.taskWriter.GetMaxReadLevel() @@ -252,15 +258,23 @@ func (tr *taskReader) getTaskBatch(ctx context.Context) ([]*persistencespb.Alloc } tasks, err := tr.getTaskBatchWithRange(ctx, readLevel, upper) if err != nil { - return nil, readLevel, true, err + return nil, err } // return as long as it grabs any tasks if len(tasks) > 0 { - return tasks, upper, true, nil + return &getTasksBatchResponse{ + tasks: tasks, + readLevel: upper, + isReadBatchDone: true, + }, nil } readLevel = upper } - return tasks, readLevel, readLevel == maxReadLevel, nil // caller will update readLevel when no task grabbed + return &getTasksBatchResponse{ + tasks: tasks, + readLevel: readLevel, + isReadBatchDone: readLevel == maxReadLevel, + }, nil // caller will update readLevel when no task grabbed } func (tr *taskReader) addTasksToBuffer( @@ -316,7 +330,7 @@ func (tr *taskReader) emitTaskLagMetric(ackLevel int64) { tr.taggedMetricsHandler().Gauge(metrics.TaskLagPerTaskQueueGauge.GetMetricName()).Record(float64(maxReadLevel - ackLevel)) } -func (tr *taskReader) backoff(duration time.Duration) { +func (tr *taskReader) reEnqueueAfterDelay(duration time.Duration) { tr.backoffTimerLock.Lock() defer tr.backoffTimerLock.Unlock() diff --git a/service/matching/task_writer.go b/service/matching/task_writer.go index 3c1cc823000..3588deb59ce 100644 --- a/service/matching/task_writer.go +++ b/service/matching/task_writer.go @@ -26,6 +26,7 @@ package matching import ( "context" + "errors" "fmt" "sync/atomic" "time" @@ -76,10 +77,13 @@ type ( } ) -// errShutdown indicates that the task queue is shutting down -var errShutdown = &persistence.ConditionFailedError{Msg: "task queue shutting down"} +var ( + // errShutdown indicates that the task queue is shutting down + errShutdown = &persistence.ConditionFailedError{Msg: "task queue shutting down"} + errNonContiguousBlocks = errors.New("previous block end is not equal to current block") -var noTaskIDs = taskIDBlock{start: 1, end: 0} + noTaskIDs = taskIDBlock{start: 1, end: 0} +) func newTaskWriter( tlMgr *taskQueueManagerImpl, @@ -310,7 +314,12 @@ func (w *taskWriter) allocTaskIDBlock(ctx context.Context, prevBlockEnd int64) ( currBlock := rangeIDToTaskIDBlock(w.idAlloc.RangeID(), w.config.RangeSize) if currBlock.end != prevBlockEnd { return taskIDBlock{}, - fmt.Errorf("allocTaskIDBlock: invalid state: prevBlockEnd:%v != currTaskIDBlock:%+v", prevBlockEnd, currBlock) + fmt.Errorf( + "%w: allocTaskIDBlock: invalid state: prevBlockEnd:%v != currTaskIDBlock:%+v", + errNonContiguousBlocks, + prevBlockEnd, + currBlock, + ) } state, err := w.renewLeaseWithRetry(ctx, persistenceOperationRetryPolicy, common.IsPersistenceTransientError) if err != nil { diff --git a/service/matching/taskqueue.go b/service/matching/taskqueue.go index 79b9f2c80e5..ee2cb380e92 100644 --- a/service/matching/taskqueue.go +++ b/service/matching/taskqueue.go @@ -42,7 +42,7 @@ type ( } ) -// newTaskQueueID returns taskQueueID which uniquely identfies as task queue +// newTaskQueueID returns taskQueueID which uniquely identifies as task queue func newTaskQueueID(namespaceID namespace.ID, taskQueueName string, taskType enumspb.TaskQueueType) (*taskQueueID, error) { return newTaskQueueIDWithPartition(namespaceID, taskQueueName, taskType, -1) } diff --git a/service/matching/version_sets.go b/service/matching/version_sets.go index afc91be1831..387eda5440a 100644 --- a/service/matching/version_sets.go +++ b/service/matching/version_sets.go @@ -386,7 +386,7 @@ func checkVersionForStickyPoll(data *persistencespb.VersioningData, caps *common // Note data may be nil here, findVersion will return -1 then. setIdx, indexInSet := worker_versioning.FindBuildId(data, caps.BuildId) if setIdx < 0 { - // A poller is using a build ID but we don't know about that build ID. See comments in + // A poller is using a build ID, but we don't know about that build ID. See comments in // lookupVersionSetForPoll. If we consider it the default for its set, then we should // leave it on the sticky queue here. return nil @@ -444,7 +444,7 @@ func checkVersionForStickyAdd(data *persistencespb.VersioningData, buildId strin // Note data may be nil here, findVersion will return -1 then. setIdx, indexInSet := worker_versioning.FindBuildId(data, buildId) if setIdx < 0 { - // A poller is using a build ID but we don't know about that build ID. See comments in + // A poller is using a build ID, but we don't know about that build ID. See comments in // lookupVersionSetForAdd. If we consider it the default for its set, then we should // leave it on the sticky queue here. return nil @@ -457,7 +457,7 @@ func checkVersionForStickyAdd(data *persistencespb.VersioningData, buildId strin } // getSetID returns an arbitrary but consistent member of the set. -// We want Add and Poll requests for the same set to converge on a single id so we can match +// We want Add and Poll requests for the same set to converge on a single id, so we can match // them, but we don't have a single id for a set in the general case: in rare cases we may have // multiple ids (due to failovers). We can do this by picking an arbitrary id in the set, e.g. // the first. If the versioning data changes in any way, we'll re-resolve the set id, so this From 2fab90b0f062dfc96f22025226324df346fe8964 Mon Sep 17 00:00:00 2001 From: Yu Xia Date: Fri, 14 Jul 2023 09:45:37 -0700 Subject: [PATCH 2/3] Drop standby tasks in standby taks executor (#4626) **What changed?** 1. Revert "Drop task when namespace is not on the cluster (#4444)" 2. Add logic to drop standby task in transfer / timer standby task executor. **Why?** **How did you test it?** **Potential risks** **Is hotfix candidate?** --- common/namespace/registry.go | 1 - .../archival_queue_task_executor_test.go | 1 - service/history/queues/executable.go | 17 ++++------------- service/history/queues/executable_test.go | 14 -------------- .../timer_queue_standby_task_executor.go | 9 +++++++++ .../transfer_queue_standby_task_executor.go | 9 +++++++++ 6 files changed, 22 insertions(+), 29 deletions(-) diff --git a/common/namespace/registry.go b/common/namespace/registry.go index bc2499b64c9..213a5fbd855 100644 --- a/common/namespace/registry.go +++ b/common/namespace/registry.go @@ -33,7 +33,6 @@ import ( "time" "go.temporal.io/api/serviceerror" - "go.temporal.io/server/common" "go.temporal.io/server/common/cache" "go.temporal.io/server/common/clock" diff --git a/service/history/archival_queue_task_executor_test.go b/service/history/archival_queue_task_executor_test.go index 76c3dd71e4b..9d3ab1561c2 100644 --- a/service/history/archival_queue_task_executor_test.go +++ b/service/history/archival_queue_task_executor_test.go @@ -354,7 +354,6 @@ func TestArchivalQueueTaskExecutor(t *testing.T) { shardContext.EXPECT().GetConfig().Return(cfg).AnyTimes() mockMetadata := cluster.NewMockMetadata(p.Controller) mockMetadata.EXPECT().IsGlobalNamespaceEnabled().Return(true).AnyTimes() - mockMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName).AnyTimes() shardContext.EXPECT().GetClusterMetadata().Return(mockMetadata).AnyTimes() shardID := int32(1) diff --git a/service/history/queues/executable.go b/service/history/queues/executable.go index 84a287bee7e..2f9e146174d 100644 --- a/service/history/queues/executable.go +++ b/service/history/queues/executable.go @@ -177,18 +177,15 @@ func (e *executableImpl) Execute() (retErr error) { e.Unlock() return nil } - var namespaceName string - ns, err := e.namespaceRegistry.GetNamespaceByID(namespace.ID(e.GetNamespaceID())) - if err == nil { - namespaceName = ns.Name().String() - } + + ns, _ := e.namespaceRegistry.GetNamespaceName(namespace.ID(e.GetNamespaceID())) var callerInfo headers.CallerInfo switch e.priority { case ctasks.PriorityHigh: - callerInfo = headers.NewBackgroundCallerInfo(namespaceName) + callerInfo = headers.NewBackgroundCallerInfo(ns.String()) default: // priority low or unknown - callerInfo = headers.NewPreemptableCallerInfo(namespaceName) + callerInfo = headers.NewPreemptableCallerInfo(ns.String()) } ctx := headers.SetCallerInfo( metrics.AddMetricsContext(context.Background()), @@ -232,12 +229,6 @@ func (e *executableImpl) Execute() (retErr error) { // Not doing it here as for certain errors latency for the attempt should not be counted }() - if ns != nil && !ns.IsOnCluster(e.clusterMetadata.GetCurrentClusterName()) { - // Discard task if the namespace is not on the current cluster. - e.taggedMetricsHandler = e.metricsHandler.WithTags(e.estimateTaskMetricTag()...) - return consts.ErrTaskDiscarded - } - metricsTags, isActive, err := e.executor.Execute(ctx, e) e.taggedMetricsHandler = e.metricsHandler.WithTags(metricsTags...) diff --git a/service/history/queues/executable_test.go b/service/history/queues/executable_test.go index 122a9f6ca17..7ea1dbd099c 100644 --- a/service/history/queues/executable_test.go +++ b/service/history/queues/executable_test.go @@ -37,7 +37,6 @@ import ( enumspb "go.temporal.io/api/enums/v1" "go.temporal.io/api/serviceerror" - persistencepb "go.temporal.io/server/api/persistence/v1" "go.temporal.io/server/common/clock" "go.temporal.io/server/common/cluster" "go.temporal.io/server/common/definition" @@ -271,19 +270,6 @@ func (s *executableSuite) TestExecute_CallerInfo() { s.NoError(executable.Execute()) } -func (s *executableSuite) TestExecute_DiscardTask() { - executable := s.newTestExecutable() - registry := namespace.NewMockRegistry(s.controller) - executable.(*executableImpl).namespaceRegistry = registry - ns := namespace.NewGlobalNamespaceForTest(nil, nil, &persistencepb.NamespaceReplicationConfig{ - ActiveClusterName: "nonCurrentCluster", - Clusters: []string{"nonCurrentCluster"}, - }, 1) - - registry.EXPECT().GetNamespaceByID(gomock.Any()).Return(ns, nil).Times(2) - s.ErrorIs(executable.Execute(), consts.ErrTaskDiscarded) -} - func (s *executableSuite) TestExecuteHandleErr_ResetAttempt() { executable := s.newTestExecutable() s.mockExecutor.EXPECT().Execute(gomock.Any(), executable).Return(nil, true, errors.New("some random error")) diff --git a/service/history/timer_queue_standby_task_executor.go b/service/history/timer_queue_standby_task_executor.go index 16b6888492c..647db8dbf4e 100644 --- a/service/history/timer_queue_standby_task_executor.go +++ b/service/history/timer_queue_standby_task_executor.go @@ -435,6 +435,15 @@ func (t *timerQueueStandbyTaskExecutor) processTimer( ctx, cancel := context.WithTimeout(ctx, taskTimeout) defer cancel() + nsRecord, err := t.shard.GetNamespaceRegistry().GetNamespaceByID(namespace.ID(timerTask.GetNamespaceID())) + if err != nil { + return err + } + if !nsRecord.IsOnCluster(t.clusterName) { + // discard standby tasks + return consts.ErrTaskDiscarded + } + executionContext, release, err := getWorkflowExecutionContextForTask(ctx, t.cache, timerTask) if err != nil { return err diff --git a/service/history/transfer_queue_standby_task_executor.go b/service/history/transfer_queue_standby_task_executor.go index 1f689cf5ae2..94b8ac0896a 100644 --- a/service/history/transfer_queue_standby_task_executor.go +++ b/service/history/transfer_queue_standby_task_executor.go @@ -503,6 +503,15 @@ func (t *transferQueueStandbyTaskExecutor) processTransfer( ctx, cancel := context.WithTimeout(ctx, taskTimeout) defer cancel() + nsRecord, err := t.shard.GetNamespaceRegistry().GetNamespaceByID(namespace.ID(taskInfo.GetNamespaceID())) + if err != nil { + return err + } + if !nsRecord.IsOnCluster(t.clusterName) { + // discard standby tasks + return consts.ErrTaskDiscarded + } + weContext, release, err := getWorkflowExecutionContextForTask(ctx, t.cache, taskInfo) if err != nil { return err From a01a309af6999c42ed7b0e383903764b0a91a90a Mon Sep 17 00:00:00 2001 From: Yu Xia Date: Fri, 14 Jul 2023 09:51:54 -0700 Subject: [PATCH 3/3] Support cluster metadata customized tags (#4622) **What changed?** Support cluster metadata customized tags **Why?** Allow static tagging cluster **How did you test it?** Local testing **Potential risks** **Is hotfix candidate?** No --- common/cluster/metadata.go | 16 ++++++++++++--- common/cluster/metadata_test.go | 36 ++++++++++++++++++++++++++++++++- temporal/fx.go | 8 +++++++- 3 files changed, 55 insertions(+), 5 deletions(-) diff --git a/common/cluster/metadata.go b/common/cluster/metadata.go index 11e1ca93079..1db9593040d 100644 --- a/common/cluster/metadata.go +++ b/common/cluster/metadata.go @@ -35,6 +35,8 @@ import ( "sync/atomic" "time" + "golang.org/x/exp/maps" + "go.temporal.io/server/common" "go.temporal.io/server/common/collection" "go.temporal.io/server/common/dynamicconfig" @@ -98,6 +100,8 @@ type ( CurrentClusterName string `yaml:"currentClusterName"` // ClusterInformation contains all cluster names to corresponding information about that cluster ClusterInformation map[string]ClusterInformation `yaml:"clusterInformation"` + // Tag contains customized tag about the current cluster + Tags map[string]string `yaml:"tags"` } // ClusterInformation contains the information about each cluster which participated in cross DC @@ -107,8 +111,9 @@ type ( // Address indicate the remote service address(Host:Port). Host can be DNS name. RPCAddress string `yaml:"rpcAddress"` // Cluster ID allows to explicitly set the ID of the cluster. Optional. - ClusterID string `yaml:"-"` - ShardCount int32 `yaml:"-"` // Ignore this field when loading config. + ClusterID string `yaml:"-"` + ShardCount int32 `yaml:"-"` // Ignore this field when loading config. + Tags map[string]string `yaml:"-"` // Ignore this field. Use cluster.Config.Tags for customized tags. // private field to track cluster information updates version int64 } @@ -463,12 +468,14 @@ func (m *metadataImpl) refreshClusterMetadata(ctx context.Context) error { InitialFailoverVersion: newClusterInfo.InitialFailoverVersion, RPCAddress: newClusterInfo.RPCAddress, ShardCount: newClusterInfo.ShardCount, + Tags: newClusterInfo.Tags, version: newClusterInfo.version, } } else if newClusterInfo.version > oldClusterInfo.version { if newClusterInfo.Enabled == oldClusterInfo.Enabled && newClusterInfo.RPCAddress == oldClusterInfo.RPCAddress && - newClusterInfo.InitialFailoverVersion == oldClusterInfo.InitialFailoverVersion { + newClusterInfo.InitialFailoverVersion == oldClusterInfo.InitialFailoverVersion && + maps.Equal(newClusterInfo.Tags, oldClusterInfo.Tags) { // key cluster info does not change continue } @@ -478,6 +485,7 @@ func (m *metadataImpl) refreshClusterMetadata(ctx context.Context) error { InitialFailoverVersion: oldClusterInfo.InitialFailoverVersion, RPCAddress: oldClusterInfo.RPCAddress, ShardCount: oldClusterInfo.ShardCount, + Tags: oldClusterInfo.Tags, version: oldClusterInfo.version, } newEntries[clusterName] = &ClusterInformation{ @@ -485,6 +493,7 @@ func (m *metadataImpl) refreshClusterMetadata(ctx context.Context) error { InitialFailoverVersion: newClusterInfo.InitialFailoverVersion, RPCAddress: newClusterInfo.RPCAddress, ShardCount: newClusterInfo.ShardCount, + Tags: newClusterInfo.Tags, version: newClusterInfo.version, } } @@ -589,6 +598,7 @@ func (m *metadataImpl) listAllClusterMetadataFromDB( InitialFailoverVersion: getClusterResp.GetInitialFailoverVersion(), RPCAddress: getClusterResp.GetClusterAddress(), ShardCount: getClusterResp.GetHistoryShardCount(), + Tags: getClusterResp.GetTags(), version: getClusterResp.Version, } } diff --git a/common/cluster/metadata_test.go b/common/cluster/metadata_test.go index db2229e36c8..0a4a56be4df 100644 --- a/common/cluster/metadata_test.go +++ b/common/cluster/metadata_test.go @@ -53,6 +53,7 @@ type ( failoverVersionIncrement int64 clusterName string secondClusterName string + thirdClusterName string } ) @@ -77,6 +78,7 @@ func (s *metadataSuite) SetupTest() { s.failoverVersionIncrement = 100 s.clusterName = uuid.New() s.secondClusterName = uuid.New() + s.thirdClusterName = uuid.New() clusterInfo := map[string]ClusterInformation{ s.clusterName: { @@ -93,6 +95,13 @@ func (s *metadataSuite) SetupTest() { ShardCount: 2, version: 1, }, + s.thirdClusterName: { + Enabled: true, + InitialFailoverVersion: int64(5), + RPCAddress: uuid.New(), + ShardCount: 1, + version: 1, + }, } s.metadata = NewMetadata( s.isGlobalNamespaceEnabled, @@ -143,7 +152,7 @@ func (s *metadataSuite) Test_RegisterMetadataChangeCallback() { s.metadata.RegisterMetadataChangeCallback( s, func(oldClusterMetadata map[string]*ClusterInformation, newClusterMetadata map[string]*ClusterInformation) { - s.Equal(2, len(newClusterMetadata)) + s.Equal(3, len(newClusterMetadata)) }) s.metadata.UnRegisterMetadataChangeCallback(s) @@ -166,12 +175,20 @@ func (s *metadataSuite) Test_RefreshClusterMetadata_Success() { newMetadata, ok = newClusterMetadata[s.secondClusterName] s.True(ok) s.Nil(newMetadata) + + oldMetadata, ok = oldClusterMetadata[s.thirdClusterName] + s.True(ok) + s.NotNil(oldMetadata) + newMetadata, ok = newClusterMetadata[s.thirdClusterName] + s.True(ok) + s.NotNil(newMetadata) } s.mockClusterMetadataStore.EXPECT().ListClusterMetadata(gomock.Any(), gomock.Any()).Return( &persistence.ListClusterMetadataResponse{ ClusterMetadata: []*persistence.GetClusterMetadataResponse{ { + // No change and not include in callback ClusterMetadata: persistencespb.ClusterMetadata{ ClusterName: s.clusterName, IsConnectionEnabled: true, @@ -182,12 +199,26 @@ func (s *metadataSuite) Test_RefreshClusterMetadata_Success() { Version: 1, }, { + // Updated, included in callback + ClusterMetadata: persistencespb.ClusterMetadata{ + ClusterName: s.thirdClusterName, + IsConnectionEnabled: true, + InitialFailoverVersion: 1, + HistoryShardCount: 1, + ClusterAddress: uuid.New(), + Tags: map[string]string{"test": "test"}, + }, + Version: 2, + }, + { + // Newly added, included in callback ClusterMetadata: persistencespb.ClusterMetadata{ ClusterName: id, IsConnectionEnabled: true, InitialFailoverVersion: 2, HistoryShardCount: 2, ClusterAddress: uuid.New(), + Tags: map[string]string{"test": "test"}, }, Version: 2, }, @@ -195,6 +226,9 @@ func (s *metadataSuite) Test_RefreshClusterMetadata_Success() { }, nil) err := s.metadata.refreshClusterMetadata(context.Background()) s.NoError(err) + clusterInfo := s.metadata.GetAllClusterInfo() + s.Equal("test", clusterInfo[s.thirdClusterName].Tags["test"]) + s.Equal("test", clusterInfo[id].Tags["test"]) } func (s *metadataSuite) Test_ListAllClusterMetadataFromDB_Success() { diff --git a/temporal/fx.go b/temporal/fx.go index 2e8897a8e01..c0ea0a75f25 100644 --- a/temporal/fx.go +++ b/temporal/fx.go @@ -41,6 +41,7 @@ import ( "go.temporal.io/api/serviceerror" "go.uber.org/fx" "go.uber.org/fx/fxevent" + "golang.org/x/exp/maps" "google.golang.org/grpc" persistencespb "go.temporal.io/server/api/persistence/v1" @@ -719,6 +720,7 @@ func loadClusterInformationFromStore(ctx context.Context, svc *config.Config, cl InitialFailoverVersion: metadata.InitialFailoverVersion, RPCAddress: metadata.ClusterAddress, ShardCount: shardCount, + Tags: metadata.Tags, } if staticClusterMetadata, ok := svc.ClusterMetadata.ClusterInformation[metadata.ClusterName]; ok { if metadata.ClusterName != svc.ClusterMetadata.CurrentClusterName { @@ -770,6 +772,7 @@ func initCurrentClusterMetadataRecord( IsConnectionEnabled: currentClusterInfo.Enabled, UseClusterIdMembership: true, // Enable this for new cluster after 1.19. This is to prevent two clusters join into one ring. IndexSearchAttributes: initialIndexSearchAttributes, + Tags: svc.ClusterMetadata.Tags, }, }) if err != nil { @@ -804,7 +807,10 @@ func updateCurrentClusterMetadataRecord( currentClusterDBRecord.ClusterAddress = currentCLusterInfo.RPCAddress updateDBRecord = true } - // TODO: Add cluster tags + if !maps.Equal(currentClusterDBRecord.Tags, svc.ClusterMetadata.Tags) { + currentClusterDBRecord.Tags = svc.ClusterMetadata.Tags + updateDBRecord = true + } if !updateDBRecord { return nil