Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update and standardize IVF indexes API #1328

Merged
merged 9 commits into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions cpp/bench/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ struct ivf_pq_knn {
{
index_params.n_lists = 4096;
index_params.metric = raft::distance::DistanceType::L2Expanded;

auto data_view = raft::make_device_matrix_view<const ValT, IdxT>(data, ps.n_samples, ps.n_dims);
index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view));
}
Expand All @@ -189,13 +188,12 @@ struct ivf_pq_knn {
IdxT* out_idxs)
{
search_params.n_probes = 20;

auto queries_view =
raft::make_device_matrix_view<const ValT, IdxT>(search_items, ps.n_queries, ps.n_dims);
auto idxs_view = raft::make_device_matrix_view<IdxT, IdxT>(out_idxs, ps.n_queries, ps.k);
auto dists_view = raft::make_device_matrix_view<dist_t, IdxT>(out_dists, ps.n_queries, ps.k);
raft::neighbors::ivf_pq::search(
handle, search_params, *index, queries_view, ps.k, idxs_view, dists_view);
handle, search_params, *index, queries_view, idxs_view, dists_view);
}
};

Expand Down
175 changes: 97 additions & 78 deletions cpp/include/raft/neighbors/ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,54 +45,55 @@ namespace raft::neighbors::ivf_pq {
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param params configure the index building
* @param[in] handle
* @param[in] params configure the index building
* @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim]
*
* @return the constructed ivf-pq index
*/
template <typename T, typename IdxT = uint32_t>
inline auto build(raft::device_resources const& handle,
index<IdxT> build(raft::device_resources const& handle,
const index_params& params,
raft::device_matrix_view<const T, IdxT, row_major> dataset) -> index<IdxT>
raft::device_matrix_view<const T, IdxT, row_major> dataset)
{
IdxT n_rows = dataset.extent(0);
IdxT dim = dataset.extent(1);
return detail::build(handle, params, dataset.data_handle(), n_rows, dim);
}

/**
* @brief Build a new index containing the data of the original plus new extra vectors.
*
* Implementation note:
* The new data is clustered according to existing kmeans clusters, then the cluster
* centers are unchanged.
*
* @brief Extend the index with the new data.
* *
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param orig_index original index
* @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()]
* @param[in] handle
* @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()]
* @param[in] new_indices a device matrix view to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
* If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt`
* here to imply a continuous range `[0...n_rows)`.
*
* @return the constructed extended ivf-pq index
* @param[inout] idx
*/
template <typename T, typename IdxT>
inline auto extend(raft::device_resources const& handle,
const index<IdxT>& orig_index,
index<IdxT> extend(raft::device_resources const& handle,
raft::device_matrix_view<const T, IdxT, row_major> new_vectors,
raft::device_matrix_view<const IdxT, IdxT, row_major> new_indices) -> index<IdxT>
std::optional<raft::device_matrix_view<const IdxT, IdxT, row_major>> new_indices,
const index<IdxT>& idx)
{
IdxT n_rows = new_vectors.extent(0);
ASSERT(n_rows == new_indices.extent(0),
"new_vectors and new_indices have different number of rows");
ASSERT(new_vectors.extent(1) == orig_index.dim(),
ASSERT(new_vectors.extent(1) == idx.dim(),
"new_vectors should have the same dimension as the index");
return detail::extend(
handle, orig_index, new_vectors.data_handle(), new_indices.data_handle(), n_rows);

IdxT n_rows = new_vectors.extent(0);
if (new_indices.has_value()) {
ASSERT(n_rows == new_indices.value().extent(0),
"new_vectors and new_indices have different number of rows");
}

return detail::extend(handle,
idx,
new_vectors.data_handle(),
new_indices.has_value() ? new_indices.value().data_handle() : nullptr,
n_rows);
}

/**
Expand All @@ -101,20 +102,33 @@ inline auto extend(raft::device_resources const& handle,
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param[inout] index
* @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()]
* @param[in] handle
* @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()]
* @param[in] new_indices a device matrix view to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
* If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt`
* here to imply a continuous range `[0...n_rows)`.
* @param[inout] idx
*/
template <typename T, typename IdxT>
inline void extend(raft::device_resources const& handle,
index<IdxT>* index,
raft::device_matrix_view<const T, IdxT, row_major> new_vectors,
raft::device_matrix_view<const IdxT, IdxT, row_major> new_indices)
void extend(raft::device_resources const& handle,
raft::device_matrix_view<const T, IdxT, row_major> new_vectors,
std::optional<raft::device_matrix_view<const IdxT, IdxT, row_major>> new_indices,
index<IdxT>* idx)
{
*index = extend(handle, *index, new_vectors, new_indices);
ASSERT(new_vectors.extent(1) == idx->dim(),
"new_vectors should have the same dimension as the index");

IdxT n_rows = new_vectors.extent(0);
if (new_indices.has_value()) {
ASSERT(n_rows == new_indices.value().extent(0),
"new_vectors and new_indices have different number of rows");
}

*idx = detail::extend(handle,
*idx,
new_vectors.data_handle(),
new_indices.has_value() ? new_indices.value().data_handle() : nullptr,
n_rows);
}

/**
Expand All @@ -133,34 +147,39 @@ inline void extend(raft::device_resources const& handle,
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param handle
* @param params configure the search
* @param index ivf-pq constructed index
* @param[in] handle
* @param[in] params configure the search
* @param[in] idx ivf-pq constructed index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param k the number of neighbors to find for each query.
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
*/
template <typename T, typename IdxT>
inline void search(raft::device_resources const& handle,
const search_params& params,
const index<IdxT>& index,
raft::device_matrix_view<const T, IdxT, row_major> queries,
uint32_t k,
raft::device_matrix_view<IdxT, IdxT, row_major> neighbors,
raft::device_matrix_view<float, IdxT, row_major> distances)
void search(raft::device_resources const& handle,
const search_params& params,
const index<IdxT>& idx,
raft::device_matrix_view<const T, IdxT, row_major> queries,
raft::device_matrix_view<IdxT, IdxT, row_major> neighbors,
raft::device_matrix_view<float, IdxT, row_major> distances)
{
IdxT n_queries = queries.extent(0);
bool check_n_rows = (n_queries == neighbors.extent(0)) && (n_queries == distances.extent(0));
ASSERT(check_n_rows,
"queries, neighbors and distances parameters have inconsistent number of rows");
RAFT_EXPECTS(
queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0),
"Number of rows in output neighbors and distances matrices must equal the number of queries.");

RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1),
"Number of columns in output neighbors and distances matrices must equal k");

