Skip to content

Commit

Permalink
Adding test for the modified lines
Browse files Browse the repository at this point in the history
Also some refactoring of existing test: don't duplicate test-name in
table-tests and do not mock expectations in SetupTest.
  • Loading branch information
dkrotx committed Apr 29, 2024
1 parent 9fa0f38 commit d1c9a27
Showing 1 changed file with 58 additions and 9 deletions.
67 changes: 58 additions & 9 deletions service/history/handler/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ import (
"github.com/uber/cadence/common/log/testlogger"
"github.com/uber/cadence/common/metrics"
"github.com/uber/cadence/common/metrics/mocks"
"github.com/uber/cadence/common/persistence"
"github.com/uber/cadence/common/quotas"
"github.com/uber/cadence/common/service"
"github.com/uber/cadence/common/types"
"github.com/uber/cadence/service/history/config"
"github.com/uber/cadence/service/history/constants"
Expand Down Expand Up @@ -89,7 +91,6 @@ func (s *handlerSuite) SetupTest() {
s.mockResource.Logger = testlogger.New(s.Suite.T())
s.mockShardController = shard.NewMockController(s.controller)
s.mockEngine = engine.NewMockEngine(s.controller)
s.mockShardController.EXPECT().GetEngineForShard(gomock.Any()).Return(s.mockEngine, nil).AnyTimes()
s.mockWFCache = workflowcache.NewMockWFCache(s.controller)
internalRequestRateLimitingEnabledConfig := func(domainName string) bool { return false }
s.handler = NewHandler(s.mockResource, config.NewForTest(), s.mockWFCache, internalRequestRateLimitingEnabledConfig).(*handlerImpl)
Expand Down Expand Up @@ -341,13 +342,11 @@ func (s *handlerSuite) TestRecordActivityTaskStarted() {

func (s *handlerSuite) TestRecordDecisionTaskStarted() {
testInput := map[string]struct {
caseName string
input *types.RecordDecisionTaskStartedRequest
expected *types.RecordDecisionTaskStartedResponse
expectedError bool
}{
"valid input": {
caseName: "valid input",
input: &types.RecordDecisionTaskStartedRequest{
DomainUUID: testDomainID,
WorkflowExecution: &types.WorkflowExecution{
Expand All @@ -369,7 +368,6 @@ func (s *handlerSuite) TestRecordDecisionTaskStarted() {
expectedError: false,
},
"empty domainID": {
caseName: "empty domainID",
input: &types.RecordDecisionTaskStartedRequest{
DomainUUID: "",
WorkflowExecution: &types.WorkflowExecution{
Expand All @@ -381,7 +379,6 @@ func (s *handlerSuite) TestRecordDecisionTaskStarted() {
expectedError: true,
},
"ratelimit exceeded": {
caseName: "ratelimit exceeded",
input: &types.RecordDecisionTaskStartedRequest{
DomainUUID: testDomainID,
WorkflowExecution: &types.WorkflowExecution{
Expand All @@ -398,7 +395,6 @@ func (s *handlerSuite) TestRecordDecisionTaskStarted() {
expectedError: true,
},
"get engine error": {
caseName: "get engine error",
input: &types.RecordDecisionTaskStartedRequest{
DomainUUID: testDomainID,
WorkflowExecution: &types.WorkflowExecution{
Expand All @@ -415,7 +411,22 @@ func (s *handlerSuite) TestRecordDecisionTaskStarted() {
expectedError: true,
},
"engine error": {
caseName: "engine error",
input: &types.RecordDecisionTaskStartedRequest{
DomainUUID: testDomainID,
WorkflowExecution: &types.WorkflowExecution{
WorkflowID: testWorkflowID,
RunID: testValidUUID,
},
PollRequest: &types.PollForDecisionTaskRequest{
TaskList: &types.TaskList{
Name: "test-task-list",
},
},
},
expected: nil,
expectedError: true,
},
"engine error with ShardOwnershipLost": {
input: &types.RecordDecisionTaskStartedRequest{
DomainUUID: testDomainID,
WorkflowExecution: &types.WorkflowExecution{
Expand All @@ -432,7 +443,6 @@ func (s *handlerSuite) TestRecordDecisionTaskStarted() {
expectedError: true,
},
"empty poll request": {
caseName: "empty poll request",
input: &types.RecordDecisionTaskStartedRequest{
DomainUUID: testDomainID,
WorkflowExecution: &types.WorkflowExecution{
Expand All @@ -447,7 +457,7 @@ func (s *handlerSuite) TestRecordDecisionTaskStarted() {

for name, input := range testInput {
s.Run(name, func() {
switch input.caseName {
switch name {
case "valid input":
s.mockShardController.EXPECT().GetEngine(gomock.Any()).Return(s.mockEngine, nil).Times(1)
s.mockEngine.EXPECT().RecordDecisionTaskStarted(gomock.Any(), input.input).Return(input.expected, nil).Times(1)
Expand All @@ -462,9 +472,15 @@ func (s *handlerSuite) TestRecordDecisionTaskStarted() {
s.mockShardController.EXPECT().GetEngine(testWorkflowID).Return(s.mockEngine, nil).Times(1)
s.mockEngine.EXPECT().RecordDecisionTaskStarted(gomock.Any(), input.input).Return(nil, errors.New("error")).Times(1)
s.mockRatelimiter.EXPECT().Allow().Return(true).Times(1)
case "engine error with ShardOwnershipLost":
s.mockShardController.EXPECT().GetEngine(testWorkflowID).Return(s.mockEngine, nil).Times(1)
s.mockRatelimiter.EXPECT().Allow().Return(true).Times(1)
s.mockEngine.EXPECT().RecordDecisionTaskStarted(gomock.Any(), input.input).Return(nil, &persistence.ShardOwnershipLostError{ShardID: 123}).Times(1)
s.mockResource.MembershipResolver.EXPECT().Lookup(service.History, string(rune(123)))
case "empty poll request":
s.mockRatelimiter.EXPECT().Allow().Return(true).Times(1)
}

response, err := s.handler.RecordDecisionTaskStarted(context.Background(), input.input)
s.Equal(input.expected, response)
if input.expectedError {
Expand Down Expand Up @@ -754,6 +770,8 @@ func (s *handlerSuite) TestGetCrossClusterTasks() {
var shardIDs []int32
numSucceeded := int32(0)
numTasksPerShard := rand.Intn(10)

s.mockShardController.EXPECT().GetEngineForShard(gomock.Any()).Return(s.mockEngine, nil).Times(numShards)
s.mockEngine.EXPECT().GetCrossClusterTasks(gomock.Any(), targetCluster).DoAndReturn(
func(_ context.Context, _ string) ([]*types.CrossClusterTaskRequest, error) {
succeeded := rand.Intn(2) == 0
Expand All @@ -764,6 +782,7 @@ func (s *handlerSuite) TestGetCrossClusterTasks() {
return nil, errors.New("some random error")
},
).MaxTimes(numShards)

for i := 0; i != numShards; i++ {
shardIDs = append(shardIDs, int32(i))
}
Expand All @@ -783,6 +802,35 @@ func (s *handlerSuite) TestGetCrossClusterTasks() {
}
}

func (s *handlerSuite) TestGetCrossClusterTasksFails_IfGetEngineFails() {
numShards := 10
targetCluster := cluster.TestAlternativeClusterName
var shardIDs []int32

for i := 0; i != numShards; i++ {
shardIDs = append(shardIDs, int32(i))
s.mockShardController.EXPECT().GetEngineForShard(i).
Return(nil, errors.New("failed to get engine"))

// as response to the above failure we're looking up for the current shard owner
s.mockResource.MembershipResolver.EXPECT().Lookup(service.History, string(rune(i)))
}

request := &types.GetCrossClusterTasksRequest{
ShardIDs: shardIDs,
TargetCluster: targetCluster,
}

response, err := s.handler.GetCrossClusterTasks(context.Background(), request)
s.NoError(err)
s.NotNil(response)

s.Len(response.FailedCauseByShard, numShards, "we fail GetEngineForShard every time")
for _, failure := range response.FailedCauseByShard {
s.IsType(types.GetTaskFailedCauseShardOwnershipLost, failure)
}
}

func (s *handlerSuite) TestRespondCrossClusterTaskCompleted_FetchNewTask() {
s.testRespondCrossClusterTaskCompleted(true)
}
Expand All @@ -802,6 +850,7 @@ func (s *handlerSuite) testRespondCrossClusterTaskCompleted(
TaskResponses: make([]*types.CrossClusterTaskResponse, numTasks),
FetchNewTasks: fetchNewTask,
}
s.mockShardController.EXPECT().GetEngineForShard(0).Return(s.mockEngine, nil)
s.mockEngine.EXPECT().RespondCrossClusterTasksCompleted(gomock.Any(), targetCluster, request.TaskResponses).Return(nil).Times(1)
if fetchNewTask {
s.mockEngine.EXPECT().GetCrossClusterTasks(gomock.Any(), targetCluster).Return(make([]*types.CrossClusterTaskRequest, numTasks), nil).Times(1)
Expand Down

0 comments on commit d1c9a27

Please sign in to comment.