Skip to content

Commit

Permalink
Manually build ES pagination query for default sorter (#4271)
Browse files Browse the repository at this point in the history
  • Loading branch information
rodrigozhou committed May 11, 2023
1 parent d040cdf commit a1fe0e6
Show file tree
Hide file tree
Showing 2 changed files with 545 additions and 36 deletions.
280 changes: 250 additions & 30 deletions common/persistence/visibility/store/elasticsearch/visibility_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"encoding/json"
"errors"
"fmt"
"math"
"strconv"
"strings"
"time"
Expand All @@ -56,17 +57,6 @@ const (
delimiter = "~"
)

// Default sort by uses the sorting order defined in the index template, so no
// additional sorting is needed during query.
var defaultSorter = []elastic.Sorter{
elastic.NewFieldSort(searchattribute.CloseTime).Desc().Missing("_first"),
elastic.NewFieldSort(searchattribute.StartTime).Desc().Missing("_first"),
}

var docSorter = []elastic.Sorter{
elastic.SortByDoc{},
}

type (
visibilityStore struct {
esClient client.Client
Expand All @@ -82,12 +72,45 @@ type (
visibilityPageToken struct {
SearchAfter []interface{}
}

fieldSort struct {
name string
desc bool
missing_first bool
}
)

var _ store.VisibilityStore = (*visibilityStore)(nil)

var (
errUnexpectedJSONFieldType = errors.New("unexpected JSON field type")

// Default sorter uses the sorting order defined in the index template.
// It is indirectly built so buildPaginationQuery can have access to
// the fields names to build the page query from the token.
defaultSorterFields = []fieldSort{
{searchattribute.CloseTime, true, true},
{searchattribute.StartTime, true, true},
}

defaultSorter = func() []elastic.Sorter {
ret := make([]elastic.Sorter, 0, len(defaultSorterFields))
for _, item := range defaultSorterFields {
fs := elastic.NewFieldSort(item.name)
if item.desc {
fs.Desc()
}
if item.missing_first {
fs.Missing("_first")
}
ret = append(ret, fs)
}
return ret
}()

docSorter = []elastic.Sorter{
elastic.SortByDoc{},
}
)

// NewVisibilityStore create a visibility store connecting to ElasticSearch
Expand Down Expand Up @@ -439,15 +462,6 @@ func (s *visibilityStore) ListWorkflowExecutions(
return nil, err
}

token, err := s.deserializePageToken(request.NextPageToken)
if err != nil {
return nil, err
}

if token != nil && len(token.SearchAfter) > 0 {
p.SearchAfter = token.SearchAfter
}

searchResult, err := s.esClient.Search(ctx, p)
if err != nil {
return nil, convertElasticsearchClientError("ListWorkflowExecutions failed", err)
Expand All @@ -465,15 +479,6 @@ func (s *visibilityStore) ScanWorkflowExecutions(
return nil, err
}

token, err := s.deserializePageToken(request.NextPageToken)
if err != nil {
return nil, err
}

if token != nil && len(token.SearchAfter) > 0 {
p.SearchAfter = token.SearchAfter
}

searchResult, err := s.esClient.Search(ctx, p)
if err != nil {
return nil, convertElasticsearchClientError("ScanWorkflowExecutions failed", err)
Expand Down Expand Up @@ -588,7 +593,6 @@ func (s *visibilityStore) buildSearchParametersV2(
request *manager.ListWorkflowExecutionsRequestV2,
getFieldSorter func([]*elastic.FieldSort) ([]elastic.Sorter, error),
) (*client.SearchParameters, error) {

boolQuery, fieldSorts, err := s.convertQuery(
request.Namespace,
request.NamespaceID,
Expand Down Expand Up @@ -619,9 +623,63 @@ func (s *visibilityStore) buildSearchParametersV2(
Sorter: sorter,
}

pageToken, err := s.deserializePageToken(request.NextPageToken)
if err != nil {
return nil, err
}
err = s.processPageToken(params, pageToken)
if err != nil {
return nil, err
}

return params, nil
}

func (s *visibilityStore) processPageToken(
params *client.SearchParameters,
pageToken *visibilityPageToken,
) error {
if pageToken == nil || len(pageToken.SearchAfter) == 0 {
return nil
}
if len(pageToken.SearchAfter) != len(params.Sorter) {
return serviceerror.NewInvalidArgument(fmt.Sprintf(
"Invalid page token for given sort fields: expected %d fields, got %d",
len(params.Sorter),
len(pageToken.SearchAfter),
))
}
if !isDefaultSorter(params.Sorter) {
params.SearchAfter = pageToken.SearchAfter
return nil
}

boolQuery, ok := params.Query.(*elastic.BoolQuery)
if !ok {
return serviceerror.NewInternal(fmt.Sprintf(
"Unexpected query type: expected *elastic.BoolQuery, got %T",
params.Query,
))
}

saTypeMap, err := s.searchAttributesProvider.GetSearchAttributes(s.index, false)
if err != nil {
return serviceerror.NewUnavailable(
fmt.Sprintf("Unable to read search attribute types: %v", err),
)
}

// build pagination search query for default sorter
shouldQueries, err := buildPaginationQuery(defaultSorterFields, pageToken.SearchAfter, saTypeMap)
if err != nil {
return err
}

boolQuery.Should(shouldQueries...)
boolQuery.MinimumNumberShouldMatch(1)
return nil
}

func (s *visibilityStore) convertQuery(
namespace namespace.Name,
namespaceID namespace.ID,
Expand Down Expand Up @@ -998,3 +1056,165 @@ func detailedErrorMessage(err error) string {
}
return sb.String()
}

func isDefaultSorter(sorter []elastic.Sorter) bool {
if len(sorter) != len(defaultSorter) {
return false
}
for i := 0; i < len(defaultSorter); i++ {
if &sorter[i] != &defaultSorter[i] {
return false
}
}
return true
}

// buildPaginationQuery builds the Elasticsearch conditions for the next page based on searchAfter.
//
// For example, if sorterFields = [A, B, C] and searchAfter = [lastA, lastB, lastC],
// it will build the following conditions (assuming all values are non-null and orders are desc):
// - k = 0: A < lastA
// - k = 1: A = lastA AND B < lastB
// - k = 2: A = lastA AND B = lastB AND C < lastC
//
//nolint:revive // cyclomatic complexity
func buildPaginationQuery(
sorterFields []fieldSort,
searchAfter []any,
saTypeMap searchattribute.NameTypeMap,
) ([]elastic.Query, error) {
n := len(sorterFields)
if len(sorterFields) != len(searchAfter) {
return nil, serviceerror.NewInvalidArgument(fmt.Sprintf(
"Invalid page token for given sort fields: expected %d fields, got %d",
len(sorterFields),
len(searchAfter),
))
}

parsedSearchAfter := make([]any, n)
for i := 0; i < n; i++ {
tp, err := saTypeMap.GetType(sorterFields[i].name)
if err != nil {
return nil, err
}
parsedSearchAfter[i], err = parsePageTokenValue(sorterFields[i].name, searchAfter[i], tp)
if err != nil {
return nil, err
}
}

// Last field of sorter must be a tie breaker, and thus cannot contain null value.
if parsedSearchAfter[len(parsedSearchAfter)-1] == nil {
return nil, serviceerror.NewInternal(fmt.Sprintf(
"Last field of sorter cannot be a nullable field: %q has null values",
sorterFields[len(sorterFields)-1].name,
))
}

shouldQueries := make([]elastic.Query, 0, len(sorterFields))
for k := 0; k < len(sorterFields); k++ {
bq := elastic.NewBoolQuery()
for i := 0; i <= k; i++ {
field := sorterFields[i]
value := parsedSearchAfter[i]
if i == k {
if value == nil {
bq.Filter(elastic.NewExistsQuery(field.name))
} else if field.desc {
bq.Filter(elastic.NewRangeQuery(field.name).Lt(value))
} else {
bq.Filter(elastic.NewRangeQuery(field.name).Gt(value))
}
} else {
if value == nil {
bq.MustNot(elastic.NewExistsQuery(field.name))
} else {
bq.Filter(elastic.NewTermQuery(field.name, value))
}
}
}
shouldQueries = append(shouldQueries, bq)
}
return shouldQueries, nil
}

// parsePageTokenValue parses the page token values to be used in the search query.
// The page token comes from the `sort` field from the previous response from Elasticsearch.
// Depending on the type of the field, the null value is represented differently:
// - integer, bool, and datetime: MaxInt64 (desc) or MinInt64 (asc)
// - double: "Infinity" (desc) or "-Infinity" (asc)
// - keyword: nil
//
// Furthermore, for bool and datetime, they need to be converted to boolean or the RFC3339Nano
// formats respectively.
//
//nolint:revive // cyclomatic complexity
func parsePageTokenValue(
fieldName string, jsonValue any,
tp enumspb.IndexedValueType,
) (any, error) {
switch tp {
case enumspb.INDEXED_VALUE_TYPE_INT,
enumspb.INDEXED_VALUE_TYPE_BOOL,
enumspb.INDEXED_VALUE_TYPE_DATETIME:
jsonNumber, ok := jsonValue.(json.Number)
if !ok {
return nil, serviceerror.NewInvalidArgument(fmt.Sprintf(
"Invalid page token: expected interger type, got %q", jsonValue))
}
num, err := jsonNumber.Int64()
if err != nil {
return nil, serviceerror.NewInvalidArgument(fmt.Sprintf(
"Invalid page token: expected interger type, got %v", jsonValue))
}
if num == math.MaxInt64 || num == math.MinInt64 {
return nil, nil
}
if tp == enumspb.INDEXED_VALUE_TYPE_BOOL {
return num != 0, nil
}
if tp == enumspb.INDEXED_VALUE_TYPE_DATETIME {
return time.Unix(0, num).UTC().Format(time.RFC3339Nano), nil
}
return num, nil

case enumspb.INDEXED_VALUE_TYPE_DOUBLE:
switch v := jsonValue.(type) {
case json.Number:
num, err := v.Float64()
if err != nil {
return nil, serviceerror.NewInvalidArgument(fmt.Sprintf(
"Invalid page token: expected float type, got %v", jsonValue))
}
return num, nil
case string:
// it can be the string representation of infinity
if _, err := strconv.ParseFloat(v, 64); err != nil {
return nil, serviceerror.NewInvalidArgument(fmt.Sprintf(
"Invalid page token: expected float type, got %q", jsonValue))
}
return nil, nil
default:
// it should never reach here
return nil, serviceerror.NewInvalidArgument(fmt.Sprintf(
"Invalid page token: expected float type, got %#v", jsonValue))
}

case enumspb.INDEXED_VALUE_TYPE_KEYWORD:
if jsonValue == nil {
return nil, nil
}
if _, ok := jsonValue.(string); !ok {
return nil, serviceerror.NewInvalidArgument(fmt.Sprintf(
"Invalid page token: expected string type, got %v", jsonValue))
}
return jsonValue, nil

default:
return nil, serviceerror.NewInvalidArgument(fmt.Sprintf(
"Invalid field type in sorter: cannot order by %q",
fieldName,
))
}
}
Loading

0 comments on commit a1fe0e6

Please sign in to comment.