Skip to content

Commit

Permalink
Test: Oversubscribed search
Browse files Browse the repository at this point in the history
Relates to #393
  • Loading branch information
ashvardanian committed Apr 11, 2024
1 parent 92e0b94 commit 7db0c39
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
19 changes: 19 additions & 0 deletions cpp/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,25 @@ void test_cosine(index_at& index, std::vector<std::vector<scalar_at>> const& vec
expect(matched_keys[0] == key_first);
expect(std::abs(matched_distances[0]) < 0.01);

// Check over-sampling beyond the size of the collection
{
std::size_t max_possible_matches = vectors.size();
std::size_t count_requested = max_possible_matches * 4;
std::vector<vector_key_t> matched_keys(count_requested);
std::vector<distance_t> matched_distances(count_requested);

matched_count = index //
.search(vector_first, count_requested, args...) //
.dump_to(matched_keys.data(), matched_distances.data());
expect(matched_count <= max_possible_matches);
expect(matched_keys[0] == key_first);
expect(std::abs(matched_distances[0]) < 0.01);

// Check that all the distance are monotonically rising
for (std::size_t i = 1; i < matched_count; i++)
expect(matched_distances[i - 1] <= matched_distances[i]);
}

if constexpr (punned_ak) {
std::vector<scalar_t> vec_recovered_from_view(dimensions);
index.get(key_second, vec_recovered_from_view.data());
Expand Down
21 changes: 21 additions & 0 deletions python/scripts/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,27 @@ def test_index_contains_remove_rename(batch_size):
assert np.sum(index.count(removed_keys)) == len(index)


@pytest.mark.parametrize("batch_size", [3, 17, 33])
@pytest.mark.parametrize("threads", [1, 4])
def test_index_oversubscribed_search(batch_size: int, threads: int):
if batch_size <= 1:
return

ndim = 8
index = Index(ndim=ndim, multi=False)
keys = np.arange(batch_size)
vectors = random_vectors(count=batch_size, ndim=ndim)

index.add(keys, vectors, threads=threads)
assert np.all(index.contains(keys))
assert np.all(index.count(keys) == np.ones(batch_size))

batch_matches: BatchMatches = index.search(vectors, batch_size * 10, threads=threads)
for i, match in enumerate(batch_matches):
assert i == match.keys[0]
assert len(match.keys) == batch_size


@pytest.mark.parametrize("ndim", [3, 97, 256])
@pytest.mark.parametrize("metric", [MetricKind.Cos, MetricKind.L2sq])
@pytest.mark.parametrize("batch_size", [500, 1024])
Expand Down

0 comments on commit 7db0c39

Please sign in to comment.