Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
7cb2842
first commit
tarang-jain Mar 31, 2026
0c7bed4
Merge branch 'main' of https://github.com/rapidsai/cuvs into hierarch…
tarang-jain Apr 8, 2026
8a81241
Merge branch 'main' into hierarchical-helpers
tarang-jain Apr 9, 2026
9f979c3
fix style
tarang-jain Apr 9, 2026
6d6d8d1
re-add comments that were removed
tarang-jain Apr 9, 2026
7cc9dad
re-add comments that were removed
tarang-jain Apr 9, 2026
20a25ff
fix comment for args
tarang-jain Apr 9, 2026
2d6baa5
Update cpp/src/cluster/detail/minClusterDistanceCompute.cu
tarang-jain Apr 9, 2026
b46e471
move batch size loop entirely to else condition
tarang-jain Apr 10, 2026
bdb6fd9
Merge branch 'main' into hierarchical-helpers
tarang-jain Apr 10, 2026
e93af81
rm comments
tarang-jain Apr 10, 2026
e7ea00f
Merge branch 'hierarchical-helpers' of https://github.com/tarang-jain…
tarang-jain Apr 10, 2026
e1a9355
add include
tarang-jain Apr 10, 2026
079b1bb
style
tarang-jain Apr 10, 2026
1a6a145
Merge branch 'main' into hierarchical-helpers
tarang-jain Apr 10, 2026
7c8fa6e
correct norm computation for cosine
tarang-jain Apr 13, 2026
39770f4
Merge branch 'hierarchical-helpers' of https://github.com/tarang-jain…
tarang-jain Apr 13, 2026
925c4ac
add cosine tests for balaned
tarang-jain Apr 13, 2026
f42ea30
Merge branch 'main' of https://github.com/rapidsai/cuvs into hierarch…
tarang-jain Apr 13, 2026
a028bec
Merge branch 'main' into hierarchical-helpers
tarang-jain Apr 14, 2026
fe8db92
Merge branch 'main' into hierarchical-helpers
tarang-jain Apr 21, 2026
40338a0
pre-normalize centers
tarang-jain Apr 21, 2026
7be77b3
Merge branch 'hierarchical-helpers' of https://github.com/tarang-jain…
tarang-jain Apr 21, 2026
707c6d9
int workspace
tarang-jain Apr 22, 2026
c605b97
Merge branch 'main' of https://github.com/rapidsai/cuvs into hierarch…
tarang-jain Apr 22, 2026
c0b2376
rm cosine norms in kmeans.cuh
tarang-jain Apr 23, 2026
4462178
revert unit norm assumption
tarang-jain Apr 23, 2026
a8e281f
empty commit
tarang-jain Apr 23, 2026
940c76d
fix sqrt_op norm
tarang-jain Apr 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 18 additions & 68 deletions cpp/src/cluster/detail/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please edit the docstrings of predict overloads taking a balanced_params parameters in cpp/include/cuvs/cluster/kmeans.hpp to clearly state that the centroids should be normalized when using the cosine metric. A user may have trained centroids elsewhere and may be attempting to GPU-accelerate the prediction.

#pragma once

#include "../../distance/fused_distance_nn.cuh"
#include "kmeans_common.cuh"
#include <cuvs/cluster/kmeans.hpp>

