Skip to content

Commit

Permalink
Merge pull request #1883 from semi-technologies/feature/WEAVIATE-19
Browse files Browse the repository at this point in the history
Unlimited vector search by Certainty
  • Loading branch information
antas-marcin committed Apr 1, 2022
2 parents 658bee1 + 97a82b8 commit c8c97be
Show file tree
Hide file tree
Showing 18 changed files with 576 additions and 98 deletions.
57 changes: 48 additions & 9 deletions adapters/handlers/graphql/local/get/class_builder_fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ import (
"regexp"
"strings"

"github.com/graphql-go/graphql"
"github.com/graphql-go/graphql/language/ast"
"github.com/semi-technologies/weaviate/adapters/handlers/graphql/descriptions"
"github.com/semi-technologies/weaviate/adapters/handlers/graphql/local/common_filters"
"github.com/semi-technologies/weaviate/entities/additional"
"github.com/semi-technologies/weaviate/entities/filters"
"github.com/semi-technologies/weaviate/entities/models"
"github.com/semi-technologies/weaviate/entities/modulecapabilities"
"github.com/semi-technologies/weaviate/entities/schema"
"github.com/semi-technologies/weaviate/entities/search"
"github.com/semi-technologies/weaviate/usecases/traverser"

"github.com/graphql-go/graphql"
"github.com/graphql-go/graphql/language/ast"
)

func (b *classBuilder) primitiveField(propertyType schema.PropertyDataType,
Expand Down Expand Up @@ -339,12 +339,6 @@ func (r *resolver) makeResolveGetClass(className string) graphql.FieldResolveFn
nearObjectParams = &p
}

var keywordRankingParams *traverser.KeywordRankingParams
if bm25, ok := p.Args["bm25"]; ok {
p := common_filters.ExtractBM25(bm25.(map[string]interface{}))
keywordRankingParams = &p
}

var moduleParams map[string]interface{}
if r.modulesProvider != nil {
extractedParams := r.modulesProvider.ExtractSearchParams(p.Args, className)
Expand All @@ -353,6 +347,12 @@ func (r *resolver) makeResolveGetClass(className string) graphql.FieldResolveFn
}
}

var keywordRankingParams *traverser.KeywordRankingParams
if bm25, ok := p.Args["bm25"]; ok {
p := common_filters.ExtractBM25(bm25.(map[string]interface{}))
keywordRankingParams = &p
}

group := extractGroup(p.Args)

params := traverser.GetParams{
Expand All @@ -368,12 +368,51 @@ func (r *resolver) makeResolveGetClass(className string) graphql.FieldResolveFn
KeywordRanking: keywordRankingParams,
}

// need to perform vector search by distance
// under certain conditions
setLimitBasedOnVectorSearchParams(&params)

return func() (interface{}, error) {
return resolver.GetClass(p.Context, principalFromContext(p.Context), params)
}, nil
}
}

// the limit needs to be set according to the vector search parameters.
// for example, if a certainty is provided by any of the near* options,
// and no limit was provided, weaviate will want to execute a vector
// search by distance. it knows to do this by watching for a limit
// flag, specicially filters.LimitFlagSearchByDistance
func setLimitBasedOnVectorSearchParams(params *traverser.GetParams) {
setLimit := func(params *traverser.GetParams) {
if params.Pagination == nil {
params.Pagination = &filters.Pagination{
Limit: filters.LimitFlagSearchByDist,
}
} else {
params.Pagination.Limit = filters.LimitFlagSearchByDist
}
}

if params.NearVector != nil && params.NearVector.Certainty != 0 {
setLimit(params)
return
}

if params.NearObject != nil && params.NearObject.Certainty != 0 {
setLimit(params)
return
}

for _, param := range params.ModuleParams {
nearParam, ok := param.(modulecapabilities.NearParam)
if ok && nearParam.GetCertainty() != 0 {
setLimit(params)
return
}
}
}

