From 677235bf31dac4898b4f46c5722d33d494b4f8b0 Mon Sep 17 00:00:00 2001 From: Kaniu Date: Sat, 30 Dec 2023 09:14:04 +0800 Subject: [PATCH] full support for minimum_should_match (#940) full support for min should match --- pkg/meta/query_dsl.go | 16 ++-- pkg/uquery/query/bool.go | 29 +++--- pkg/uquery/query/match.go | 63 ++++++++++++- pkg/uquery/query/multi_match.go | 28 ++---- pkg/zutils/min_should.go | 152 ++++++++++++++++++++++++++++++++ pkg/zutils/min_should_test.go | 53 +++++++++++ 6 files changed, 294 insertions(+), 47 deletions(-) create mode 100644 pkg/zutils/min_should.go create mode 100644 pkg/zutils/min_should_test.go diff --git a/pkg/meta/query_dsl.go b/pkg/meta/query_dsl.go index c32e81c77..41d0ee5f5 100644 --- a/pkg/meta/query_dsl.go +++ b/pkg/meta/query_dsl.go @@ -102,7 +102,7 @@ type BoolQuery struct { Must interface{} `json:"must,omitempty"` // query, [query1, query2] MustNot interface{} `json:"must_not,omitempty"` // query, [query1, query2] Filter interface{} `json:"filter,omitempty"` // query, [query1, query2] - MinimumShouldMatch float64 `json:"minimum_should_match,omitempty"` // only for should + MinimumShouldMatch interface{} `json:"minimum_should_match,omitempty"` // only for should } type BoolQueryForSDK struct { @@ -151,13 +151,13 @@ type MatchPhrasePrefixQuery struct { } type MultiMatchQuery struct { - Query string `json:"query,omitempty"` - Analyzer string `json:"analyzer,omitempty"` - Fields []string `json:"fields,omitempty"` - Boost float64 `json:"boost,omitempty"` - Type string `json:"type,omitempty"` // best_fields(default), most_fields, cross_fields, phrase, phrase_prefix, bool_prefix - Operator string `json:"operator,omitempty"` // or(default), and - MinimumShouldMatch float64 `json:"minimum_should_match,omitempty"` + Query string `json:"query,omitempty"` + Analyzer string `json:"analyzer,omitempty"` + Fields []string `json:"fields,omitempty"` + Boost float64 `json:"boost,omitempty"` + Type string `json:"type,omitempty"` // best_fields(default), most_fields, cross_fields, phrase, phrase_prefix, bool_prefix + Operator string `json:"operator,omitempty"` // or(default), and + MinimumShouldMatch interface{} `json:"minimum_should_match,omitempty"` } type CombinedFieldsQuery struct { diff --git a/pkg/uquery/query/bool.go b/pkg/uquery/query/bool.go index 8d362568c..68638debd 100644 --- a/pkg/uquery/query/bool.go +++ b/pkg/uquery/query/bool.go @@ -17,7 +17,6 @@ package query import ( "fmt" - "strconv" "strings" "github.com/blugelabs/bluge" @@ -25,10 +24,12 @@ import ( "github.com/zincsearch/zincsearch/pkg/errors" "github.com/zincsearch/zincsearch/pkg/meta" + "github.com/zincsearch/zincsearch/pkg/zutils" ) func BoolQuery(query map[string]interface{}, mappings *meta.Mappings, analyzers map[string]*analysis.Analyzer) (bluge.Query, error) { boolQuery := bluge.NewBooleanQuery() + var minimumShouldMatch interface{} for k, v := range query { k := strings.ToLower(k) switch k { @@ -111,27 +112,19 @@ func BoolQuery(query map[string]interface{}, mappings *meta.Mappings, analyzers } boolQuery.AddMust(filterQuery) case "minimum_should_match": - switch v := v.(type) { - case string: - if strings.Contains(v, "%") || strings.Contains(v, "<") { - return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[bool] %s value only support integer", k)) - } - vi, err := strconv.ParseInt(v, 10, 64) - if err != nil { - return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[bool] %s type string convert to int error: %s", k, err)) - } - boolQuery.SetMinShould(int(vi)) // lgtm[go/hardcoded-credentials] - case int: - boolQuery.SetMinShould(v) - case float64: - boolQuery.SetMinShould(int(v)) - default: - return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[bool] %s doesn't support values of type: %T", k, v)) - } + minimumShouldMatch = v default: return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[bool] unknown field [%s]", k)) } } + if minimumShouldMatch != nil { + minValue, err := zutils.CalculateMin(len(boolQuery.Shoulds()), minimumShouldMatch) + if err != nil { + return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[bool] unsupported MinimumShouldMatch value: %v", err)) + } + boolQuery.SetMinShould(minValue) + } + return boolQuery, nil } diff --git a/pkg/uquery/query/match.go b/pkg/uquery/query/match.go index 304cd9caa..8481f4148 100644 --- a/pkg/uquery/query/match.go +++ b/pkg/uquery/query/match.go @@ -21,10 +21,10 @@ import ( "github.com/blugelabs/bluge" "github.com/blugelabs/bluge/analysis" - "github.com/zincsearch/zincsearch/pkg/errors" "github.com/zincsearch/zincsearch/pkg/meta" zincanalysis "github.com/zincsearch/zincsearch/pkg/uquery/analysis" + zincanalyzer "github.com/zincsearch/zincsearch/pkg/uquery/analysis/analyzer" "github.com/zincsearch/zincsearch/pkg/zutils" ) @@ -36,6 +36,7 @@ func MatchQuery(query map[string]interface{}, mappings *meta.Mappings, analyzers field := "" value := new(meta.MatchQuery) value.Boost = -1.0 + var minimumShouldMatch interface{} for k, v := range query { field = k switch v := v.(type) { @@ -57,6 +58,8 @@ func MatchQuery(query map[string]interface{}, mappings *meta.Mappings, analyzers value.PrefixLength, _ = zutils.ToFloat64(v) case "boost": value.Boost, _ = zutils.ToFloat64(v) + case "minimum_should_match": + minimumShouldMatch = v default: // return nil, errors.New(errors.ErrorTypeParsingException, fmt.Sprintf("[match] unknown field [%s]", k)) } @@ -83,6 +86,11 @@ func MatchQuery(query map[string]interface{}, mappings *meta.Mappings, analyzers } } + // only "OR" supports minimum should match + if minimumShouldMatch != nil && (value.Operator == "" || strings.ToUpper(value.Operator) == "OR") { + return genQueryWithMinimumShouldMatch(zer, field, value, minimumShouldMatch) + } + subq := bluge.NewMatchQuery(value.Query).SetField(field) if zer != nil { subq.SetAnalyzer(zer) @@ -115,3 +123,56 @@ func MatchQuery(query map[string]interface{}, mappings *meta.Mappings, analyzers return subq, nil } + +func genQueryWithMinimumShouldMatch(ana *analysis.Analyzer, field string, value *meta.MatchQuery, minimumShouldMatch interface{}) (bluge.Query, error) { + if ana == nil { + ana, _ = zincanalyzer.NewStandardAnalyzer(nil) + } + + var fuzziness int + if value.Fuzziness != nil { + if value.Fuzziness != nil { + v := ParseFuzziness(value.Fuzziness, value.Query, ana) + if v > 0 { + fuzziness = v + } + } + } + + var boost float64 = 1 + if value.Boost >= 0 { + boost = value.Boost + } + + tokens := ana.Analyze([]byte(value.Query)) + if len(tokens) > 0 { + tqs := make([]bluge.Query, len(tokens)) + if fuzziness != 0 { + for i, token := range tokens { + query := bluge.NewFuzzyQuery(string(token.Term)) + query.SetFuzziness(fuzziness) + query.SetPrefix(int(value.PrefixLength)) + query.SetField(field) + query.SetBoost(boost) + tqs[i] = query + } + } else { + for i, token := range tokens { + tq := bluge.NewTermQuery(string(token.Term)) + tq.SetField(field) + tq.SetBoost(boost) + tqs[i] = tq + } + } + minValue, err := zutils.CalculateMin(len(tokens), minimumShouldMatch) + if err != nil { + return nil, err + } + booleanQuery := bluge.NewBooleanQuery() + booleanQuery.AddShould(tqs...) + booleanQuery.SetMinShould(minValue) + booleanQuery.SetBoost(boost) + return booleanQuery, nil + } + return bluge.NewMatchNoneQuery(), nil +} diff --git a/pkg/uquery/query/multi_match.go b/pkg/uquery/query/multi_match.go index 4f4d7cacb..d57eb7501 100644 --- a/pkg/uquery/query/multi_match.go +++ b/pkg/uquery/query/multi_match.go @@ -17,7 +17,6 @@ package query import ( "fmt" - "strconv" "strings" "github.com/blugelabs/bluge" @@ -26,6 +25,7 @@ import ( "github.com/zincsearch/zincsearch/pkg/errors" "github.com/zincsearch/zincsearch/pkg/meta" zincanalysis "github.com/zincsearch/zincsearch/pkg/uquery/analysis" + "github.com/zincsearch/zincsearch/pkg/zutils" ) func MultiMatchQuery(query map[string]interface{}, mappings *meta.Mappings, analyzers map[string]*analysis.Analyzer) (bluge.Query, error) { @@ -51,23 +51,7 @@ func MultiMatchQuery(query map[string]interface{}, mappings *meta.Mappings, anal case "operator": value.Operator = v.(string) case "minimum_should_match": - switch v := v.(type) { - case string: - if strings.Contains(v, "%") || strings.Contains(v, "<") { - return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[multi_match] %s value only support integer", k)) - } - vi, err := strconv.ParseInt(v, 10, 64) - if err != nil { - return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[multi_match] %s type string convert to int error: %s", k, err)) - } - value.MinimumShouldMatch = float64(vi) - case int: - value.MinimumShouldMatch = float64(v) - case float64: - value.MinimumShouldMatch = v - default: - return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[multi_match] %s doesn't support values of type: %T", k, v)) - } + value.MinimumShouldMatch = v default: // return nil, errors.New(errors.ErrorTypeParsingException, fmt.Sprintf("[multi_match] unknown field [%s]", k)) } @@ -92,8 +76,12 @@ func MultiMatchQuery(query map[string]interface{}, mappings *meta.Mappings, anal } subq := bluge.NewBooleanQuery() - if value.MinimumShouldMatch > 0 { - subq.SetMinShould(int(value.MinimumShouldMatch)) // lgtm[go/hardcoded-credentials] + if value.MinimumShouldMatch != nil { + minValue, err := zutils.CalculateMin(len(value.Fields), value.MinimumShouldMatch) + if err != nil { + return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[multi_match] unsupported MinimumShouldMatch value: %v", err)) + } + subq.SetMinShould(minValue) // lgtm[go/hardcoded-credentials] } if value.Boost >= 0 { subq.SetBoost(value.Boost) diff --git a/pkg/zutils/min_should.go b/pkg/zutils/min_should.go new file mode 100644 index 000000000..942f767ce --- /dev/null +++ b/pkg/zutils/min_should.go @@ -0,0 +1,152 @@ +package zutils + +import ( + "fmt" + "math" + "regexp" + "sort" + "strconv" + "strings" +) + +type combination struct { + condition int + count int +} + +var regex = regexp.MustCompile(`(\d+)<([-+]?\d+%?)`) + +// CalculateMin +// calculate the MinimumShouldMatch value with given expr and sub query count. +func CalculateMin(subCount int, v interface{}) (res int, err error) { + if subCount == 0 { + return 1, nil + } + defer func() { + if err != nil { + return + } + if res <= 1 { + res = 1 + return + } + if res >= subCount { + res = subCount + return + } + }() + switch x := v.(type) { + case int64, int, float64: + m := 0 + switch val := v.(type) { + case int: + m = val + case int64: + m = int(val) + case float64: + m = int(math.Floor(val)) + } + if m < 0 { + m = subCount + m + } + return m, nil + case []string: + conditions := make([]combination, len(x)) + for i, str := range x { + match := regex.FindStringSubmatch(str) + if match != nil { + leftPart := match[1] + rightPart := match[2] + condition, err := strconv.ParseInt(leftPart, 10, 64) + if err != nil { + return 0, fmt.Errorf("cannot parse the condition value: %w", err) + } + count, err := getPartValue(subCount, rightPart) + if err != nil { + return 0, fmt.Errorf("cannot parse the clauses count: %w", err) + } + conditions[i] = combination{ + condition: int(condition), + count: count, + } + } else { + return 0, fmt.Errorf("invalid MinimumShould value: %s", x) + } + } + + sort.Slice(conditions, func(i, j int) bool { + return conditions[i].condition < conditions[j].condition + }) + + for i, condition := range conditions { + // only match first + if subCount <= condition.condition { + return subCount, nil + } + // we are the last one, matched + if i == len(conditions)-1 { + return condition.count, nil + } + // less than next, we matched + if subCount <= conditions[i+1].condition { + return condition.count, nil + } + } + return 0, fmt.Errorf("invalid MinimumShould value: %v", x) + case string: + combinations := strings.Split(x, " ") + if len(combinations) > 1 { + return CalculateMin(subCount, combinations) + } + // simple expr + if res, err := getPartValue(subCount, x); err == nil { + return res, nil + } + + // complex, we use regex + match := regex.FindStringSubmatch(x) + if match != nil { + leftPart := match[1] + rightPart := match[2] + condition, err := strconv.ParseInt(leftPart, 10, 64) + if err != nil { + return 0, fmt.Errorf("cannot parse the condition value: %w", err) + } + count, err := getPartValue(subCount, rightPart) + if err != nil { + return 0, fmt.Errorf("cannot parse the clauses count: %w", err) + } + if subCount <= int(condition) { + return subCount, nil + } + return count, nil + } else { + return 0, fmt.Errorf("invalid MinimumShould value: %s", x) + } + default: + return 0, fmt.Errorf("invalid MinimumShouldMatch value") + } +} + +func getPartValue(termCount int, part string) (int, error) { + if strings.Contains(part, "%") { + proportion, err := strconv.ParseInt(part[0:len(part)-1], 10, 64) + if err != nil { + return 0, fmt.Errorf("cannot parse a percent value: %w", err) + } + if proportion < 0 { + count := float64(termCount) * float64(-proportion) / 100 + return termCount - int(count), nil + } + count := float64(termCount) * float64(proportion) / 100 + return int(count), nil + } + count, err := strconv.ParseInt(part, 10, 64) + if err != nil { + return 0, fmt.Errorf("cannot parse a int value: %w", err) + } + if count < 0 { + return termCount + int(count), nil + } + return int(count), nil +} diff --git a/pkg/zutils/min_should_test.go b/pkg/zutils/min_should_test.go new file mode 100644 index 000000000..33b19650e --- /dev/null +++ b/pkg/zutils/min_should_test.go @@ -0,0 +1,53 @@ +package zutils + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCalculateMin(t *testing.T) { + cases := []struct { + subCount int + value interface{} + want int + }{ + // Simple Integer + {subCount: 5, value: 3, want: 3}, + {subCount: 10, value: "2", want: 2}, + {subCount: 8, value: int64(7), want: 7}, + {subCount: 9, value: 3.0, want: 3}, + {subCount: 5, value: 5.7, want: 5}, + + {subCount: 5, value: -10, want: 1}, + {subCount: 3, value: 5, want: 3}, + + // Negative Integer + {subCount: 10, value: -2, want: 8}, + {subCount: 8, value: "-5", want: 3}, + {subCount: 15, value: -3.0, want: 12}, + {subCount: 9, value: -3.5, want: 5}, + + // percent + {subCount: 10, value: "80%", want: 8}, + {subCount: 10, value: "-20%", want: 8}, + {subCount: 5, value: "75%", want: 3}, + {subCount: 5, value: "-25%", want: 4}, + + // combination + {subCount: 4, value: "5<90%", want: 4}, + {subCount: 5, value: "5<90%", want: 5}, + {subCount: 7, value: "5<3", want: 3}, + {subCount: 5, value: "2<-25%", want: 4}, + + // multi combinations + {subCount: 2, value: "2<-25% 9<-3", want: 2}, + {subCount: 5, value: "4<-25% 9<-3", want: 4}, + {subCount: 10, value: "4<-40% 9<-3", want: 7}, + } + for _, c := range cases { + v, err := CalculateMin(c.subCount, c.value) + assert.Nil(t, err) + assert.Equal(t, c.want, v) + } +}