Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] membership_vector for HDBSCAN #5247

Merged
merged 33 commits into from
Mar 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
e49d06a
membership_vector initial commit
tarang-jain Feb 18, 2023
436b180
Further updates to membership_vector
tarang-jain Feb 22, 2023
48030b8
Merge branch 'branch-23.04' into fea-membership-vector
tarang-jain Feb 22, 2023
7912dba
Initial testing membership_vector
tarang-jain Feb 23, 2023
4b41edb
Debug statements
tarang-jain Feb 23, 2023
fe0fd34
Merge branch 'fea-membership-vector' of https://github.com/tarang-jai…
tarang-jain Feb 23, 2023
9d5badc
debugging membership_vector
tarang-jain Feb 24, 2023
19f9dd8
membership_vector first working impl
tarang-jain Feb 28, 2023
a4b565c
GoogleTest intermediate commit
tarang-jain Feb 28, 2023
1f4bf78
GTest working
tarang-jain Feb 28, 2023
fdf100b
working tests and styling changes
tarang-jain Feb 28, 2023
e18096a
replace with raft mdspan primitives and add FastIntDiv
tarang-jain Mar 1, 2023
c2aa77e
Merge branch 'branch-23.04' into fea-membership-vector
tarang-jain Mar 1, 2023
182ba31
cpu support
tarang-jain Mar 1, 2023
366ef26
Fix failing pytest
tarang-jain Mar 7, 2023
b60d869
Merge branch 'branch-23.04' into fea-membership-vector
tarang-jain Mar 7, 2023
6bfaae2
modification after merge
tarang-jain Mar 7, 2023
c4e0bf1
Update softmax with raft::linalg reduction
tarang-jain Mar 8, 2023
fb634e4
Remove sync stream
tarang-jain Mar 9, 2023
a49ba87
memory study commit (to be reversed)
tarang-jain Mar 11, 2023
4ed9fd7
Merge branch 'branch-23.04' of github.com:rapidsai/cuml into fea-memb…
tarang-jain Mar 11, 2023
fa7b44e
Style fix
tarang-jain Mar 17, 2023
45f8ca4
Merge branch 'branch-23.04' of github.com:rapidsai/cuml into fea-memb…
tarang-jain Mar 17, 2023
367de04
Remove print debug statements
tarang-jain Mar 17, 2023
980b1f7
Merge branch 'branch-23.04' of github.com:rapidsai/cuml into fea-memb…
tarang-jain Mar 20, 2023
98aa237
Merge branch 'branch-23.04' into fea-membership-vector
tarang-jain Mar 20, 2023
d387026
Merge branch 'branch-23.04' into fea-membership-vector
tarang-jain Mar 27, 2023
ed40e22
Updates after PR reviews
tarang-jain Mar 28, 2023
387cde8
Merge branch 'fea-membership-vector' of https://github.com/tarang-jai…
tarang-jain Mar 28, 2023
092b3f8
Merge branch 'branch-23.04' of github.com:rapidsai/cuml into fea-memb…
tarang-jain Mar 28, 2023
ef85fd3
Update height_argmax
tarang-jain Mar 28, 2023
17de9ec
Merge branch 'branch-23.04' into fea-membership-vector
tarang-jain Mar 30, 2023
38208ec
Merge branch 'branch-23.04' into fea-membership-vector
tarang-jain Mar 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions cpp/include/cuml/cluster/hdbscan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,16 @@ void compute_all_points_membership_vectors(
raft::distance::DistanceType metric,
float* membership_vec);

void compute_membership_vector(const raft::handle_t& handle,
HDBSCAN::Common::CondensedHierarchy<int, float>& condensed_tree,
HDBSCAN::Common::PredictionData<int, float>& prediction_data,
const float* X,
const float* points_to_predict,
size_t n_prediction_points,
int min_samples,
raft::distance::DistanceType metric,
float* membership_vec);

