Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion cpp/include/cuvs/cluster/kmeans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,13 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 };
* on the host. Data is processed in GPU-sized batches, streaming from host to device.
* The batch size is controlled by params.streaming_batch_size.
*
* Multi-GPU dispatch is selected automatically based on the handle state:
* - If `raft::resource::is_multi_gpu(handle)` (cuVS SNMG): the full dataset X
* is split across GPUs internally with an OpenMP parallel region and NCCL.
* - If `raft::resource::comms_initialized(handle)` (Dask/Ray): X is treated as
* this worker's partition, and RAFT communicators are used for collectives.
* - Otherwise: single-GPU batched k-means.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/cluster/kmeans.hpp>
Expand Down Expand Up @@ -196,7 +203,8 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 };
* raft::make_host_scalar_view(&n_iter));
* @endcode
*
* @param[in] handle The raft handle.
* @param[in] handle The raft handle. When a multi-GPU resource is
* attached, multi-GPU dispatch is used automatically.
* @param[in] params Parameters for KMeans model. Batch size is read from
* params.streaming_batch_size.
* @param[in] X Training instances on HOST memory. The data must
Expand Down
239 changes: 235 additions & 4 deletions cpp/src/cluster/detail/kmeans_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/add.cuh>
#include <raft/linalg/map.cuh>
#include <raft/linalg/map_then_reduce.cuh>
#include <raft/linalg/matrix_vector_op.cuh>
Expand All @@ -43,6 +44,7 @@
#include <cuda.h>
#include <cuda/iterator>
#include <thrust/for_each.h>
#include <thrust/iterator/transform_iterator.h>

#include <algorithm>
#include <cmath>
Expand Down Expand Up @@ -367,7 +369,9 @@ void minClusterAndDistanceCompute(
cuvs::distance::DistanceType metric,
int batch_samples,
int batch_centroids,
rmm::device_uvector<char>& workspace);
rmm::device_uvector<char>& workspace,
std::optional<raft::device_vector_view<const DataT, IndexT>> precomputed_centroid_norms =
std::nullopt);

#define EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(DataT, IndexT) \
extern template void minClusterAndDistanceCompute<DataT, IndexT>( \
Expand All @@ -380,7 +384,8 @@ void minClusterAndDistanceCompute(
cuvs::distance::DistanceType metric, \
int batch_samples, \
int batch_centroids, \
rmm::device_uvector<char>& workspace);
rmm::device_uvector<char>& workspace, \
std::optional<raft::device_vector_view<const DataT, IndexT>> precomputed_centroid_norms);

EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(float, int64_t)
EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(float, int)
Expand All @@ -399,7 +404,9 @@ void minClusterDistanceCompute(raft::resources const& handle,
cuvs::distance::DistanceType metric,
int batch_samples,
int batch_centroids,
rmm::device_uvector<char>& workspace);
rmm::device_uvector<char>& workspace,
std::optional<raft::device_vector_view<const DataT, IndexT>>
precomputed_centroid_norms = std::nullopt);

#define EXTERN_TEMPLATE_MIN_CLUSTER_DISTANCE(DataT, IndexT) \
extern template void minClusterDistanceCompute<DataT, IndexT>( \
Expand All @@ -412,7 +419,8 @@ void minClusterDistanceCompute(raft::resources const& handle,
cuvs::distance::DistanceType metric, \
int batch_samples, \
int batch_centroids, \
rmm::device_uvector<char>& workspace);
rmm::device_uvector<char>& workspace, \
std::optional<raft::device_vector_view<const DataT, IndexT>> precomputed_centroid_norms);

EXTERN_TEMPLATE_MIN_CLUSTER_DISTANCE(float, int64_t)
EXTERN_TEMPLATE_MIN_CLUSTER_DISTANCE(double, int64_t)
Expand Down Expand Up @@ -528,6 +536,229 @@ void compute_centroid_adjustments(
n_clusters,
stream);
}

