diff --git a/internal/core/algorithm/ngt/ngt.go b/internal/core/algorithm/ngt/ngt.go index 28a40b2bfc..99d73fac3d 100644 --- a/internal/core/algorithm/ngt/ngt.go +++ b/internal/core/algorithm/ngt/ngt.go @@ -25,7 +25,6 @@ package ngt import "C" import ( - "context" "reflect" "sync" "unsafe" @@ -41,7 +40,7 @@ type ( // NGT is core interface. NGT interface { // Search returns search result as []SearchResult - Search(ctx context.Context, vec []float32, size int, epsilon, radius float32) ([]SearchResult, error) + Search(vec []float32, size int, epsilon, radius float32) ([]SearchResult, error) // Linear Search returns linear search result as []SearchResult LinearSearch(vec []float32, size int) ([]SearchResult, error) @@ -367,7 +366,7 @@ func (n *ngt) loadObjectSpace() error { } // Search returns search result as []SearchResult. -func (n *ngt) Search(ctx context.Context, vec []float32, size int, epsilon, radius float32) (result []SearchResult, err error) { +func (n *ngt) Search(vec []float32, size int, epsilon, radius float32) (result []SearchResult, err error) { if len(vec) != int(n.dimension) { return nil, errors.ErrIncompatibleDimensionSize(len(vec), int(n.dimension)) } @@ -416,12 +415,6 @@ func (n *ngt) Search(ctx context.Context, vec []float32, size int, epsilon, radi result = make([]SearchResult, rsize) for i := range result { - select { - case <-ctx.Done(): - n.PutErrorBuffer(ebuf) - return result[:i], nil - default: - } d := C.ngt_get_result(results, C.uint32_t(i), ebuf) if d.id == 0 && d.distance == 0 { result[i] = SearchResult{0, 0, n.newGoError(ebuf)} diff --git a/internal/core/algorithm/ngt/ngt_test.go b/internal/core/algorithm/ngt/ngt_test.go index 70156d88ff..5b7c071276 100644 --- a/internal/core/algorithm/ngt/ngt_test.go +++ b/internal/core/algorithm/ngt/ngt_test.go @@ -18,7 +18,6 @@ package ngt import ( - "context" "io/fs" "math" "os" @@ -231,11 +230,11 @@ func TestLoad(t *testing.T) { name string args args want want - checkFunc func(context.Context, want, NGT, error) error + checkFunc func(want, NGT, error) error beforeFunc func(*testing.T, args) afterFunc func(*testing.T, NGT) error } - defaultCheckFunc := func(_ context.Context, w want, got NGT, err error) error { + defaultCheckFunc := func(w want, got NGT, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } @@ -297,8 +296,8 @@ func TestLoad(t *testing.T) { mu: &sync.RWMutex{}, }, }, - checkFunc: func(ctx context.Context, w want, n NGT, e error) error { - if err := defaultCheckFunc(ctx, w, n, e); err != nil { + checkFunc: func(w want, n NGT, e error) error { + if err := defaultCheckFunc(w, n, e); err != nil { return err } @@ -309,7 +308,7 @@ func TestLoad(t *testing.T) { } // check no vector can be searched - vs, err := n.Search(ctx, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8}, 10, 0, 0) + vs, err := n.Search([]float32{0, 1, 2, 3, 4, 5, 6, 7, 8}, 10, 0, 0) if err != nil && !errors.Is(err, errors.ErrEmptySearchResult) { return err } @@ -362,8 +361,8 @@ func TestLoad(t *testing.T) { mu: &sync.RWMutex{}, }, }, - checkFunc: func(ctx context.Context, w want, n NGT, e error) error { - if err := defaultCheckFunc(ctx, w, n, e); err != nil { + checkFunc: func(w want, n NGT, e error) error { + if err := defaultCheckFunc(w, n, e); err != nil { return err } @@ -377,7 +376,7 @@ func TestLoad(t *testing.T) { } // check inserted vector can be searched - vs, err := n.Search(ctx, vec, 10, 0, 0) + vs, err := n.Search(vec, 10, 0, 0) if err != nil && !errors.Is(err, errors.ErrEmptySearchResult) { return err } @@ -427,8 +426,8 @@ func TestLoad(t *testing.T) { mu: &sync.RWMutex{}, }, }, - checkFunc: func(ctx context.Context, w want, n NGT, e error) error { - if err := defaultCheckFunc(ctx, w, n, e); err != nil { + checkFunc: func(w want, n NGT, e error) error { + if err := defaultCheckFunc(w, n, e); err != nil { return err } @@ -439,7 +438,7 @@ func TestLoad(t *testing.T) { } // check no vector can be searched - vs, err := n.Search(ctx, []float32{0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}, 10, 0, 0) + vs, err := n.Search([]float32{0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}, 10, 0, 0) if err != nil && !errors.Is(err, errors.ErrEmptySearchResult) { return err } @@ -492,8 +491,8 @@ func TestLoad(t *testing.T) { mu: &sync.RWMutex{}, }, }, - checkFunc: func(ctx context.Context, w want, n NGT, e error) error { - if err := defaultCheckFunc(ctx, w, n, e); err != nil { + checkFunc: func(w want, n NGT, e error) error { + if err := defaultCheckFunc(w, n, e); err != nil { return err } @@ -507,7 +506,7 @@ func TestLoad(t *testing.T) { } // check inserted vector can be searched - vs, err := n.Search(ctx, vec, 10, 0, 0) + vs, err := n.Search(vec, 10, 0, 0) if err != nil && !errors.Is(err, errors.ErrEmptySearchResult) { return err } @@ -538,8 +537,8 @@ func TestLoad(t *testing.T) { want: nil, err: errors.ErrIndexFileNotFound, }, - checkFunc: func(ctx context.Context, w want, n NGT, e error) error { - if err := defaultCheckFunc(ctx, w, n, e); err != nil { + checkFunc: func(w want, n NGT, e error) error { + if err := defaultCheckFunc(w, n, e); err != nil { return err } @@ -552,7 +551,7 @@ func TestLoad(t *testing.T) { } // check no vector can be searched - vs, err := n.Search(ctx, vec, 10, 0, 0) + vs, err := n.Search(vec, 10, 0, 0) if err != nil && !errors.Is(err, errors.ErrEmptySearchResult) { return err } @@ -588,7 +587,7 @@ func TestLoad(t *testing.T) { t.Error(err) } }, - checkFunc: func(_ context.Context, w want, n NGT, e error) error { + checkFunc: func(w want, n NGT, e error) error { if e != nil && !errors.As(e, w.err) { t.Error(e) return e @@ -621,10 +620,7 @@ func TestLoad(t *testing.T) { tt.Error(err) } }() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - if err := checkFunc(ctx, test.want, got, err); err != nil { + if err := checkFunc(test.want, got, err); err != nil { tt.Errorf("error = %v", err) } }) @@ -644,11 +640,11 @@ func Test_gen(t *testing.T) { name string args args want want - checkFunc func(context.Context, want, NGT, error) error + checkFunc func(want, NGT, error) error beforeFunc func(*testing.T, args) afterFunc func(*testing.T, NGT) error } - defaultCheckFunc := func(_ context.Context, w want, got NGT, err error) error { + defaultCheckFunc := func(w want, got NGT, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } @@ -731,8 +727,8 @@ func Test_gen(t *testing.T) { mu: &sync.RWMutex{}, }, }, - checkFunc: func(ctx context.Context, w want, n NGT, e error) error { - if err := defaultCheckFunc(ctx, w, n, e); err != nil { + checkFunc: func(w want, n NGT, e error) error { + if err := defaultCheckFunc(w, n, e); err != nil { return err } @@ -746,7 +742,7 @@ func Test_gen(t *testing.T) { } // check inserted vector can be searched - vs, err := n.Search(ctx, vec, 10, 0, 0) + vs, err := n.Search(vec, 10, 0, 0) if err != nil && !errors.Is(err, errors.ErrEmptySearchResult) { return err } @@ -793,10 +789,7 @@ func Test_gen(t *testing.T) { tt.Error(err) } }() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - if err := checkFunc(ctx, test.want, got, err); err != nil { + if err := checkFunc(test.want, got, err); err != nil { tt.Errorf("error = %v", err) } }) @@ -1315,7 +1308,6 @@ func Test_ngt_loadObjectSpace(t *testing.T) { func Test_ngt_Search(t *testing.T) { type args struct { - ctx context.Context vec []float32 size int epsilon float32 @@ -1389,7 +1381,6 @@ func Test_ngt_Search(t *testing.T) { { name: "return vector id after the same vector inserted (uint8)", args: args{ - ctx: context.Background(), vec: []float32{0, 1, 2, 3, 4, 5, 6, 7, 8}, size: 5, epsilon: 0, @@ -1419,7 +1410,6 @@ func Test_ngt_Search(t *testing.T) { { name: "resturn vector id after the nearby vector inserted (uint8)", args: args{ - ctx: context.Background(), vec: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}, size: 5, }, @@ -1447,7 +1437,6 @@ func Test_ngt_Search(t *testing.T) { { name: "return vector ids after insert with multiple vectors (uint8)", args: args{ - ctx: context.Background(), vec: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}, size: 5, }, @@ -1481,7 +1470,6 @@ func Test_ngt_Search(t *testing.T) { { name: "return limited result after insert 10 vectors with limited size 3 (uint8)", args: args{ - ctx: context.Background(), vec: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}, size: 3, }, @@ -1522,7 +1510,6 @@ func Test_ngt_Search(t *testing.T) { { name: "return most accurate result after insert 10 vectors with limited size 5 (uint8)", args: args{ - ctx: context.Background(), vec: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}, size: 5, }, @@ -1566,7 +1553,6 @@ func Test_ngt_Search(t *testing.T) { { name: "return vector id after the same vector inserted (float)", args: args{ - ctx: context.Background(), vec: []float32{0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}, size: 5, epsilon: 0, @@ -1596,7 +1582,6 @@ func Test_ngt_Search(t *testing.T) { { name: "resturn vector id after the nearby vector inserted (float)", args: args{ - ctx: context.Background(), vec: []float32{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.91}, size: 5, }, @@ -1624,7 +1609,6 @@ func Test_ngt_Search(t *testing.T) { { name: "return vector ids after insert with multiple vectors (float)", args: args{ - ctx: context.Background(), vec: []float32{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}, size: 5, }, @@ -1657,7 +1641,6 @@ func Test_ngt_Search(t *testing.T) { { name: "return limited result after insert 10 vectors with limited size 3 (float)", args: args{ - ctx: context.Background(), vec: []float32{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}, size: 3, }, @@ -1698,7 +1681,6 @@ func Test_ngt_Search(t *testing.T) { { name: "return most accurate result after insert 10 vectors with limited size 5 (float)", args: args{ - ctx: context.Background(), vec: []float32{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}, size: 5, }, @@ -1742,7 +1724,6 @@ func Test_ngt_Search(t *testing.T) { { name: "return nothing if the search dimension is less than the inserted vector", args: args{ - ctx: context.Background(), vec: []float32{0, 1, 2, 3, 4, 5, 6, 7}, size: 5, epsilon: 0, @@ -1770,7 +1751,6 @@ func Test_ngt_Search(t *testing.T) { { name: "return nothing if the search dimension is more than the inserted vector", args: args{ - ctx: context.Background(), vec: []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, size: 5, epsilon: 0, @@ -1798,32 +1778,6 @@ func Test_ngt_Search(t *testing.T) { { name: "return ErrEmptySearchResult error if there is no inserted vector", args: args{ - ctx: context.Background(), - vec: []float32{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}, - size: 3, - }, - fields: fields{ - inMemory: false, - idxPath: "/tmp/ngt-813", - bulkInsertChunkSize: 100, - dimension: 9, - objectType: Float, - radius: float32(-1.0), - epsilon: float32(0.1), - }, - createFunc: defaultCreateFunc, - want: want{ - err: errors.ErrEmptySearchResult, - }, - }, - { - name: "return ErrEmptySearchResult error if the context is canceled", - args: args{ - ctx: func() context.Context { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - return ctx - }(), vec: []float32{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}, size: 3, }, @@ -1847,9 +1801,6 @@ func Test_ngt_Search(t *testing.T) { test := tc t.Run(test.name, func(tt *testing.T) { tt.Parallel() - ctx, cancel := context.WithCancel(test.args.ctx) - defer cancel() - defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) if test.beforeFunc != nil { test.beforeFunc(test.args) @@ -1870,7 +1821,7 @@ func Test_ngt_Search(t *testing.T) { tt.Fatal(err) } - got, err := n.Search(ctx, test.args.vec, test.args.size, test.args.epsilon, test.args.radius) + got, err := n.Search(test.args.vec, test.args.size, test.args.epsilon, test.args.radius) if err := checkFunc(test.want, got, n, err); err != nil { tt.Errorf("error = %v", err) } @@ -1906,7 +1857,7 @@ func Test_ngt_Insert(t *testing.T) { fields fields createFunc func(*testing.T, fields) (NGT, error) want want - checkFunc func(context.Context, want, uint, NGT, args, error) error + checkFunc func(want, uint, NGT, args, error) error beforeFunc func(args) afterFunc func(*testing.T, NGT) error } @@ -1924,7 +1875,7 @@ func Test_ngt_Insert(t *testing.T) { WithDimension(fields.dimension), ) } - defaultCheckFunc := func(ctx context.Context, w want, got uint, n NGT, args args, err error) error { + defaultCheckFunc := func(w want, got uint, n NGT, args args, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } @@ -1937,7 +1888,7 @@ func Test_ngt_Insert(t *testing.T) { } // search before indexing, it should return nothing - r, err := n.Search(ctx, args.vec, 5, 0, 0) + r, err := n.Search(args.vec, 5, 0, 0) if err != nil && !errors.Is(err, errors.ErrEmptySearchResult) { return err } @@ -1949,7 +1900,7 @@ func Test_ngt_Insert(t *testing.T) { if err := n.CreateIndex(1); err != nil { return err } - r, err = n.Search(ctx, args.vec, 5, 0, 0) + r, err = n.Search(args.vec, 5, 0, 0) if err != nil { return err } @@ -2135,11 +2086,9 @@ func Test_ngt_Insert(t *testing.T) { if err != nil { tt.Fatal(err) } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() got, err := n.Insert(test.args.vec) - if err := checkFunc(ctx, test.want, got, n, test.args, err); err != nil { + if err := checkFunc(test.want, got, n, test.args, err); err != nil { tt.Errorf("error = %v", err) } @@ -2175,7 +2124,7 @@ func Test_ngt_InsertCommit(t *testing.T) { fields fields createFunc func(*testing.T, fields) (NGT, error) want want - checkFunc func(context.Context, want, uint, NGT, args, error) error + checkFunc func(want, uint, NGT, args, error) error beforeFunc func(args) afterFunc func(*testing.T, NGT) error } @@ -2193,7 +2142,7 @@ func Test_ngt_InsertCommit(t *testing.T) { WithDimension(fields.dimension), ) } - defaultCheckFunc := func(ctx context.Context, w want, got uint, n NGT, args args, err error) error { + defaultCheckFunc := func(w want, got uint, n NGT, args args, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } @@ -2204,8 +2153,7 @@ func Test_ngt_InsertCommit(t *testing.T) { if got == 0 { return nil } - - r, err := n.Search(ctx, args.vec, 5, 0, 0) + r, err := n.Search(args.vec, 5, 0, 0) if err != nil { return err } @@ -2391,11 +2339,9 @@ func Test_ngt_InsertCommit(t *testing.T) { if err != nil { tt.Fatal(err) } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() got, err := n.InsertCommit(test.args.vec, test.args.poolSize) - if err := checkFunc(ctx, test.want, got, n, test.args, err); err != nil { + if err := checkFunc(test.want, got, n, test.args, err); err != nil { tt.Errorf("error = %v", err) } @@ -2430,7 +2376,7 @@ func Test_ngt_BulkInsert(t *testing.T) { fields fields createFunc func(*testing.T, fields) (NGT, error) want want - checkFunc func(context.Context, want, []uint, NGT, fields, args, []error) error + checkFunc func(want, []uint, NGT, fields, args, []error) error beforeFunc func(args) afterFunc func(*testing.T, NGT) error } @@ -2448,7 +2394,7 @@ func Test_ngt_BulkInsert(t *testing.T) { WithDimension(fields.dimension), ) } - defaultCheckFunc := func(ctx context.Context, w want, got []uint, n NGT, fields fields, args args, got1 []error) error { + defaultCheckFunc := func(w want, got []uint, n NGT, fields fields, args args, got1 []error) error { if diff := comparator.Diff(w.want1, got1, comparator.ErrorComparer); diff != "" { return errors.New(diff) } @@ -2472,7 +2418,7 @@ func Test_ngt_BulkInsert(t *testing.T) { if len(vec) != fields.dimension { continue } - r, err := n.Search(ctx, vec, 1, 0, 0) + r, err := n.Search(vec, 1, 0, 0) if err != nil && !errors.Is(err, errors.ErrEmptySearchResult) { return err } @@ -2489,7 +2435,7 @@ func Test_ngt_BulkInsert(t *testing.T) { if len(vec) != fields.dimension { continue } - r, err := n.Search(ctx, vec, 1, 0, 0) + r, err := n.Search(vec, 1, 0, 0) if err != nil { return err } @@ -2709,11 +2655,9 @@ func Test_ngt_BulkInsert(t *testing.T) { tt.Error(err) } }() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() got, got1 := n.BulkInsert(test.args.vecs) - if err := checkFunc(ctx, test.want, got, n, test.fields, test.args, got1); err != nil { + if err := checkFunc(test.want, got, n, test.fields, test.args, got1); err != nil { tt.Errorf("error = %v", err) } }) @@ -2745,7 +2689,7 @@ func Test_ngt_BulkInsertCommit(t *testing.T) { fields fields createFunc func(*testing.T, fields) (NGT, error) want want - checkFunc func(context.Context, want, []uint, NGT, fields, args, []error) error + checkFunc func(want, []uint, NGT, fields, args, []error) error beforeFunc func(args) afterFunc func(*testing.T, NGT) error } @@ -2763,7 +2707,7 @@ func Test_ngt_BulkInsertCommit(t *testing.T) { WithDimension(fields.dimension), ) } - defaultCheckFunc := func(ctx context.Context, w want, got []uint, n NGT, fields fields, args args, got1 []error) error { + defaultCheckFunc := func(w want, got []uint, n NGT, fields fields, args args, got1 []error) error { if diff := comparator.Diff(w.want1, got1, comparator.ErrorComparer); diff != "" { return errors.New(diff) } @@ -2785,7 +2729,7 @@ func Test_ngt_BulkInsertCommit(t *testing.T) { if len(vec) != fields.dimension { continue } - r, err := n.Search(ctx, vec, 1, 0, 0) + r, err := n.Search(vec, 1, 0, 0) if err != nil { return err } @@ -3005,11 +2949,9 @@ func Test_ngt_BulkInsertCommit(t *testing.T) { tt.Error(err) } }() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() got, got1 := n.BulkInsertCommit(test.args.vecs, test.args.poolSize) - if err := checkFunc(ctx, test.want, got, n, test.fields, test.args, got1); err != nil { + if err := checkFunc(test.want, got, n, test.fields, test.args, got1); err != nil { tt.Errorf("error = %v", err) } }) @@ -3039,7 +2981,7 @@ func Test_ngt_CreateAndSaveIndex(t *testing.T) { fields fields createFunc func(*testing.T, fields) (NGT, error) want want - checkFunc func(context.Context, want, NGT, args, error) error + checkFunc func(want, NGT, args, error) error beforeFunc func(args) afterFunc func(*testing.T, NGT) error } @@ -3057,7 +2999,7 @@ func Test_ngt_CreateAndSaveIndex(t *testing.T) { WithDimension(fields.dimension), ) } - defaultCheckFunc := func(_ context.Context, w want, n NGT, args args, got error) error { + defaultCheckFunc := func(w want, n NGT, args args, got error) error { if diff := comparator.Diff(w.err, got); diff != "" { return errors.New(diff) } @@ -3144,14 +3086,14 @@ func Test_ngt_CreateAndSaveIndex(t *testing.T) { return ngt, err }, - checkFunc: func(ctx context.Context, w want, n NGT, a args, e error) error { - if err := defaultCheckFunc(ctx, w, n, a, e); err != nil { + checkFunc: func(w want, n NGT, a args, e error) error { + if err := defaultCheckFunc(w, n, a, e); err != nil { return err } // search the inserted vector exists after create index for _, v := range ivs { - if rs, err := n.Search(ctx, v, 1, 0, 0); err != nil { + if rs, err := n.Search(v, 1, 0, 0); err != nil { if rs[0].Distance != 0 { return errors.Errorf("vector distance is invalid, got: %d, want: %d", rs[0].Distance, 0) } @@ -3203,14 +3145,14 @@ func Test_ngt_CreateAndSaveIndex(t *testing.T) { return ngt, err }, - checkFunc: func(ctx context.Context, w want, n NGT, a args, e error) error { - if err := defaultCheckFunc(ctx, w, n, a, e); err != nil { + checkFunc: func(w want, n NGT, a args, e error) error { + if err := defaultCheckFunc(w, n, a, e); err != nil { return err } // search the inserted vector exists after create index for _, v := range ivs { - if rs, err := n.Search(ctx, v, 1, 0, 0); err != nil { + if rs, err := n.Search(v, 1, 0, 0); err != nil { if rs[0].Distance != 0 { return errors.Errorf("vector distance is invalid, got: %d, want: %d", rs[0].Distance, 0) } @@ -3264,11 +3206,9 @@ func Test_ngt_CreateAndSaveIndex(t *testing.T) { tt.Error(err) } }() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() err = n.CreateAndSaveIndex(test.args.poolSize) - if err := checkFunc(ctx, test.want, n, test.args, err); err != nil { + if err := checkFunc(test.want, n, test.args, err); err != nil { tt.Errorf("error = %v", err) } }) @@ -3298,7 +3238,7 @@ func Test_ngt_CreateIndex(t *testing.T) { fields fields createFunc func(*testing.T, fields) (NGT, error) want want - checkFunc func(context.Context, want, NGT, args, error) error + checkFunc func(want, NGT, args, error) error beforeFunc func(args) afterFunc func(*testing.T, NGT) error } @@ -3316,7 +3256,7 @@ func Test_ngt_CreateIndex(t *testing.T) { WithDimension(fields.dimension), ) } - defaultCheckFunc := func(_ context.Context, w want, n NGT, args args, got error) error { + defaultCheckFunc := func(w want, n NGT, args args, got error) error { if diff := comparator.Diff(w.err, got); diff != "" { return errors.New(diff) } @@ -3403,14 +3343,14 @@ func Test_ngt_CreateIndex(t *testing.T) { return ngt, err }, - checkFunc: func(ctx context.Context, w want, n NGT, a args, e error) error { - if err := defaultCheckFunc(ctx, w, n, a, e); err != nil { + checkFunc: func(w want, n NGT, a args, e error) error { + if err := defaultCheckFunc(w, n, a, e); err != nil { return err } // search the inserted vector exists after create index for _, v := range ivs { - if rs, err := n.Search(ctx, v, 1, 0, 0); err != nil { + if rs, err := n.Search(v, 1, 0, 0); err != nil { if rs[0].Distance != 0 { return errors.Errorf("vector distance is invalid, got: %d, want: %d", rs[0].Distance, 0) } @@ -3462,14 +3402,14 @@ func Test_ngt_CreateIndex(t *testing.T) { return ngt, err }, - checkFunc: func(ctx context.Context, w want, n NGT, a args, e error) error { - if err := defaultCheckFunc(ctx, w, n, a, e); err != nil { + checkFunc: func(w want, n NGT, a args, e error) error { + if err := defaultCheckFunc(w, n, a, e); err != nil { return err } // search the inserted vector exists after create index for _, v := range ivs { - if rs, err := n.Search(ctx, v, 1, 0, 0); err != nil { + if rs, err := n.Search(v, 1, 0, 0); err != nil { if rs[0].Distance != 0 { return errors.Errorf("vector distance is invalid, got: %d, want: %d", rs[0].Distance, 0) } @@ -3523,11 +3463,9 @@ func Test_ngt_CreateIndex(t *testing.T) { tt.Error(err) } }() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() err = n.CreateIndex(test.args.poolSize) - if err := checkFunc(ctx, test.want, n, test.args, err); err != nil { + if err := checkFunc(test.want, n, test.args, err); err != nil { tt.Errorf("error = %v", err) } }) @@ -3557,7 +3495,7 @@ func Test_ngt_SaveIndex(t *testing.T) { fields fields createFunc func(*testing.T, fields) (NGT, error) want want - checkFunc func(context.Context, want, NGT, args, error) error + checkFunc func(want, NGT, args, error) error beforeFunc func(args) afterFunc func(*testing.T, NGT) error } @@ -3575,7 +3513,7 @@ func Test_ngt_SaveIndex(t *testing.T) { WithDimension(fields.dimension), ) } - defaultCheckFunc := func(_ context.Context, w want, n NGT, args args, e error) error { + defaultCheckFunc := func(w want, n NGT, args args, e error) error { if ngt, ok := n.(*ngt); ok { _, err := os.Stat(ngt.idxPath) // if ngt is in-memory mode, the index file should not be created @@ -3659,14 +3597,14 @@ func Test_ngt_SaveIndex(t *testing.T) { return ngt, err }, - checkFunc: func(ctx context.Context, w want, n NGT, a args, e error) error { - if err := defaultCheckFunc(ctx, w, n, a, e); err != nil { + checkFunc: func(w want, n NGT, a args, e error) error { + if err := defaultCheckFunc(w, n, a, e); err != nil { return err } // search the inserted vector exists after create index for _, v := range ivs { - if rs, err := n.Search(ctx, v, 1, 0, 0); err != nil && !errors.Is(err, errors.ErrEmptySearchResult) { + if rs, err := n.Search(v, 1, 0, 0); err != nil && !errors.Is(err, errors.ErrEmptySearchResult) { if rs[0].Distance != 0 { return errors.Errorf("vector distance is invalid, got: %d, want: %d", rs[0].Distance, 0) } @@ -3718,14 +3656,14 @@ func Test_ngt_SaveIndex(t *testing.T) { return ngt, err }, - checkFunc: func(ctx context.Context, w want, n NGT, a args, e error) error { - if err := defaultCheckFunc(ctx, w, n, a, e); err != nil { + checkFunc: func(w want, n NGT, a args, e error) error { + if err := defaultCheckFunc(w, n, a, e); err != nil { return err } // search the inserted vector exists after create index for _, v := range ivs { - if rs, err := n.Search(ctx, v, 1, 0, 0); err != nil && !errors.Is(err, errors.ErrEmptySearchResult) { + if rs, err := n.Search(v, 1, 0, 0); err != nil && !errors.Is(err, errors.ErrEmptySearchResult) { if rs[0].Distance != 0 { return errors.Errorf("vector distance is invalid, got: %d, want: %d", rs[0].Distance, 0) } @@ -3779,11 +3717,9 @@ func Test_ngt_SaveIndex(t *testing.T) { tt.Error(err) } }() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() err = n.SaveIndex() - if err := checkFunc(ctx, test.want, n, test.args, err); err != nil { + if err := checkFunc(test.want, n, test.args, err); err != nil { tt.Errorf("error = %v", err) } }) diff --git a/pkg/agent/core/ngt/handler/grpc/search.go b/pkg/agent/core/ngt/handler/grpc/search.go index 2a9c9d6fdc..328806f7e7 100644 --- a/pkg/agent/core/ngt/handler/grpc/search.go +++ b/pkg/agent/core/ngt/handler/grpc/search.go @@ -69,7 +69,7 @@ func (s *server) Search(ctx context.Context, req *payload.Search_Request) (res * return nil, err } res, err = toSearchResponse( - s.ngt.Search(ctx, + s.ngt.Search( req.GetVector(), req.GetConfig().GetNum(), req.GetConfig().GetEpsilon(), @@ -195,7 +195,7 @@ func (s *server) SearchByID(ctx context.Context, req *payload.Search_IDRequest) } return nil, err } - vec, dst, err := s.ngt.SearchByID(ctx, + vec, dst, err := s.ngt.SearchByID( uuid, req.GetConfig().GetNum(), req.GetConfig().GetEpsilon(), diff --git a/pkg/agent/core/ngt/service/ngt.go b/pkg/agent/core/ngt/service/ngt.go index 56377b38a3..f83da3826c 100644 --- a/pkg/agent/core/ngt/service/ngt.go +++ b/pkg/agent/core/ngt/service/ngt.go @@ -50,8 +50,8 @@ import ( type NGT interface { Start(ctx context.Context) <-chan error - Search(ctx context.Context, vec []float32, size uint32, epsilon, radius float32) ([]model.Distance, error) - SearchByID(ctx context.Context, uuid string, size uint32, epsilon, radius float32) ([]float32, []model.Distance, error) + Search(vec []float32, size uint32, epsilon, radius float32) ([]model.Distance, error) + SearchByID(uuid string, size uint32, epsilon, radius float32) ([]float32, []model.Distance, error) LinearSearch(vec []float32, size uint32) ([]model.Distance, error) LinearSearchByID(uuid string, size uint32) ([]float32, []model.Distance, error) Insert(uuid string, vec []float32) (err error) @@ -869,11 +869,11 @@ func (n *ngt) Start(ctx context.Context) <-chan error { return ech } -func (n *ngt) Search(ctx context.Context, vec []float32, size uint32, epsilon, radius float32) ([]model.Distance, error) { +func (n *ngt) Search(vec []float32, size uint32, epsilon, radius float32) ([]model.Distance, error) { if n.IsIndexing() { return nil, errors.ErrCreateIndexingIsInProgress } - sr, err := n.core.Search(ctx, vec, int(size), epsilon, radius) + sr, err := n.core.Search(vec, int(size), epsilon, radius) if err != nil { if n.IsIndexing() { return nil, errors.ErrCreateIndexingIsInProgress @@ -888,11 +888,6 @@ func (n *ngt) Search(ctx context.Context, vec []float32, size uint32, epsilon, r ds := make([]model.Distance, 0, len(sr)) for _, d := range sr { - select { - case <-ctx.Done(): - return ds, nil - default: - } if err = d.Error; d.ID == 0 && err != nil { log.Warnf("an error occurred while searching: %s", err) continue @@ -911,7 +906,7 @@ func (n *ngt) Search(ctx context.Context, vec []float32, size uint32, epsilon, r return ds, nil } -func (n *ngt) SearchByID(ctx context.Context, uuid string, size uint32, epsilon, radius float32) (vec []float32, dst []model.Distance, err error) { +func (n *ngt) SearchByID(uuid string, size uint32, epsilon, radius float32) (vec []float32, dst []model.Distance, err error) { if n.IsIndexing() { return nil, nil, errors.ErrCreateIndexingIsInProgress } @@ -919,7 +914,7 @@ func (n *ngt) SearchByID(ctx context.Context, uuid string, size uint32, epsilon, if err != nil { return nil, nil, err } - dst, err = n.Search(ctx, vec, size, epsilon, radius) + dst, err = n.Search(vec, size, epsilon, radius) if err != nil { return vec, nil, err } diff --git a/pkg/agent/core/ngt/service/ngt_stateful_test.go b/pkg/agent/core/ngt/service/ngt_stateful_test.go index a3a1bd6e61..fb3da107b3 100644 --- a/pkg/agent/core/ngt/service/ngt_stateful_test.go +++ b/pkg/agent/core/ngt/service/ngt_stateful_test.go @@ -511,10 +511,9 @@ var ( RunFunc: func( systemUnderTest commands.SystemUnderTest, ) commands.Result { - ngtSys := systemUnderTest.(*ngtSystem) - ngt := ngtSys.ngt + ngt := systemUnderTest.(*ngtSystem).ngt - res, err := ngt.Search(ngtSys.ctx, []float32{0.1, 0.1, 0.1}, 3, 0.1, -1.0) + res, err := ngt.Search([]float32{0.1, 0.1, 0.1}, 3, 0.1, -1.0) return &resultContainer{ err: err, results: res, @@ -584,10 +583,9 @@ var ( RunFunc: func( systemUnderTest commands.SystemUnderTest, ) commands.Result { - ngtSys := systemUnderTest.(*ngtSystem) - ngt := ngtSys.ngt + ngt := systemUnderTest.(*ngtSystem).ngt - _, res, err := ngt.SearchByID(ngtSys.ctx, idA, 3, 0.1, -1.0) + _, res, err := ngt.SearchByID(idA, 3, 0.1, -1.0) return &resultContainer{ err: err, results: res, @@ -664,10 +662,9 @@ var ( RunFunc: func( systemUnderTest commands.SystemUnderTest, ) commands.Result { - ngtSys := systemUnderTest.(*ngtSystem) - ngt := ngtSys.ngt + ngt := systemUnderTest.(*ngtSystem).ngt - _, res, err := ngt.SearchByID(ngtSys.ctx, idB, 3, 0.1, -1.0) + _, res, err := ngt.SearchByID(idB, 3, 0.1, -1.0) return &resultContainer{ err: err, results: res, @@ -744,10 +741,9 @@ var ( RunFunc: func( systemUnderTest commands.SystemUnderTest, ) commands.Result { - ngtSys := systemUnderTest.(*ngtSystem) - ngt := ngtSys.ngt + ngt := systemUnderTest.(*ngtSystem).ngt - _, res, err := ngt.SearchByID(ngtSys.ctx, idC, 3, 0.1, -1.0) + _, res, err := ngt.SearchByID(idC, 3, 0.1, -1.0) return &resultContainer{ err: err, results: res,