Expand Down Expand Up @@ -88,80 +87,31 @@ inline std::enable_if_t<std::is_floating_point_v<MathT>> predict_core(
auto stream = raft::resource::get_cuda_stream(handle);
switch (params.metric) {
case cuvs::distance::DistanceType::L2Expanded:
case cuvs::distance::DistanceType::L2SqrtExpanded: {
auto workspace = raft::make_device_mdarray<char, IdxT>(
handle, mr, raft::make_extents<IdxT>((sizeof(int)) * n_rows));

auto minClusterAndDistance = raft::make_device_mdarray<raft::KeyValuePair<IdxT, MathT>, IdxT>(
handle, mr, raft::make_extents<IdxT>(n_rows));
raft::KeyValuePair<IdxT, MathT> initial_value(0, std::numeric_limits<MathT>::max());
raft::matrix::fill(handle, minClusterAndDistance.view(), initial_value);

auto centroidsNorm =
raft::make_device_mdarray<MathT, IdxT>(handle, mr, raft::make_extents<IdxT>(n_clusters));
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(
handle,
raft::make_device_matrix_view<const MathT, IdxT, raft::row_major>(centers, n_clusters, dim),
centroidsNorm.view());

cuvs::distance::fusedDistanceNNMinReduce<MathT, raft::KeyValuePair<IdxT, MathT>, IdxT>(
minClusterAndDistance.data_handle(),
dataset,
centers,
dataset_norm,
centroidsNorm.data_handle(),
n_rows,
n_clusters,
dim,
(void*)workspace.data_handle(),
(params.metric == cuvs::distance::DistanceType::L2Expanded) ? false : true,
false,
true,
params.metric,
0.0f,
stream);

// todo(lsugy): use KVP + iterator in caller.
// Copy keys to output labels
raft::linalg::map(handle,
raft::make_const_mdspan(minClusterAndDistance.view()),
raft::make_device_vector_view<LabelT, IdxT>(labels, n_rows),
raft::compose_op<raft::cast_op<LabelT>, raft::key_op>());
break;
}
case cuvs::distance::DistanceType::L2SqrtExpanded:
case cuvs::distance::DistanceType::CosineExpanded: {
auto workspace = raft::make_device_mdarray<char, IdxT>(
handle, mr, raft::make_extents<IdxT>((sizeof(int)) * n_rows));
rmm::device_uvector<MathT> L2NormBuf_OR_DistBuf(0, stream, mr);
rmm::device_uvector<char> workspace(0, stream, mr);

auto X_view = raft::make_device_matrix_view<const MathT, IdxT>(dataset, n_rows, dim);
auto centroids_view =
raft::make_device_matrix_view<const MathT, IdxT>(centers, n_clusters, dim);
auto X_norm_view = raft::make_device_vector_view<const MathT, IdxT>(dataset_norm, n_rows);

auto minClusterAndDistance = raft::make_device_mdarray<raft::KeyValuePair<IdxT, MathT>, IdxT>(
handle, mr, raft::make_extents<IdxT>(n_rows));
raft::KeyValuePair<IdxT, MathT> initial_value(0, std::numeric_limits<MathT>::max());
raft::matrix::fill(handle, minClusterAndDistance.view(), initial_value);

auto centroidsNorm =
raft::make_device_mdarray<MathT, IdxT>(handle, mr, raft::make_extents<IdxT>(n_clusters));
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(
cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute<MathT, IdxT>(
handle,
raft::make_device_matrix_view<const MathT, IdxT, raft::row_major>(centers, n_clusters, dim),
centroidsNorm.view(),
raft::sqrt_op{});

cuvs::distance::fusedDistanceNNMinReduce<MathT, raft::KeyValuePair<IdxT, MathT>, IdxT>(
minClusterAndDistance.data_handle(),
dataset,
centers,
dataset_norm,
centroidsNorm.data_handle(),
n_rows,
n_clusters,
dim,
(void*)workspace.data_handle(),
false,
false,
true,
X_view,
centroids_view,
minClusterAndDistance.view(),
X_norm_view,
L2NormBuf_OR_DistBuf,
params.metric,
0.0f,
stream);
0, // batch_samples (unused for fused reduction)
0, // batch_centroids (unused for fused reduction)
workspace);

// Copy keys to output labels
raft::linalg::map(handle,
raft::make_const_mdspan(minClusterAndDistance.view()),
Expand Down
1 change: 0 additions & 1 deletion cpp/src/cluster/detail/kmeans_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#pragma once

#include "../../distance/distance.cuh"
#include "../../distance/fused_distance_nn.cuh"
#include <cstdint>
#include <cuvs/cluster/kmeans.hpp>
#include <cuvs/distance/distance.hpp>
Expand Down
Loading
Loading