void out_of_sample_predict(const raft::handle_t& handle,
HDBSCAN::Common::CondensedHierarchy<int, float>& condensed_tree,
HDBSCAN::Common::PredictionData<int, float>& prediction_data,
Expand Down
60 changes: 42 additions & 18 deletions cpp/src/hdbscan/detail/kernels/soft_clustering.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/util/fast_int_div.cuh>

namespace ML {
namespace HDBSCAN {
namespace detail {
Expand All @@ -28,6 +29,7 @@ __global__ void merge_height_kernel(value_t* heights,
value_idx* parents,
size_t m,
value_idx n_selected_clusters,
raft::util::FastIntDiv n,
value_idx* selected_clusters)
{
value_idx idx = blockDim.x * blockIdx.x + threadIdx.x;
Expand Down Expand Up @@ -62,25 +64,47 @@ __global__ void merge_height_kernel(value_t* heights,
}
}

template <typename value_idx, typename value_t>
__global__ void prob_in_some_cluster_kernel(value_t* heights,
value_t* height_argmax,
value_t* deaths,
value_idx* index_into_children,
value_idx* selected_clusters,
value_t* lambdas,
value_t* prob_in_some_cluster,
value_idx n_selected_clusters,
value_idx n_leaves,
size_t m)
template <typename value_idx, typename value_t, int tpb = 256>
__global__ void merge_height_kernel(value_t* heights,
value_t* lambdas,
value_t* prediction_lambdas,
value_idx* min_mr_indices,
value_idx* index_into_children,
value_idx* parents,
size_t n_prediction_points,
value_idx n_selected_clusters,
raft::util::FastIntDiv n,
value_idx* selected_clusters)
{
value_idx idx = blockDim.x * blockIdx.x + threadIdx.x;
if (idx < (value_idx)m) {
value_t max_lambda = max(lambdas[index_into_children[idx]],
deaths[selected_clusters[(int)height_argmax[idx]] - n_leaves]);
prob_in_some_cluster[idx] =
heights[idx * n_selected_clusters + (int)height_argmax[idx]] / max_lambda;
return;
if (idx < value_idx(n_prediction_points * n_selected_clusters)) {
value_idx row = idx / n;
value_idx col = idx % n;
value_idx right_cluster = selected_clusters[col];
value_idx left_cluster = parents[index_into_children[min_mr_indices[row]]];
bool took_right_parent = false;
bool took_left_parent = false;
value_idx last_cluster;

while (left_cluster != right_cluster) {
if (left_cluster > right_cluster) {
took_left_parent = true;
last_cluster = left_cluster;
left_cluster = parents[index_into_children[left_cluster]];
} else {
took_right_parent = true;
last_cluster = right_cluster;
right_cluster = parents[index_into_children[right_cluster]];
}
}

if (took_left_parent && took_right_parent) {
heights[idx] = lambdas[index_into_children[last_cluster]];
}

else {
heights[idx] = prediction_lambdas[row];
}
}
}

Expand Down
109 changes: 66 additions & 43 deletions cpp/src/hdbscan/detail/predict.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ namespace Predict {
Find the nearest mutual reachability neighbor of a point, and compute
the associated lambda value for the point, given the mutual reachability
distance to a nearest neighbor.
*
* @tparam value_idx
* @tparam value_t
* @tparam tpb
Expand Down Expand Up @@ -149,6 +148,62 @@ void _find_cluster_and_probability(const raft::handle_t& handle,
out_labels,
out_probabilities);
}

// Build the mutual reachability graph and obtain the nearest neighbors for the prediction points.
// The KNN and core distances of prediction points are computed here.
template <typename value_idx, typename value_t, int tpb = 256>
void _compute_knn_and_nearest_neighbor(const raft::handle_t& handle,
Common::PredictionData<value_idx, value_t>& prediction_data,
const value_t* X,
const value_t* points_to_predict,
int min_samples,
size_t n_prediction_points,
value_idx* min_mr_inds,
value_t* prediction_lambdas,
raft::distance::DistanceType metric)
{
auto stream = handle.get_stream();
size_t m = prediction_data.n_rows;
size_t n = prediction_data.n_cols;
value_t* input_core_dists = prediction_data.get_core_dists();
int neighborhood = (min_samples - 1) * 2;

rmm::device_uvector<value_idx> inds(neighborhood * n_prediction_points, stream);
rmm::device_uvector<value_t> dists(neighborhood * n_prediction_points, stream);
rmm::device_uvector<value_t> prediction_core_dists(n_prediction_points, stream);

// perform knn
Reachability::compute_knn(handle,
X,
inds.data(),
dists.data(),
m,
n,
points_to_predict,
n_prediction_points,
neighborhood,
metric);

// Slice core distances (distances to kth nearest neighbor). The index of the neighbor is
// consistent with Scikit-learn Contrib
Reachability::core_distances<value_idx>(dists.data(),
min_samples,
neighborhood,
n_prediction_points,
prediction_core_dists.data(),
stream);

_find_neighbor_and_lambda(handle,
input_core_dists,
prediction_core_dists.data(),
dists.data(),
inds.data(),
n_prediction_points,
neighborhood,
min_mr_inds,
prediction_lambdas);
}

/**
* Predict the cluster label and the probability of the label for new points.
* The returned labels are those of the original clustering,
Expand All @@ -165,7 +220,7 @@ void _find_cluster_and_probability(const raft::handle_t& handle,
* @param[in] labels converted monotonic labels of the input data points
* @param[in] points_to_predict input prediction points (size n_prediction_points * n)
* @param[in] n_prediction_points number of prediction points
* @param[in] metric distance metric to use
* @param[in] metric distance metric
* @param[in] min_samples neighborhood size during training (includes self-loop)
* @param[out] out_labels output cluster labels
* @param[out] out_probabilities output probabilities
Expand All @@ -189,50 +244,18 @@ void approximate_predict(const raft::handle_t& handle,
auto stream = handle.get_stream();
auto exec_policy = handle.get_thrust_policy();

size_t m = prediction_data.n_rows;
size_t n = prediction_data.n_cols;
value_t* input_core_dists = prediction_data.get_core_dists();

// this is the neighborhood of prediction points for which MR distances are computed
int neighborhood = (min_samples - 1) * 2;

rmm::device_uvector<value_idx> inds(neighborhood * n_prediction_points, stream);
rmm::device_uvector<value_t> dists(neighborhood * n_prediction_points, stream);
rmm::device_uvector<value_t> prediction_core_dists(n_prediction_points, stream);

// perform knn
Reachability::compute_knn(handle,
X,
inds.data(),
dists.data(),
m,
n,
points_to_predict,
n_prediction_points,
neighborhood,
metric);

// Slice core distances (distances to kth nearest neighbor). The index of the neighbor is
// consistent with Scikit-learn Contrib
Reachability::core_distances<value_idx>(dists.data(),
min_samples,
neighborhood,
n_prediction_points,
prediction_core_dists.data(),
stream);

// Obtain lambdas for each prediction point using the closest point in mutual reachability space
rmm::device_uvector<value_t> prediction_lambdas(n_prediction_points, stream);
rmm::device_uvector<value_idx> min_mr_inds(n_prediction_points, stream);
_find_neighbor_and_lambda(handle,
input_core_dists,
prediction_core_dists.data(),
dists.data(),
inds.data(),
n_prediction_points,
neighborhood,
min_mr_inds.data(),
prediction_lambdas.data());
_compute_knn_and_nearest_neighbor(handle,
prediction_data,
X,
points_to_predict,
min_samples,
n_prediction_points,
min_mr_inds.data(),
prediction_lambdas.data(),
metric);

// Using the nearest neighbor indices, find the assigned cluster label and probability
_find_cluster_and_probability(handle,
Expand Down
Loading