Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unlimited vector search by Certainty #1883

Merged
merged 8 commits into from
Apr 1, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
61 changes: 50 additions & 11 deletions adapters/handlers/graphql/local/get/class_builder_fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ import (
"regexp"
"strings"

"github.com/graphql-go/graphql"
"github.com/graphql-go/graphql/language/ast"
"github.com/semi-technologies/weaviate/adapters/handlers/graphql/descriptions"
"github.com/semi-technologies/weaviate/adapters/handlers/graphql/local/common_filters"
"github.com/semi-technologies/weaviate/entities/additional"
"github.com/semi-technologies/weaviate/entities/filters"
"github.com/semi-technologies/weaviate/entities/models"
"github.com/semi-technologies/weaviate/entities/modulecapabilities"
"github.com/semi-technologies/weaviate/entities/schema"
"github.com/semi-technologies/weaviate/entities/search"
"github.com/semi-technologies/weaviate/usecases/traverser"

"github.com/graphql-go/graphql"
"github.com/graphql-go/graphql/language/ast"
)

func (b *classBuilder) primitiveField(propertyType schema.PropertyDataType,
Expand Down Expand Up @@ -322,7 +322,7 @@ func (r *resolver) makeResolveGetClass(className string) graphql.FieldResolveFn
return nil, err
}

filters, err := common_filters.ExtractFilters(p.Args, p.Info.FieldName)
filt, err := common_filters.ExtractFilters(p.Args, p.Info.FieldName)
antas-marcin marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return nil, fmt.Errorf("could not extract filters: %s", err)
}
Expand All @@ -339,12 +339,6 @@ func (r *resolver) makeResolveGetClass(className string) graphql.FieldResolveFn
nearObjectParams = &p
}

var keywordRankingParams *traverser.KeywordRankingParams
if bm25, ok := p.Args["bm25"]; ok {
p := common_filters.ExtractBM25(bm25.(map[string]interface{}))
keywordRankingParams = &p
}

var moduleParams map[string]interface{}
if r.modulesProvider != nil {
extractedParams := r.modulesProvider.ExtractSearchParams(p.Args, className)
Expand All @@ -353,10 +347,16 @@ func (r *resolver) makeResolveGetClass(className string) graphql.FieldResolveFn
}
}

var keywordRankingParams *traverser.KeywordRankingParams
if bm25, ok := p.Args["bm25"]; ok {
p := common_filters.ExtractBM25(bm25.(map[string]interface{}))
keywordRankingParams = &p
}

group := extractGroup(p.Args)

params := traverser.GetParams{
Filters: filters,
Filters: filt,
ClassName: className,
Pagination: pagination,
Properties: properties,
Expand All @@ -368,12 +368,51 @@ func (r *resolver) makeResolveGetClass(className string) graphql.FieldResolveFn
KeywordRanking: keywordRankingParams,
}

// need to perform vector search by distance
// under certain conditions
setLimitBasedOnVectorSearchParams(&params)

return func() (interface{}, error) {
return resolver.GetClass(p.Context, principalFromContext(p.Context), params)
}, nil
}
}

// the limit needs to be set according to the vector search parameters.
// for example, if a certainty is provided by any of the near* options,
// and no limit was provided, weaviate will want to execute a vector
// search by distance. it knows to do this by watching for a limit
// flag, specicially filters.LimitFlagSearchByDistance
func setLimitBasedOnVectorSearchParams(params *traverser.GetParams) {
setLimit := func(params *traverser.GetParams) {
if params.Pagination == nil {
params.Pagination = &filters.Pagination{
Limit: filters.LimitFlagSearchByDist,
}
} else {
params.Pagination.Limit = filters.LimitFlagSearchByDist
}
}

if params.NearVector != nil && params.NearVector.Certainty != 0 {
setLimit(params)
return
}

if params.NearObject != nil && params.NearObject.Certainty != 0 {
setLimit(params)
return
}

for _, param := range params.ModuleParams {
nearParam, ok := param.(modulecapabilities.NearParam)
if ok && nearParam.GetCertainty() != 0 {
setLimit(params)
return
}
}
}

