Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

full support for minimum_should_match #940

Merged
merged 1 commit into from
Dec 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 8 additions & 8 deletions pkg/meta/query_dsl.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ type BoolQuery struct {
Must interface{} `json:"must,omitempty"` // query, [query1, query2]
MustNot interface{} `json:"must_not,omitempty"` // query, [query1, query2]
Filter interface{} `json:"filter,omitempty"` // query, [query1, query2]
MinimumShouldMatch float64 `json:"minimum_should_match,omitempty"` // only for should
MinimumShouldMatch interface{} `json:"minimum_should_match,omitempty"` // only for should
}

type BoolQueryForSDK struct {
Expand Down Expand Up @@ -151,13 +151,13 @@ type MatchPhrasePrefixQuery struct {
}

type MultiMatchQuery struct {
Query string `json:"query,omitempty"`
Analyzer string `json:"analyzer,omitempty"`
Fields []string `json:"fields,omitempty"`
Boost float64 `json:"boost,omitempty"`
Type string `json:"type,omitempty"` // best_fields(default), most_fields, cross_fields, phrase, phrase_prefix, bool_prefix
Operator string `json:"operator,omitempty"` // or(default), and
MinimumShouldMatch float64 `json:"minimum_should_match,omitempty"`
Query string `json:"query,omitempty"`
Analyzer string `json:"analyzer,omitempty"`
Fields []string `json:"fields,omitempty"`
Boost float64 `json:"boost,omitempty"`
Type string `json:"type,omitempty"` // best_fields(default), most_fields, cross_fields, phrase, phrase_prefix, bool_prefix
Operator string `json:"operator,omitempty"` // or(default), and
MinimumShouldMatch interface{} `json:"minimum_should_match,omitempty"`
}

type CombinedFieldsQuery struct {
Expand Down
29 changes: 11 additions & 18 deletions pkg/uquery/query/bool.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,19 @@ package query

import (
"fmt"
"strconv"
"strings"

"github.com/blugelabs/bluge"
"github.com/blugelabs/bluge/analysis"

"github.com/zincsearch/zincsearch/pkg/errors"
"github.com/zincsearch/zincsearch/pkg/meta"
"github.com/zincsearch/zincsearch/pkg/zutils"
)

func BoolQuery(query map[string]interface{}, mappings *meta.Mappings, analyzers map[string]*analysis.Analyzer) (bluge.Query, error) {
boolQuery := bluge.NewBooleanQuery()
var minimumShouldMatch interface{}
for k, v := range query {
k := strings.ToLower(k)
switch k {
Expand Down Expand Up @@ -111,27 +112,19 @@ func BoolQuery(query map[string]interface{}, mappings *meta.Mappings, analyzers
}
boolQuery.AddMust(filterQuery)
case "minimum_should_match":
switch v := v.(type) {
case string:
if strings.Contains(v, "%") || strings.Contains(v, "<") {
return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[bool] %s value only support integer", k))
}
vi, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[bool] %s type string convert to int error: %s", k, err))
}
boolQuery.SetMinShould(int(vi)) // lgtm[go/hardcoded-credentials]
case int:
boolQuery.SetMinShould(v)
case float64:
boolQuery.SetMinShould(int(v))
default:
return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[bool] %s doesn't support values of type: %T", k, v))
}
minimumShouldMatch = v
default:
return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[bool] unknown field [%s]", k))
}
}

if minimumShouldMatch != nil {
minValue, err := zutils.CalculateMin(len(boolQuery.Shoulds()), minimumShouldMatch)
if err != nil {
return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[bool] unsupported MinimumShouldMatch value: %v", err))
}
boolQuery.SetMinShould(minValue)
}

return boolQuery, nil
}
63 changes: 62 additions & 1 deletion pkg/uquery/query/match.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ import (

"github.com/blugelabs/bluge"
"github.com/blugelabs/bluge/analysis"

"github.com/zincsearch/zincsearch/pkg/errors"
"github.com/zincsearch/zincsearch/pkg/meta"
zincanalysis "github.com/zincsearch/zincsearch/pkg/uquery/analysis"
zincanalyzer "github.com/zincsearch/zincsearch/pkg/uquery/analysis/analyzer"
"github.com/zincsearch/zincsearch/pkg/zutils"
)

Expand All @@ -36,6 +36,7 @@ func MatchQuery(query map[string]interface{}, mappings *meta.Mappings, analyzers
field := ""
value := new(meta.MatchQuery)
value.Boost = -1.0
var minimumShouldMatch interface{}
for k, v := range query {
field = k
switch v := v.(type) {
Expand All @@ -57,6 +58,8 @@ func MatchQuery(query map[string]interface{}, mappings *meta.Mappings, analyzers
value.PrefixLength, _ = zutils.ToFloat64(v)
case "boost":
value.Boost, _ = zutils.ToFloat64(v)
case "minimum_should_match":
minimumShouldMatch = v
default:
// return nil, errors.New(errors.ErrorTypeParsingException, fmt.Sprintf("[match] unknown field [%s]", k))
}
Expand All @@ -83,6 +86,11 @@ func MatchQuery(query map[string]interface{}, mappings *meta.Mappings, analyzers
}
}

// only "OR" supports minimum should match
if minimumShouldMatch != nil && (value.Operator == "" || strings.ToUpper(value.Operator) == "OR") {
return genQueryWithMinimumShouldMatch(zer, field, value, minimumShouldMatch)
}

subq := bluge.NewMatchQuery(value.Query).SetField(field)
if zer != nil {
subq.SetAnalyzer(zer)
Expand Down Expand Up @@ -115,3 +123,56 @@ func MatchQuery(query map[string]interface{}, mappings *meta.Mappings, analyzers

return subq, nil
}

func genQueryWithMinimumShouldMatch(ana *analysis.Analyzer, field string, value *meta.MatchQuery, minimumShouldMatch interface{}) (bluge.Query, error) {
if ana == nil {
ana, _ = zincanalyzer.NewStandardAnalyzer(nil)
}

var fuzziness int
if value.Fuzziness != nil {
if value.Fuzziness != nil {
v := ParseFuzziness(value.Fuzziness, value.Query, ana)
if v > 0 {
fuzziness = v
}
}
}

var boost float64 = 1
if value.Boost >= 0 {
boost = value.Boost
}

tokens := ana.Analyze([]byte(value.Query))
if len(tokens) > 0 {
tqs := make([]bluge.Query, len(tokens))
if fuzziness != 0 {
for i, token := range tokens {
query := bluge.NewFuzzyQuery(string(token.Term))
query.SetFuzziness(fuzziness)
query.SetPrefix(int(value.PrefixLength))
query.SetField(field)
query.SetBoost(boost)
tqs[i] = query
}
} else {
for i, token := range tokens {
tq := bluge.NewTermQuery(string(token.Term))
tq.SetField(field)
tq.SetBoost(boost)
tqs[i] = tq
}
}
minValue, err := zutils.CalculateMin(len(tokens), minimumShouldMatch)
if err != nil {
return nil, err
}
booleanQuery := bluge.NewBooleanQuery()
booleanQuery.AddShould(tqs...)
booleanQuery.SetMinShould(minValue)
booleanQuery.SetBoost(boost)
return booleanQuery, nil
}
return bluge.NewMatchNoneQuery(), nil
}
28 changes: 8 additions & 20 deletions pkg/uquery/query/multi_match.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ package query

import (
"fmt"
"strconv"
"strings"

"github.com/blugelabs/bluge"
Expand All @@ -26,6 +25,7 @@ import (
"github.com/zincsearch/zincsearch/pkg/errors"
"github.com/zincsearch/zincsearch/pkg/meta"
zincanalysis "github.com/zincsearch/zincsearch/pkg/uquery/analysis"
"github.com/zincsearch/zincsearch/pkg/zutils"
)

func MultiMatchQuery(query map[string]interface{}, mappings *meta.Mappings, analyzers map[string]*analysis.Analyzer) (bluge.Query, error) {
Expand All @@ -51,23 +51,7 @@ func MultiMatchQuery(query map[string]interface{}, mappings *meta.Mappings, anal
case "operator":
value.Operator = v.(string)
case "minimum_should_match":
switch v := v.(type) {
case string:
if strings.Contains(v, "%") || strings.Contains(v, "<") {
return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[multi_match] %s value only support integer", k))
}
vi, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[multi_match] %s type string convert to int error: %s", k, err))
}
value.MinimumShouldMatch = float64(vi)
case int:
value.MinimumShouldMatch = float64(v)
case float64:
value.MinimumShouldMatch = v
default:
return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[multi_match] %s doesn't support values of type: %T", k, v))
}
value.MinimumShouldMatch = v
default:
// return nil, errors.New(errors.ErrorTypeParsingException, fmt.Sprintf("[multi_match] unknown field [%s]", k))
}
Expand All @@ -92,8 +76,12 @@ func MultiMatchQuery(query map[string]interface{}, mappings *meta.Mappings, anal
}

subq := bluge.NewBooleanQuery()
if value.MinimumShouldMatch > 0 {
subq.SetMinShould(int(value.MinimumShouldMatch)) // lgtm[go/hardcoded-credentials]
if value.MinimumShouldMatch != nil {
minValue, err := zutils.CalculateMin(len(value.Fields), value.MinimumShouldMatch)
if err != nil {
return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[multi_match] unsupported MinimumShouldMatch value: %v", err))
}
subq.SetMinShould(minValue) // lgtm[go/hardcoded-credentials]
}
if value.Boost >= 0 {
subq.SetBoost(value.Boost)
Expand Down
152 changes: 152 additions & 0 deletions pkg/zutils/min_should.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
package zutils

import (
"fmt"
"math"
"regexp"
"sort"
"strconv"
"strings"
)

type combination struct {
condition int
count int
}

var regex = regexp.MustCompile(`(\d+)<([-+]?\d+%?)`)

// CalculateMin
// calculate the MinimumShouldMatch value with given expr and sub query count.
func CalculateMin(subCount int, v interface{}) (res int, err error) {
if subCount == 0 {
return 1, nil
}
defer func() {
if err != nil {
return
}
if res <= 1 {
res = 1
return
}
if res >= subCount {
res = subCount
return
}
}()
switch x := v.(type) {
case int64, int, float64:
m := 0
switch val := v.(type) {
case int:
m = val
case int64:
m = int(val)
case float64:
m = int(math.Floor(val))
}
if m < 0 {
m = subCount + m
}
return m, nil
case []string:
conditions := make([]combination, len(x))
for i, str := range x {
match := regex.FindStringSubmatch(str)
if match != nil {
leftPart := match[1]
rightPart := match[2]
condition, err := strconv.ParseInt(leftPart, 10, 64)
if err != nil {
return 0, fmt.Errorf("cannot parse the condition value: %w", err)
}
count, err := getPartValue(subCount, rightPart)
if err != nil {
return 0, fmt.Errorf("cannot parse the clauses count: %w", err)
}
conditions[i] = combination{
condition: int(condition),
count: count,
}
} else {
return 0, fmt.Errorf("invalid MinimumShould value: %s", x)
}
}

sort.Slice(conditions, func(i, j int) bool {
return conditions[i].condition < conditions[j].condition
})

for i, condition := range conditions {
// only match first
if subCount <= condition.condition {
return subCount, nil
}
// we are the last one, matched
if i == len(conditions)-1 {
return condition.count, nil
}
// less than next, we matched
if subCount <= conditions[i+1].condition {
return condition.count, nil
}
}
return 0, fmt.Errorf("invalid MinimumShould value: %v", x)
case string:
combinations := strings.Split(x, " ")
if len(combinations) > 1 {
return CalculateMin(subCount, combinations)
}
// simple expr
if res, err := getPartValue(subCount, x); err == nil {
return res, nil
}

// complex, we use regex
match := regex.FindStringSubmatch(x)
if match != nil {
leftPart := match[1]
rightPart := match[2]
condition, err := strconv.ParseInt(leftPart, 10, 64)
if err != nil {
return 0, fmt.Errorf("cannot parse the condition value: %w", err)
}
count, err := getPartValue(subCount, rightPart)
if err != nil {
return 0, fmt.Errorf("cannot parse the clauses count: %w", err)
}
if subCount <= int(condition) {
return subCount, nil
}
return count, nil
} else {
return 0, fmt.Errorf("invalid MinimumShould value: %s", x)
}
default:
return 0, fmt.Errorf("invalid MinimumShouldMatch value")
}
}

func getPartValue(termCount int, part string) (int, error) {
if strings.Contains(part, "%") {
proportion, err := strconv.ParseInt(part[0:len(part)-1], 10, 64)
if err != nil {
return 0, fmt.Errorf("cannot parse a percent value: %w", err)
}
if proportion < 0 {
count := float64(termCount) * float64(-proportion) / 100
return termCount - int(count), nil
}
count := float64(termCount) * float64(proportion) / 100
return int(count), nil
}
count, err := strconv.ParseInt(part, 10, 64)
if err != nil {
return 0, fmt.Errorf("cannot parse a int value: %w", err)
}
if count < 0 {
return termCount + int(count), nil
}
return int(count), nil
}