Skip to content

Commit

Permalink
Fix vector search with named vectors (#4857)
Browse files Browse the repository at this point in the history
  • Loading branch information
antas-marcin committed May 6, 2024
1 parent c835574 commit aacdfdc
Show file tree
Hide file tree
Showing 8 changed files with 277 additions and 7 deletions.
4 changes: 2 additions & 2 deletions adapters/repos/db/shard.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,8 @@ func (s *Shard) initVectorIndex(ctx context.Context,
ShardName: s.name,
ClassName: s.index.Config.ClassName.String(),
PrometheusMetrics: s.promMetrics,
VectorForIDThunk: s.vectorByIndexID,
TempVectorForIDThunk: s.readVectorByIndexIDIntoSlice,
VectorForIDThunk: hnsw.NewVectorForIDThunk(targetVector, s.vectorByIndexID),
TempVectorForIDThunk: hnsw.NewTempVectorForIDThunk(targetVector, s.readVectorByIndexIDIntoSlice),
DistanceProvider: distProv,
MakeCommitLoggerThunk: func() (hnsw.CommitLogger, error) {
return hnsw.NewCommitLogger(s.path(), vecIdxID,
Expand Down
8 changes: 4 additions & 4 deletions adapters/repos/db/shard_read.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,12 @@ func (s *Shard) objectByIndexID(ctx context.Context, indexID uint64, acceptDelet
return obj, nil
}

func (s *Shard) vectorByIndexID(ctx context.Context, indexID uint64) ([]float32, error) {
func (s *Shard) vectorByIndexID(ctx context.Context, indexID uint64, targetVector string) ([]float32, error) {
keyBuf := make([]byte, 8)
return s.readVectorByIndexIDIntoSlice(ctx, indexID, &common.VectorSlice{Buff8: keyBuf})
return s.readVectorByIndexIDIntoSlice(ctx, indexID, &common.VectorSlice{Buff8: keyBuf}, targetVector)
}

func (s *Shard) readVectorByIndexIDIntoSlice(ctx context.Context, indexID uint64, container *common.VectorSlice) ([]float32, error) {
func (s *Shard) readVectorByIndexIDIntoSlice(ctx context.Context, indexID uint64, container *common.VectorSlice, targetVector string) ([]float32, error) {
binary.LittleEndian.PutUint64(container.Buff8, indexID)

bytes, newBuff, err := s.store.Bucket(helpers.ObjectsBucketLSM).
Expand All @@ -158,7 +158,7 @@ func (s *Shard) readVectorByIndexIDIntoSlice(ctx context.Context, indexID uint64
}

container.Buff = newBuff
return storobj.VectorFromBinary(bytes, container.Slice)
return storobj.VectorFromBinary(bytes, container.Slice, targetVector)
}

func (s *Shard) ObjectSearch(ctx context.Context, limit int, filters *filters.LocalFilter,
Expand Down
18 changes: 18 additions & 0 deletions adapters/repos/db/vector/common/vector_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,24 @@ type (
MultiVectorForID func(ctx context.Context, ids []uint64) ([][]float32, []error)
)

type TargetVectorForID[T float32 | byte | uint64] struct {
TargetVector string
VectorForIDThunk func(ctx context.Context, id uint64, targetVector string) ([]T, error)
}

func (t TargetVectorForID[T]) VectorForID(ctx context.Context, id uint64) ([]T, error) {
return t.VectorForIDThunk(ctx, id, t.TargetVector)
}

type TargetTempVectorForID struct {
TargetVector string
TempVectorForIDThunk func(ctx context.Context, id uint64, container *VectorSlice, targetVector string) ([]float32, error)
}

func (t TargetTempVectorForID) TempVectorForID(ctx context.Context, id uint64, container *VectorSlice) ([]float32, error) {
return t.TempVectorForIDThunk(ctx, id, container, t.TargetVector)
}

type TempVectorsPool struct {
pool *sync.Pool
}
Expand Down
18 changes: 18 additions & 0 deletions adapters/repos/db/vector/hnsw/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
package hnsw

import (
"context"

"github.com/sirupsen/logrus"
"github.com/weaviate/weaviate/adapters/repos/db/vector/common"
"github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer"
Expand Down Expand Up @@ -65,3 +67,19 @@ func (c Config) Validate() error {

return ec.ToError()
}

func NewVectorForIDThunk(targetVector string, fn func(ctx context.Context, id uint64, targetVector string) ([]float32, error)) common.VectorForID[float32] {
t := common.TargetVectorForID[float32]{
TargetVector: targetVector,
VectorForIDThunk: fn,
}
return t.VectorForID
}

func NewTempVectorForIDThunk(targetVector string, fn func(ctx context.Context, indexID uint64, container *common.VectorSlice, targetVector string) ([]float32, error)) common.TempVectorForID {
t := common.TargetTempVectorForID{
TargetVector: targetVector,
TempVectorForIDThunk: fn,
}
return t.TempVectorForID
}
32 changes: 31 additions & 1 deletion entities/storobj/storage_object.go
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ func unmarshalTargetVectors(rw *byteops.ReadWriter) (map[string][]float32, error
return nil, nil
}

func VectorFromBinary(in []byte, buffer []float32) ([]float32, error) {
func VectorFromBinary(in []byte, buffer []float32, targetVector string) ([]float32, error) {
if len(in) == 0 {
return nil, nil
}
Expand All @@ -861,6 +861,36 @@ func VectorFromBinary(in []byte, buffer []float32) ([]float32, error) {
return nil, errors.Errorf("unsupported marshaller version %d", version)
}

if targetVector != "" {
startPos := uint64(1 + 8 + 1 + 16 + 8 + 8) // elements at the start
rw := byteops.NewReadWriter(in, byteops.WithPosition(startPos))

vectorLength := uint64(rw.ReadUint16())
rw.MoveBufferPositionForward(vectorLength * 4)

classnameLength := uint64(rw.ReadUint16())
rw.MoveBufferPositionForward(classnameLength)

schemaLength := uint64(rw.ReadUint32())
rw.MoveBufferPositionForward(schemaLength)

metaLength := uint64(rw.ReadUint32())
rw.MoveBufferPositionForward(metaLength)

vectorWeightsLength := uint64(rw.ReadUint32())
rw.MoveBufferPositionForward(vectorWeightsLength)

targetVectors, err := unmarshalTargetVectors(&rw)
if err != nil {
return nil, errors.Errorf("unable to unmarshal vector for target vector: %s", targetVector)
}
vector, ok := targetVectors[targetVector]
if !ok {
return nil, errors.Errorf("vector not found for target vector: %s", targetVector)
}
return vector, nil
}

// since we know the version and know that the blob is not len(0), we can
// assume that we can directly access the vector length field. The only
// situation where this is not accessible would be on corrupted data - where
Expand Down
76 changes: 76 additions & 0 deletions entities/storobj/storage_object_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -624,3 +624,79 @@ func TestStorageMaxVectorDimensionsObjectMarshalling(t *testing.T) {
})
}
}

func TestVectorFromBinary(t *testing.T) {
vector1 := []float32{1, 2, 3}
vector2 := []float32{4, 5, 6}
vector3 := []float32{7, 8, 9}
before := FromObject(
&models.Object{
Class: "MyFavoriteClass",
CreationTimeUnix: 123456,
LastUpdateTimeUnix: 56789,
ID: strfmt.UUID("73f2eb5f-5abf-447a-81ca-74b1dd168247"),
Additional: models.AdditionalProperties{
"classification": &additional.Classification{
BasedOn: []string{"some", "fields"},
},
"interpretation": map[string]interface{}{
"Source": []interface{}{
map[string]interface{}{
"concept": "foo",
"occurrence": float64(7),
"weight": float64(3),
},
},
},
"group": &additional.Group{
ID: 100,
GroupedBy: &additional.GroupedBy{
Value: "group-by-some-property",
Path: []string{"property-path"},
},
MaxDistance: 0.1,
MinDistance: 0.2,
Count: 200,
Hits: []map[string]interface{}{
{
"property1": "value1",
"_additional": &additional.GroupHitAdditional{
ID: "2c76ca18-2073-4c48-aa52-7f444d2f5b80",
Distance: 0.24,
},
},
{
"property1": "value2",
},
},
},
},
Properties: map[string]interface{}{
"name": "MyName",
"foo": float64(17),
},
},
[]float32{1, 2, 0.7},
models.Vectors{
"vector1": vector1,
"vector2": vector2,
"vector3": vector3,
},
)
before.DocID = 7

asBinary, err := before.MarshalBinary()
require.Nil(t, err)

outVector1, err := VectorFromBinary(asBinary, nil, "vector1")
require.Nil(t, err)
assert.Equal(t, vector1, outVector1)

outVector2, err := VectorFromBinary(asBinary, nil, "vector2")
require.Nil(t, err)
assert.Equal(t, vector2, outVector2)

outVector3, err := VectorFromBinary(asBinary, nil, "vector3")
require.Nil(t, err)
assert.Equal(t, vector3, outVector3)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: hello@weaviate.io
//

package named_vectors_tests

import (
"acceptance_tests_with_client/fixtures"
"context"
"testing"

"github.com/go-openapi/strfmt"
"github.com/stretchr/testify/require"
wvt "github.com/weaviate/weaviate-go-client/v4/weaviate"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/test/docker"
)

func testRestart(compose *docker.DockerCompose) func(t *testing.T) {
return func(t *testing.T) {
ctx := context.Background()
host := compose.GetWeaviate().URI()
client, err := wvt.NewClient(wvt.Config{Scheme: "http", Host: host})
require.Nil(t, err)

cleanup := func() {
err := client.Schema().AllDeleter().Do(context.Background())
require.Nil(t, err)
}

t.Run("multiple named vectors", func(t *testing.T) {
cleanup()

t.Run("create schema", func(t *testing.T) {
createNamedVectorsClass(t, client)
})

t.Run("batch create objects", func(t *testing.T) {
objs := []*models.Object{}
for id, book := range fixtures.Books() {
obj := &models.Object{
Class: className,
ID: strfmt.UUID(id),
Properties: map[string]interface{}{
"text": book.Description,
},
}
objs = append(objs, obj)
}

resp, err := client.Batch().ObjectsBatcher().
WithObjects(objs...).
Do(ctx)
require.NoError(t, err)
require.NotNil(t, resp)
})

t.Run("check existence", func(t *testing.T) {
for id := range fixtures.Books() {
exists, err := client.Data().Checker().
WithID(id).
WithClassName(className).
Do(ctx)
require.NoError(t, err)
require.True(t, exists)
}
})

t.Run("GraphQL get vectors", func(t *testing.T) {
for id := range fixtures.Books() {
resultVectors := getVectors(t, client, className, id, targetVectors...)
checkTargetVectors(t, resultVectors)
}
})

t.Run("GraphQL near<Media> check", func(t *testing.T) {
for id, book := range fixtures.Books() {
for _, targetVector := range targetVectors {
nearText := client.GraphQL().NearTextArgBuilder().
WithConcepts([]string{book.Title}).
WithTargetVectors(targetVector)
resultVectors := getVectorsWithNearText(t, client, className, id, nearText, targetVectors...)
checkTargetVectors(t, resultVectors)
}
}
})

t.Run("GraphQL near<Media> check after restart", func(t *testing.T) {
err := compose.Stop(ctx, compose.GetWeaviate().Name(), nil)
require.NoError(t, err)

err = compose.Start(ctx, compose.GetWeaviate().Name())
require.NoError(t, err)

host := compose.GetWeaviate().URI()
client, err := wvt.NewClient(wvt.Config{Scheme: "http", Host: host})
require.Nil(t, err)

for id, book := range fixtures.Books() {
for _, targetVector := range targetVectors {
nearText := client.GraphQL().NearTextArgBuilder().
WithConcepts([]string{book.Title}).
WithTargetVectors(targetVector)
resultVectors := getVectorsWithNearText(t, client, className, id, nearText, targetVectors...)
checkTargetVectors(t, resultVectors)
}
}
})
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ func allTests(endpoint string) func(t *testing.T) {
}
}

func TestNamedVectors_SingleNode_Restart(t *testing.T) {
ctx := context.Background()
compose, err := createSingleNodeEnvironment(ctx)
require.NoError(t, err)
defer func() {
require.NoError(t, compose.Terminate(ctx))
}()
t.Run("restart", testRestart(compose))
}

func createSingleNodeEnvironment(ctx context.Context) (compose *docker.DockerCompose, err error) {
compose, err = composeModules().
WithWeaviate().
Expand Down

0 comments on commit aacdfdc

Please sign in to comment.