From e1b24e1950f25119471f037c90a5685e544897e0 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 7 Apr 2024 23:06:55 +0000 Subject: [PATCH] Add: `filtered_search` in Rust --- rust/lib.cpp | 85 +++++--- rust/lib.hpp | 33 ++- rust/lib.rs | 561 +++++++++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 606 insertions(+), 73 deletions(-) diff --git a/rust/lib.cpp b/rust/lib.cpp index f43477f7..630390c3 100644 --- a/rust/lib.cpp +++ b/rust/lib.cpp @@ -10,13 +10,40 @@ using search_result_t = typename index_t::search_result_t; using labeling_result_t = typename index_t::labeling_result_t; using vector_key_t = typename index_dense_t::vector_key_t; -template Matches search_(index_dense_t& index, scalar_at const* vec, size_t count) { +metric_kind_t rust_to_cpp_metric(MetricKind value) { + switch (value) { + case MetricKind::IP: return metric_kind_t::ip_k; + case MetricKind::L2sq: return metric_kind_t::l2sq_k; + case MetricKind::Cos: return metric_kind_t::cos_k; + case MetricKind::Pearson: return metric_kind_t::pearson_k; + case MetricKind::Haversine: return metric_kind_t::haversine_k; + case MetricKind::Divergence: return metric_kind_t::divergence_k; + case MetricKind::Hamming: return metric_kind_t::hamming_k; + case MetricKind::Tanimoto: return metric_kind_t::tanimoto_k; + case MetricKind::Sorensen: return metric_kind_t::sorensen_k; + default: return metric_kind_t::unknown_k; + } +} + +scalar_kind_t rust_to_cpp_scalar(ScalarKind value) { + switch (value) { + case ScalarKind::I8: return scalar_kind_t::i8_k; + case ScalarKind::F16: return scalar_kind_t::f16_k; + case ScalarKind::F32: return scalar_kind_t::f32_k; + case ScalarKind::F64: return scalar_kind_t::f64_k; + case ScalarKind::B1: return scalar_kind_t::b1x8_k; + default: return scalar_kind_t::unknown_k; + } +} + +template +Matches search_(index_dense_t& index, scalar_at const* vec, size_t count, predicate_at&& predicate = predicate_at{}) { Matches matches; matches.keys.reserve(count); matches.distances.reserve(count); for (size_t i = 0; i != count; ++i) matches.keys.push_back(0), matches.distances.push_back(0); - search_result_t result = index.search(vec, count); + search_result_t result = index.filtered_search(vec, count, std::forward(predicate)); result.error.raise(); count = result.dump_to(matches.keys.data(), matches.distances.data()); matches.keys.truncate(count); @@ -26,6 +53,14 @@ template Matches search_(index_dense_t& index, scalar_at co NativeIndex::NativeIndex(std::unique_ptr index) : index_(std::move(index)) {} +auto make_predicate(uptr_t metric, uptr_t metric_state) { + return [=](vector_key_t key) { + auto func = reinterpret_cast(metric); + auto state = reinterpret_cast(metric_state); + return func(key, state); + }; +} + // clang-format off void NativeIndex::add_i8(vector_key_t key, rust::Slice vec) const { index_->add(key, vec.data()).error.raise(); } void NativeIndex::add_f16(vector_key_t key, rust::Slice vec) const { index_->add(key, (f16_t const*)vec.data()).error.raise(); } @@ -37,6 +72,11 @@ Matches NativeIndex::search_f16(rust::Slice vec, size_t count) c Matches NativeIndex::search_f32(rust::Slice vec, size_t count) const { return search_(*index_, vec.data(), count); } Matches NativeIndex::search_f64(rust::Slice vec, size_t count) const { return search_(*index_, vec.data(), count); } +Matches NativeIndex::filtered_search_i8(rust::Slice vec, size_t count, uptr_t metric, uptr_t metric_state) const { return search_(*index_, vec.data(), count, make_predicate(metric, metric_state)); } +Matches NativeIndex::filtered_search_f16(rust::Slice vec, size_t count, uptr_t metric, uptr_t metric_state) const { return search_(*index_, (f16_t const*)vec.data(), count, make_predicate(metric, metric_state)); } +Matches NativeIndex::filtered_search_f32(rust::Slice vec, size_t count, uptr_t metric, uptr_t metric_state) const { return search_(*index_, vec.data(), count, make_predicate(metric, metric_state)); } +Matches NativeIndex::filtered_search_f64(rust::Slice vec, size_t count, uptr_t metric, uptr_t metric_state) const { return search_(*index_, vec.data(), count, make_predicate(metric, metric_state)); } + size_t NativeIndex::get_i8(vector_key_t key, rust::Slice vec) const { if (vec.size() % dimensions()) throw std::invalid_argument("Vector length must be a multiple of index dimensionality"); return index_->get(key, vec.data(), vec.size() / dimensions()); } size_t NativeIndex::get_f16(vector_key_t key, rust::Slice vec) const { if (vec.size() % dimensions()) throw std::invalid_argument("Vector length must be a multiple of index dimensionality"); return index_->get(key, (f16_t*)vec.data(), vec.size() / dimensions()); } size_t NativeIndex::get_f32(vector_key_t key, rust::Slice vec) const { if (vec.size() % dimensions()) throw std::invalid_argument("Vector length must be a multiple of index dimensionality"); return index_->get(key, vec.data(), vec.size() / dimensions()); } @@ -48,6 +88,21 @@ size_t NativeIndex::expansion_search() const { return index_->expansion_search() void NativeIndex::change_expansion_add(size_t n) const { index_->change_expansion_add(n); } void NativeIndex::change_expansion_search(size_t n) const { index_->change_expansion_search(n); } +void NativeIndex::change_metric(uptr_t metric, uptr_t state) const { + index_->change_metric(metric_punned_t::statefull( // + reinterpret_cast(metric), // + reinterpret_cast(state), // + index_->metric().metric_kind(), // + index_->scalar_kind())); +} + +void NativeIndex::change_metric_kind(MetricKind metric) const { + index_->change_metric(metric_punned_t::builtin( // + index_->dimensions(), // + rust_to_cpp_metric(metric), // + index_->scalar_kind())); +} + size_t NativeIndex::remove(vector_key_t key) const { labeling_result_t result = index_->remove(key); result.error.raise(); @@ -101,32 +156,6 @@ std::unique_ptr wrap(index_t&& index) { return result; } -metric_kind_t rust_to_cpp_metric(MetricKind value) { - switch (value) { - case MetricKind::IP: return metric_kind_t::ip_k; - case MetricKind::L2sq: return metric_kind_t::l2sq_k; - case MetricKind::Cos: return metric_kind_t::cos_k; - case MetricKind::Pearson: return metric_kind_t::pearson_k; - case MetricKind::Haversine: return metric_kind_t::haversine_k; - case MetricKind::Divergence: return metric_kind_t::divergence_k; - case MetricKind::Hamming: return metric_kind_t::hamming_k; - case MetricKind::Tanimoto: return metric_kind_t::tanimoto_k; - case MetricKind::Sorensen: return metric_kind_t::sorensen_k; - default: return metric_kind_t::unknown_k; - } -} - -scalar_kind_t rust_to_cpp_scalar(ScalarKind value) { - switch (value) { - case ScalarKind::I8: return scalar_kind_t::i8_k; - case ScalarKind::F16: return scalar_kind_t::f16_k; - case ScalarKind::F32: return scalar_kind_t::f32_k; - case ScalarKind::F64: return scalar_kind_t::f64_k; - case ScalarKind::B1: return scalar_kind_t::b1x8_k; - default: return scalar_kind_t::unknown_k; - } -} - std::unique_ptr new_native_index(IndexOptions const& options) { metric_kind_t metric_kind = rust_to_cpp_metric(options.metric); scalar_kind_t scalar_kind = rust_to_cpp_scalar(options.quantization); diff --git a/rust/lib.hpp b/rust/lib.hpp index eca7d06c..304cb515 100644 --- a/rust/lib.hpp +++ b/rust/lib.hpp @@ -1,12 +1,17 @@ #pragma once #include "rust/cxx.h" -#include // `std::shared_ptr` - -#include - +// We don't have to forward decalre all of those: struct Matches; struct IndexOptions; +enum class MetricKind; +enum class ScalarKind; + +#include // `unum::usearch::index_dense_t` + +#include // `std::unique_ptr` + +using uptr_t = size_t; class NativeIndex { public: @@ -26,10 +31,19 @@ class NativeIndex { void add_f32(vector_key_t key, rust::Slice vector) const; void add_f64(vector_key_t key, rust::Slice vector) const; - Matches search_i8(rust::Slice vector, size_t count) const; - Matches search_f16(rust::Slice vector, size_t count) const; - Matches search_f32(rust::Slice vector, size_t count) const; - Matches search_f64(rust::Slice vector, size_t count) const; + Matches search_i8(rust::Slice query, size_t count) const; + Matches search_f16(rust::Slice query, size_t count) const; + Matches search_f32(rust::Slice query, size_t count) const; + Matches search_f64(rust::Slice query, size_t count) const; + + Matches filtered_search_i8(rust::Slice query, size_t count, // + uptr_t filter_function, uptr_t filter_state) const; + Matches filtered_search_f16(rust::Slice query, size_t count, // + uptr_t filter_function, uptr_t filter_state) const; + Matches filtered_search_f32(rust::Slice query, size_t count, // + uptr_t filter_function, uptr_t filter_state) const; + Matches filtered_search_f64(rust::Slice query, size_t count, // + uptr_t filter_function, uptr_t filter_state) const; size_t get_i8(vector_key_t key, rust::Slice vector) const; size_t get_f16(vector_key_t key, rust::Slice vector) const; @@ -41,6 +55,9 @@ class NativeIndex { void change_expansion_add(size_t n) const; void change_expansion_search(size_t n) const; + void change_metric(uptr_t metric, uptr_t state) const; + void change_metric_kind(MetricKind metric) const; + size_t count(vector_key_t key) const; size_t remove(vector_key_t key) const; size_t rename(vector_key_t from, vector_key_t to) const; diff --git a/rust/lib.rs b/rust/lib.rs index c341fe4c..f42c439a 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -1,34 +1,79 @@ +use std::boxed::Box; + +/// The key type used to identify vectors in the index. +/// It is a 64-bit unsigned integer. +pub type Key = u64; + +/// The distance type used to represent the similarity between vectors. +/// It is a 32-bit floating-point number. +pub type Distance = f32; + +/// Callback signature for custom metric functions, defined in the Rust layer and used in the C++ layer. +pub type StatefullMetric = unsafe extern "C" fn( + *const std::ffi::c_void, + *const std::ffi::c_void, + *mut std::ffi::c_void, +) -> Distance; + +/// Callback signature for custom predicate functions, defined in the Rust layer and used in the C++ layer. +pub type StatefullPredicate = unsafe extern "C" fn(Key, *mut std::ffi::c_void) -> bool; + #[cxx::bridge] pub mod ffi { - // Shared structs with fields visible to both languages. - #[derive(Debug)] - struct Matches { - keys: Vec, - distances: Vec, - } + /// The metric kind used to differentiate built-in distance functions. #[derive(Debug)] + #[repr(i32)] enum MetricKind { + Unknown, + /// The Inner Product metric, defined as `IP = 1 - sum(a[i] * b[i])`. IP, + /// The squared Euclidean Distance metric, defined as `L2 = sum((a[i] - b[i])^2)`. L2sq, + /// The Cosine Similarity metric, defined as `Cos = 1 - sum(a[i] * b[i]) / (sqrt(sum(a[i]^2) * sqrt(sum(b[i]^2)))`. Cos, + /// The Pearson Correlation metric. Pearson, + /// The Haversine (Great Circle) Distance metric. Haversine, + /// The Jensen Shannon Divergence metric. Divergence, + /// The bit-level Hamming Distance metric, defined as the number of differing bits. Hamming, + /// The bit-level Tanimoto (Jaccard) metric, defined as the number of intersecting bits divided by the number of union bits. Tanimoto, + /// The bit-level Sorensen metric. Sorensen, } + /// The scalar kind used to differentiate built-in vector element types. #[derive(Debug)] + #[repr(i32)] enum ScalarKind { + Unknown, + /// 64-bit double-precision IEEE 754 floating-point number. F64, + /// 32-bit single-precision IEEE 754 floating-point number. F32, + /// 16-bit half-precision IEEE 754 floating-point number (different from `bf16`). F16, + /// 8-bit signed integer. I8, + /// 1-bit binary value, packed 8 per byte. B1, } + /// The resulting matches from a search operation. + /// It contains the keys and distances of the closest vectors. + #[derive(Debug)] + struct Matches { + keys: Vec, + distances: Vec, + } + + /// The index options used to configure the dense index during creation. + /// It contains the number of dimensions, the metric kind, the scalar kind, + /// the connectivity, the expansion values, and the multi-flag. #[derive(Debug, PartialEq)] struct IndexOptions { dimensions: usize, @@ -49,8 +94,16 @@ pub mod ffi { pub fn expansion_add(self: &NativeIndex) -> usize; pub fn expansion_search(self: &NativeIndex) -> usize; - pub fn change_expansion_add(self: &NativeIndex, n: usize) -> Result<()>; - pub fn change_expansion_search(self: &NativeIndex, n: usize) -> Result<()>; + pub fn change_expansion_add(self: &NativeIndex, n: usize); + pub fn change_expansion_search(self: &NativeIndex, n: usize); + pub fn change_metric_kind(self: &NativeIndex, metric: MetricKind); + + /// Changes the metric function used to calculate the distance between vectors. + /// Avoids the `std::ffi::c_void` type and the `StatefullMetric` type, that the FFI + /// does not support, replacing them with basic pointer-sized integer types. + /// The first two arguments are the pointers to the vectors to compare, and the third + /// argument is the `metric_state` propagated from the Rust layer. + pub fn change_metric(self: &NativeIndex, metric: usize, metric_state: usize); pub fn new_native_index(options: &IndexOptions) -> Result>; pub fn reserve(self: &NativeIndex, capacity: usize) -> Result<()>; @@ -70,6 +123,35 @@ pub mod ffi { pub fn search_f32(self: &NativeIndex, query: &[f32], count: usize) -> Result; pub fn search_f64(self: &NativeIndex, query: &[f64], count: usize) -> Result; + pub fn filtered_search_i8( + self: &NativeIndex, + query: &[i8], + count: usize, + filter: usize, + filter_state: usize, + ) -> Result; + pub fn filtered_search_f16( + self: &NativeIndex, + query: &[u16], + count: usize, + filter: usize, + filter_state: usize, + ) -> Result; + pub fn filtered_search_f32( + self: &NativeIndex, + query: &[f32], + count: usize, + filter: usize, + filter_state: usize, + ) -> Result; + pub fn filtered_search_f64( + self: &NativeIndex, + query: &[f64], + count: usize, + filter: usize, + filter_state: usize, + ) -> Result; + pub fn get_i8(self: &NativeIndex, key: u64, buffer: &mut [i8]) -> Result; pub fn get_f16(self: &NativeIndex, key: u64, buffer: &mut [u16]) -> Result; pub fn get_f32(self: &NativeIndex, key: u64, buffer: &mut [f32]) -> Result; @@ -93,8 +175,51 @@ pub mod ffi { } } +pub enum MetricFunction { + I8Metric(Box Distance + Send + Sync>), + F16Metric(Box Distance + Send + Sync>), + F32Metric(Box Distance + Send + Sync>), + F64Metric(Box Distance + Send + Sync>), +} + +/// Approximate Nearest Neighbors search index for dense vectors. +/// +/// The `Index` struct provides an abstraction over a dense vector space, allowing +/// for efficient addition, search, and management of high-dimensional vectors. +/// It supports various distance metrics and vector types through generic interfaces. +/// +/// # Examples +/// +/// Basic usage: +/// +/// ```rust +/// use usearch::{Index, IndexOptions, MetricKind, ScalarKind}; +/// +/// // Create an index with specific options +/// let mut options = IndexOptions::default(); +/// options.dimensions = 256; // Set the number of dimensions for vectors +/// options.metric = MetricKind::Cos; // Use cosine similarity for distance measurement +/// options.quantization = ScalarKind::F32; // Use 32-bit floating point numbers +/// +/// let index = Index::new(&options).expect("Failed to create index."); +/// +/// // Add vectors to the index +/// let vector1: Vec = vec![0.0, 1.0, 0.0, 1.0, ...]; +/// let vector2: Vec = vec![1.0, 0.0, 1.0, 0.0, ...]; +/// index.add(1, &vector1).expect("Failed to add vector1."); +/// index.add(2, &vector2).expect("Failed to add vector2."); +/// +/// // Search for the nearest neighbors to a query vector +/// let query: Vec = vec![0.5, 0.5, 0.5, 0.5, ...]; +/// let results = index.search(&query, 5).expect("Search failed."); +/// for (key, distance) in results.keys.iter().zip(results.distances.iter()) { +/// println!("Key: {}, Distance: {}", key, distance); +/// } +/// ``` +/// pub struct Index { inner: cxx::UniquePtr, + metric_fn: Option, } impl Default for ffi::IndexOptions { @@ -127,39 +252,227 @@ impl Clone for ffi::IndexOptions { } } +/// The `VectorType` trait defines operations for managing and querying vectors +/// in an index. It supports generic operations on vectors of different types, +/// allowing for the addition, retrieval, and search of vectors within an index. pub trait VectorType { - fn add(index: &Index, key: u64, vector: &[Self]) -> Result<(), cxx::Exception> + /// Adds a vector to the index under the specified key. + /// + /// # Parameters + /// - `index`: A reference to the `Index` where the vector is to be added. + /// - `key`: The key under which the vector should be stored. + /// - `vector`: A slice representing the vector to be added. + /// + /// # Returns + /// - `Ok(())` if the vector was successfully added to the index. + /// - `Err(cxx::Exception)` if an error occurred during the operation. + fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> where Self: Sized; - fn get(index: &Index, key: u64, buffer: &mut [Self]) -> Result + + /// Retrieves a vector from the index by its key. + /// + /// # Parameters + /// - `index`: A reference to the `Index` from which the vector is to be retrieved. + /// - `key`: The key of the vector to retrieve. + /// - `buffer`: A mutable slice where the retrieved vector will be stored. The size of the + /// buffer determines the maximum number of elements that can be retrieved. + /// + /// # Returns + /// - `Ok(usize)` indicating the number of elements actually written into the `buffer`. + /// - `Err(cxx::Exception)` if an error occurred during the operation. + fn get(index: &Index, key: Key, buffer: &mut [Self]) -> Result where Self: Sized; + + /// Performs a search in the index using the given query vector, returning + /// up to `count` closest matches. + /// + /// # Parameters + /// - `index`: A reference to the `Index` where the search is to be performed. + /// - `query`: A slice representing the query vector. + /// - `count`: The maximum number of matches to return. + /// + /// # Returns + /// - `Ok(ffi::Matches)` containing the matches found. + /// - `Err(cxx::Exception)` if an error occurred during the search operation. fn search(index: &Index, query: &[Self], count: usize) -> Result where Self: Sized; + + /// Performs a filtered search in the index using a query vector and a custom + /// filter function, returning up to `count` matches that satisfy the filter. + /// + /// # Parameters + /// - `index`: A reference to the `Index` where the search is to be performed. + /// - `query`: A slice representing the query vector. + /// - `count`: The maximum number of matches to return. + /// - `filter`: A closure that takes a `Key` and returns `true` if the corresponding + /// vector should be included in the search results, or `false` otherwise. + /// + /// # Returns + /// - `Ok(ffi::Matches)` containing the matches that satisfy the filter. + /// - `Err(cxx::Exception)` if an error occurred during the filtered search operation. + fn filtered_search( + index: &Index, + query: &[Self], + count: usize, + filter: F, + ) -> Result + where + Self: Sized, + F: Fn(Key) -> bool; + + /// Changes the metric used for distance calculations within the index. + /// + /// # Parameters + /// - `index`: A mutable reference to the `Index` for which the metric is to be changed. + /// - `metric`: A boxed closure that defines the new metric for distance calculation. The + /// closure must take two pointers to elements of type `Self` and return a `Distance`. + /// + /// # Returns + /// - `Ok(())` if the metric was successfully changed. + /// - `Err(cxx::Exception)` if an error occurred during the operation. + fn change_metric( + index: &mut Index, + metric: Box Distance + Send + Sync>, + ) -> Result<(), cxx::Exception> + where + Self: Sized; +} + +impl VectorType for f32 { + fn search(index: &Index, query: &[Self], count: usize) -> Result { + index.inner.search_f32(query, count) + } + fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result { + index.inner.get_f32(key, vector) + } + fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> { + index.inner.add_f32(key, vector) + } + fn filtered_search( + index: &Index, + query: &[Self], + count: usize, + filter: F, + ) -> Result + where + Self: Sized, + F: Fn(Key) -> bool, + { + // Trampoline is the function that knows how to call the Rust closure. + extern "C" fn trampoline bool>(key: u64, closure_address: usize) -> bool { + let closure = closure_address as *const F; + unsafe { (*closure)(key) } + } + + // Temporarily cast the closure to a raw pointer for passing. + unsafe { + let trampoline_fn: usize = std::mem::transmute(trampoline:: as *const ()); + let closure_address: usize = &filter as *const F as usize; + index + .inner + .filtered_search_f32(query, count, trampoline_fn, closure_address) + } + } + + fn change_metric( + index: &mut Index, + metric: Box Distance + Send + Sync>, + ) -> Result<(), cxx::Exception> { + // Store the metric function in the Index. + type MetricFn = fn(*const f32, *const f32) -> Distance; + index.metric_fn = Some(MetricFunction::F32Metric(metric)); + + // Trampoline is the function that knows how to call the Rust closure. + // The `first` is a pointer to the first vector, `second` is a pointer to the second vector, + // and `index_wrapper` is a pointer to the `index` itself, from which we can infer the metric function + // and the number of dimensions. + extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance { + let first_ptr = first as *const f32; + let second_ptr = second as *const f32; + let closure: MetricFn = unsafe { std::mem::transmute(closure_address) }; + closure(first_ptr, second_ptr) + } + + unsafe { + let trampoline_fn: usize = std::mem::transmute(trampoline as *const ()); + let closure_address = match index.metric_fn { + Some(MetricFunction::F32Metric(ref metric)) => metric as *const _ as usize, + _ => panic!("Expected F32Metric"), + }; + index.inner.change_metric(trampoline_fn, closure_address) + } + + Ok(()) + } } impl VectorType for i8 { fn search(index: &Index, query: &[Self], count: usize) -> Result { index.inner.search_i8(query, count) } - fn get(index: &Index, key: u64, vector: &mut [Self]) -> Result { + fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result { index.inner.get_i8(key, vector) } - fn add(index: &Index, key: u64, vector: &[Self]) -> Result<(), cxx::Exception> { + fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> { index.inner.add_i8(key, vector) } -} + fn filtered_search( + index: &Index, + query: &[Self], + count: usize, + filter: F, + ) -> Result + where + Self: Sized, + F: Fn(Key) -> bool, + { + // Trampoline is the function that knows how to call the Rust closure. + extern "C" fn trampoline bool>(key: u64, closure_address: usize) -> bool { + let closure = closure_address as *const F; + unsafe { (*closure)(key) } + } -impl VectorType for f32 { - fn search(index: &Index, query: &[Self], count: usize) -> Result { - index.inner.search_f32(query, count) - } - fn get(index: &Index, key: u64, vector: &mut [Self]) -> Result { - index.inner.get_f32(key, vector) + // Temporarily cast the closure to a raw pointer for passing. + unsafe { + let trampoline_fn: usize = std::mem::transmute(trampoline:: as *const ()); + let closure_address: usize = &filter as *const F as usize; + index + .inner + .filtered_search_i8(query, count, trampoline_fn, closure_address) + } } - fn add(index: &Index, key: u64, vector: &[Self]) -> Result<(), cxx::Exception> { - index.inner.add_f32(key, vector) + fn change_metric( + index: &mut Index, + metric: Box Distance + Send + Sync>, + ) -> Result<(), cxx::Exception> { + // Store the metric function in the Index. + type MetricFn = fn(*const i8, *const i8) -> Distance; + index.metric_fn = Some(MetricFunction::I8Metric(metric)); + + // Trampoline is the function that knows how to call the Rust closure. + // The `first` is a pointer to the first vector, `second` is a pointer to the second vector, + // and `index_wrapper` is a pointer to the `index` itself, from which we can infer the metric function + // and the number of dimensions. + extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance { + let first_ptr = first as *const i8; + let second_ptr = second as *const i8; + let closure: MetricFn = unsafe { std::mem::transmute(closure_address) }; + closure(first_ptr, second_ptr) + } + + unsafe { + let trampoline_fn: usize = std::mem::transmute(trampoline as *const ()); + let closure_address = match index.metric_fn { + Some(MetricFunction::I8Metric(ref metric)) => metric as *const _ as usize, + _ => panic!("Expected I8Metric"), + }; + index.inner.change_metric(trampoline_fn, closure_address) + } + + Ok(()) } } @@ -167,18 +480,76 @@ impl VectorType for f64 { fn search(index: &Index, query: &[Self], count: usize) -> Result { index.inner.search_f64(query, count) } - fn get(index: &Index, key: u64, vector: &mut [Self]) -> Result { + fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result { index.inner.get_f64(key, vector) } - fn add(index: &Index, key: u64, vector: &[Self]) -> Result<(), cxx::Exception> { + fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> { index.inner.add_f64(key, vector) } + fn filtered_search( + index: &Index, + query: &[Self], + count: usize, + filter: F, + ) -> Result + where + Self: Sized, + F: Fn(Key) -> bool, + { + // Trampoline is the function that knows how to call the Rust closure. + extern "C" fn trampoline bool>(key: u64, closure_address: usize) -> bool { + let closure = closure_address as *const F; + unsafe { (*closure)(key) } + } + + // Temporarily cast the closure to a raw pointer for passing. + unsafe { + let trampoline_fn: usize = std::mem::transmute(trampoline:: as *const ()); + let closure_address: usize = &filter as *const F as usize; + index + .inner + .filtered_search_f64(query, count, trampoline_fn, closure_address) + } + } + fn change_metric( + index: &mut Index, + metric: Box Distance + Send + Sync>, + ) -> Result<(), cxx::Exception> { + // Store the metric function in the Index. + type MetricFn = fn(*const f64, *const f64) -> Distance; + index.metric_fn = Some(MetricFunction::F64Metric(metric)); + + // Trampoline is the function that knows how to call the Rust closure. + // The `first` is a pointer to the first vector, `second` is a pointer to the second vector, + // and `index_wrapper` is a pointer to the `index` itself, from which we can infer the metric function + // and the number of dimensions. + extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance { + let first_ptr = first as *const f64; + let second_ptr = second as *const f64; + let closure: MetricFn = unsafe { std::mem::transmute(closure_address) }; + closure(first_ptr, second_ptr) + } + + unsafe { + let trampoline_fn: usize = std::mem::transmute(trampoline as *const ()); + let closure_address = match index.metric_fn { + Some(MetricFunction::F64Metric(ref metric)) => metric as *const _ as usize, + _ => panic!("Expected F64Metric"), + }; + index.inner.change_metric(trampoline_fn, closure_address) + } + + Ok(()) + } } impl Index { pub fn new(options: &ffi::IndexOptions) -> Result { match ffi::new_native_index(options) { - Ok(inner) => Result::Ok(Self { inner }), + Ok(inner) => Result::Ok(Self { + inner, + metric_fn: None, + }), Err(err) => Err(err), } } @@ -194,15 +565,28 @@ impl Index { } /// Updates the expansion value used during index creation. Rarely used. - pub fn change_expansion_add(self: &Index, n: usize) -> Result<(), cxx::Exception> { + pub fn change_expansion_add(self: &Index, n: usize) { self.inner.change_expansion_add(n) } /// Updates the expansion value used during search operations. - pub fn change_expansion_search(self: &Index, n: usize) -> Result<(), cxx::Exception> { + pub fn change_expansion_search(self: &Index, n: usize) { self.inner.change_expansion_search(n) } + /// Changes the metric kind used to calculate the distance between vectors. + pub fn change_metric_kind(self: &Index, metric: ffi::MetricKind) { + self.inner.change_metric_kind(metric) + } + + /// Overrides the metric function used to calculate the distance between vectors. + pub fn change_metric( + self: &mut Index, + metric: Box Distance + Send + Sync>, + ) { + T::change_metric(self, metric).unwrap(); + } + /// Retrieves the hardware acceleration information. pub fn hardware_acceleration(&self) -> String { use core::ffi::CStr; @@ -230,13 +614,37 @@ impl Index { T::search(self, query, count) } + /// Performs k-Approximate Nearest Neighbors (kANN) Search for closest vectors to the provided query + /// satisfying a custom filter function. + /// + /// # Arguments + /// + /// * `query` - A slice containing the query vector data. + /// * `count` - The maximum number of neighbors to search for. + /// * `filter` - A closure that takes a `Key` and returns `true` if the corresponding vector should be included in the search results, or `false` otherwise. + /// + /// # Returns + /// + /// A `Result` containing the matches found. + pub fn filtered_search( + self: &Index, + query: &[T], + count: usize, + filter: F, + ) -> Result + where + F: Fn(Key) -> bool, + { + T::filtered_search(self, query, count, filter) + } + /// Adds a vector with a specified key to the index. /// /// # Arguments /// /// * `key` - The key associated with the vector. /// * `vector` - A slice containing the vector data. - pub fn add(self: &Index, key: u64, vector: &[T]) -> Result<(), cxx::Exception> { + pub fn add(self: &Index, key: Key, vector: &[T]) -> Result<(), cxx::Exception> { T::add(self, key, vector) } @@ -253,7 +661,7 @@ impl Index { /// * `vector` - A slice containing the vector data. pub fn get( self: &Index, - key: u64, + key: Key, vector: &mut [T], ) -> Result { T::get(self, key, vector) @@ -268,7 +676,7 @@ impl Index { /// * `vector` - A mutable vector containing the vector data. pub fn export( self: &Index, - key: u64, + key: Key, vector: &mut Vec, ) -> Result { let dim = self.dimensions(); @@ -325,7 +733,7 @@ impl Index { /// # Returns /// /// `true` if the vector is successfully removed, `false` otherwise. - pub fn remove(self: &Index, key: u64) -> Result { + pub fn remove(self: &Index, key: Key) -> Result { self.inner.remove(key) } @@ -339,7 +747,7 @@ impl Index { /// # Returns /// /// `true` if the vector is renamed, `false` otherwise. - pub fn rename(self: &Index, from: u64, to: u64) -> Result { + pub fn rename(self: &Index, from: Key, to: Key) -> Result { self.inner.rename(from, to) } @@ -352,7 +760,7 @@ impl Index { /// # Returns /// /// `true` if the index contains the vector with the given key, `false` otherwise. - pub fn contains(self: &Index, key: u64) -> bool { + pub fn contains(self: &Index, key: Key) -> bool { self.inner.contains(key) } @@ -365,7 +773,7 @@ impl Index { /// # Returns /// /// Number of vectors found. - pub fn count(self: &Index, key: u64) -> usize { + pub fn count(self: &Index, key: Key) -> usize { self.inner.count(key) } @@ -446,7 +854,9 @@ mod tests { use crate::ffi::ScalarKind; use crate::new_index; + use crate::Distance; use crate::Index; + use crate::Key; use std::env; @@ -527,6 +937,7 @@ mod tests { fn test_add_get_vector() { let mut options = IndexOptions::default(); options.dimensions = 5; + options.quantization = ScalarKind::F32; let index = Index::new(&options).unwrap(); assert!(index.reserve(10).is_ok()); @@ -625,10 +1036,10 @@ mod tests { index.memory_usage(), index.capacity(), ); - assert!(index.change_expansion_add(10).is_ok()); + index.change_expansion_add(10); assert_eq!(index.expansion_add(), 10); assert!(index.add(42, &first).is_ok()); - assert!(index.change_expansion_add(12).is_ok()); + index.change_expansion_add(12); assert_eq!(index.expansion_add(), 12); assert!(index.add(43, &second).is_ok()); assert_eq!(index.size(), 2); @@ -640,14 +1051,14 @@ mod tests { index.capacity(), ); - assert!(index.change_expansion_search(10).is_ok()); + index.change_expansion_search(10); assert_eq!(index.expansion_search(), 10); // Read back the tags let results = index.search(&first, 10).unwrap(); println!("{:?}", results); assert_eq!(results.keys.len(), 2); - assert!(index.change_expansion_search(12).is_ok()); + index.change_expansion_search(12); assert_eq!(index.expansion_search(), 12); let results = index.search(&first, 10).unwrap(); println!("{:?}", results); @@ -694,4 +1105,80 @@ mod tests { assert_ne!(opts.metric, options.metric); assert!(new_index(&opts).is_ok()); } + + #[test] + fn test_search_with_stateless_filter() { + let mut options = IndexOptions::default(); + options.dimensions = 5; + let index = Index::new(&options).unwrap(); + index.reserve(10).unwrap(); + + // Adding sample vectors to the index + let first: [f32; 5] = [0.2, 0.1, 0.2, 0.1, 0.3]; + let second: [f32; 5] = [0.3, 0.2, 0.4, 0.0, 0.1]; + index.add(1, &first).unwrap(); + index.add(2, &second).unwrap(); + + // Stateless filter: checks if the key is odd + let stateless_filter = |key: Key| key % 2 == 1; + + let query = vec![0.2, 0.1, 0.2, 0.1, 0.3]; // Example query vector + let results = index.filtered_search(&query, 10, stateless_filter).unwrap(); + assert!( + results.keys.iter().all(|&key| key % 2 == 1), + "All keys must be odd" + ); + } + + #[test] + fn test_search_with_stateful_filter() { + use std::collections::HashSet; + + let mut options = IndexOptions::default(); + options.dimensions = 5; + let index = Index::new(&options).unwrap(); + index.reserve(10).unwrap(); + + // Adding sample vectors to the index + let first: [f32; 5] = [0.2, 0.1, 0.2, 0.1, 0.3]; + index.add(1, &first).unwrap(); + index.add(2, &first).unwrap(); + + let allowed_keys = vec![1, 2, 3].into_iter().collect::>(); + // Clone `allowed_keys` for use in the closure + let filter_keys = allowed_keys.clone(); + let stateful_filter = move |key: Key| filter_keys.contains(&key); + + let query = vec![0.2, 0.1, 0.2, 0.1, 0.3]; // Example query vector + let results = index.filtered_search(&query, 10, stateful_filter).unwrap(); + + // Use the original `allowed_keys` for assertion + assert!( + results.keys.iter().all(|&key| allowed_keys.contains(&key)), + "All keys must be in the allowed set" + ); + } + + #[test] + fn test_change_distance_function() { + let mut options = IndexOptions::default(); + options.dimensions = 2; // Adjusted for simplicity in creating test vectors + let mut index = Index::new(&options).unwrap(); + index.reserve(10).unwrap(); + + // Adding a simple vector to test the distance function changes + let vector: [f32; 2] = [1.0, 0.0]; + index.add(1, &vector).unwrap(); + + // Stateless distance function: simply returns the difference in the first element + let stateless_distance = + Box::new(|a: *const f32, b: *const f32| unsafe { (*a - *b).abs() }); + index.change_metric(stateless_distance); + + // Now changing to a stateful distance function: scales the difference by a factor + let scale_factor = 2.0; + let stateful_distance = + Box::new(move |a: *const f32, b: *const f32| unsafe { (*a - *b).abs() * scale_factor }); + index.change_metric(stateful_distance); + } }