Skip to content

Commit

Permalink
Add: filtered_search in Rust
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Apr 7, 2024
1 parent 5bb38aa commit e1b24e1
Show file tree
Hide file tree
Showing 3 changed files with 606 additions and 73 deletions.
85 changes: 57 additions & 28 deletions rust/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,40 @@ using search_result_t = typename index_t::search_result_t;
using labeling_result_t = typename index_t::labeling_result_t;
using vector_key_t = typename index_dense_t::vector_key_t;

template <typename scalar_at> Matches search_(index_dense_t& index, scalar_at const* vec, size_t count) {
metric_kind_t rust_to_cpp_metric(MetricKind value) {
switch (value) {
case MetricKind::IP: return metric_kind_t::ip_k;
case MetricKind::L2sq: return metric_kind_t::l2sq_k;
case MetricKind::Cos: return metric_kind_t::cos_k;
case MetricKind::Pearson: return metric_kind_t::pearson_k;
case MetricKind::Haversine: return metric_kind_t::haversine_k;
case MetricKind::Divergence: return metric_kind_t::divergence_k;
case MetricKind::Hamming: return metric_kind_t::hamming_k;
case MetricKind::Tanimoto: return metric_kind_t::tanimoto_k;
case MetricKind::Sorensen: return metric_kind_t::sorensen_k;
default: return metric_kind_t::unknown_k;
}
}

scalar_kind_t rust_to_cpp_scalar(ScalarKind value) {
switch (value) {
case ScalarKind::I8: return scalar_kind_t::i8_k;
case ScalarKind::F16: return scalar_kind_t::f16_k;
case ScalarKind::F32: return scalar_kind_t::f32_k;
case ScalarKind::F64: return scalar_kind_t::f64_k;
case ScalarKind::B1: return scalar_kind_t::b1x8_k;
default: return scalar_kind_t::unknown_k;
}
}

template <typename scalar_at, typename predicate_at = dummy_predicate_t>
Matches search_(index_dense_t& index, scalar_at const* vec, size_t count, predicate_at&& predicate = predicate_at{}) {
Matches matches;
matches.keys.reserve(count);
matches.distances.reserve(count);
for (size_t i = 0; i != count; ++i)
matches.keys.push_back(0), matches.distances.push_back(0);
search_result_t result = index.search(vec, count);
search_result_t result = index.filtered_search(vec, count, std::forward<predicate_at>(predicate));
result.error.raise();
count = result.dump_to(matches.keys.data(), matches.distances.data());
matches.keys.truncate(count);
Expand All @@ -26,6 +53,14 @@ template <typename scalar_at> Matches search_(index_dense_t& index, scalar_at co

NativeIndex::NativeIndex(std::unique_ptr<index_t> index) : index_(std::move(index)) {}

auto make_predicate(uptr_t metric, uptr_t metric_state) {
return [=](vector_key_t key) {
auto func = reinterpret_cast<bool (*)(uptr_t, vector_key_t)>(metric);
auto state = reinterpret_cast<uptr_t>(metric_state);
return func(key, state);
};
}

// clang-format off
void NativeIndex::add_i8(vector_key_t key, rust::Slice<int8_t const> vec) const { index_->add(key, vec.data()).error.raise(); }
void NativeIndex::add_f16(vector_key_t key, rust::Slice<uint16_t const> vec) const { index_->add(key, (f16_t const*)vec.data()).error.raise(); }
Expand All @@ -37,6 +72,11 @@ Matches NativeIndex::search_f16(rust::Slice<uint16_t const> vec, size_t count) c
Matches NativeIndex::search_f32(rust::Slice<float const> vec, size_t count) const { return search_(*index_, vec.data(), count); }
Matches NativeIndex::search_f64(rust::Slice<double const> vec, size_t count) const { return search_(*index_, vec.data(), count); }

Matches NativeIndex::filtered_search_i8(rust::Slice<int8_t const> vec, size_t count, uptr_t metric, uptr_t metric_state) const { return search_(*index_, vec.data(), count, make_predicate(metric, metric_state)); }
Matches NativeIndex::filtered_search_f16(rust::Slice<uint16_t const> vec, size_t count, uptr_t metric, uptr_t metric_state) const { return search_(*index_, (f16_t const*)vec.data(), count, make_predicate(metric, metric_state)); }
Matches NativeIndex::filtered_search_f32(rust::Slice<float const> vec, size_t count, uptr_t metric, uptr_t metric_state) const { return search_(*index_, vec.data(), count, make_predicate(metric, metric_state)); }
Matches NativeIndex::filtered_search_f64(rust::Slice<double const> vec, size_t count, uptr_t metric, uptr_t metric_state) const { return search_(*index_, vec.data(), count, make_predicate(metric, metric_state)); }

size_t NativeIndex::get_i8(vector_key_t key, rust::Slice<int8_t> vec) const { if (vec.size() % dimensions()) throw std::invalid_argument("Vector length must be a multiple of index dimensionality"); return index_->get(key, vec.data(), vec.size() / dimensions()); }
size_t NativeIndex::get_f16(vector_key_t key, rust::Slice<uint16_t> vec) const { if (vec.size() % dimensions()) throw std::invalid_argument("Vector length must be a multiple of index dimensionality"); return index_->get(key, (f16_t*)vec.data(), vec.size() / dimensions()); }
size_t NativeIndex::get_f32(vector_key_t key, rust::Slice<float> vec) const { if (vec.size() % dimensions()) throw std::invalid_argument("Vector length must be a multiple of index dimensionality"); return index_->get(key, vec.data(), vec.size() / dimensions()); }
Expand All @@ -48,6 +88,21 @@ size_t NativeIndex::expansion_search() const { return index_->expansion_search()
void NativeIndex::change_expansion_add(size_t n) const { index_->change_expansion_add(n); }
void NativeIndex::change_expansion_search(size_t n) const { index_->change_expansion_search(n); }

void NativeIndex::change_metric(uptr_t metric, uptr_t state) const {
index_->change_metric(metric_punned_t::statefull( //
reinterpret_cast<std::uintptr_t>(metric), //
reinterpret_cast<std::uintptr_t>(state), //
index_->metric().metric_kind(), //
index_->scalar_kind()));
}

void NativeIndex::change_metric_kind(MetricKind metric) const {
index_->change_metric(metric_punned_t::builtin( //
index_->dimensions(), //
rust_to_cpp_metric(metric), //
index_->scalar_kind()));
}

size_t NativeIndex::remove(vector_key_t key) const {
labeling_result_t result = index_->remove(key);
result.error.raise();
Expand Down Expand Up @@ -101,32 +156,6 @@ std::unique_ptr<NativeIndex> wrap(index_t&& index) {
return result;
}

metric_kind_t rust_to_cpp_metric(MetricKind value) {
switch (value) {
case MetricKind::IP: return metric_kind_t::ip_k;
case MetricKind::L2sq: return metric_kind_t::l2sq_k;
case MetricKind::Cos: return metric_kind_t::cos_k;
case MetricKind::Pearson: return metric_kind_t::pearson_k;
case MetricKind::Haversine: return metric_kind_t::haversine_k;
case MetricKind::Divergence: return metric_kind_t::divergence_k;
case MetricKind::Hamming: return metric_kind_t::hamming_k;
case MetricKind::Tanimoto: return metric_kind_t::tanimoto_k;
case MetricKind::Sorensen: return metric_kind_t::sorensen_k;
default: return metric_kind_t::unknown_k;
}
}

scalar_kind_t rust_to_cpp_scalar(ScalarKind value) {
switch (value) {
case ScalarKind::I8: return scalar_kind_t::i8_k;
case ScalarKind::F16: return scalar_kind_t::f16_k;
case ScalarKind::F32: return scalar_kind_t::f32_k;
case ScalarKind::F64: return scalar_kind_t::f64_k;
case ScalarKind::B1: return scalar_kind_t::b1x8_k;
default: return scalar_kind_t::unknown_k;
}
}

std::unique_ptr<NativeIndex> new_native_index(IndexOptions const& options) {
metric_kind_t metric_kind = rust_to_cpp_metric(options.metric);
scalar_kind_t scalar_kind = rust_to_cpp_scalar(options.quantization);
Expand Down
33 changes: 25 additions & 8 deletions rust/lib.hpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
#pragma once
#include "rust/cxx.h"

#include <memory> // `std::shared_ptr`

#include <usearch/index_dense.hpp>

// We don't have to forward decalre all of those:
struct Matches;
struct IndexOptions;
enum class MetricKind;
enum class ScalarKind;

#include <usearch/index_dense.hpp> // `unum::usearch::index_dense_t`

#include <memory> // `std::unique_ptr`

using uptr_t = size_t;

class NativeIndex {
public:
Expand All @@ -26,10 +31,19 @@ class NativeIndex {
void add_f32(vector_key_t key, rust::Slice<float const> vector) const;
void add_f64(vector_key_t key, rust::Slice<double const> vector) const;

Matches search_i8(rust::Slice<int8_t const> vector, size_t count) const;
Matches search_f16(rust::Slice<uint16_t const> vector, size_t count) const;
Matches search_f32(rust::Slice<float const> vector, size_t count) const;
Matches search_f64(rust::Slice<double const> vector, size_t count) const;
Matches search_i8(rust::Slice<int8_t const> query, size_t count) const;
Matches search_f16(rust::Slice<uint16_t const> query, size_t count) const;
Matches search_f32(rust::Slice<float const> query, size_t count) const;
Matches search_f64(rust::Slice<double const> query, size_t count) const;

Matches filtered_search_i8(rust::Slice<int8_t const> query, size_t count, //
uptr_t filter_function, uptr_t filter_state) const;
Matches filtered_search_f16(rust::Slice<uint16_t const> query, size_t count, //
uptr_t filter_function, uptr_t filter_state) const;
Matches filtered_search_f32(rust::Slice<float const> query, size_t count, //
uptr_t filter_function, uptr_t filter_state) const;
Matches filtered_search_f64(rust::Slice<double const> query, size_t count, //
uptr_t filter_function, uptr_t filter_state) const;

size_t get_i8(vector_key_t key, rust::Slice<int8_t> vector) const;
size_t get_f16(vector_key_t key, rust::Slice<uint16_t> vector) const;
Expand All @@ -41,6 +55,9 @@ class NativeIndex {
void change_expansion_add(size_t n) const;
void change_expansion_search(size_t n) const;

void change_metric(uptr_t metric, uptr_t state) const;
void change_metric_kind(MetricKind metric) const;

size_t count(vector_key_t key) const;
size_t remove(vector_key_t key) const;
size_t rename(vector_key_t from, vector_key_t to) const;
Expand Down
Loading

0 comments on commit e1b24e1

Please sign in to comment.