Skip to content

Commit

Permalink
Add: Multi-Index lookups
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Jul 25, 2023
1 parent c04a5cc commit c5b7ccd
Show file tree
Hide file tree
Showing 4 changed files with 308 additions and 123 deletions.
21 changes: 21 additions & 0 deletions include/usearch/index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2083,6 +2083,27 @@ class index_gt {
node_t node = index_.node_with_id_(candidate.id);
return {member_cref_t{node.label(), node.vector_view(), candidate.id}, candidate.distance};
}
inline std::size_t merge_into( //
label_t* labels, distance_t* distances, //
std::size_t old_count, std::size_t max_count) const noexcept {

std::size_t merged_count = old_count;
for (std::size_t i = 0; i != count; ++i) {
match_t result = operator[](i);
auto merged_end = distances + merged_count;
auto offset = std::lower_bound(distances, merged_end, result.distance) - distances;
if (offset == max_count)
continue;

std::size_t count_worse = merged_count - offset - (max_count == merged_count);
std::memmove(labels + offset + 1, labels + offset, count_worse * sizeof(label_t));
std::memmove(distances + offset + 1, distances + offset, count_worse * sizeof(distance_t));
labels[merged_count] = result.member.label;
distances[merged_count] = result.distance;
merged_count += 1;
}
return merged_count;
}
inline std::size_t dump_to(label_t* labels, distance_t* distances) const noexcept {
for (std::size_t i = 0; i != count; ++i) {
match_t result = operator[](i);
Expand Down
110 changes: 93 additions & 17 deletions python/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,26 @@ struct dense_index_py_t : public dense_index_t {
dense_index_py_t(native_t&& base) : native_t(std::move(base)) {}
};

struct dense_indexes_py_t {
std::vector<std::shared_ptr<dense_index_py_t>> shards_;

void add(std::shared_ptr<dense_index_py_t> shard) { shards_.push_back(shard); }
std::size_t scalar_words() const noexcept { return shards_.empty() ? 0 : shards_[0]->scalar_words(); }
index_limits_t limits() const noexcept { return {size(), std::numeric_limits<std::size_t>::max()}; }

std::size_t size() const noexcept {
std::size_t result = 0;
for (auto const& shard : shards_)
result += shard->size();
return result;
}

void reserve(index_limits_t) {
for (auto const& shard : shards_)
shard->reserve({shard->size(), 1});
}
};

using set_member_t = std::uint32_t;
using set_view_t = span_gt<set_member_t const>;
using sparse_index_t = index_gt<jaccard_gt<set_member_t>, label_t, id_t>;
Expand Down Expand Up @@ -194,8 +214,9 @@ static void add_typed_to_index( //
});
}

static void add_many_to_index( //
dense_index_py_t& index, py::buffer labels, py::buffer vectors, //
template <typename index_at>
static void add_many_to_index( //
index_at& index, py::buffer labels, py::buffer vectors, //
bool copy, std::size_t threads) {

py::buffer_info labels_info = labels.request();
Expand Down Expand Up @@ -267,6 +288,49 @@ static void search_typed( //
});
}

template <typename scalar_at>
static void search_typed( //
dense_indexes_py_t const& indexes, py::buffer_info& vectors_info, //
std::size_t wanted, bool exact, std::size_t threads, //
py::array_t<label_t>& labels_py, py::array_t<distance_t>& distances_py, py::array_t<Py_ssize_t>& counts_py) {

auto labels_py2d = labels_py.template mutable_unchecked<2>();
auto distances_py2d = distances_py.template mutable_unchecked<2>();
auto counts_py1d = counts_py.template mutable_unchecked<1>();

Py_ssize_t vectors_count = vectors_info.shape[0];
char const* vectors_data = reinterpret_cast<char const*>(vectors_info.ptr);
for (std::size_t vector_idx = 0; vector_idx != static_cast<std::size_t>(vectors_count); ++vector_idx)
counts_py1d(vector_idx) = 0;

if (!threads)
threads = std::thread::hardware_concurrency();

std::vector<std::mutex> vectors_mutexes(static_cast<std::size_t>(vectors_count));
executor_default_t{threads}.execute_bulk(indexes.size(), [&](std::size_t, std::size_t task_idx) {
dense_index_py_t const& index = *indexes.shards_[task_idx].get();

search_config_t config;
config.thread = 0;
config.exact = exact;
for (std::size_t vector_idx = 0; vector_idx != static_cast<std::size_t>(vectors_count); ++vector_idx) {
scalar_at const* vector = (scalar_at const*)(vectors_data + vector_idx * vectors_info.strides[0]);
dense_search_result_t result = index.search(vector, wanted, config);
result.error.raise();
{
std::unique_lock<std::mutex> lock(vectors_mutexes[vector_idx]);
counts_py1d(vector_idx) = static_cast<Py_ssize_t>(result.merge_into( //
&labels_py2d(vector_idx, 0), //
&distances_py2d(vector_idx, 0), //
static_cast<std::size_t>(counts_py1d(vector_idx)), //
wanted));
}
if (PyErr_CheckSignals() != 0)
throw py::error_already_set();
}
});
}

