-
Notifications
You must be signed in to change notification settings - Fork 523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEA] membership_vector for HDBSCAN #5247
Changes from 27 commits
e49d06a
436b180
48030b8
7912dba
4b41edb
fe0fd34
9d5badc
19f9dd8
a4b565c
1f4bf78
fdf100b
e18096a
c2aa77e
182ba31
366ef26
b60d869
6bfaae2
c4e0bf1
fb634e4
a49ba87
4ed9fd7
fa7b44e
45f8ca4
367de04
980b1f7
98aa237
d387026
ed40e22
387cde8
092b3f8
ef85fd3
17de9ec
38208ec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ __global__ void merge_height_kernel(value_t* heights, | |
value_idx* parents, | ||
size_t m, | ||
value_idx n_selected_clusters, | ||
MLCommon::FastIntDiv n, | ||
value_idx* selected_clusters) | ||
{ | ||
value_idx idx = blockDim.x * blockIdx.x + threadIdx.x; | ||
|
@@ -54,6 +55,7 @@ __global__ void merge_height_kernel(value_t* heights, | |
|
||
if (took_left_parent && took_right_parent) { | ||
heights[idx] = lambdas[index_into_children[last_cluster]]; | ||
// printf("%d %d %d %f\n", row, col, last_cluster, heights[idx]); | ||
} | ||
|
||
else { | ||
|
@@ -62,9 +64,53 @@ __global__ void merge_height_kernel(value_t* heights, | |
} | ||
} | ||
|
||
template <typename value_idx, typename value_t, int tpb = 256> | ||
__global__ void merge_height_kernel(value_t* heights, | ||
value_t* lambdas, | ||
value_t* prediction_lambdas, | ||
value_idx* min_mr_indices, | ||
value_idx* index_into_children, | ||
value_idx* parents, | ||
size_t n_prediction_points, | ||
value_idx n_selected_clusters, | ||
MLCommon::FastIntDiv n, | ||
value_idx* selected_clusters) | ||
{ | ||
value_idx idx = blockDim.x * blockIdx.x + threadIdx.x; | ||
if (idx < value_idx(n_prediction_points * n_selected_clusters)) { | ||
value_idx row = idx / n; | ||
value_idx col = idx % n; | ||
value_idx right_cluster = selected_clusters[col]; | ||
value_idx left_cluster = parents[index_into_children[min_mr_indices[row]]]; | ||
bool took_right_parent = false; | ||
bool took_left_parent = false; | ||
value_idx last_cluster; | ||
|
||
while (left_cluster != right_cluster) { | ||
if (left_cluster > right_cluster) { | ||
took_left_parent = true; | ||
last_cluster = left_cluster; | ||
left_cluster = parents[index_into_children[left_cluster]]; | ||
} else { | ||
took_right_parent = true; | ||
last_cluster = right_cluster; | ||
right_cluster = parents[index_into_children[right_cluster]]; | ||
} | ||
} | ||
|
||
if (took_left_parent && took_right_parent) { | ||
heights[idx] = lambdas[index_into_children[last_cluster]]; | ||
} | ||
|
||
else { | ||
heights[idx] = prediction_lambdas[row]; | ||
} | ||
} | ||
} | ||
|
||
template <typename value_idx, typename value_t> | ||
__global__ void prob_in_some_cluster_kernel(value_t* heights, | ||
value_t* height_argmax, | ||
value_idx* height_argmax, | ||
value_t* deaths, | ||
value_idx* index_into_children, | ||
value_idx* selected_clusters, | ||
|
@@ -77,9 +123,34 @@ __global__ void prob_in_some_cluster_kernel(value_t* heights, | |
value_idx idx = blockDim.x * blockIdx.x + threadIdx.x; | ||
if (idx < (value_idx)m) { | ||
value_t max_lambda = max(lambdas[index_into_children[idx]], | ||
deaths[selected_clusters[(int)height_argmax[idx]] - n_leaves]); | ||
deaths[selected_clusters[height_argmax[idx]] - n_leaves]); | ||
prob_in_some_cluster[idx] = | ||
heights[idx * n_selected_clusters + height_argmax[idx]] / max_lambda; | ||
return; | ||
} | ||
} | ||
|
||
template <typename value_idx, typename value_t> | ||
__global__ void prob_in_some_cluster_kernel(value_t* heights, | ||
value_idx* height_argmax, | ||
value_t* prediction_lambdas, | ||
value_t* deaths, | ||
value_idx* index_into_children, | ||
value_idx* min_mr_indices, | ||
value_idx* selected_clusters, | ||
value_t* lambdas, | ||
value_t* prob_in_some_cluster, | ||
value_idx n_selected_clusters, | ||
value_idx n_leaves, | ||
size_t n_prediction_points) | ||
{ | ||
value_idx idx = blockDim.x * blockIdx.x + threadIdx.x; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like another candidate for |
||
if (idx < (value_idx)n_prediction_points) { | ||
value_t max_lambda = | ||
max(prediction_lambdas[idx], deaths[selected_clusters[height_argmax[idx]] - n_leaves]) + | ||
1e-8; | ||
prob_in_some_cluster[idx] = | ||
heights[idx * n_selected_clusters + (int)height_argmax[idx]] / max_lambda; | ||
heights[idx * n_selected_clusters + height_argmax[idx]] / max_lambda; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please compute |
||
return; | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use the version of this in
raft::util
please? This should really be removed from cuml altogether now that we have a version in raft.