Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unlimited vector search by Certainty #1883

Merged
merged 8 commits into from
Apr 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
57 changes: 48 additions & 9 deletions adapters/handlers/graphql/local/get/class_builder_fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ import (
"regexp"
"strings"

"github.com/graphql-go/graphql"
"github.com/graphql-go/graphql/language/ast"
"github.com/semi-technologies/weaviate/adapters/handlers/graphql/descriptions"
"github.com/semi-technologies/weaviate/adapters/handlers/graphql/local/common_filters"
"github.com/semi-technologies/weaviate/entities/additional"
"github.com/semi-technologies/weaviate/entities/filters"
"github.com/semi-technologies/weaviate/entities/models"
"github.com/semi-technologies/weaviate/entities/modulecapabilities"
"github.com/semi-technologies/weaviate/entities/schema"
"github.com/semi-technologies/weaviate/entities/search"
"github.com/semi-technologies/weaviate/usecases/traverser"

"github.com/graphql-go/graphql"
"github.com/graphql-go/graphql/language/ast"
)

func (b *classBuilder) primitiveField(propertyType schema.PropertyDataType,
Expand Down Expand Up @@ -339,12 +339,6 @@ func (r *resolver) makeResolveGetClass(className string) graphql.FieldResolveFn
nearObjectParams = &p
}

var keywordRankingParams *traverser.KeywordRankingParams
if bm25, ok := p.Args["bm25"]; ok {
p := common_filters.ExtractBM25(bm25.(map[string]interface{}))
keywordRankingParams = &p
}

var moduleParams map[string]interface{}
if r.modulesProvider != nil {
extractedParams := r.modulesProvider.ExtractSearchParams(p.Args, className)
Expand All @@ -353,6 +347,12 @@ func (r *resolver) makeResolveGetClass(className string) graphql.FieldResolveFn
}
}

var keywordRankingParams *traverser.KeywordRankingParams
if bm25, ok := p.Args["bm25"]; ok {
p := common_filters.ExtractBM25(bm25.(map[string]interface{}))
keywordRankingParams = &p
}

group := extractGroup(p.Args)

params := traverser.GetParams{
Expand All @@ -368,12 +368,51 @@ func (r *resolver) makeResolveGetClass(className string) graphql.FieldResolveFn
KeywordRanking: keywordRankingParams,
}

// need to perform vector search by distance
// under certain conditions
setLimitBasedOnVectorSearchParams(&params)

return func() (interface{}, error) {
return resolver.GetClass(p.Context, principalFromContext(p.Context), params)
}, nil
}
}

// the limit needs to be set according to the vector search parameters.
// for example, if a certainty is provided by any of the near* options,
// and no limit was provided, weaviate will want to execute a vector
// search by distance. it knows to do this by watching for a limit
// flag, specicially filters.LimitFlagSearchByDistance
func setLimitBasedOnVectorSearchParams(params *traverser.GetParams) {
setLimit := func(params *traverser.GetParams) {
if params.Pagination == nil {
params.Pagination = &filters.Pagination{
Limit: filters.LimitFlagSearchByDist,
}
} else {
params.Pagination.Limit = filters.LimitFlagSearchByDist
}
}

if params.NearVector != nil && params.NearVector.Certainty != 0 {
setLimit(params)
return
}

if params.NearObject != nil && params.NearObject.Certainty != 0 {
setLimit(params)
return
}

for _, param := range params.ModuleParams {
nearParam, ok := param.(modulecapabilities.NearParam)
if ok && nearParam.GetCertainty() != 0 {
setLimit(params)
return
}
}
}