func extractGroup(args map[string]interface{}) *traverser.GroupParams {
group, ok := args["group"]
if !ok {
Expand Down
15 changes: 11 additions & 4 deletions adapters/handlers/graphql/local/get/get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ func TestNearCustomTextRanker(t *testing.T) {

t.Run("for a class that does not have a text2vec module", func(t *testing.T) {
query := `{ Get { CustomVectorClass(nearCustomText: {
concepts: ["c1", "c2", "c3"],
concepts: ["c1", "c2", "c3"],
moveTo: {
concepts:["positive"],
force: 0.5
Expand All @@ -707,7 +707,7 @@ func TestNearCustomTextRanker(t *testing.T) {
concepts:["epic"]
force: 0.25
}
}) { intField } } }`
}) { intField } } }`

res := resolver.Resolve(query)
require.Len(t, res.Errors, 1)
Expand All @@ -716,7 +716,7 @@ func TestNearCustomTextRanker(t *testing.T) {

t.Run("for things with optional certainty set", func(t *testing.T) {
query := `{ Get { SomeThing(nearCustomText: {
concepts: ["c1", "c2", "c3"],
concepts: ["c1", "c2", "c3"],
certainty: 0.4,
moveTo: {
concepts:["positive"],
Expand All @@ -726,11 +726,12 @@ func TestNearCustomTextRanker(t *testing.T) {
concepts:["epic"]
force: 0.25
}
}) { intField } } }`
}) { intField } } }`

expectedParams := traverser.GetParams{
ClassName: "SomeThing",
Properties: []search.SelectProperty{{Name: "intField", IsPrimitive: true}},
Pagination: &filters.Pagination{Limit: filters.LimitFlagSearchByDist},
ModuleParams: map[string]interface{}{
"nearCustomText": extractNearTextParam(map[string]interface{}{
"concepts": []interface{}{"c1", "c2", "c3"},
Expand Down Expand Up @@ -778,6 +779,7 @@ func TestNearCustomTextRanker(t *testing.T) {
expectedParams := traverser.GetParams{
ClassName: "SomeThing",
Properties: []search.SelectProperty{{Name: "intField", IsPrimitive: true}},
Pagination: &filters.Pagination{Limit: filters.LimitFlagSearchByDist},
ModuleParams: map[string]interface{}{
"nearCustomText": extractNearTextParam(map[string]interface{}{
"concepts": []interface{}{"c1", "c2", "c3"},
Expand Down Expand Up @@ -852,6 +854,7 @@ func TestNearVectorRanker(t *testing.T) {
expectedParams := traverser.GetParams{
ClassName: "SomeThing",
Properties: []search.SelectProperty{{Name: "intField", IsPrimitive: true}},
Pagination: &filters.Pagination{Limit: filters.LimitFlagSearchByDist},
NearVector: &traverser.NearVectorParams{
Vector: []float32{0.123, 0.984},
Certainty: 0.4,
Expand Down Expand Up @@ -1113,6 +1116,7 @@ func TestNearObject(t *testing.T) {
expectedParams := traverser.GetParams{
ClassName: "SomeThing",
Properties: []search.SelectProperty{{Name: "intField", IsPrimitive: true}},
Pagination: &filters.Pagination{Limit: filters.LimitFlagSearchByDist},
NearObject: &traverser.NearObjectParams{
Beacon: "weaviate://localhost/some-other-uuid",
Certainty: 0.7,
Expand Down Expand Up @@ -1153,6 +1157,7 @@ func TestNearObject(t *testing.T) {

expectedParams := traverser.GetParams{
ClassName: "SomeThing",
Pagination: &filters.Pagination{Limit: filters.LimitFlagSearchByDist},
Properties: []search.SelectProperty{{Name: "intField", IsPrimitive: true}},
NearObject: &traverser.NearObjectParams{
ID: "some-other-uuid",
Expand Down Expand Up @@ -1230,6 +1235,7 @@ func TestNearObjectNoModules(t *testing.T) {

expectedParams := traverser.GetParams{
ClassName: "SomeThing",
Pagination: &filters.Pagination{Limit: filters.LimitFlagSearchByDist},
Properties: []search.SelectProperty{{Name: "intField", IsPrimitive: true}},
NearObject: &traverser.NearObjectParams{
ID: "some-uuid",
Expand Down Expand Up @@ -1277,6 +1283,7 @@ func TestNearVectorNoModules(t *testing.T) {
expectedParams := traverser.GetParams{
ClassName: "SomeThing",
Properties: []search.SelectProperty{{Name: "intField", IsPrimitive: true}},
Pagination: &filters.Pagination{Limit: filters.LimitFlagSearchByDist},
NearVector: &traverser.NearVectorParams{
Vector: []float32{0.123, 0.984},
Certainty: 0.4,
Expand Down
5 changes: 5 additions & 0 deletions adapters/handlers/graphql/local/get/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ type nearCustomTextParams struct {
Certainty float64
}

// implements the modulecapabilities.NearParam interface
func (n *nearCustomTextParams) GetCertainty() float64 {
return n.Certainty
}

type nearExploreMove struct {
Values []string
Force float32
Expand Down
6 changes: 3 additions & 3 deletions adapters/handlers/rest/clusterapi/indices.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ type shards interface {
MultiGetObjects(ctx context.Context, indexName, shardName string,
id []strfmt.UUID) ([]*storobj.Object, error)
Search(ctx context.Context, indexName, shardName string,
vector []float32, limit int, filters *filters.LocalFilter,
vector []float32, certainty float64, limit int, filters *filters.LocalFilter,
additional additional.Properties) ([]*storobj.Object, []float32, error)
Aggregate(ctx context.Context, indexName, shardName string,
params aggregation.Params) (*aggregation.Result, error)
Expand Down Expand Up @@ -460,7 +460,7 @@ func (i *indices) postSearchObjects() http.Handler {
return
}

vector, limit, filters, additional, err := IndicesPayloads.SearchParams.
vector, certainty, limit, filters, additional, err := IndicesPayloads.SearchParams.
Unmarshal(reqPayload)
if err != nil {
http.Error(w, "unmarshal search params from json: "+err.Error(),
Expand All @@ -469,7 +469,7 @@ func (i *indices) postSearchObjects() http.Handler {
}

results, dists, err := i.shards.Search(r.Context(), index, shard,
vector, limit, filters, additional)
vector, certainty, limit, filters, additional)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
Expand Down
5 changes: 3 additions & 2 deletions adapters/handlers/rest/clusterapi/indices_payloads.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,17 +233,18 @@ func (p searchParamsPayload) Marshal(vector []float32, limit int,
return json.Marshal(par)
}

func (p searchParamsPayload) Unmarshal(in []byte) ([]float32, int,
func (p searchParamsPayload) Unmarshal(in []byte) ([]float32, float64, int,
*filters.LocalFilter, additional.Properties, error) {
type searchParametersPayload struct {
SearchVector []float32 `json:"searchVector"`
Certainty float64 `json:"certainty"`
Limit int `json:"limit"`
Filters *filters.LocalFilter `json:"filters"`
Additional additional.Properties `json:"additional"`
}
var par searchParametersPayload
err := json.Unmarshal(in, &par)
return par.SearchVector, par.Limit, par.Filters, par.Additional, err
return par.SearchVector, par.Certainty, par.Limit, par.Filters, par.Additional, err
}

func (p searchParamsPayload) MIME() string {
Expand Down