Skip to content

Commit

Permalink
Refactor: Keep only batch requests in CPython
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Jul 24, 2023
1 parent 1f89e0a commit 44c0318
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 109 deletions.
85 changes: 0 additions & 85 deletions python/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,34 +172,6 @@ scalar_kind_t numpy_string_to_kind(std::string const& name) {
return scalar_kind_t::unknown_k;
}

static void add_one_to_index(dense_index_py_t& index, label_t label, py::buffer vector, bool copy, std::size_t) {

py::buffer_info vector_info = vector.request();
if (vector_info.ndim != 1)
throw std::invalid_argument("Expects a vector, not a higher-rank tensor!");

Py_ssize_t vector_dimensions = vector_info.shape[0];
char const* vector_data = reinterpret_cast<char const*>(vector_info.ptr);
if (vector_dimensions != static_cast<Py_ssize_t>(index.scalar_words()))
throw std::invalid_argument("The number of vector dimensions doesn't match!");

if (!index.reserve(ceil2(index.size() + 1)))
throw std::invalid_argument("Out of memory!");

add_config_t config;
config.store_vector = copy;

switch (numpy_string_to_kind(vector_info.format)) {
case scalar_kind_t::b1x8_k: index.add(label, (b1x8_t const*)(vector_data), config).error.raise(); break;
case scalar_kind_t::f8_k: index.add(label, (f8_bits_t const*)(vector_data), config).error.raise(); break;
case scalar_kind_t::f16_k: index.add(label, (f16_t const*)(vector_data), config).error.raise(); break;
case scalar_kind_t::f32_k: index.add(label, (f32_t const*)(vector_data), config).error.raise(); break;
case scalar_kind_t::f64_k: index.add(label, (f64_t const*)(vector_data), config).error.raise(); break;
case scalar_kind_t::unknown_k:
throw std::invalid_argument("Incompatible scalars in the vector: " + vector_info.format);
}
}

template <typename scalar_at>
static void add_typed_to_index( //
dense_index_py_t& index, //
Expand Down Expand Up @@ -263,52 +235,6 @@ static void add_many_to_index( //
}
}

static py::tuple search_one_in_index(dense_index_py_t& index, py::buffer vector, std::size_t wanted, bool exact) {

py::buffer_info vector_info = vector.request();
Py_ssize_t vector_dimensions = vector_info.shape[0];
char const* vector_data = reinterpret_cast<char const*>(vector_info.ptr);
if (vector_dimensions != static_cast<Py_ssize_t>(index.scalar_words()))
throw std::invalid_argument("The number of vector dimensions doesn't match!");

constexpr Py_ssize_t vectors_count = 1;
py::array_t<label_t> labels_py({vectors_count, static_cast<Py_ssize_t>(wanted)});
py::array_t<distance_t> distances_py({vectors_count, static_cast<Py_ssize_t>(wanted)});
py::array_t<Py_ssize_t> counts_py(vectors_count);
std::size_t count{};
auto labels_py2d = labels_py.template mutable_unchecked<2>();
auto distances_py2d = distances_py.template mutable_unchecked<2>();
auto counts_py1d = counts_py.template mutable_unchecked<1>();

search_config_t config;
config.exact = exact;

auto raise_and_dump = [&](dense_search_result_t result) {
result.error.raise();
count = result.dump_to(&labels_py2d(0, 0), &distances_py2d(0, 0));
};

switch (numpy_string_to_kind(vector_info.format)) {
case scalar_kind_t::b1x8_k: raise_and_dump(index.search((b1x8_t const*)(vector_data), wanted, config)); break;
case scalar_kind_t::f8_k: raise_and_dump(index.search((f8_bits_t const*)(vector_data), wanted, config)); break;
case scalar_kind_t::f16_k: raise_and_dump(index.search((f16_t const*)(vector_data), wanted, config)); break;
case scalar_kind_t::f32_k: raise_and_dump(index.search((f32_t const*)(vector_data), wanted, config)); break;
case scalar_kind_t::f64_k: raise_and_dump(index.search((f64_t const*)(vector_data), wanted, config)); break;
case scalar_kind_t::unknown_k:
throw std::invalid_argument("Incompatible scalars in the query vector: " + vector_info.format);
}

labels_py.resize(py_shape_t{vectors_count, static_cast<Py_ssize_t>(count)});
distances_py.resize(py_shape_t{vectors_count, static_cast<Py_ssize_t>(count)});
counts_py1d[0] = static_cast<Py_ssize_t>(count);

py::tuple results(3);
results[0] = labels_py;
results[1] = distances_py;
results[2] = counts_py;
return results;
}

