Skip to content

Commit

Permalink
Add String method to ContextImpl to fix a race (#2879)
Browse files Browse the repository at this point in the history
  • Loading branch information
dnr committed May 21, 2022
1 parent b94ddc1 commit b530a35
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 27 deletions.
5 changes: 5 additions & 0 deletions service/history/shard/context_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ const (
minContextTimeout = 2 * time.Second
)

func (s *ContextImpl) String() string {
// constant from initialization, no need for locks
return fmt.Sprintf("Shard(%d)", s.shardID)
}

func (s *ContextImpl) GetShardID() int32 {
// constant from initialization, no need for locks
return s.shardID
Expand Down
49 changes: 22 additions & 27 deletions service/history/shard/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,6 @@ import (
)

type (
contextMatcher struct {
shardID int32
}

controllerSuite struct {
suite.Suite
*require.Assertions
Expand Down Expand Up @@ -151,17 +147,6 @@ func (s *controllerSuite) TearDownTest() {
s.controller.Finish()
}

type getOrCreateShardRequestMatcher int32

func (s getOrCreateShardRequestMatcher) Matches(x interface{}) bool {
req, ok := x.(*persistence.GetOrCreateShardRequest)
return ok && req.ShardID == int32(s)
}

func (s getOrCreateShardRequestMatcher) String() string {
return strconv.Itoa(int(s))
}

func (s *controllerSuite) TestAcquireShardSuccess() {
numShards := int32(10)
s.config.NumberOfShards = numShards
Expand Down Expand Up @@ -702,7 +687,7 @@ func (s *controllerSuite) setupMocksForAcquireShard(shardID int32, mockEngine *M
// s.mockResource.ExecutionMgr.On("Close").Return()
mockEngine.EXPECT().Start().MinTimes(minTimes)
s.mockServiceResolver.EXPECT().Lookup(convert.Int32ToString(shardID)).Return(s.hostInfo, nil).Times(2).MinTimes(minTimes)
s.mockEngineFactory.EXPECT().CreateEngine(newContextMatcher(shardID)).Return(mockEngine).MinTimes(minTimes)
s.mockEngineFactory.EXPECT().CreateEngine(contextMatcher(shardID)).Return(mockEngine).MinTimes(minTimes)
s.mockShardManager.EXPECT().GetOrCreateShard(gomock.Any(), getOrCreateShardRequestMatcher(shardID)).Return(
&persistence.GetOrCreateShardResponse{
ShardInfo: &persistencespb.ShardInfo{
Expand Down Expand Up @@ -750,19 +735,29 @@ func (s *controllerSuite) setupMocksForAcquireShard(shardID int32, mockEngine *M
}).Return(nil).MinTimes(minTimes)
}

func newContextMatcher(shardID int32) *contextMatcher {
return &contextMatcher{shardID: shardID}
}
// This is needed to avoid race conditions when using this matcher, since
// fmt.Sprintf("%v"), used by gomock, would otherwise access private fields.
// See https://github.com/temporalio/temporal/issues/2777
var _ fmt.Stringer = (*ContextImpl)(nil)

func (m *contextMatcher) Matches(x interface{}) bool {
type contextMatcher int32

func (s contextMatcher) Matches(x interface{}) bool {
context, ok := x.(Context)
if !ok {
return false
}
return m.shardID == context.GetShardID()
return ok && context.GetShardID() == int32(s)
}

func (s contextMatcher) String() string {
return strconv.Itoa(int(s))
}

type getOrCreateShardRequestMatcher int32

func (s getOrCreateShardRequestMatcher) Matches(x interface{}) bool {
req, ok := x.(*persistence.GetOrCreateShardRequest)
return ok && req.ShardID == int32(s)
}

func (m *contextMatcher) String() string {
// noop, not used
return ""
func (s getOrCreateShardRequestMatcher) String() string {
return strconv.Itoa(int(s))
}

0 comments on commit b530a35

Please sign in to comment.