/**
* @brief Process one local chunk in a Lloyd iteration.
*
* For the given @p batch_data and @p batch_weights, this helper:
* 1. assigns each sample to its nearest centroid,
* 2. computes weighted per-cluster sums and per-cluster weight totals
* into the per-batch scratch buffers @p batch_sums / @p batch_counts,
* 3. adds those scratch values into the running accumulators
* @p centroid_sums / @p weight_per_cluster,
* 4. optionally adds the weighted clustering cost of this batch into
* the running @p clustering_cost scalar (via the scratch
* @p batch_clustering_cost).
*
* Preconditions (caller responsibility):
* - @p centroid_sums and @p weight_per_cluster must be zero-initialized
* before the first call of a Lloyd iteration (they accumulate across batches).
* - @p clustering_cost must be zero-initialized before the first call of a
* Lloyd iteration when @p compute_inertia is true.
* - @p batch_sums, @p batch_counts, and @p batch_clustering_cost do NOT need
* to be zeroed — they are overwritten on each call.
* - @p batch_workspace may be empty; it is grown to `batch_data.extent(0)` bytes
* by compute_centroid_adjustments.
*
* @param[in] handle RAFT resources handle.
* @param[in] batch_data Input samples [n_rows x n_features].
* @param[in] batch_weights Per-sample weights [n_rows].
* @param[in] centroids Current centroids [n_clusters x n_features].
* @param[in] metric Distance metric.
* @param[in] tile_samples Sample-tiling size forwarded to
* minClusterAndDistanceCompute.
* @param[in] tile_centroids Centroid-tiling size forwarded to
* minClusterAndDistanceCompute.
* @param[out] minClusterAndDistance Per-sample (label, distance) pairs [n_rows].
* @param[in] L2NormBatch Pre-computed ||x||^2 for @p batch_data
* (required only for L2Expanded-family metrics).
* @param[inout] L2NormBuf_OR_DistBuf Workspace for centroid norms / pairwise distances.
* @param[inout] workspace Generic workspace buffer.
* @param[inout] batch_workspace Workspace for reduce_rows_by_key (sized internally).
* @param[inout] centroid_sums Running weighted-sum accumulator
* [n_clusters x n_features].
* @param[inout] weight_per_cluster Running per-cluster weight accumulator [n_clusters].
* @param[out] batch_sums Per-batch weighted-sum scratch [n_clusters x n_features].
* @param[out] batch_counts Per-batch weight scratch [n_clusters].
* @param[inout] clustering_cost Running weighted clustering cost accumulator (scalar).
* Unused when @p compute_inertia is false.
* @param[out] batch_clustering_cost Per-batch cost scratch (scalar).
* Unused when @p compute_inertia is false.
* @param[in] compute_inertia Whether to accumulate the weighted clustering cost.
* @param[in] centroid_norms Optional pre-computed centroid norms.
* When provided, avoids recomputing them inside
* minClusterAndDistanceCompute; useful when the
* same centroids are reused across many batches.
*/
template <typename DataT, typename IndexT>
void process_batch(
raft::resources const& handle,
raft::device_matrix_view<const DataT, IndexT> batch_data,
raft::device_vector_view<const DataT, IndexT> batch_weights,
raft::device_matrix_view<const DataT, IndexT> centroids,
cuvs::distance::DistanceType metric,
int tile_samples,
int tile_centroids,
raft::device_vector_view<raft::KeyValuePair<IndexT, DataT>, IndexT> minClusterAndDistance,
raft::device_vector_view<const DataT, IndexT> L2NormBatch,
rmm::device_uvector<DataT>& L2NormBuf_OR_DistBuf,
rmm::device_uvector<char>& workspace,
rmm::device_uvector<char>& batch_workspace,
raft::device_matrix_view<DataT, IndexT> centroid_sums,
raft::device_vector_view<DataT, IndexT> weight_per_cluster,
raft::device_matrix_view<DataT, IndexT> batch_sums,
raft::device_vector_view<DataT, IndexT> batch_counts,
raft::device_scalar_view<DataT> clustering_cost,
raft::device_scalar_view<DataT> batch_clustering_cost,
bool compute_inertia,
std::optional<raft::device_vector_view<const DataT, IndexT>> centroid_norms = std::nullopt)
{
cudaStream_t stream = raft::resource::get_cuda_stream(handle);

minClusterAndDistanceCompute<DataT, IndexT>(handle,
batch_data,
centroids,
minClusterAndDistance,
L2NormBatch,
L2NormBuf_OR_DistBuf,
metric,
tile_samples,
tile_centroids,
workspace,
centroid_norms);

KeyValueIndexOp<IndexT, DataT> conversion_op;
thrust::transform_iterator<KeyValueIndexOp<IndexT, DataT>,
const raft::KeyValuePair<IndexT, DataT>*>
labels_itr(minClusterAndDistance.data_handle(), conversion_op);

compute_centroid_adjustments(handle,
batch_data,
batch_weights,
labels_itr,
static_cast<IndexT>(centroid_sums.extent(0)),
batch_sums,
batch_counts,
batch_workspace);

raft::linalg::add(centroid_sums.data_handle(),
centroid_sums.data_handle(),
batch_sums.data_handle(),
centroid_sums.size(),
stream);

raft::linalg::add(weight_per_cluster.data_handle(),
weight_per_cluster.data_handle(),
batch_counts.data_handle(),
weight_per_cluster.size(),
stream);

if (!compute_inertia) { return; }

// Note: batch_clustering_cost does not need to be zero-initialized; computeClusterCost
// writes (via cub::DeviceReduce::Reduce) rather than accumulating into it.
raft::linalg::map(
handle,
minClusterAndDistance,
[=] __device__(const raft::KeyValuePair<IndexT, DataT> kvp, DataT wt) {
raft::KeyValuePair<IndexT, DataT> res;
res.value = kvp.value * wt;
res.key = kvp.key;
return res;
},
raft::make_const_mdspan(minClusterAndDistance),
batch_weights);

computeClusterCost(handle,
minClusterAndDistance,
workspace,
batch_clustering_cost,
raft::value_op{},
raft::add_op{});

raft::linalg::add(clustering_cost.data_handle(),
clustering_cost.data_handle(),
batch_clustering_cost.data_handle(),
1,
stream);
}

