Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] approximate_predict function for HDBSCAN #4872

Merged
Show file tree
Hide file tree
Changes from 87 commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
8614cbd
Exemplar indices obtained
tarang-jain Jul 1, 2022
dbb7d48
Further additions to distance membership
tarang-jain Jul 5, 2022
8e25363
Cleanup dist membership vector
tarang-jain Jul 6, 2022
a6cb20f
Include changes
tarang-jain Jul 6, 2022
d04dff2
testing
tarang-jain Jul 6, 2022
2cfdae6
Further changes
tarang-jain Jul 6, 2022
f13e587
Further changes to distance based membership (clean build)
tarang-jain Jul 7, 2022
8421082
Reuse label map, replace unique_by_key with sorted_coo_to_csr
tarang-jain Jul 7, 2022
1ab28c0
Outlier based membership initial commit (unclean build)
tarang-jain Jul 11, 2022
66c815a
Restructuring functions (unclean build)
tarang-jain Jul 12, 2022
abff747
Intermediate commits
tarang-jain Jul 12, 2022
b979a5b
Corrections in exemplar computation and outlier membership
tarang-jain Jul 15, 2022
7bf26d7
All point membership vector all parts working
tarang-jain Jul 19, 2022
a7ab49e
Initial commit for Prediction Data class
tarang-jain Jul 20, 2022
bffd7ca
initial commit
tarang-jain Jul 25, 2022
ed59da5
Staged changes
tarang-jain Jul 25, 2022
3ed7ba3
Circling back to implementation without PredictionData
tarang-jain Jul 25, 2022
199695f
PredictionData finally added (errors in cython)
tarang-jain Jul 26, 2022
17cf426
Clean build with somee debug statements
tarang-jain Jul 26, 2022
e5bc693
Python API created and working
tarang-jain Jul 26, 2022
4f4136b
convert output dtype to cupy
tarang-jain Jul 26, 2022
b130ebe
Debugging exemplar_idx with large number of clusters
tarang-jain Jul 27, 2022
51cc6c8
correction cache function and moving to .cu file
tarang-jain Jul 28, 2022
336d638
Allow size_t as size of dataset (note: beware of limits as laterconve…
tarang-jain Jul 28, 2022
cca5f15
Save data (self.X_m), resolving compute-sanitizer errors
tarang-jain Jul 29, 2022
d5264f9
Resolving compute-santiizer error
tarang-jain Jul 29, 2022
441a397
Added gtest
tarang-jain Aug 1, 2022
2c00591
Further changes to pytest
tarang-jain Aug 1, 2022
a5cdf70
Add pytest, add support for distance metrics
tarang-jain Aug 2, 2022
b971aa5
Styling changes
tarang-jain Aug 2, 2022
26f7ab0
Resolved build failure
tarang-jain Aug 2, 2022
f46b3f4
Styling and copyright changes
tarang-jain Aug 2, 2022
2364455
Added prediction_data to get_param_names
tarang-jain Aug 2, 2022
291aa8d
Remove debug and sync statements from runner.h
tarang-jain Aug 9, 2022
cfeb6d3
Merge branch 'branch-22.10' of github.com:rapidsai/cuml into fea-all-…
tarang-jain Aug 10, 2022
e9c6ca4
Updated docs for CI failure
tarang-jain Aug 10, 2022
dfca0e4
Docs changes for failing CI
tarang-jain Aug 11, 2022
daed99b
Rename namespace to Predict
tarang-jain Aug 11, 2022
0d311ea
some changes after PR review
tarang-jain Aug 11, 2022
0a577e5
debugging failing gtest
tarang-jain Aug 11, 2022
589194a
Updates after PR changes (failing GTest handled)
tarang-jain Aug 12, 2022
ca81e87
Merge branch 'branch-22.10' of github.com:rapidsai/cuml into fea-all-…
tarang-jain Aug 12, 2022
43de057
Merge branch 'rapidsai:branch-22.10' into fea-all-points-membership-v…
tarang-jain Aug 12, 2022
dc8cbb4
Adding new separate hdbscan() function for prediction_data aftter PR …
tarang-jain Aug 15, 2022
db324f4
Merge branch 'fea-all-points-membership-vector-hdbscan' of github.com…
tarang-jain Aug 15, 2022
4f49a7d
Updates to docs
tarang-jain Aug 16, 2022
1768625
Update python docs
tarang-jain Aug 16, 2022
7e8b8e0
include count header for failing CI
tarang-jain Aug 16, 2022
c214898
New branch to account for latest membership vector changes
tarang-jain Aug 19, 2022
48b84c5
Fixing bug related to unconverted labels
tarang-jain Aug 19, 2022
3416c8b
Merge branch 'fea-all-points-membership-vector-hdbscan' of github.com…
tarang-jain Aug 19, 2022
464e3d3
Remove debug statements and formatting
tarang-jain Aug 19, 2022
8c62c4c
Debugging, Added google test
tarang-jain Aug 22, 2022
053fa3a
Added core_dists arg to build_linkage
tarang-jain Aug 22, 2022
10a71ba
Added pytest
tarang-jain Aug 23, 2022
ec168b4
fixing docs and copyright
tarang-jain Aug 23, 2022
d303396
Update names and docs after PR Review
tarang-jain Aug 24, 2022
8857aa5
Handling edge case when there are no clusters in the cluster tree
tarang-jain Aug 25, 2022
cb7ab05
Merge branch 'branch-22.10' of github.com:rapidsai/cuml into fea-appr…
tarang-jain Aug 25, 2022
ffd81ef
Update pytest
tarang-jain Aug 25, 2022
d353db7
Update pytest
tarang-jain Aug 25, 2022
dd12323
Further changes to pytest
tarang-jain Aug 25, 2022
3f90e47
Merge branch 'branch-22.10' of github.com:rapidsai/cuml into fea-appr…
tarang-jain Aug 25, 2022
ba47465
Handled failing gtest
tarang-jain Aug 26, 2022
b9a17de
Merge branch 'branch-22.10' of github.com:rapidsai/cuml into fea-appr…
tarang-jain Aug 26, 2022
297ca31
Merge branch 'branch-22.10' of github.com:rapidsai/cuml into fea-appr…
tarang-jain Aug 26, 2022
ac68d02
Edge case handling on python side
tarang-jain Aug 26, 2022
19bc57e
Merge with upstream (resolve conflicts)
tarang-jain Aug 26, 2022
687c73c
Resolving style changes
tarang-jain Aug 26, 2022
ba73c1f
Resolving merge conflicts, style issues
tarang-jain Aug 26, 2022
d7cba42
Refactoring approximate_predict (simplify)
tarang-jain Aug 29, 2022
d4e1caf
Update docs
tarang-jain Aug 29, 2022
dd408a6
Debug gtest after refactoring
tarang-jain Aug 29, 2022
adefc61
Resolved failing gtest
tarang-jain Aug 29, 2022
18ca7f9
Added index_into_children to PredictionData, avoid duplicate computat…
tarang-jain Aug 29, 2022
91d044e
Update docs
tarang-jain Aug 29, 2022
34c6533
Rename prediction_data_ attribute
tarang-jain Aug 29, 2022
55f7861
Check for dimensions mismatch
tarang-jain Aug 29, 2022
207e565
Little doc update from PR review
tarang-jain Aug 29, 2022
01d2830
Refactoring: added prediction utilities to prediction.pyx
tarang-jain Aug 29, 2022
b24215d
Style fix
tarang-jain Aug 30, 2022
18f68f6
Fix discrepancy in core distance computation
tarang-jain Aug 30, 2022
b2a418e
Final fix for discrepancy in core distance computation
tarang-jain Aug 30, 2022
0a99f3c
Merge branch 'branch-22.10' of github.com:rapidsai/cuml into fea-appr…
tarang-jain Aug 30, 2022
f7325d0
Update gtest
tarang-jain Aug 30, 2022
30fceb4
Update pytest
tarang-jain Sep 1, 2022
254d5e4
Merge branch 'branch-22.10' of github.com:rapidsai/cuml into fea-appr…
tarang-jain Sep 1, 2022
c3aad75
Adding pytests for moons and circles to test hdbscan approximate pred…
cjnolet Sep 2, 2022
6826568
Fixing line width
cjnolet Sep 2, 2022
f27e454
Add digits pytest
tarang-jain Sep 2, 2022
e17b4c2
Updates after PR reviews
tarang-jain Sep 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
51 changes: 37 additions & 14 deletions cpp/include/cuml/cluster/hdbscan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ template class CondensedHierarchy<int, float>;

/**
* Container object for computing and storing intermediate information needed later for computing
* membership vectors and approximate_predict.
* membership vectors and approximate predict. Users are only expected to create an instance of this
* object, the hdbscan method will do the rest.
* @tparam value_idx
* @tparam value_t
*/
Expand All @@ -322,6 +323,8 @@ class PredictionData {
n_selected_clusters(0),
selected_clusters(0, handle.get_stream()),
deaths(0, handle.get_stream()),
core_dists(m, handle.get_stream()),
index_into_children(0, handle.get_stream()),
n_exemplars(0),
n_rows(m),
n_cols(n)
Expand All @@ -339,21 +342,26 @@ class PredictionData {
value_idx* get_exemplar_label_offsets() { return exemplar_label_offsets.data(); }
value_idx* get_selected_clusters() { return selected_clusters.data(); }
value_t* get_deaths() { return deaths.data(); }
value_t* get_core_dists() { return core_dists.data(); }
value_idx* get_index_into_children() { return index_into_children.data(); }

/**
* Resize buffers to the required sizes for storing data
* @param handle raft handle for ordering cuda operations
* @param n_exemplars_ number of exemplar points
* @param n_selected_clusters_ number of clusters selected
* Resizes the buffers in the PredictionData object.
*
* @param[in] handle raft handle for resource reuse
* @param[in] n_exemplars_ number of exemplar points
* @param[in] n_selected_clusters_ number of selected clusters in the final clustering
* @param[in] n_edges_ number of edges in the condensed hierarchy
*/
void allocate(const raft::handle_t& handle,
value_idx n_exemplars_,
value_idx n_selected_clusters_);
value_idx n_selected_clusters_,
value_idx n_edges_);

/**
* Resize buffers for cluster deaths to n_clusters
* @param handle raft handle for ordering cuda operations
* @param n_clusters_
* @param n_clusters_ number of clusters
*/
void set_n_clusters(const raft::handle_t& handle, value_idx n_clusters_)
{
Expand All @@ -368,6 +376,8 @@ class PredictionData {
value_idx n_selected_clusters;
rmm::device_uvector<value_idx> selected_clusters;
rmm::device_uvector<value_t> deaths;
rmm::device_uvector<value_t> core_dists;
rmm::device_uvector<value_idx> index_into_children;
};

template class PredictionData<int, float>;
Expand Down Expand Up @@ -412,7 +422,7 @@ void hdbscan(const raft::handle_t& handle,
HDBSCAN::Common::hdbscan_output<int, float>& out);

/**
* Executes HDBSCAN clustering on an mxn-dimensional input array, X and builds the PredictionData
* Executes HDBSCAN clustering on an mxn-dimensional input array, X, then builds the PredictionData
* object which computes and stores information needed later for prediction algorithms.
*
* @param[in] handle raft handle for resource reuse
Expand Down Expand Up @@ -456,10 +466,23 @@ void _extract_clusters(const raft::handle_t& handle,
int max_cluster_size,
float cluster_selection_epsilon);

void _all_points_membership_vectors(const raft::handle_t& handle,
HDBSCAN::Common::CondensedHierarchy<int, float>& condensed_tree,
HDBSCAN::Common::PredictionData<int, float>& prediction_data,
float* membership_vec,
const float* X,
raft::distance::DistanceType metric);
void compute_all_points_membership_vectors(
const raft::handle_t& handle,
HDBSCAN::Common::CondensedHierarchy<int, float>& condensed_tree,
HDBSCAN::Common::PredictionData<int, float>& prediction_data,
const float* X,
raft::distance::DistanceType metric,
float* membership_vec);

void out_of_sample_predict(const raft::handle_t& handle,
HDBSCAN::Common::CondensedHierarchy<int, float>& condensed_tree,
HDBSCAN::Common::PredictionData<int, float>& prediction_data,
const float* X,
int* labels,
const float* points_to_predict,
size_t n_prediction_points,
raft::distance::DistanceType metric,
int min_samples,
int* out_labels,
float* out_probabilities);
} // END namespace ML
102 changes: 102 additions & 0 deletions cpp/src/hdbscan/detail/kernels/predict.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

namespace ML {
namespace HDBSCAN {
namespace detail {
namespace Predict {

template <typename value_idx, typename value_t>
__global__ void min_mutual_reachability_kernel(value_t* input_core_dists,
value_t* prediction_core_dists,
value_t* pairwise_dists,
value_idx* neighbor_indices,
size_t n_prediction_points,
value_idx min_samples,
value_t* min_mr_dists,
value_idx* min_mr_indices)
{
value_idx idx = blockDim.x * blockIdx.x + threadIdx.x;
if (idx < value_idx(n_prediction_points)) {
value_t min_mr_dist = std::numeric_limits<value_t>::max();
value_idx min_mr_ind = -1;
for (int i = 0; i < 2 * min_samples; i++) {
value_t mr_dist = prediction_core_dists[idx];
if (input_core_dists[neighbor_indices[idx * 2 * min_samples + i]] > mr_dist) {
mr_dist = input_core_dists[neighbor_indices[idx * 2 * min_samples + i]];
}
if (pairwise_dists[idx * 2 * min_samples + i] > mr_dist) {
mr_dist = pairwise_dists[idx * 2 * min_samples + i];
}
if (min_mr_dist > mr_dist) {
min_mr_dist = mr_dist;
min_mr_ind = neighbor_indices[idx * 2 * min_samples + i];
}
}
min_mr_dists[idx] = min_mr_dist;
min_mr_indices[idx] = min_mr_ind;
}
return;
}

template <typename value_idx, typename value_t>
__global__ void cluster_probability_kernel(value_idx* min_mr_indices,
value_t* prediction_lambdas,
value_idx* index_into_children,
value_idx* labels,
value_t* deaths,
value_idx* selected_clusters,
value_idx* parents,
value_t* lambdas,
value_idx n_leaves,
size_t n_prediction_points,
value_idx* predicted_labels,
value_t* cluster_probabilities)
{
value_idx idx = blockDim.x * blockIdx.x + threadIdx.x;
if (idx < value_idx(n_prediction_points)) {
value_idx cluster_label = labels[min_mr_indices[idx]];

if (cluster_label >= 0 && selected_clusters[cluster_label] > n_leaves &&
lambdas[index_into_children[selected_clusters[cluster_label]]] < prediction_lambdas[idx]) {
predicted_labels[idx] = cluster_label;
} else if (cluster_label >= 0 && selected_clusters[cluster_label] == n_leaves) {
predicted_labels[idx] = cluster_label;
} else {
predicted_labels[idx] = -1;
}
if (predicted_labels[idx] >= 0) {
value_t max_lambda = deaths[selected_clusters[cluster_label] - n_leaves];
if (max_lambda > 0) {
cluster_probabilities[idx] =
(max_lambda < prediction_lambdas[idx] ? max_lambda : prediction_lambdas[idx]) /
max_lambda;
} else {
cluster_probabilities[idx] = 1.0;
}
} else {
cluster_probabilities[idx] = 0.0;
}
}
return;
}

}; // namespace Predict
}; // namespace detail
}; // namespace HDBSCAN
}; // namespace ML