-
Notifications
You must be signed in to change notification settings - Fork 194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Expose cluster_cost to python #1028
Expose cluster_cost to python #1028
Conversation
Add cython bindings for the cluster_cost function, to allow computing inertia from python. Closes rapidsai#972
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes LGTM. Just need assertions in the pytest.
# cython: embedsignature = True | ||
# cython: language_level = 3 | ||
|
||
from pylibraft.common.handle cimport handle_t |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm liking this convention of putting the wrapper cython definitions in a cpp
directory.
centroids = X[:n_clusters] | ||
centroids_device = device_ndarray(centroids) | ||
|
||
# TODO: compute inertia naively, make sure is close |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can use our pairwise distances to compute this naively :-P
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done - thanks for the pointer!
I've marked this ready to review - even though I don't think the results being returned here are actually correct because of #1036 ( and probably shouldn't merge this until we figure that out ) |
[] __device__(const raft::KeyValuePair<IndexType, ElementType>& a) { return a.value; }); | ||
|
||
rmm::device_scalar<ElementType> device_cost(0, handle.get_stream()); | ||
raft::cluster::kmeans::cluster_cost( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we're doing an additional computation here that can be avoided by running the fusedL2NNMinReduce
w/ sqrt=False
. If we have the squared distances already then we should just be able sum them up to get the inertia
score. I'm thinking this might also help us avoid running into #1036 in this PR (assuming my assumption is correct that the issue is related to sqrt=True
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could do a simple L1 norm computation for the reduction, I guess.
raft/cpp/include/raft/linalg/norm.cuh
Line 109 in d6df557
void norm(const raft::handle_t& handle, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thats a good point - and I've made the change to fusedL2NNMinReduce w/ sqrt=False. I checked the cluster cost, and its already only doing a sum reduction - so using sqrt=True was wrong here.
This also lets us avoid the issue in #1036 - afaict you are right and this issue is only with sqrt=True (all the tests pass at least now =) ).
inertia is sum of squared distances, we were doing sum of euclidean distances
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Just a very small comment but I'm pre-approving since it's pretty minor and we could do it in a follow-on if needed.
@@ -88,3 +88,31 @@ def test_compute_new_centroids( | |||
actual_centers = new_centroids_device.copy_to_host() | |||
|
|||
assert np.allclose(expected_centers, actual_centers, rtol=1e-6) | |||
|
|||
|
|||
@pytest.mark.parametrize("n_rows", [8]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be a good idea to test this with a slightly larger number like 100.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch! I had reduced it for debugging -
its back at 100 in the last commit
@gpucibot merge |
cc @betatim |
@gpucibot merge |
Add cython bindings for the cluster_cost function, to allow computing inertia from python.
Closes #972