Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ivf_flat::index: hide implementation details #747

Merged
merged 147 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
147 commits
Select commit Hold shift + click to select a range
35ab60d
inital commit and formatting cleanup
achirkin May 13, 2022
24e8c4d
update save/load index function to work with cuann benchmark suite, s…
achirkin May 16, 2022
cb8bcd2
Added benchmarks.
achirkin May 16, 2022
884723c
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 17, 2022
8c4a0a0
Add a missing parameter docs
achirkin May 17, 2022
070fd05
Adapt to the changes in the warpsort api
achirkin May 17, 2022
83b6630
cleanup: use WarpSize constant
achirkin May 17, 2022
3a2703c
cleanup: remove unnecessary helpers
achirkin May 17, 2022
31bbaec
Use a more efficient warp_sort_filtered
achirkin May 17, 2022
4b40181
Recover files that have only non-relevant changes to reduce the size …
achirkin May 17, 2022
7e3041c
wip: replacing explicit allocations with rmm buffers
achirkin May 17, 2022
f6556b7
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 17, 2022
f75761f
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 18, 2022
dd558b4
Update cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh
achirkin May 18, 2022
94b3cbe
Update cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh
achirkin May 18, 2022
2be45a9
wip: replace cudaMemcpy with raft::copy
achirkin May 18, 2022
30c32a9
Simplified some cudaMemcpy invocations
achirkin May 18, 2022
c8e7b4d
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 19, 2022
150a438
Refactoring with helper functions
achirkin May 19, 2022
ddfb8cc
Make the scratch buf 3x L2 cache size
achirkin May 19, 2022
b788e2e
Remove serialization code for now
achirkin May 19, 2022
3e1c14d
remove obsolete comment
achirkin May 19, 2022
a001999
Add a missing sync
achirkin May 19, 2022
2d08271
Rename ann_quantized_faiss
achirkin May 19, 2022
0f88aaa
wip from manual allocations to rmm: updated some parts with pointer r…
achirkin May 19, 2022
363dfc9
wip from manual allocations to rmm
achirkin May 19, 2022
e5399f8
fix style
achirkin May 19, 2022
306f5bf
Set minimum memory pool size in radix_topk to 256 bytes
achirkin May 20, 2022
fd7d2ba
wip malloc-to-rmm: removed most of the manual allocations
achirkin May 20, 2022
403667a
misc cleanup
achirkin May 20, 2022
4c6d563
Refactoing; used raft::handle in place of cublas handle everywhere
achirkin May 20, 2022
3ae52ea
Fix the value type at runtime (use templates instead of runtime dtype)
achirkin May 20, 2022
6fecd7f
ceildiv
achirkin May 20, 2022
174854f
Use rmm's memory pool in place of explicitly allocated buffers
achirkin May 20, 2022
b45b14c
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 20, 2022
ca1aaad
Use raft logging
achirkin May 24, 2022
4228a02
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 24, 2022
70d84ec
Updated logging and nvtx markers
achirkin May 24, 2022
f9c12f8
clang-format
achirkin May 24, 2022
17968e4
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 24, 2022
957ac94
Use the recommended logger header
achirkin May 24, 2022
ccfbccc
Use warpsort for smaller k
achirkin May 25, 2022
7819397
Using raft helpers
achirkin May 25, 2022
510c467
Determine the template parameters Capacity and Veclen recursively
achirkin May 25, 2022
c5087be
wip: refactoring and reducing duplicate calls
achirkin May 26, 2022
f850a4a
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 27, 2022
c5f1c89
Refactor and document ann_ivf_flat_kernel
achirkin May 27, 2022
7b2b9ff
Documenting and refactoring the kernel
achirkin May 27, 2022
913edfb
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 30, 2022
b1208ed
Add a case of high dimensionality
achirkin May 31, 2022
a30ade5
Add more sync into the test to detect device errors
achirkin May 31, 2022
84db732
Add more sync into the test to detect device errors
achirkin May 31, 2022
346afb2
Allow large batch sizes and document more functions
achirkin May 31, 2022
fc201b5
Add a lower bound on expected recall
achirkin May 31, 2022
4021ea2
Compure required memory dynamically
achirkin May 31, 2022
ea8b1c4
readability quickfix
achirkin May 31, 2022
d8a034a
Correct the smem size for the warpsort and add launch bounds
achirkin May 31, 2022
d97d248
Add couple checks against floating point exceptions
achirkin Jun 1, 2022
2e64037
Don't run kmeans on empty dataset
achirkin Jun 2, 2022
9ed50ac
Order all ops by a cuda stream
achirkin Jun 2, 2022
1f9352c
Update comments
achirkin Jun 2, 2022
c048af2
Suggest replacing _cuann_sqsum
achirkin Jun 2, 2022
96f39a8
wip: refactoting utils
achirkin Jun 2, 2022
888daeb
minor comments
achirkin Jun 2, 2022
e6ff267
ann_utils refactoring, docs, and clang-tidy
achirkin Jun 3, 2022
426f713
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin Jun 7, 2022
bacb402
Refactor tests and reduce their memory footprint
achirkin Jun 7, 2022
4042b28
Refactored and documents ann_kmeans_balanced
achirkin Jun 7, 2022
bb5726b
Use memory_resource for temp data in kmeans
achirkin Jun 7, 2022
810c26b
Address clang-tidy and other refactoring suggestions
achirkin Jun 8, 2022
042c410
Move part of the index building onto gpu
achirkin Jun 8, 2022
7ace0fb
Document the index building kernel
achirkin Jun 15, 2022
e9c0d49
Merge branch 'branch-22.08' into fea-knn-ivf-flat
achirkin Jun 15, 2022
3515715
Added a dims padding todo
achirkin Jun 15, 2022
6bd6560
Move kmeans-related allocations and routines to ann_kmeans_balanced.cuh
achirkin Jun 15, 2022
2811814
Add documentation to the build_optimized_kmeans
achirkin Jun 15, 2022
fc3e46e
Using mdarrays and structured index
achirkin Jun 16, 2022
fb8c4b1
Fixed a memory leak and introduced a few assertions to check pointer …
achirkin Jun 17, 2022
f3b2cb2
Merge branch 'branch-22.08' into fea-knn-ivf-flat
cjnolet Jun 17, 2022
092d428
Refactoring build_optimized_kmeans
achirkin Jun 17, 2022
fbcb16b
A few smaller refactorings for kmeans
achirkin Jun 17, 2022
29ca199
Add docs to public methods of the handle
achirkin Jun 20, 2022
38b3cec
Made the metric be a part of the index struct and set the greater_ = …
achirkin Jun 21, 2022
d19bb5f
Do not persist grid_dim_x between searches
achirkin Jun 21, 2022
9094707
Refactor names according to clang-tidy
achirkin Jun 21, 2022
325e201
Refactor the usage of stream and params
achirkin Jun 21, 2022
2a3eb33
Refactor api to have symmetric index/search params
achirkin Jun 21, 2022
867beca
refactor away ivf_flat_index
achirkin Jun 22, 2022
059a6c0
Add the memory resource argument to warp_sort_topk
achirkin Jun 22, 2022
df17b5b
update docs
achirkin Jun 22, 2022
fe9ced1
Allow empty mesoclusters
achirkin Jun 23, 2022
91fdcbb
Add low-dimensional and non-veclen-aligned-dimensional test cases
achirkin Jun 23, 2022
be14c63
Refactor and document loadAndComputeDist
achirkin Jun 23, 2022
eeb4601
Minor renamings
achirkin Jun 23, 2022
025e5a5
Add 8bit int types to knn benchmarks
achirkin Jun 23, 2022
3821366
Fix incorrect data mapping for int8 types
achirkin Jun 24, 2022
d596842
Merge branch 'branch-22.08' into fea-knn-ivf-flat
achirkin Jun 24, 2022
a29baa7
Introduce kIndexGroupSize constant
achirkin Jun 27, 2022
546bef8
Cleanup ann_quantized
achirkin Jun 27, 2022
32d0d2e
Add several type aliases and helpers for creating mdarrays
achirkin Jun 27, 2022
5f427c0
Remove unnecessary inlines and fix docs
achirkin Jun 28, 2022
c581fe2
More refactoring and a few forceinlines
achirkin Jun 28, 2022
805e78c
Add a helper for creating pool_memory_resource when it makes sense
achirkin Jun 29, 2022
a4973e6
Force move the mdarrays when creating index to avoid copying them
achirkin Jun 29, 2022
68c267e
Minor refactorings
achirkin Jun 29, 2022
f2b8ed8
Add nvtx annotations to the outermost ANN calls for better performanc…
achirkin Jun 29, 2022
f91c7f7
Add a few more test cases and annotations for them
achirkin Jun 29, 2022
84b1c5b
Fix a typo
achirkin Jun 29, 2022
afc1f6a
Move ensure_integral_extents to the detail folder
achirkin Jun 30, 2022
3a10f86
Lift the requirement to have query pointers aligned with Veclen
achirkin Jun 30, 2022
9f5c64c
Merge branch 'branch-22.08' into enh-mdarray-helpers
achirkin Jun 30, 2022
1afd667
Use move semantics for the index everywhere, but try to keep it const…
achirkin Jun 30, 2022
73ce9e1
Update documentation
achirkin Jun 30, 2022
2a45645
Remove the debug path USE_FAISS
achirkin Jun 30, 2022
75a48b4
Add a type trait for checking if the conversion between two numeric t…
achirkin Jul 1, 2022
ed25cae
Merge branch 'branch-22.08' into fea-knn-ivf-flat
achirkin Jul 1, 2022
388200c
Support 32bit and unsigned indices in bruteforce KNN
achirkin Jul 1, 2022
f08df83
Merge branch 'enh-mdarray-helpers' into fea-knn-ivf-flat
achirkin Jul 1, 2022
9200886
Merge branch 'enh-knn-bruteforce-uint32' into fea-knn-ivf-flat
achirkin Jul 1, 2022
14bfe02
Make index type a template parameter
achirkin Jul 1, 2022
1283cbe
Revert the api changes as much as possible and deprecate the old api
achirkin Jul 1, 2022
e73b259
Remove the stream argument from the public API
achirkin Jul 4, 2022
8e7ffb8
Merge branch 'branch-22.08' into fea-knn-ivf-flat
achirkin Jul 5, 2022
5f5dc0d
Merge branch 'branch-22.08' into fea-knn-ivf-flat
achirkin Jul 5, 2022
03ebbe0
Simplify kmeans::predict a little bit
achirkin Jul 6, 2022
cde7f97
Factor out predict from the other ops in kmeans for use outside of th…
achirkin Jul 7, 2022
305bbcd
Add new function extend(index, new_vecs, new_inds) to ivf_flat
achirkin Jul 20, 2022
76c383f
Merge branch 'branch-22.08' into fea-knn-ivf-flat
achirkin Jul 21, 2022
7f640a9
Improve the docs
achirkin Jul 21, 2022
2e9eda5
Fix using non-existing log function
achirkin Jul 21, 2022
dc62a0f
Hide all data components from ifv_flat::index and expose immutable views
achirkin Jul 21, 2022
fb841c3
Replace thurst::exclusive_scan with thrust::inclusive_scan to avoid a…
achirkin Jul 22, 2022
04bb5dc
Merge branch 'fea-knn-ivf-flat' into enh-knn-ivf-flat-hide-impl
achirkin Jul 22, 2022
c95ea85
ann_common.h: remove deps on cuda code, so that the file can be inclu…
achirkin Jul 22, 2022
0c72ee8
ann_common.h: remove deps on cuda code, so that the file can be inclu…
achirkin Jul 22, 2022
0196695
Make helper overloads inline for linking in cuml
achirkin Jul 22, 2022
eb15639
Split processing.hpp into *.cuh and *.hpp to avoid incomplete types
achirkin Jul 22, 2022
e4b2b39
WIP: investigating segmentation fault in cuml test
achirkin Jul 25, 2022
6bc0fcb
Revert the wip-changes from the last commit
achirkin Jul 26, 2022
f599aaf
Merge remote-tracking branch 'origin/fea-knn-ivf-flat' into enh-knn-i…
achirkin Jul 26, 2022
a191410
Merge branch 'branch-22.08' into enh-knn-ivf-flat-hide-impl
achirkin Jul 28, 2022
317ddf3
Enhance documentation
achirkin Jul 28, 2022
114fb63
Fix couple typos in docs
achirkin Jul 28, 2022
1d283ae
Change the data indexing to size_t to make sure the total size (size*…
achirkin Jul 28, 2022
a9bd2d6
Merge branch 'branch-22.08' into enh-knn-ivf-flat-hide-impl
achirkin Aug 2, 2022
f9d55a7
Make ivf_flat::index look a little bit more like knn::sparse api
achirkin Aug 2, 2022
fef6dac
Test both overloads of
achirkin Aug 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,16 @@ constexpr auto calc_minibatch_size(uint32_t n_clusters, size_t n_rows) -> uint32
/**
* @brief Given the data and labels, calculate cluster centers and sizes in one sweep.
*
* Let S_i = {x_k | x_k \in dataset & labels[k] == i} be the vectors in the dataset with label i.
* On exit centers_i = normalize(\sum_{x \in S_i} x), where `normalize` depends on the distance
* type.
* Let `S_i = {x_k | x_k \in dataset & labels[k] == i}` be the vectors in the dataset with label i.
*
* On exit,
* `centers_i = (\sum_{x \in S_i} x + w_i * center_i) / (|S_i| + w_i)`,
* where `w_i = reset_counters ? 0 : cluster_size[i]`.
*
* In other words, the updated cluster centers are a weighted average of the existing cluster
* center, and the coordinates of the points labeled with i. _This allows calling this function
* multiple times with different datasets with the same effect as if calling this function once
* on the combined dataset_.
*
* NB: `centers` and `cluster_sizes` must be accessible on GPU due to
* divide_along_rows/normalize_rows. The rest can be both, under assumption that all pointers are
Expand Down
18 changes: 9 additions & 9 deletions cpp/include/raft/spatial/knn/detail/ann_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -368,33 +368,33 @@ __global__ void map_along_rows_kernel(
}

/**
* @brief Divide matrix values along rows by an integer value, skipping rows if the corresponding
* divisor is zero.
* @brief Map a binary function over a matrix and a vector element-wise, broadcasting the vector
* values along rows: `m[i, j] = op(m[i,j], v[i])`
*
* NB: device-only function
*
* @tparam Lambda
*
* @param n_rows
* @param n_cols
* @param[inout] a device pointer to a row-major matrix [n_rows, n_cols]
* @param[in] d device pointer to a vector [n_rows]
* @param map the binary operation to apply on every element of matrix rows and of the vector
* @param[inout] m device pointer to a row-major matrix [n_rows, n_cols]
* @param[in] v device pointer to a vector [n_rows]
* @param op the binary operation to apply on every element of matrix rows and of the vector
*/
template <typename Lambda>
inline void map_along_rows(uint32_t n_rows,
uint32_t n_cols,
float* a,
const uint32_t* d,
Lambda map,
float* m,
const uint32_t* v,
Lambda op,
rmm::cuda_stream_view stream)
{
dim3 threads(128, 1, 1);
dim3 blocks(
ceildiv<uint64_t>(static_cast<uint64_t>(n_rows) * static_cast<uint64_t>(n_cols), threads.x),
1,
1);
map_along_rows_kernel<<<blocks, threads, 0, stream>>>(n_rows, n_cols, a, d, map);
map_along_rows_kernel<<<blocks, threads, 0, stream>>>(n_rows, n_cols, m, v, op);
}

template <typename T>
Expand Down
146 changes: 48 additions & 98 deletions cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -108,39 +108,38 @@ inline auto extend(const handle_t& handle,
const index<T, IdxT>& orig_index,
const T* new_vectors,
const IdxT* new_indices,
IdxT n_rows,
rmm::cuda_stream_view stream) -> index<T, IdxT>
IdxT n_rows) -> index<T, IdxT>
{
auto n_lists = orig_index.n_lists;
auto dim = orig_index.dim;
auto stream = handle.get_stream();
auto n_lists = orig_index.n_lists();
auto dim = orig_index.dim();
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"ivf_flat::extend(%zu, %u)", size_t(n_rows), dim);

