Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions rust/cuvs/src/cagra/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,48 @@ impl Index {
))
}
}

/// Perform a filtered Approximate Nearest Neighbors search on the Index
///
/// Like [`search`](Self::search), but accepts a bitset filter to exclude
/// vectors during graph traversal. Filtered vectors are never visited,
/// giving better recall than post-filtering.
///
/// # Arguments
///
/// * `res` - Resources to use
/// * `params` - Parameters to use in searching the index
/// * `queries` - A matrix in device memory to query for
/// * `neighbors` - Matrix in device memory that receives the indices of the nearest neighbors
/// * `distances` - Matrix in device memory that receives the distances of the nearest neighbors
/// * `bitset` - A 1-D `uint32` device tensor with `ceil(n_rows / 32)` elements.
/// Each bit corresponds to a dataset row: bit 1 = include, bit 0 = exclude.
pub fn search_with_filter(
&self,
res: &Resources,
params: &SearchParams,
queries: &ManagedTensor,
neighbors: &ManagedTensor,
distances: &ManagedTensor,
bitset: &ManagedTensor,
) -> Result<()> {
unsafe {
let prefilter = ffi::cuvsFilter {
addr: bitset.as_ptr() as usize,
type_: ffi::cuvsFilterType::BITSET,
};

check_cuvs(ffi::cuvsCagraSearch(
res.0,
params.0,
self.0,
queries.as_ptr(),
neighbors.as_ptr(),
distances.as_ptr(),
prefilter,
))
}
}
}

impl Drop for Index {
Expand Down Expand Up @@ -168,6 +210,76 @@ mod tests {
test_cagra(build_params);
}

/// Test bitset-filtered search: exclude odd-indexed rows, verify they don't appear.
#[test]
fn test_cagra_search_with_filter() {
let res = Resources::new().unwrap();
let build_params = IndexParams::new().unwrap();

let n_datapoints = 256;
let n_features = 16;
let dataset =
ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0));

let index =
Index::build(&res, &build_params, &dataset).expect("failed to create cagra index");

// Build a bitset that includes only even-indexed rows
let n_words = (n_datapoints + 31) / 32;
let mut bitset_host = ndarray::Array::<u32, _>::zeros(ndarray::Ix1(n_words));
for i in 0..n_datapoints {
if i % 2 == 0 {
bitset_host[i / 32] |= 1u32 << (i % 32);
}
}
let bitset = ManagedTensor::from(&bitset_host).to_device(&res).unwrap();

// Query with the first 4 even-indexed rows
let n_queries = 4;
let queries = dataset.slice(s![0..n_queries * 2;2, ..]); // rows 0, 2, 4, 6
let queries = ManagedTensor::from(&queries).to_device(&res).unwrap();

let k = 10;
let mut neighbors_host = ndarray::Array::<u32, _>::zeros((n_queries, k));
let neighbors = ManagedTensor::from(&neighbors_host)
.to_device(&res)
.unwrap();
let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
let distances = ManagedTensor::from(&distances_host)
.to_device(&res)
.unwrap();

let search_params = SearchParams::new().unwrap();

index
.search_with_filter(
&res,
&search_params,
&queries,
&neighbors,
&distances,
&bitset,
)
.unwrap();

neighbors.to_host(&res, &mut neighbors_host).unwrap();

// All returned neighbors must be even-indexed (odd rows are filtered out).
for q in 0..n_queries {
for n in 0..k {
let neighbor_id = neighbors_host[[q, n]];
assert_eq!(
neighbor_id % 2,
0,
"query {q}, neighbor {n}: got odd index {neighbor_id}, expected only even"
);
}
}

// First query (row 0) should find itself as the nearest neighbor.
assert_eq!(neighbors_host[[0, 0]], 0);
}

/// Test that an index can be searched multiple times without rebuilding.
/// This validates that search() takes &self instead of self.
#[test]
Expand Down
Loading