Skip to content

Commit

Permalink
Use context for shard acquisition timeout (#2219)
Browse files Browse the repository at this point in the history
  • Loading branch information
dnr authored Nov 23, 2021
1 parent c3fdcb8 commit 78acd4d
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 56 deletions.
69 changes: 35 additions & 34 deletions service/history/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ func (h *Handler) RecordActivityTaskHeartbeat(ctx context.Context, request *hist
}
workflowID := taskToken.GetWorkflowId()

engine, err1 := h.controller.GetEngine(namespaceID, workflowID)
engine, err1 := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err1 != nil {
return nil, h.convertError(err1)
}
Expand All @@ -258,7 +258,7 @@ func (h *Handler) RecordActivityTaskStarted(ctx context.Context, request *histor
return nil, h.convertError(errNamespaceNotSet)
}

engine, err1 := h.controller.GetEngine(namespaceID, workflowID)
engine, err1 := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err1 != nil {
return nil, h.convertError(err1)
}
Expand Down Expand Up @@ -287,7 +287,7 @@ func (h *Handler) RecordWorkflowTaskStarted(ctx context.Context, request *histor
return nil, h.convertError(errTaskQueueNotSet)
}

engine, err1 := h.controller.GetEngine(namespaceID, workflowID)
engine, err1 := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err1 != nil {
h.GetLogger().Error("RecordWorkflowTaskStarted failed.",
tag.Error(err1),
Expand Down Expand Up @@ -328,7 +328,7 @@ func (h *Handler) RespondActivityTaskCompleted(ctx context.Context, request *his
}
workflowID := taskToken.GetWorkflowId()

engine, err1 := h.controller.GetEngine(namespaceID, workflowID)
engine, err1 := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err1 != nil {
return nil, h.convertError(err1)
}
Expand Down Expand Up @@ -363,7 +363,7 @@ func (h *Handler) RespondActivityTaskFailed(ctx context.Context, request *histor
}
workflowID := taskToken.GetWorkflowId()

engine, err1 := h.controller.GetEngine(namespaceID, workflowID)
engine, err1 := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err1 != nil {
return nil, h.convertError(err1)
}
Expand Down Expand Up @@ -398,7 +398,7 @@ func (h *Handler) RespondActivityTaskCanceled(ctx context.Context, request *hist
}
workflowID := taskToken.GetWorkflowId()

engine, err1 := h.controller.GetEngine(namespaceID, workflowID)
engine, err1 := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err1 != nil {
return nil, h.convertError(err1)
}
Expand Down Expand Up @@ -439,7 +439,7 @@ func (h *Handler) RespondWorkflowTaskCompleted(ctx context.Context, request *his
}
workflowID := token.GetWorkflowId()

engine, err1 := h.controller.GetEngine(namespaceID, workflowID)
engine, err1 := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err1 != nil {
return nil, h.convertError(err1)
}
Expand Down Expand Up @@ -480,7 +480,7 @@ func (h *Handler) RespondWorkflowTaskFailed(ctx context.Context, request *histor
}
workflowID := token.GetWorkflowId()

engine, err1 := h.controller.GetEngine(namespaceID, workflowID)
engine, err1 := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err1 != nil {
return nil, h.convertError(err1)
}
Expand All @@ -505,7 +505,7 @@ func (h *Handler) StartWorkflowExecution(ctx context.Context, request *historyse

startRequest := request.StartRequest
workflowID := startRequest.GetWorkflowId()
engine, err1 := h.controller.GetEngine(namespaceID, workflowID)
engine, err1 := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err1 != nil {
return nil, h.convertError(err1)
}
Expand Down Expand Up @@ -612,7 +612,7 @@ func (h *Handler) DescribeMutableState(ctx context.Context, request *historyserv

workflowExecution := request.Execution
workflowID := workflowExecution.GetWorkflowId()
engine, err1 := h.controller.GetEngine(namespaceID, workflowID)
engine, err1 := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err1 != nil {
return nil, h.convertError(err1)
}
Expand All @@ -636,7 +636,7 @@ func (h *Handler) GetMutableState(ctx context.Context, request *historyservice.G

workflowExecution := request.Execution
workflowID := workflowExecution.GetWorkflowId()
engine, err1 := h.controller.GetEngine(namespaceID, workflowID)
engine, err1 := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err1 != nil {
return nil, h.convertError(err1)
}
Expand All @@ -660,7 +660,7 @@ func (h *Handler) PollMutableState(ctx context.Context, request *historyservice.

workflowExecution := request.Execution
workflowID := workflowExecution.GetWorkflowId()
engine, err1 := h.controller.GetEngine(namespaceID, workflowID)
engine, err1 := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err1 != nil {
return nil, h.convertError(err1)
}
Expand All @@ -684,7 +684,7 @@ func (h *Handler) DescribeWorkflowExecution(ctx context.Context, request *histor

workflowExecution := request.Request.Execution
workflowID := workflowExecution.GetWorkflowId()
engine, err1 := h.controller.GetEngine(namespaceID, workflowID)
engine, err1 := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err1 != nil {
return nil, h.convertError(err1)
}
Expand Down Expand Up @@ -718,7 +718,7 @@ func (h *Handler) RequestCancelWorkflowExecution(ctx context.Context, request *h
tag.WorkflowRunID(cancelRequest.WorkflowExecution.GetRunId()))

workflowID := cancelRequest.WorkflowExecution.GetWorkflowId()
engine, err1 := h.controller.GetEngine(namespaceID, workflowID)
engine, err1 := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err1 != nil {
return nil, h.convertError(err1)
}
Expand Down Expand Up @@ -748,7 +748,7 @@ func (h *Handler) SignalWorkflowExecution(ctx context.Context, request *historys

workflowExecution := request.SignalRequest.WorkflowExecution
workflowID := workflowExecution.GetWorkflowId()
engine, err1 := h.controller.GetEngine(namespaceID, workflowID)
engine, err1 := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err1 != nil {
return nil, h.convertError(err1)
}
Expand Down Expand Up @@ -781,7 +781,7 @@ func (h *Handler) SignalWithStartWorkflowExecution(ctx context.Context, request

signalWithStartRequest := request.SignalWithStartRequest
workflowID := signalWithStartRequest.GetWorkflowId()
engine, err1 := h.controller.GetEngine(namespaceID, workflowID)
engine, err1 := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err1 != nil {
return nil, h.convertError(err1)
}
Expand Down Expand Up @@ -834,7 +834,7 @@ func (h *Handler) RemoveSignalMutableState(ctx context.Context, request *history

workflowExecution := request.WorkflowExecution
workflowID := workflowExecution.GetWorkflowId()
engine, err1 := h.controller.GetEngine(namespaceID, workflowID)
engine, err1 := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err1 != nil {
return nil, h.convertError(err1)
}
Expand Down Expand Up @@ -864,7 +864,7 @@ func (h *Handler) TerminateWorkflowExecution(ctx context.Context, request *histo

workflowExecution := request.TerminateRequest.WorkflowExecution
workflowID := workflowExecution.GetWorkflowId()
engine, err1 := h.controller.GetEngine(namespaceID, workflowID)
engine, err1 := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err1 != nil {
return nil, h.convertError(err1)
}
Expand Down Expand Up @@ -894,7 +894,7 @@ func (h *Handler) ResetWorkflowExecution(ctx context.Context, request *historyse

workflowExecution := request.ResetRequest.WorkflowExecution
workflowID := workflowExecution.GetWorkflowId()
engine, err1 := h.controller.GetEngine(namespaceID, workflowID)
engine, err1 := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err1 != nil {
return nil, h.convertError(err1)
}
Expand Down Expand Up @@ -922,7 +922,7 @@ func (h *Handler) QueryWorkflow(ctx context.Context, request *historyservice.Que
}

workflowID := request.GetRequest().GetExecution().GetWorkflowId()
engine, err1 := h.controller.GetEngine(namespaceID, workflowID)
engine, err1 := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err1 != nil {
return nil, h.convertError(err1)
}
Expand Down Expand Up @@ -966,7 +966,7 @@ func (h *Handler) ScheduleWorkflowTask(ctx context.Context, request *historyserv

workflowExecution := request.WorkflowExecution
workflowID := workflowExecution.GetWorkflowId()
engine, err1 := h.controller.GetEngine(namespaceID, workflowID)
engine, err1 := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err1 != nil {
return nil, h.convertError(err1)
}
Expand Down Expand Up @@ -1000,7 +1000,7 @@ func (h *Handler) RecordChildExecutionCompleted(ctx context.Context, request *hi

workflowExecution := request.WorkflowExecution
workflowID := workflowExecution.GetWorkflowId()
engine, err1 := h.controller.GetEngine(namespaceID, workflowID)
engine, err1 := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err1 != nil {
return nil, h.convertError(err1)
}
Expand Down Expand Up @@ -1032,7 +1032,7 @@ func (h *Handler) ResetStickyTaskQueue(ctx context.Context, request *historyserv
}

workflowID := request.Execution.GetWorkflowId()
engine, err := h.controller.GetEngine(namespaceID, workflowID)
engine, err := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err != nil {
return nil, h.convertError(err)
}
Expand Down Expand Up @@ -1061,7 +1061,7 @@ func (h *Handler) ReplicateEventsV2(ctx context.Context, request *historyservice

workflowExecution := request.WorkflowExecution
workflowID := workflowExecution.GetWorkflowId()
engine, err1 := h.controller.GetEngine(namespaceID, workflowID)
engine, err1 := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err1 != nil {
return nil, h.convertError(err1)
}
Expand Down Expand Up @@ -1096,7 +1096,7 @@ func (h *Handler) SyncShardStatus(ctx context.Context, request *historyservice.S
}

// shard ID is already provided in the request
engine, err := h.controller.GetEngineForShard(request.GetShardId())
engine, err := h.controller.GetEngineForShard(ctx, request.GetShardId())
if err != nil {
return nil, h.convertError(err)
}
Expand Down Expand Up @@ -1132,7 +1132,7 @@ func (h *Handler) SyncActivity(ctx context.Context, request *historyservice.Sync
}

workflowID := request.GetWorkflowId()
engine, err := h.controller.GetEngine(namespaceID, workflowID)
engine, err := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err != nil {
return nil, h.convertError(err)
}
Expand Down Expand Up @@ -1162,7 +1162,7 @@ func (h *Handler) GetReplicationMessages(ctx context.Context, request *historyse
go func(token *replicationspb.ReplicationToken) {
defer wg.Done()

engine, err := h.controller.GetEngineForShard(token.GetShardId())
engine, err := h.controller.GetEngineForShard(ctx, token.GetShardId())
if err != nil {
h.GetLogger().Warn("History engine not found for shard", tag.Error(err))
return
Expand Down Expand Up @@ -1230,6 +1230,7 @@ func (h *Handler) GetDLQReplicationMessages(ctx context.Context, request *histor
}

engine, err := h.controller.GetEngine(
ctx,
namespace.ID(taskInfos[0].GetNamespaceId()),
taskInfos[0].GetWorkflowId(),
)
Expand Down Expand Up @@ -1278,7 +1279,7 @@ func (h *Handler) ReapplyEvents(ctx context.Context, request *historyservice.Rea

namespaceID := namespace.ID(request.GetNamespaceId())
workflowID := request.GetRequest().GetWorkflowExecution().GetWorkflowId()
engine, err := h.controller.GetEngine(namespaceID, workflowID)
engine, err := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err != nil {
return nil, h.convertError(err)
}
Expand Down Expand Up @@ -1312,7 +1313,7 @@ func (h *Handler) GetDLQMessages(ctx context.Context, request *historyservice.Ge
return nil, errShuttingDown
}

engine, err := h.controller.GetEngineForShard(request.GetShardId())
engine, err := h.controller.GetEngineForShard(ctx, request.GetShardId())
if err != nil {
err = h.convertError(err)
return nil, err
Expand All @@ -1335,7 +1336,7 @@ func (h *Handler) PurgeDLQMessages(ctx context.Context, request *historyservice.
return nil, errShuttingDown
}

engine, err := h.controller.GetEngineForShard(request.GetShardId())
engine, err := h.controller.GetEngineForShard(ctx, request.GetShardId())
if err != nil {
err = h.convertError(err)
return nil, err
Expand All @@ -1357,7 +1358,7 @@ func (h *Handler) MergeDLQMessages(ctx context.Context, request *historyservice.
return nil, errShuttingDown
}

engine, err := h.controller.GetEngineForShard(request.GetShardId())
engine, err := h.controller.GetEngineForShard(ctx, request.GetShardId())
if err != nil {
err = h.convertError(err)
return nil, err
Expand All @@ -1383,7 +1384,7 @@ func (h *Handler) RefreshWorkflowTasks(ctx context.Context, request *historyserv
namespaceID := namespace.ID(request.GetNamespaceId())
execution := request.GetRequest().GetExecution()
workflowID := execution.GetWorkflowId()
engine, err := h.controller.GetEngine(namespaceID, workflowID)
engine, err := h.controller.GetEngine(ctx, namespaceID, workflowID)
if err != nil {
err = h.convertError(err)
return nil, err
Expand Down Expand Up @@ -1418,7 +1419,7 @@ func (h *Handler) GenerateLastHistoryReplicationTasks(
}

namespaceID := namespace.ID(request.GetNamespaceId())
engine, err := h.controller.GetEngine(namespaceID, request.GetExecution().GetWorkflowId())
engine, err := h.controller.GetEngine(ctx, namespaceID, request.GetExecution().GetWorkflowId())
if err != nil {
err = h.convertError(err)
return nil, err
Expand Down Expand Up @@ -1447,7 +1448,7 @@ func (h *Handler) GetReplicationStatus(

resp := &historyservice.GetReplicationStatusResponse{}
for _, shardID := range h.controller.ShardIDs() {
engine, err := h.controller.GetEngineForShard(shardID)
engine, err := h.controller.GetEngineForShard(ctx, shardID)
if err != nil {
return nil, err
}
Expand Down
14 changes: 7 additions & 7 deletions service/history/shard/context_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
package shard

import (
"context"
"sync"
"time"

Expand Down Expand Up @@ -1153,16 +1154,15 @@ func (s *ContextImpl) createEngine() Engine {
return engine
}

func (s *ContextImpl) getOrCreateEngine() (engine Engine, retErr error) {
// Wait on shard acquisition for 1s. Note that this retry is just polling a value in memory.
// Another goroutine is doing the actual work.
// TODO: use context to do timeout here
func (s *ContextImpl) getOrCreateEngine(ctx context.Context) (engine Engine, retErr error) {
// Block on shard acquisition for the lifetime of this context. Note that this retry is just
// polling a value in memory. Another goroutine is doing the actual work.
policy := backoff.NewExponentialRetryPolicy(5 * time.Millisecond)
policy.SetExpirationInterval(1 * time.Second)
policy.SetMaximumInterval(1 * time.Second)

isRetryable := func(err error) bool { return err == ErrShardStatusUnknown }

op := func() error {
op := func(context.Context) error {
s.rLock()
defer s.rUnlock()
err := s.errorByStateLocked()
Expand All @@ -1172,7 +1172,7 @@ func (s *ContextImpl) getOrCreateEngine() (engine Engine, retErr error) {
return err
}

retErr = backoff.Retry(op, policy, isRetryable)
retErr = backoff.RetryContext(ctx, op, policy, isRetryable)
if retErr == nil && engine == nil {
// This shouldn't ever happen, but don't let it return nil error.
retErr = ErrShardStatusUnknown
Expand Down
15 changes: 9 additions & 6 deletions service/history/shard/controller_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
package shard

import (
"context"
"fmt"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -138,19 +139,19 @@ func (c *ControllerImpl) Status() int32 {
return atomic.LoadInt32(&c.status)
}

func (c *ControllerImpl) GetEngine(namespaceID namespace.ID, workflowID string) (Engine, error) {
func (c *ControllerImpl) GetEngine(ctx context.Context, namespaceID namespace.ID, workflowID string) (Engine, error) {
shardID := c.config.GetShardID(namespaceID, workflowID)
return c.GetEngineForShard(shardID)
return c.GetEngineForShard(ctx, shardID)
}

func (c *ControllerImpl) GetEngineForShard(shardID int32) (Engine, error) {
func (c *ControllerImpl) GetEngineForShard(ctx context.Context, shardID int32) (Engine, error) {
sw := c.metricsScope.StartTimer(metrics.GetEngineForShardLatency)
defer sw.Stop()
shard, err := c.getOrCreateShardContext(shardID)
if err != nil {
return nil, err
}
return shard.getOrCreateEngine()
return shard.getOrCreateEngine(ctx)
}

func (c *ControllerImpl) CloseShardByID(shardID int32) {
Expand Down Expand Up @@ -312,10 +313,12 @@ func (c *ControllerImpl) acquireShards() {
c.logger.Error("Error looking up host for shardID", tag.Error(err), tag.OperationFailed, tag.ShardID(shardID))
} else {
if info.Identity() == c.GetHostInfo().Identity() {
if _, err := c.GetEngineForShard(shardID); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
if _, err := c.GetEngineForShard(ctx, shardID); err != nil {
c.metricsScope.IncCounter(metrics.GetEngineForShardErrorCounter)
c.logger.Error("Unable to create history shard engine", tag.Error(err), tag.OperationFailed, tag.ShardID(shardID))
c.logger.Error("Unable to create history shard context", tag.Error(err), tag.OperationFailed, tag.ShardID(shardID))
}
cancel()
}
// TODO: If we're _not_ the owner for this shard, and we have it loaded, we should unload it.
}
Expand Down
Loading

0 comments on commit 78acd4d

Please sign in to comment.