Skip to content

Commit

Permalink
full support for minimum_should_match (#940)
Browse files Browse the repository at this point in the history
full support for min should match
  • Loading branch information
KaniuBillows committed Dec 30, 2023
1 parent ea7a4bf commit 677235b
Show file tree
Hide file tree
Showing 6 changed files with 294 additions and 47 deletions.
16 changes: 8 additions & 8 deletions pkg/meta/query_dsl.go
Expand Up @@ -102,7 +102,7 @@ type BoolQuery struct {
Must interface{} `json:"must,omitempty"` // query, [query1, query2]
MustNot interface{} `json:"must_not,omitempty"` // query, [query1, query2]
Filter interface{} `json:"filter,omitempty"` // query, [query1, query2]
MinimumShouldMatch float64 `json:"minimum_should_match,omitempty"` // only for should
MinimumShouldMatch interface{} `json:"minimum_should_match,omitempty"` // only for should
}

type BoolQueryForSDK struct {
Expand Down Expand Up @@ -151,13 +151,13 @@ type MatchPhrasePrefixQuery struct {
}

type MultiMatchQuery struct {
Query string `json:"query,omitempty"`
Analyzer string `json:"analyzer,omitempty"`
Fields []string `json:"fields,omitempty"`
Boost float64 `json:"boost,omitempty"`
Type string `json:"type,omitempty"` // best_fields(default), most_fields, cross_fields, phrase, phrase_prefix, bool_prefix
Operator string `json:"operator,omitempty"` // or(default), and
MinimumShouldMatch float64 `json:"minimum_should_match,omitempty"`
Query string `json:"query,omitempty"`
Analyzer string `json:"analyzer,omitempty"`
Fields []string `json:"fields,omitempty"`
Boost float64 `json:"boost,omitempty"`
Type string `json:"type,omitempty"` // best_fields(default), most_fields, cross_fields, phrase, phrase_prefix, bool_prefix
Operator string `json:"operator,omitempty"` // or(default), and
MinimumShouldMatch interface{} `json:"minimum_should_match,omitempty"`
}

type CombinedFieldsQuery struct {
Expand Down
29 changes: 11 additions & 18 deletions pkg/uquery/query/bool.go
Expand Up @@ -17,18 +17,19 @@ package query

import (
"fmt"
"strconv"
"strings"

"github.com/blugelabs/bluge"
"github.com/blugelabs/bluge/analysis"

"github.com/zincsearch/zincsearch/pkg/errors"
"github.com/zincsearch/zincsearch/pkg/meta"
"github.com/zincsearch/zincsearch/pkg/zutils"
)

func BoolQuery(query map[string]interface{}, mappings *meta.Mappings, analyzers map[string]*analysis.Analyzer) (bluge.Query, error) {
boolQuery := bluge.NewBooleanQuery()
var minimumShouldMatch interface{}
for k, v := range query {
k := strings.ToLower(k)
switch k {
Expand Down Expand Up @@ -111,27 +112,19 @@ func BoolQuery(query map[string]interface{}, mappings *meta.Mappings, analyzers
}
boolQuery.AddMust(filterQuery)
case "minimum_should_match":
switch v := v.(type) {
case string:
if strings.Contains(v, "%") || strings.Contains(v, "<") {
return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[bool] %s value only support integer", k))
}
vi, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[bool] %s type string convert to int error: %s", k, err))
}
boolQuery.SetMinShould(int(vi)) // lgtm[go/hardcoded-credentials]
case int:
boolQuery.SetMinShould(v)
case float64:
boolQuery.SetMinShould(int(v))
default:
return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[bool] %s doesn't support values of type: %T", k, v))
}
minimumShouldMatch = v
default:
return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[bool] unknown field [%s]", k))
}
}

if minimumShouldMatch != nil {
minValue, err := zutils.CalculateMin(len(boolQuery.Shoulds()), minimumShouldMatch)
if err != nil {
return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[bool] unsupported MinimumShouldMatch value: %v", err))
}
boolQuery.SetMinShould(minValue)
}

return boolQuery, nil
}
63 changes: 62 additions & 1 deletion pkg/uquery/query/match.go
Expand Up @@ -21,10 +21,10 @@ import (

"github.com/blugelabs/bluge"
"github.com/blugelabs/bluge/analysis"

"github.com/zincsearch/zincsearch/pkg/errors"
"github.com/zincsearch/zincsearch/pkg/meta"
zincanalysis "github.com/zincsearch/zincsearch/pkg/uquery/analysis"
zincanalyzer "github.com/zincsearch/zincsearch/pkg/uquery/analysis/analyzer"
"github.com/zincsearch/zincsearch/pkg/zutils"
)

Expand All @@ -36,6 +36,7 @@ func MatchQuery(query map[string]interface{}, mappings *meta.Mappings, analyzers
field := ""
value := new(meta.MatchQuery)
value.Boost = -1.0
var minimumShouldMatch interface{}
for k, v := range query {
field = k
switch v := v.(type) {
Expand All @@ -57,6 +58,8 @@ func MatchQuery(query map[string]interface{}, mappings *meta.Mappings, analyzers
value.PrefixLength, _ = zutils.ToFloat64(v)
case "boost":
value.Boost, _ = zutils.ToFloat64(v)
case "minimum_should_match":
minimumShouldMatch = v
default:
// return nil, errors.New(errors.ErrorTypeParsingException, fmt.Sprintf("[match] unknown field [%s]", k))
}
Expand All @@ -83,6 +86,11 @@ func MatchQuery(query map[string]interface{}, mappings *meta.Mappings, analyzers
}
}

// only "OR" supports minimum should match
if minimumShouldMatch != nil && (value.Operator == "" || strings.ToUpper(value.Operator) == "OR") {
return genQueryWithMinimumShouldMatch(zer, field, value, minimumShouldMatch)
}

subq := bluge.NewMatchQuery(value.Query).SetField(field)
if zer != nil {
subq.SetAnalyzer(zer)
Expand Down Expand Up @@ -115,3 +123,56 @@ func MatchQuery(query map[string]interface{}, mappings *meta.Mappings, analyzers

return subq, nil
}

func genQueryWithMinimumShouldMatch(ana *analysis.Analyzer, field string, value *meta.MatchQuery, minimumShouldMatch interface{}) (bluge.Query, error) {
if ana == nil {
ana, _ = zincanalyzer.NewStandardAnalyzer(nil)
}

var fuzziness int
if value.Fuzziness != nil {
if value.Fuzziness != nil {
v := ParseFuzziness(value.Fuzziness, value.Query, ana)
if v > 0 {
fuzziness = v
}
}
}

var boost float64 = 1
if value.Boost >= 0 {
boost = value.Boost
}

tokens := ana.Analyze([]byte(value.Query))
if len(tokens) > 0 {
tqs := make([]bluge.Query, len(tokens))
if fuzziness != 0 {
for i, token := range tokens {
query := bluge.NewFuzzyQuery(string(token.Term))
query.SetFuzziness(fuzziness)
query.SetPrefix(int(value.PrefixLength))
query.SetField(field)
query.SetBoost(boost)
tqs[i] = query
}
} else {
for i, token := range tokens {
tq := bluge.NewTermQuery(string(token.Term))
tq.SetField(field)
tq.SetBoost(boost)
tqs[i] = tq
}
}
minValue, err := zutils.CalculateMin(len(tokens), minimumShouldMatch)
if err != nil {
return nil, err
}
booleanQuery := bluge.NewBooleanQuery()
booleanQuery.AddShould(tqs...)
booleanQuery.SetMinShould(minValue)
booleanQuery.SetBoost(boost)
return booleanQuery, nil
}
return bluge.NewMatchNoneQuery(), nil
}
28 changes: 8 additions & 20 deletions pkg/uquery/query/multi_match.go
Expand Up @@ -17,7 +17,6 @@ package query

import (
"fmt"
"strconv"
"strings"

"github.com/blugelabs/bluge"
Expand All @@ -26,6 +25,7 @@ import (
"github.com/zincsearch/zincsearch/pkg/errors"
"github.com/zincsearch/zincsearch/pkg/meta"
zincanalysis "github.com/zincsearch/zincsearch/pkg/uquery/analysis"
"github.com/zincsearch/zincsearch/pkg/zutils"
)

func MultiMatchQuery(query map[string]interface{}, mappings *meta.Mappings, analyzers map[string]*analysis.Analyzer) (bluge.Query, error) {
Expand All @@ -51,23 +51,7 @@ func MultiMatchQuery(query map[string]interface{}, mappings *meta.Mappings, anal
case "operator":
value.Operator = v.(string)
case "minimum_should_match":
switch v := v.(type) {
case string:
if strings.Contains(v, "%") || strings.Contains(v, "<") {
return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[multi_match] %s value only support integer", k))
}
vi, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[multi_match] %s type string convert to int error: %s", k, err))
}
value.MinimumShouldMatch = float64(vi)
case int:
value.MinimumShouldMatch = float64(v)
case float64:
value.MinimumShouldMatch = v
default:
return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[multi_match] %s doesn't support values of type: %T", k, v))
}
value.MinimumShouldMatch = v
default:
// return nil, errors.New(errors.ErrorTypeParsingException, fmt.Sprintf("[multi_match] unknown field [%s]", k))
}
Expand All @@ -92,8 +76,12 @@ func MultiMatchQuery(query map[string]interface{}, mappings *meta.Mappings, anal
}

subq := bluge.NewBooleanQuery()
if value.MinimumShouldMatch > 0 {
subq.SetMinShould(int(value.MinimumShouldMatch)) // lgtm[go/hardcoded-credentials]
if value.MinimumShouldMatch != nil {
minValue, err := zutils.CalculateMin(len(value.Fields), value.MinimumShouldMatch)
if err != nil {
return nil, errors.New(errors.ErrorTypeXContentParseException, fmt.Sprintf("[multi_match] unsupported MinimumShouldMatch value: %v", err))
}
subq.SetMinShould(minValue) // lgtm[go/hardcoded-credentials]
}
if value.Boost >= 0 {
subq.SetBoost(value.Boost)
Expand Down
152 changes: 152 additions & 0 deletions pkg/zutils/min_should.go
@@ -0,0 +1,152 @@
package zutils

import (
"fmt"
"math"
"regexp"
"sort"
"strconv"
"strings"
)

type combination struct {
condition int
count int
}

var regex = regexp.MustCompile(`(\d+)<([-+]?\d+%?)`)

// CalculateMin
// calculate the MinimumShouldMatch value with given expr and sub query count.
func CalculateMin(subCount int, v interface{}) (res int, err error) {
if subCount == 0 {
return 1, nil
}
defer func() {
if err != nil {
return
}
if res <= 1 {
res = 1
return
}
if res >= subCount {
res = subCount
return
}
}()
switch x := v.(type) {
case int64, int, float64:
m := 0
switch val := v.(type) {
case int:
m = val
case int64:
m = int(val)
case float64:
m = int(math.Floor(val))
}
if m < 0 {
m = subCount + m
}
return m, nil
case []string:
conditions := make([]combination, len(x))
for i, str := range x {
match := regex.FindStringSubmatch(str)
if match != nil {
leftPart := match[1]
rightPart := match[2]
condition, err := strconv.ParseInt(leftPart, 10, 64)
if err != nil {
return 0, fmt.Errorf("cannot parse the condition value: %w", err)
}
count, err := getPartValue(subCount, rightPart)
if err != nil {
return 0, fmt.Errorf("cannot parse the clauses count: %w", err)
}
conditions[i] = combination{
condition: int(condition),
count: count,
}
} else {
return 0, fmt.Errorf("invalid MinimumShould value: %s", x)
}
}

sort.Slice(conditions, func(i, j int) bool {
return conditions[i].condition < conditions[j].condition
})

for i, condition := range conditions {
// only match first
if subCount <= condition.condition {
return subCount, nil
}
// we are the last one, matched
if i == len(conditions)-1 {
return condition.count, nil
}
// less than next, we matched
if subCount <= conditions[i+1].condition {
return condition.count, nil
}
}
return 0, fmt.Errorf("invalid MinimumShould value: %v", x)
case string:
combinations := strings.Split(x, " ")
if len(combinations) > 1 {
return CalculateMin(subCount, combinations)
}
// simple expr
if res, err := getPartValue(subCount, x); err == nil {
return res, nil
}

// complex, we use regex
match := regex.FindStringSubmatch(x)
if match != nil {
leftPart := match[1]
rightPart := match[2]
condition, err := strconv.ParseInt(leftPart, 10, 64)
if err != nil {
return 0, fmt.Errorf("cannot parse the condition value: %w", err)
}
count, err := getPartValue(subCount, rightPart)
if err != nil {
return 0, fmt.Errorf("cannot parse the clauses count: %w", err)
}
if subCount <= int(condition) {
return subCount, nil
}
return count, nil
} else {
return 0, fmt.Errorf("invalid MinimumShould value: %s", x)
}
default:
return 0, fmt.Errorf("invalid MinimumShouldMatch value")
}
}

func getPartValue(termCount int, part string) (int, error) {
if strings.Contains(part, "%") {
proportion, err := strconv.ParseInt(part[0:len(part)-1], 10, 64)
if err != nil {
return 0, fmt.Errorf("cannot parse a percent value: %w", err)
}
if proportion < 0 {
count := float64(termCount) * float64(-proportion) / 100
return termCount - int(count), nil
}
count := float64(termCount) * float64(proportion) / 100
return int(count), nil
}
count, err := strconv.ParseInt(part, 10, 64)
if err != nil {
return 0, fmt.Errorf("cannot parse a int value: %w", err)
}
if count < 0 {
return termCount + int(count), nil
}
return int(count), nil
}

0 comments on commit 677235b

Please sign in to comment.