RAFT_EXPECTS(new_indices != nullptr || orig_index.size == 0,
RAFT_EXPECTS(new_indices != nullptr || orig_index.size() == 0,
"You must pass data indices when the index is non-empty.");

rmm::device_uvector<uint32_t> new_labels(n_rows, stream);
kmeans::predict(handle,
orig_index.centers.data(),
orig_index.centers().data_handle(),
n_lists,
dim,
new_vectors,
n_rows,
new_labels.data(),
orig_index.metric,
orig_index.metric(),
stream);

auto&& list_sizes = rmm::device_uvector<uint32_t>(n_lists, stream);
auto&& list_offsets = rmm::device_uvector<IdxT>(n_lists + 1, stream);
auto list_sizes_ptr = list_sizes.data();
auto list_offsets_ptr = list_offsets.data();
index<T, IdxT> ext_index(handle, orig_index.metric(), n_lists, dim);

auto&& centers = rmm::device_uvector<float>(size_t(n_lists) * size_t(dim), stream);
auto centers_ptr = centers.data();
auto list_sizes_ptr = ext_index.list_sizes().data_handle();
auto list_offsets_ptr = ext_index.list_offsets().data_handle();
auto centers_ptr = ext_index.centers().data_handle();

// Calculate the centers and sizes on the new data, starting from the original values
raft::copy(centers_ptr, orig_index.centers.data(), centers.size(), stream);
raft::copy(list_sizes_ptr, orig_index.list_sizes.data(), list_sizes.size(), stream);
raft::copy(centers_ptr, orig_index.centers().data_handle(), ext_index.centers().size(), stream);
raft::copy(
list_sizes_ptr, orig_index.list_sizes().data_handle(), ext_index.list_sizes().size(), stream);

kmeans::calc_centers_and_sizes(centers_ptr,
list_sizes_ptr,
Expand All @@ -160,146 +159,97 @@ inline auto extend(const handle_t& handle,
list_sizes_ptr,
list_sizes_ptr + n_lists,
list_offsets_ptr + 1,
[] __device__(IdxT s, uint32_t l) { return s + Pow2<WarpSize>::roundUp(l); });
[] __device__(IdxT s, uint32_t l) { return s + Pow2<kIndexGroupSize>::roundUp(l); });
update_host(&index_size, list_offsets_ptr + n_lists, 1, stream);
handle.sync_stream(stream);

auto&& data = rmm::device_uvector<T>(index_size * IdxT(dim), stream);
auto&& indices = rmm::device_uvector<IdxT>(index_size, stream);
ext_index.allocate(
handle, index_size, ext_index.metric() == raft::distance::DistanceType::L2Expanded);

// Populate index with the old data
if (orig_index.size > 0) {
utils::block_copy(orig_index.list_offsets.data(),
if (orig_index.size() > 0) {
utils::block_copy(orig_index.list_offsets().data_handle(),
list_offsets_ptr,
IdxT(n_lists),
orig_index.data.data(),
data.data(),
orig_index.data().data_handle(),
ext_index.data().data_handle(),
IdxT(dim),
stream);

utils::block_copy(orig_index.list_offsets.data(),
utils::block_copy(orig_index.list_offsets().data_handle(),
list_offsets_ptr,
IdxT(n_lists),
orig_index.indices.data(),
indices.data(),
orig_index.indices().data_handle(),
ext_index.indices().data_handle(),
IdxT(1),
stream);
}

// Copy the old sizes, so we can start from the current state of the index;
// we'll rebuild the `list_sizes_ptr` in the following kernel, using it as an atomic counter.
raft::copy(list_sizes_ptr, orig_index.list_sizes.data(), list_sizes.size(), stream);
raft::copy(
list_sizes_ptr, orig_index.list_sizes().data_handle(), ext_index.list_sizes().size(), stream);

const dim3 block_dim(256);
const dim3 grid_dim(raft::ceildiv<IdxT>(n_rows, block_dim.x));
build_index_kernel<<<grid_dim, block_dim, 0, stream>>>(new_labels.data(),
list_offsets_ptr,
new_vectors,
new_indices,
data.data(),
indices.data(),
ext_index.data().data_handle(),
ext_index.indices().data_handle(),
list_sizes_ptr,
n_rows,
dim,
orig_index.veclen);
ext_index.veclen());
RAFT_CUDA_TRY(cudaPeekAtLastError());

// Precompute the centers vector norms for L2Expanded distance
auto compute_norms = [&]() {
auto&& r = rmm::device_uvector<float>(n_lists, stream);
utils::dots_along_rows(n_lists, dim, centers.data(), r.data(), stream);
RAFT_LOG_TRACE_VEC(r.data(), 20);
return std::move(r);
};
auto&& center_norms = orig_index.metric == raft::distance::DistanceType::L2Expanded
? std::optional(compute_norms())
: std::nullopt;
if (ext_index.center_norms().has_value()) {
utils::dots_along_rows(n_lists,
dim,
ext_index.centers().data_handle(),
ext_index.center_norms()->data_handle(),
stream);
RAFT_LOG_TRACE_VEC(ext_index.center_norms()->data_handle(), std::min<uint32_t>(dim, 20));
}

// assemble the index
index<T, IdxT> new_index{{},
orig_index.veclen,
orig_index.metric,
index_size,
orig_index.dim,
orig_index.n_lists,
std::move(data),
std::move(indices),
std::move(list_sizes),
std::move(list_offsets),
std::move(centers),
std::move(center_norms)};

// check index invariants
new_index.check_consistency();

return new_index;
return ext_index;
}

/** See raft::spatial::knn::ivf_flat::build docs */
template <typename T, typename IdxT>
inline auto build(const handle_t& handle,
const index_params& params,
const T* dataset,
IdxT n_rows,
uint32_t dim,
rmm::cuda_stream_view stream) -> index<T, IdxT>
inline auto build(
const handle_t& handle, const index_params& params, const T* dataset, IdxT n_rows, uint32_t dim)
-> index<T, IdxT>
{
auto stream = handle.get_stream();
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"ivf_flat::build(%zu, %u)", size_t(n_rows), dim);
static_assert(std::is_same_v<T, float> || std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>,
"unsupported data type");
RAFT_EXPECTS(n_rows > 0 && dim > 0, "empty dataset");

// TODO: consider padding the dimensions and fixing veclen to its maximum possible value as a
// template parameter (https://github.com/rapidsai/raft/issues/711)
uint32_t veclen = 16 / sizeof(T);
while (dim % veclen != 0) {
veclen = veclen >> 1;
}
auto n_lists = static_cast<uint32_t>(params.n_lists);

// kmeans cluster ids for the dataset
auto&& centers = rmm::device_uvector<float>(size_t(n_lists) * size_t(dim), stream);
index<T, IdxT> index(handle, params, dim);
utils::memzero(index.list_sizes().data_handle(), index.list_sizes().size(), stream);
utils::memzero(index.list_offsets().data_handle(), index.list_offsets().size(), stream);

// Predict labels of the whole dataset
kmeans::build_optimized_kmeans(handle,
params.kmeans_n_iters,
dim,
dataset,
n_rows,
centers.data(),
n_lists,
index.centers().data_handle(),
params.n_lists,
params.kmeans_trainset_fraction,
params.metric,
stream);

auto&& data = rmm::device_uvector<T>(0, stream);
auto&& indices = rmm::device_uvector<IdxT>(0, stream);
auto&& list_sizes = rmm::device_uvector<uint32_t>(n_lists, stream);
auto&& list_offsets = rmm::device_uvector<IdxT>(n_lists + 1, stream);
utils::memzero(list_sizes.data(), list_sizes.size(), stream);
utils::memzero(list_offsets.data(), list_offsets.size(), stream);

// assemble the index
index<T, IdxT> index{{},
veclen,
params.metric,
IdxT(0),
dim,
n_lists,
std::move(data),
std::move(indices),
std::move(list_sizes),
std::move(list_offsets),
std::move(centers),
std::nullopt};

// check index invariants
index.check_consistency();

// add the data if necessary
if (params.add_data_on_build) {
return extend<T, IdxT>(handle, index, dataset, nullptr, n_rows, stream);
return detail::extend<T, IdxT>(handle, index, dataset, nullptr, n_rows);
} else {
return index;
}
Expand Down
Loading