/**
* @brief Compute the weighted clustering cost (inertia) of @p X against @p centroids.
*
* This helper mirrors the inertia branch of ::process_batch: it assigns each sample
* to its nearest centroid using @p metric (so L2SqrtExpanded yields sqrt distances),
* multiplies per-sample distances by @p weights, and sum-reduces the result into
* @p inertia_out. Unlike the public ::cuvs::cluster::kmeans::cluster_cost(X, centroids, …)
* overload (which is hardcoded to L2Expanded), this respects the configured metric and
* can reuse a pre-computed @p L2NormX and optional pre-computed @p centroid_norms,
* avoiding redundant norm passes when called repeatedly (e.g. per streaming batch).
*
* The output scalar is overwritten (not accumulated) by the internal cub reduction,
* so callers do not need to zero-initialize it. For multi-batch use, write each batch
* result to a scratch scalar and add it into the running total.
*
* @param[in] handle RAFT resources handle.
* @param[in] X Input samples [n_rows x n_features].
* @param[in] weights Per-sample weights [n_rows]. Pass a vector of ones for
* unweighted inertia.
* @param[in] centroids Centroids [n_clusters x n_features].
* @param[in] metric Distance metric (must match the training-time metric).
* @param[in] tile_samples Sample-tiling size for minClusterAndDistanceCompute.
* @param[in] tile_centroids Centroid-tiling size for minClusterAndDistanceCompute.
* @param[out] scratch_kvp Scratch buffer for per-sample (label, distance) pairs
* [n_rows]. Overwritten.
* @param[in] L2NormX Pre-computed ||x||^2 for @p X
* (required only for L2Expanded-family metrics; for other
* metrics the buffer is not read).
* @param[inout] L2NormBuf_OR_DistBuf Scratch for centroid norms / pairwise distances.
* @param[inout] workspace Generic workspace buffer.
* @param[out] inertia_out Scalar to receive the weighted cost. Overwritten.
* @param[in] centroid_norms Optional pre-computed centroid norms to avoid recomputation.
*/
template <typename DataT, typename IndexT>
void compute_weighted_inertia(
raft::resources const& handle,
raft::device_matrix_view<const DataT, IndexT> X,
raft::device_vector_view<const DataT, IndexT> weights,
raft::device_matrix_view<const DataT, IndexT> centroids,
cuvs::distance::DistanceType metric,
int tile_samples,
int tile_centroids,
raft::device_vector_view<raft::KeyValuePair<IndexT, DataT>, IndexT> scratch_kvp,
raft::device_vector_view<const DataT, IndexT> L2NormX,
rmm::device_uvector<DataT>& L2NormBuf_OR_DistBuf,
rmm::device_uvector<char>& workspace,
raft::device_scalar_view<DataT> inertia_out,
std::optional<raft::device_vector_view<const DataT, IndexT>> centroid_norms = std::nullopt)
{
minClusterAndDistanceCompute<DataT, IndexT>(handle,
X,
centroids,
scratch_kvp,
L2NormX,
L2NormBuf_OR_DistBuf,
metric,
tile_samples,
tile_centroids,
workspace,
centroid_norms);

raft::linalg::map(
handle,
scratch_kvp,
[=] __device__(const raft::KeyValuePair<IndexT, DataT> kvp, DataT wt) {
raft::KeyValuePair<IndexT, DataT> res;
res.value = kvp.value * wt;
res.key = kvp.key;
return res;
},
raft::make_const_mdspan(scratch_kvp),
weights);

computeClusterCost(handle, scratch_kvp, workspace, inertia_out, raft::value_op{}, raft::add_op{});
}

/**
* @brief Finalize centroids by dividing accumulated sums by counts.
*
Expand Down
Loading
Loading