From 2a16156870478f78e73a3ccbeb47ffad741481b5 Mon Sep 17 00:00:00 2001 From: Alfred Landrum Date: Fri, 23 Jun 2023 13:11:41 -0700 Subject: [PATCH] support delaying shard close for membership change --- common/dynamicconfig/constants.go | 10 ++ common/metrics/metric_defs.go | 2 + service/history/configs/config.go | 18 ++- service/history/shard/controller_impl.go | 57 +++++++- service/history/shard/controller_test.go | 157 ++++++++++++++++++++++- 5 files changed, 232 insertions(+), 12 deletions(-) diff --git a/common/dynamicconfig/constants.go b/common/dynamicconfig/constants.go index d90528378df..d49c8cd1dbe 100644 --- a/common/dynamicconfig/constants.go +++ b/common/dynamicconfig/constants.go @@ -501,6 +501,16 @@ const ( AcquireShardInterval = "history.acquireShardInterval" // AcquireShardConcurrency is number of goroutines that can be used to acquire shards in the shard controller. AcquireShardConcurrency = "history.acquireShardConcurrency" + // ShardLingerEnabled configures if the shard controller will temporarily + // delay closing shards after a membership update, awaiting a shard + // ownership lost error from persistence. Not recommended with persistence + // layers that are missing AssertShardOwnership support. + ShardLingerEnabled = "history.shardLingerEnabled" + // ShardLingerOwnershipCheckQPS is the frequency to perform shard ownership + // checks while a shard is lingering. + ShardLingerOwnershipCheckQPS = "history.shardLingerOwnershipCheckQPS" + // ShardLingerTimeLimit is the upper bound on how long a shard can linger. + ShardLingerTimeLimit = "history.shardLingerTimeLimit" // HistoryClientOwnershipCachingEnabled configures if history clients try to cache // shard ownership information, instead of checking membership for each request. HistoryClientOwnershipCachingEnabled = "history.clientOwnershipCachingEnabled" diff --git a/common/metrics/metric_defs.go b/common/metrics/metric_defs.go index 1b62151ba8f..4cc463ad01a 100644 --- a/common/metrics/metric_defs.go +++ b/common/metrics/metric_defs.go @@ -1426,6 +1426,8 @@ var ( NamespaceRegistryLockLatency = NewTimerDef("namespace_registry_lock_latency") ClosedWorkflowBufferEventCount = NewCounterDef("closed_workflow_buffer_event_counter") InorderBufferedEventsCounter = NewCounterDef("inordered_buffered_events") + ShardLingerSuccess = NewTimerDef("shard_linger_success") + ShardLingerTimeouts = NewCounterDef("shard_linger_timeouts") // Matching MatchingClientForwardedCounter = NewCounterDef("forwarded") diff --git a/service/history/configs/config.go b/service/history/configs/config.go index 5702d1a3e86..1e5f909522e 100644 --- a/service/history/configs/config.go +++ b/service/history/configs/config.go @@ -79,9 +79,12 @@ type Config struct { EventsCacheTTL dynamicconfig.DurationPropertyFn // ShardController settings - RangeSizeBits uint - AcquireShardInterval dynamicconfig.DurationPropertyFn - AcquireShardConcurrency dynamicconfig.IntPropertyFn + RangeSizeBits uint + AcquireShardInterval dynamicconfig.DurationPropertyFn + AcquireShardConcurrency dynamicconfig.IntPropertyFn + ShardLingerEnabled dynamicconfig.BoolPropertyFn + ShardLingerOwnershipCheckQPS dynamicconfig.IntPropertyFn + ShardLingerTimeLimit dynamicconfig.DurationPropertyFn HistoryClientOwnershipCachingEnabled dynamicconfig.BoolPropertyFn @@ -361,9 +364,12 @@ func NewConfig( EventsCacheMaxSizeBytes: dc.GetIntProperty(dynamicconfig.EventsCacheMaxSizeBytes, 512*1024), // 512KB EventsCacheTTL: dc.GetDurationProperty(dynamicconfig.EventsCacheTTL, time.Hour), - RangeSizeBits: 20, // 20 bits for sequencer, 2^20 sequence number for any range - AcquireShardInterval: dc.GetDurationProperty(dynamicconfig.AcquireShardInterval, time.Minute), - AcquireShardConcurrency: dc.GetIntProperty(dynamicconfig.AcquireShardConcurrency, 10), + RangeSizeBits: 20, // 20 bits for sequencer, 2^20 sequence number for any range + AcquireShardInterval: dc.GetDurationProperty(dynamicconfig.AcquireShardInterval, time.Minute), + AcquireShardConcurrency: dc.GetIntProperty(dynamicconfig.AcquireShardConcurrency, 10), + ShardLingerEnabled: dc.GetBoolProperty(dynamicconfig.ShardLingerEnabled, false), + ShardLingerOwnershipCheckQPS: dc.GetIntProperty(dynamicconfig.ShardLingerOwnershipCheckQPS, 4), + ShardLingerTimeLimit: dc.GetDurationProperty(dynamicconfig.ShardLingerTimeLimit, 3*time.Second), HistoryClientOwnershipCachingEnabled: dc.GetBoolProperty(dynamicconfig.HistoryClientOwnershipCachingEnabled, false), diff --git a/service/history/shard/controller_impl.go b/service/history/shard/controller_impl.go index aa10b5e2556..4d45fe3dd24 100644 --- a/service/history/shard/controller_impl.go +++ b/service/history/shard/controller_impl.go @@ -34,6 +34,7 @@ import ( "go.temporal.io/api/serviceerror" "golang.org/x/sync/semaphore" + "golang.org/x/time/rate" "go.temporal.io/server/common" "go.temporal.io/server/common/headers" @@ -203,7 +204,7 @@ func (c *ControllerImpl) ShardIDs() []int32 { return ids } -func (c *ControllerImpl) shardClosedCallback(shard ControllableContext) { +func (c *ControllerImpl) shardRemoveAndStop(shard ControllableContext) { startTime := time.Now().UTC() defer func() { c.taggedMetricsHandler.Timer(metrics.RemoveEngineForShardLatency.GetMetricName()).Record(time.Since(startTime)) @@ -256,7 +257,7 @@ func (c *ControllerImpl) getOrCreateShardContext(shardID int32) (Context, error) return nil, fmt.Errorf("ControllerImpl for host '%v' shutting down", hostInfo.Identity()) } - shard, err := c.contextFactory.CreateContext(shardID, c.shardClosedCallback) + shard, err := c.contextFactory.CreateContext(shardID, c.shardRemoveAndStop) if err != nil { return nil, err } @@ -291,6 +292,52 @@ func (c *ControllerImpl) removeShardLocked(shardID int32, expected ControllableC return current } +// shardLingerThenClose delays closing the shard for a small amount of time, +// while watching for the shard to become invalid due to receiving a shard +// ownership lost error. It calls AssertOwnership to probe for lost ownership, +// but any concurrent request may also see the error, which will mark the shard +// as invalid. +func (c *ControllerImpl) shardLingerThenClose(ctx context.Context, shardID int32) { + c.RLock() + shard, ok := c.historyShards[shardID] + c.RUnlock() + if !ok { + return + } + + startTime := time.Now() + timeout := util.Min(c.config.ShardLingerTimeLimit(), shardIOTimeout) + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + qps := c.config.ShardLingerOwnershipCheckQPS() + // The limiter must be configured with burst>=1. With burst=1, + // the first call to Wait() won't be delayed. + limiter := rate.NewLimiter(rate.Limit(qps), 1) + + for { + if !shard.IsValid() { + c.taggedMetricsHandler.Timer(metrics.ShardLingerSuccess.GetMetricName()).Record(time.Since(startTime)) + break + } + + if err := limiter.Wait(ctx); err != nil { + c.contextTaggedLogger.Info("shardLinger: wait timed out", + tag.ShardID(shardID), + tag.NewDurationTag("duration", time.Now().Sub(startTime)), + ) + c.taggedMetricsHandler.Counter(metrics.ShardLingerTimeouts.GetMetricName()).Record(1) + break + } + + // If this AssertOwnership or any other request on the shard receives + // a shard ownership lost error, the shard will be marked as invalid. + _ = shard.AssertOwnership(ctx) + } + + c.shardRemoveAndStop(shard) +} + func (c *ControllerImpl) acquireShards(ctx context.Context) { c.taggedMetricsHandler.Counter(metrics.AcquireShardsCounter.GetMetricName()).Record(1) startTime := time.Now().UTC() @@ -304,7 +351,11 @@ func (c *ControllerImpl) acquireShards(ctx context.Context) { if err := c.ownership.verifyOwnership(shardID); err != nil { if IsShardOwnershipLostError(err) { // current host is not owner of shard, unload it if it is already loaded. - c.CloseShardByID(shardID) + if c.config.ShardLingerEnabled() { + c.shardLingerThenClose(ctx, shardID) + } else { + c.CloseShardByID(shardID) + } } return } diff --git a/service/history/shard/controller_test.go b/service/history/shard/controller_test.go index 39a775c441e..71561278e90 100644 --- a/service/history/shard/controller_test.go +++ b/service/history/shard/controller_test.go @@ -48,6 +48,8 @@ import ( "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/membership" + "go.temporal.io/server/common/metrics" + "go.temporal.io/server/common/metrics/metricstest" "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/primitives" "go.temporal.io/server/common/primitives/timestamp" @@ -77,6 +79,7 @@ type ( logger log.Logger shardController *ControllerImpl mockHostInfoProvider *membership.MockHostInfoProvider + metricsTestHandler *metricstest.Handler } ) @@ -85,6 +88,7 @@ func NewTestController( config *configs.Config, resource *resourcetest.Test, hostInfoProvider *membership.MockHostInfoProvider, + metricsTestHandler *metricstest.Handler, ) *ControllerImpl { contextFactory := ContextFactoryProvider(ContextFactoryParams{ ArchivalMetadata: resource.GetArchivalMetadata(), @@ -96,7 +100,7 @@ func NewTestController( HistoryServiceResolver: resource.GetHistoryServiceResolver(), HostInfoProvider: hostInfoProvider, Logger: resource.GetLogger(), - MetricsHandler: resource.GetMetricsHandler(), + MetricsHandler: metricsTestHandler, NamespaceRegistry: resource.GetNamespaceRegistry(), PayloadSerializer: resource.GetPayloadSerializer(), PersistenceExecutionManager: resource.GetExecutionManager(), @@ -111,7 +115,7 @@ func NewTestController( config, resource.GetLogger(), resource.GetHistoryServiceResolver(), - resource.GetMetricsHandler(), + metricsTestHandler, resource.GetHostInfoProvider(), contextFactory, ).(*ControllerImpl) @@ -139,12 +143,16 @@ func (s *controllerSuite) SetupTest() { s.logger = s.mockResource.Logger s.config = tests.NewDynamicConfig() + metricsTestHandler, err := metricstest.NewHandler(log.NewNoopLogger(), metrics.ClientConfig{}) + s.NoError(err) + s.metricsTestHandler = metricsTestHandler s.shardController = NewTestController( s.mockEngineFactory, s.config, s.mockResource, s.mockHostInfoProvider, + s.metricsTestHandler, ) } @@ -319,6 +327,7 @@ func (s *controllerSuite) TestHistoryEngineClosed() { s.config, s.mockResource, s.mockHostInfoProvider, + s.metricsTestHandler, ) historyEngines := make(map[int32]*MockEngine) for shardID := int32(1); shardID <= numShards; shardID++ { @@ -423,6 +432,7 @@ func (s *controllerSuite) TestShardControllerClosed() { s.config, s.mockResource, s.mockHostInfoProvider, + s.metricsTestHandler, ) historyEngines := make(map[int32]*MockEngine) @@ -728,6 +738,133 @@ func (s *controllerSuite) Test_GetOrCreateShard_InvalidShardID() { s.ErrorIs(err, invalidShardIdUpperBound) } +func (s *controllerSuite) TestShardLingerTimeout() { + shardID := int32(1) + s.config.NumberOfShards = 1 + s.config.ShardLingerEnabled = func() bool { + return true + } + timeLimit := 1 * time.Second + checkQPS := 5 + s.config.ShardLingerTimeLimit = func() time.Duration { + return timeLimit + } + s.config.ShardLingerOwnershipCheckQPS = func() int { + return checkQPS + } + + historyEngines := make(map[int32]*MockEngine) + mockEngine := NewMockEngine(s.controller) + historyEngines[shardID] = mockEngine + s.setupMocksForAcquireShard(shardID, mockEngine, 5, 6, true) + + // when shard is initialized, it will use the 2 mock function below to initialize the "current" time of each cluster + s.mockClusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName).AnyTimes() + s.mockClusterMetadata.EXPECT().GetAllClusterInfo().Return(cluster.TestSingleDCClusterInfo).AnyTimes() + s.shardController.acquireShards(context.Background()) + + s.Len(s.shardController.ShardIDs(), 1) + + s.mockServiceResolver.EXPECT().Lookup(convert.Int32ToString(shardID)). + Return(membership.NewHostInfoFromAddress("newhost"), nil) + + mockEngine.EXPECT().Stop().Return() + + start := time.Now() + s.shardController.acquireShards(context.Background()) + expectedMinimumWait := timeLimit - (time.Second / time.Duration(checkQPS)) + s.Greater(time.Now().Sub(start), expectedMinimumWait) + s.Len(s.shardController.ShardIDs(), 0) + + s.Equal(float64(1), s.readMetricsCounter( + metrics.ShardLingerTimeouts.GetMetricName(), + metrics.OperationTag(metrics.HistoryShardControllerScope))) +} + +func (s *controllerSuite) TestShardLingerSuccess() { + shardID := int32(1) + s.config.NumberOfShards = 1 + s.config.ShardLingerEnabled = func() bool { + return true + } + timeLimit := 1 * time.Second + checkQPS := 5 + s.config.ShardLingerTimeLimit = func() time.Duration { + return timeLimit + } + s.config.ShardLingerOwnershipCheckQPS = func() int { + return checkQPS + } + + historyEngines := make(map[int32]*MockEngine) + mockEngine := NewMockEngine(s.controller) + historyEngines[shardID] = mockEngine + + mockEngine.EXPECT().Start().MinTimes(1) + mockEngine.EXPECT().NotifyNewTasks(gomock.Any()).MaxTimes(2) + s.mockServiceResolver.EXPECT().Lookup(convert.Int32ToString(shardID)).Return(s.hostInfo, nil).Times(2).MinTimes(1) + s.mockEngineFactory.EXPECT().CreateEngine(contextMatcher(shardID)).Return(mockEngine).MinTimes(1) + s.mockShardManager.EXPECT().GetOrCreateShard(gomock.Any(), getOrCreateShardRequestMatcher(shardID)).Return( + &persistence.GetOrCreateShardResponse{ + ShardInfo: &persistencespb.ShardInfo{ + ShardId: shardID, + Owner: s.hostInfo.Identity(), + RangeId: 5, + ReplicationDlqAckLevel: map[string]int64{}, + QueueStates: s.queueStates(), + }, + }, nil).MinTimes(1) + s.mockShardManager.EXPECT().UpdateShard(gomock.Any(), updateShardRequestMatcher(persistence.UpdateShardRequest{ + ShardInfo: &persistencespb.ShardInfo{ + ShardId: shardID, + Owner: s.hostInfo.Identity(), + RangeId: 6, + StolenSinceRenew: 1, + ReplicationDlqAckLevel: map[string]int64{}, + QueueStates: s.queueStates(), + }, + PreviousRangeID: 5, + })).Return(nil).MinTimes(1) + s.mockShardManager.EXPECT().AssertShardOwnership(gomock.Any(), &persistence.AssertShardOwnershipRequest{ + ShardID: shardID, + RangeID: 6, + }).Return(nil).Times(1) + s.mockClusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName).AnyTimes() + s.mockClusterMetadata.EXPECT().GetAllClusterInfo().Return(cluster.TestSingleDCClusterInfo).AnyTimes() + + s.shardController.acquireShards(context.Background()) + s.Len(s.shardController.ShardIDs(), 1) + shard, err := s.shardController.getOrCreateShardContext(shardID) + s.NoError(err) + + s.mockServiceResolver.EXPECT().Lookup(convert.Int32ToString(shardID)). + Return(membership.NewHostInfoFromAddress("newhost"), nil) + + mockEngine.EXPECT().Stop().Return().MinTimes(1) + + // We mock 2 AssertShardOwnership calls because the first call happens + // before any waiting actually occcurs in shardLingerAndClose. + s.mockShardManager.EXPECT().AssertShardOwnership(gomock.Any(), &persistence.AssertShardOwnershipRequest{ + ShardID: shardID, + RangeID: 6, + }).Return(nil).Times(1) + s.mockShardManager.EXPECT().AssertShardOwnership(gomock.Any(), &persistence.AssertShardOwnershipRequest{ + ShardID: shardID, + RangeID: 6, + }).DoAndReturn(func(_ context.Context, _ *persistence.AssertShardOwnershipRequest) error { + shard.UnloadForOwnershipLost() + return nil + }).Times(1) + + start := time.Now() + s.shardController.acquireShards(context.Background()) + s.Len(s.shardController.ShardIDs(), 0) + + expectedWait := time.Second / time.Duration(checkQPS) + expected := start.Add(expectedWait) + s.WithinDuration(time.Now(), expected, 100*time.Millisecond) +} + func (s *controllerSuite) setupMocksForAcquireShard( shardID int32, mockEngine *MockEngine, @@ -772,7 +909,7 @@ func (s *controllerSuite) setupMocksForAcquireShard( s.mockShardManager.EXPECT().AssertShardOwnership(gomock.Any(), &persistence.AssertShardOwnershipRequest{ ShardID: shardID, RangeID: newRangeID, - }).Return(nil).AnyTimes() + }).Return(nil).MinTimes(minTimes) } func (s *controllerSuite) queueStates() map[int32]*persistencespb.QueueState { @@ -869,3 +1006,17 @@ func (m updateShardRequestMatcher) Matches(x interface{}) bool { func (m updateShardRequestMatcher) String() string { return fmt.Sprintf("%+v", (persistence.UpdateShardRequest)(m)) } + +func (s *controllerSuite) readMetricsCounter(name string, nonSystemTags ...metrics.Tag) float64 { + expectedSystemTags := []metrics.Tag{ + metrics.StringTag("otel_scope_name", "temporal"), + metrics.StringTag("otel_scope_version", ""), + } + snapshot, err := s.metricsTestHandler.Snapshot() + s.NoError(err) + + tags := append(nonSystemTags, expectedSystemTags...) + value, err := snapshot.Counter(name+"_total", tags...) + s.NoError(err) + return value +}