func extractGroup(args map[string]interface{}) *traverser.GroupParams {
group, ok := args["group"]
if !ok {
Expand Down
15 changes: 11 additions & 4 deletions adapters/handlers/graphql/local/get/get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ func TestNearCustomTextRanker(t *testing.T) {

t.Run("for a class that does not have a text2vec module", func(t *testing.T) {
query := `{ Get { CustomVectorClass(nearCustomText: {
concepts: ["c1", "c2", "c3"],
concepts: ["c1", "c2", "c3"],
moveTo: {
concepts:["positive"],
force: 0.5
Expand All @@ -707,7 +707,7 @@ func TestNearCustomTextRanker(t *testing.T) {
concepts:["epic"]
force: 0.25
}
}) { intField } } }`
}) { intField } } }`

res := resolver.Resolve(query)
require.Len(t, res.Errors, 1)
Expand All @@ -716,7 +716,7 @@ func TestNearCustomTextRanker(t *testing.T) {

t.Run("for things with optional certainty set", func(t *testing.T) {
query := `{ Get { SomeThing(nearCustomText: {
concepts: ["c1", "c2", "c3"],
concepts: ["c1", "c2", "c3"],
certainty: 0.4,
moveTo: {
concepts:["positive"],
Expand All @@ -726,11 +726,12 @@ func TestNearCustomTextRanker(t *testing.T) {
concepts:["epic"]
force: 0.25
}
}) { intField } } }`
}) { intField } } }`

expectedParams := traverser.GetParams{
ClassName: "SomeThing",
Properties: []search.SelectProperty{{Name: "intField", IsPrimitive: true}},
Pagination: &filters.Pagination{Limit: filters.LimitFlagSearchByDist},
ModuleParams: map[string]interface{}{
"nearCustomText": extractNearTextParam(map[string]interface{}{
"concepts": []interface{}{"c1", "c2", "c3"},
Expand Down Expand Up @@ -778,6 +779,7 @@ func TestNearCustomTextRanker(t *testing.T) {
expectedParams := traverser.GetParams{
ClassName: "SomeThing",
Properties: []search.SelectProperty{{Name: "intField", IsPrimitive: true}},
Pagination: &filters.Pagination{Limit: filters.LimitFlagSearchByDist},
ModuleParams: map[string]interface{}{
"nearCustomText": extractNearTextParam(map[string]interface{}{
"concepts": []interface{}{"c1", "c2", "c3"},
Expand Down Expand Up @@ -852,6 +854,7 @@ func TestNearVectorRanker(t *testing.T) {
expectedParams := traverser.GetParams{
ClassName: "SomeThing",
Properties: []search.SelectProperty{{Name: "intField", IsPrimitive: true}},
Pagination: &filters.Pagination{Limit: filters.LimitFlagSearchByDist},
NearVector: &traverser.NearVectorParams{
Vector: []float32{0.123, 0.984},
Certainty: 0.4,
Expand Down Expand Up @@ -1113,6 +1116,7 @@ func TestNearObject(t *testing.T) {
expectedParams := traverser.GetParams{
ClassName: "SomeThing",
Properties: []search.SelectProperty{{Name: "intField", IsPrimitive: true}},
Pagination: &filters.Pagination{Limit: filters.LimitFlagSearchByDist},
NearObject: &traverser.NearObjectParams{
Beacon: "weaviate://localhost/some-other-uuid",
Certainty: 0.7,
Expand Down Expand Up @@ -1153,6 +1157,7 @@ func TestNearObject(t *testing.T) {

expectedParams := traverser.GetParams{
ClassName: "SomeThing",
Pagination: &filters.Pagination{Limit: filters.LimitFlagSearchByDist},
Properties: []search.SelectProperty{{Name: "intField", IsPrimitive: true}},
NearObject: &traverser.NearObjectParams{
ID: "some-other-uuid",
Expand Down Expand Up @@ -1230,6 +1235,7 @@ func TestNearObjectNoModules(t *testing.T) {

expectedParams := traverser.GetParams{
ClassName: "SomeThing",
Pagination: &filters.Pagination{Limit: filters.LimitFlagSearchByDist},
Properties: []search.SelectProperty{{Name: "intField", IsPrimitive: true}},
NearObject: &traverser.NearObjectParams{
ID: "some-uuid",
Expand Down Expand Up @@ -1277,6 +1283,7 @@ func TestNearVectorNoModules(t *testing.T) {
expectedParams := traverser.GetParams{
ClassName: "SomeThing",
Properties: []search.SelectProperty{{Name: "intField", IsPrimitive: true}},
Pagination: &filters.Pagination{Limit: filters.LimitFlagSearchByDist},
NearVector: &traverser.NearVectorParams{
Vector: []float32{0.123, 0.984},
Certainty: 0.4,
Expand Down
5 changes: 5 additions & 0 deletions adapters/handlers/graphql/local/get/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ type nearCustomTextParams struct {
Certainty float64
}

// implements the modulecapabilities.NearParam interface
func (n *nearCustomTextParams) GetCertainty() float64 {
return n.Certainty
}

type nearExploreMove struct {
Values []string
Force float32
Expand Down
6 changes: 3 additions & 3 deletions adapters/handlers/rest/clusterapi/indices.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ type shards interface {
MultiGetObjects(ctx context.Context, indexName, shardName string,
id []strfmt.UUID) ([]*storobj.Object, error)
Search(ctx context.Context, indexName, shardName string,
vector []float32, limit int, filters *filters.LocalFilter,
vector []float32, certainty float64, limit int, filters *filters.LocalFilter,
additional additional.Properties) ([]*storobj.Object, []float32, error)
Aggregate(ctx context.Context, indexName, shardName string,
params aggregation.Params) (*aggregation.Result, error)
Expand Down Expand Up @@ -460,7 +460,7 @@ func (i *indices) postSearchObjects() http.Handler {
return
}

vector, limit, filters, additional, err := IndicesPayloads.SearchParams.
vector, certainty, limit, filters, additional, err := IndicesPayloads.SearchParams.
Unmarshal(reqPayload)
if err != nil {
http.Error(w, "unmarshal search params from json: "+err.Error(),
Expand All @@ -469,7 +469,7 @@ func (i *indices) postSearchObjects() http.Handler {
}

results, dists, err := i.shards.Search(r.Context(), index, shard,
vector, limit, filters, additional)
vector, certainty, limit, filters, additional)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
Expand Down
5 changes: 3 additions & 2 deletions adapters/handlers/rest/clusterapi/indices_payloads.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,17 +233,18 @@ func (p searchParamsPayload) Marshal(vector []float32, limit int,
return json.Marshal(par)
}

func (p searchParamsPayload) Unmarshal(in []byte) ([]float32, int,
func (p searchParamsPayload) Unmarshal(in []byte) ([]float32, float64, int,
*filters.LocalFilter, additional.Properties, error) {
type searchParametersPayload struct {
SearchVector []float32 `json:"searchVector"`
Certainty float64 `json:"certainty"`
Limit int `json:"limit"`
Filters *filters.LocalFilter `json:"filters"`
Additional additional.Properties `json:"additional"`
}
var par searchParametersPayload
err := json.Unmarshal(in, &par)
return par.SearchVector, par.Limit, par.Filters, par.Additional, err
return par.SearchVector, par.Certainty, par.Limit, par.Filters, par.Additional, err
}

func (p searchParamsPayload) MIME() string {
Expand Down
142 changes: 142 additions & 0 deletions adapters/repos/db/crud_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,148 @@ func Test_ImportWithoutVector_UpdateWithVectorLater(t *testing.T) {
})
}

func TestVectorSearch_ByCertainty(t *testing.T) {
className := "SomeClass"
var class *models.Class

rand.Seed(time.Now().UnixNano())
dirName := fmt.Sprintf("./testdata/%d", rand.Intn(10000000))
os.MkdirAll(dirName, 0o777)
logger, _ := test.NewNullLogger()
defer func() {
err := os.RemoveAll(dirName)
fmt.Println(err)
}()

schemaGetter := &fakeSchemaGetter{shardState: singleShardState()}
repo := New(logger, Config{
RootPath: dirName,
// this is set really low to ensure that search
// by distance is conducted, which executes
// without regard to this value
QueryMaximumResults: 1,
DiskUseWarningPercentage: config.DefaultDiskUseWarningPercentage,
DiskUseReadOnlyPercentage: config.DefaultDiskUseReadonlyPercentage,
}, &fakeRemoteClient{},
&fakeNodeResolver{})
repo.SetSchemaGetter(schemaGetter)
err := repo.WaitForStartup(testCtx())
require.Nil(t, err)
migrator := NewMigrator(repo, logger)

t.Run("create required schema", func(t *testing.T) {
class = &models.Class{
Class: className,
Properties: []*models.Property{
{
DataType: []string{string(schema.DataTypeInt)},
Name: "int_prop",
},
},
VectorIndexConfig: hnsw.NewDefaultUserConfig(),
InvertedIndexConfig: invertedConfig(),
}
require.Nil(t,
migrator.AddClass(context.Background(), class, schemaGetter.shardState))
})

// update schema getter so it's in sync with class
schemaGetter.schema = schema.Schema{
Objects: &models.Schema{
Classes: []*models.Class{class},
},
}

searchVector := []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}
searchObject := strfmt.UUID("fe687bf4-f10f-4c23-948d-0746ea2927b3")

tests := map[strfmt.UUID]struct {
inputVec []float32
expected bool
}{
strfmt.UUID("88460290-03b2-44a3-9adb-9fa3ae11d9e6"): {
inputVec: []float32{1, 2, 3, 4, 5, 6, 98, 99, 100},
expected: true,
},
strfmt.UUID("c99bc97d-7035-4311-94f3-947dc6471f51"): {
inputVec: []float32{1, 2, 3, 4, 5, 6, -98, -99, -100},
expected: false,
},
strfmt.UUID("fe687bf4-f10f-4c23-948d-0746ea2927b3"): {
inputVec: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
expected: true,
},
strfmt.UUID("e7bf6c45-de72-493a-b273-5ef198974d61"): {
inputVec: []float32{-1, -2, -3, -4, -5, -6, -7, -8, 0},
expected: false,
},
strfmt.UUID("0999d109-1d5f-465a-bd8b-e3fbd46f10aa"): {
inputVec: []float32{1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9},
expected: true,
},
}

t.Run("insert test objects", func(t *testing.T) {
for id, props := range tests {
err := repo.PutObject(context.Background(), &models.Object{Class: className, ID: id}, props.inputVec)
require.Nil(t, err)
}
})

t.Run("perform nearVector search by distance", func(t *testing.T) {
results, err := repo.VectorClassSearch(context.Background(), traverser.GetParams{
ClassName: className,
Pagination: &filters.Pagination{Limit: filters.LimitFlagSearchByDist},
NearVector: &traverser.NearVectorParams{
Certainty: 0.9,
},
SearchVector: searchVector,
AdditionalProperties: additional.Properties{Certainty: true},
})
require.Nil(t, err)
require.NotEmpty(t, results)
// ensure that we receive more results than
// the `QueryMaximumResults`, as this should
// only apply to limited vector searches
require.Greater(t, len(results), 1)

for _, res := range results {
if props, ok := tests[res.ID]; !ok {
assert.False(t, props.expected)
} else {
assert.True(t, props.expected)
}
}
})

t.Run("perform nearObject search by distance", func(t *testing.T) {
results, err := repo.VectorClassSearch(context.Background(), traverser.GetParams{
ClassName: className,
Pagination: &filters.Pagination{Limit: filters.LimitFlagSearchByDist},
NearObject: &traverser.NearObjectParams{
Certainty: 0.9,
ID: searchObject.String(),
},
SearchVector: searchVector,
AdditionalProperties: additional.Properties{Certainty: true},
})
require.Nil(t, err)
require.NotEmpty(t, results)
// ensure that we receive more results than
// the `QueryMaximumResults`, as this should
// only apply to limited vector searches
require.Greater(t, len(results), 1)

for _, res := range results {
if props, ok := tests[res.ID]; !ok {
assert.False(t, props.expected)
} else {
assert.True(t, props.expected)
}
}
})
}

func findID(list []search.Result, id strfmt.UUID) (search.Result, bool) {
for _, item := range list {
if item.ID == id {
Expand Down