/**
* @param vectors Matrix of vectors to search for.
* @param wanted Number of matches per request.
Expand All @@ -276,8 +340,9 @@ static void search_typed( //
* 2. matrix of distances,
* 3. array with match counts.
*/
template <typename index_at>
static py::tuple search_many_in_index( //
dense_index_py_t& index, py::buffer vectors, std::size_t wanted, bool exact, std::size_t threads) {
index_at& index, py::buffer vectors, std::size_t wanted, bool exact, std::size_t threads) {

if (wanted == 0)
return py::tuple(3);
Expand Down Expand Up @@ -560,7 +625,7 @@ PYBIND11_MODULE(compiled, m) {
h.def_readonly("bytes_for_vectors", &file_head_result_t::bytes_for_vectors);
h.def_readonly("bytes_checksum", &file_head_result_t::bytes_checksum);

auto i = py::class_<dense_index_py_t>(m, "Index");
auto i = py::class_<dense_index_py_t, std::shared_ptr<dense_index_py_t>>(m, "Index");

i.def(py::init(&make_index), //
py::kw_only(), //
Expand All @@ -575,21 +640,21 @@ PYBIND11_MODULE(compiled, m) {
py::arg("tune") = false //
);

i.def( //
"add", &add_many_to_index, //
py::arg("labels"), //
py::arg("vectors"), //
py::kw_only(), //
py::arg("copy") = true, //
py::arg("threads") = 0 //
i.def( //
"add", &add_many_to_index<dense_index_py_t>, //
py::arg("labels"), //
py::arg("vectors"), //
py::kw_only(), //
py::arg("copy") = true, //
py::arg("threads") = 0 //
);

i.def( //
"search", &search_many_in_index, //
py::arg("query"), //
py::arg("count") = 10, //
py::arg("exact") = false, //
py::arg("threads") = 0 //
i.def( //
"search", &search_many_in_index<dense_index_py_t>, //
py::arg("query"), //
py::arg("count") = 10, //
py::arg("exact") = false, //
py::arg("threads") = 0 //
);

i.def(
Expand Down Expand Up @@ -677,6 +742,17 @@ PYBIND11_MODULE(compiled, m) {
i.def_property_readonly("levels_stats", &compute_stats<dense_index_py_t>);
i.def("level_stats", &compute_level_stats<dense_index_py_t>, py::arg("level"));

auto is = py::class_<dense_indexes_py_t>(m, "Indexes");
is.def(py::init());
is.def("add", &dense_indexes_py_t::add);
is.def( //
"search", &search_many_in_index<dense_indexes_py_t>, //
py::arg("query"), //
py::arg("count") = 10, //
py::arg("exact") = false, //
py::arg("threads") = 0 //
);

auto si = py::class_<sparse_index_py_t>(m, "SparseIndex");

si.def( //
Expand Down
15 changes: 15 additions & 0 deletions python/scripts/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from usearch.index import (
Index,
Indexes,
SparseIndex,
MetricKind,
ScalarKind,
Expand Down Expand Up @@ -247,6 +248,20 @@ def test_exact_recall(
assert man == woman, "Stable marriage failed"


def test_indexes():
ndim = 10
index_a = Index(ndim=ndim)
index_b = Index(ndim=ndim)

vectors = random_vectors(count=3, ndim=ndim)
index_a.add(42, vectors[0])
index_b.add(43, vectors[1])

indexes = Indexes([index_a, index_b])
matches = indexes.search(vectors[2], 10)
assert len(matches) == 2


@pytest.mark.parametrize("bits", dimensions)
@pytest.mark.parametrize("metric", hash_metrics)
@pytest.mark.parametrize("connectivity", connectivity_options)
Expand Down
Loading

0 comments on commit c5b7ccd

Please sign in to comment.