From 9b3dcf3ca0660a49bf1d106176793d4bd817f6b9 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 10 Mar 2023 11:56:07 +0100 Subject: [PATCH 1/4] Update and standardize IVF indexes API --- cpp/bench/neighbors/knn.cuh | 12 +- cpp/include/raft/neighbors/ivf_flat.cuh | 22 ++-- cpp/include/raft/neighbors/ivf_pq.cuh | 120 +++++++++++++----- .../raft/neighbors/specializations/ivf_pq.cuh | 86 ++++++++----- .../raft/spatial/knn/detail/ann_quantized.cuh | 9 +- cpp/include/raft_runtime/neighbors/ivf_pq.hpp | 77 ++++++----- cpp/src/distance/neighbors/ivfpq_build.cu | 27 ++-- .../neighbors/ivfpq_search_float_uint64_t.cu | 20 ++- .../neighbors/ivfpq_search_int8_t_uint64_t.cu | 20 ++- .../ivfpq_search_uint8_t_uint64_t.cu | 20 ++- .../ivfpq_build_float_uint64_t.cu | 10 +- .../ivfpq_build_int8_t_uint64_t.cu | 10 +- .../ivfpq_build_uint8_t_uint64_t.cu | 10 +- .../ivfpq_extend_float_uint64_t.cu | 23 ++-- .../ivfpq_extend_int8_t_uint64_t.cu | 23 ++-- .../ivfpq_extend_uint8_t_uint64_t.cu | 23 ++-- .../ivfpq_search_float_uint64_t.cu | 5 +- .../ivfpq_search_int8_t_uint64_t.cu | 5 +- .../ivfpq_search_uint8_t_uint64_t.cu | 5 +- cpp/test/neighbors/ann_ivf_flat.cuh | 9 +- cpp/test/neighbors/ann_ivf_pq.cuh | 6 +- .../neighbors/ivf_pq/cpp/c_ivf_pq.pxd | 78 ++++++++---- .../pylibraft/neighbors/ivf_pq/ivf_pq.pyx | 27 ++-- 23 files changed, 369 insertions(+), 278 deletions(-) diff --git a/cpp/bench/neighbors/knn.cuh b/cpp/bench/neighbors/knn.cuh index ed3c6db909..e011aeb706 100644 --- a/cpp/bench/neighbors/knn.cuh +++ b/cpp/bench/neighbors/knn.cuh @@ -179,9 +179,8 @@ struct ivf_pq_knn { { index_params.n_lists = 4096; index_params.metric = raft::distance::DistanceType::L2Expanded; - - auto data_view = raft::make_device_matrix_view(data, ps.n_samples, ps.n_dims); - index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view)); + index.emplace(raft::neighbors::ivf_pq::build( + handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims))); } void search(const raft::device_resources& handle, @@ -190,13 +189,8 @@ struct ivf_pq_knn { IdxT* out_idxs) { search_params.n_probes = 20; - - auto queries_view = - raft::make_device_matrix_view(search_items, ps.n_queries, ps.n_dims); - auto idxs_view = raft::make_device_matrix_view(out_idxs, ps.n_queries, ps.k); - auto dists_view = raft::make_device_matrix_view(out_dists, ps.n_queries, ps.k); raft::neighbors::ivf_pq::search( - handle, search_params, *index, queries_view, ps.k, idxs_view, dists_view); + handle, search_params, *index, search_items, ps.n_queries, ps.k, out_idxs, out_dists); } }; diff --git a/cpp/include/raft/neighbors/ivf_flat.cuh b/cpp/include/raft/neighbors/ivf_flat.cuh index f18611b9f1..34080038f5 100644 --- a/cpp/include/raft/neighbors/ivf_flat.cuh +++ b/cpp/include/raft/neighbors/ivf_flat.cuh @@ -379,9 +379,9 @@ void search(raft::device_resources const& handle, * ivf_flat::search_params search_params; * // Use the same allocator across multiple searches to reduce the number of * // cuda memory allocations - * ivf_flat::search(handle, index, queries1, out_inds1, out_dists1, search_params, K); - * ivf_flat::search(handle, index, queries2, out_inds2, out_dists2, search_params, K); - * ivf_flat::search(handle, index, queries3, out_inds3, out_dists3, search_params, K); + * ivf_flat::search(handle, index, queries1, out_inds1, out_dists1, search_params); + * ivf_flat::search(handle, index, queries2, out_inds2, out_dists2, search_params); + * ivf_flat::search(handle, index, queries3, out_inds3, out_dists3, search_params); * ... * @endcode * @@ -397,37 +397,35 @@ void search(raft::device_resources const& handle, * [n_queries, k] * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] * @param[in] params configure the search - * @param[in] k the number of neighbors to find for each query. */ -template +template void search(raft::device_resources const& handle, const index& index, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - const search_params& params, - int_t k) + const search_params& params) { RAFT_EXPECTS( queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), "Number of rows in output neighbors and distances matrices must equal the number of queries."); - RAFT_EXPECTS( - neighbors.extent(1) == distances.extent(1) && neighbors.extent(1) == static_cast(k), - "Number of columns in output neighbors and distances matrices must equal k"); + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), + "Number of columns in output neighbors and distances matrices must equal k"); RAFT_EXPECTS(queries.extent(1) == index.dim(), "Number of query dimensions should equal number of dimensions in the index."); + std::uint32_t k = neighbors.extent(1); return search(handle, params, index, queries.data_handle(), static_cast(queries.extent(0)), - static_cast(k), + k, neighbors.data_handle(), distances.data_handle(), - nullptr); + handle.get_workspace_resource()); } /** @} */ diff --git a/cpp/include/raft/neighbors/ivf_pq.cuh b/cpp/include/raft/neighbors/ivf_pq.cuh index 4bb617b526..fe25e806d5 100644 --- a/cpp/include/raft/neighbors/ivf_pq.cuh +++ b/cpp/include/raft/neighbors/ivf_pq.cuh @@ -45,21 +45,50 @@ namespace raft::neighbors::ivf_pq { * @tparam IdxT type of the indices in the source dataset * * @param handle - * @param params configure the index building * @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim] + * @param params configure the index building * * @return the constructed ivf-pq index */ template -inline auto build(raft::device_resources const& handle, - const index_params& params, - raft::device_matrix_view dataset) -> index +auto build(raft::device_resources const& handle, + raft::device_matrix_view dataset, + const index_params& params) -> index { IdxT n_rows = dataset.extent(0); IdxT dim = dataset.extent(1); return detail::build(handle, params, dataset.data_handle(), n_rows, dim); } +/** + * @brief Build the index from the dataset for efficient search. + * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param handle + * @param[inout] index + * @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim] + * @param params configure the index building + * + * @return the constructed ivf-pq index + */ +template +void build(raft::device_resources const& handle, + index* index, + raft::device_matrix_view dataset, + const index_params& params) +{ + IdxT n_rows = dataset.extent(0); + IdxT dim = dataset.extent(1); + *index = detail::build(handle, params, dataset.data_handle(), n_rows, dim); +} + /** * @brief Build a new index containing the data of the original plus new extra vectors. * @@ -80,18 +109,26 @@ inline auto build(raft::device_resources const& handle, * @return the constructed extended ivf-pq index */ template -inline auto extend(raft::device_resources const& handle, - const index& orig_index, - raft::device_matrix_view new_vectors, - raft::device_matrix_view new_indices) -> index +auto extend(raft::device_resources const& handle, + const index& orig_index, + raft::device_matrix_view new_vectors, + std::optional> new_indices = + std::nullopt) -> index { - IdxT n_rows = new_vectors.extent(0); - ASSERT(n_rows == new_indices.extent(0), - "new_vectors and new_indices have different number of rows"); ASSERT(new_vectors.extent(1) == orig_index.dim(), "new_vectors should have the same dimension as the index"); - return detail::extend( - handle, orig_index, new_vectors.data_handle(), new_indices.data_handle(), n_rows); + + IdxT n_rows = new_vectors.extent(0); + if (new_indices.has_value()) { + ASSERT(n_rows == new_indices.value().extent(0), + "new_vectors and new_indices have different number of rows"); + } + + return detail::extend(handle, + orig_index, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + n_rows); } /** @@ -108,12 +145,26 @@ inline auto extend(raft::device_resources const& handle, * here to imply a continuous range `[0...n_rows)`. */ template -inline void extend(raft::device_resources const& handle, - index* index, - raft::device_matrix_view new_vectors, - raft::device_matrix_view new_indices) +void extend( + raft::device_resources const& handle, + index* index, + raft::device_matrix_view new_vectors, + std::optional> new_indices = std::nullopt) { - *index = extend(handle, *index, new_vectors, new_indices); + ASSERT(new_vectors.extent(1) == index->dim(), + "new_vectors should have the same dimension as the index"); + + IdxT n_rows = new_vectors.extent(0); + if (new_indices.has_value()) { + ASSERT(n_rows == new_indices.value().extent(0), + "new_vectors and new_indices have different number of rows"); + } + + *index = extend(handle, + *index, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + n_rows); } /** @@ -133,33 +184,38 @@ inline void extend(raft::device_resources const& handle, * @tparam IdxT type of the indices * * @param handle - * @param params configure the search * @param index ivf-pq constructed index * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param k the number of neighbors to find for each query. * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset * [n_queries, k] * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, * k] + * @param params configure the search */ template -inline void search(raft::device_resources const& handle, - const search_params& params, - const index& index, - raft::device_matrix_view queries, - uint32_t k, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) +void search(raft::device_resources const& handle, + const index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + const search_params& params) { - IdxT n_queries = queries.extent(0); - bool check_n_rows = (n_queries == neighbors.extent(0)) && (n_queries == distances.extent(0)); - ASSERT(check_n_rows, - "queries, neighbors and distances parameters have inconsistent number of rows"); + RAFT_EXPECTS( + queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), + "Number of rows in output neighbors and distances matrices must equal the number of queries."); + + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), + "Number of columns in output neighbors and distances matrices must equal k"); + + RAFT_EXPECTS(queries.extent(1) == index.dim(), + "Number of query dimensions should equal number of dimensions in the index."); + + std::uint32_t k = neighbors.extent(1); return detail::search(handle, params, index, queries.data_handle(), - n_queries, + static_cast(queries.extent(0)), k, neighbors.data_handle(), distances.data_handle(), diff --git a/cpp/include/raft/neighbors/specializations/ivf_pq.cuh b/cpp/include/raft/neighbors/specializations/ivf_pq.cuh index 3ff99fb4da..164e3277eb 100644 --- a/cpp/include/raft/neighbors/specializations/ivf_pq.cuh +++ b/cpp/include/raft/neighbors/specializations/ivf_pq.cuh @@ -22,37 +22,59 @@ namespace raft::neighbors::ivf_pq { -#define RAFT_INST(T, IdxT) \ - extern template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - const T* dataset, \ - IdxT n_rows, \ - uint32_t dim) \ - ->index; \ - extern template auto extend(raft::device_resources const& handle, \ - const index& orig_index, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows) \ - ->index; \ - extern template void extend(raft::device_resources const& handle, \ - index* index, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows); \ - extern template void search(raft::device_resources const&, \ - const search_params&, \ - const index&, \ - const T*, \ - uint32_t, \ - uint32_t, \ - IdxT*, \ - float*, \ - rmm::mr::device_memory_resource*); -RAFT_INST(float, uint64_t); -RAFT_INST(int8_t, uint64_t); -RAFT_INST(uint8_t, uint64_t); - -#undef RAFT_INST +#ifdef RAFT_DECL_BUILD_EXTEND +#undef RAFT_DECL_BUILD_EXTEND +#endif + +#ifdef RAFT_DECL_SEARCH +#undef RAFT_DECL_SEARCH +#endif + +// We define overloads for build and extend with void return type. This is used in the Cython +// wrappers, where exception handling is not compatible with return type that has nontrivial +// constructor. +#define RAFT_DECL_BUILD_EXTEND(T, IdxT) \ + extern template auto build(raft::device_resources const&, \ + raft::device_matrix_view, \ + const raft::neighbors::ivf_pq::index_params&) \ + ->raft::neighbors::ivf_pq::index; \ + \ + extern template void build(raft::device_resources const&, \ + index*, \ + raft::device_matrix_view, \ + const raft::neighbors::ivf_pq::index_params&); \ + \ + extern template auto extend( \ + raft::device_resources const&, \ + const index&, \ + raft::device_matrix_view, \ + std::optional>) \ + ->raft::neighbors::ivf_pq::index; \ + \ + extern template void extend( \ + raft::device_resources const&, \ + index*, \ + raft::device_matrix_view, \ + std::optional>); + +RAFT_DECL_BUILD_EXTEND(float, uint64_t) +RAFT_DECL_BUILD_EXTEND(int8_t, uint64_t) +RAFT_DECL_BUILD_EXTEND(uint8_t, uint64_t) + +#undef RAFT_DECL_BUILD_EXTEND + +#define RAFT_DECL_SEARCH(T, IdxT) \ + extern template void search(raft::device_resources const&, \ + const index&, \ + raft::device_matrix_view, \ + raft::device_matrix_view, \ + raft::device_matrix_view, \ + const search_params&); + +RAFT_DECL_SEARCH(float, uint64_t); +RAFT_DECL_SEARCH(int8_t, uint64_t); +RAFT_DECL_SEARCH(uint8_t, uint64_t); + +#undef RAFT_DECL_SEARCH } // namespace raft::neighbors::ivf_pq diff --git a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh index 066dcaaa6b..79539c3729 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh @@ -81,7 +81,7 @@ void approx_knn_build_index(raft::device_resources const& handle, auto index_view = raft::make_device_matrix_view(index_array, n, D); index->ivf_pq = std::make_unique>( - neighbors::ivf_pq::build(handle, params, index_view)); + neighbors::ivf_pq::build(handle, index_view, params)); } else { RAFT_FAIL("Unrecognized index type."); } @@ -112,13 +112,8 @@ void approx_knn_search(raft::device_resources const& handle, } else if (index->ivf_pq) { neighbors::ivf_pq::search_params params; params.n_probes = index->nprobe; - - auto query_view = - raft::make_device_matrix_view(query_array, n, index->ivf_pq->dim()); - auto indices_view = raft::make_device_matrix_view(indices, n, k); - auto distances_view = raft::make_device_matrix_view(distances, n, k); neighbors::ivf_pq::search( - handle, params, *index->ivf_pq, query_view, k, indices_view, distances_view); + handle, params, *index->ivf_pq, query_array, n, k, indices, distances); } else { RAFT_FAIL("The model is not trained"); } diff --git a/cpp/include/raft_runtime/neighbors/ivf_pq.hpp b/cpp/include/raft_runtime/neighbors/ivf_pq.hpp index e4c228effe..2b967db452 100644 --- a/cpp/include/raft_runtime/neighbors/ivf_pq.hpp +++ b/cpp/include/raft_runtime/neighbors/ivf_pq.hpp @@ -20,51 +20,48 @@ namespace raft::runtime::neighbors::ivf_pq { -#define RAFT_INST_SEARCH(T, IdxT) \ - void search(raft::device_resources const&, \ - const raft::neighbors::ivf_pq::search_params&, \ - const raft::neighbors::ivf_pq::index&, \ - raft::device_matrix_view, \ - uint32_t, \ - raft::device_matrix_view, \ - raft::device_matrix_view); - -RAFT_INST_SEARCH(float, uint64_t); -RAFT_INST_SEARCH(int8_t, uint64_t); -RAFT_INST_SEARCH(uint8_t, uint64_t); - -#undef RAFT_INST_SEARCH - // We define overloads for build and extend with void return type. This is used in the Cython // wrappers, where exception handling is not compatible with return type that has nontrivial // constructor. -#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ - auto build(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index_params& params, \ - raft::device_matrix_view dataset) \ - ->raft::neighbors::ivf_pq::index; \ - \ - auto extend(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index& orig_index, \ - raft::device_matrix_view new_vectors, \ - raft::device_matrix_view new_indices) \ - ->raft::neighbors::ivf_pq::index; \ - \ - void build(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index_params& params, \ - raft::device_matrix_view dataset, \ - raft::neighbors::ivf_pq::index* idx); \ - \ - void extend(raft::device_resources const& handle, \ - raft::neighbors::ivf_pq::index* idx, \ - raft::device_matrix_view new_vectors, \ - raft::device_matrix_view new_indices); +#define RAFT_DECL_BUILD_EXTEND(T, IdxT) \ + auto build(raft::device_resources const& handle, \ + raft::device_matrix_view dataset, \ + const raft::neighbors::ivf_pq::index_params& params); \ + \ + void build(raft::device_resources const& handle, \ + raft::neighbors::ivf_pq::index* index, \ + raft::device_matrix_view dataset, \ + const raft::neighbors::ivf_pq::index_params& params); \ + \ + auto extend(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index& orig_index, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices); \ + \ + void extend(raft::device_resources const& handle, \ + raft::neighbors::ivf_pq::index* index, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices); + +RAFT_DECL_BUILD_EXTEND(float, uint64_t) +RAFT_DECL_BUILD_EXTEND(int8_t, uint64_t) +RAFT_DECL_BUILD_EXTEND(uint8_t, uint64_t) + +#undef RAFT_DECL_BUILD_EXTEND + +#define RAFT_DECL_SEARCH(T, IdxT) \ + void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const raft::neighbors::ivf_pq::search_params& params); -RAFT_INST_BUILD_EXTEND(float, uint64_t) -RAFT_INST_BUILD_EXTEND(int8_t, uint64_t) -RAFT_INST_BUILD_EXTEND(uint8_t, uint64_t) +RAFT_DECL_SEARCH(float, uint64_t); +RAFT_DECL_SEARCH(int8_t, uint64_t); +RAFT_DECL_SEARCH(uint8_t, uint64_t); -#undef RAFT_INST_BUILD_EXTEND +#undef RAFT_DECL_SEARCH /** * Save the index to file. diff --git a/cpp/src/distance/neighbors/ivfpq_build.cu b/cpp/src/distance/neighbors/ivfpq_build.cu index 96ba349d1d..7a787386f9 100644 --- a/cpp/src/distance/neighbors/ivfpq_build.cu +++ b/cpp/src/distance/neighbors/ivfpq_build.cu @@ -22,32 +22,29 @@ namespace raft::runtime::neighbors::ivf_pq { #define RAFT_INST_BUILD_EXTEND(T, IdxT) \ auto build(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index_params& params, \ - raft::device_matrix_view dataset) \ - ->raft::neighbors::ivf_pq::index \ + raft::device_matrix_view dataset, \ + const raft::neighbors::ivf_pq::index_params& params) \ + { \ + return raft::neighbors::ivf_pq::build(handle, dataset, params); \ + } \ + void build(raft::device_resources const& handle, \ + raft::neighbors::ivf_pq::index* idx, \ + raft::device_matrix_view dataset, \ + const raft::neighbors::ivf_pq::index_params& params) \ { \ - return raft::neighbors::ivf_pq::build(handle, params, dataset); \ + raft::neighbors::ivf_pq::build(handle, idx, dataset, params); \ } \ auto extend(raft::device_resources const& handle, \ const raft::neighbors::ivf_pq::index& orig_index, \ raft::device_matrix_view new_vectors, \ - raft::device_matrix_view new_indices) \ - ->raft::neighbors::ivf_pq::index \ + std::optional> new_indices) \ { \ return raft::neighbors::ivf_pq::extend(handle, orig_index, new_vectors, new_indices); \ - } \ - \ - void build(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index_params& params, \ - raft::device_matrix_view dataset, \ - raft::neighbors::ivf_pq::index* idx) \ - { \ - *idx = raft::neighbors::ivf_pq::build(handle, params, dataset); \ } \ void extend(raft::device_resources const& handle, \ raft::neighbors::ivf_pq::index* idx, \ raft::device_matrix_view new_vectors, \ - raft::device_matrix_view new_indices) \ + std::optional> new_indices) \ { \ raft::neighbors::ivf_pq::extend(handle, idx, new_vectors, new_indices); \ } diff --git a/cpp/src/distance/neighbors/ivfpq_search_float_uint64_t.cu b/cpp/src/distance/neighbors/ivfpq_search_float_uint64_t.cu index 9bd750a2e2..caae4b6bb0 100644 --- a/cpp/src/distance/neighbors/ivfpq_search_float_uint64_t.cu +++ b/cpp/src/distance/neighbors/ivfpq_search_float_uint64_t.cu @@ -20,17 +20,15 @@ namespace raft::runtime::neighbors::ivf_pq { -#define RAFT_SEARCH_INST(T, IdxT) \ - void search(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - raft::device_matrix_view queries, \ - uint32_t k, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances) \ - { \ - raft::neighbors::ivf_pq::search( \ - handle, params, idx, queries, k, neighbors, distances); \ +#define RAFT_SEARCH_INST(T, IdxT) \ + void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const raft::neighbors::ivf_pq::search_params& params) \ + { \ + raft::neighbors::ivf_pq::search(handle, idx, queries, neighbors, distances, params); \ } RAFT_SEARCH_INST(float, uint64_t); diff --git a/cpp/src/distance/neighbors/ivfpq_search_int8_t_uint64_t.cu b/cpp/src/distance/neighbors/ivfpq_search_int8_t_uint64_t.cu index 303c7009cf..fd2f3d0ef8 100644 --- a/cpp/src/distance/neighbors/ivfpq_search_int8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/ivfpq_search_int8_t_uint64_t.cu @@ -20,17 +20,15 @@ namespace raft::runtime::neighbors::ivf_pq { -#define RAFT_SEARCH_INST(T, IdxT) \ - void search(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - raft::device_matrix_view queries, \ - uint32_t k, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances) \ - { \ - raft::neighbors::ivf_pq::search( \ - handle, params, idx, queries, k, neighbors, distances); \ +#define RAFT_SEARCH_INST(T, IdxT) \ + void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const raft::neighbors::ivf_pq::search_params& params) \ + { \ + raft::neighbors::ivf_pq::search(handle, idx, queries, neighbors, distances, params); \ } RAFT_SEARCH_INST(int8_t, uint64_t); diff --git a/cpp/src/distance/neighbors/ivfpq_search_uint8_t_uint64_t.cu b/cpp/src/distance/neighbors/ivfpq_search_uint8_t_uint64_t.cu index c057abd22e..7f203158a4 100644 --- a/cpp/src/distance/neighbors/ivfpq_search_uint8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/ivfpq_search_uint8_t_uint64_t.cu @@ -20,17 +20,15 @@ namespace raft::runtime::neighbors::ivf_pq { -#define RAFT_SEARCH_INST(T, IdxT) \ - void search(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - raft::device_matrix_view queries, \ - uint32_t k, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances) \ - { \ - raft::neighbors::ivf_pq::search( \ - handle, params, idx, queries, k, neighbors, distances); \ +#define RAFT_SEARCH_INST(T, IdxT) \ + void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const raft::neighbors::ivf_pq::search_params& params) \ + { \ + raft::neighbors::ivf_pq::search(handle, idx, queries, neighbors, distances, params); \ } RAFT_SEARCH_INST(uint8_t, uint64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_build_float_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_build_float_uint64_t.cu index 9563ea8a88..b6ca4275d8 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_build_float_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_build_float_uint64_t.cu @@ -20,9 +20,13 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - raft::device_matrix_view dataset) \ - ->index; + raft::device_matrix_view dataset, \ + const index_params& params) \ + ->index; \ + template void build(raft::device_resources const& handle, \ + index* index, \ + raft::device_matrix_view dataset, \ + const index_params& params); RAFT_MAKE_INSTANCE(float, uint64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_build_int8_t_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_build_int8_t_uint64_t.cu index 40c84d2a73..fcb7f1d467 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_build_int8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_build_int8_t_uint64_t.cu @@ -20,9 +20,13 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - raft::device_matrix_view dataset) \ - ->index; + raft::device_matrix_view dataset, \ + const index_params& params) \ + ->index; \ + template void build(raft::device_resources const& handle, \ + index* index, \ + raft::device_matrix_view dataset, \ + const index_params& params); RAFT_MAKE_INSTANCE(int8_t, uint64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_build_uint8_t_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_build_uint8_t_uint64_t.cu index 8d406542e8..4e795b4301 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_build_uint8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_build_uint8_t_uint64_t.cu @@ -20,9 +20,13 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - raft::device_matrix_view dataset) \ - ->index; + raft::device_matrix_view dataset, \ + const index_params& params) \ + ->index; \ + template void build(raft::device_resources const& handle, \ + index* index, \ + raft::device_matrix_view dataset, \ + const index_params& params); RAFT_MAKE_INSTANCE(uint8_t, uint64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_uint64_t.cu index 3a0690a2f1..96cb6c35fc 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_uint64_t.cu @@ -18,17 +18,18 @@ namespace raft::neighbors::ivf_pq { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend(raft::device_resources const& handle, \ - const index& orig_index, \ - raft::device_matrix_view new_vectors, \ - raft::device_matrix_view new_indices) \ - ->index; \ - template void extend( \ - raft::device_resources const& handle, \ - index* index, \ - raft::device_matrix_view new_vectors, \ - raft::device_matrix_view new_indices); +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template auto extend( \ + raft::device_resources const& handle, \ + const index& orig_index, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices) \ + ->index; \ + template void extend( \ + raft::device_resources const& handle, \ + index* index, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices); RAFT_MAKE_INSTANCE(float, uint64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_uint64_t.cu index 83cb2d14e9..f1dbc47fa0 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_uint64_t.cu @@ -18,17 +18,18 @@ namespace raft::neighbors::ivf_pq { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend(raft::device_resources const& handle, \ - const index& orig_index, \ - raft::device_matrix_view new_vectors, \ - raft::device_matrix_view new_indices) \ - ->index; \ - template void extend( \ - raft::device_resources const& handle, \ - index* index, \ - raft::device_matrix_view new_vectors, \ - raft::device_matrix_view new_indices); +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template auto extend( \ + raft::device_resources const& handle, \ + const index& orig_index, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices) \ + ->index; \ + template void extend( \ + raft::device_resources const& handle, \ + index* index, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices); RAFT_MAKE_INSTANCE(int8_t, uint64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_t_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_t_uint64_t.cu index 0b218dbc6f..3f45558e0e 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_t_uint64_t.cu @@ -18,17 +18,18 @@ namespace raft::neighbors::ivf_pq { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend(raft::device_resources const& handle, \ - const index& orig_index, \ - raft::device_matrix_view new_vectors, \ - raft::device_matrix_view new_indices) \ - ->index; \ - template void extend( \ - raft::device_resources const& handle, \ - index* index, \ - raft::device_matrix_view new_vectors, \ - raft::device_matrix_view new_indices); +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template auto extend( \ + raft::device_resources const& handle, \ + const index& orig_index, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices) \ + ->index; \ + template void extend( \ + raft::device_resources const& handle, \ + index* index, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices); RAFT_MAKE_INSTANCE(uint8_t, uint64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_search_float_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_search_float_uint64_t.cu index f28e854554..95ded932ef 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_search_float_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_search_float_uint64_t.cu @@ -20,12 +20,11 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template void search(raft::device_resources const& handle, \ - const search_params& params, \ const index& index, \ raft::device_matrix_view queries, \ - uint32_t k, \ raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances); + raft::device_matrix_view distances, \ + const search_params& params); RAFT_MAKE_INSTANCE(float, uint64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_search_int8_t_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_search_int8_t_uint64_t.cu index 230001df75..1a6ed041b6 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_search_int8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_search_int8_t_uint64_t.cu @@ -20,12 +20,11 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template void search(raft::device_resources const& handle, \ - const search_params& params, \ const index& index, \ raft::device_matrix_view queries, \ - uint32_t k, \ raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances); + raft::device_matrix_view distances, \ + const search_params& params); RAFT_MAKE_INSTANCE(int8_t, uint64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_search_uint8_t_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_search_uint8_t_uint64_t.cu index c6ff5097dc..65bfd1af00 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_search_uint8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_search_uint8_t_uint64_t.cu @@ -20,12 +20,11 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template void search(raft::device_resources const& handle, \ - const search_params& params, \ const index& index, \ raft::device_matrix_view queries, \ - uint32_t k, \ raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances); + raft::device_matrix_view distances, \ + const search_params& params); RAFT_MAKE_INSTANCE(uint8_t, uint64_t); diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 26b8301cb1..02314a2278 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -186,11 +186,11 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { std::make_optional>( new_half_of_data_indices_view)); - auto search_queries_view = raft::make_device_matrix_view( + auto search_queries_view = raft::make_device_matrix_view( search_queries.data(), ps.num_queries, ps.dim); - auto indices_out_view = raft::make_device_matrix_view( + auto indices_out_view = raft::make_device_matrix_view( indices_ivfflat_dev.data(), ps.num_queries, ps.k); - auto dists_out_view = raft::make_device_matrix_view( + auto dists_out_view = raft::make_device_matrix_view( distances_ivfflat_dev.data(), ps.num_queries, ps.k); raft::spatial::knn::ivf_flat::detail::serialize(handle_, "ivf_flat_index", index_2); @@ -202,8 +202,7 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { search_queries_view, indices_out_view, dists_out_view, - search_params, - ps.k); + search_params); update_host(distances_ivfflat.data(), distances_ivfflat_dev.data(), queries_size, stream_); update_host(indices_ivfflat.data(), indices_ivfflat_dev.data(), queries_size, stream_); diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index e2a938aef8..b2d272d730 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -185,7 +185,7 @@ class ivf_pq_test : public ::testing::TestWithParam { auto index_view = raft::make_device_matrix_view(database.data(), ps.num_db_vecs, ps.dim); - return ivf_pq::build(handle_, ipams, index_view); + return ivf_pq::build(handle_, index_view, ipams); } auto build_2_extends() @@ -207,7 +207,7 @@ class ivf_pq_test : public ::testing::TestWithParam { auto database_view = raft::make_device_matrix_view(database.data(), ps.num_db_vecs, ps.dim); - auto index = ivf_pq::build(handle_, ipams, database_view); + auto index = ivf_pq::build(handle_, database_view, ipams); auto vecs_2_view = raft::make_device_matrix_view(vecs_2, size_2, ps.dim); auto inds_2_view = raft::make_device_matrix_view(inds_2, size_2, 1); @@ -244,7 +244,7 @@ class ivf_pq_test : public ::testing::TestWithParam { raft::make_device_matrix_view(distances_ivf_pq_dev.data(), ps.num_queries, ps.k); ivf_pq::search( - handle_, ps.search_params, index, query_view, ps.k, inds_view, dists_view); + handle_, index, query_view, inds_view, dists_view, ps.search_params); update_host(distances_ivf_pq.data(), distances_ivf_pq_dev.data(), queries_size, stream_); update_host(indices_ivf_pq.data(), indices_ivf_pq_dev.data(), queries_size, stream_); diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd b/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd index ca35f5b8ca..bf31403822 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd @@ -41,6 +41,39 @@ from pylibraft.common.handle cimport device_resources from pylibraft.distance.distance_type cimport DistanceType +cdef extern from "" namespace "std" nogil: + cdef cppclass nullopt_t: + nullopt_t() + + cdef nullopt_t nullopt + + cdef cppclass optional[T]: + ctypedef T value_type + optional() + optional(nullopt_t) + optional(optional&) except + + optional(T&) except + + bool has_value() + T& value() + T& value_or[U](U& default_value) + void swap(optional&) + void reset() + T& emplace(...) + T& operator*() + optional& operator=(optional&) + optional& operator=[U](U&) + bool operator bool() + bool operator!() + bool operator==[U](optional&, U&) + bool operator!=[U](optional&, U&) + bool operator<[U](optional&, U&) + bool operator>[U](optional&, U&) + bool operator<=[U](optional&, U&) + bool operator>=[U](optional&, U&) + + optional[T] make_optional[T](...) except + + + cdef extern from "library_types.h": ctypedef enum cudaDataType_t: CUDA_R_32F "CUDA_R_32F" # float @@ -113,66 +146,63 @@ cdef extern from "raft_runtime/neighbors/ivf_pq.hpp" \ cdef void build( const device_resources& handle, - const index_params& params, - device_matrix_view[float, uint64_t, row_major] dataset, - index[uint64_t]* index) except + + index[uint64_t]* index, + device_matrix_view[const float, uint64_t, row_major] dataset, + const index_params& params) except + cdef void build( const device_resources& handle, - const index_params& params, - device_matrix_view[int8_t, uint64_t, row_major] dataset, - index[uint64_t]* index) except + + index[uint64_t]* index, + device_matrix_view[const int8_t, uint64_t, row_major] dataset, + const index_params& params) except + cdef void build( const device_resources& handle, - const index_params& params, - device_matrix_view[uint8_t, uint64_t, row_major] dataset, - index[uint64_t]* index) except + + index[uint64_t]* index, + device_matrix_view[const uint8_t, uint64_t, row_major] dataset, + const index_params& params) except + cdef void extend( const device_resources& handle, index[uint64_t]* index, - device_matrix_view[float, uint64_t, row_major] new_vectors, - device_matrix_view[uint64_t, uint64_t, row_major] new_indices) except + + device_matrix_view[const float, uint64_t, row_major] new_vectors, + optional[device_matrix_view[const uint64_t, uint64_t, row_major]] new_indices) except + # noqa: E501 cdef void extend( const device_resources& handle, index[uint64_t]* index, - device_matrix_view[int8_t, uint64_t, row_major] new_vectors, - device_matrix_view[uint64_t, uint64_t, row_major] new_indices) except + + device_matrix_view[const int8_t, uint64_t, row_major] new_vectors, + optional[device_matrix_view[const uint64_t, uint64_t, row_major]] new_indices) except + # noqa: E501 cdef void extend( const device_resources& handle, index[uint64_t]* index, - device_matrix_view[uint8_t, uint64_t, row_major] new_vectors, - device_matrix_view[uint64_t, uint64_t, row_major] new_indices) except + + device_matrix_view[const uint8_t, uint64_t, row_major] new_vectors, + optional[device_matrix_view[const uint64_t, uint64_t, row_major]] new_indices) except + # noqa: E501 cdef void search( const device_resources& handle, - const search_params& params, const index[uint64_t]& index, device_matrix_view[float, uint64_t, row_major] queries, - uint32_t k, device_matrix_view[uint64_t, uint64_t, row_major] neighbors, - device_matrix_view[float, uint64_t, row_major] distances) except + + device_matrix_view[float, uint64_t, row_major] distances, + const search_params& params) except + cdef void search( const device_resources& handle, - const search_params& params, const index[uint64_t]& index, device_matrix_view[int8_t, uint64_t, row_major] queries, - uint32_t k, device_matrix_view[uint64_t, uint64_t, row_major] neighbors, - device_matrix_view[float, uint64_t, row_major] distances) except + + device_matrix_view[float, uint64_t, row_major] distances, + const search_params& params) except + cdef void search( const device_resources& handle, - const search_params& params, const index[uint64_t]& index, device_matrix_view[uint8_t, uint64_t, row_major] queries, - uint32_t k, device_matrix_view[uint64_t, uint64_t, row_major] neighbors, - device_matrix_view[float, uint64_t, row_major] distances) except + + device_matrix_view[float, uint64_t, row_major] distances, + const search_params& params) except + cdef void serialize(const device_resources& handle, const string& filename, diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx index 47d8e94e5f..4276514a47 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx @@ -409,23 +409,23 @@ def build(IndexParams index_params, dataset, handle=None): if dataset_dt == np.float32: with cuda_interruptible(): c_ivf_pq.build(deref(handle_), - index_params.params, + idx.index, get_dmv_float(dataset_cai, check_shape=True), - idx.index) + index_params.params) idx.trained = True elif dataset_dt == np.byte: with cuda_interruptible(): c_ivf_pq.build(deref(handle_), - index_params.params, + idx.index, get_dmv_int8(dataset_cai, check_shape=True), - idx.index) + index_params.params) idx.trained = True elif dataset_dt == np.ubyte: with cuda_interruptible(): c_ivf_pq.build(deref(handle_), - index_params.params, + idx.index, get_dmv_uint8(dataset_cai, check_shape=True), - idx.index) + index_params.params) idx.trained = True else: raise TypeError("dtype %s not supported" % dataset_dt) @@ -720,30 +720,27 @@ def search(SearchParams search_params, if queries_dt == np.float32: with cuda_interruptible(): c_ivf_pq.search(deref(handle_), - params, deref(index.index), get_dmv_float(queries_cai, check_shape=True), - k, get_dmv_uint64(neighbors_cai, check_shape=True), - get_dmv_float(distances_cai, check_shape=True)) + get_dmv_float(distances_cai, check_shape=True), + params) elif queries_dt == np.byte: with cuda_interruptible(): c_ivf_pq.search(deref(handle_), - params, deref(index.index), get_dmv_int8(queries_cai, check_shape=True), - k, get_dmv_uint64(neighbors_cai, check_shape=True), - get_dmv_float(distances_cai, check_shape=True)) + get_dmv_float(distances_cai, check_shape=True), + params) elif queries_dt == np.ubyte: with cuda_interruptible(): c_ivf_pq.search(deref(handle_), - params, deref(index.index), get_dmv_uint8(queries_cai, check_shape=True), - k, get_dmv_uint64(neighbors_cai, check_shape=True), - get_dmv_float(distances_cai, check_shape=True)) + get_dmv_float(distances_cai, check_shape=True), + params) else: raise ValueError("query dtype %s not supported" % queries_dt) From 6e34d8ad297aee4359552d685d0d2ae5d10e942b Mon Sep 17 00:00:00 2001 From: viclafargue Date: Tue, 14 Mar 2023 12:04:30 +0100 Subject: [PATCH 2/4] addressing reviews --- cpp/bench/neighbors/knn.cuh | 10 +- cpp/include/raft/neighbors/ivf_flat.cuh | 410 +++++++++--------- cpp/include/raft/neighbors/ivf_pq.cuh | 124 ++---- .../raft/neighbors/specializations/ivf_pq.cuh | 55 ++- .../raft/spatial/knn/detail/ann_quantized.cuh | 9 +- cpp/include/raft_runtime/neighbors/ivf_pq.hpp | 52 +-- cpp/src/distance/neighbors/ivfpq_build.cu | 56 +-- .../neighbors/ivfpq_search_float_int64_t.cu | 6 +- .../neighbors/ivfpq_search_int8_t_int64_t.cu | 6 +- .../neighbors/ivfpq_search_uint8_t_int64_t.cu | 6 +- .../ivfpq_build_float_int64_t.cu | 10 +- .../ivfpq_build_int8_t_int64_t.cu | 10 +- .../ivfpq_build_uint8_t_int64_t.cu | 10 +- .../ivfpq_extend_float_int64_t.cu | 8 +- .../ivfpq_extend_int8_t_int64_t.cu | 8 +- .../ivfpq_extend_uint8_t_int64_t.cu | 8 +- .../ivfpq_search_float_int64_t.cu | 6 +- .../ivfpq_search_int8_t_int64_t.cu | 6 +- .../ivfpq_search_uint8_t_int64_t.cu | 6 +- cpp/test/neighbors/ann_ivf_flat.cuh | 38 +- cpp/test/neighbors/ann_ivf_pq.cuh | 27 +- .../neighbors/ivf_pq/cpp/c_ivf_pq.pxd | 36 +- .../pylibraft/neighbors/ivf_pq/ivf_pq.pyx | 36 +- 23 files changed, 444 insertions(+), 499 deletions(-) diff --git a/cpp/bench/neighbors/knn.cuh b/cpp/bench/neighbors/knn.cuh index 37d4471852..fe8c2c10d8 100644 --- a/cpp/bench/neighbors/knn.cuh +++ b/cpp/bench/neighbors/knn.cuh @@ -178,8 +178,8 @@ struct ivf_pq_knn { { index_params.n_lists = 4096; index_params.metric = raft::distance::DistanceType::L2Expanded; - index.emplace(raft::neighbors::ivf_pq::build( - handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims))); + auto data_view = raft::make_device_matrix_view(data, ps.n_samples, ps.n_dims); + index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view)); } void search(const raft::device_resources& handle, @@ -188,8 +188,12 @@ struct ivf_pq_knn { IdxT* out_idxs) { search_params.n_probes = 20; + auto queries_view = + raft::make_device_matrix_view(search_items, ps.n_queries, ps.n_dims); + auto idxs_view = raft::make_device_matrix_view(out_idxs, ps.n_queries, ps.k); + auto dists_view = raft::make_device_matrix_view(out_dists, ps.n_queries, ps.k); raft::neighbors::ivf_pq::search( - handle, search_params, *index, search_items, ps.n_queries, ps.k, out_idxs, out_dists); + handle, search_params, *index, queries_view, idxs_view, dists_view); } }; diff --git a/cpp/include/raft/neighbors/ivf_flat.cuh b/cpp/include/raft/neighbors/ivf_flat.cuh index 34080038f5..dd16813737 100644 --- a/cpp/include/raft/neighbors/ivf_flat.cuh +++ b/cpp/include/raft/neighbors/ivf_flat.cuh @@ -28,6 +28,13 @@ namespace raft::neighbors::ivf_flat { +namespace detail = raft::spatial::knn::ivf_flat::detail; + +/** + * @defgroup ivf_flat IVF Flat Algorithm + * @{ + */ + /** * @brief Build the index from the dataset for efficient search. * @@ -42,11 +49,11 @@ namespace raft::neighbors::ivf_flat { * // use default index parameters * ivf_flat::index_params index_params; * // create and fill the index from a [N, D] dataset - * auto index = ivf_flat::build(handle, index_params, dataset, N, D); + * auto index = ivf_flat::build(handle, dataset, index_params); * // use default search parameters * ivf_flat::search_params search_params; * // search K nearest neighbours for each of the N queries - * ivf_flat::search(handle, search_params, index, queries, N, K, out_inds, out_dists); + * ivf_flat::search(handle, index, queries, out_inds, out_dists, search_params, k); * @endcode * * @tparam T data element type @@ -55,78 +62,68 @@ namespace raft::neighbors::ivf_flat { * @param[in] handle * @param[in] params configure the index building * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] - * @param[in] n_rows the number of samples - * @param[in] dim the dimensionality of the data * * @return the constructed ivf-flat index */ template auto build(raft::device_resources const& handle, const index_params& params, - const T* dataset, - IdxT n_rows, - uint32_t dim) -> index + raft::device_matrix_view dataset) -> index { - return raft::spatial::knn::ivf_flat::detail::build(handle, params, dataset, n_rows, dim); + IdxT n_rows = dataset.extent(0); + IdxT dim = dataset.extent(1); + return detail::build(handle, params, dataset.data_handle(), n_rows, dim); } /** - * @defgroup ivf_flat IVF Flat Algorithm - * @{ - */ - -/** - * @brief Build the index from the dataset for efficient search. - * - * NB: Currently, the following distance metrics are supported: - * - L2Expanded - * - L2Unexpanded - * - InnerProduct + * @brief Extend the index in-place with the new data. * * Usage example: * @code{.cpp} * using namespace raft::neighbors; - * // use default index parameters * ivf_flat::index_params index_params; - * // create and fill the index from a [N, D] dataset - * auto index = ivf_flat::build(handle, dataset, index_params); - * // use default search parameters - * ivf_flat::search_params search_params; - * // search K nearest neighbours for each of the N queries - * ivf_flat::search(handle, index, queries, out_inds, out_dists, search_params, k); + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_flat::build(handle, dataset, index_params, dataset); + * // fill the index with the data + * ivf_flat::extend(handle, index_empty, dataset); * @endcode * - * @tparam value_t data element type - * @tparam idx_t type of the indices in the source dataset - * @tparam int_t precision / type of integral arguments - * @tparam matrix_idx_t matrix indexing type + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset * * @param[in] handle - * @param[in] params configure the index building - * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] - * - * @return the constructed ivf-flat index + * @param[inout] idx + * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices a device pointer to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. */ -template -auto build(raft::device_resources const& handle, - raft::device_matrix_view dataset, - const index_params& params) -> index +template +index extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + index& idx) { - return raft::spatial::knn::ivf_flat::detail::build(handle, - params, - dataset.data_handle(), - static_cast(dataset.extent(0)), - static_cast(dataset.extent(1))); -} + ASSERT(new_vectors.extent(1) == idx.dim(), + "new_vectors should have the same dimension as the index"); -/** @} */ + IdxT n_rows = new_vectors.extent(0); + if (new_indices.has_value()) { + ASSERT(n_rows == new_indices.value().extent(0), + "new_vectors and new_indices have different number of rows"); + } + + return detail::extend(handle, + idx, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + n_rows); +} /** - * @brief Build a new index containing the data of the original plus new extra vectors. - * - * Implementation note: - * The new data is clustered according to existing kmeans clusters, then the cluster - * centers are adjusted to match the newly labeled data. + * @brief Extend the index in-place with the new data. * * Usage example: * @code{.cpp} @@ -135,92 +132,158 @@ auto build(raft::device_resources const& handle, * index_params.add_data_on_build = false; // don't populate index on build * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training * // train the index from a [N, D] dataset - * auto index_empty = ivf_flat::build(handle, index_params, dataset, N, D); + * auto index_empty = ivf_flat::build(handle, dataset, index_params, dataset); * // fill the index with the data - * auto index = ivf_flat::extend(handle, index_empty, dataset, nullptr, N); + * ivf_flat::extend(handle, index_empty, dataset); * @endcode * * @tparam T data element type * @tparam IdxT type of the indices in the source dataset * * @param[in] handle - * @param[in] orig_index original index + * @param[inout] idx * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] * @param[in] new_indices a device pointer to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` * here to imply a continuous range `[0...n_rows)`. - * @param[in] n_rows number of rows in `new_vectors` - * - * @return the constructed extended ivf-flat index */ template -auto extend(raft::device_resources const& handle, - const index& orig_index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -> index +void extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + index* idx) { - return raft::spatial::knn::ivf_flat::detail::extend( - handle, orig_index, new_vectors, new_indices, n_rows); + ASSERT(new_vectors.extent(1) == idx->dim(), + "new_vectors should have the same dimension as the index"); + + IdxT n_rows = new_vectors.extent(0); + if (new_indices.has_value()) { + ASSERT(n_rows == new_indices.value().extent(0), + "new_vectors and new_indices have different number of rows"); + } + + *idx = detail::extend(handle, + *idx, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + n_rows); } /** - * @ingroup ivf_flat - * @{ + * @brief Search ANN using the constructed index. + * + * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`: + * @code{.cpp} + * ... + * // use default search parameters + * ivf_flat::search_params search_params; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_flat::search(handle, index, queries1, out_inds1, out_dists1, search_params); + * ivf_flat::search(handle, index, queries2, out_inds2, out_dists2, search_params); + * ivf_flat::search(handle, index, queries3, out_inds3, out_dists3, search_params); + * ... + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] idx ivf-flat constructed index + * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] */ +template +void search(raft::device_resources const& handle, + const search_params& params, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) +{ + RAFT_EXPECTS( + queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), + "Number of rows in output neighbors and distances matrices must equal the number of queries."); + + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), + "Number of columns in output neighbors and distances matrices must equal k"); + + RAFT_EXPECTS(queries.extent(1) == idx.dim(), + "Number of query dimensions should equal number of dimensions in the index."); + + IdxT n_queries = queries.extent(0); + uint32_t k = neighbors.extent(1); + return detail::search(handle, + params, + idx, + queries.data_handle(), + n_queries, + k, + neighbors.data_handle(), + distances.data_handle(), + handle.get_workspace_resource()); +} + +/** @} */ /** - * @brief Build a new index containing the data of the original plus new extra vectors. + * @brief Build the index from the dataset for efficient search. * - * Implementation note: - * The new data is clustered according to existing kmeans clusters, then the cluster - * centers are adjusted to match the newly labeled data. + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct * * Usage example: * @code{.cpp} * using namespace raft::neighbors; + * // use default index parameters * ivf_flat::index_params index_params; - * index_params.add_data_on_build = false; // don't populate index on build - * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training - * // train the index from a [N, D] dataset - * auto index_empty = ivf_flat::build(handle, dataset, index_params, dataset); - * // fill the index with the data - * auto index = ivf_flat::extend(handle, index_empty, dataset); + * // create and fill the index from a [N, D] dataset + * auto index = ivf_flat::build(handle, index_params, dataset, N, D); + * // use default search parameters + * ivf_flat::search_params search_params; + * // search K nearest neighbours for each of the N queries + * ivf_flat::search(handle, search_params, index, queries, N, K, out_inds, out_dists); * @endcode * - * @tparam value_t data element type - * @tparam idx_t type of the indices in the source dataset - * @tparam int_t precision / type of integral arguments - * @tparam matrix_idx_t matrix indexing type + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset * * @param[in] handle - * @param[in] orig_index original index - * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] - * @param[in] new_indices a device pointer to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` - * here to imply a continuous range `[0...n_rows)`. + * @param[in] params configure the index building + * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] + * @param[in] n_rows the number of samples + * @param[in] dim the dimensionality of the data * - * @return the constructed extended ivf-flat index + * @return the constructed ivf-flat index */ -template -auto extend(raft::device_resources const& handle, - const index& orig_index, - raft::device_matrix_view new_vectors, - std::optional> new_indices = std::nullopt) - -> index +template +auto build(raft::device_resources const& handle, + const index_params& params, + const T* dataset, + IdxT n_rows, + uint32_t dim) -> index { - return extend( - handle, - orig_index, - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - new_vectors.extent(0)); + return detail::build(handle, params, dataset, n_rows, dim); } -/** @} */ +/** @} */ // end group ivf_flat /** - * @brief Extend the index in-place with the new data. + * @brief Build a new index containing the data of the original plus new extra vectors. + * + * Implementation note: + * The new data is clustered according to existing kmeans clusters, then the cluster + * centers are adjusted to match the newly labeled data. * * Usage example: * @code{.cpp} @@ -231,35 +294,32 @@ auto extend(raft::device_resources const& handle, * // train the index from a [N, D] dataset * auto index_empty = ivf_flat::build(handle, index_params, dataset, N, D); * // fill the index with the data - * ivf_flat::extend(handle, index_empty, dataset, nullptr, N); + * auto index = ivf_flat::extend(handle, index_empty, dataset, nullptr, N); * @endcode * * @tparam T data element type * @tparam IdxT type of the indices in the source dataset * - * @param handle - * @param[inout] index + * @param[in] handle + * @param[in] idx original index * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] * @param[in] new_indices a device pointer to a vector of indices [n_rows]. * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. - * @param[in] n_rows the number of samples + * @param[in] n_rows number of rows in `new_vectors` + * + * @return the constructed extended ivf-flat index */ template -void extend(raft::device_resources const& handle, - index* index, +auto extend(raft::device_resources const& handle, + const index& idx, const T* new_vectors, const IdxT* new_indices, - IdxT n_rows) + IdxT n_rows) -> index { - *index = extend(handle, *index, new_vectors, new_indices, n_rows); + return detail::extend(handle, idx, new_vectors, new_indices, n_rows); } -/** - * @ingroup ivf_flat - * @{ - */ - /** * @brief Extend the index in-place with the new data. * @@ -270,38 +330,32 @@ void extend(raft::device_resources const& handle, * index_params.add_data_on_build = false; // don't populate index on build * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training * // train the index from a [N, D] dataset - * auto index_empty = ivf_flat::build(handle, dataset, index_params, dataset); + * auto index_empty = ivf_flat::build(handle, index_params, dataset, N, D); * // fill the index with the data - * ivf_flat::extend(handle, index_empty, dataset); + * ivf_flat::extend(handle, index_empty, dataset, nullptr, N); * @endcode * - * @tparam value_t data element type - * @tparam idx_t type of the indices in the source dataset - * @tparam int_t precision / type of integral arguments - * @tparam matrix_idx_t matrix indexing type + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset * - * @param[in] handle - * @param[inout] index + * @param handle + * @param[inout] idx * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] * @param[in] new_indices a device pointer to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` + * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. + * @param[in] n_rows the number of samples */ -template +template void extend(raft::device_resources const& handle, - index* index, - raft::device_matrix_view new_vectors, - std::optional> new_indices = std::nullopt) + index& idx, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) { - *index = extend(handle, - *index, - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - static_cast(new_vectors.extent(0))); + idx = detail::extend(handle, idx, new_vectors, new_indices, n_rows); } -/** @} */ - /** * @brief Search ANN using the constructed index. * @@ -332,22 +386,22 @@ void extend(raft::device_resources const& handle, * @tparam T data element type * @tparam IdxT type of the indices * - * @param[in] handle - * @param[in] params configure the search - * @param[in] index ivf-flat constructed index + * @param handle + * @param params configure the search + * @param idx ivf-pq constructed index * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] - * @param[in] n_queries the batch size - * @param[in] k the number of neighbors to find for each query. + * @param n_queries the batch size + * @param k the number of neighbors to find for each query. * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset * [n_queries, k] * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param[in] mr an optional memory resource to use across the searches (you can provide a large - * enough memory pool here to avoid memory allocations within search). + * @param mr an optional memory resource to use across the searches (you can provide a large enough + * memory pool here to avoid memory allocations within search). */ template void search(raft::device_resources const& handle, const search_params& params, - const index& index, + const index& idx, const T* queries, uint32_t n_queries, uint32_t k, @@ -355,79 +409,7 @@ void search(raft::device_resources const& handle, float* distances, rmm::mr::device_memory_resource* mr = nullptr) { - return raft::spatial::knn::ivf_flat::detail::search( - handle, params, index, queries, n_queries, k, neighbors, distances, mr); + return detail::search(handle, params, idx, queries, n_queries, k, neighbors, distances, mr); } -/** - * @ingroup ivf_flat - * @{ - */ - -/** - * @brief Search ANN using the constructed index. - * - * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. - * - * Note, this function requires a temporary buffer to store intermediate results between cuda kernel - * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can - * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or - * eliminate entirely allocations happening within `search`: - * @code{.cpp} - * ... - * // use default search parameters - * ivf_flat::search_params search_params; - * // Use the same allocator across multiple searches to reduce the number of - * // cuda memory allocations - * ivf_flat::search(handle, index, queries1, out_inds1, out_dists1, search_params); - * ivf_flat::search(handle, index, queries2, out_inds2, out_dists2, search_params); - * ivf_flat::search(handle, index, queries3, out_inds3, out_dists3, search_params); - * ... - * @endcode - * - * @tparam value_t data element type - * @tparam idx_t type of the indices - * @tparam int_t precision / type of integral arguments - * @tparam matrix_idx_t matrix indexing type - * - * @param[in] handle - * @param[in] index ivf-flat constructed index - * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param[in] params configure the search - */ -template -void search(raft::device_resources const& handle, - const index& index, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - const search_params& params) -{ - RAFT_EXPECTS( - queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), - "Number of rows in output neighbors and distances matrices must equal the number of queries."); - - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), - "Number of columns in output neighbors and distances matrices must equal k"); - - RAFT_EXPECTS(queries.extent(1) == index.dim(), - "Number of query dimensions should equal number of dimensions in the index."); - - std::uint32_t k = neighbors.extent(1); - return search(handle, - params, - index, - queries.data_handle(), - static_cast(queries.extent(0)), - k, - neighbors.data_handle(), - distances.data_handle(), - handle.get_workspace_resource()); -} - -/** @} */ - } // namespace raft::neighbors::ivf_flat diff --git a/cpp/include/raft/neighbors/ivf_pq.cuh b/cpp/include/raft/neighbors/ivf_pq.cuh index 549bf606a4..fd293672de 100644 --- a/cpp/include/raft/neighbors/ivf_pq.cuh +++ b/cpp/include/raft/neighbors/ivf_pq.cuh @@ -46,15 +46,15 @@ namespace raft::neighbors::ivf_pq { * @tparam IdxT type of the indices in the source dataset * * @param handle - * @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim] * @param params configure the index building + * @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim] * * @return the constructed ivf-pq index */ template auto build(raft::device_resources const& handle, - raft::device_matrix_view dataset, - const index_params& params) -> index + const index_params& params, + raft::device_matrix_view dataset) -> index { IdxT n_rows = dataset.extent(0); IdxT dim = dataset.extent(1); @@ -62,61 +62,25 @@ auto build(raft::device_resources const& handle, } /** - * @brief Build the index from the dataset for efficient search. - * - * NB: Currently, the following distance metrics are supported: - * - L2Expanded - * - L2Unexpanded - * - InnerProduct - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param handle - * @param[inout] index - * @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim] - * @param params configure the index building - * - * @return the constructed ivf-pq index - */ -template -void build(raft::device_resources const& handle, - index* index, - raft::device_matrix_view dataset, - const index_params& params) -{ - IdxT n_rows = dataset.extent(0); - IdxT dim = dataset.extent(1); - *index = detail::build(handle, params, dataset.data_handle(), n_rows, dim); -} - -/** - * @brief Build a new index containing the data of the original plus new extra vectors. - * - * Implementation note: - * The new data is clustered according to existing kmeans clusters, then the cluster - * centers are unchanged. - * + * @brief Extend the index with the new data. + * * * @tparam T data element type * @tparam IdxT type of the indices in the source dataset * * @param handle - * @param orig_index original index + * @param[inout] idx * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()] * @param[in] new_indices a device matrix view to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * If the original index is empty (`index.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. - * - * @return the constructed extended ivf-pq index */ template -auto extend(raft::device_resources const& handle, - const index& orig_index, - raft::device_matrix_view new_vectors, - std::optional> new_indices = - std::nullopt) -> index +index extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + const index& idx) { - ASSERT(new_vectors.extent(1) == orig_index.dim(), + ASSERT(new_vectors.extent(1) == idx.dim(), "new_vectors should have the same dimension as the index"); IdxT n_rows = new_vectors.extent(0); @@ -126,7 +90,7 @@ auto extend(raft::device_resources const& handle, } return detail::extend(handle, - orig_index, + idx, new_vectors.data_handle(), new_indices.has_value() ? new_indices.value().data_handle() : nullptr, n_rows); @@ -139,20 +103,19 @@ auto extend(raft::device_resources const& handle, * @tparam IdxT type of the indices in the source dataset * * @param handle - * @param[inout] index + * @param[inout] idx * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()] * @param[in] new_indices a device matrix view to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * If the original index is empty (`index.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. */ template -void extend( - raft::device_resources const& handle, - index* index, - raft::device_matrix_view new_vectors, - std::optional> new_indices = std::nullopt) +void extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + index* idx) { - ASSERT(new_vectors.extent(1) == index->dim(), + ASSERT(new_vectors.extent(1) == idx->dim(), "new_vectors should have the same dimension as the index"); IdxT n_rows = new_vectors.extent(0); @@ -161,11 +124,11 @@ void extend( "new_vectors and new_indices have different number of rows"); } - *index = extend(handle, - *index, - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - n_rows); + *idx = detail::extend(handle, + *idx, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + n_rows); } /** @@ -185,21 +148,21 @@ void extend( * @tparam IdxT type of the indices * * @param handle - * @param index ivf-pq constructed index + * @param params configure the search + * @param idx ivf-pq constructed index * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset * [n_queries, k] * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, * k] - * @param params configure the search */ template void search(raft::device_resources const& handle, - const index& index, + const search_params& params, + const index& idx, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - const search_params& params) + raft::device_matrix_view distances) { RAFT_EXPECTS( queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), @@ -208,13 +171,13 @@ void search(raft::device_resources const& handle, RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Number of columns in output neighbors and distances matrices must equal k"); - RAFT_EXPECTS(queries.extent(1) == index.dim(), + RAFT_EXPECTS(queries.extent(1) == idx.dim(), "Number of query dimensions should equal number of dimensions in the index."); std::uint32_t k = neighbors.extent(1); return detail::search(handle, params, - index, + idx, queries.data_handle(), static_cast(queries.extent(0)), k, @@ -290,10 +253,10 @@ auto build(raft::device_resources const& handle, * @tparam IdxT type of the indices in the source dataset * * @param handle - * @param orig_index original index + * @param idx original index * @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, index.dim()] * @param[in] new_indices a device/host pointer to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * If the original index is empty (`index.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. * @param n_rows the number of samples * @@ -301,12 +264,12 @@ auto build(raft::device_resources const& handle, */ template auto extend(raft::device_resources const& handle, - const index& orig_index, + const index& idx, const T* new_vectors, const IdxT* new_indices, IdxT n_rows) -> index { - return detail::extend(handle, orig_index, new_vectors, new_indices, n_rows); + return detail::extend(handle, idx, new_vectors, new_indices, n_rows); } /** @@ -316,21 +279,22 @@ auto extend(raft::device_resources const& handle, * @tparam IdxT type of the indices in the source dataset * * @param handle - * @param[inout] index + * @param[inout] idx * @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, index.dim()] * @param[in] new_indices a device/host pointer to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * If the original index is empty (`index.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. + * @param[inout] idx * @param n_rows the number of samples */ template void extend(raft::device_resources const& handle, - index* index, + index* idx, const T* new_vectors, const IdxT* new_indices, IdxT n_rows) { - detail::extend(handle, index, new_vectors, new_indices, n_rows); + detail::extend(handle, idx, new_vectors, new_indices, n_rows); } /** @@ -365,7 +329,7 @@ void extend(raft::device_resources const& handle, * * @param handle * @param params configure the search - * @param index ivf-pq constructed index + * @param idx ivf-pq constructed index * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] * @param n_queries the batch size * @param k the number of neighbors to find for each query. @@ -378,7 +342,7 @@ void extend(raft::device_resources const& handle, template void search(raft::device_resources const& handle, const search_params& params, - const index& index, + const index& idx, const T* queries, uint32_t n_queries, uint32_t k, @@ -386,7 +350,7 @@ void search(raft::device_resources const& handle, float* distances, rmm::mr::device_memory_resource* mr = nullptr) { - return detail::search(handle, params, index, queries, n_queries, k, neighbors, distances, mr); + return detail::search(handle, params, idx, queries, n_queries, k, neighbors, distances, mr); } } // namespace raft::neighbors::ivf_pq diff --git a/cpp/include/raft/neighbors/specializations/ivf_pq.cuh b/cpp/include/raft/neighbors/specializations/ivf_pq.cuh index 352c75bc89..55a7cd5858 100644 --- a/cpp/include/raft/neighbors/specializations/ivf_pq.cuh +++ b/cpp/include/raft/neighbors/specializations/ivf_pq.cuh @@ -35,29 +35,24 @@ namespace raft::neighbors::ivf_pq { // We define overloads for build and extend with void return type. This is used in the Cython // wrappers, where exception handling is not compatible with return type that has nontrivial // constructor. -#define RAFT_DECL_BUILD_EXTEND(T, IdxT) \ - extern template auto build(raft::device_resources const&, \ - raft::device_matrix_view, \ - const raft::neighbors::ivf_pq::index_params&) \ - ->raft::neighbors::ivf_pq::index; \ - \ - extern template void build(raft::device_resources const&, \ - index*, \ - raft::device_matrix_view, \ - const raft::neighbors::ivf_pq::index_params&); \ - \ - extern template auto extend( \ - raft::device_resources const&, \ - const index&, \ - raft::device_matrix_view, \ - std::optional>) \ - ->raft::neighbors::ivf_pq::index; \ - \ - extern template void extend( \ - raft::device_resources const&, \ - index*, \ - raft::device_matrix_view, \ - std::optional>); +#define RAFT_DECL_BUILD_EXTEND(T, IdxT) \ + extern template auto build(raft::device_resources const&, \ + const raft::neighbors::ivf_pq::index_params&, \ + raft::device_matrix_view) \ + ->raft::neighbors::ivf_pq::index; \ + \ + extern template auto extend( \ + raft::device_resources const&, \ + raft::device_matrix_view, \ + std::optional>, \ + const raft::neighbors::ivf_pq::index&) \ + ->raft::neighbors::ivf_pq::index; \ + \ + extern template void extend( \ + raft::device_resources const&, \ + raft::device_matrix_view, \ + std::optional>, \ + raft::neighbors::ivf_pq::index*); RAFT_DECL_BUILD_EXTEND(float, int64_t) RAFT_DECL_BUILD_EXTEND(int8_t, int64_t) @@ -65,13 +60,13 @@ RAFT_DECL_BUILD_EXTEND(uint8_t, int64_t) #undef RAFT_DECL_BUILD_EXTEND -#define RAFT_DECL_SEARCH(T, IdxT) \ - extern template void search(raft::device_resources const&, \ - const index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ - const search_params&); +#define RAFT_DECL_SEARCH(T, IdxT) \ + extern template void search(raft::device_resources const&, \ + const raft::neighbors::ivf_pq::search_params&, \ + const raft::neighbors::ivf_pq::index&, \ + raft::device_matrix_view, \ + raft::device_matrix_view, \ + raft::device_matrix_view); RAFT_DECL_SEARCH(float, int64_t); RAFT_DECL_SEARCH(int8_t, int64_t); diff --git a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh index 9c511c4acf..cc95b32cee 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh @@ -83,7 +83,7 @@ void approx_knn_build_index(raft::device_resources const& handle, auto index_view = raft::make_device_matrix_view(index_array, n, D); index->ivf_pq = std::make_unique>( - neighbors::ivf_pq::build(handle, index_view, params)); + neighbors::ivf_pq::build(handle, params, index_view)); } else { RAFT_FAIL("Unrecognized index type."); } @@ -114,8 +114,13 @@ void approx_knn_search(raft::device_resources const& handle, } else if (index->ivf_pq) { neighbors::ivf_pq::search_params params; params.n_probes = index->nprobe; + + auto query_view = + raft::make_device_matrix_view(query_array, n, index->ivf_pq->dim()); + auto indices_view = raft::make_device_matrix_view(indices, n, k); + auto distances_view = raft::make_device_matrix_view(distances, n, k); neighbors::ivf_pq::search( - handle, params, *index->ivf_pq, query_array, n, k, indices, distances); + handle, params, *index->ivf_pq, query_view, indices_view, distances_view); } else { RAFT_FAIL("The model is not trained"); } diff --git a/cpp/include/raft_runtime/neighbors/ivf_pq.hpp b/cpp/include/raft_runtime/neighbors/ivf_pq.hpp index 04664716f0..fb22d7657e 100644 --- a/cpp/include/raft_runtime/neighbors/ivf_pq.hpp +++ b/cpp/include/raft_runtime/neighbors/ivf_pq.hpp @@ -23,39 +23,41 @@ namespace raft::runtime::neighbors::ivf_pq { // We define overloads for build and extend with void return type. This is used in the Cython // wrappers, where exception handling is not compatible with return type that has nontrivial // constructor. -#define RAFT_DECL_BUILD_EXTEND(T, IdxT) \ - auto build(raft::device_resources const& handle, \ - raft::device_matrix_view dataset, \ - const raft::neighbors::ivf_pq::index_params& params); \ - \ - void build(raft::device_resources const& handle, \ - raft::neighbors::ivf_pq::index* index, \ - raft::device_matrix_view dataset, \ - const raft::neighbors::ivf_pq::index_params& params); \ - \ - auto extend(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index& orig_index, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices); \ - \ - void extend(raft::device_resources const& handle, \ - raft::neighbors::ivf_pq::index* index, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices); +#define RAFT_DECL_BUILD_EXTEND(T, IdxT) \ + [[nodiscard]] raft::neighbors::ivf_pq::index build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset); \ + \ + void build(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::ivf_pq::index* idx); \ + \ + [[nodiscard]] raft::neighbors::ivf_pq::index extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_pq::index& idx); \ + \ + void extend(raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_pq::index* idx); -RAFT_DECL_BUILD_EXTEND(float, int64_t) -RAFT_DECL_BUILD_EXTEND(int8_t, int64_t) -RAFT_DECL_BUILD_EXTEND(uint8_t, int64_t) +RAFT_DECL_BUILD_EXTEND(float, int64_t); +RAFT_DECL_BUILD_EXTEND(int8_t, int64_t); +RAFT_DECL_BUILD_EXTEND(uint8_t, int64_t); #undef RAFT_DECL_BUILD_EXTEND #define RAFT_DECL_SEARCH(T, IdxT) \ void search(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index& index, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const raft::neighbors::ivf_pq::search_params& params); + raft::device_matrix_view distances); RAFT_DECL_SEARCH(float, int64_t); RAFT_DECL_SEARCH(int8_t, int64_t); diff --git a/cpp/src/distance/neighbors/ivfpq_build.cu b/cpp/src/distance/neighbors/ivfpq_build.cu index dbd877401e..8759ca2587 100644 --- a/cpp/src/distance/neighbors/ivfpq_build.cu +++ b/cpp/src/distance/neighbors/ivfpq_build.cu @@ -20,33 +20,35 @@ namespace raft::runtime::neighbors::ivf_pq { -#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ - auto build(raft::device_resources const& handle, \ - raft::device_matrix_view dataset, \ - const raft::neighbors::ivf_pq::index_params& params) \ - { \ - return raft::neighbors::ivf_pq::build(handle, dataset, params); \ - } \ - void build(raft::device_resources const& handle, \ - raft::neighbors::ivf_pq::index* idx, \ - raft::device_matrix_view dataset, \ - const raft::neighbors::ivf_pq::index_params& params) \ - { \ - raft::neighbors::ivf_pq::build(handle, idx, dataset, params); \ - } \ - auto extend(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index& orig_index, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - return raft::neighbors::ivf_pq::extend(handle, orig_index, new_vectors, new_indices); \ - } \ - void extend(raft::device_resources const& handle, \ - raft::neighbors::ivf_pq::index* idx, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - raft::neighbors::ivf_pq::extend(handle, idx, new_vectors, new_indices); \ +#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ + raft::neighbors::ivf_pq::index build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset) \ + { \ + return raft::neighbors::ivf_pq::build(handle, params, dataset); \ + } \ + void build(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::ivf_pq::index* idx) \ + { \ + *idx = raft::neighbors::ivf_pq::build(handle, params, dataset); \ + } \ + raft::neighbors::ivf_pq::index extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_pq::index& idx) \ + { \ + return raft::neighbors::ivf_pq::extend(handle, new_vectors, new_indices, idx); \ + } \ + void extend(raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_pq::index* idx) \ + { \ + raft::neighbors::ivf_pq::extend(handle, new_vectors, new_indices, idx); \ } RAFT_INST_BUILD_EXTEND(float, int64_t); diff --git a/cpp/src/distance/neighbors/ivfpq_search_float_int64_t.cu b/cpp/src/distance/neighbors/ivfpq_search_float_int64_t.cu index 00392be8a7..91093d3a39 100644 --- a/cpp/src/distance/neighbors/ivfpq_search_float_int64_t.cu +++ b/cpp/src/distance/neighbors/ivfpq_search_float_int64_t.cu @@ -23,13 +23,13 @@ namespace raft::runtime::neighbors::ivf_pq { #define RAFT_SEARCH_INST(T, IdxT) \ void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ const raft::neighbors::ivf_pq::index& idx, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const raft::neighbors::ivf_pq::search_params& params) \ + raft::device_matrix_view distances) \ { \ - raft::neighbors::ivf_pq::search(handle, idx, queries, neighbors, distances, params); \ + raft::neighbors::ivf_pq::search(handle, params, idx, queries, neighbors, distances); \ } RAFT_SEARCH_INST(float, int64_t); diff --git a/cpp/src/distance/neighbors/ivfpq_search_int8_t_int64_t.cu b/cpp/src/distance/neighbors/ivfpq_search_int8_t_int64_t.cu index 01a26b78b3..e1552c0b27 100644 --- a/cpp/src/distance/neighbors/ivfpq_search_int8_t_int64_t.cu +++ b/cpp/src/distance/neighbors/ivfpq_search_int8_t_int64_t.cu @@ -23,13 +23,13 @@ namespace raft::runtime::neighbors::ivf_pq { #define RAFT_SEARCH_INST(T, IdxT) \ void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ const raft::neighbors::ivf_pq::index& idx, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const raft::neighbors::ivf_pq::search_params& params) \ + raft::device_matrix_view distances) \ { \ - raft::neighbors::ivf_pq::search(handle, idx, queries, neighbors, distances, params); \ + raft::neighbors::ivf_pq::search(handle, params, idx, queries, neighbors, distances); \ } RAFT_SEARCH_INST(int8_t, int64_t); diff --git a/cpp/src/distance/neighbors/ivfpq_search_uint8_t_int64_t.cu b/cpp/src/distance/neighbors/ivfpq_search_uint8_t_int64_t.cu index 5b99b0df9f..85195a7551 100644 --- a/cpp/src/distance/neighbors/ivfpq_search_uint8_t_int64_t.cu +++ b/cpp/src/distance/neighbors/ivfpq_search_uint8_t_int64_t.cu @@ -23,13 +23,13 @@ namespace raft::runtime::neighbors::ivf_pq { #define RAFT_SEARCH_INST(T, IdxT) \ void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ const raft::neighbors::ivf_pq::index& idx, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const raft::neighbors::ivf_pq::search_params& params) \ + raft::device_matrix_view distances) \ { \ - raft::neighbors::ivf_pq::search(handle, idx, queries, neighbors, distances, params); \ + raft::neighbors::ivf_pq::search(handle, params, idx, queries, neighbors, distances); \ } RAFT_SEARCH_INST(uint8_t, int64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_build_float_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_build_float_int64_t.cu index 6818fa665d..d559291b93 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_build_float_int64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_build_float_int64_t.cu @@ -21,13 +21,9 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template auto build(raft::device_resources const& handle, \ - raft::device_matrix_view dataset, \ - const index_params& params) \ - ->index; \ - template void build(raft::device_resources const& handle, \ - index* index, \ - raft::device_matrix_view dataset, \ - const index_params& params); + const index_params& params, \ + raft::device_matrix_view dataset) \ + ->index; RAFT_MAKE_INSTANCE(float, int64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu index feee5eaba2..c8b31e1fff 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu @@ -21,13 +21,9 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template auto build(raft::device_resources const& handle, \ - raft::device_matrix_view dataset, \ - const index_params& params) \ - ->index; \ - template void build(raft::device_resources const& handle, \ - index* index, \ - raft::device_matrix_view dataset, \ - const index_params& params); + const index_params& params, \ + raft::device_matrix_view dataset) \ + ->index; RAFT_MAKE_INSTANCE(int8_t, int64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu index 963cc23f57..5fc62969f0 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu @@ -21,13 +21,9 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template auto build(raft::device_resources const& handle, \ - raft::device_matrix_view dataset, \ - const index_params& params) \ - ->index; \ - template void build(raft::device_resources const& handle, \ - index* index, \ - raft::device_matrix_view dataset, \ - const index_params& params); + const index_params& params, \ + raft::device_matrix_view dataset) \ + ->index; RAFT_MAKE_INSTANCE(uint8_t, int64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_int64_t.cu index 70ef1a3acf..4cc616f32d 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_int64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_int64_t.cu @@ -22,15 +22,15 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template auto extend( \ raft::device_resources const& handle, \ - const index& orig_index, \ raft::device_matrix_view new_vectors, \ - std::optional> new_indices) \ + std::optional> new_indices, \ + const index& idx) \ ->index; \ template void extend( \ raft::device_resources const& handle, \ - index* index, \ raft::device_matrix_view new_vectors, \ - std::optional> new_indices); + std::optional> new_indices, \ + index* idx); RAFT_MAKE_INSTANCE(float, int64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu index a284bec9f3..a3117aae0f 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu @@ -22,15 +22,15 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template auto extend( \ raft::device_resources const& handle, \ - const index& orig_index, \ raft::device_matrix_view new_vectors, \ - std::optional> new_indices) \ + std::optional> new_indices, \ + const index& idx) \ ->index; \ template void extend( \ raft::device_resources const& handle, \ - index* index, \ raft::device_matrix_view new_vectors, \ - std::optional> new_indices); + std::optional> new_indices, \ + index* idx); RAFT_MAKE_INSTANCE(int8_t, int64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu index 2ef568885f..a5e3d68569 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu @@ -22,15 +22,15 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template auto extend( \ raft::device_resources const& handle, \ - const index& orig_index, \ raft::device_matrix_view new_vectors, \ - std::optional> new_indices) \ + std::optional> new_indices, \ + const index& idx) \ ->index; \ template void extend( \ raft::device_resources const& handle, \ - index* index, \ raft::device_matrix_view new_vectors, \ - std::optional> new_indices); + std::optional> new_indices, \ + index* idx); RAFT_MAKE_INSTANCE(uint8_t, int64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_search_float_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_search_float_int64_t.cu index 43f2d3898e..92a4d89e6b 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_search_float_int64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_search_float_int64_t.cu @@ -21,11 +21,11 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template void search(raft::device_resources const& handle, \ - const index& index, \ + const search_params& params, \ + const index& idx, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const search_params& params); + raft::device_matrix_view distances); RAFT_MAKE_INSTANCE(float, int64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu index fd8c727853..62a8b48ad5 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu @@ -21,11 +21,11 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template void search(raft::device_resources const& handle, \ - const index& index, \ + const search_params& params, \ + const index& idx, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const search_params& params); + raft::device_matrix_view distances); RAFT_MAKE_INSTANCE(int8_t, int64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu index 0717a7462d..3bcf134a22 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu @@ -21,11 +21,11 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template void search(raft::device_resources const& handle, \ - const index& index, \ + const search_params& params, \ + const index& idx, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const search_params& params); + raft::device_matrix_view distances); RAFT_MAKE_INSTANCE(uint8_t, int64_t); diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index cdd6570562..b78bd872f7 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -156,7 +156,7 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { auto database_view = raft::make_device_matrix_view( (const DataT*)database.data(), ps.num_db_vecs, ps.dim); - auto index = ivf_flat::build(handle_, database_view, index_params); + index idx = ivf_flat::build(handle_, index_params, database_view); rmm::device_uvector vector_indices(ps.num_db_vecs, stream_); thrust::sequence(handle_.get_thrust_policy(), @@ -169,7 +169,8 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { auto half_of_data_view = raft::make_device_matrix_view( (const DataT*)database.data(), half_of_data, ps.dim); - auto index_2 = ivf_flat::extend(handle_, index, half_of_data_view); + const std::optional> no_opt = std::nullopt; + index idx_2 = ivf_flat::extend(handle_, half_of_data_view, no_opt, idx); auto new_half_of_data_view = raft::make_device_matrix_view( database.data() + half_of_data * ps.dim, IdxT(ps.num_db_vecs) - half_of_data, ps.dim); @@ -178,10 +179,10 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { vector_indices.data() + half_of_data, IdxT(ps.num_db_vecs) - half_of_data); ivf_flat::extend(handle_, - &index_2, new_half_of_data_view, std::make_optional>( - new_half_of_data_indices_view)); + new_half_of_data_indices_view), + &idx_2); auto search_queries_view = raft::make_device_matrix_view( search_queries.data(), ps.num_queries, ps.dim); @@ -189,47 +190,46 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { indices_ivfflat_dev.data(), ps.num_queries, ps.k); auto dists_out_view = raft::make_device_matrix_view( distances_ivfflat_dev.data(), ps.num_queries, ps.k); - raft::spatial::knn::ivf_flat::detail::serialize(handle_, "ivf_flat_index", index_2); + raft::spatial::knn::ivf_flat::detail::serialize(handle_, "ivf_flat_index", idx_2); auto index_loaded = raft::spatial::knn::ivf_flat::detail::deserialize(handle_, "ivf_flat_index"); ivf_flat::search(handle_, + search_params, index_loaded, search_queries_view, indices_out_view, - dists_out_view, - search_params); + dists_out_view); update_host(distances_ivfflat.data(), distances_ivfflat_dev.data(), queries_size, stream_); update_host(indices_ivfflat.data(), indices_ivfflat_dev.data(), queries_size, stream_); handle_.sync_stream(stream_); // Test the centroid invariants - if (index_2.adaptive_centers()) { + if (idx_2.adaptive_centers()) { // The centers must be up-to-date with the corresponding data - std::vector list_sizes(index_2.n_lists()); - std::vector list_offsets(index_2.n_lists()); + std::vector list_sizes(idx_2.n_lists()); + std::vector list_offsets(idx_2.n_lists()); rmm::device_uvector centroid(ps.dim, stream_); + raft::copy(list_sizes.data(), idx_2.list_sizes().data_handle(), idx_2.n_lists(), stream_); raft::copy( - list_sizes.data(), index_2.list_sizes().data_handle(), index_2.n_lists(), stream_); - raft::copy( - list_offsets.data(), index_2.list_offsets().data_handle(), index_2.n_lists(), stream_); + list_offsets.data(), idx_2.list_offsets().data_handle(), idx_2.n_lists(), stream_); handle_.sync_stream(stream_); - for (uint32_t l = 0; l < index_2.n_lists(); l++) { + for (uint32_t l = 0; l < idx_2.n_lists(); l++) { rmm::device_uvector cluster_data(list_sizes[l] * ps.dim, stream_); raft::spatial::knn::detail::utils::copy_selected( (IdxT)list_sizes[l], (IdxT)ps.dim, database.data(), - index_2.indices().data_handle() + list_offsets[l], + idx_2.indices().data_handle() + list_offsets[l], (IdxT)ps.dim, cluster_data.data(), (IdxT)ps.dim, stream_); raft::stats::mean( centroid.data(), cluster_data.data(), ps.dim, list_sizes[l], false, true, stream_); - ASSERT_TRUE(raft::devArrMatch(index_2.centers().data_handle() + ps.dim * l, + ASSERT_TRUE(raft::devArrMatch(idx_2.centers().data_handle() + ps.dim * l, centroid.data(), ps.dim, raft::CompareApprox(0.001), @@ -237,9 +237,9 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { } } else { // The centers must be immutable - ASSERT_TRUE(raft::devArrMatch(index_2.centers().data_handle(), - index.centers().data_handle(), - index_2.centers().size(), + ASSERT_TRUE(raft::devArrMatch(idx_2.centers().data_handle(), + idx.centers().data_handle(), + idx_2.centers().size(), raft::Compare(), stream_)); } diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index f07a241b95..c368192b03 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -178,17 +178,17 @@ class ivf_pq_test : public ::testing::TestWithParam { handle_.sync_stream(stream_); } - auto build_only() + index build_only() { auto ipams = ps.index_params; ipams.add_data_on_build = true; auto index_view = raft::make_device_matrix_view(database.data(), ps.num_db_vecs, ps.dim); - return ivf_pq::build(handle_, index_view, ipams); + return ivf_pq::build(handle_, ipams, index_view); } - auto build_2_extends() + index build_2_extends() { rmm::device_uvector db_indices(ps.num_db_vecs, stream_); thrust::sequence(handle_.get_thrust_policy(), @@ -207,18 +207,21 @@ class ivf_pq_test : public ::testing::TestWithParam { auto database_view = raft::make_device_matrix_view(database.data(), ps.num_db_vecs, ps.dim); - auto index = ivf_pq::build(handle_, database_view, ipams); + auto idx = ivf_pq::build(handle_, ipams, database_view); auto vecs_2_view = raft::make_device_matrix_view(vecs_2, size_2, ps.dim); auto inds_2_view = raft::make_device_matrix_view(inds_2, size_2, 1); - ivf_pq::extend(handle_, &index, vecs_2_view, inds_2_view); - - auto vecs_1_view = raft::make_device_matrix_view(vecs_1, size_1, ps.dim); - auto inds_1_view = raft::make_device_matrix_view(inds_1, size_1, 1); - return ivf_pq::extend(handle_, index, vecs_1_view, inds_1_view); + ivf_pq::extend(handle_, vecs_2_view, inds_2_view, &idx); + + auto vecs_1_view = + raft::make_device_matrix_view(vecs_1, size_1, ps.dim); + auto inds_1_view = + raft::make_device_matrix_view(inds_1, size_1, 1); + ivf_pq::extend(handle_, vecs_1_view, inds_1_view, &idx); + return idx; } - auto build_serialize() + index build_serialize() { ivf_pq::serialize(handle_, "ivf_pq_index", build_only()); return ivf_pq::deserialize(handle_, "ivf_pq_index"); @@ -227,7 +230,7 @@ class ivf_pq_test : public ::testing::TestWithParam { template void run(BuildIndex build_index) { - auto index = build_index(); + index index = build_index(); size_t queries_size = ps.num_queries * ps.k; std::vector indices_ivf_pq(queries_size); @@ -244,7 +247,7 @@ class ivf_pq_test : public ::testing::TestWithParam { raft::make_device_matrix_view(distances_ivf_pq_dev.data(), ps.num_queries, ps.k); ivf_pq::search( - handle_, index, query_view, inds_view, dists_view, ps.search_params); + handle_, ps.search_params, index, query_view, inds_view, dists_view); update_host(distances_ivf_pq.data(), distances_ivf_pq_dev.data(), queries_size, stream_); update_host(indices_ivf_pq.data(), indices_ivf_pq_dev.data(), queries_size, stream_); diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd b/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd index 7d951ae56a..d04d833f3b 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd @@ -107,63 +107,63 @@ cdef extern from "raft_runtime/neighbors/ivf_pq.hpp" \ cdef void build( const device_resources& handle, - index[int64_t]* index, + const index_params& params, device_matrix_view[float, int64_t, row_major] dataset, - const index_params& params) except + + index[int64_t]* index) except + cdef void build( const device_resources& handle, - index[int64_t]* index, + const index_params& params, device_matrix_view[int8_t, int64_t, row_major] dataset, - const index_params& params) except + + index[int64_t]* index) except + cdef void build( const device_resources& handle, - index[int64_t]* index, + const index_params& params, device_matrix_view[uint8_t, int64_t, row_major] dataset, - const index_params& params) except + + index[int64_t]* index) except + cdef void extend( const device_resources& handle, - index[int64_t]* index, device_matrix_view[float, int64_t, row_major] new_vectors, - optional[device_matrix_view[int64_t, int64_t, row_major]] new_indices) except + # noqa: E501 + optional[device_matrix_view[int64_t, int64_t, row_major]] new_indices, + index[int64_t]* index) except + cdef void extend( const device_resources& handle, - index[int64_t]* index, device_matrix_view[int8_t, int64_t, row_major] new_vectors, - optional[device_matrix_view[int64_t, int64_t, row_major]] new_indices) except + # noqa: E501 + optional[device_matrix_view[int64_t, int64_t, row_major]] new_indices, + index[int64_t]* index) except + cdef void extend( const device_resources& handle, - index[int64_t]* index, device_matrix_view[uint8_t, int64_t, row_major] new_vectors, - optional[device_matrix_view[int64_t, int64_t, row_major]] new_indices) except + # noqa: E501 + optional[device_matrix_view[int64_t, int64_t, row_major]] new_indices, + index[int64_t]* index) except + cdef void search( const device_resources& handle, + const search_params& params, const index[int64_t]& index, device_matrix_view[float, int64_t, row_major] queries, device_matrix_view[int64_t, int64_t, row_major] neighbors, - device_matrix_view[float, int64_t, row_major] distances, - const search_params& params) except + + device_matrix_view[float, int64_t, row_major] distances) except + cdef void search( const device_resources& handle, + const search_params& params, const index[int64_t]& index, device_matrix_view[int8_t, int64_t, row_major] queries, device_matrix_view[int64_t, int64_t, row_major] neighbors, - device_matrix_view[float, int64_t, row_major] distances, - const search_params& params) except + + device_matrix_view[float, int64_t, row_major] distances) except + cdef void search( const device_resources& handle, + const search_params& params, const index[int64_t]& index, device_matrix_view[uint8_t, int64_t, row_major] queries, device_matrix_view[int64_t, int64_t, row_major] neighbors, - device_matrix_view[float, int64_t, row_major] distances, - const search_params& params) except + + device_matrix_view[float, int64_t, row_major] distances) except + cdef void serialize(const device_resources& handle, const string& filename, diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx index 75de0aba82..4f4d2c75a4 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx @@ -410,23 +410,23 @@ def build(IndexParams index_params, dataset, handle=None): if dataset_dt == np.float32: with cuda_interruptible(): c_ivf_pq.build(deref(handle_), - idx.index, + index_params.params, get_dmv_float(dataset_cai, check_shape=True), - index_params.params) + idx.index) idx.trained = True elif dataset_dt == np.byte: with cuda_interruptible(): c_ivf_pq.build(deref(handle_), - idx.index, + index_params.params, get_dmv_int8(dataset_cai, check_shape=True), - index_params.params) + idx.index) idx.trained = True elif dataset_dt == np.ubyte: with cuda_interruptible(): c_ivf_pq.build(deref(handle_), - idx.index, + index_params.params, get_dmv_uint8(dataset_cai, check_shape=True), - index_params.params) + idx.index) idx.trained = True else: raise TypeError("dtype %s not supported" % dataset_dt) @@ -520,21 +520,21 @@ def extend(Index index, new_vectors, new_indices, handle=None): if vecs_dt == np.float32: with cuda_interruptible(): c_ivf_pq.extend(deref(handle_), - index.index, get_dmv_float(vecs_cai, check_shape=True), - create_optional(get_dmv_int64(idx_cai, check_shape=False))) # noqa: E501 + create_optional(get_dmv_int64(idx_cai, check_shape=False)), # noqa: E501 + index.index) elif vecs_dt == np.int8: with cuda_interruptible(): c_ivf_pq.extend(deref(handle_), - index.index, get_dmv_int8(vecs_cai, check_shape=True), - create_optional(get_dmv_int64(idx_cai, check_shape=False))) # noqa: E501 + create_optional(get_dmv_int64(idx_cai, check_shape=False)), # noqa: E501 + index.index) elif vecs_dt == np.uint8: with cuda_interruptible(): c_ivf_pq.extend(deref(handle_), - index.index, get_dmv_uint8(vecs_cai, check_shape=True), - create_optional(get_dmv_int64(idx_cai, check_shape=False))) # noqa: E501 + create_optional(get_dmv_int64(idx_cai, check_shape=False)), # noqa: E501 + index.index) else: raise TypeError("query dtype %s not supported" % vecs_dt) @@ -721,27 +721,27 @@ def search(SearchParams search_params, if queries_dt == np.float32: with cuda_interruptible(): c_ivf_pq.search(deref(handle_), + params, deref(index.index), get_dmv_float(queries_cai, check_shape=True), get_dmv_int64(neighbors_cai, check_shape=True), - get_dmv_float(distances_cai, check_shape=True), - params) + get_dmv_float(distances_cai, check_shape=True)) elif queries_dt == np.byte: with cuda_interruptible(): c_ivf_pq.search(deref(handle_), + params, deref(index.index), get_dmv_int8(queries_cai, check_shape=True), get_dmv_int64(neighbors_cai, check_shape=True), - get_dmv_float(distances_cai, check_shape=True), - params) + get_dmv_float(distances_cai, check_shape=True)) elif queries_dt == np.ubyte: with cuda_interruptible(): c_ivf_pq.search(deref(handle_), + params, deref(index.index), get_dmv_uint8(queries_cai, check_shape=True), get_dmv_int64(neighbors_cai, check_shape=True), - get_dmv_float(distances_cai, check_shape=True), - params) + get_dmv_float(distances_cai, check_shape=True)) else: raise ValueError("query dtype %s not supported" % queries_dt) From a1d8cee9b836004d0bec0f1238ba2f8a86948039 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Tue, 14 Mar 2023 16:43:01 +0100 Subject: [PATCH 3/4] Fix doc --- cpp/include/raft/neighbors/ivf_flat.cuh | 10 ++++------ cpp/include/raft/neighbors/ivf_pq.cuh | 7 +++---- cpp/include/raft/neighbors/ivf_pq_serialize.cuh | 2 -- 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/cpp/include/raft/neighbors/ivf_flat.cuh b/cpp/include/raft/neighbors/ivf_flat.cuh index dd16813737..66c8e4840c 100644 --- a/cpp/include/raft/neighbors/ivf_flat.cuh +++ b/cpp/include/raft/neighbors/ivf_flat.cuh @@ -66,9 +66,9 @@ namespace detail = raft::spatial::knn::ivf_flat::detail; * @return the constructed ivf-flat index */ template -auto build(raft::device_resources const& handle, - const index_params& params, - raft::device_matrix_view dataset) -> index +index build(raft::device_resources const& handle, + const index_params& params, + raft::device_matrix_view dataset) { IdxT n_rows = dataset.extent(0); IdxT dim = dataset.extent(1); @@ -232,7 +232,7 @@ void search(raft::device_resources const& handle, handle.get_workspace_resource()); } -/** @} */ +/** @} */ // end group ivf_flat /** * @brief Build the index from the dataset for efficient search. @@ -276,8 +276,6 @@ auto build(raft::device_resources const& handle, return detail::build(handle, params, dataset, n_rows, dim); } -/** @} */ // end group ivf_flat - /** * @brief Build a new index containing the data of the original plus new extra vectors. * diff --git a/cpp/include/raft/neighbors/ivf_pq.cuh b/cpp/include/raft/neighbors/ivf_pq.cuh index fd293672de..f1506497b3 100644 --- a/cpp/include/raft/neighbors/ivf_pq.cuh +++ b/cpp/include/raft/neighbors/ivf_pq.cuh @@ -52,9 +52,9 @@ namespace raft::neighbors::ivf_pq { * @return the constructed ivf-pq index */ template -auto build(raft::device_resources const& handle, - const index_params& params, - raft::device_matrix_view dataset) -> index +index build(raft::device_resources const& handle, + const index_params& params, + raft::device_matrix_view dataset) { IdxT n_rows = dataset.extent(0); IdxT dim = dataset.extent(1); @@ -284,7 +284,6 @@ auto extend(raft::device_resources const& handle, * @param[in] new_indices a device/host pointer to a vector of indices [n_rows]. * If the original index is empty (`index.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. - * @param[inout] idx * @param n_rows the number of samples */ template diff --git a/cpp/include/raft/neighbors/ivf_pq_serialize.cuh b/cpp/include/raft/neighbors/ivf_pq_serialize.cuh index 98b59fd5e1..2dd9d39d73 100644 --- a/cpp/include/raft/neighbors/ivf_pq_serialize.cuh +++ b/cpp/include/raft/neighbors/ivf_pq_serialize.cuh @@ -47,7 +47,6 @@ namespace raft::neighbors::ivf_pq { * @param[in] os output stream * @param[in] index IVF-PQ index * - * @return raft::neighbors::ivf_pq::index */ template void serialize(raft::device_resources const& handle, std::ostream& os, const index& index) @@ -77,7 +76,6 @@ void serialize(raft::device_resources const& handle, std::ostream& os, const ind * @param[in] filename the file name for saving the index * @param[in] index IVF-PQ index * - * @return raft::neighbors::ivf_pq::index */ template void serialize(raft::device_resources const& handle, From b9fe92c6d675142bb2b509a080312b0c9bf244e9 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Wed, 15 Mar 2023 12:48:21 +0100 Subject: [PATCH 4/4] update doc --- cpp/include/raft/neighbors/ivf_flat.cuh | 2 +- cpp/include/raft/neighbors/ivf_pq.cuh | 66 ++++++++++++------------- 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/cpp/include/raft/neighbors/ivf_flat.cuh b/cpp/include/raft/neighbors/ivf_flat.cuh index 821d1934b0..f42bfe66c7 100644 --- a/cpp/include/raft/neighbors/ivf_flat.cuh +++ b/cpp/include/raft/neighbors/ivf_flat.cuh @@ -427,4 +427,4 @@ void search(raft::device_resources const& handle, /** @} */ -} // namespace raft::neighbors::ivf_flat \ No newline at end of file +} // namespace raft::neighbors::ivf_flat diff --git a/cpp/include/raft/neighbors/ivf_pq.cuh b/cpp/include/raft/neighbors/ivf_pq.cuh index f1506497b3..4a12ca72a4 100644 --- a/cpp/include/raft/neighbors/ivf_pq.cuh +++ b/cpp/include/raft/neighbors/ivf_pq.cuh @@ -45,8 +45,8 @@ namespace raft::neighbors::ivf_pq { * @tparam T data element type * @tparam IdxT type of the indices in the source dataset * - * @param handle - * @param params configure the index building + * @param[in] handle + * @param[in] params configure the index building * @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim] * * @return the constructed ivf-pq index @@ -67,12 +67,12 @@ index build(raft::device_resources const& handle, * @tparam T data element type * @tparam IdxT type of the indices in the source dataset * - * @param handle - * @param[inout] idx - * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()] + * @param[in] handle + * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()] * @param[in] new_indices a device matrix view to a vector of indices [n_rows]. - * If the original index is empty (`index.size() == 0`), you can pass `nullptr` + * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` * here to imply a continuous range `[0...n_rows)`. + * @param[inout] idx */ template index extend(raft::device_resources const& handle, @@ -102,12 +102,12 @@ index extend(raft::device_resources const& handle, * @tparam T data element type * @tparam IdxT type of the indices in the source dataset * - * @param handle - * @param[inout] idx - * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()] + * @param[in] handle + * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()] * @param[in] new_indices a device matrix view to a vector of indices [n_rows]. - * If the original index is empty (`index.size() == 0`), you can pass `nullptr` + * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` * here to imply a continuous range `[0...n_rows)`. + * @param[inout] idx */ template void extend(raft::device_resources const& handle, @@ -147,9 +147,9 @@ void extend(raft::device_resources const& handle, * @tparam T data element type * @tparam IdxT type of the indices * - * @param handle - * @param params configure the search - * @param idx ivf-pq constructed index + * @param[in] handle + * @param[in] params configure the search + * @param[in] idx ivf-pq constructed index * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset * [n_queries, k] @@ -212,11 +212,11 @@ void search(raft::device_resources const& handle, * @tparam T data element type * @tparam IdxT type of the indices in the source dataset * - * @param handle - * @param params configure the index building + * @param[in] handle + * @param[in] params configure the index building * @param[in] dataset a device/host pointer to a row-major matrix [n_rows, dim] - * @param n_rows the number of samples - * @param dim the dimensionality of the data + * @param[in] n_rows the number of samples + * @param[in] dim the dimensionality of the data * * @return the constructed ivf-pq index */ @@ -252,13 +252,13 @@ auto build(raft::device_resources const& handle, * @tparam T data element type * @tparam IdxT type of the indices in the source dataset * - * @param handle - * @param idx original index - * @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, index.dim()] + * @param[in] handle + * @param[inout] idx original index + * @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, idx.dim()] * @param[in] new_indices a device/host pointer to a vector of indices [n_rows]. - * If the original index is empty (`index.size() == 0`), you can pass `nullptr` + * If the original index is empty (`idx.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. - * @param n_rows the number of samples + * @param[in] n_rows the number of samples * * @return the constructed extended ivf-pq index */ @@ -278,13 +278,13 @@ auto extend(raft::device_resources const& handle, * @tparam T data element type * @tparam IdxT type of the indices in the source dataset * - * @param handle + * @param[in] handle * @param[inout] idx - * @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, index.dim()] + * @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, idx.dim()] * @param[in] new_indices a device/host pointer to a vector of indices [n_rows]. - * If the original index is empty (`index.size() == 0`), you can pass `nullptr` + * If the original index is empty (`idx.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. - * @param n_rows the number of samples + * @param[in] n_rows the number of samples */ template void extend(raft::device_resources const& handle, @@ -326,17 +326,17 @@ void extend(raft::device_resources const& handle, * @tparam T data element type * @tparam IdxT type of the indices * - * @param handle - * @param params configure the search - * @param idx ivf-pq constructed index + * @param[in] handle + * @param[in] params configure the search + * @param[in] idx ivf-pq constructed index * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] - * @param n_queries the batch size - * @param k the number of neighbors to find for each query. + * @param[in] n_queries the batch size + * @param[in] k the number of neighbors to find for each query. * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset * [n_queries, k] * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param mr an optional memory resource to use across the searches (you can provide a large enough - * memory pool here to avoid memory allocations within search). + * @param[in] mr an optional memory resource to use across the searches (you can provide a large + * enough memory pool here to avoid memory allocations within search). */ template void search(raft::device_resources const& handle,