-
Notifications
You must be signed in to change notification settings - Fork 514
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MG weighted k-means #3959
Add MG weighted k-means #3959
Conversation
Conflicts: CHANGELOG.md
Conflicts: python/cuml/test/dask/test_kmeans.py python/cuml/test/test_kmeans.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for picking this one up. Overall it looks great but we do still have an issue to fix (see comment in the review).
@@ -620,6 +657,9 @@ void fit(const raft::handle_t &handle, const KMeansParams ¶ms, | |||
MLCommon::device_buffer<char> workspace(handle.get_device_allocator(), | |||
stream); | |||
|
|||
// check if weights sum up to n_samples | |||
checkWeights(handle, workspace, weight, stream); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great you pulled this over from cumlprims! IIRC, the one remaining issue should be that the single-GPU k-means normalizes the weights in predict, however that will cause the multi-gpu version to normalize each partition individually since it's embarrassingly parallel.
The weights are being normalized globally in the Dask-based predict but the single-GPU predict is going to re-normalize them locally. The more straightforward path to fixing this might be to have the C++ predict()
function accept a normalize_weights
argument which defaults to true
but we can have the multi-GPU predict function flip it off. The goal here is to eliminate the need for predict()
to use the comms, because then it would no longer be able to execute embarassingly parallel.
rerun tests |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM pending CI
rerun tests |
1 similar comment
rerun tests |
Docstring fix identified in CI:
|
@gpucibot merge |
Codecov Report
@@ Coverage Diff @@
## branch-21.08 #3959 +/- ##
===============================================
Coverage ? 85.46%
===============================================
Files ? 230
Lines ? 18116
Branches ? 0
===============================================
Hits ? 15482
Misses ? 2634
Partials ? 0
Flags with carried forward coverage won't be shown. Click here to find out more. Continue to review full report at Codecov.
|
This PR adds support for MG weighted k-means and is a continuation of @akkamesh and @cjnolet work on PR rapidsai#2126. Authors: - Micka (https://github.com/lowener) - https://github.com/akkamesh - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: rapidsai#3959
This PR adds support for MG weighted k-means and is a continuation of @akkamesh and @cjnolet work on PR #2126.