Skip to content

Commit

Permalink
Create a task category registry object
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelSnowden committed Oct 15, 2023
1 parent 1eeb92e commit 5dd37df
Show file tree
Hide file tree
Showing 45 changed files with 493 additions and 340 deletions.
59 changes: 22 additions & 37 deletions api/historyservice/v1/request_response.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 5 additions & 4 deletions client/history/historytest/clienttest.go
Expand Up @@ -97,7 +97,7 @@ func TestClient(t *testing.T, historyTaskQueueManager persistence.HistoryTaskQue
require.NoError(t, err)
enqueueTasks(t, historyTaskQueueManager, 2, queueKey.SourceCluster, queueKey.TargetCluster)
dlqKey := &commonspb.HistoryDLQKey{
TaskCategory: tasks.CategoryTransfer.ID(),
TaskCategory: int32(tasks.CategoryTransfer.ID()),
SourceCluster: queueKey.SourceCluster,
TargetCluster: queueKey.TargetCluster,
}
Expand Down Expand Up @@ -139,7 +139,7 @@ func readTasks(
for i := 0; i < numTasks; i++ {
res, err := client.GetDLQTasks(context.Background(), &historyservice.GetDLQTasksRequest{
DlqKey: &commonspb.HistoryDLQKey{
TaskCategory: tasks.CategoryTransfer.ID(),
TaskCategory: int32(tasks.CategoryTransfer.ID()),
SourceCluster: sourceCluster,
TargetCluster: targetCluster,
},
Expand All @@ -156,8 +156,9 @@ func readTasks(
func createServer(historyTaskQueueManager persistence.HistoryTaskQueueManager) *grpc.Server {
// TODO: find a better way to create a history handler
historyHandler := historyserver.HandlerProvider(historyserver.NewHandlerArgs{
TaskQueueManager: historyTaskQueueManager,
TracerProvider: fakeTracerProvider{},
TaskQueueManager: historyTaskQueueManager,
TracerProvider: fakeTracerProvider{},
TaskCategoryRegistry: tasks.NewDefaultTaskCategoryRegistry(),
})
grpcServer := grpc.NewServer()
historyservice.RegisterHistoryServiceServer(grpcServer, historyHandler)
Expand Down
3 changes: 2 additions & 1 deletion cmd/tools/tdbg/main.go
Expand Up @@ -27,10 +27,11 @@ package main
import (
"os"

"go.temporal.io/server/service/history/tasks"
"go.temporal.io/server/tools/tdbg"
)

func main() {
app := tdbg.NewCliApp(tdbg.NewClientFactory())
app := tdbg.NewCliApp(tdbg.NewClientFactory(), tasks.NewDefaultTaskCategoryRegistry())
_ = app.Run(os.Args)
}
12 changes: 6 additions & 6 deletions common/persistence/sql/execution_tasks.go
Expand Up @@ -146,7 +146,7 @@ func (m *sqlExecutionStore) getHistoryImmediateTasks(

rows, err := m.Db.RangeSelectFromHistoryImmediateTasks(ctx, sqlplugin.HistoryImmediateTasksRangeFilter{
ShardID: request.ShardID,
CategoryID: categoryID,
CategoryID: int32(categoryID),
InclusiveMinTaskID: inclusiveMinTaskID,
ExclusiveMaxTaskID: exclusiveMaxTaskID,
PageSize: request.BatchSize,
Expand Down Expand Up @@ -200,7 +200,7 @@ func (m *sqlExecutionStore) completeHistoryImmediateTask(

if _, err := m.Db.DeleteFromHistoryImmediateTasks(ctx, sqlplugin.HistoryImmediateTasksFilter{
ShardID: request.ShardID,
CategoryID: categoryID,
CategoryID: int32(categoryID),
TaskID: request.TaskKey.TaskID,
}); err != nil {
return serviceerror.NewUnavailable(
Expand Down Expand Up @@ -229,7 +229,7 @@ func (m *sqlExecutionStore) rangeCompleteHistoryImmediateTasks(

if _, err := m.Db.RangeDeleteFromHistoryImmediateTasks(ctx, sqlplugin.HistoryImmediateTasksRangeFilter{
ShardID: request.ShardID,
CategoryID: categoryID,
CategoryID: int32(categoryID),
InclusiveMinTaskID: request.InclusiveMinTaskKey.TaskID,
ExclusiveMaxTaskID: request.ExclusiveMaxTaskKey.TaskID,
}); err != nil {
Expand Down Expand Up @@ -263,7 +263,7 @@ func (m *sqlExecutionStore) getHistoryScheduledTasks(

rows, err := m.Db.RangeSelectFromHistoryScheduledTasks(ctx, sqlplugin.HistoryScheduledTasksRangeFilter{
ShardID: request.ShardID,
CategoryID: categoryID,
CategoryID: int32(categoryID),
InclusiveMinVisibilityTimestamp: pageToken.Timestamp,
InclusiveMinTaskID: pageToken.TaskID,
ExclusiveMaxVisibilityTimestamp: request.ExclusiveMaxTaskKey.FireTime,
Expand Down Expand Up @@ -313,7 +313,7 @@ func (m *sqlExecutionStore) completeHistoryScheduledTask(

if _, err := m.Db.DeleteFromHistoryScheduledTasks(ctx, sqlplugin.HistoryScheduledTasksFilter{
ShardID: request.ShardID,
CategoryID: categoryID,
CategoryID: int32(categoryID),
VisibilityTimestamp: request.TaskKey.FireTime,
TaskID: request.TaskKey.TaskID,
}); err != nil {
Expand All @@ -338,7 +338,7 @@ func (m *sqlExecutionStore) rangeCompleteHistoryScheduledTasks(
end := request.ExclusiveMaxTaskKey.FireTime
if _, err := m.Db.RangeDeleteFromHistoryScheduledTasks(ctx, sqlplugin.HistoryScheduledTasksRangeFilter{
ShardID: request.ShardID,
CategoryID: categoryID,
CategoryID: int32(categoryID),
InclusiveMinVisibilityTimestamp: start,
ExclusiveMaxVisibilityTimestamp: end,
}); err != nil {
Expand Down
8 changes: 4 additions & 4 deletions common/persistence/sql/execution_util.go
Expand Up @@ -698,7 +698,7 @@ func createImmediateTasks(
ctx context.Context,
tx sqlplugin.Tx,
shardID int32,
categoryID int32,
categoryID int,
immedidateTasks []p.InternalHistoryTask,
) error {
// This is for backward compatiblity.
Expand All @@ -721,7 +721,7 @@ func createImmediateTasks(
for _, task := range immedidateTasks {
immediateTasksRows = append(immediateTasksRows, sqlplugin.HistoryImmediateTasksRow{
ShardID: shardID,
CategoryID: categoryID,
CategoryID: int32(categoryID),
TaskID: task.Key.TaskID,
Data: task.Blob.Data,
DataEncoding: task.Blob.EncodingType.String(),
Expand All @@ -748,7 +748,7 @@ func createScheduledTasks(
ctx context.Context,
tx sqlplugin.Tx,
shardID int32,
categoryID int32,
categoryID int,
scheduledTasks []p.InternalHistoryTask,
) error {
// This is for backward compatiblity.
Expand All @@ -766,7 +766,7 @@ func createScheduledTasks(
for _, task := range scheduledTasks {
scheduledTasksRows = append(scheduledTasksRows, sqlplugin.HistoryScheduledTasksRow{
ShardID: shardID,
CategoryID: categoryID,
CategoryID: int32(categoryID),
VisibilityTimestamp: task.Key.FireTime,
TaskID: task.Key.TaskID,
Data: task.Blob.Data,
Expand Down
Expand Up @@ -778,9 +778,10 @@ message AddTasksRequest {
int32 shard_id = 1;

message Task {
// category is needed to deserialize the tasks. Examples include "transfer", "timer", etc. See the history/tasks
// package for a definitive list. Warning: this is not the same as the stringified value of a TaskCategory enum.
string category = 1;
// category is needed to deserialize the tasks. See TaskCategory for a list of options here. However, keep in mind
// that the list of valid options is registered dynamically with the server in the history/tasks package, so that
// enum is not comprehensive.
int32 category = 1;
// blob is the serialized task.
temporal.api.common.v1.DataBlob blob = 2;
}
Expand Down
7 changes: 6 additions & 1 deletion service/frontend/admin_handler.go
Expand Up @@ -123,6 +123,8 @@ type (

// DEPRECATED
persistenceExecutionManager persistence.ExecutionManager

taskCategoryRegistry tasks.TaskCategoryRegistry
}

NewAdminHandlerArgs struct {
Expand Down Expand Up @@ -153,6 +155,8 @@ type (

// DEPRECATED
PersistenceExecutionManager persistence.ExecutionManager

CategoryRegistry tasks.TaskCategoryRegistry
}
)

Expand Down Expand Up @@ -202,6 +206,7 @@ func NewAdminHandler(
saManager: args.SaManager,
clusterMetadata: args.ClusterMetadata,
healthServer: args.HealthServer,
taskCategoryRegistry: args.CategoryRegistry,
}
}

Expand Down Expand Up @@ -794,7 +799,7 @@ func (adh *AdminHandler) ListHistoryTasks(
return nil, errTaskRangeNotSet
}

taskCategory, ok := tasks.GetCategoryByID(int32(request.Category))
taskCategory, ok := adh.taskCategoryRegistry.GetCategoryByID(int(request.Category))
if !ok {
return nil, &serviceerror.InvalidArgument{
Message: fmt.Sprintf("unknown task category: %v", request.Category),
Expand Down
5 changes: 3 additions & 2 deletions service/frontend/admin_handler_test.go
Expand Up @@ -170,6 +170,7 @@ func (s *adminHandlerSuite) SetupTest() {
serialization.NewSerializer(),
clock.NewRealTimeSource(),
s.mockResource.GetExecutionManager(),
tasks.NewDefaultTaskCategoryRegistry(),
}
s.mockMetadata.EXPECT().GetCurrentClusterName().Return(uuid.New()).AnyTimes()
s.handler = NewAdminHandler(args)
Expand Down Expand Up @@ -1160,7 +1161,7 @@ func (s *adminHandlerSuite) TestGetDLQTasks() {
blob := &commonpb.DataBlob{}
expectation := s.mockHistoryClient.EXPECT().GetDLQTasks(gomock.Any(), &historyservice.GetDLQTasksRequest{
DlqKey: &commonspb.HistoryDLQKey{
TaskCategory: tasks.CategoryTransfer.ID(),
TaskCategory: int32(tasks.CategoryTransfer.ID()),
SourceCluster: "test-source-cluster",
TargetCluster: "test-target-cluster",
},
Expand All @@ -1187,7 +1188,7 @@ func (s *adminHandlerSuite) TestGetDLQTasks() {
}
response, err := s.handler.GetDLQTasks(context.Background(), &adminservice.GetDLQTasksRequest{
DlqKey: &commonspb.HistoryDLQKey{
TaskCategory: tasks.CategoryTransfer.ID(),
TaskCategory: int32(tasks.CategoryTransfer.ID()),
SourceCluster: "test-source-cluster",
TargetCluster: "test-target-cluster",
},
Expand Down
3 changes: 3 additions & 0 deletions service/frontend/fx.go
Expand Up @@ -65,6 +65,7 @@ import (
"go.temporal.io/server/common/telemetry"
"go.temporal.io/server/service"
"go.temporal.io/server/service/frontend/configs"
"go.temporal.io/server/service/history/tasks"
)

type FEReplicatorNamespaceReplicationQueue persistence.NamespaceReplicationQueue
Expand Down Expand Up @@ -503,6 +504,7 @@ func AdminHandlerProvider(
healthServer *health.Server,
eventSerializer serialization.Serializer,
timeSource clock.TimeSource,
taskCategoryRegistry tasks.TaskCategoryRegistry,
) *AdminHandler {
args := NewAdminHandlerArgs{
persistenceConfig,
Expand Down Expand Up @@ -530,6 +532,7 @@ func AdminHandlerProvider(
eventSerializer,
timeSource,
persistenceExecutionManager,
taskCategoryRegistry,
}
return NewAdminHandler(args)
}
Expand Down
22 changes: 5 additions & 17 deletions service/history/api/addtasks/api.go
Expand Up @@ -34,6 +34,7 @@ import (
"go.temporal.io/server/api/historyservice/v1"
"go.temporal.io/server/common/definition"
"go.temporal.io/server/common/persistence"
"go.temporal.io/server/service/history/api"
"go.temporal.io/server/service/history/shard"
"go.temporal.io/server/service/history/tasks"
)
Expand Down Expand Up @@ -63,6 +64,7 @@ func Invoke(
deserializer TaskDeserializer,
numShards int,
req *historyservice.AddTasksRequest,
taskRegistry tasks.TaskCategoryRegistry,
) (*historyservice.AddTasksResponse, error) {
if len(req.Tasks) > maxTasksPerRequest {
return nil, serviceerror.NewInvalidArgument(fmt.Sprintf(
Expand All @@ -83,12 +85,9 @@ func Invoke(
return nil, serviceerror.NewInvalidArgument(fmt.Sprintf("Nil task at index: %d", i))
}

category, ok := getCategoryByName(task.Category)
if !ok {
return nil, serviceerror.NewInvalidArgument(fmt.Sprintf(
"Invalid task category: %s",
task.Category,
))
category, err := api.GetTaskCategory(int(task.GetCategory()), taskRegistry)
if err != nil {
return nil, err
}

if task.Blob == nil {
Expand Down Expand Up @@ -139,14 +138,3 @@ func Invoke(

return &historyservice.AddTasksResponse{}, nil
}

func getCategoryByName(categoryName string) (tasks.Category, bool) {
categories := tasks.GetCategories()
for _, category := range categories {
if category.Name() == categoryName {
return category, true
}
}

return tasks.Category{}, false
}

0 comments on commit 5dd37df

Please sign in to comment.