Skip to content

Commit

Permalink
Fix potential deadlock in shard addTask (#2823)
Browse files Browse the repository at this point in the history
  • Loading branch information
yycptt committed May 10, 2022
1 parent 2f43b88 commit bffb755
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 108 deletions.
90 changes: 65 additions & 25 deletions service/history/shard/context_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ import (
"go.temporal.io/server/common/primitives/timestamp"
"go.temporal.io/server/common/searchattribute"
"go.temporal.io/server/service/history/configs"
"go.temporal.io/server/service/history/consts"
"go.temporal.io/server/service/history/events"
"go.temporal.io/server/service/history/tasks"
"go.temporal.io/server/service/history/vclock"
Expand Down Expand Up @@ -616,14 +617,25 @@ func (s *ContextImpl) AddTasks(
return err
}

s.wLock()
defer s.wUnlock()
engine, err := s.GetEngineWithContext(ctx)
if err != nil {
return err
}

s.wLock()
if err := s.errorByStateLocked(); err != nil {
s.wUnlock()
return err
}

return s.addTasksLocked(ctx, request, namespaceEntry)
err = s.addTasksLocked(ctx, request, namespaceEntry)
s.wUnlock()

if OperationPossiblySucceeded(err) {
engine.NotifyNewTasks(namespaceEntry.ActiveClusterName(), request.Tasks)
}

return err
}

func (s *ContextImpl) CreateWorkflowExecution(
Expand Down Expand Up @@ -879,15 +891,7 @@ func (s *ContextImpl) addTasksLocked(

request.RangeID = s.getRangeIDLocked()
err := s.executionManager.AddHistoryTasks(ctx, request)
if err = s.handleWriteErrorAndUpdateMaxReadLevelLocked(err, transferMaxReadLevel); err != nil {
return err
}
engine, err := s.GetEngineWithContext(ctx)
if err != nil {
return err
}
engine.NotifyNewTasks(namespaceEntry.ActiveClusterName(), request.Tasks)
return nil
return s.handleWriteErrorAndUpdateMaxReadLevelLocked(err, transferMaxReadLevel)
}

func (s *ContextImpl) AppendHistoryEvents(
Expand Down Expand Up @@ -935,7 +939,7 @@ func (s *ContextImpl) DeleteWorkflowExecution(
newTaskVersion int64,
startTime *time.Time,
closeTime *time.Time,
) error {
) (retErr error) {
// DeleteWorkflowExecution is a 4-steps process (order is very important and should not be changed):
// 1. Add visibility delete task, i.e. schedule visibility record delete,
// 2. Delete current workflow execution pointer,
Expand All @@ -957,6 +961,11 @@ func (s *ContextImpl) DeleteWorkflowExecution(
}
defer cancel()

engine, err := s.GetEngineWithContext(ctx)
if err != nil {
return err
}

// Do not get namespace cache within shard lock.
namespaceEntry, err := s.GetNamespaceRegistry().GetNamespaceByID(namespace.ID(key.NamespaceID))
deleteVisibilityRecord := true
Expand All @@ -970,6 +979,13 @@ func (s *ContextImpl) DeleteWorkflowExecution(
}
}

var newTasks map[tasks.Category][]tasks.Task
defer func() {
if OperationPossiblySucceeded(retErr) && newTasks != nil {
engine.NotifyNewTasks(namespaceEntry.ActiveClusterName(), newTasks)
}
}()

s.wLock()
defer s.wUnlock()

Expand All @@ -979,24 +995,26 @@ func (s *ContextImpl) DeleteWorkflowExecution(

// Step 1. Delete visibility.
if deleteVisibilityRecord {
// TODO: move to existing task generator logic
newTasks = map[tasks.Category][]tasks.Task{
tasks.CategoryVisibility: {
&tasks.DeleteExecutionVisibilityTask{
// TaskID is set by addTasksLocked
WorkflowKey: key,
VisibilityTimestamp: s.timeSource.Now(),
Version: newTaskVersion,
StartTime: startTime,
CloseTime: closeTime,
},
},
}
addTasksRequest := &persistence.AddHistoryTasksRequest{
ShardID: s.shardID,
NamespaceID: key.NamespaceID,
WorkflowID: key.WorkflowID,
RunID: key.RunID,

Tasks: map[tasks.Category][]tasks.Task{
tasks.CategoryVisibility: {
&tasks.DeleteExecutionVisibilityTask{
// TaskID is set by addTasksLocked
WorkflowKey: key,
VisibilityTimestamp: s.timeSource.Now(),
Version: newTaskVersion,
StartTime: startTime,
CloseTime: closeTime,
},
},
},
Tasks: newTasks,
}
err = s.addTasksLocked(ctx, addTasksRequest, namespaceEntry)
if err != nil {
Expand Down Expand Up @@ -2001,6 +2019,28 @@ func (s *ContextImpl) ensureMinContextTimeout(
return newContext, cancel, nil
}

func OperationPossiblySucceeded(err error) bool {
if err == consts.ErrConflict {
return false
}

switch err.(type) {
case *persistence.CurrentWorkflowConditionFailedError,
*persistence.WorkflowConditionFailedError,
*persistence.ConditionFailedError,
*persistence.ShardOwnershipLostError,
*persistence.InvalidPersistenceRequestError,
*persistence.TransactionSizeLimitError,
*serviceerror.ResourceExhausted,
*serviceerror.NotFound,
*serviceerror.NamespaceNotFound:
// Persistence failure that means that write was definitely not committed.
return false
default:
return true
}
}

func convertAckLevelToTaskKey(
categoryType tasks.CategoryType,
ackLevel int64,
Expand Down
78 changes: 26 additions & 52 deletions service/history/shard/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,13 @@ import (
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"

enumsspb "go.temporal.io/server/api/enums/v1"
persistencespb "go.temporal.io/server/api/persistence/v1"
"go.temporal.io/server/common/clock"
"go.temporal.io/server/common/cluster"
"go.temporal.io/server/common/convert"
"go.temporal.io/server/common/namespace"
"go.temporal.io/server/common/persistence"
"go.temporal.io/server/common/primitives/timestamp"
"go.temporal.io/server/common/resource"
"go.temporal.io/server/service/history/tasks"
"go.temporal.io/server/service/history/tests"
)
Expand All @@ -51,19 +49,15 @@ type (
suite.Suite
*require.Assertions

controller *gomock.Controller
shardContext Context

mockResource *resource.Test

namespaceID namespace.ID
mockNamespaceCache *namespace.MockRegistry
namespaceEntry *namespace.Namespace
timeSource *clock.EventTimeSource

controller *gomock.Controller
mockShard Context
mockClusterMetadata *cluster.MockMetadata
mockShardManager *persistence.MockShardManager
mockExecutionManager *persistence.MockExecutionManager
mockNamespaceCache *namespace.MockRegistry
mockHistoryEngine *MockEngine

timeSource *clock.EventTimeSource
}
)

Expand All @@ -89,20 +83,19 @@ func (s *contextSuite) SetupTest() {
tests.NewDynamicConfig(),
s.timeSource,
)
s.shardContext = shardContext
s.mockShard = shardContext

s.mockResource = shardContext.Resource
shardContext.MockHostInfoProvider.EXPECT().HostInfo().Return(s.mockResource.GetHostInfo()).AnyTimes()
shardContext.MockHostInfoProvider.EXPECT().HostInfo().Return(shardContext.Resource.GetHostInfo()).AnyTimes()

s.namespaceID = "namespace-Id"
s.namespaceEntry = namespace.NewLocalNamespaceForTest(&persistencespb.NamespaceInfo{Id: s.namespaceID.String()}, &persistencespb.NamespaceConfig{}, "")
s.mockNamespaceCache = s.mockResource.NamespaceCache
shardContext.namespaceRegistry = s.mockResource.NamespaceCache
s.mockNamespaceCache = shardContext.Resource.NamespaceCache
s.mockNamespaceCache.EXPECT().GetNamespaceByID(tests.NamespaceID).Return(tests.LocalNamespaceEntry, nil).AnyTimes()

s.mockClusterMetadata = s.mockResource.ClusterMetadata
shardContext.clusterMetadata = s.mockClusterMetadata
s.mockClusterMetadata = shardContext.Resource.ClusterMetadata
s.mockClusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName).AnyTimes()
s.mockClusterMetadata.EXPECT().GetAllClusterInfo().Return(cluster.TestAllClusterInfo).AnyTimes()

s.mockExecutionManager = s.mockResource.ExecutionMgr
s.mockExecutionManager = shardContext.Resource.ExecutionMgr
s.mockShardManager = shardContext.Resource.ShardMgr
s.mockHistoryEngine = NewMockEngine(s.controller)
shardContext.engineFuture.Set(s.mockHistoryEngine, nil)
}
Expand All @@ -112,18 +105,6 @@ func (s *contextSuite) TearDownTest() {
}

func (s *contextSuite) TestAddTasks_Success() {
task := &persistencespb.TimerTaskInfo{
NamespaceId: s.namespaceID.String(),
WorkflowId: "workflow-id",
RunId: "run-id",
TaskType: enumsspb.TASK_TYPE_VISIBILITY_DELETE_EXECUTION,
Version: 1,
EventId: 2,
ScheduleAttempt: 1,
TaskId: 12345,
VisibilityTime: timestamp.TimeNowPtrUtc(),
}

tasks := map[tasks.Category][]tasks.Task{
tasks.CategoryTransfer: {&tasks.ActivityTask{}}, // Just for testing purpose. In the real code ActivityTask can't be passed to shardContext.AddTasks.
tasks.CategoryTimer: {&tasks.ActivityRetryTimerTask{}}, // Just for testing purpose. In the real code ActivityRetryTimerTask can't be passed to shardContext.AddTasks.
Expand All @@ -132,20 +113,18 @@ func (s *contextSuite) TestAddTasks_Success() {
}

addTasksRequest := &persistence.AddHistoryTasksRequest{
ShardID: s.shardContext.GetShardID(),
NamespaceID: task.GetNamespaceId(),
WorkflowID: task.GetWorkflowId(),
RunID: task.GetRunId(),
ShardID: s.mockShard.GetShardID(),
NamespaceID: tests.NamespaceID.String(),
WorkflowID: tests.WorkflowID,
RunID: tests.RunID,

Tasks: tasks,
}

s.mockNamespaceCache.EXPECT().GetNamespaceByID(s.namespaceID).Return(s.namespaceEntry, nil)
s.mockClusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName)
s.mockExecutionManager.EXPECT().AddHistoryTasks(gomock.Any(), addTasksRequest).Return(nil)
s.mockHistoryEngine.EXPECT().NotifyNewTasks(gomock.Any(), tasks)

err := s.shardContext.AddTasks(context.Background(), addTasksRequest)
err := s.mockShard.AddTasks(context.Background(), addTasksRequest)
s.NoError(err)
}

Expand All @@ -159,22 +138,20 @@ func (s *contextSuite) TestTimerMaxReadLevelInitialization() {
cluster.TestCurrentClusterName: timestamp.TimePtr(now),
},
}
s.mockResource.ShardMgr.EXPECT().GetOrCreateShard(gomock.Any(), gomock.Any()).Return(
s.mockShardManager.EXPECT().GetOrCreateShard(gomock.Any(), gomock.Any()).Return(
&persistence.GetOrCreateShardResponse{
ShardInfo: persistenceShardInfo,
},
nil,
)
s.mockResource.ClusterMetadata.EXPECT().GetAllClusterInfo().Return(cluster.TestAllClusterInfo).AnyTimes()
s.mockResource.ClusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName)

// clear shardInfo and load from persistence
shardContextImpl := s.shardContext.(*ContextTest)
shardContextImpl := s.mockShard.(*ContextTest)
shardContextImpl.shardInfo = nil
err := shardContextImpl.loadShardMetadata(convert.BoolPtr(false))
s.NoError(err)

for clusterName, info := range s.shardContext.GetClusterMetadata().GetAllClusterInfo() {
for clusterName, info := range s.mockShard.GetClusterMetadata().GetAllClusterInfo() {
if !info.Enabled {
continue
}
Expand All @@ -189,18 +166,15 @@ func (s *contextSuite) TestTimerMaxReadLevelInitialization() {
}

func (s *contextSuite) TestTimerMaxReadLevelUpdate() {
clusterName := cluster.TestCurrentClusterName
s.mockResource.ClusterMetadata.EXPECT().GetCurrentClusterName().Return(clusterName).AnyTimes()

now := time.Now()
s.timeSource.Update(now)
maxReadLevel := s.shardContext.GetQueueMaxReadLevel(tasks.CategoryTimer, clusterName)
maxReadLevel := s.mockShard.GetQueueMaxReadLevel(tasks.CategoryTimer, cluster.TestCurrentClusterName)

s.timeSource.Update(now.Add(-time.Minute))
newMaxReadLevel := s.shardContext.GetQueueMaxReadLevel(tasks.CategoryTimer, clusterName)
newMaxReadLevel := s.mockShard.GetQueueMaxReadLevel(tasks.CategoryTimer, cluster.TestCurrentClusterName)
s.Equal(maxReadLevel, newMaxReadLevel)

s.timeSource.Update(now.Add(time.Minute))
newMaxReadLevel = s.shardContext.GetQueueMaxReadLevel(tasks.CategoryTimer, clusterName)
newMaxReadLevel = s.mockShard.GetQueueMaxReadLevel(tasks.CategoryTimer, cluster.TestCurrentClusterName)
s.True(newMaxReadLevel.FireTime.After(maxReadLevel.FireTime))
}
31 changes: 4 additions & 27 deletions service/history/workflow/transaction_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (t *TransactionImpl) CreateWorkflowExecution(
NewWorkflowSnapshot: *newWorkflowSnapshot,
NewWorkflowEvents: newWorkflowEventsSeq,
})
if operationPossiblySucceeded(err) {
if shard.OperationPossiblySucceeded(err) {
NotifyWorkflowSnapshotTasks(engine, newWorkflowSnapshot, clusterName)
}
if err != nil {
Expand Down Expand Up @@ -129,7 +129,7 @@ func (t *TransactionImpl) ConflictResolveWorkflowExecution(
CurrentWorkflowMutation: currentWorkflowMutation,
CurrentWorkflowEvents: currentWorkflowEventsSeq,
})
if operationPossiblySucceeded(err) {
if shard.OperationPossiblySucceeded(err) {
NotifyWorkflowSnapshotTasks(engine, resetWorkflowSnapshot, clusterName)
NotifyWorkflowSnapshotTasks(engine, newWorkflowSnapshot, clusterName)
NotifyWorkflowMutationTasks(engine, currentWorkflowMutation, clusterName)
Expand Down Expand Up @@ -182,7 +182,7 @@ func (t *TransactionImpl) UpdateWorkflowExecution(
NewWorkflowSnapshot: newWorkflowSnapshot,
NewWorkflowEvents: newWorkflowEventsSeq,
})
if operationPossiblySucceeded(err) {
if shard.OperationPossiblySucceeded(err) {
NotifyWorkflowMutationTasks(engine, currentWorkflowMutation, clusterName)
NotifyWorkflowSnapshotTasks(engine, newWorkflowSnapshot, clusterName)
}
Expand Down Expand Up @@ -219,7 +219,7 @@ func (t *TransactionImpl) SetWorkflowExecution(
// RangeID , this is set by shard context
SetWorkflowSnapshot: *workflowSnapshot,
})
if operationPossiblySucceeded(err) {
if shard.OperationPossiblySucceeded(err) {
NotifyWorkflowSnapshotTasks(engine, workflowSnapshot, clusterName)
}
if err != nil {
Expand Down Expand Up @@ -785,26 +785,3 @@ func emitCompletionMetrics(
)
}
}

func operationPossiblySucceeded(err error) bool {
if err == consts.ErrConflict {
return false
}

switch err.(type) {
case *persistence.CurrentWorkflowConditionFailedError,
*persistence.WorkflowConditionFailedError,
*persistence.ConditionFailedError,
*persistence.ShardOwnershipLostError,
*persistence.InvalidPersistenceRequestError,
*persistence.TransactionSizeLimitError,
*serviceerror.ResourceExhausted,
*serviceerror.NotFound,
*serviceerror.NamespaceNotFound:
// Persistence failure that means that write was definitely not committed.
return false
default:
return true
}

}

0 comments on commit bffb755

Please sign in to comment.