diff --git a/service/matching/taskReader.go b/service/matching/taskReader.go index 5a427f0d8b5a..9e4dc20254a3 100644 --- a/service/matching/taskReader.go +++ b/service/matching/taskReader.go @@ -65,7 +65,7 @@ func newTaskReader(tlMgr *taskQueueManagerImpl) *taskReader { return &taskReader{ status: common.DaemonStatusInitialized, tlMgr: tlMgr, - taskValidator: newTaskValidator(tlMgr.newIOContext, tlMgr.engine.historyClient, tlMgr.metricsHandler), + taskValidator: newTaskValidator(tlMgr.newIOContext, tlMgr.clusterMeta, tlMgr.namespaceRegistry, tlMgr.engine.historyClient, tlMgr.metricsHandler), notifyC: make(chan struct{}, 1), // we always dequeue the head of the buffer and try to dispatch it to a poller // so allocate one less than desired target buffer size diff --git a/service/matching/task_validation.go b/service/matching/task_validation.go index 16b488209092..475e056ddae1 100644 --- a/service/matching/task_validation.go +++ b/service/matching/task_validation.go @@ -34,7 +34,9 @@ import ( "go.temporal.io/server/api/historyservice/v1" persistencespb "go.temporal.io/server/api/persistence/v1" + "go.temporal.io/server/common/cluster" "go.temporal.io/server/common/metrics" + "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/primitives/timestamp" ) @@ -57,9 +59,11 @@ type ( } taskValidatorImpl struct { - newIOContextFn func() (context.Context, context.CancelFunc) - historyClient historyservice.HistoryServiceClient - metricsHandler metrics.Handler + newIOContextFn func() (context.Context, context.CancelFunc) + clusterMetadata cluster.Metadata + namespaceCache namespace.Registry + historyClient historyservice.HistoryServiceClient + metricsHandler metrics.Handler lastValidatedTaskInfo taskValidationInfo } @@ -67,13 +71,17 @@ type ( func newTaskValidator( newIOContextFn func() (context.Context, context.CancelFunc), + clusterMetadata cluster.Metadata, + namespaceCache namespace.Registry, historyClient historyservice.HistoryServiceClient, metricsHandler metrics.Handler, ) *taskValidatorImpl { return &taskValidatorImpl{ - newIOContextFn: newIOContextFn, - historyClient: historyClient, - metricsHandler: metricsHandler, + newIOContextFn: newIOContextFn, + clusterMetadata: clusterMetadata, + namespaceCache: namespaceCache, + historyClient: historyClient, + metricsHandler: metricsHandler, } } @@ -104,6 +112,22 @@ func (v *taskValidatorImpl) maybeValidate( // preValidate track a task and return if validation should be done func (v *taskValidatorImpl) preValidate( task *persistencespb.AllocatedTaskInfo, +) bool { + namespaceID := task.Data.NamespaceId + namespaceEntry, err := v.namespaceCache.GetNamespaceByID(namespace.ID(namespaceID)) + if err != nil { + // if cannot find the namespace entry, treat task as active + return v.preValidateActive(task) + } + if v.clusterMetadata.GetCurrentClusterName() == namespaceEntry.ActiveClusterName() { + return v.preValidateActive(task) + } + return v.preValidatePassive(task) +} + +// preValidateActive track a task and return if validation should be done, if namespace is active +func (v *taskValidatorImpl) preValidateActive( + task *persistencespb.AllocatedTaskInfo, ) bool { if v.lastValidatedTaskInfo.taskID != task.TaskId { // first time seen the task, caller should try to dispatch first @@ -125,6 +149,29 @@ func (v *taskValidatorImpl) preValidate( return time.Since(v.lastValidatedTaskInfo.validationTime) > taskReaderValidationThreshold } +// preValidatePassive track a task and return if validation should be done, if namespace is passive +func (v *taskValidatorImpl) preValidatePassive( + task *persistencespb.AllocatedTaskInfo, +) bool { + if v.lastValidatedTaskInfo.taskID != task.TaskId { + // first time seen the task, make a decision based on task creation time + if task.Data.CreateTime != nil { + v.lastValidatedTaskInfo = taskValidationInfo{ + taskID: task.TaskId, + validationTime: *task.Data.CreateTime, // task is valid when created + } + } else { + v.lastValidatedTaskInfo = taskValidationInfo{ + taskID: task.TaskId, + validationTime: time.Now().UTC(), // if no creation time specified, use now + } + } + } + + // this task has been validated before + return time.Since(v.lastValidatedTaskInfo.validationTime) > taskReaderValidationThreshold +} + // postValidate update tracked task info func (v *taskValidatorImpl) postValidate( task *persistencespb.AllocatedTaskInfo, diff --git a/service/matching/task_validation_test.go b/service/matching/task_validation_test.go index 84d0bd315f2d..fc3cb3b6dfdb 100644 --- a/service/matching/task_validation_test.go +++ b/service/matching/task_validation_test.go @@ -41,7 +41,9 @@ import ( "go.temporal.io/server/api/historyservice/v1" "go.temporal.io/server/api/historyservicemock/v1" persistencespb "go.temporal.io/server/api/persistence/v1" + "go.temporal.io/server/common/cluster" "go.temporal.io/server/common/metrics" + "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/primitives/timestamp" ) @@ -50,8 +52,10 @@ type ( suite.Suite *require.Assertions - controller *gomock.Controller - historyClient *historyservicemock.MockHistoryServiceClient + controller *gomock.Controller + clusterMetadata *cluster.MockMetadata + historyClient *historyservicemock.MockHistoryServiceClient + namespaceCache *namespace.MockRegistry namespaceID string workflowID string @@ -72,7 +76,9 @@ func (s *taskValidatorSuite) SetupTest() { s.Assertions = require.New(s.T()) s.controller = gomock.NewController(s.T()) + s.clusterMetadata = cluster.NewMockMetadata(s.controller) s.historyClient = historyservicemock.NewMockHistoryServiceClient(s.controller) + s.namespaceCache = namespace.NewMockRegistry(s.controller) s.namespaceID = uuid.New().String() s.workflowID = uuid.New().String() @@ -90,21 +96,21 @@ func (s *taskValidatorSuite) SetupTest() { s.taskValidator = newTaskValidator(func() (context.Context, context.CancelFunc) { return context.WithTimeout(context.Background(), 4*time.Second) - }, s.historyClient, metrics.NoopMetricsHandler) + }, s.clusterMetadata, s.namespaceCache, s.historyClient, metrics.NoopMetricsHandler) } func (s *taskValidatorSuite) TeardownTest() { s.controller.Finish() } -func (s *taskValidatorSuite) TestPreValidate_NewTask_Skip_WithCreationTIme() { +func (s *taskValidatorSuite) TestPreValidateActive_NewTask_Skip_WithCreationTime() { s.taskValidator.lastValidatedTaskInfo = taskValidationInfo{ taskID: s.task.TaskId - 1, validationTime: time.Unix(0, rand.Int63()), } s.task.Data.CreateTime = timestamp.TimePtr(time.Unix(0, rand.Int63())) - shouldValidate := s.taskValidator.preValidate(s.task) + shouldValidate := s.taskValidator.preValidateActive(s.task) s.False(shouldValidate) s.Equal(taskValidationInfo{ taskID: s.task.TaskId, @@ -112,36 +118,99 @@ func (s *taskValidatorSuite) TestPreValidate_NewTask_Skip_WithCreationTIme() { }, s.taskValidator.lastValidatedTaskInfo) } -func (s *taskValidatorSuite) TestPreValidate_NewTask_Skip_WithoutCreationTIme() { +func (s *taskValidatorSuite) TestPreValidateActive_NewTask_Skip_WithoutCreationTime() { s.taskValidator.lastValidatedTaskInfo = taskValidationInfo{ taskID: s.task.TaskId - 1, validationTime: time.Unix(0, rand.Int63()), } s.task.Data.CreateTime = nil - shouldValidate := s.taskValidator.preValidate(s.task) + shouldValidate := s.taskValidator.preValidateActive(s.task) s.False(shouldValidate) s.Equal(s.task.TaskId, s.taskValidator.lastValidatedTaskInfo.taskID) s.True(time.Now().Sub(s.taskValidator.lastValidatedTaskInfo.validationTime) < time.Second) } -func (s *taskValidatorSuite) TestPreValidate_ExistingTask_Validate() { +func (s *taskValidatorSuite) TestPreValidateActive_ExistingTask_Validate() { s.taskValidator.lastValidatedTaskInfo = taskValidationInfo{ taskID: s.task.TaskId, - validationTime: time.Now().Add(-2 * taskReaderValidationThreshold), + validationTime: time.Now().Add(-taskReaderValidationThreshold * 2), } - shouldValidate := s.taskValidator.preValidate(s.task) + shouldValidate := s.taskValidator.preValidateActive(s.task) s.True(shouldValidate) } -func (s *taskValidatorSuite) TestPreValidate_ExistingTask_Skip() { +func (s *taskValidatorSuite) TestPreValidateActive_ExistingTask_Skip() { s.taskValidator.lastValidatedTaskInfo = taskValidationInfo{ taskID: s.task.TaskId, - validationTime: time.Now().Add(2 * taskReaderValidationThreshold), + validationTime: time.Now().Add(taskReaderValidationThreshold * 2), } - shouldValidate := s.taskValidator.preValidate(s.task) + shouldValidate := s.taskValidator.preValidateActive(s.task) + s.False(shouldValidate) +} + +func (s *taskValidatorSuite) TestPreValidatePassive_NewTask_Skip_WithCreationTime() { + s.taskValidator.lastValidatedTaskInfo = taskValidationInfo{ + taskID: s.task.TaskId - 1, + validationTime: time.Unix(0, rand.Int63()), + } + s.task.Data.CreateTime = timestamp.TimePtr(time.Now().Add(-taskReaderValidationThreshold / 2)) + + shouldValidate := s.taskValidator.preValidatePassive(s.task) + s.False(shouldValidate) + s.Equal(taskValidationInfo{ + taskID: s.task.TaskId, + validationTime: *s.task.Data.CreateTime, + }, s.taskValidator.lastValidatedTaskInfo) +} + +func (s *taskValidatorSuite) TestPreValidatePassive_NewTask_Validate_WithCreationTime() { + s.taskValidator.lastValidatedTaskInfo = taskValidationInfo{ + taskID: s.task.TaskId - 1, + validationTime: time.Unix(0, rand.Int63()), + } + s.task.Data.CreateTime = timestamp.TimePtr(time.Now().Add(-taskReaderValidationThreshold * 2)) + + shouldValidate := s.taskValidator.preValidatePassive(s.task) + s.True(shouldValidate) + s.Equal(taskValidationInfo{ + taskID: s.task.TaskId, + validationTime: *s.task.Data.CreateTime, + }, s.taskValidator.lastValidatedTaskInfo) +} + +func (s *taskValidatorSuite) TestPreValidatePassive_NewTask_Skip_WithoutCreationTime() { + s.taskValidator.lastValidatedTaskInfo = taskValidationInfo{ + taskID: s.task.TaskId - 1, + validationTime: time.Unix(0, rand.Int63()), + } + s.task.Data.CreateTime = nil + + shouldValidate := s.taskValidator.preValidatePassive(s.task) + s.False(shouldValidate) + s.Equal(s.task.TaskId, s.taskValidator.lastValidatedTaskInfo.taskID) + s.True(time.Now().Sub(s.taskValidator.lastValidatedTaskInfo.validationTime) < time.Second) +} + +func (s *taskValidatorSuite) TestPreValidatePassive_ExistingTask_Validate() { + s.taskValidator.lastValidatedTaskInfo = taskValidationInfo{ + taskID: s.task.TaskId, + validationTime: time.Now().Add(-taskReaderValidationThreshold * 2), + } + + shouldValidate := s.taskValidator.preValidatePassive(s.task) + s.True(shouldValidate) +} + +func (s *taskValidatorSuite) TestPreValidatePassive_ExistingTask_Skip() { + s.taskValidator.lastValidatedTaskInfo = taskValidationInfo{ + taskID: s.task.TaskId, + validationTime: time.Now().Add(taskReaderValidationThreshold * 2), + } + + shouldValidate := s.taskValidator.preValidatePassive(s.task) s.False(shouldValidate) }