diff --git a/common/dynamicconfig/constants.go b/common/dynamicconfig/constants.go index e66c91e2b6e..d491d3cffa6 100644 --- a/common/dynamicconfig/constants.go +++ b/common/dynamicconfig/constants.go @@ -384,6 +384,11 @@ const ( MatchingLongPollExpirationInterval = "matching.longPollExpirationInterval" // MatchingSyncMatchWaitDuration is to wait time for sync match MatchingSyncMatchWaitDuration = "matching.syncMatchWaitDuration" + // MatchingLoadUserData can be used to entirely disable loading user data from persistence (and the inter node RPCs + // that propoagate it). When turned off, features that rely on user data (e.g. worker versioning) will essentially + // be disabled. When disabled, matching will drop tasks for versioned workflows and activities to avoid breaking + // versioning semantics. Operator intervention will be required to reschedule the dropped tasks. + MatchingLoadUserData = "matching.loadUserData" // MatchingUpdateAckInterval is the interval for update ack MatchingUpdateAckInterval = "matching.updateAckInterval" // MatchingMaxTaskQueueIdleTime is the time after which an idle task queue will be unloaded diff --git a/service/matching/config.go b/service/matching/config.go index dd90ba26c3a..c62187a6021 100644 --- a/service/matching/config.go +++ b/service/matching/config.go @@ -84,6 +84,8 @@ type ( EnableReadFromSecondaryVisibility dynamicconfig.BoolPropertyFnWithNamespaceFilter VisibilityDisableOrderByClause dynamicconfig.BoolPropertyFnWithNamespaceFilter VisibilityEnableManualPagination dynamicconfig.BoolPropertyFnWithNamespaceFilter + + LoadUserData dynamicconfig.BoolPropertyFnWithTaskQueueInfoFilters } forwarderConfig struct { @@ -119,6 +121,11 @@ type ( AdminNamespaceToPartitionDispatchRate func() float64 // partition qps = AdminNamespaceTaskQueueToPartitionDispatchRate(namespace, task_queue) AdminNamespaceTaskQueueToPartitionDispatchRate func() float64 + + // If set to false, matching does not load user data from DB for root partitions or fetch it via RPC from the + // root. When disbled, features that rely on user data (e.g. worker versioning) will essentially be disabled. + // See the documentation for constants.MatchingLoadUserData for the implications on versioning. + LoadUserData func() bool } ) @@ -150,6 +157,7 @@ func NewConfig( PersistenceDynamicRateLimitingParams: dc.GetMapProperty(dynamicconfig.MatchingPersistenceDynamicRateLimitingParams, dynamicconfig.DefaultDynamicRateLimitingParams), SyncMatchWaitDuration: dc.GetDurationPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingSyncMatchWaitDuration, 200*time.Millisecond), TestDisableSyncMatch: dc.GetBoolProperty(dynamicconfig.TestMatchingDisableSyncMatch, false), + LoadUserData: dc.GetBoolPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingLoadUserData, true), RPS: dc.GetIntProperty(dynamicconfig.MatchingRPS, 1200), RangeSize: 100000, GetTasksBatchSize: dc.GetIntPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingGetTasksBatchSize, 1000), @@ -206,6 +214,9 @@ func newTaskQueueConfig(id *taskQueueID, config *Config, namespace namespace.Nam return config.SyncMatchWaitDuration(namespace.String(), taskQueueName, taskType) }, TestDisableSyncMatch: config.TestDisableSyncMatch, + LoadUserData: func() bool { + return config.LoadUserData(namespace.String(), taskQueueName, taskType) + }, LongPollExpirationInterval: func() time.Duration { return config.LongPollExpirationInterval(namespace.String(), taskQueueName, taskType) }, diff --git a/service/matching/matchingEngine.go b/service/matching/matchingEngine.go index d83eb4d721a..65925ff9478 100644 --- a/service/matching/matchingEngine.go +++ b/service/matching/matchingEngine.go @@ -320,6 +320,13 @@ func (e *matchingEngineImpl) AddWorkflowTask( return false, err } + shouldDrop, err := e.shouldDropTask(origTaskQueue, addRequest.VersionDirective) + if err != nil { + return false, err + } else if shouldDrop { + return true, nil + } + // We don't need the userDataChanged channel here because: // - if we sync match or sticky worker unavailable, we're done // - if we spool to db, we'll re-resolve when it comes out of the db @@ -380,6 +387,13 @@ func (e *matchingEngineImpl) AddActivityTask( return false, err } + shouldDrop, err := e.shouldDropTask(origTaskQueue, addRequest.VersionDirective) + if err != nil { + return false, err + } else if shouldDrop { + return true, nil + } + // We don't need the userDataChanged channel here because: // - if we sync match, we're done // - if we spool to db, we'll re-resolve when it comes out of the db @@ -434,6 +448,12 @@ func (e *matchingEngineImpl) DispatchSpooledTask( unversionedOrigTaskQueue := newTaskQueueIDWithVersionSet(origTaskQueue, "") // Redirect and re-resolve if we're blocked in matcher and user data changes. for { + shouldDrop, err := e.shouldDropTask(unversionedOrigTaskQueue, directive) + if err != nil { + return err + } else if shouldDrop { + return nil + } taskQueue, userDataChanged, err := e.redirectToVersionedQueueForAdd( ctx, unversionedOrigTaskQueue, directive, stickyInfo) if err != nil { @@ -669,6 +689,12 @@ func (e *matchingEngineImpl) QueryWorkflow( return nil, err } + shouldDrop, err := e.shouldDropTask(origTaskQueue, queryRequest.VersionDirective) + if err != nil { + return nil, err + } else if shouldDrop { + return nil, serviceerror.NewFailedPrecondition("Operations on versioned workflows are disabled") + } // We don't need the userDataChanged channel here because we either do this sync (local or remote) // or fail with a relatively short timeout. taskQueue, _, err := e.redirectToVersionedQueueForAdd(ctx, origTaskQueue, queryRequest.VersionDirective, stickyInfo) @@ -1432,6 +1458,26 @@ func (e *matchingEngineImpl) redirectToVersionedQueueForPoll( return newTaskQueueIDWithVersionSet(taskQueue, versionSet), nil } +// When user data loading is disabled, we intentionally drop tasks for versioned workflows to avoid breaking versioning +// semantics and dispatching tasks to the wrong workers. +func (e *matchingEngineImpl) shouldDropTask(taskQueue *taskQueueID, directive *taskqueuespb.TaskVersionDirective) (bool, error) { + isVersioned := false + switch directive.GetValue().(type) { + case *taskqueuespb.TaskVersionDirective_UseDefault, + *taskqueuespb.TaskVersionDirective_BuildId: + isVersioned = true + } + if !isVersioned { + return false, nil + } + namespaceEntry, err := e.namespaceRegistry.GetNamespaceByID(taskQueue.namespaceID) + if err != nil { + return false, err + } + shouldDrop := !e.config.LoadUserData(namespaceEntry.Name().String(), taskQueue.BaseNameString(), taskQueue.taskType) + return shouldDrop, nil +} + func (e *matchingEngineImpl) redirectToVersionedQueueForAdd( ctx context.Context, taskQueue *taskQueueID, diff --git a/service/matching/matchingEngine_test.go b/service/matching/matchingEngine_test.go index 649c5a4e233..e6d6f87c73b 100644 --- a/service/matching/matchingEngine_test.go +++ b/service/matching/matchingEngine_test.go @@ -49,11 +49,13 @@ import ( "go.temporal.io/api/workflowservice/v1" clockspb "go.temporal.io/server/api/clock/v1" + "go.temporal.io/server/api/enums/v1" "go.temporal.io/server/api/historyservice/v1" "go.temporal.io/server/api/historyservicemock/v1" "go.temporal.io/server/api/matchingservice/v1" "go.temporal.io/server/api/matchingservicemock/v1" persistencespb "go.temporal.io/server/api/persistence/v1" + "go.temporal.io/server/api/taskqueue/v1" tokenspb "go.temporal.io/server/api/token/v1" "go.temporal.io/server/common" "go.temporal.io/server/common/clock" @@ -2246,6 +2248,75 @@ func (s *matchingEngineSuite) TestUpdateUserData_FailsOnKnownVersionMismatch() { s.ErrorAs(err, &failedPreconditionError) } +func (s *matchingEngineSuite) TestAddWorkflowTask_ForVersionedWorkflows_SilentlyDroppedWhenDisablingLoadingUserData() { + namespaceId := uuid.New() + tq := taskqueuepb.TaskQueue{ + Name: "test", + Kind: enumspb.TASK_QUEUE_KIND_NORMAL, + } + s.matchingEngine.config.LoadUserData = func(string, string, enumspb.TaskQueueType) bool { return false } + + _, err := s.matchingEngine.AddWorkflowTask(context.Background(), &matchingservice.AddWorkflowTaskRequest{ + NamespaceId: namespaceId, + Execution: &commonpb.WorkflowExecution{ + WorkflowId: "test", + RunId: uuid.New(), + }, + TaskQueue: &tq, + ScheduledEventId: 7, + Source: enums.TASK_SOURCE_HISTORY, + VersionDirective: &taskqueue.TaskVersionDirective{ + Value: &taskqueue.TaskVersionDirective_UseDefault{UseDefault: &types.Empty{}}, + }, + }) + s.Require().NoError(err) +} + +func (s *matchingEngineSuite) TestAddActivityTask_ForVersionedWorkflows_SilentlyDroppedWhenDisablingLoadingUserData() { + namespaceId := uuid.New() + tq := taskqueuepb.TaskQueue{ + Name: "test", + Kind: enumspb.TASK_QUEUE_KIND_NORMAL, + } + s.matchingEngine.config.LoadUserData = func(string, string, enumspb.TaskQueueType) bool { return false } + + _, err := s.matchingEngine.AddActivityTask(context.Background(), &matchingservice.AddActivityTaskRequest{ + NamespaceId: namespaceId, + Execution: &commonpb.WorkflowExecution{ + WorkflowId: "test", + RunId: uuid.New(), + }, + TaskQueue: &tq, + ScheduledEventId: 7, + Source: enums.TASK_SOURCE_HISTORY, + VersionDirective: &taskqueue.TaskVersionDirective{ + Value: &taskqueue.TaskVersionDirective_UseDefault{UseDefault: &types.Empty{}}, + }, + }) + s.Require().NoError(err) +} + +func (s *matchingEngineSuite) TestDispatchSpooledTask_ForVersionedWorkflows_SilentlyDroppedWhenDisablingLoadingUserData() { + namespaceId := namespace.ID(uuid.New()) + tqId, err := newTaskQueueID(namespaceId, "foo", enumspb.TASK_QUEUE_TYPE_ACTIVITY) + s.Require().NoError(err) + s.matchingEngine.config.LoadUserData = func(string, string, enumspb.TaskQueueType) bool { return false } + + err = s.matchingEngine.DispatchSpooledTask(context.Background(), &internalTask{ + event: &genericTaskInfo{ + &persistencespb.AllocatedTaskInfo{ + Data: &persistencespb.TaskInfo{ + VersionDirective: &taskqueue.TaskVersionDirective{ + Value: &taskqueue.TaskVersionDirective_UseDefault{UseDefault: &types.Empty{}}, + }, + }, + }, + func(ati *persistencespb.AllocatedTaskInfo, err error) {}, + }, + }, tqId, stickyInfo{}) + s.Require().NoError(err) +} + func (s *matchingEngineSuite) setupRecordActivityTaskStartedMock(tlName string) { activityTypeName := "activity1" activityID := "activityId1" diff --git a/service/matching/taskQueueManager.go b/service/matching/taskQueueManager.go index 2a21146c4d0..284c1a500dc 100644 --- a/service/matching/taskQueueManager.go +++ b/service/matching/taskQueueManager.go @@ -307,6 +307,8 @@ func (c *taskQueueManagerImpl) Start() { c.taskReader.Start() if c.shouldFetchUserData() { c.goroGroup.Go(c.fetchUserDataLoop) + } else { + c.userDataInitialFetch.Set(struct{}{}, nil) } c.logger.Info("", tag.LifeCycleStarted) c.taggedMetricsHandler.Counter(metrics.TaskQueueStartedCounter.GetMetricName()).Record(1) @@ -358,7 +360,7 @@ func (c *taskQueueManagerImpl) managesSpecificVersionSet() bool { func (c *taskQueueManagerImpl) shouldFetchUserData() bool { // 1. If the db stores it, then we definitely should not be fetching. // 2. Additionally, we should not fetch for "versioned" tqms. - return !c.db.DbStoresUserData() && !c.managesSpecificVersionSet() + return c.config.LoadUserData() && !c.db.DbStoresUserData() && !c.managesSpecificVersionSet() } func (c *taskQueueManagerImpl) WaitUntilInitialized(ctx context.Context) error { @@ -491,11 +493,17 @@ func (c *taskQueueManagerImpl) GetUserData(ctx context.Context) (*persistencespb if c.managesSpecificVersionSet() { return nil, nil, errNoUserDataOnVersionedTQM } + if !c.config.LoadUserData() { + return nil, nil, nil + } return c.db.GetUserData(ctx) } // UpdateUserData updates user data for this task queue and replicates across clusters if necessary. func (c *taskQueueManagerImpl) UpdateUserData(ctx context.Context, options UserDataUpdateOptions, updateFn UserDataUpdateFunc) error { + if !c.config.LoadUserData() { + return serviceerror.NewFailedPrecondition("Task queue user data operations are disabled") + } newData, shouldReplicate, err := c.db.UpdateUserData(ctx, updateFn, options.KnownVersion, options.TaskQueueLimitPerBuildId) if err != nil { return err diff --git a/service/matching/taskQueueManager_test.go b/service/matching/taskQueueManager_test.go index c1bbf3ff4dd..ac54948dadc 100644 --- a/service/matching/taskQueueManager_test.go +++ b/service/matching/taskQueueManager_test.go @@ -858,3 +858,19 @@ func TestUpdateOnNonRootFails(t *testing.T) { require.Error(t, err) require.ErrorIs(t, err, errUserDataNoMutateNonRoot) } + +func TestDisableLoadUserData_NonRootDoesNotRequestUserDataFromRoot(t *testing.T) { + ctx := context.Background() + controller := gomock.NewController(t) + defer controller.Finish() + taskQueueId, err := newTaskQueueIDWithPartition(defaultNamespaceId, defaultRootTqID, enumspb.TASK_QUEUE_TYPE_WORKFLOW, 1) + require.NoError(t, err) + tqCfg := defaultTqmTestOpts(controller) + tqCfg.tqId = taskQueueId + mgr := mustCreateTestTaskQueueManagerWithConfig(t, controller, tqCfg) + tqCfg.matchingClientMock.EXPECT().GetTaskQueueUserData(gomock.Any(), gomock.Any()).Times(0) + mgr.config.LoadUserData = func() bool { return false } + mgr.Start() + err = mgr.WaitUntilInitialized(ctx) + require.NoError(t, err) +} diff --git a/service/matching/taskWriter.go b/service/matching/taskWriter.go index 91540029923..e92b4d2e97a 100644 --- a/service/matching/taskWriter.go +++ b/service/matching/taskWriter.go @@ -219,9 +219,6 @@ func (w *taskWriter) appendTasks( func (w *taskWriter) taskWriterLoop(ctx context.Context) error { err := w.initReadWriteState(ctx) w.tlMgr.initializedError.Set(struct{}{}, err) - if !w.tlMgr.shouldFetchUserData() { - w.tlMgr.userDataInitialFetch.Set(struct{}{}, err) - } if err != nil { // We can't recover from here without starting over, so unload the whole task queue w.tlMgr.unloadFromEngine() diff --git a/tests/versioning_test.go b/tests/versioning_test.go index de9519774fe..f6ae642596c 100644 --- a/tests/versioning_test.go +++ b/tests/versioning_test.go @@ -39,14 +39,14 @@ import ( enumspb "go.temporal.io/api/enums/v1" "go.temporal.io/api/serviceerror" - "go.temporal.io/api/workflowservice/v1" + workflowservice "go.temporal.io/api/workflowservice/v1" "go.temporal.io/sdk/activity" sdkclient "go.temporal.io/sdk/client" "go.temporal.io/sdk/temporal" "go.temporal.io/sdk/worker" "go.temporal.io/sdk/workflow" - "go.temporal.io/server/api/matchingservice/v1" + matchingservice "go.temporal.io/server/api/matchingservice/v1" persistencespb "go.temporal.io/server/api/persistence/v1" "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/log/tag" @@ -1370,6 +1370,122 @@ func (s *versioningIntegSuite) dispatchCron() { s.GreaterOrEqual(runs2.Load(), int32(3)) } +func (s *versioningIntegSuite) TestDisableLoadUserData() { + tq := s.T().Name() + v1 := s.prefixed("v1") + v2 := s.prefixed("v2") + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // First insert some data (we'll try to read it below) + s.addNewDefaultBuildId(ctx, tq, v1) + + dc := s.testCluster.host.dcClient + defer dc.RemoveOverride(dynamicconfig.MatchingLoadUserData) + dc.OverrideValue(dynamicconfig.MatchingLoadUserData, false) + + // Verify update fails + _, err := s.engine.UpdateWorkerBuildIdCompatibility(ctx, &workflowservice.UpdateWorkerBuildIdCompatibilityRequest{ + Namespace: s.namespace, + TaskQueue: tq, + Operation: &workflowservice.UpdateWorkerBuildIdCompatibilityRequest_AddNewBuildIdInNewDefaultSet{ + AddNewBuildIdInNewDefaultSet: v2, + }, + }) + var failedPreconditionError *serviceerror.FailedPrecondition + s.Require().ErrorAs(err, &failedPreconditionError) + + s.unloadTaskQueue(ctx, tq) + + // Verify read returns empty + res, err := s.engine.GetWorkerBuildIdCompatibility(ctx, &workflowservice.GetWorkerBuildIdCompatibilityRequest{ + Namespace: s.namespace, + TaskQueue: tq, + }) + s.Require().NoError(err) + s.Require().Equal(0, len(res.GetMajorVersionSets())) +} + +func (s *versioningIntegSuite) TestWorkflowGetsStuckWhenDisablingLoadingUserData() { + tq := s.T().Name() + v1 := "v1" + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + s.addNewDefaultBuildId(ctx, tq, v1) + + dc := s.testCluster.host.dcClient + defer dc.RemoveOverride(dynamicconfig.MatchingLoadUserData) + dc.OverrideValue(dynamicconfig.MatchingLoadUserData, false) + + s.unloadTaskQueue(ctx, tq) + + var runs atomic.Int32 + wf := func(ctx workflow.Context) error { + runs.Add(1) + return nil + } + wrk := worker.New(s.sdkClient, tq, worker.Options{ + BuildID: s.prefixed("v1"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, + }) + wrk.RegisterWorkflowWithOptions(wf, workflow.RegisterOptions{Name: "wf"}) + s.NoError(wrk.Start()) + defer wrk.Stop() + + run, err := s.sdkClient.ExecuteWorkflow(ctx, sdkclient.StartWorkflowOptions{ + TaskQueue: tq, + WorkflowExecutionTimeout: 5 * time.Second, + }, "wf") + s.Require().NoError(err) + err = run.Get(ctx, nil) + var timeoutError *temporal.TimeoutError + s.Require().ErrorAs(err, &timeoutError) + s.Require().Equal(int32(0), runs.Load()) +} + +func (s *versioningIntegSuite) TestWorkflowQueryTimesOutWhenDisablingLoadingUserData() { + tq := s.T().Name() + v1 := "v1" + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + s.addNewDefaultBuildId(ctx, tq, v1) + + var runs atomic.Int32 + wf := func(ctx workflow.Context) error { + return workflow.SetQueryHandler(ctx, "query", func() (string, error) { + runs.Add(1) + return "response", nil + }) + } + wrk := worker.New(s.sdkClient, tq, worker.Options{ + BuildID: s.prefixed("v1"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, + }) + wrk.RegisterWorkflowWithOptions(wf, workflow.RegisterOptions{Name: "wf"}) + s.NoError(wrk.Start()) + defer wrk.Stop() + + run, err := s.sdkClient.ExecuteWorkflow(ctx, sdkclient.StartWorkflowOptions{ + TaskQueue: tq, + WorkflowExecutionTimeout: 5 * time.Second, + }, "wf") + s.Require().NoError(err) + + dc := s.testCluster.host.dcClient + defer dc.RemoveOverride(dynamicconfig.MatchingLoadUserData) + dc.OverrideValue(dynamicconfig.MatchingLoadUserData, false) + + s.unloadTaskQueue(ctx, tq) + + _, err = s.sdkClient.QueryWorkflow(ctx, run.GetID(), run.GetRunID(), "query") + var deadlineExceededError *serviceerror.DeadlineExceeded + s.Require().ErrorAs(err, &deadlineExceededError) + s.Require().Equal(int32(0), runs.Load()) +} + // Add a per test prefix to avoid hitting the namespace limit of mapped task queue per build id func (s *versioningIntegSuite) prefixed(buildId string) string { return fmt.Sprintf("t%x:%s", 0xffff&farm.Hash32([]byte(s.T().Name())), buildId) @@ -1454,6 +1570,15 @@ func (s *versioningIntegSuite) waitForChan(ctx context.Context, ch chan struct{} } } +func (s *versioningIntegSuite) unloadTaskQueue(ctx context.Context, tq string) { + _, err := s.testCluster.GetMatchingClient().ForceUnloadTaskQueue(ctx, &matchingservice.ForceUnloadTaskQueueRequest{ + NamespaceId: s.getNamespaceID(s.namespace), + TaskQueue: tq, + TaskQueueType: enumspb.TASK_QUEUE_TYPE_WORKFLOW, + }) + s.Require().NoError(err) +} + func containsBuildId(data *persistencespb.VersioningData, buildId string) bool { for _, set := range data.GetVersionSets() { for _, id := range set.BuildIds {