From 7cb28425fef1b118b5cef55495c2e3e100e29e95 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 31 Mar 2026 16:08:01 -0700 Subject: [PATCH 01/18] first commit --- cpp/src/cluster/detail/kmeans_balanced.cuh | 87 +++----------- .../detail/minClusterDistanceCompute.cu | 112 +++++++++--------- 2 files changed, 75 insertions(+), 124 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index f5dc759725..66ac41fe5c 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -5,7 +5,6 @@ #pragma once -#include "../../distance/fused_distance_nn.cuh" #include "kmeans_common.cuh" #include @@ -88,80 +87,32 @@ inline std::enable_if_t> predict_core( auto stream = raft::resource::get_cuda_stream(handle); switch (params.metric) { case cuvs::distance::DistanceType::L2Expanded: - case cuvs::distance::DistanceType::L2SqrtExpanded: { - auto workspace = raft::make_device_mdarray( - handle, mr, raft::make_extents((sizeof(int)) * n_rows)); - - auto minClusterAndDistance = raft::make_device_mdarray, IdxT>( - handle, mr, raft::make_extents(n_rows)); - raft::KeyValuePair initial_value(0, std::numeric_limits::max()); - raft::matrix::fill(handle, minClusterAndDistance.view(), initial_value); - - auto centroidsNorm = - raft::make_device_mdarray(handle, mr, raft::make_extents(n_clusters)); - raft::linalg::norm( - handle, - raft::make_device_matrix_view(centers, n_clusters, dim), - centroidsNorm.view()); - - cuvs::distance::fusedDistanceNNMinReduce, IdxT>( - minClusterAndDistance.data_handle(), - dataset, - centers, - dataset_norm, - centroidsNorm.data_handle(), - n_rows, - n_clusters, - dim, - (void*)workspace.data_handle(), - (params.metric == cuvs::distance::DistanceType::L2Expanded) ? false : true, - false, - true, - params.metric, - 0.0f, - stream); - - // todo(lsugy): use KVP + iterator in caller. - // Copy keys to output labels - raft::linalg::map(handle, - raft::make_const_mdspan(minClusterAndDistance.view()), - raft::make_device_vector_view(labels, n_rows), - raft::compose_op, raft::key_op>()); - break; - } + case cuvs::distance::DistanceType::L2SqrtExpanded: case cuvs::distance::DistanceType::CosineExpanded: { - auto workspace = raft::make_device_mdarray( - handle, mr, raft::make_extents((sizeof(int)) * n_rows)); + rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream, mr); + rmm::device_uvector workspace(0, stream, mr); + + auto X_view = raft::make_device_matrix_view(dataset, n_rows, dim); + auto centroids_view = + raft::make_device_matrix_view(centers, n_clusters, dim); + auto X_norm_view = + raft::make_device_vector_view(dataset_norm, n_rows); auto minClusterAndDistance = raft::make_device_mdarray, IdxT>( handle, mr, raft::make_extents(n_rows)); - raft::KeyValuePair initial_value(0, std::numeric_limits::max()); - raft::matrix::fill(handle, minClusterAndDistance.view(), initial_value); - auto centroidsNorm = - raft::make_device_mdarray(handle, mr, raft::make_extents(n_clusters)); - raft::linalg::norm( + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( handle, - raft::make_device_matrix_view(centers, n_clusters, dim), - centroidsNorm.view(), - raft::sqrt_op{}); - - cuvs::distance::fusedDistanceNNMinReduce, IdxT>( - minClusterAndDistance.data_handle(), - dataset, - centers, - dataset_norm, - centroidsNorm.data_handle(), - n_rows, - n_clusters, - dim, - (void*)workspace.data_handle(), - false, - false, - true, + X_view, + centroids_view, + minClusterAndDistance.view(), + X_norm_view, + L2NormBuf_OR_DistBuf, params.metric, - 0.0f, - stream); + 0, // batch_samples (0 = default) + 0, // batch_centroids (0 = default) + workspace); + // Copy keys to output labels raft::linalg::map(handle, raft::make_const_mdspan(minClusterAndDistance.view()), diff --git a/cpp/src/cluster/detail/minClusterDistanceCompute.cu b/cpp/src/cluster/detail/minClusterDistanceCompute.cu index 8370ff922f..5da3b1efb2 100644 --- a/cpp/src/cluster/detail/minClusterDistanceCompute.cu +++ b/cpp/src/cluster/detail/minClusterDistanceCompute.cu @@ -29,76 +29,76 @@ void minClusterAndDistanceCompute( auto n_samples = X.extent(0); auto n_features = X.extent(1); auto n_clusters = centroids.extent(0); - // todo(lsugy): change batch size computation when using fusedL2NN! bool is_fused = metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded; + metric == cuvs::distance::DistanceType::L2SqrtExpanded || + metric == cuvs::distance::DistanceType::CosineExpanded; auto dataBatchSize = is_fused ? (IndexT)n_samples : getDataBatchSize(batch_samples, n_samples); auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters); if (is_fused) { L2NormBuf_OR_DistBuf.resize(n_clusters, stream); - raft::linalg::norm( - handle, - centroids, - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters)); + if (metric == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::norm( + handle, + centroids, + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters), + raft::sqrt_op{}); + } else { + raft::linalg::norm( + handle, + centroids, + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters)); + } + + auto centroidsNorm = + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); + + raft::KeyValuePair initial_value(0, std::numeric_limits::max()); + raft::matrix::fill(handle, minClusterAndDistance, initial_value); + + workspace.resize((sizeof(int)) * n_samples, stream); + + cuvs::distance::fusedDistanceNNMinReduce, IndexT>( + minClusterAndDistance.data_handle(), + X.data_handle(), + centroids.data_handle(), + L2NormX.data_handle(), + centroidsNorm.data_handle(), + n_samples, + n_clusters, + n_features, + (void*)workspace.data(), + metric != cuvs::distance::DistanceType::L2Expanded, + false, + true, + metric, + 0.0f, + stream); } else { - // TODO: Unless pool allocator is used, passing in a workspace for this - // isn't really increasing performance because this needs to do a re-allocation - // anyways. ref https://github.com/rapidsai/raft/issues/930 L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream); - } - // Note - pairwiseDistance and centroidsNorm share the same buffer - // centroidsNorm [n_clusters] - tensor wrapper around centroids L2 Norm - auto centroidsNorm = - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - // pairwiseDistance[ns x nc] - tensor wrapper around the distance buffer - auto pairwiseDistance = raft::make_device_matrix_view( - L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); + // pairwiseDistance[ns x nc] - tensor wrapper around the distance buffer + auto pairwiseDistance = raft::make_device_matrix_view( + L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); - raft::KeyValuePair initial_value(0, std::numeric_limits::max()); + raft::KeyValuePair initial_value(0, std::numeric_limits::max()); + raft::matrix::fill(handle, minClusterAndDistance, initial_value); - raft::matrix::fill(handle, minClusterAndDistance, initial_value); + // tile over the input dataset + for (IndexT dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { + // # of samples for the current batch + auto ns = std::min((IndexT)dataBatchSize, n_samples - dIdx); - // tile over the input dataset - for (IndexT dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { - // # of samples for the current batch - auto ns = std::min((IndexT)dataBatchSize, n_samples - dIdx); + // datasetView [ns x n_features] - view representing the current batch of + // input dataset + auto datasetView = raft::make_device_matrix_view( + X.data_handle() + (dIdx * n_features), ns, n_features); - // datasetView [ns x n_features] - view representing the current batch of - // input dataset - auto datasetView = raft::make_device_matrix_view( - X.data_handle() + (dIdx * n_features), ns, n_features); + // minClusterAndDistanceView [ns x n_clusters] + auto minClusterAndDistanceView = + raft::make_device_vector_view, IndexT>( + minClusterAndDistance.data_handle() + dIdx, ns); - // minClusterAndDistanceView [ns x n_clusters] - auto minClusterAndDistanceView = - raft::make_device_vector_view, IndexT>( - minClusterAndDistance.data_handle() + dIdx, ns); - - auto L2NormXView = - raft::make_device_vector_view(L2NormX.data_handle() + dIdx, ns); - - if (is_fused) { - workspace.resize((sizeof(int)) * ns, stream); - - // todo(lsugy): remove cIdx - cuvs::distance::fusedDistanceNNMinReduce, IndexT>( - minClusterAndDistanceView.data_handle(), - datasetView.data_handle(), - centroids.data_handle(), - L2NormXView.data_handle(), - centroidsNorm.data_handle(), - ns, - n_clusters, - n_features, - (void*)workspace.data(), - metric != cuvs::distance::DistanceType::L2Expanded, - false, - true, - metric, - 0.0f, - stream); - } else { // tile over the centroids for (IndexT cIdx = 0; cIdx < n_clusters; cIdx += centroidsBatchSize) { // # of centroids for the current batch From 9f979c30072a338af3faef601baaa9a4d1ad66ca Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 8 Apr 2026 17:50:43 -0700 Subject: [PATCH 02/18] fix style --- cpp/src/cluster/detail/kmeans_balanced.cuh | 3 +-- cpp/src/cluster/detail/minClusterDistanceCompute.cu | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index 66ac41fe5c..c142393dc9 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -95,8 +95,7 @@ inline std::enable_if_t> predict_core( auto X_view = raft::make_device_matrix_view(dataset, n_rows, dim); auto centroids_view = raft::make_device_matrix_view(centers, n_clusters, dim); - auto X_norm_view = - raft::make_device_vector_view(dataset_norm, n_rows); + auto X_norm_view = raft::make_device_vector_view(dataset_norm, n_rows); auto minClusterAndDistance = raft::make_device_mdarray, IdxT>( handle, mr, raft::make_extents(n_rows)); diff --git a/cpp/src/cluster/detail/minClusterDistanceCompute.cu b/cpp/src/cluster/detail/minClusterDistanceCompute.cu index 5da3b1efb2..52692fad98 100644 --- a/cpp/src/cluster/detail/minClusterDistanceCompute.cu +++ b/cpp/src/cluster/detail/minClusterDistanceCompute.cu @@ -29,7 +29,7 @@ void minClusterAndDistanceCompute( auto n_samples = X.extent(0); auto n_features = X.extent(1); auto n_clusters = centroids.extent(0); - bool is_fused = metric == cuvs::distance::DistanceType::L2Expanded || + bool is_fused = metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded || metric == cuvs::distance::DistanceType::CosineExpanded; auto dataBatchSize = is_fused ? (IndexT)n_samples : getDataBatchSize(batch_samples, n_samples); From 6d6d8d199a9c293b14df4f009a0598e4977a5dbc Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 8 Apr 2026 17:58:14 -0700 Subject: [PATCH 03/18] re-add comments that were removed --- cpp/src/cluster/detail/minClusterDistanceCompute.cu | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cpp/src/cluster/detail/minClusterDistanceCompute.cu b/cpp/src/cluster/detail/minClusterDistanceCompute.cu index 52692fad98..ac3cc69ddf 100644 --- a/cpp/src/cluster/detail/minClusterDistanceCompute.cu +++ b/cpp/src/cluster/detail/minClusterDistanceCompute.cu @@ -29,6 +29,7 @@ void minClusterAndDistanceCompute( auto n_samples = X.extent(0); auto n_features = X.extent(1); auto n_clusters = centroids.extent(0); + // todo(lsugy): change batch size computation when using fusedL2NN! bool is_fused = metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded || metric == cuvs::distance::DistanceType::CosineExpanded; @@ -75,6 +76,9 @@ void minClusterAndDistanceCompute( 0.0f, stream); } else { + // TODO: Unless pool allocator is used, passing in a workspace for this + // isn't really increasing performance because this needs to do a re-allocation + // anyways. ref https://github.com/rapidsai/raft/issues/930 L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream); // pairwiseDistance[ns x nc] - tensor wrapper around the distance buffer @@ -227,6 +231,7 @@ void minClusterDistanceCompute(raft::resources const& handle, if (is_fused) { workspace.resize((sizeof(IndexT)) * ns, stream); + // todo(lsugy): remove cIdx cuvs::distance::fusedDistanceNNMinReduce( minClusterDistanceView.data_handle(), datasetView.data_handle(), From 7cc9dadc67d68629f8ed76a857e49648b3fc5a3e Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 8 Apr 2026 18:08:14 -0700 Subject: [PATCH 04/18] re-add comments that were removed --- cpp/src/cluster/detail/minClusterDistanceCompute.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/cluster/detail/minClusterDistanceCompute.cu b/cpp/src/cluster/detail/minClusterDistanceCompute.cu index ac3cc69ddf..bbbf634184 100644 --- a/cpp/src/cluster/detail/minClusterDistanceCompute.cu +++ b/cpp/src/cluster/detail/minClusterDistanceCompute.cu @@ -59,6 +59,7 @@ void minClusterAndDistanceCompute( workspace.resize((sizeof(int)) * n_samples, stream); + // todo(lsugy): remove cIdx cuvs::distance::fusedDistanceNNMinReduce, IndexT>( minClusterAndDistance.data_handle(), X.data_handle(), @@ -231,7 +232,6 @@ void minClusterDistanceCompute(raft::resources const& handle, if (is_fused) { workspace.resize((sizeof(IndexT)) * ns, stream); - // todo(lsugy): remove cIdx cuvs::distance::fusedDistanceNNMinReduce( minClusterDistanceView.data_handle(), datasetView.data_handle(), From 20a25ffa67ff48b56d99b692ece1a93cd86bc865 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 8 Apr 2026 18:11:00 -0700 Subject: [PATCH 05/18] fix comment for args --- cpp/src/cluster/detail/kmeans_balanced.cuh | 4 ++-- cpp/src/cluster/detail/minClusterDistanceCompute.cu | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index c142393dc9..7f51f91d70 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -108,8 +108,8 @@ inline std::enable_if_t> predict_core( X_norm_view, L2NormBuf_OR_DistBuf, params.metric, - 0, // batch_samples (0 = default) - 0, // batch_centroids (0 = default) + 0, // batch_samples (unused for fused reduction) + 0, // batch_centroids (unused for fused reduction) workspace); // Copy keys to output labels diff --git a/cpp/src/cluster/detail/minClusterDistanceCompute.cu b/cpp/src/cluster/detail/minClusterDistanceCompute.cu index bbbf634184..7d4c0c4cec 100644 --- a/cpp/src/cluster/detail/minClusterDistanceCompute.cu +++ b/cpp/src/cluster/detail/minClusterDistanceCompute.cu @@ -30,7 +30,7 @@ void minClusterAndDistanceCompute( auto n_features = X.extent(1); auto n_clusters = centroids.extent(0); // todo(lsugy): change batch size computation when using fusedL2NN! - bool is_fused = metric == cuvs::distance::DistanceType::L2Expanded || + bool is_fused = metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded || metric == cuvs::distance::DistanceType::CosineExpanded; auto dataBatchSize = is_fused ? (IndexT)n_samples : getDataBatchSize(batch_samples, n_samples); From 2d6baa5a1b0cd218a68e0f907c3a4a90c370525b Mon Sep 17 00:00:00 2001 From: Tarang Jain <40517122+tarang-jain@users.noreply.github.com> Date: Thu, 9 Apr 2026 11:09:13 -0700 Subject: [PATCH 06/18] Update cpp/src/cluster/detail/minClusterDistanceCompute.cu Co-authored-by: Jinsol Park --- cpp/src/cluster/detail/minClusterDistanceCompute.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cpp/src/cluster/detail/minClusterDistanceCompute.cu b/cpp/src/cluster/detail/minClusterDistanceCompute.cu index 7d4c0c4cec..fb3aae9052 100644 --- a/cpp/src/cluster/detail/minClusterDistanceCompute.cu +++ b/cpp/src/cluster/detail/minClusterDistanceCompute.cu @@ -38,22 +38,22 @@ void minClusterAndDistanceCompute( if (is_fused) { L2NormBuf_OR_DistBuf.resize(n_clusters, stream); +auto centroidsNorm = + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); + if (metric == cuvs::distance::DistanceType::CosineExpanded) { raft::linalg::norm( handle, centroids, - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters), + centroidsNorm, raft::sqrt_op{}); } else { raft::linalg::norm( handle, centroids, - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters)); + centroidsNorm; } - auto centroidsNorm = - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - raft::KeyValuePair initial_value(0, std::numeric_limits::max()); raft::matrix::fill(handle, minClusterAndDistance, initial_value); From b46e4712c731c78dba40264a00205a6d32019843 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 9 Apr 2026 19:03:32 -0700 Subject: [PATCH 07/18] move batch size loop entirely to else condition --- cpp/src/cluster/detail/kmeans_common.cuh | 1 - .../detail/minClusterDistanceCompute.cu | 118 ++++++++---------- 2 files changed, 50 insertions(+), 69 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index 250563dd12..e42d868dd9 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -5,7 +5,6 @@ #pragma once #include "../../distance/distance.cuh" -#include "../../distance/fused_distance_nn.cuh" #include #include #include diff --git a/cpp/src/cluster/detail/minClusterDistanceCompute.cu b/cpp/src/cluster/detail/minClusterDistanceCompute.cu index fb3aae9052..4afea0c124 100644 --- a/cpp/src/cluster/detail/minClusterDistanceCompute.cu +++ b/cpp/src/cluster/detail/minClusterDistanceCompute.cu @@ -33,25 +33,20 @@ void minClusterAndDistanceCompute( bool is_fused = metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded || metric == cuvs::distance::DistanceType::CosineExpanded; - auto dataBatchSize = is_fused ? (IndexT)n_samples : getDataBatchSize(batch_samples, n_samples); - auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters); if (is_fused) { L2NormBuf_OR_DistBuf.resize(n_clusters, stream); -auto centroidsNorm = - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - + auto centroidsNorm = + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); + if (metric == cuvs::distance::DistanceType::CosineExpanded) { raft::linalg::norm( - handle, - centroids, - centroidsNorm, - raft::sqrt_op{}); + handle, centroids, centroidsNorm, raft::sqrt_op{}); } else { raft::linalg::norm( handle, centroids, - centroidsNorm; + centroidsNorm); } raft::KeyValuePair initial_value(0, std::numeric_limits::max()); @@ -77,6 +72,9 @@ auto centroidsNorm = 0.0f, stream); } else { + auto dataBatchSize = getDataBatchSize(batch_samples, n_samples); + auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters); + // TODO: Unless pool allocator is used, passing in a workspace for this // isn't really increasing performance because this needs to do a re-allocation // anyways. ref https://github.com/rapidsai/raft/issues/930 @@ -186,86 +184,70 @@ void minClusterDistanceCompute(raft::resources const& handle, auto n_clusters = centroids.extent(0); bool is_fused = metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded; - auto dataBatchSize = is_fused ? (IndexT)n_samples : getDataBatchSize(batch_samples, n_samples); - auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters); + metric == cuvs::distance::DistanceType::L2SqrtExpanded || + metric == cuvs::distance::DistanceType::CosineExpanded; + + raft::matrix::fill(handle, minClusterDistance, std::numeric_limits::max()); if (is_fused) { L2NormBuf_OR_DistBuf.resize(n_clusters, stream); + auto centroidsNorm = + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); + raft::linalg::norm( handle, raft::make_device_matrix_view( centroids.data_handle(), centroids.extent(0), centroids.extent(1)), - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters)); + centroidsNorm); + + workspace.resize((sizeof(IndexT)) * n_samples, stream); + + cuvs::distance::fusedDistanceNNMinReduce( + minClusterDistance.data_handle(), + X.data_handle(), + centroids.data_handle(), + L2NormX.data_handle(), + centroidsNorm.data_handle(), + n_samples, + n_clusters, + n_features, + (void*)workspace.data(), + metric != cuvs::distance::DistanceType::L2Expanded, + false, + true, + metric, + 0.0f, + stream); } else { + auto dataBatchSize = getDataBatchSize(batch_samples, n_samples); + auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters); + L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream); - } - // Note - pairwiseDistance and centroidsNorm share the same buffer - // centroidsNorm [n_clusters] - tensor wrapper around centroids L2 Norm - auto centroidsNorm = - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - // pairwiseDistance[ns x nc] - tensor wrapper around the distance buffer - auto pairwiseDistance = raft::make_device_matrix_view( - L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); + auto pairwiseDistance = raft::make_device_matrix_view( + L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); - raft::matrix::fill(handle, minClusterDistance, std::numeric_limits::max()); + // tile over the input data and calculate distance matrix [n_samples x + // n_clusters] + for (IndexT dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { + auto ns = std::min((IndexT)dataBatchSize, n_samples - dIdx); + + auto datasetView = raft::make_device_matrix_view( + X.data_handle() + dIdx * n_features, ns, n_features); + + auto minClusterDistanceView = + raft::make_device_vector_view(minClusterDistance.data_handle() + dIdx, ns); - // tile over the input data and calculate distance matrix [n_samples x - // n_clusters] - for (IndexT dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { - // # of samples for the current batch - auto ns = std::min((IndexT)dataBatchSize, n_samples - dIdx); - - // datasetView [ns x n_features] - view representing the current batch of - // input dataset - auto datasetView = raft::make_device_matrix_view( - X.data_handle() + dIdx * n_features, ns, n_features); - - // minClusterDistanceView [ns x n_clusters] - auto minClusterDistanceView = - raft::make_device_vector_view(minClusterDistance.data_handle() + dIdx, ns); - - auto L2NormXView = - raft::make_device_vector_view(L2NormX.data_handle() + dIdx, ns); - - if (is_fused) { - workspace.resize((sizeof(IndexT)) * ns, stream); - - cuvs::distance::fusedDistanceNNMinReduce( - minClusterDistanceView.data_handle(), - datasetView.data_handle(), - centroids.data_handle(), - L2NormXView.data_handle(), - centroidsNorm.data_handle(), - ns, - n_clusters, - n_features, - (void*)workspace.data(), - metric != cuvs::distance::DistanceType::L2Expanded, - false, - true, - metric, - 0.0f, - stream); - } else { // tile over the centroids for (IndexT cIdx = 0; cIdx < n_clusters; cIdx += centroidsBatchSize) { - // # of centroids for the current batch auto nc = std::min((IndexT)centroidsBatchSize, n_clusters - cIdx); - // centroidsView [nc x n_features] - view representing the current batch - // of centroids auto centroidsView = raft::make_device_matrix_view( centroids.data_handle() + cIdx * n_features, nc, n_features); - // pairwiseDistanceView [ns x nc] - view representing the pairwise - // distance for current batch auto pairwiseDistanceView = raft::make_device_matrix_view(pairwiseDistance.data_handle(), ns, nc); - // calculate pairwise distance between current tile of cluster centroids - // and input dataset pairwise_distance_kmeans( handle, datasetView, centroidsView, pairwiseDistanceView, metric); From e93af814414b998b8c020cfdadec66cc59585b00 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 9 Apr 2026 19:09:05 -0700 Subject: [PATCH 08/18] rm comments --- cpp/src/cluster/detail/minClusterDistanceCompute.cu | 2 -- 1 file changed, 2 deletions(-) diff --git a/cpp/src/cluster/detail/minClusterDistanceCompute.cu b/cpp/src/cluster/detail/minClusterDistanceCompute.cu index 4afea0c124..0e24342f7e 100644 --- a/cpp/src/cluster/detail/minClusterDistanceCompute.cu +++ b/cpp/src/cluster/detail/minClusterDistanceCompute.cu @@ -29,7 +29,6 @@ void minClusterAndDistanceCompute( auto n_samples = X.extent(0); auto n_features = X.extent(1); auto n_clusters = centroids.extent(0); - // todo(lsugy): change batch size computation when using fusedL2NN! bool is_fused = metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded || metric == cuvs::distance::DistanceType::CosineExpanded; @@ -54,7 +53,6 @@ void minClusterAndDistanceCompute( workspace.resize((sizeof(int)) * n_samples, stream); - // todo(lsugy): remove cIdx cuvs::distance::fusedDistanceNNMinReduce, IndexT>( minClusterAndDistance.data_handle(), X.data_handle(), From e1a93551d85272418272ddbac2b3e6d387eae66c Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 9 Apr 2026 19:14:00 -0700 Subject: [PATCH 09/18] add include --- cpp/src/cluster/detail/minClusterDistanceCompute.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/src/cluster/detail/minClusterDistanceCompute.cu b/cpp/src/cluster/detail/minClusterDistanceCompute.cu index 0e24342f7e..6a2f0b619a 100644 --- a/cpp/src/cluster/detail/minClusterDistanceCompute.cu +++ b/cpp/src/cluster/detail/minClusterDistanceCompute.cu @@ -4,6 +4,7 @@ */ #include "kmeans_common.cuh" +#include "../../distance/fused_distance_nn.cuh" #include From 079b1bbacabc70302f167674a1e3ee4276dfdc1b Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 9 Apr 2026 19:17:40 -0700 Subject: [PATCH 10/18] style --- cpp/src/cluster/detail/minClusterDistanceCompute.cu | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/cpp/src/cluster/detail/minClusterDistanceCompute.cu b/cpp/src/cluster/detail/minClusterDistanceCompute.cu index 6a2f0b619a..74fcc77522 100644 --- a/cpp/src/cluster/detail/minClusterDistanceCompute.cu +++ b/cpp/src/cluster/detail/minClusterDistanceCompute.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -#include "kmeans_common.cuh" #include "../../distance/fused_distance_nn.cuh" +#include "kmeans_common.cuh" #include @@ -30,7 +30,7 @@ void minClusterAndDistanceCompute( auto n_samples = X.extent(0); auto n_features = X.extent(1); auto n_clusters = centroids.extent(0); - bool is_fused = metric == cuvs::distance::DistanceType::L2Expanded || + bool is_fused = metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded || metric == cuvs::distance::DistanceType::CosineExpanded; @@ -44,9 +44,7 @@ void minClusterAndDistanceCompute( handle, centroids, centroidsNorm, raft::sqrt_op{}); } else { raft::linalg::norm( - handle, - centroids, - centroidsNorm); + handle, centroids, centroidsNorm); } raft::KeyValuePair initial_value(0, std::numeric_limits::max()); From 7c8fa6e1e323253c5612a3433110b12d7e58f40d Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 13 Apr 2026 13:06:17 -0700 Subject: [PATCH 11/18] correct norm computation for cosine --- .../detail/minClusterDistanceCompute.cu | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/cpp/src/cluster/detail/minClusterDistanceCompute.cu b/cpp/src/cluster/detail/minClusterDistanceCompute.cu index 74fcc77522..3a93f9362e 100644 --- a/cpp/src/cluster/detail/minClusterDistanceCompute.cu +++ b/cpp/src/cluster/detail/minClusterDistanceCompute.cu @@ -191,11 +191,20 @@ void minClusterDistanceCompute(raft::resources const& handle, auto centroidsNorm = raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - raft::linalg::norm( - handle, - raft::make_device_matrix_view( - centroids.data_handle(), centroids.extent(0), centroids.extent(1)), - centroidsNorm); + if (metric == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::norm( + handle, + raft::make_device_matrix_view( + centroids.data_handle(), centroids.extent(0), centroids.extent(1)), + centroidsNorm, + raft::sqrt_op{}); + } else { + raft::linalg::norm( + handle, + raft::make_device_matrix_view( + centroids.data_handle(), centroids.extent(0), centroids.extent(1)), + centroidsNorm); + } workspace.resize((sizeof(IndexT)) * n_samples, stream); From 925c4aca2b72f1407ab35ec5198b2e755ec2c8f2 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 13 Apr 2026 15:37:49 -0700 Subject: [PATCH 12/18] add cosine tests for balaned --- cpp/src/cluster/detail/kmeans.cuh | 32 +++++++++++++++++----------- cpp/tests/cluster/kmeans_balanced.cu | 27 +++++++++++++++++++++++ 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 5a35f203b3..a485ad346c 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -130,11 +130,13 @@ void kmeansPlusPlus(raft::resources const& handle, raft::device_matrix_view candidates_view( centroidCandidates.data_handle(), n_trials, n_features); - // L2 norm of X: ||c||^2 auto L2NormX = raft::make_device_vector(handle, n_samples); - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + if (metric == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::norm( + handle, X, L2NormX.view(), raft::sqrt_op{}); + } else if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded) { raft::linalg::norm(handle, X, L2NormX.view()); } @@ -342,13 +344,15 @@ void kmeans_fit_main(raft::resources const& handle, rmm::device_scalar clusterCostD(stream); - // L2 norm of X: ||x||^2 auto L2NormX = raft::make_device_vector(handle, n_samples); auto l2normx_view = raft::make_device_vector_view(L2NormX.data_handle(), n_samples); - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + if (metric == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::norm( + handle, X, L2NormX.view(), raft::sqrt_op{}); + } else if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded) { raft::linalg::norm(handle, X, L2NormX.view()); } @@ -523,10 +527,12 @@ void initScalableKMeansPlusPlus(raft::resources const& handle, // destructor releases the resource rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - // L2 norm of X: ||x||^2 auto L2NormX = raft::make_device_vector(handle, n_samples); - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + if (metric == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::norm( + handle, X, L2NormX.view(), raft::sqrt_op{}); + } else if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded) { raft::linalg::norm(handle, X, L2NormX.view()); } @@ -933,10 +939,12 @@ void kmeans_predict(raft::resources const& handle, raft::make_device_vector, IndexT>(handle, n_samples); rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - // L2 norm of X: ||x||^2 auto L2NormX = raft::make_device_vector(handle, n_samples); - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + if (metric == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::norm( + handle, X, L2NormX.view(), raft::sqrt_op{}); + } else if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded) { raft::linalg::norm(handle, X, L2NormX.view()); } diff --git a/cpp/tests/cluster/kmeans_balanced.cu b/cpp/tests/cluster/kmeans_balanced.cu index b84ab5a7ff..f1e12e09dc 100644 --- a/cpp/tests/cluster/kmeans_balanced.cu +++ b/cpp/tests/cluster/kmeans_balanced.cu @@ -179,10 +179,34 @@ std::vector> get_kmeans_balanced_inputs() return out; } +template +std::vector> get_kmeans_balanced_cosine_inputs() +{ + std::vector> out; + KmeansBalancedInputs p; + p.kb_params.n_iters = 20; + p.kb_params.metric = cuvs::distance::DistanceType::CosineExpanded; + p.tol = MathT{0.0001}; + std::vector> row_cols_k = { + {1000, 32, 5}, + {1000, 100, 20}, + {10000, 32, 10}, + {10000, 100, 50}, + }; + for (auto& rck : row_cols_k) { + p.n_rows = static_cast(std::get<0>(rck)); + p.n_cols = static_cast(std::get<1>(rck)); + p.n_clusters = static_cast(std::get<2>(rck)); + out.push_back(p); + } + return out; +} + const auto inputsf_i32 = get_kmeans_balanced_inputs(); // const auto inputsd_i32 = get_kmeans_balanced_inputs(); const auto inputsf_i64 = get_kmeans_balanced_inputs(); // const auto inputsd_i64 = get_kmeans_balanced_inputs(); +const auto inputsf_cosine_i32 = get_kmeans_balanced_cosine_inputs(); #define KB_TEST(test_type, test_name, test_inputs) \ typedef RAFT_DEPAREN(test_type) test_name; \ @@ -223,6 +247,9 @@ KB_TEST((KmeansBalancedTest // KB_TEST((KmeansBalancedTest), // KmeansBalancedTestFFI64I64, // inputsf_i64); +KB_TEST((KmeansBalancedTest), + KmeansBalancedTestCosineFFU32I32, + inputsf_cosine_i32); /* * Second set of tests: integer dataset with conversion From 40338a03ad5712171f5c526c9ebe9f8d47b75897 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 21 Apr 2026 16:22:55 -0700 Subject: [PATCH 13/18] pre-normalize centers --- cpp/src/cluster/detail/kmeans_balanced.cuh | 10 ++++++++++ cpp/src/cluster/detail/minClusterDistanceCompute.cu | 12 ++++-------- cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh | 3 --- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index 7f51f91d70..3c63274720 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -706,6 +706,16 @@ void balancing_em_iters(const raft::resources& handle, mapping_op, device_memory); } + + // For cosine, ensure the returned centers are unit-norm so downstream + // `predict` can skip the per-call centroid normalization. + if (params.metric == cuvs::distance::DistanceType::CosineExpanded) { + auto clusters_in_view = raft::make_device_matrix_view( + cluster_centers, n_clusters, dim); + auto clusters_out_view = + raft::make_device_matrix_view(cluster_centers, n_clusters, dim); + raft::linalg::row_normalize(handle, clusters_in_view, clusters_out_view); + } } /** Randomly initialize cluster centers and then call `balancing_em_iters`. */ diff --git a/cpp/src/cluster/detail/minClusterDistanceCompute.cu b/cpp/src/cluster/detail/minClusterDistanceCompute.cu index 3a93f9362e..6e77ca4977 100644 --- a/cpp/src/cluster/detail/minClusterDistanceCompute.cu +++ b/cpp/src/cluster/detail/minClusterDistanceCompute.cu @@ -40,8 +40,8 @@ void minClusterAndDistanceCompute( raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); if (metric == cuvs::distance::DistanceType::CosineExpanded) { - raft::linalg::norm( - handle, centroids, centroidsNorm, raft::sqrt_op{}); + // Centroids are L2-normalized for cosine metric + raft::matrix::fill(handle, centroidsNorm, DataT{1}); } else { raft::linalg::norm( handle, centroids, centroidsNorm); @@ -192,12 +192,8 @@ void minClusterDistanceCompute(raft::resources const& handle, raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); if (metric == cuvs::distance::DistanceType::CosineExpanded) { - raft::linalg::norm( - handle, - raft::make_device_matrix_view( - centroids.data_handle(), centroids.extent(0), centroids.extent(1)), - centroidsNorm, - raft::sqrt_op{}); + // Centroids are L2-normalized for cosine metric + raft::matrix::fill(handle, centroidsNorm, DataT{1}); } else { raft::linalg::norm( handle, diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh index 05c0176faa..6f33a05678 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh @@ -1328,9 +1328,6 @@ auto build(raft::resources const& handle, rmm::device_uvector labels(n_rows_train, stream, big_memory_resource); auto centers_const_view = raft::make_device_matrix_view( cluster_centers, impl->n_lists(), impl->dim()); - if (impl->metric() == distance::DistanceType::CosineExpanded) { - raft::linalg::row_normalize(handle, centers_const_view, centers_view); - } auto labels_view = raft::make_device_vector_view(labels.data(), n_rows_train); cuvs::cluster::kmeans::predict( From 707c6d9d0d04ae7c3cbd6ce7192405c74d01e424 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 22 Apr 2026 15:22:18 -0700 Subject: [PATCH 14/18] int workspace --- cpp/src/cluster/detail/minClusterDistanceCompute.cu | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/cpp/src/cluster/detail/minClusterDistanceCompute.cu b/cpp/src/cluster/detail/minClusterDistanceCompute.cu index 6e77ca4977..a27bce22d5 100644 --- a/cpp/src/cluster/detail/minClusterDistanceCompute.cu +++ b/cpp/src/cluster/detail/minClusterDistanceCompute.cu @@ -12,7 +12,9 @@ namespace cuvs::cluster::kmeans::detail { // Calculates a pair for every sample in input 'X' where key is an // index to an sample in 'centroids' (index of the nearest centroid) and 'value' -// is the distance between the sample and the 'centroid[key]' +// is the distance between the sample and the 'centroids[key]'. +// +// NB: (CosineExpanded): `centroids` rows must be L2-normalized when the cosine metric is used. template void minClusterAndDistanceCompute( raft::resources const& handle, @@ -163,6 +165,10 @@ INSTANTIATE_MIN_CLUSTER_AND_DISTANCE(double, int) #undef INSTANTIATE_MIN_CLUSTER_AND_DISTANCE +/** + * NB: (CosineExpanded): `centroids` rows must be L2-normalized when the cosine metric is used. + * Non-unit-norm centroids will silently produce incorrect distances. + */ template void minClusterDistanceCompute(raft::resources const& handle, raft::device_matrix_view X, @@ -202,7 +208,7 @@ void minClusterDistanceCompute(raft::resources const& handle, centroidsNorm); } - workspace.resize((sizeof(IndexT)) * n_samples, stream); + workspace.resize(sizeof(int) * n_samples, stream); cuvs::distance::fusedDistanceNNMinReduce( minClusterDistance.data_handle(), From c0b2376e1e52fd75464bb37aa17e616af5f17161 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 22 Apr 2026 17:13:08 -0700 Subject: [PATCH 15/18] rm cosine norms in kmeans.cuh --- cpp/src/cluster/detail/kmeans.cuh | 32 ++++++++++++------------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index a485ad346c..5a35f203b3 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -130,13 +130,11 @@ void kmeansPlusPlus(raft::resources const& handle, raft::device_matrix_view candidates_view( centroidCandidates.data_handle(), n_trials, n_features); + // L2 norm of X: ||c||^2 auto L2NormX = raft::make_device_vector(handle, n_samples); - if (metric == cuvs::distance::DistanceType::CosineExpanded) { - raft::linalg::norm( - handle, X, L2NormX.view(), raft::sqrt_op{}); - } else if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded) { raft::linalg::norm(handle, X, L2NormX.view()); } @@ -344,15 +342,13 @@ void kmeans_fit_main(raft::resources const& handle, rmm::device_scalar clusterCostD(stream); + // L2 norm of X: ||x||^2 auto L2NormX = raft::make_device_vector(handle, n_samples); auto l2normx_view = raft::make_device_vector_view(L2NormX.data_handle(), n_samples); - if (metric == cuvs::distance::DistanceType::CosineExpanded) { - raft::linalg::norm( - handle, X, L2NormX.view(), raft::sqrt_op{}); - } else if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded) { raft::linalg::norm(handle, X, L2NormX.view()); } @@ -527,12 +523,10 @@ void initScalableKMeansPlusPlus(raft::resources const& handle, // destructor releases the resource rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); + // L2 norm of X: ||x||^2 auto L2NormX = raft::make_device_vector(handle, n_samples); - if (metric == cuvs::distance::DistanceType::CosineExpanded) { - raft::linalg::norm( - handle, X, L2NormX.view(), raft::sqrt_op{}); - } else if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded) { raft::linalg::norm(handle, X, L2NormX.view()); } @@ -939,12 +933,10 @@ void kmeans_predict(raft::resources const& handle, raft::make_device_vector, IndexT>(handle, n_samples); rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); + // L2 norm of X: ||x||^2 auto L2NormX = raft::make_device_vector(handle, n_samples); - if (metric == cuvs::distance::DistanceType::CosineExpanded) { - raft::linalg::norm( - handle, X, L2NormX.view(), raft::sqrt_op{}); - } else if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded) { raft::linalg::norm(handle, X, L2NormX.view()); } From 44621789ea49bb0d2227018dd057387a9aae1fa1 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 23 Apr 2026 09:00:30 -0700 Subject: [PATCH 16/18] revert unit norm assumption --- cpp/src/cluster/detail/kmeans_balanced.cuh | 10 ------- .../detail/minClusterDistanceCompute.cu | 30 +++++-------------- 2 files changed, 7 insertions(+), 33 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index bf80dddb89..b40c6fbcf0 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -706,16 +706,6 @@ void balancing_em_iters(const raft::resources& handle, mapping_op, device_memory); } - - // For cosine, ensure the returned centers are unit-norm so downstream - // `predict` can skip the per-call centroid normalization. - if (params.metric == cuvs::distance::DistanceType::CosineExpanded) { - auto clusters_in_view = raft::make_device_matrix_view( - cluster_centers, n_clusters, dim); - auto clusters_out_view = - raft::make_device_matrix_view(cluster_centers, n_clusters, dim); - raft::linalg::row_normalize(handle, clusters_in_view, clusters_out_view); - } } /** Randomly initialize cluster centers and then call `balancing_em_iters`. */ diff --git a/cpp/src/cluster/detail/minClusterDistanceCompute.cu b/cpp/src/cluster/detail/minClusterDistanceCompute.cu index a27bce22d5..b4ce4aa550 100644 --- a/cpp/src/cluster/detail/minClusterDistanceCompute.cu +++ b/cpp/src/cluster/detail/minClusterDistanceCompute.cu @@ -13,8 +13,6 @@ namespace cuvs::cluster::kmeans::detail { // Calculates a pair for every sample in input 'X' where key is an // index to an sample in 'centroids' (index of the nearest centroid) and 'value' // is the distance between the sample and the 'centroids[key]'. -// -// NB: (CosineExpanded): `centroids` rows must be L2-normalized when the cosine metric is used. template void minClusterAndDistanceCompute( raft::resources const& handle, @@ -41,13 +39,8 @@ void minClusterAndDistanceCompute( auto centroidsNorm = raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - if (metric == cuvs::distance::DistanceType::CosineExpanded) { - // Centroids are L2-normalized for cosine metric - raft::matrix::fill(handle, centroidsNorm, DataT{1}); - } else { - raft::linalg::norm( - handle, centroids, centroidsNorm); - } + raft::linalg::norm( + handle, centroids, centroidsNorm); raft::KeyValuePair initial_value(0, std::numeric_limits::max()); raft::matrix::fill(handle, minClusterAndDistance, initial_value); @@ -165,10 +158,6 @@ INSTANTIATE_MIN_CLUSTER_AND_DISTANCE(double, int) #undef INSTANTIATE_MIN_CLUSTER_AND_DISTANCE -/** - * NB: (CosineExpanded): `centroids` rows must be L2-normalized when the cosine metric is used. - * Non-unit-norm centroids will silently produce incorrect distances. - */ template void minClusterDistanceCompute(raft::resources const& handle, raft::device_matrix_view X, @@ -197,16 +186,11 @@ void minClusterDistanceCompute(raft::resources const& handle, auto centroidsNorm = raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - if (metric == cuvs::distance::DistanceType::CosineExpanded) { - // Centroids are L2-normalized for cosine metric - raft::matrix::fill(handle, centroidsNorm, DataT{1}); - } else { - raft::linalg::norm( - handle, - raft::make_device_matrix_view( - centroids.data_handle(), centroids.extent(0), centroids.extent(1)), - centroidsNorm); - } + raft::linalg::norm( + handle, + raft::make_device_matrix_view( + centroids.data_handle(), centroids.extent(0), centroids.extent(1)), + centroidsNorm); workspace.resize(sizeof(int) * n_samples, stream); From a8e281fb3b6beaf60e5dc938639b228fe1d441e9 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 23 Apr 2026 09:59:39 -0700 Subject: [PATCH 17/18] empty commit From 940c76de100cac1f2af9fdc98e407db5b2cd46a2 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 23 Apr 2026 17:11:09 -0700 Subject: [PATCH 18/18] fix sqrt_op norm --- .../detail/minClusterDistanceCompute.cu | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/cpp/src/cluster/detail/minClusterDistanceCompute.cu b/cpp/src/cluster/detail/minClusterDistanceCompute.cu index b4ce4aa550..e271a861e8 100644 --- a/cpp/src/cluster/detail/minClusterDistanceCompute.cu +++ b/cpp/src/cluster/detail/minClusterDistanceCompute.cu @@ -39,8 +39,13 @@ void minClusterAndDistanceCompute( auto centroidsNorm = raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - raft::linalg::norm( - handle, centroids, centroidsNorm); + if (metric == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::norm( + handle, centroids, centroidsNorm, raft::sqrt_op{}); + } else { + raft::linalg::norm( + handle, centroids, centroidsNorm); + } raft::KeyValuePair initial_value(0, std::numeric_limits::max()); raft::matrix::fill(handle, minClusterAndDistance, initial_value); @@ -186,11 +191,20 @@ void minClusterDistanceCompute(raft::resources const& handle, auto centroidsNorm = raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - raft::linalg::norm( - handle, - raft::make_device_matrix_view( - centroids.data_handle(), centroids.extent(0), centroids.extent(1)), - centroidsNorm); + if (metric == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::norm( + handle, + raft::make_device_matrix_view( + centroids.data_handle(), centroids.extent(0), centroids.extent(1)), + centroidsNorm, + raft::sqrt_op{}); + } else { + raft::linalg::norm( + handle, + raft::make_device_matrix_view( + centroids.data_handle(), centroids.extent(0), centroids.extent(1)), + centroidsNorm); + } workspace.resize(sizeof(int) * n_samples, stream);