Skip to content

Commit

Permalink
Added unit tests for history handler (#6007)
Browse files Browse the repository at this point in the history
* Added unit tests for history handler
  • Loading branch information
timl3136 committed May 10, 2024
1 parent 684836a commit dbfb1c8
Show file tree
Hide file tree
Showing 2 changed files with 324 additions and 0 deletions.
2 changes: 2 additions & 0 deletions common/resource/resourceTest.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ type (
ArchivalMetadata *archiver.MockArchivalMetadata
ArchiverProvider *provider.MockArchiverProvider
BlobstoreClient *blobstore.MockClient
MockPayloadSerializer *persistence.MockPayloadSerializer

// membership infos
MembershipResolver *membership.MockResolver
Expand Down Expand Up @@ -184,6 +185,7 @@ func NewTest(
ArchivalMetadata: &archiver.MockArchivalMetadata{},
ArchiverProvider: &provider.MockArchiverProvider{},
BlobstoreClient: &blobstore.MockClient{},
MockPayloadSerializer: persistence.NewMockPayloadSerializer(controller),

// membership infos
MembershipResolver: membership.NewMockResolver(controller),
Expand Down
322 changes: 322 additions & 0 deletions service/history/handler/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"go.uber.org/goleak"
"go.uber.org/yarpc/yarpcerrors"

"github.com/uber/cadence/common"
Expand Down Expand Up @@ -2816,6 +2817,327 @@ func (s *handlerSuite) TestSyncActivity() {
}
}

func (s *handlerSuite) TestGetReplicationMessages() {
validInput := &types.GetReplicationMessagesRequest{
ClusterName: "test",
Tokens: []*types.ReplicationToken{
{
ShardID: 1,
LastRetrievedMessageID: 1,
},
{
ShardID: 2,
LastRetrievedMessageID: 2,
},
},
}

testInput := map[string]struct {
input *types.GetReplicationMessagesRequest
expectedError bool
mockFn func()
}{
"shutting down": {
input: validInput,
expectedError: true,
mockFn: func() {
s.handler.shuttingDown = int32(1)
},
},
"success": {
input: validInput,
expectedError: false,
mockFn: func() {
s.mockShardController.EXPECT().GetEngineForShard(int(validInput.Tokens[0].ShardID)).Return(s.mockEngine, nil).Times(1)
s.mockEngine.EXPECT().GetReplicationMessages(gomock.Any(), validInput.ClusterName, validInput.Tokens[0].LastRetrievedMessageID).Return(&types.ReplicationMessages{}, nil).Times(1)
s.mockShardController.EXPECT().GetEngineForShard(int(validInput.Tokens[1].ShardID)).Return(s.mockEngine, nil).Times(1)
s.mockEngine.EXPECT().GetReplicationMessages(gomock.Any(), validInput.ClusterName, validInput.Tokens[1].LastRetrievedMessageID).Return(&types.ReplicationMessages{}, nil).Times(1)
},
},
"cannot get engine and cannot get task": {
input: validInput,
expectedError: false,
mockFn: func() {
s.mockShardController.EXPECT().GetEngineForShard(int(validInput.Tokens[0].ShardID)).Return(nil, errors.New("errors")).Times(1)
s.mockShardController.EXPECT().GetEngineForShard(int(validInput.Tokens[1].ShardID)).Return(s.mockEngine, nil).Times(1)
s.mockEngine.EXPECT().GetReplicationMessages(gomock.Any(), validInput.ClusterName, validInput.Tokens[1].LastRetrievedMessageID).Return(nil, errors.New("errors")).Times(1)
},
},
"maxSize exceeds": {
input: validInput,
expectedError: false,
mockFn: func() {
s.handler.config.MaxResponseSize = 0
s.mockShardController.EXPECT().GetEngineForShard(int(validInput.Tokens[0].ShardID)).Return(s.mockEngine, nil).Times(1)
s.mockEngine.EXPECT().GetReplicationMessages(gomock.Any(), validInput.ClusterName, validInput.Tokens[0].LastRetrievedMessageID).Return(&types.ReplicationMessages{
ReplicationTasks: []*types.ReplicationTask{
{
TaskType: types.ReplicationTaskTypeHistory.Ptr(),
},
{
TaskType: types.ReplicationTaskTypeHistory.Ptr(),
},
},
}, nil).Times(1)
s.mockShardController.EXPECT().GetEngineForShard(int(validInput.Tokens[1].ShardID)).Return(s.mockEngine, nil).Times(1)
s.mockEngine.EXPECT().GetReplicationMessages(gomock.Any(), validInput.ClusterName, validInput.Tokens[1].LastRetrievedMessageID).Return(&types.ReplicationMessages{
ReplicationTasks: []*types.ReplicationTask{
{
TaskType: types.ReplicationTaskTypeHistory.Ptr(),
},
{
TaskType: types.ReplicationTaskTypeHistory.Ptr(),
},
},
}, nil).Times(1)
},
},
}

for name, input := range testInput {
s.Run(name, func() {
input.mockFn()
resp, err := s.handler.GetReplicationMessages(context.Background(), input.input)
s.handler.shuttingDown = int32(0)
if input.expectedError {
s.Nil(resp)
s.Error(err)
} else {
s.NotNil(resp)
s.NoError(err)
}
goleak.VerifyNone(s.T())
})
}
}

func (s *handlerSuite) TestGetDLQReplicationMessages() {
validInput := &types.GetDLQReplicationMessagesRequest{
TaskInfos: []*types.ReplicationTaskInfo{
{
DomainID: testDomainID,
WorkflowID: testWorkflowID,
RunID: testValidUUID,
},
},
}
mockResp := make([]*types.ReplicationTask, 0, 10)
mockResp = append(mockResp, &types.ReplicationTask{
TaskType: types.ReplicationTaskTypeHistory.Ptr(),
})

mockEmptyResp := make([]*types.ReplicationTask, 0)

testInput := map[string]struct {
input *types.GetDLQReplicationMessagesRequest
expectedError bool
mockFn func()
}{
"shutting down": {
input: validInput,
expectedError: true,
mockFn: func() {
s.handler.shuttingDown = int32(1)
},
},
"success": {
input: validInput,
expectedError: false,
mockFn: func() {
s.mockShardController.EXPECT().GetEngine(gomock.Any()).Return(s.mockEngine, nil).Times(1)
s.mockEngine.EXPECT().GetDLQReplicationMessages(gomock.Any(), gomock.Any()).Return(mockResp, nil).Times(1)
},
},
"cannot get engine": {
input: validInput,
expectedError: false,
mockFn: func() {
s.mockShardController.EXPECT().GetEngine(gomock.Any()).Return(nil, errors.New("error")).Times(1)
},
},
"cannot get task": {
input: validInput,
expectedError: false,
mockFn: func() {
s.mockShardController.EXPECT().GetEngine(gomock.Any()).Return(s.mockEngine, nil).Times(1)
s.mockEngine.EXPECT().GetDLQReplicationMessages(gomock.Any(), gomock.Any()).Return(nil, errors.New("error")).Times(1)
},
},
"empty task response": {
input: validInput,
expectedError: false,
mockFn: func() {
s.mockShardController.EXPECT().GetEngine(gomock.Any()).Return(s.mockEngine, nil).Times(1)
s.mockEngine.EXPECT().GetDLQReplicationMessages(gomock.Any(), gomock.Any()).Return(mockEmptyResp, nil).Times(1)
},
},
}

for name, input := range testInput {
s.Run(name, func() {
input.mockFn()
resp, err := s.handler.GetDLQReplicationMessages(context.Background(), input.input)
s.handler.shuttingDown = int32(0)
if input.expectedError {
s.Nil(resp)
s.Error(err)
} else {
s.NotNil(resp)
s.NoError(err)
}
goleak.VerifyNone(s.T())
})

}
}

func (s *handlerSuite) TestReapplyEvents() {
validInput := &types.HistoryReapplyEventsRequest{
DomainUUID: testDomainID,
Request: &types.ReapplyEventsRequest{
WorkflowExecution: &types.WorkflowExecution{
WorkflowID: testWorkflowID,
RunID: testValidUUID,
},
Events: &types.DataBlob{
EncodingType: types.EncodingTypeThriftRW.Ptr(),
Data: []byte{},
},
},
}

testInput := map[string]struct {
input *types.HistoryReapplyEventsRequest
expectedError bool
mockFn func()
}{
"shutting down": {
input: validInput,
expectedError: true,
mockFn: func() {
s.handler.shuttingDown = int32(1)
},
},
"cannot get engine": {
input: validInput,
expectedError: true,
mockFn: func() {
s.mockShardController.EXPECT().GetEngine(testWorkflowID).Return(nil, errors.New("error")).Times(1)
},
},
"cannot get serialized": {
input: &types.HistoryReapplyEventsRequest{
DomainUUID: testDomainID,
Request: &types.ReapplyEventsRequest{
WorkflowExecution: &types.WorkflowExecution{
WorkflowID: testWorkflowID,
RunID: testValidUUID,
},
Events: &types.DataBlob{
EncodingType: types.EncodingTypeThriftRW.Ptr(),
Data: []byte{1, 2, 3, 4},
},
},
},
expectedError: true,
mockFn: func() {
s.mockShardController.EXPECT().GetEngine(testWorkflowID).Return(s.mockEngine, nil).Times(1)
},
},
"reapplyEvents error": {
input: validInput,
expectedError: true,
mockFn: func() {
s.mockShardController.EXPECT().GetEngine(testWorkflowID).Return(s.mockEngine, nil).Times(1)
s.mockEngine.EXPECT().ReapplyEvents(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("error")).Times(1)
},
},
"success": {
input: validInput,
expectedError: false,
mockFn: func() {
s.mockShardController.EXPECT().GetEngine(testWorkflowID).Return(s.mockEngine, nil).Times(1)
s.mockEngine.EXPECT().ReapplyEvents(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Times(1)
},
},
}

for name, input := range testInput {
s.Run(name, func() {
input.mockFn()
err := s.handler.ReapplyEvents(context.Background(), input.input)
s.handler.shuttingDown = int32(0)
if input.expectedError {
s.Error(err)
} else {
s.NoError(err)
}
})
}
}

func (s *handlerSuite) TestCountDLQMessages() {
validInput := &types.CountDLQMessagesRequest{
ForceFetch: true,
}

testInput := map[string]struct {
input *types.CountDLQMessagesRequest
expectedError bool
mockFn func()
}{
"shutting down": {
input: validInput,
expectedError: true,
mockFn: func() {
s.handler.shuttingDown = int32(1)
},
},
"cannot get engine": {
input: validInput,
expectedError: true,
mockFn: func() {
s.mockShardController.EXPECT().ShardIDs().Return([]int32{0}).Times(1)
s.mockShardController.EXPECT().GetEngineForShard(gomock.Any()).Return(nil, errors.New("error")).Times(1)
},
},
"countDLQMessages error": {
input: validInput,
expectedError: true,
mockFn: func() {
s.mockShardController.EXPECT().ShardIDs().Return([]int32{0}).Times(1)
s.mockShardController.EXPECT().GetEngineForShard(gomock.Any()).Return(s.mockEngine, nil).Times(1)
s.mockEngine.EXPECT().CountDLQMessages(gomock.Any(), gomock.Any()).Return(map[string]int64{}, errors.New("error")).Times(1)
},
},
"success": {
input: validInput,
expectedError: false,
mockFn: func() {
s.mockShardController.EXPECT().ShardIDs().Return([]int32{0}).Times(1)
s.mockShardController.EXPECT().GetEngineForShard(gomock.Any()).Return(s.mockEngine, nil).Times(1)
s.mockEngine.EXPECT().CountDLQMessages(gomock.Any(), gomock.Any()).Return(map[string]int64{
"test": 1,
"test2": 2,
}, nil).Times(1)
},
},
}

for name, input := range testInput {
s.Run(name, func() {
input.mockFn()
_, err := s.handler.CountDLQMessages(context.Background(), input.input)
s.handler.shuttingDown = int32(0)
if input.expectedError {
s.Error(err)
} else {
s.NoError(err)
}
})
}
}

func (s *handlerSuite) TestGetCrossClusterTasks() {
numShards := 10
targetCluster := cluster.TestAlternativeClusterName
Expand Down

0 comments on commit dbfb1c8

Please sign in to comment.