diff --git a/common/persistence/visibility/store/elasticsearch/visibility_store.go b/common/persistence/visibility/store/elasticsearch/visibility_store.go index 3b6957b6911..2d7d7f6db33 100644 --- a/common/persistence/visibility/store/elasticsearch/visibility_store.go +++ b/common/persistence/visibility/store/elasticsearch/visibility_store.go @@ -31,6 +31,7 @@ import ( "encoding/json" "errors" "fmt" + "math" "strconv" "strings" "time" @@ -56,17 +57,6 @@ const ( delimiter = "~" ) -// Default sort by uses the sorting order defined in the index template, so no -// additional sorting is needed during query. -var defaultSorter = []elastic.Sorter{ - elastic.NewFieldSort(searchattribute.CloseTime).Desc().Missing("_first"), - elastic.NewFieldSort(searchattribute.StartTime).Desc().Missing("_first"), -} - -var docSorter = []elastic.Sorter{ - elastic.SortByDoc{}, -} - type ( visibilityStore struct { esClient client.Client @@ -82,12 +72,45 @@ type ( visibilityPageToken struct { SearchAfter []interface{} } + + fieldSort struct { + name string + desc bool + missing_first bool + } ) var _ store.VisibilityStore = (*visibilityStore)(nil) var ( errUnexpectedJSONFieldType = errors.New("unexpected JSON field type") + + // Default sorter uses the sorting order defined in the index template. + // It is indirectly built so buildPaginationQuery can have access to + // the fields names to build the page query from the token. + defaultSorterFields = []fieldSort{ + {searchattribute.CloseTime, true, true}, + {searchattribute.StartTime, true, true}, + } + + defaultSorter = func() []elastic.Sorter { + ret := make([]elastic.Sorter, 0, len(defaultSorterFields)) + for _, item := range defaultSorterFields { + fs := elastic.NewFieldSort(item.name) + if item.desc { + fs.Desc() + } + if item.missing_first { + fs.Missing("_first") + } + ret = append(ret, fs) + } + return ret + }() + + docSorter = []elastic.Sorter{ + elastic.SortByDoc{}, + } ) // NewVisibilityStore create a visibility store connecting to ElasticSearch @@ -439,15 +462,6 @@ func (s *visibilityStore) ListWorkflowExecutions( return nil, err } - token, err := s.deserializePageToken(request.NextPageToken) - if err != nil { - return nil, err - } - - if token != nil && len(token.SearchAfter) > 0 { - p.SearchAfter = token.SearchAfter - } - searchResult, err := s.esClient.Search(ctx, p) if err != nil { return nil, convertElasticsearchClientError("ListWorkflowExecutions failed", err) @@ -465,15 +479,6 @@ func (s *visibilityStore) ScanWorkflowExecutions( return nil, err } - token, err := s.deserializePageToken(request.NextPageToken) - if err != nil { - return nil, err - } - - if token != nil && len(token.SearchAfter) > 0 { - p.SearchAfter = token.SearchAfter - } - searchResult, err := s.esClient.Search(ctx, p) if err != nil { return nil, convertElasticsearchClientError("ScanWorkflowExecutions failed", err) @@ -588,7 +593,6 @@ func (s *visibilityStore) buildSearchParametersV2( request *manager.ListWorkflowExecutionsRequestV2, getFieldSorter func([]*elastic.FieldSort) ([]elastic.Sorter, error), ) (*client.SearchParameters, error) { - boolQuery, fieldSorts, err := s.convertQuery( request.Namespace, request.NamespaceID, @@ -619,9 +623,63 @@ func (s *visibilityStore) buildSearchParametersV2( Sorter: sorter, } + pageToken, err := s.deserializePageToken(request.NextPageToken) + if err != nil { + return nil, err + } + err = s.processPageToken(params, pageToken) + if err != nil { + return nil, err + } + return params, nil } +func (s *visibilityStore) processPageToken( + params *client.SearchParameters, + pageToken *visibilityPageToken, +) error { + if pageToken == nil || len(pageToken.SearchAfter) == 0 { + return nil + } + if len(pageToken.SearchAfter) != len(params.Sorter) { + return serviceerror.NewInvalidArgument(fmt.Sprintf( + "Invalid page token for given sort fields: expected %d fields, got %d", + len(params.Sorter), + len(pageToken.SearchAfter), + )) + } + if !isDefaultSorter(params.Sorter) { + params.SearchAfter = pageToken.SearchAfter + return nil + } + + boolQuery, ok := params.Query.(*elastic.BoolQuery) + if !ok { + return serviceerror.NewInternal(fmt.Sprintf( + "Unexpected query type: expected *elastic.BoolQuery, got %T", + params.Query, + )) + } + + saTypeMap, err := s.searchAttributesProvider.GetSearchAttributes(s.index, false) + if err != nil { + return serviceerror.NewUnavailable( + fmt.Sprintf("Unable to read search attribute types: %v", err), + ) + } + + // build pagination search query for default sorter + shouldQueries, err := buildPaginationQuery(defaultSorterFields, pageToken.SearchAfter, saTypeMap) + if err != nil { + return err + } + + boolQuery.Should(shouldQueries...) + boolQuery.MinimumNumberShouldMatch(1) + return nil +} + func (s *visibilityStore) convertQuery( namespace namespace.Name, namespaceID namespace.ID, @@ -998,3 +1056,165 @@ func detailedErrorMessage(err error) string { } return sb.String() } + +func isDefaultSorter(sorter []elastic.Sorter) bool { + if len(sorter) != len(defaultSorter) { + return false + } + for i := 0; i < len(defaultSorter); i++ { + if &sorter[i] != &defaultSorter[i] { + return false + } + } + return true +} + +// buildPaginationQuery builds the Elasticsearch conditions for the next page based on searchAfter. +// +// For example, if sorterFields = [A, B, C] and searchAfter = [lastA, lastB, lastC], +// it will build the following conditions (assuming all values are non-null and orders are desc): +// - k = 0: A < lastA +// - k = 1: A = lastA AND B < lastB +// - k = 2: A = lastA AND B = lastB AND C < lastC +// +//nolint:revive // cyclomatic complexity +func buildPaginationQuery( + sorterFields []fieldSort, + searchAfter []any, + saTypeMap searchattribute.NameTypeMap, +) ([]elastic.Query, error) { + n := len(sorterFields) + if len(sorterFields) != len(searchAfter) { + return nil, serviceerror.NewInvalidArgument(fmt.Sprintf( + "Invalid page token for given sort fields: expected %d fields, got %d", + len(sorterFields), + len(searchAfter), + )) + } + + parsedSearchAfter := make([]any, n) + for i := 0; i < n; i++ { + tp, err := saTypeMap.GetType(sorterFields[i].name) + if err != nil { + return nil, err + } + parsedSearchAfter[i], err = parsePageTokenValue(sorterFields[i].name, searchAfter[i], tp) + if err != nil { + return nil, err + } + } + + // Last field of sorter must be a tie breaker, and thus cannot contain null value. + if parsedSearchAfter[len(parsedSearchAfter)-1] == nil { + return nil, serviceerror.NewInternal(fmt.Sprintf( + "Last field of sorter cannot be a nullable field: %q has null values", + sorterFields[len(sorterFields)-1].name, + )) + } + + shouldQueries := make([]elastic.Query, 0, len(sorterFields)) + for k := 0; k < len(sorterFields); k++ { + bq := elastic.NewBoolQuery() + for i := 0; i <= k; i++ { + field := sorterFields[i] + value := parsedSearchAfter[i] + if i == k { + if value == nil { + bq.Filter(elastic.NewExistsQuery(field.name)) + } else if field.desc { + bq.Filter(elastic.NewRangeQuery(field.name).Lt(value)) + } else { + bq.Filter(elastic.NewRangeQuery(field.name).Gt(value)) + } + } else { + if value == nil { + bq.MustNot(elastic.NewExistsQuery(field.name)) + } else { + bq.Filter(elastic.NewTermQuery(field.name, value)) + } + } + } + shouldQueries = append(shouldQueries, bq) + } + return shouldQueries, nil +} + +// parsePageTokenValue parses the page token values to be used in the search query. +// The page token comes from the `sort` field from the previous response from Elasticsearch. +// Depending on the type of the field, the null value is represented differently: +// - integer, bool, and datetime: MaxInt64 (desc) or MinInt64 (asc) +// - double: "Infinity" (desc) or "-Infinity" (asc) +// - keyword: nil +// +// Furthermore, for bool and datetime, they need to be converted to boolean or the RFC3339Nano +// formats respectively. +// +//nolint:revive // cyclomatic complexity +func parsePageTokenValue( + fieldName string, jsonValue any, + tp enumspb.IndexedValueType, +) (any, error) { + switch tp { + case enumspb.INDEXED_VALUE_TYPE_INT, + enumspb.INDEXED_VALUE_TYPE_BOOL, + enumspb.INDEXED_VALUE_TYPE_DATETIME: + jsonNumber, ok := jsonValue.(json.Number) + if !ok { + return nil, serviceerror.NewInvalidArgument(fmt.Sprintf( + "Invalid page token: expected interger type, got %q", jsonValue)) + } + num, err := jsonNumber.Int64() + if err != nil { + return nil, serviceerror.NewInvalidArgument(fmt.Sprintf( + "Invalid page token: expected interger type, got %v", jsonValue)) + } + if num == math.MaxInt64 || num == math.MinInt64 { + return nil, nil + } + if tp == enumspb.INDEXED_VALUE_TYPE_BOOL { + return num != 0, nil + } + if tp == enumspb.INDEXED_VALUE_TYPE_DATETIME { + return time.Unix(0, num).UTC().Format(time.RFC3339Nano), nil + } + return num, nil + + case enumspb.INDEXED_VALUE_TYPE_DOUBLE: + switch v := jsonValue.(type) { + case json.Number: + num, err := v.Float64() + if err != nil { + return nil, serviceerror.NewInvalidArgument(fmt.Sprintf( + "Invalid page token: expected float type, got %v", jsonValue)) + } + return num, nil + case string: + // it can be the string representation of infinity + if _, err := strconv.ParseFloat(v, 64); err != nil { + return nil, serviceerror.NewInvalidArgument(fmt.Sprintf( + "Invalid page token: expected float type, got %q", jsonValue)) + } + return nil, nil + default: + // it should never reach here + return nil, serviceerror.NewInvalidArgument(fmt.Sprintf( + "Invalid page token: expected float type, got %#v", jsonValue)) + } + + case enumspb.INDEXED_VALUE_TYPE_KEYWORD: + if jsonValue == nil { + return nil, nil + } + if _, ok := jsonValue.(string); !ok { + return nil, serviceerror.NewInvalidArgument(fmt.Sprintf( + "Invalid page token: expected string type, got %v", jsonValue)) + } + return jsonValue, nil + + default: + return nil, serviceerror.NewInvalidArgument(fmt.Sprintf( + "Invalid field type in sorter: cannot order by %q", + fieldName, + )) + } +} diff --git a/common/persistence/visibility/store/elasticsearch/visibility_store_read_test.go b/common/persistence/visibility/store/elasticsearch/visibility_store_read_test.go index b1ee67c1504..2898dbfcef4 100644 --- a/common/persistence/visibility/store/elasticsearch/visibility_store_read_test.go +++ b/common/persistence/visibility/store/elasticsearch/visibility_store_read_test.go @@ -1135,7 +1135,7 @@ func (s *ESVisibilitySuite) TestScanWorkflowExecutions() { Hits: []*elastic.SearchHit{ { Source: source, - Sort: []interface{}{json.Number("123"), "runId"}, + Sort: []interface{}{json.Number("123")}, }, }, }, @@ -1156,7 +1156,7 @@ func (s *ESVisibilitySuite) TestScanWorkflowExecutions() { request.Query = `ExecutionStatus = "Terminated"` s.mockESClient.EXPECT().Search(gomock.Any(), gomock.Any()).Return(searchResult, nil) - token := &visibilityPageToken{SearchAfter: []interface{}{json.Number("1528358645123456789"), "qwe"}} + token := &visibilityPageToken{SearchAfter: []interface{}{json.Number("1528358645123456789")}} tokenBytes, err := s.visibilityStore.serializePageToken(token) s.NoError(err) request.NextPageToken = tokenBytes @@ -1164,7 +1164,7 @@ func (s *ESVisibilitySuite) TestScanWorkflowExecutions() { s.NoError(err) responseToken, err := s.visibilityStore.deserializePageToken(result.NextPageToken) s.NoError(err) - s.Equal([]interface{}{json.Number("123"), "runId"}, responseToken.SearchAfter) + s.Equal([]interface{}{json.Number("123")}, responseToken.SearchAfter) // test last page searchResult = &elastic.SearchResult{ @@ -1209,7 +1209,7 @@ func (s *ESVisibilitySuite) TestScanWorkflowExecutions_OldPageToken() { Hits: []*elastic.SearchHit{ { Source: source, - Sort: []interface{}{json.Number("123"), "runId"}, + Sort: []interface{}{json.Number("123")}, }, }, }, @@ -1224,7 +1224,7 @@ func (s *ESVisibilitySuite) TestScanWorkflowExecutions_OldPageToken() { ScrollID string PointInTimeID string }{ - SearchAfter: []interface{}{json.Number("1528358645123456789"), "qwe"}, + SearchAfter: []interface{}{json.Number("1528358645123456789")}, ScrollID: "random-scroll", PointInTimeID: "random-pit", } @@ -1236,7 +1236,7 @@ func (s *ESVisibilitySuite) TestScanWorkflowExecutions_OldPageToken() { s.NoError(err) responseToken, err := s.visibilityStore.deserializePageToken(result.NextPageToken) s.NoError(err) - s.Equal([]interface{}{json.Number("123"), "runId"}, responseToken.SearchAfter) + s.Equal([]interface{}{json.Number("123")}, responseToken.SearchAfter) } func (s *ESVisibilitySuite) TestCountWorkflowExecutions() { @@ -1386,3 +1386,292 @@ func (s *ESVisibilitySuite) Test_detailedErrorMessage() { } s.Equal("elastic: Error 500 (Internal Server Error): some reason [type=some type], root causes: some other reason1 [type=some other type1], some other reason2 [type=some other type2]", detailedErrorMessage(err)) } + +func (s *ESVisibilitySuite) Test_buildPaginationQuery() { + startTime := time.Now().UTC() + closeTime := startTime.Add(1 * time.Minute) + datetimeNull := json.Number(fmt.Sprintf("%d", math.MaxInt64)) + + testCases := []struct { + name string + sorterFields []fieldSort + searchAfter []any + res []elastic.Query + err error + }{ + { + name: "one field", + sorterFields: []fieldSort{{searchattribute.StartTime, true, true}}, + searchAfter: []any{json.Number(fmt.Sprintf("%d", startTime.UnixNano()))}, + res: []elastic.Query{ + elastic.NewBoolQuery().Filter( + elastic.NewRangeQuery(searchattribute.StartTime).Lt(startTime.Format(time.RFC3339Nano)), + ), + }, + err: nil, + }, + { + name: "two fields one null", + sorterFields: []fieldSort{ + {searchattribute.CloseTime, true, true}, + {searchattribute.StartTime, true, true}, + }, + searchAfter: []any{ + datetimeNull, + json.Number(fmt.Sprintf("%d", startTime.UnixNano())), + }, + res: []elastic.Query{ + elastic.NewBoolQuery().Filter(elastic.NewExistsQuery(searchattribute.CloseTime)), + elastic.NewBoolQuery(). + MustNot(elastic.NewExistsQuery(searchattribute.CloseTime)). + Filter( + elastic.NewRangeQuery(searchattribute.StartTime).Lt(startTime.Format(time.RFC3339Nano)), + ), + }, + err: nil, + }, + { + name: "two fields no null", + sorterFields: []fieldSort{ + {searchattribute.CloseTime, true, true}, + {searchattribute.StartTime, true, true}, + }, + searchAfter: []any{ + json.Number(fmt.Sprintf("%d", closeTime.UnixNano())), + json.Number(fmt.Sprintf("%d", startTime.UnixNano())), + }, + res: []elastic.Query{ + elastic.NewBoolQuery().Filter( + elastic.NewRangeQuery(searchattribute.CloseTime).Lt(closeTime.Format(time.RFC3339Nano)), + ), + elastic.NewBoolQuery(). + Filter( + elastic.NewTermQuery(searchattribute.CloseTime, closeTime.Format(time.RFC3339Nano)), + elastic.NewRangeQuery(searchattribute.StartTime).Lt(startTime.Format(time.RFC3339Nano)), + ), + }, + err: nil, + }, + { + name: "three fields", + sorterFields: []fieldSort{ + {searchattribute.CloseTime, true, true}, + {searchattribute.StartTime, true, true}, + {searchattribute.RunID, false, true}, + }, + searchAfter: []any{ + json.Number(fmt.Sprintf("%d", closeTime.UnixNano())), + json.Number(fmt.Sprintf("%d", startTime.UnixNano())), + "random-run-id", + }, + res: []elastic.Query{ + elastic.NewBoolQuery().Filter( + elastic.NewRangeQuery(searchattribute.CloseTime).Lt(closeTime.Format(time.RFC3339Nano)), + ), + elastic.NewBoolQuery(). + Filter( + elastic.NewTermQuery(searchattribute.CloseTime, closeTime.Format(time.RFC3339Nano)), + elastic.NewRangeQuery(searchattribute.StartTime).Lt(startTime.Format(time.RFC3339Nano)), + ), + elastic.NewBoolQuery(). + Filter( + elastic.NewTermQuery(searchattribute.CloseTime, closeTime.Format(time.RFC3339Nano)), + elastic.NewTermQuery(searchattribute.StartTime, startTime.Format(time.RFC3339Nano)), + elastic.NewRangeQuery(searchattribute.RunID).Gt("random-run-id"), + ), + }, + err: nil, + }, + { + name: "invalid token: wrong size", + sorterFields: []fieldSort{ + {searchattribute.CloseTime, true, true}, + {searchattribute.StartTime, true, true}, + {searchattribute.RunID, false, true}, + }, + searchAfter: []any{ + json.Number(fmt.Sprintf("%d", closeTime.UnixNano())), + json.Number(fmt.Sprintf("%d", startTime.UnixNano())), + }, + res: nil, + err: serviceerror.NewInvalidArgument("Invalid page token for given sort fields: expected 3 fields, got 2"), + }, + { + name: "invalid token: last value null", + sorterFields: []fieldSort{ + {searchattribute.CloseTime, true, true}, + }, + searchAfter: []any{datetimeNull}, + res: nil, + err: serviceerror.NewInternal("Last field of sorter cannot be a nullable field: \"CloseTime\" has null values"), + }, + } + + for _, tc := range testCases { + s.T().Run(tc.name, func(t *testing.T) { + res, err := buildPaginationQuery(tc.sorterFields, tc.searchAfter, searchattribute.TestNameTypeMap) + s.Equal(tc.err, err) + s.Equal(tc.res, res) + }) + } +} + +func (s *ESVisibilitySuite) Test_parsePageTokenValue() { + testCases := []struct { + name string + value any + tp enumspb.IndexedValueType + res any + err error + }{ + { + name: "IntField", + value: 123, + tp: enumspb.INDEXED_VALUE_TYPE_INT, + res: int64(123), + err: nil, + }, + { + name: "NullMaxIntField", + value: math.MaxInt64, + tp: enumspb.INDEXED_VALUE_TYPE_INT, + res: nil, + err: nil, + }, + { + name: "NullMinIntField", + value: math.MinInt64, + tp: enumspb.INDEXED_VALUE_TYPE_INT, + res: nil, + err: nil, + }, + { + name: "BoolFieldTrue", + value: 1, + tp: enumspb.INDEXED_VALUE_TYPE_BOOL, + res: true, + err: nil, + }, + { + name: "BoolFieldFalse", + value: 0, + tp: enumspb.INDEXED_VALUE_TYPE_BOOL, + res: false, + err: nil, + }, + { + name: "NullMaxBoolField", + value: math.MaxInt64, + tp: enumspb.INDEXED_VALUE_TYPE_BOOL, + res: nil, + err: nil, + }, + { + name: "NullMinBoolField", + value: math.MinInt64, + tp: enumspb.INDEXED_VALUE_TYPE_BOOL, + res: nil, + err: nil, + }, + { + name: "DatetimeField", + value: 1683221689123456789, + tp: enumspb.INDEXED_VALUE_TYPE_DATETIME, + res: "2023-05-04T17:34:49.123456789Z", + err: nil, + }, + { + name: "NullMaxDatetimeField", + value: math.MaxInt64, + tp: enumspb.INDEXED_VALUE_TYPE_DATETIME, + res: nil, + err: nil, + }, + { + name: "NullMinDatetimeField", + value: math.MinInt64, + tp: enumspb.INDEXED_VALUE_TYPE_DATETIME, + res: nil, + err: nil, + }, + { + name: "DoubleField", + value: 3.14, + tp: enumspb.INDEXED_VALUE_TYPE_DOUBLE, + res: float64(3.14), + err: nil, + }, + { + name: "NullMaxDoubleField", + value: "Infinity", + tp: enumspb.INDEXED_VALUE_TYPE_DOUBLE, + res: nil, + err: nil, + }, + { + name: "NullMinDoubleField", + value: "-Infinity", + tp: enumspb.INDEXED_VALUE_TYPE_DOUBLE, + res: nil, + err: nil, + }, + { + name: "KeywordField", + value: "foo", + tp: enumspb.INDEXED_VALUE_TYPE_KEYWORD, + res: "foo", + err: nil, + }, + { + name: "NullKeywordField", + value: nil, + tp: enumspb.INDEXED_VALUE_TYPE_KEYWORD, + res: nil, + err: nil, + }, + { + name: "IntFieldError", + value: "123", + tp: enumspb.INDEXED_VALUE_TYPE_INT, + res: nil, + err: serviceerror.NewInvalidArgument("Invalid page token: expected interger type, got \"123\""), + }, + { + name: "DoubleFieldError", + value: "foo", + tp: enumspb.INDEXED_VALUE_TYPE_DOUBLE, + res: nil, + err: serviceerror.NewInvalidArgument("Invalid page token: expected float type, got \"foo\""), + }, + { + name: "KeywordFieldError", + value: 123, + tp: enumspb.INDEXED_VALUE_TYPE_KEYWORD, + res: nil, + err: serviceerror.NewInvalidArgument("Invalid page token: expected string type, got 123"), + }, + { + name: "TextFieldError", + value: "foo", + tp: enumspb.INDEXED_VALUE_TYPE_TEXT, + res: nil, + err: serviceerror.NewInvalidArgument("Invalid field type in sorter: cannot order by \"TextFieldError\""), + }, + } + + pageToken := &visibilityPageToken{} + for _, tc := range testCases { + pageToken.SearchAfter = append(pageToken.SearchAfter, tc.value) + } + jsonToken, _ := json.Marshal(pageToken) + pageToken, err := s.visibilityStore.deserializePageToken(jsonToken) + s.NoError(err) + s.Equal(len(testCases), len(pageToken.SearchAfter)) + for i, tc := range testCases { + s.T().Run(tc.name, func(t *testing.T) { + res, err := parsePageTokenValue(tc.name, pageToken.SearchAfter[i], tc.tp) + s.Equal(tc.err, err) + s.Equal(tc.res, res) + }) + } +}