Skip to content

Commit

Permalink
New matching task tracker for task loading & tracking (#2263)
Browse files Browse the repository at this point in the history
* Add new matching task tracker
  * Add task iterator
  * Add task tracking functionality
  * UT
  • Loading branch information
wxing1292 committed Dec 11, 2021
1 parent 3de5687 commit 5099ff8
Show file tree
Hide file tree
Showing 15 changed files with 733 additions and 82 deletions.
10 changes: 5 additions & 5 deletions common/persistence/cassandra/matching_task_store.go
Expand Up @@ -220,7 +220,7 @@ func (d *MatchingTaskStore) DeleteTaskQueue(
request *p.DeleteTaskQueueRequest,
) error {
query := d.Session.Query(templateDeleteTaskQueueQuery,
request.TaskQueue.NamespaceID, request.TaskQueue.Name, request.TaskQueue.TaskType, rowTypeTaskQueue, taskQueueTaskID, request.RangeID)
request.TaskQueue.NamespaceID, request.TaskQueue.TaskQueueName, request.TaskQueue.TaskQueueType, rowTypeTaskQueue, taskQueueTaskID, request.RangeID)
previous := make(map[string]interface{})
applied, err := query.MapScanCAS(previous)
if err != nil {
Expand Down Expand Up @@ -323,8 +323,8 @@ func (d *MatchingTaskStore) GetTasks(
request.TaskQueue,
request.TaskType,
rowTypeTask,
request.MinTaskID,
request.MaxTaskID,
request.MinTaskIDExclusive,
request.MaxTaskIDInclusive,
)
iter := query.PageSize(request.PageSize).PageState(request.NextPageToken).Iter()

Expand Down Expand Up @@ -377,8 +377,8 @@ func (d *MatchingTaskStore) CompleteTask(
tli := request.TaskQueue
query := d.Session.Query(templateCompleteTaskQuery,
tli.NamespaceID,
tli.Name,
tli.TaskType,
tli.TaskQueueName,
tli.TaskQueueType,
rowTypeTask,
request.TaskID)

Expand Down
20 changes: 10 additions & 10 deletions common/persistence/dataInterfaces.go
Expand Up @@ -184,9 +184,9 @@ type (

// TaskQueueKey is the struct used to identity TaskQueues
TaskQueueKey struct {
NamespaceID string
Name string
TaskType enumspb.TaskQueueType
NamespaceID string
TaskQueueName string
TaskQueueType enumspb.TaskQueueType
}

// GetOrCreateShardRequest is used to get shard information, or supply
Expand Down Expand Up @@ -675,13 +675,13 @@ type (

// GetTasksRequest is used to retrieve tasks of a task queue
GetTasksRequest struct {
NamespaceID string
TaskQueue string
TaskType enumspb.TaskQueueType
MinTaskID int64 // exclusive
MaxTaskID int64 // inclusive
PageSize int
NextPageToken []byte
NamespaceID string
TaskQueue string
TaskType enumspb.TaskQueueType
MinTaskIDExclusive int64 // exclusive
MaxTaskIDInclusive int64 // inclusive
PageSize int
NextPageToken []byte
}

// GetTasksResponse is the response to GetTasksRequests
Expand Down
18 changes: 9 additions & 9 deletions common/persistence/persistence-tests/matchingPersistenceTest.go
Expand Up @@ -168,12 +168,12 @@ func (s *MatchingPersistenceSuite) TestGetTasksWithNoMaxReadLevel() {
for _, tc := range testCases {
s.Run(fmt.Sprintf("tc_%v_%v", tc.batchSz, tc.readLevel), func() {
response, err := s.TaskMgr.GetTasks(&p.GetTasksRequest{
NamespaceID: namespaceID,
TaskQueue: taskQueue,
TaskType: enumspb.TASK_QUEUE_TYPE_ACTIVITY,
PageSize: tc.batchSz,
MinTaskID: tc.readLevel,
MaxTaskID: math.MaxInt64,
NamespaceID: namespaceID,
TaskQueue: taskQueue,
TaskType: enumspb.TASK_QUEUE_TYPE_ACTIVITY,
PageSize: tc.batchSz,
MinTaskIDExclusive: tc.readLevel,
MaxTaskIDInclusive: math.MaxInt64,
})
s.NoError(err)
s.Equal(len(tc.taskIDs), len(response.Tasks), "wrong number of tasks")
Expand Down Expand Up @@ -395,9 +395,9 @@ func (s *MatchingPersistenceSuite) deleteAllTaskQueue() {
it := i.Data
err = s.TaskMgr.DeleteTaskQueue(&p.DeleteTaskQueueRequest{
TaskQueue: &p.TaskQueueKey{
NamespaceID: it.GetNamespaceId(),
Name: it.Name,
TaskType: it.TaskType,
NamespaceID: it.GetNamespaceId(),
TaskQueueName: it.Name,
TaskQueueType: it.TaskType,
},
RangeID: i.RangeID,
})
Expand Down
16 changes: 8 additions & 8 deletions common/persistence/persistence-tests/persistenceTestBase.go
Expand Up @@ -1226,11 +1226,11 @@ func (s *TestBase) CreateActivityTasks(namespaceID string, workflowExecution com
// GetTasks is a utility method to get tasks from persistence
func (s *TestBase) GetTasks(namespaceID string, taskQueue string, taskType enumspb.TaskQueueType, batchSize int) (*persistence.GetTasksResponse, error) {
response, err := s.TaskMgr.GetTasks(&persistence.GetTasksRequest{
NamespaceID: namespaceID,
TaskQueue: taskQueue,
TaskType: taskType,
PageSize: batchSize,
MaxTaskID: math.MaxInt64,
NamespaceID: namespaceID,
TaskQueue: taskQueue,
TaskType: taskType,
PageSize: batchSize,
MaxTaskIDInclusive: math.MaxInt64,
})

if err != nil {
Expand All @@ -1244,9 +1244,9 @@ func (s *TestBase) GetTasks(namespaceID string, taskQueue string, taskType enums
func (s *TestBase) CompleteTask(namespaceID string, taskQueue string, taskType enumspb.TaskQueueType, taskID int64) error {
return s.TaskMgr.CompleteTask(&persistence.CompleteTaskRequest{
TaskQueue: &persistence.TaskQueueKey{
NamespaceID: namespaceID,
TaskType: taskType,
Name: taskQueue,
NamespaceID: namespaceID,
TaskQueueType: taskType,
TaskQueueName: taskQueue,
},
TaskID: taskID,
})
Expand Down
14 changes: 7 additions & 7 deletions common/persistence/sql/task.go
Expand Up @@ -122,7 +122,7 @@ func (m *sqlTaskManager) GetTaskQueue(request *persistence.InternalGetTaskQueueR
}, nil
case sql.ErrNoRows:
return nil, serviceerror.NewNotFound(
fmt.Sprintf("GetTaskQueue operation failed. TaskQueue: %v, TaskType: %v, Error: %v",
fmt.Sprintf("GetTaskQueue operation failed. TaskQueue: %v, TaskQueueType: %v, Error: %v",
request.TaskQueue, request.TaskType, err))
default:
return nil, serviceerror.NewUnavailable(
Expand Down Expand Up @@ -360,7 +360,7 @@ func (m *sqlTaskManager) DeleteTaskQueue(
if err != nil {
return serviceerror.NewUnavailable(err.Error())
}
tqId, tqHash := m.taskQueueIdAndHash(nidBytes, request.TaskQueue.Name, request.TaskQueue.TaskType)
tqId, tqHash := m.taskQueueIdAndHash(nidBytes, request.TaskQueue.TaskQueueName, request.TaskQueue.TaskQueueType)
result, err := m.Db.DeleteFromTaskQueues(ctx, sqlplugin.TaskQueuesFilter{
RangeHash: tqHash,
TaskQueueID: tqId,
Expand Down Expand Up @@ -430,8 +430,8 @@ func (m *sqlTaskManager) GetTasks(
return nil, serviceerror.NewUnavailable(err.Error())
}

minTaskID := request.MinTaskID
maxTaskID := request.MaxTaskID
minTaskID := request.MinTaskIDExclusive
maxTaskID := request.MaxTaskIDInclusive
if len(request.NextPageToken) != 0 {
token, err := deserializeMatchingTaskPageToken(request.NextPageToken)
if err != nil {
Expand Down Expand Up @@ -484,7 +484,7 @@ func (m *sqlTaskManager) CompleteTask(
}

taskID := request.TaskID
tqId, tqHash := m.taskQueueIdAndHash(nidBytes, request.TaskQueue.Name, request.TaskQueue.TaskType)
tqId, tqHash := m.taskQueueIdAndHash(nidBytes, request.TaskQueue.TaskQueueName, request.TaskQueue.TaskQueueType)
_, err = m.Db.DeleteFromTasks(ctx, sqlplugin.TasksFilter{
RangeHash: tqHash,
TaskQueueID: tqId,
Expand Down Expand Up @@ -521,7 +521,7 @@ func (m *sqlTaskManager) CompleteTasksLessThan(
return int(nRows), nil
}

// Returns uint32 hash for a particular TaskQueue/Task given a Namespace, Name and TaskQueueType
// Returns uint32 hash for a particular TaskQueue/Task given a Namespace, TaskQueueName and TaskQueueType
func (m *sqlTaskManager) calculateTaskQueueHash(
namespaceID primitives.UUID,
name string,
Expand All @@ -530,7 +530,7 @@ func (m *sqlTaskManager) calculateTaskQueueHash(
return farm.Fingerprint32(m.taskQueueId(namespaceID, name, taskType))
}

// Returns uint32 hash for a particular TaskQueue/Task given a Namespace, Name and TaskQueueType
// Returns uint32 hash for a particular TaskQueue/Task given a Namespace, TaskQueueName and TaskQueueType
func (m *sqlTaskManager) taskQueueIdAndHash(
namespaceID primitives.UUID,
name string,
Expand Down
4 changes: 2 additions & 2 deletions common/persistence/task_manager.go
Expand Up @@ -85,7 +85,7 @@ func (m *taskManagerImpl) LeaseTaskQueue(request *LeaseTaskQueueRequest) (*Lease
}
taskQueueInfo, err := m.serializer.TaskQueueInfoFromBlob(taskQueue.TaskQueueInfo)
if err != nil {
return nil, serviceerror.NewUnavailable(fmt.Sprintf("LeaseTaskQueue operation failed during serialization. TaskQueue: %v, TaskType: %v, Error: %v", request.TaskQueue, request.TaskType, err))
return nil, serviceerror.NewUnavailable(fmt.Sprintf("LeaseTaskQueue operation failed during serialization. TaskQueue: %v, TaskQueueType: %v, Error: %v", request.TaskQueue, request.TaskType, err))
}

taskQueueInfo.LastUpdateTime = timestamp.TimeNowPtrUtc()
Expand Down Expand Up @@ -233,7 +233,7 @@ func (m *taskManagerImpl) CreateTasks(request *CreateTasksRequest) (*CreateTasks
}

func (m *taskManagerImpl) GetTasks(request *GetTasksRequest) (*GetTasksResponse, error) {
if request.MinTaskID >= request.MaxTaskID {
if request.MinTaskIDExclusive >= request.MaxTaskIDInclusive {
return &GetTasksResponse{}, nil
}

Expand Down
12 changes: 6 additions & 6 deletions service/frontend/adminHandler.go
Expand Up @@ -1493,12 +1493,12 @@ func (adh *AdminHandler) GetTaskQueueTasks(

maxTaskID := request.GetMaxTaskId()
req := &persistence.GetTasksRequest{
NamespaceID: namespaceID.String(),
TaskQueue: request.GetTaskQueue(),
TaskType: request.GetTaskQueueType(),
MinTaskID: request.GetMinTaskId(),
MaxTaskID: maxTaskID,
PageSize: int(request.GetBatchSize()),
NamespaceID: namespaceID.String(),
TaskQueue: request.GetTaskQueue(),
TaskType: request.GetTaskQueueType(),
MinTaskIDExclusive: request.GetMinTaskId(),
MaxTaskIDInclusive: maxTaskID,
PageSize: int(request.GetBatchSize()),
}

resp, err := taskMgr.GetTasks(req)
Expand Down
18 changes: 9 additions & 9 deletions service/matching/db.go
Expand Up @@ -146,22 +146,22 @@ func (db *taskQueueDB) CreateTasks(tasks []*persistencespb.AllocatedTaskInfo) (*
// GetTasks returns a batch of tasks between the given range
func (db *taskQueueDB) GetTasks(minTaskID int64, maxTaskID int64, batchSize int) (*persistence.GetTasksResponse, error) {
return db.store.GetTasks(&persistence.GetTasksRequest{
NamespaceID: db.namespaceID.String(),
TaskQueue: db.taskQueueName,
TaskType: db.taskType,
PageSize: batchSize,
MinTaskID: minTaskID, // exclusive
MaxTaskID: maxTaskID, // inclusive
NamespaceID: db.namespaceID.String(),
TaskQueue: db.taskQueueName,
TaskType: db.taskType,
PageSize: batchSize,
MinTaskIDExclusive: minTaskID, // exclusive
MaxTaskIDInclusive: maxTaskID, // inclusive
})
}

// CompleteTask deletes a single task from this task queue
func (db *taskQueueDB) CompleteTask(taskID int64) error {
err := db.store.CompleteTask(&persistence.CompleteTaskRequest{
TaskQueue: &persistence.TaskQueueKey{
NamespaceID: db.namespaceID.String(),
Name: db.taskQueueName,
TaskType: db.taskType,
NamespaceID: db.namespaceID.String(),
TaskQueueName: db.taskQueueName,
TaskQueueType: db.taskType,
},
TaskID: taskID,
})
Expand Down

0 comments on commit 5099ff8

Please sign in to comment.