Skip to content

Commit

Permalink
support delaying shard close for membership change
Browse files Browse the repository at this point in the history
  • Loading branch information
alfred-landrum committed Jul 20, 2023
1 parent ccf3c05 commit 2a16156
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 12 deletions.
10 changes: 10 additions & 0 deletions common/dynamicconfig/constants.go
Expand Up @@ -501,6 +501,16 @@ const (
AcquireShardInterval = "history.acquireShardInterval"
// AcquireShardConcurrency is number of goroutines that can be used to acquire shards in the shard controller.
AcquireShardConcurrency = "history.acquireShardConcurrency"
// ShardLingerEnabled configures if the shard controller will temporarily
// delay closing shards after a membership update, awaiting a shard
// ownership lost error from persistence. Not recommended with persistence
// layers that are missing AssertShardOwnership support.
ShardLingerEnabled = "history.shardLingerEnabled"
// ShardLingerOwnershipCheckQPS is the frequency to perform shard ownership
// checks while a shard is lingering.
ShardLingerOwnershipCheckQPS = "history.shardLingerOwnershipCheckQPS"
// ShardLingerTimeLimit is the upper bound on how long a shard can linger.
ShardLingerTimeLimit = "history.shardLingerTimeLimit"
// HistoryClientOwnershipCachingEnabled configures if history clients try to cache
// shard ownership information, instead of checking membership for each request.
HistoryClientOwnershipCachingEnabled = "history.clientOwnershipCachingEnabled"
Expand Down
2 changes: 2 additions & 0 deletions common/metrics/metric_defs.go
Expand Up @@ -1426,6 +1426,8 @@ var (
NamespaceRegistryLockLatency = NewTimerDef("namespace_registry_lock_latency")
ClosedWorkflowBufferEventCount = NewCounterDef("closed_workflow_buffer_event_counter")
InorderBufferedEventsCounter = NewCounterDef("inordered_buffered_events")
ShardLingerSuccess = NewTimerDef("shard_linger_success")
ShardLingerTimeouts = NewCounterDef("shard_linger_timeouts")

// Matching
MatchingClientForwardedCounter = NewCounterDef("forwarded")
Expand Down
18 changes: 12 additions & 6 deletions service/history/configs/config.go
Expand Up @@ -79,9 +79,12 @@ type Config struct {
EventsCacheTTL dynamicconfig.DurationPropertyFn

// ShardController settings
RangeSizeBits uint
AcquireShardInterval dynamicconfig.DurationPropertyFn
AcquireShardConcurrency dynamicconfig.IntPropertyFn
RangeSizeBits uint
AcquireShardInterval dynamicconfig.DurationPropertyFn
AcquireShardConcurrency dynamicconfig.IntPropertyFn
ShardLingerEnabled dynamicconfig.BoolPropertyFn
ShardLingerOwnershipCheckQPS dynamicconfig.IntPropertyFn
ShardLingerTimeLimit dynamicconfig.DurationPropertyFn

HistoryClientOwnershipCachingEnabled dynamicconfig.BoolPropertyFn

Expand Down Expand Up @@ -361,9 +364,12 @@ func NewConfig(
EventsCacheMaxSizeBytes: dc.GetIntProperty(dynamicconfig.EventsCacheMaxSizeBytes, 512*1024), // 512KB
EventsCacheTTL: dc.GetDurationProperty(dynamicconfig.EventsCacheTTL, time.Hour),

RangeSizeBits: 20, // 20 bits for sequencer, 2^20 sequence number for any range
AcquireShardInterval: dc.GetDurationProperty(dynamicconfig.AcquireShardInterval, time.Minute),
AcquireShardConcurrency: dc.GetIntProperty(dynamicconfig.AcquireShardConcurrency, 10),
RangeSizeBits: 20, // 20 bits for sequencer, 2^20 sequence number for any range
AcquireShardInterval: dc.GetDurationProperty(dynamicconfig.AcquireShardInterval, time.Minute),
AcquireShardConcurrency: dc.GetIntProperty(dynamicconfig.AcquireShardConcurrency, 10),
ShardLingerEnabled: dc.GetBoolProperty(dynamicconfig.ShardLingerEnabled, false),
ShardLingerOwnershipCheckQPS: dc.GetIntProperty(dynamicconfig.ShardLingerOwnershipCheckQPS, 4),
ShardLingerTimeLimit: dc.GetDurationProperty(dynamicconfig.ShardLingerTimeLimit, 3*time.Second),

HistoryClientOwnershipCachingEnabled: dc.GetBoolProperty(dynamicconfig.HistoryClientOwnershipCachingEnabled, false),

Expand Down
57 changes: 54 additions & 3 deletions service/history/shard/controller_impl.go
Expand Up @@ -34,6 +34,7 @@ import (

"go.temporal.io/api/serviceerror"
"golang.org/x/sync/semaphore"
"golang.org/x/time/rate"

"go.temporal.io/server/common"
"go.temporal.io/server/common/headers"
Expand Down Expand Up @@ -203,7 +204,7 @@ func (c *ControllerImpl) ShardIDs() []int32 {
return ids
}

func (c *ControllerImpl) shardClosedCallback(shard ControllableContext) {
func (c *ControllerImpl) shardRemoveAndStop(shard ControllableContext) {
startTime := time.Now().UTC()
defer func() {
c.taggedMetricsHandler.Timer(metrics.RemoveEngineForShardLatency.GetMetricName()).Record(time.Since(startTime))
Expand Down Expand Up @@ -256,7 +257,7 @@ func (c *ControllerImpl) getOrCreateShardContext(shardID int32) (Context, error)
return nil, fmt.Errorf("ControllerImpl for host '%v' shutting down", hostInfo.Identity())
}

shard, err := c.contextFactory.CreateContext(shardID, c.shardClosedCallback)
shard, err := c.contextFactory.CreateContext(shardID, c.shardRemoveAndStop)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -291,6 +292,52 @@ func (c *ControllerImpl) removeShardLocked(shardID int32, expected ControllableC
return current
}

// shardLingerThenClose delays closing the shard for a small amount of time,
// while watching for the shard to become invalid due to receiving a shard
// ownership lost error. It calls AssertOwnership to probe for lost ownership,
// but any concurrent request may also see the error, which will mark the shard
// as invalid.
func (c *ControllerImpl) shardLingerThenClose(ctx context.Context, shardID int32) {
c.RLock()
shard, ok := c.historyShards[shardID]
c.RUnlock()
if !ok {
return
}

startTime := time.Now()
timeout := util.Min(c.config.ShardLingerTimeLimit(), shardIOTimeout)
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

qps := c.config.ShardLingerOwnershipCheckQPS()
// The limiter must be configured with burst>=1. With burst=1,
// the first call to Wait() won't be delayed.
limiter := rate.NewLimiter(rate.Limit(qps), 1)

for {
if !shard.IsValid() {
c.taggedMetricsHandler.Timer(metrics.ShardLingerSuccess.GetMetricName()).Record(time.Since(startTime))
break
}

if err := limiter.Wait(ctx); err != nil {
c.contextTaggedLogger.Info("shardLinger: wait timed out",
tag.ShardID(shardID),
tag.NewDurationTag("duration", time.Now().Sub(startTime)),
)
c.taggedMetricsHandler.Counter(metrics.ShardLingerTimeouts.GetMetricName()).Record(1)
break
}

// If this AssertOwnership or any other request on the shard receives
// a shard ownership lost error, the shard will be marked as invalid.
_ = shard.AssertOwnership(ctx)
}

c.shardRemoveAndStop(shard)
}

func (c *ControllerImpl) acquireShards(ctx context.Context) {
c.taggedMetricsHandler.Counter(metrics.AcquireShardsCounter.GetMetricName()).Record(1)
startTime := time.Now().UTC()
Expand All @@ -304,7 +351,11 @@ func (c *ControllerImpl) acquireShards(ctx context.Context) {
if err := c.ownership.verifyOwnership(shardID); err != nil {
if IsShardOwnershipLostError(err) {
// current host is not owner of shard, unload it if it is already loaded.
c.CloseShardByID(shardID)
if c.config.ShardLingerEnabled() {
c.shardLingerThenClose(ctx, shardID)
} else {
c.CloseShardByID(shardID)
}
}
return
}
Expand Down
157 changes: 154 additions & 3 deletions service/history/shard/controller_test.go
Expand Up @@ -48,6 +48,8 @@ import (
"go.temporal.io/server/common/log"
"go.temporal.io/server/common/log/tag"
"go.temporal.io/server/common/membership"
"go.temporal.io/server/common/metrics"
"go.temporal.io/server/common/metrics/metricstest"
"go.temporal.io/server/common/persistence"
"go.temporal.io/server/common/primitives"
"go.temporal.io/server/common/primitives/timestamp"
Expand Down Expand Up @@ -77,6 +79,7 @@ type (
logger log.Logger
shardController *ControllerImpl
mockHostInfoProvider *membership.MockHostInfoProvider
metricsTestHandler *metricstest.Handler
}
)

Expand All @@ -85,6 +88,7 @@ func NewTestController(
config *configs.Config,
resource *resourcetest.Test,
hostInfoProvider *membership.MockHostInfoProvider,
metricsTestHandler *metricstest.Handler,
) *ControllerImpl {
contextFactory := ContextFactoryProvider(ContextFactoryParams{
ArchivalMetadata: resource.GetArchivalMetadata(),
Expand All @@ -96,7 +100,7 @@ func NewTestController(
HistoryServiceResolver: resource.GetHistoryServiceResolver(),
HostInfoProvider: hostInfoProvider,
Logger: resource.GetLogger(),
MetricsHandler: resource.GetMetricsHandler(),
MetricsHandler: metricsTestHandler,
NamespaceRegistry: resource.GetNamespaceRegistry(),
PayloadSerializer: resource.GetPayloadSerializer(),
PersistenceExecutionManager: resource.GetExecutionManager(),
Expand All @@ -111,7 +115,7 @@ func NewTestController(
config,
resource.GetLogger(),
resource.GetHistoryServiceResolver(),
resource.GetMetricsHandler(),
metricsTestHandler,
resource.GetHostInfoProvider(),
contextFactory,
).(*ControllerImpl)
Expand Down Expand Up @@ -139,12 +143,16 @@ func (s *controllerSuite) SetupTest() {

s.logger = s.mockResource.Logger
s.config = tests.NewDynamicConfig()
metricsTestHandler, err := metricstest.NewHandler(log.NewNoopLogger(), metrics.ClientConfig{})
s.NoError(err)
s.metricsTestHandler = metricsTestHandler

s.shardController = NewTestController(
s.mockEngineFactory,
s.config,
s.mockResource,
s.mockHostInfoProvider,
s.metricsTestHandler,
)
}

Expand Down Expand Up @@ -319,6 +327,7 @@ func (s *controllerSuite) TestHistoryEngineClosed() {
s.config,
s.mockResource,
s.mockHostInfoProvider,
s.metricsTestHandler,
)
historyEngines := make(map[int32]*MockEngine)
for shardID := int32(1); shardID <= numShards; shardID++ {
Expand Down Expand Up @@ -423,6 +432,7 @@ func (s *controllerSuite) TestShardControllerClosed() {
s.config,
s.mockResource,
s.mockHostInfoProvider,
s.metricsTestHandler,
)

historyEngines := make(map[int32]*MockEngine)
Expand Down Expand Up @@ -728,6 +738,133 @@ func (s *controllerSuite) Test_GetOrCreateShard_InvalidShardID() {
s.ErrorIs(err, invalidShardIdUpperBound)
}

func (s *controllerSuite) TestShardLingerTimeout() {
shardID := int32(1)
s.config.NumberOfShards = 1
s.config.ShardLingerEnabled = func() bool {
return true
}
timeLimit := 1 * time.Second
checkQPS := 5
s.config.ShardLingerTimeLimit = func() time.Duration {
return timeLimit
}
s.config.ShardLingerOwnershipCheckQPS = func() int {
return checkQPS
}

historyEngines := make(map[int32]*MockEngine)
mockEngine := NewMockEngine(s.controller)
historyEngines[shardID] = mockEngine
s.setupMocksForAcquireShard(shardID, mockEngine, 5, 6, true)

// when shard is initialized, it will use the 2 mock function below to initialize the "current" time of each cluster
s.mockClusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName).AnyTimes()
s.mockClusterMetadata.EXPECT().GetAllClusterInfo().Return(cluster.TestSingleDCClusterInfo).AnyTimes()
s.shardController.acquireShards(context.Background())

s.Len(s.shardController.ShardIDs(), 1)

s.mockServiceResolver.EXPECT().Lookup(convert.Int32ToString(shardID)).
Return(membership.NewHostInfoFromAddress("newhost"), nil)

mockEngine.EXPECT().Stop().Return()

start := time.Now()
s.shardController.acquireShards(context.Background())
expectedMinimumWait := timeLimit - (time.Second / time.Duration(checkQPS))
s.Greater(time.Now().Sub(start), expectedMinimumWait)
s.Len(s.shardController.ShardIDs(), 0)

s.Equal(float64(1), s.readMetricsCounter(
metrics.ShardLingerTimeouts.GetMetricName(),
metrics.OperationTag(metrics.HistoryShardControllerScope)))
}

func (s *controllerSuite) TestShardLingerSuccess() {
shardID := int32(1)
s.config.NumberOfShards = 1
s.config.ShardLingerEnabled = func() bool {
return true
}
timeLimit := 1 * time.Second
checkQPS := 5
s.config.ShardLingerTimeLimit = func() time.Duration {
return timeLimit
}
s.config.ShardLingerOwnershipCheckQPS = func() int {
return checkQPS
}

historyEngines := make(map[int32]*MockEngine)
mockEngine := NewMockEngine(s.controller)
historyEngines[shardID] = mockEngine

mockEngine.EXPECT().Start().MinTimes(1)
mockEngine.EXPECT().NotifyNewTasks(gomock.Any()).MaxTimes(2)
s.mockServiceResolver.EXPECT().Lookup(convert.Int32ToString(shardID)).Return(s.hostInfo, nil).Times(2).MinTimes(1)
s.mockEngineFactory.EXPECT().CreateEngine(contextMatcher(shardID)).Return(mockEngine).MinTimes(1)
s.mockShardManager.EXPECT().GetOrCreateShard(gomock.Any(), getOrCreateShardRequestMatcher(shardID)).Return(
&persistence.GetOrCreateShardResponse{
ShardInfo: &persistencespb.ShardInfo{
ShardId: shardID,
Owner: s.hostInfo.Identity(),
RangeId: 5,
ReplicationDlqAckLevel: map[string]int64{},
QueueStates: s.queueStates(),
},
}, nil).MinTimes(1)
s.mockShardManager.EXPECT().UpdateShard(gomock.Any(), updateShardRequestMatcher(persistence.UpdateShardRequest{
ShardInfo: &persistencespb.ShardInfo{
ShardId: shardID,
Owner: s.hostInfo.Identity(),
RangeId: 6,
StolenSinceRenew: 1,
ReplicationDlqAckLevel: map[string]int64{},
QueueStates: s.queueStates(),
},
PreviousRangeID: 5,
})).Return(nil).MinTimes(1)
s.mockShardManager.EXPECT().AssertShardOwnership(gomock.Any(), &persistence.AssertShardOwnershipRequest{
ShardID: shardID,
RangeID: 6,
}).Return(nil).Times(1)
s.mockClusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName).AnyTimes()
s.mockClusterMetadata.EXPECT().GetAllClusterInfo().Return(cluster.TestSingleDCClusterInfo).AnyTimes()

s.shardController.acquireShards(context.Background())
s.Len(s.shardController.ShardIDs(), 1)
shard, err := s.shardController.getOrCreateShardContext(shardID)
s.NoError(err)

s.mockServiceResolver.EXPECT().Lookup(convert.Int32ToString(shardID)).
Return(membership.NewHostInfoFromAddress("newhost"), nil)

mockEngine.EXPECT().Stop().Return().MinTimes(1)

// We mock 2 AssertShardOwnership calls because the first call happens
// before any waiting actually occcurs in shardLingerAndClose.
s.mockShardManager.EXPECT().AssertShardOwnership(gomock.Any(), &persistence.AssertShardOwnershipRequest{
ShardID: shardID,
RangeID: 6,
}).Return(nil).Times(1)
s.mockShardManager.EXPECT().AssertShardOwnership(gomock.Any(), &persistence.AssertShardOwnershipRequest{
ShardID: shardID,
RangeID: 6,
}).DoAndReturn(func(_ context.Context, _ *persistence.AssertShardOwnershipRequest) error {
shard.UnloadForOwnershipLost()
return nil
}).Times(1)

start := time.Now()
s.shardController.acquireShards(context.Background())
s.Len(s.shardController.ShardIDs(), 0)

expectedWait := time.Second / time.Duration(checkQPS)
expected := start.Add(expectedWait)
s.WithinDuration(time.Now(), expected, 100*time.Millisecond)
}

func (s *controllerSuite) setupMocksForAcquireShard(
shardID int32,
mockEngine *MockEngine,
Expand Down Expand Up @@ -772,7 +909,7 @@ func (s *controllerSuite) setupMocksForAcquireShard(
s.mockShardManager.EXPECT().AssertShardOwnership(gomock.Any(), &persistence.AssertShardOwnershipRequest{
ShardID: shardID,
RangeID: newRangeID,
}).Return(nil).AnyTimes()
}).Return(nil).MinTimes(minTimes)
}

func (s *controllerSuite) queueStates() map[int32]*persistencespb.QueueState {
Expand Down Expand Up @@ -869,3 +1006,17 @@ func (m updateShardRequestMatcher) Matches(x interface{}) bool {
func (m updateShardRequestMatcher) String() string {
return fmt.Sprintf("%+v", (persistence.UpdateShardRequest)(m))
}

func (s *controllerSuite) readMetricsCounter(name string, nonSystemTags ...metrics.Tag) float64 {
expectedSystemTags := []metrics.Tag{
metrics.StringTag("otel_scope_name", "temporal"),
metrics.StringTag("otel_scope_version", ""),
}
snapshot, err := s.metricsTestHandler.Snapshot()
s.NoError(err)

tags := append(nonSystemTags, expectedSystemTags...)
value, err := snapshot.Counter(name+"_total", tags...)
s.NoError(err)
return value
}

0 comments on commit 2a16156

Please sign in to comment.