func extractGroup(args map[string]interface{}) *traverser.GroupParams {
group, ok := args["group"]
if !ok {
Expand Down
15 changes: 11 additions & 4 deletions adapters/handlers/graphql/local/get/get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ func TestNearCustomTextRanker(t *testing.T) {

t.Run("for a class that does not have a text2vec module", func(t *testing.T) {
query := `{ Get { CustomVectorClass(nearCustomText: {
concepts: ["c1", "c2", "c3"],
concepts: ["c1", "c2", "c3"],
moveTo: {
concepts:["positive"],
force: 0.5
Expand All @@ -707,7 +707,7 @@ func TestNearCustomTextRanker(t *testing.T) {
concepts:["epic"]
force: 0.25
}
}) { intField } } }`
}) { intField } } }`

res := resolver.Resolve(query)
require.Len(t, res.Errors, 1)
Expand All @@ -716,7 +716,7 @@ func TestNearCustomTextRanker(t *testing.T) {

t.Run("for things with optional certainty set", func(t *testing.T) {
query := `{ Get { SomeThing(nearCustomText: {
concepts: ["c1", "c2", "c3"],
concepts: ["c1", "c2", "c3"],
certainty: 0.4,
moveTo: {
concepts:["positive"],
Expand All @@ -726,11 +726,12 @@ func TestNearCustomTextRanker(t *testing.T) {
concepts:["epic"]
force: 0.25
}
}) { intField } } }`
}) { intField } } }`

expectedParams := traverser.GetParams{
ClassName: "SomeThing",
Properties: []search.SelectProperty{{Name: "intField", IsPrimitive: true}},
Pagination: &filters.Pagination{Limit: filters.LimitFlagSearchByDist},
ModuleParams: map[string]interface{}{
"nearCustomText": extractNearTextParam(map[string]interface{}{
"concepts": []interface{}{"c1", "c2", "c3"},
Expand Down Expand Up @@ -778,6 +779,7 @@ func TestNearCustomTextRanker(t *testing.T) {
expectedParams := traverser.GetParams{
ClassName: "SomeThing",
Properties: []search.SelectProperty{{Name: "intField", IsPrimitive: true}},
Pagination: &filters.Pagination{Limit: filters.LimitFlagSearchByDist},
ModuleParams: map[string]interface{}{
"nearCustomText": extractNearTextParam(map[string]interface{}{
"concepts": []interface{}{"c1", "c2", "c3"},
Expand Down Expand Up @@ -852,6 +854,7 @@ func TestNearVectorRanker(t *testing.T) {
expectedParams := traverser.GetParams{
ClassName: "SomeThing",
Properties: []search.SelectProperty{{Name: "intField", IsPrimitive: true}},
Pagination: &filters.Pagination{Limit: filters.LimitFlagSearchByDist},
NearVector: &traverser.NearVectorParams{
Vector: []float32{0.123, 0.984},
Certainty: 0.4,
Expand Down Expand Up @@ -1113,6 +1116,7 @@ func TestNearObject(t *testing.T) {
expectedParams := traverser.GetParams{
ClassName: "SomeThing",
Properties: []search.SelectProperty{{Name: "intField", IsPrimitive: true}},
Pagination: &filters.Pagination{Limit: filters.LimitFlagSearchByDist},
NearObject: &traverser.NearObjectParams{
Beacon: "weaviate://localhost/some-other-uuid",
Certainty: 0.7,
Expand Down Expand Up @@ -1153,6 +1157,7 @@ func TestNearObject(t *testing.T) {

expectedParams := traverser.GetParams{
ClassName: "SomeThing",
Pagination: &filters.Pagination{Limit: filters.LimitFlagSearchByDist},
Properties: []search.SelectProperty{{Name: "intField", IsPrimitive: true}},
NearObject: &traverser.NearObjectParams{
ID: "some-other-uuid",
Expand Down Expand Up @@ -1230,6 +1235,7 @@ func TestNearObjectNoModules(t *testing.T) {

expectedParams := traverser.GetParams{
ClassName: "SomeThing",
Pagination: &filters.Pagination{Limit: filters.LimitFlagSearchByDist},
Properties: []search.SelectProperty{{Name: "intField", IsPrimitive: true}},
NearObject: &traverser.NearObjectParams{
ID: "some-uuid",
Expand Down Expand Up @@ -1277,6 +1283,7 @@ func TestNearVectorNoModules(t *testing.T) {
expectedParams := traverser.GetParams{
ClassName: "SomeThing",
Properties: []search.SelectProperty{{Name: "intField", IsPrimitive: true}},
Pagination: &filters.Pagination{Limit: filters.LimitFlagSearchByDist},
NearVector: &traverser.NearVectorParams{
Vector: []float32{0.123, 0.984},
Certainty: 0.4,
Expand Down
5 changes: 5 additions & 0 deletions adapters/handlers/graphql/local/get/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ type nearCustomTextParams struct {
Certainty float64
}

// implements the modulecapabilities.NearParam interface
func (n *nearCustomTextParams) GetCertainty() float64 {
return n.Certainty
}

type nearExploreMove struct {
Values []string
Force float32
Expand Down
6 changes: 3 additions & 3 deletions adapters/handlers/rest/clusterapi/indices.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ type shards interface {
MultiGetObjects(ctx context.Context, indexName, shardName string,
id []strfmt.UUID) ([]*storobj.Object, error)
Search(ctx context.Context, indexName, shardName string,
vector []float32, limit int, filters *filters.LocalFilter,
vector []float32, certainty float64, limit int, filters *filters.LocalFilter,
additional additional.Properties) ([]*storobj.Object, []float32, error)
Aggregate(ctx context.Context, indexName, shardName string,
params aggregation.Params) (*aggregation.Result, error)
Expand Down Expand Up @@ -460,7 +460,7 @@ func (i *indices) postSearchObjects() http.Handler {
return
}

vector, limit, filters, additional, err := IndicesPayloads.SearchParams.
vector, certainty, limit, filters, additional, err := IndicesPayloads.SearchParams.
Unmarshal(reqPayload)
if err != nil {
http.Error(w, "unmarshal search params from json: "+err.Error(),
Expand All @@ -469,7 +469,7 @@ func (i *indices) postSearchObjects() http.Handler {
}

results, dists, err := i.shards.Search(r.Context(), index, shard,
vector, limit, filters, additional)
vector, certainty, limit, filters, additional)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
Expand Down
5 changes: 3 additions & 2 deletions adapters/handlers/rest/clusterapi/indices_payloads.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,17 +233,18 @@ func (p searchParamsPayload) Marshal(vector []float32, limit int,
return json.Marshal(par)
}

func (p searchParamsPayload) Unmarshal(in []byte) ([]float32, int,
func (p searchParamsPayload) Unmarshal(in []byte) ([]float32, float64, int,
*filters.LocalFilter, additional.Properties, error) {
type searchParametersPayload struct {
SearchVector []float32 `json:"searchVector"`
Certainty float64 `json:"certainty"`
Limit int `json:"limit"`
Filters *filters.LocalFilter `json:"filters"`
Additional additional.Properties `json:"additional"`
}
var par searchParametersPayload
err := json.Unmarshal(in, &par)
return par.SearchVector, par.Limit, par.Filters, par.Additional, err
return par.SearchVector, par.Certainty, par.Limit, par.Filters, par.Additional, err
}

func (p searchParamsPayload) MIME() string {
Expand Down
142 changes: 142 additions & 0 deletions adapters/repos/db/crud_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,148 @@ func Test_ImportWithoutVector_UpdateWithVectorLater(t *testing.T) {
})
}

func TestVectorSearch_ByCertainty(t *testing.T) {
className := "SomeClass"
var class *models.Class

rand.Seed(time.Now().UnixNano())
dirName := fmt.Sprintf("./testdata/%d", rand.Intn(10000000))
os.MkdirAll(dirName, 0o777)
logger, _ := test.NewNullLogger()
defer func() {
err := os.RemoveAll(dirName)
fmt.Println(err)
}()

schemaGetter := &fakeSchemaGetter{shardState: singleShardState()}
repo := New(logger, Config{
RootPath: dirName,
// this is set really low to ensure that search
// by distance is conducted, which executes
// without regard to this value
QueryMaximumResults: 1,
DiskUseWarningPercentage: config.DefaultDiskUseWarningPercentage,
DiskUseReadOnlyPercentage: config.DefaultDiskUseReadonlyPercentage,
}, &fakeRemoteClient{},
&fakeNodeResolver{})
repo.SetSchemaGetter(schemaGetter)
err := repo.WaitForStartup(testCtx())
require.Nil(t, err)
migrator := NewMigrator(repo, logger)

t.Run("create required schema", func(t *testing.T) {
class = &models.Class{
Class: className,
Properties: []*models.Property{
{
DataType: []string{string(schema.DataTypeInt)},
Name: "int_prop",
},
},
VectorIndexConfig: hnsw.NewDefaultUserConfig(),
InvertedIndexConfig: invertedConfig(),
}
require.Nil(t,
migrator.AddClass(context.Background(), class, schemaGetter.shardState))
})

// update schema getter so it's in sync with class
schemaGetter.schema = schema.Schema{
Objects: &models.Schema{
Classes: []*models.Class{class},
},
}

searchVector := []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}
searchObject := strfmt.UUID("fe687bf4-f10f-4c23-948d-0746ea2927b3")

tests := map[strfmt.UUID]struct {
inputVec []float32
expected bool
}{
strfmt.UUID("88460290-03b2-44a3-9adb-9fa3ae11d9e6"): {
inputVec: []float32{1, 2, 3, 4, 5, 6, 98, 99, 100},
expected: true,
},
strfmt.UUID("c99bc97d-7035-4311-94f3-947dc6471f51"): {
inputVec: []float32{1, 2, 3, 4, 5, 6, -98, -99, -100},
expected: false,
},
strfmt.UUID("fe687bf4-f10f-4c23-948d-0746ea2927b3"): {
inputVec: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
expected: true,
},
strfmt.UUID("e7bf6c45-de72-493a-b273-5ef198974d61"): {
inputVec: []float32{-1, -2, -3, -4, -5, -6, -7, -8, 0},
expected: false,
},
strfmt.UUID("0999d109-1d5f-465a-bd8b-e3fbd46f10aa"): {
inputVec: []float32{1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9},
expected: true,
},
}

t.Run("insert test objects", func(t *testing.T) {
for id, props := range tests {
err := repo.PutObject(context.Background(), &models.Object{Class: className, ID: id}, props.inputVec)
require.Nil(t, err)
}
})

t.Run("perform nearVector search by distance", func(t *testing.T) {
results, err := repo.VectorClassSearch(context.Background(), traverser.GetParams{
ClassName: className,
Pagination: &filters.Pagination{Limit: filters.LimitFlagSearchByDist},
NearVector: &traverser.NearVectorParams{
Certainty: 0.9,
},
SearchVector: searchVector,
AdditionalProperties: additional.Properties{Certainty: true},
})
require.Nil(t, err)
require.NotEmpty(t, results)
// ensure that we receive more results than
// the `QueryMaximumResults`, as this should
// only apply to limited vector searches
require.Greater(t, len(results), 1)

for _, res := range results {
if props, ok := tests[res.ID]; !ok {
assert.False(t, props.expected)
} else {
assert.True(t, props.expected)
}
}
})

t.Run("perform nearObject search by distance", func(t *testing.T) {
results, err := repo.VectorClassSearch(context.Background(), traverser.GetParams{
ClassName: className,
Pagination: &filters.Pagination{Limit: filters.LimitFlagSearchByDist},
NearObject: &traverser.NearObjectParams{
Certainty: 0.9,
ID: searchObject.String(),
},
SearchVector: searchVector,
AdditionalProperties: additional.Properties{Certainty: true},
})
require.Nil(t, err)
require.NotEmpty(t, results)
// ensure that we receive more results than
// the `QueryMaximumResults`, as this should
// only apply to limited vector searches
require.Greater(t, len(results), 1)

for _, res := range results {
if props, ok := tests[res.ID]; !ok {
assert.False(t, props.expected)
} else {
assert.True(t, props.expected)
}
}
})
}

func findID(list []search.Result, id strfmt.UUID) (search.Result, bool) {
for _, item := range list {
if item.ID == id {
Expand Down

0 comments on commit c8c97be

Please sign in to comment.