template <typename scalar_at>
static void search_typed( //
dense_index_py_t& index, py::buffer_info& vectors_info, //
Expand Down Expand Up @@ -360,8 +286,6 @@ static py::tuple search_many_in_index( //
throw std::invalid_argument("Can't use that many threads!");

py::buffer_info vectors_info = vectors.request();
if (vectors_info.ndim == 1)
return search_one_in_index(index, vectors, wanted, exact);
if (vectors_info.ndim != 2)
throw std::invalid_argument("Expects a matrix of vectors to add!");

Expand Down Expand Up @@ -660,15 +584,6 @@ PYBIND11_MODULE(compiled, m) {
py::arg("threads") = 0 //
);

i.def( //
"add", &add_one_to_index, //
py::arg("label"), //
py::arg("vector"), //
py::kw_only(), //
py::arg("copy") = true, //
py::arg("threads") = 0 //
);

i.def( //
"search", &search_many_in_index, //
py::arg("query"), //
Expand Down
15 changes: 9 additions & 6 deletions python/scripts/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
SparseIndex,
MetricKind,
ScalarKind,
Match,
Matches,
BatchMatches,
)
from usearch.index import (
DEFAULT_CONNECTIVITY,
Expand Down Expand Up @@ -177,9 +179,9 @@ def test_index_batch(
assert len(index) == batch_size
assert np.allclose(index.get_vectors(labels).astype(numpy_type), vectors, atol=0.1)

matches: Matches = index.search(vectors, 10, threads=2)
matches: BatchMatches = index.search(vectors, 10, threads=2)
assert matches.labels.shape[0] == matches.distances.shape[0]
assert matches.counts.shape[0] == batch_size
assert len(matches) == batch_size
assert np.all(np.sort(index.labels) == np.sort(labels))

if batch_size > 1:
Expand Down Expand Up @@ -232,10 +234,11 @@ def test_exact_recall(
assert found_labels[0] == i

# Search the whole batch
matches: Matches = index.search(vectors, 10, exact=True)
found_labels = matches.labels
for i in range(batch_size):
assert found_labels[i, 0] == i
if batch_size > 1:
matches: BatchMatches = index.search(vectors, 10, exact=True)
found_labels = matches.labels
for i in range(batch_size):
assert found_labels[i, 0] == i

# Match entries aginst themselves
index_copy: Index = index.copy()
Expand Down
34 changes: 16 additions & 18 deletions python/usearch/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def _normalize_metric(metric):

return metric


@dataclass
class Match:
label: int
Expand Down Expand Up @@ -353,27 +354,22 @@ def add(
"""
assert isinstance(vectors, np.ndarray), "Expects a NumPy array"
assert vectors.ndim == 1 or vectors.ndim == 2, "Expects a matrix or vector"
count_vectors = vectors.shape[0] if vectors.ndim == 2 else 1
is_multiple = count_vectors > 1
if not is_multiple:
vectors = vectors.flatten()
if vectors.ndim == 1:
vectors = vectors.reshape(1, len(vectors))

# Validate of generate teh labels
# Validate or generate the labels
count_vectors = vectors.shape[0]
generate_labels = labels is None
if generate_labels:
start_id = len(self._compiled)
if is_multiple:
labels = np.arange(start_id, start_id + count_vectors, dtype=Label)
else:
labels = start_id
labels = np.arange(start_id, start_id + count_vectors, dtype=Label)
else:
if isinstance(labels, Iterable):
if not is_multiple:
labels = int(labels[0])
else:
labels = np.array(labels).astype(Label)
count_labels = len(labels) if isinstance(labels, Iterable) else 1
assert count_labels == count_vectors
if not isinstance(labels, Iterable):
assert count_vectors == 1, "Each vector must have a label"
labels = [labels]
labels = np.array(labels).astype(Label)

assert len(labels) == count_vectors

# If logging is requested, and batch size is undefined, set it to grow 1% at a time:
if log and batch_size == 0:
Expand Down Expand Up @@ -440,10 +436,12 @@ def search(

assert isinstance(vectors, np.ndarray), "Expects a NumPy array"
assert vectors.ndim == 1 or vectors.ndim == 2, "Expects a matrix or vector"
count_vectors = vectors.shape[0] if vectors.ndim == 2 else 1
if vectors.ndim == 1:
vectors = vectors.reshape(1, len(vectors))
count_vectors = vectors.shape[0]

def distil_batch(batch_matches: BatchMatches) -> Union[BatchMatches, Matches]:
return batch_matches if vectors.ndim == 2 else batch_matches[0]
return batch_matches[0] if count_vectors == 1 else batch_matches

if log and batch_size == 0:
batch_size = int(math.ceil(count_vectors / 100))
Expand Down

0 comments on commit 44c0318

Please sign in to comment.