Skip to content

Commit

Permalink
Refactor task queue user data loading (#4487)
Browse files Browse the repository at this point in the history
  • Loading branch information
dnr committed Jun 13, 2023
1 parent 71748a0 commit 03ca8e8
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 69 deletions.
74 changes: 32 additions & 42 deletions service/matching/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ package matching

import (
"context"
"errors"
"fmt"
"sync"
"time"
Expand All @@ -36,6 +35,7 @@ import (

"go.temporal.io/server/api/matchingservice/v1"
persistencespb "go.temporal.io/server/api/persistence/v1"
"go.temporal.io/server/common"
"go.temporal.io/server/common/log"
"go.temporal.io/server/common/log/tag"
"go.temporal.io/server/common/namespace"
Expand Down Expand Up @@ -145,18 +145,16 @@ func (db *taskQueueDB) takeOverTaskQueueLocked(
response.TaskQueueInfo.Kind = db.taskQueueKind
response.TaskQueueInfo.ExpiryTime = db.expiryTime()
response.TaskQueueInfo.LastUpdateTime = timestamp.TimeNowPtrUtc()
_, err := db.store.UpdateTaskQueue(ctx, &persistence.UpdateTaskQueueRequest{
if _, err := db.store.UpdateTaskQueue(ctx, &persistence.UpdateTaskQueueRequest{
RangeID: response.RangeID + 1,
TaskQueueInfo: response.TaskQueueInfo,
PrevRangeID: response.RangeID,
})
if err != nil {
}); err != nil {
return err
}
db.ackLevel = response.TaskQueueInfo.AckLevel
db.rangeID = response.RangeID + 1
_, _, err = db.getUserDataLocked(ctx)
return err
return nil

case *serviceerror.NotFound:
if _, err := db.store.CreateTaskQueue(ctx, &persistence.CreateTaskQueueRequest{
Expand All @@ -166,8 +164,7 @@ func (db *taskQueueDB) takeOverTaskQueueLocked(
return err
}
db.rangeID = initialRangeID
_, _, err = db.getUserDataLocked(ctx)
return err
return nil

default:
return err
Expand Down Expand Up @@ -310,7 +307,7 @@ func (db *taskQueueDB) GetUserData(
) (*persistencespb.VersionedTaskQueueUserData, chan struct{}, error) {
db.Lock()
defer db.Unlock()
return db.getUserDataLocked(ctx)
return db.userData, db.userDataChanged, nil
}

func (db *taskQueueDB) setUserDataLocked(userData *persistencespb.VersionedTaskQueueUserData) {
Expand All @@ -319,32 +316,29 @@ func (db *taskQueueDB) setUserDataLocked(userData *persistencespb.VersionedTaskQ
db.userDataChanged = make(chan struct{})
}

// db.Lock() must be held before calling.
// Returns in-memory cached value or reads from DB and updates the cached value.
// Note: can return nil value with no error.
func (db *taskQueueDB) getUserDataLocked(
ctx context.Context,
) (*persistencespb.VersionedTaskQueueUserData, chan struct{}, error) {
if db.userData == nil {
if !db.DbStoresUserData() {
return nil, db.userDataChanged, nil
}
// Loads user data from db (called only on initialization of taskQueueManager).
func (db *taskQueueDB) loadUserData(ctx context.Context) error {
if !db.DbStoresUserData() {
return nil
}

response, err := db.store.GetTaskQueueUserData(ctx, &persistence.GetTaskQueueUserDataRequest{
NamespaceID: db.namespaceID.String(),
TaskQueue: db.taskQueue.BaseNameString(),
})
if err != nil {
var notFoundError *serviceerror.NotFound
if errors.As(err, &notFoundError) {
return nil, db.userDataChanged, nil
}
return nil, nil, err
}
db.setUserDataLocked(response.UserData)
response, err := db.store.GetTaskQueueUserData(ctx, &persistence.GetTaskQueueUserDataRequest{
NamespaceID: db.namespaceID.String(),
TaskQueue: db.taskQueue.BaseNameString(),
})
if common.IsNotFoundError(err) {
// not all task queues have user data
response, err = &persistence.GetTaskQueueUserDataResponse{}, nil
}
if err != nil {
return err
}

return db.userData, db.userDataChanged, nil
db.Lock()
defer db.Unlock()
db.setUserDataLocked(response.UserData)

return nil
}

// UpdateUserData allows callers to update user data (such as worker build IDs) for this task queue. The pointer passed
Expand All @@ -370,17 +364,13 @@ func (db *taskQueueDB) UpdateUserData(
db.Lock()
defer db.Unlock()

userData, _, err := db.getUserDataLocked(ctx)
if err != nil {
return nil, false, err
}

preUpdateData := userData.GetData()
preUpdateData := db.userData.GetData()
preUpdateVersion := db.userData.GetVersion()
if preUpdateData == nil {
preUpdateData = &persistencespb.TaskQueueUserData{}
}
if knownVersion > 0 && userData.GetVersion() != knownVersion {
return nil, false, serviceerror.NewFailedPrecondition(fmt.Sprintf("user data version mismatch: requested: %d, current: %d", knownVersion, userData.GetVersion()))
if knownVersion > 0 && preUpdateVersion != knownVersion {
return nil, false, serviceerror.NewFailedPrecondition(fmt.Sprintf("user data version mismatch: requested: %d, current: %d", knownVersion, preUpdateVersion))
}
updatedUserData, shouldReplicate, err := updateFn(preUpdateData)
if err != nil {
Expand All @@ -407,13 +397,13 @@ func (db *taskQueueDB) UpdateUserData(
_, err = db.matchingClient.UpdateTaskQueueUserData(ctx, &matchingservice.UpdateTaskQueueUserDataRequest{
NamespaceId: db.namespaceID.String(),
TaskQueue: db.cachedQueueInfo().Name,
UserData: &persistencespb.VersionedTaskQueueUserData{Version: userData.GetVersion(), Data: updatedUserData},
UserData: &persistencespb.VersionedTaskQueueUserData{Version: preUpdateVersion, Data: updatedUserData},
BuildIdsAdded: added,
BuildIdsRemoved: removed,
})
var updatedVersionedData *persistencespb.VersionedTaskQueueUserData
if err == nil {
updatedVersionedData = &persistencespb.VersionedTaskQueueUserData{Version: userData.GetVersion() + 1, Data: updatedUserData}
updatedVersionedData = &persistencespb.VersionedTaskQueueUserData{Version: preUpdateVersion + 1, Data: updatedUserData}
db.setUserDataLocked(updatedVersionedData)
}
return updatedVersionedData, shouldReplicate, err
Expand Down
26 changes: 20 additions & 6 deletions service/matching/matchingEngine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2420,12 +2420,13 @@ func (m *testTaskManager) getTaskQueueManager(id *taskQueueID) *testTaskQueueMan

type testTaskQueueManager struct {
sync.Mutex
rangeID int64
ackLevel int64
createTaskCount int
getTasksCount int
tasks *treemap.Map
userData *persistencespb.VersionedTaskQueueUserData
rangeID int64
ackLevel int64
createTaskCount int
getTasksCount int
getUserDataCount int
tasks *treemap.Map
userData *persistencespb.VersionedTaskQueueUserData
}

func (m *testTaskQueueManager) RangeID() int64 {
Expand Down Expand Up @@ -2680,6 +2681,14 @@ func (m *testTaskManager) getGetTasksCount(taskQueue *taskQueueID) int {
return tlm.getTasksCount
}

// getGetUserDataCount returns how many times GetUserData was called
func (m *testTaskManager) getGetUserDataCount(taskQueue *taskQueueID) int {
tlm := m.getTaskQueueManager(taskQueue)
tlm.Lock()
defer tlm.Unlock()
return tlm.getUserDataCount
}

func (m *testTaskManager) String() string {
m.Lock()
defer m.Unlock()
Expand Down Expand Up @@ -2708,6 +2717,9 @@ func (m *testTaskManager) String() string {
// GetTaskQueueData implements persistence.TaskManager
func (m *testTaskManager) GetTaskQueueUserData(ctx context.Context, request *persistence.GetTaskQueueUserDataRequest) (*persistence.GetTaskQueueUserDataResponse, error) {
tlm := m.getTaskQueueManager(newTestTaskQueueID(namespace.ID(request.NamespaceID), request.TaskQueue, enumspb.TASK_QUEUE_TYPE_WORKFLOW))
tlm.Lock()
defer tlm.Unlock()
tlm.getUserDataCount++
return &persistence.GetTaskQueueUserDataResponse{
UserData: tlm.userData,
}, nil
Expand All @@ -2716,6 +2728,8 @@ func (m *testTaskManager) GetTaskQueueUserData(ctx context.Context, request *per
// UpdateTaskQueueUserData implements persistence.TaskManager
func (m *testTaskManager) UpdateTaskQueueUserData(ctx context.Context, request *persistence.UpdateTaskQueueUserDataRequest) error {
tlm := m.getTaskQueueManager(newTestTaskQueueID(namespace.ID(request.NamespaceID), request.TaskQueue, enumspb.TASK_QUEUE_TYPE_WORKFLOW))
tlm.Lock()
defer tlm.Unlock()
newData := *request.UserData
newData.Version++
tlm.userData = &newData
Expand Down
40 changes: 25 additions & 15 deletions service/matching/taskQueueManager.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,11 +309,7 @@ func (c *taskQueueManagerImpl) Start() {
c.liveness.Start()
c.taskWriter.Start()
c.taskReader.Start()
if c.shouldFetchUserData() {
c.goroGroup.Go(c.fetchUserDataLoop)
} else {
c.userDataInitialFetch.Set(struct{}{}, nil)
}
c.goroGroup.Go(c.fetchUserData)
c.logger.Info("", tag.LifeCycleStarted)
c.taggedMetricsHandler.Counter(metrics.TaskQueueStartedCounter.GetMetricName()).Record(1)
}
Expand Down Expand Up @@ -358,15 +354,6 @@ func (c *taskQueueManagerImpl) managesSpecificVersionSet() bool {
return c.taskQueueID.VersionSet() != ""
}

// shouldFetchUserData consolidates the logic for when to fetch user data from another task
// queue or (maybe) read it from the db. We set the userDataInitialFetch future from two
// places, so they need to agree on which one should set it.
func (c *taskQueueManagerImpl) shouldFetchUserData() bool {
// 1. If the db stores it, then we definitely should not be fetching.
// 2. Additionally, we should not fetch for "versioned" tqms.
return c.config.LoadUserData() && !c.db.DbStoresUserData() && !c.managesSpecificVersionSet()
}

func (c *taskQueueManagerImpl) WaitUntilInitialized(ctx context.Context) error {
_, err := c.initializedError.Get(ctx)
if err != nil {
Expand Down Expand Up @@ -751,9 +738,32 @@ func (c *taskQueueManagerImpl) userDataFetchSource() (string, error) {
return parent.FullName(), nil
}

func (c *taskQueueManagerImpl) fetchUserDataLoop(ctx context.Context) error {
func (c *taskQueueManagerImpl) fetchUserData(ctx context.Context) error {
ctx = c.callerInfoContext(ctx)

if !c.config.LoadUserData() {
// if disabled, mark ready now
c.userDataInitialFetch.Set(struct{}{}, nil)
return nil
}
if c.managesSpecificVersionSet() {
// tqm for specific version set doesn't have its own user data
c.userDataInitialFetch.Set(struct{}{}, nil)
return nil
}
if c.db.DbStoresUserData() {
// root workflow partition "owns" user data, read it from db
err := c.db.loadUserData(ctx)
c.userDataInitialFetch.Set(struct{}{}, err)
if err != nil {
// We can't recover from here without starting over, so unload the whole task queue
c.unloadFromEngine()
}
return err
}

// otherwise fetch from parent partition

fetchSource, err := c.userDataFetchSource()
if err != nil {
if err == errMissingNormalQueueName {
Expand Down
34 changes: 34 additions & 0 deletions service/matching/taskQueueManager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,40 @@ func TestTQMLoadsUserDataFromPersistenceOnInit(t *testing.T) {
tq.Stop()
}

func TestTQMLoadsUserDataFromPersistenceOnInitOnlyOnceWhenNoData(t *testing.T) {
controller := gomock.NewController(t)
defer controller.Finish()
ctx := context.Background()
tqId, err := newTaskQueueIDWithPartition(defaultNamespaceId, defaultRootTqID, enumspb.TASK_QUEUE_TYPE_WORKFLOW, 0)
require.NoError(t, err)
tqCfg := defaultTqmTestOpts(controller)
tqCfg.tqId = tqId

tq := mustCreateTestTaskQueueManagerWithConfig(t, controller, tqCfg)
tm := tq.engine.taskManager.(*testTaskManager)

require.Equal(t, 0, tm.getGetUserDataCount(tqId))

tq.Start()
require.NoError(t, tq.WaitUntilInitialized(ctx))

require.Equal(t, 1, tm.getGetUserDataCount(tqId))

userData, _, err := tq.GetUserData(ctx)
require.NoError(t, err)
require.Nil(t, userData)

require.Equal(t, 1, tm.getGetUserDataCount(tqId))

userData, _, err = tq.GetUserData(ctx)
require.NoError(t, err)
require.Nil(t, userData)

require.Equal(t, 1, tm.getGetUserDataCount(tqId))

tq.Stop()
}

func TestTQMFetchesUserDataFromOnInit(t *testing.T) {
controller := gomock.NewController(t)
defer controller.Finish()
Expand Down
16 changes: 10 additions & 6 deletions tests/versioning_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1623,9 +1623,11 @@ func (s *versioningIntegSuite) TestDescribeWorkflowExecution() {
s.waitForChan(ctx, started1)

// describe and check build id
resp, err := s.sdkClient.DescribeWorkflowExecution(ctx, run.GetID(), "")
s.NoError(err)
s.Equal(v1, resp.WorkflowExecutionInfo.MostRecentWorkerVersionStamp.BuildId)
s.Eventually(func() bool {
resp, err := s.sdkClient.DescribeWorkflowExecution(ctx, run.GetID(), "")
s.NoError(err)
return v1 == resp.GetWorkflowExecutionInfo().GetMostRecentWorkerVersionStamp().GetBuildId()
}, 5*time.Second, 100*time.Millisecond)

// now register v11 as newer compatible with v1
s.addCompatibleBuildId(ctx, tq, v11, v1, false)
Expand All @@ -1649,9 +1651,11 @@ func (s *versioningIntegSuite) TestDescribeWorkflowExecution() {
s.NoError(s.sdkClient.SignalWorkflow(ctx, run.GetID(), "", "wait", nil))
s.waitForChan(ctx, started11)

resp, err = s.sdkClient.DescribeWorkflowExecution(ctx, run.GetID(), "")
s.NoError(err)
s.Equal(v11, resp.WorkflowExecutionInfo.MostRecentWorkerVersionStamp.BuildId)
s.Eventually(func() bool {
resp, err := s.sdkClient.DescribeWorkflowExecution(ctx, run.GetID(), "")
s.NoError(err)
return v11 == resp.GetWorkflowExecutionInfo().GetMostRecentWorkerVersionStamp().GetBuildId()
}, 5*time.Second, 100*time.Millisecond)

// unblock. it should complete
s.NoError(s.sdkClient.SignalWorkflow(ctx, run.GetID(), "", "wait", nil))
Expand Down

0 comments on commit 03ca8e8

Please sign in to comment.