Skip to content

Commit

Permalink
Improve: Exposing search stats to users
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Aug 1, 2023
1 parent fa70779 commit 2779ffc
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 26 deletions.
27 changes: 15 additions & 12 deletions include/usearch/index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1854,7 +1854,7 @@ class index_gt {
struct add_result_t {
error_t error{};
std::size_t new_size{};
std::size_t cycles{};
std::size_t lookups{};
std::size_t measurements{};
std::size_t slot{};

Expand All @@ -1881,8 +1881,11 @@ class index_gt {
: nodes_(index.nodes_), top_(&top) {}

public:
/** @brief Number of search results found. */
std::size_t count{};
std::size_t cycles{};
/** @brief Number of graph nodes traversed. */
std::size_t lookups{};
/** @brief Number of times the distances were computed. */
std::size_t measurements{};
error_t error{};

Expand Down Expand Up @@ -2021,7 +2024,7 @@ class index_gt {

// Pull stats
result.measurements = context.measurements_count;
result.cycles = context.iteration_cycles;
result.lookups = context.iteration_cycles;

connect_node_across_levels_( //
new_slot, value, metric, //
Expand All @@ -2030,7 +2033,7 @@ class index_gt {

// Normalize stats
result.measurements = context.measurements_count - result.measurements;
result.cycles = context.iteration_cycles - result.cycles;
result.lookups = context.iteration_cycles - result.lookups;

// Updating the entry point if needed
if (target_level > max_level_copy) {
Expand Down Expand Up @@ -2100,7 +2103,7 @@ class index_gt {

// Pull stats
result.measurements = context.measurements_count;
result.cycles = context.iteration_cycles;
result.lookups = context.iteration_cycles;

connect_node_across_levels_( //
old_slot, value, metric, //
Expand All @@ -2110,7 +2113,7 @@ class index_gt {

// Normalize stats
result.measurements = context.measurements_count - result.measurements;
result.cycles = context.iteration_cycles - result.cycles;
result.lookups = context.iteration_cycles - result.lookups;
result.slot = old_slot;

callback(at(old_slot));
Expand Down Expand Up @@ -2139,7 +2142,7 @@ class index_gt {

// Go down the level, tracking only the closest match
result.measurements = context.measurements_count;
result.cycles = context.iteration_cycles;
result.lookups = context.iteration_cycles;

if (config.exact) {
if (!top.reserve(wanted))
Expand All @@ -2165,7 +2168,7 @@ class index_gt {

// Normalize stats
result.measurements = context.measurements_count - result.measurements;
result.cycles = context.iteration_cycles - result.cycles;
result.lookups = context.iteration_cycles - result.lookups;
result.count = top.size();
return result;
}
Expand Down Expand Up @@ -2835,7 +2838,7 @@ struct join_result_t {
error_t error{};
std::size_t intersection_size{};
std::size_t engagements{};
std::size_t cycles{};
std::size_t lookups{};
std::size_t measurements{};

explicit operator bool() const noexcept { return !error; }
Expand Down Expand Up @@ -2946,7 +2949,7 @@ static join_result_t join( //
std::atomic<std::size_t> rounds{0};
std::atomic<std::size_t> engagements{0};
std::atomic<std::size_t> measurements{0};
std::atomic<std::size_t> cycles{0};
std::atomic<std::size_t> lookups{0};

// Concurrently process all the men
executor.execute_bulk([&](std::size_t thread_idx) {
Expand Down Expand Up @@ -2979,7 +2982,7 @@ static join_result_t join( //
// Find the closest woman, to whom this man hasn't proposed yet.
++free_man_proposals;
auto candidates = women.search(men_values[free_man_slot], free_man_proposals, women_metric, search_config);
cycles += candidates.cycles;
lookups += candidates.lookups;
measurements += candidates.measurements;
if (!candidates) {
// TODO:
Expand Down Expand Up @@ -3042,7 +3045,7 @@ static join_result_t join( //
result.engagements = engagements;
result.intersection_size = intersection_size;
result.measurements = measurements;
result.cycles = cycles;
result.lookups = lookups;
return result;
}

Expand Down
42 changes: 28 additions & 14 deletions python/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ template <typename scalar_at>
static void search_typed( //
dense_index_py_t& index, py::buffer_info& vectors_info, //
std::size_t wanted, bool exact, std::size_t threads, //
py::array_t<key_t>& keys_py, py::array_t<distance_t>& distances_py, py::array_t<Py_ssize_t>& counts_py) {
py::array_t<key_t>& keys_py, py::array_t<distance_t>& distances_py, py::array_t<Py_ssize_t>& counts_py,
std::atomic<std::size_t>& stats_lookups, std::atomic<std::size_t>& stats_measurements) {

auto keys_py2d = keys_py.template mutable_unchecked<2>();
auto distances_py2d = distances_py.template mutable_unchecked<2>();
Expand All @@ -267,6 +268,9 @@ static void search_typed( //
result.error.raise();
counts_py1d(task_idx) =
static_cast<Py_ssize_t>(result.dump_to(&keys_py2d(task_idx, 0), &distances_py2d(task_idx, 0)));

stats_lookups += result.lookups;
stats_measurements += result.measurements;
if (PyErr_CheckSignals() != 0)
throw py::error_already_set();
});
Expand All @@ -276,7 +280,8 @@ template <typename scalar_at>
static void search_typed( //
dense_indexes_py_t& indexes, py::buffer_info& vectors_info, //
std::size_t wanted, bool exact, std::size_t threads, //
py::array_t<key_t>& keys_py, py::array_t<distance_t>& distances_py, py::array_t<Py_ssize_t>& counts_py) {
py::array_t<key_t>& keys_py, py::array_t<distance_t>& distances_py, py::array_t<Py_ssize_t>& counts_py,
std::atomic<std::size_t>& stats_lookups, std::atomic<std::size_t>& stats_measurements) {

auto keys_py2d = keys_py.template mutable_unchecked<2>();
auto distances_py2d = distances_py.template mutable_unchecked<2>();
Expand Down Expand Up @@ -317,6 +322,9 @@ static void search_typed( //
static_cast<std::size_t>(counts_py1d(vector_idx)), //
wanted));
}

stats_lookups += result.lookups;
stats_measurements += result.measurements;
if (PyErr_CheckSignals() != 0)
throw py::error_already_set();
}
Expand Down Expand Up @@ -351,23 +359,29 @@ static py::tuple search_many_in_index( //
if (vectors_dimensions != static_cast<Py_ssize_t>(index.scalar_words()))
throw std::invalid_argument("The number of vector dimensions doesn't match!");

py::array_t<key_t> ls({vectors_count, static_cast<Py_ssize_t>(wanted)});
py::array_t<distance_t> ds({vectors_count, static_cast<Py_ssize_t>(wanted)});
py::array_t<Py_ssize_t> cs(vectors_count);
py::array_t<key_t> keys_py({vectors_count, static_cast<Py_ssize_t>(wanted)});
py::array_t<distance_t> distances_py({vectors_count, static_cast<Py_ssize_t>(wanted)});
py::array_t<Py_ssize_t> counts_py(vectors_count);
std::atomic<std::size_t> stats_lookups(0);
std::atomic<std::size_t> stats_measurements(0);

// clang-format off
switch (numpy_string_to_kind(vectors_info.format)) {
case scalar_kind_t::b1x8_k: search_typed<b1x8_t>(index, vectors_info, wanted, exact, threads, ls, ds, cs); break;
case scalar_kind_t::f8_k: search_typed<f8_bits_t>(index, vectors_info, wanted, exact, threads, ls, ds, cs); break;
case scalar_kind_t::f16_k: search_typed<f16_t>(index, vectors_info, wanted, exact, threads, ls, ds, cs); break;
case scalar_kind_t::f32_k: search_typed<f32_t>(index, vectors_info, wanted, exact, threads, ls, ds, cs); break;
case scalar_kind_t::f64_k: search_typed<f64_t>(index, vectors_info, wanted, exact, threads, ls, ds, cs); break;
case scalar_kind_t::b1x8_k: search_typed<b1x8_t>(index, vectors_info, wanted, exact, threads, keys_py, distances_py, counts_py, stats_lookups, stats_measurements); break;
case scalar_kind_t::f8_k: search_typed<f8_bits_t>(index, vectors_info, wanted, exact, threads, keys_py, distances_py, counts_py, stats_lookups, stats_measurements); break;
case scalar_kind_t::f16_k: search_typed<f16_t>(index, vectors_info, wanted, exact, threads, keys_py, distances_py, counts_py, stats_lookups, stats_measurements); break;
case scalar_kind_t::f32_k: search_typed<f32_t>(index, vectors_info, wanted, exact, threads, keys_py, distances_py, counts_py, stats_lookups, stats_measurements); break;
case scalar_kind_t::f64_k: search_typed<f64_t>(index, vectors_info, wanted, exact, threads, keys_py, distances_py, counts_py, stats_lookups, stats_measurements); break;
default: throw std::invalid_argument("Incompatible scalars in the query matrix: " + vectors_info.format);
}
// clang-format on

py::tuple results(3);
results[0] = ls;
results[1] = ds;
results[2] = cs;
py::tuple results(5);
results[0] = keys_py;
results[1] = distances_py;
results[2] = counts_py;
results[3] = stats_lookups.load();
results[4] = stats_measurements.load();
return results;
}

Expand Down
8 changes: 8 additions & 0 deletions python/usearch/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ class Matches:
keys: np.ndarray
distances: np.ndarray

lookups: int
measurements: int

def __len__(self) -> int:
return len(self.keys)

Expand Down Expand Up @@ -281,6 +284,9 @@ class BatchMatches:
distances: np.ndarray
counts: np.ndarray

lookups: int
measurements: int

def __len__(self) -> int:
return len(self.counts)

Expand All @@ -289,6 +295,8 @@ def __getitem__(self, index: int) -> Matches:
return Matches(
keys=self.keys[index, : self.counts[index]],
distances=self.distances[index, : self.counts[index]],
lookups=self.lookups // len(self),
measurements=self.measurements // len(self),
)
else:
raise IndexError(f"`index` must be an integer under {len(self)}")
Expand Down

0 comments on commit 2779ffc

Please sign in to comment.