RAFT_EXPECTS(queries.extent(1) == idx.dim(),
"Number of query dimensions should equal number of dimensions in the index.");

std::uint32_t k = neighbors.extent(1);
return detail::search(handle,
params,
index,
idx,
queries.data_handle(),
n_queries,
static_cast<std::uint32_t>(queries.extent(0)),
k,
neighbors.data_handle(),
distances.data_handle(),
Expand Down Expand Up @@ -193,11 +212,11 @@ inline void search(raft::device_resources const& handle,
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param params configure the index building
* @param[in] handle
* @param[in] params configure the index building
* @param[in] dataset a device/host pointer to a row-major matrix [n_rows, dim]
* @param n_rows the number of samples
* @param dim the dimensionality of the data
* @param[in] n_rows the number of samples
* @param[in] dim the dimensionality of the data
*
* @return the constructed ivf-pq index
*/
Expand Down Expand Up @@ -233,24 +252,24 @@ auto build(raft::device_resources const& handle,
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param orig_index original index
* @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, index.dim()]
* @param[in] handle
* @param[inout] idx original index
* @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, idx.dim()]
* @param[in] new_indices a device/host pointer to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
* If the original index is empty (`idx.size() == 0`), you can pass `nullptr`
* here to imply a continuous range `[0...n_rows)`.
* @param n_rows the number of samples
* @param[in] n_rows the number of samples
*
* @return the constructed extended ivf-pq index
*/
template <typename T, typename IdxT>
auto extend(raft::device_resources const& handle,
const index<IdxT>& orig_index,
const index<IdxT>& idx,
const T* new_vectors,
const IdxT* new_indices,
IdxT n_rows) -> index<IdxT>
{
return detail::extend(handle, orig_index, new_vectors, new_indices, n_rows);
return detail::extend(handle, idx, new_vectors, new_indices, n_rows);
}

/**
Expand All @@ -259,22 +278,22 @@ auto extend(raft::device_resources const& handle,
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param[inout] index
* @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, index.dim()]
* @param[in] handle
* @param[inout] idx
* @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, idx.dim()]
* @param[in] new_indices a device/host pointer to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
* If the original index is empty (`idx.size() == 0`), you can pass `nullptr`
* here to imply a continuous range `[0...n_rows)`.
* @param n_rows the number of samples
* @param[in] n_rows the number of samples
*/
template <typename T, typename IdxT>
void extend(raft::device_resources const& handle,
index<IdxT>* index,
index<IdxT>* idx,
const T* new_vectors,
const IdxT* new_indices,
IdxT n_rows)
{
detail::extend(handle, index, new_vectors, new_indices, n_rows);
detail::extend(handle, idx, new_vectors, new_indices, n_rows);
}

/**
Expand Down Expand Up @@ -307,30 +326,30 @@ void extend(raft::device_resources const& handle,
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param handle
* @param params configure the search
* @param index ivf-pq constructed index
* @param[in] handle
* @param[in] params configure the search
* @param[in] idx ivf-pq constructed index
* @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()]
* @param n_queries the batch size
* @param k the number of neighbors to find for each query.
* @param[in] n_queries the batch size
* @param[in] k the number of neighbors to find for each query.
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
* @param mr an optional memory resource to use across the searches (you can provide a large enough
* memory pool here to avoid memory allocations within search).
* @param[in] mr an optional memory resource to use across the searches (you can provide a large
* enough memory pool here to avoid memory allocations within search).
*/
template <typename T, typename IdxT>
void search(raft::device_resources const& handle,
const search_params& params,
const index<IdxT>& index,
const index<IdxT>& idx,
const T* queries,
uint32_t n_queries,
uint32_t k,
IdxT* neighbors,
float* distances,
rmm::mr::device_memory_resource* mr = nullptr)
{
return detail::search(handle, params, index, queries, n_queries, k, neighbors, distances, mr);
return detail::search(handle, params, idx, queries, n_queries, k, neighbors, distances, mr);
}

} // namespace raft::neighbors::ivf_pq
2 changes: 0 additions & 2 deletions cpp/include/raft/neighbors/ivf_pq_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ namespace raft::neighbors::ivf_pq {
* @param[in] os output stream
* @param[in] index IVF-PQ index
*
* @return raft::neighbors::ivf_pq::index<IdxT>
*/
template <typename IdxT>
void serialize(raft::device_resources const& handle, std::ostream& os, const index<IdxT>& index)
Expand Down Expand Up @@ -77,7 +76,6 @@ void serialize(raft::device_resources const& handle, std::ostream& os, const ind
* @param[in] filename the file name for saving the index
* @param[in] index IVF-PQ index
*
* @return raft::neighbors::ivf_pq::index<IdxT>
*/
template <typename IdxT>
void serialize(raft::device_resources const& handle,
Expand Down
Loading