Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new tests for mean_shift algo #13179

Merged
merged 2 commits into from
Apr 25, 2019
Merged

Conversation

rajdeepd
Copy link
Contributor

Reference Issues/PRs

none

What does this implement/fix? Explain your changes.

Add test cases to cover un-tested portions of mean_shift.py

Any other comments?

no other comments

@rajdeepd
Copy link
Contributor Author

@ogrisel can help review this

def test_mean_shift_negative_bandwidth():
bandwidth = -1
ms = MeanShift(bandwidth=bandwidth)
msg = \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use parentheses to enclose expressions and split them over multiple lines rather than using \ for line continuation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jnothman this comment is not clear will following statement work?

msg = "bandwidth needs to be greater than zero or None,"
" got -1.000000"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will:

msg = ("bandwidth needs to be greater than zero or None," 
       " got -1.000000")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

sklearn/cluster/tests/test_mean_shift.py Outdated Show resolved Hide resolved

def test_seeds():
ms = MeanShift(seeds=None)
_ = ms.fit(X).labels_
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you get labels_?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

assert_raise_message(ValueError, msg, ms.fit, X)


def test_seeds():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get what this is testing. Checking that parameters are maintained should usually be covered by common tests not tests for each specific estimator

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

labels = ms.fit(X).labels_
labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)
assert_equal(n_clusters_ > n_clusters, True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use bare assert as with seeds above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

n_clusters_ = len(labels_unique)
assert_equal(n_clusters_ > n_clusters, True)

cluster_centers, labels = mean_shift(X, bandwidth=bandwidth,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than repeat the code, please use pytest.mark.parameterize to test multiple settings of bandwidth

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to use
pytest.mark.parameterize

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jnothman please review

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jnothman @ogrisel please review

@rajdeepd rajdeepd force-pushed the test_mean_shift branch 2 times, most recently from b018e99 to 4cf6413 Compare March 1, 2019 14:36
Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I confirm this covers untested lines.

bandwidth = -1
ms = MeanShift(bandwidth=bandwidth)
msg = ("bandwidth needs to be greater than zero or None,"
" got -1.000000")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whitespace looks like an error in the code raising the message. Please change the code to have a single space between the comma and "got"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unresolved. Please fix the error message in mean_shift_.py

(1.2, True, 3),
(1.2, False, 4)
])
def test_eval(bandwidth, cluster_all, expected):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean by calling this "eval"? Can't we just paramertrize test_mean_shift above, rather than adding a new test?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But ideally we should also test that cluster_all=False is actually effective at allowing some points to be left unclustered. Create a dataset where a point will be left with label -1 to test this properly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jnothman fixed as suggested

@jnothman
Copy link
Member

Please merge the current master

def test_mean_shift():
@pytest.mark.parametrize("bandwidth, cluster_all, expected, "
"first_cluster_label",
[(1.2, True, 3, 0), (1.2, False, 4, -1)])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much clearer, thanks!

bandwidth = -1
ms = MeanShift(bandwidth=bandwidth)
msg = ("bandwidth needs to be greater than zero or None,"
" got -1.000000")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unresolved. Please fix the error message in mean_shift_.py

@rajdeepd
Copy link
Contributor Author

rajdeepd commented Apr 2, 2019

@jnothman fixed the comments

@jnothman
Copy link
Member

jnothman commented Apr 2, 2019

Thanks @rajdeepd

@rajdeepd
Copy link
Contributor Author

rajdeepd commented Apr 5, 2019

@jnothman how do we get this pull request merged into master?

@jnothman
Copy link
Member

jnothman commented Apr 6, 2019

4 days is not long to wait for a second review, @rajdeepd... hopefully one will come soon.


cluster_centers, labels = mean_shift(X, bandwidth=bandwidth)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing this means we are not testing the mean_shift function directly anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are testing using
ms = MeanShift(bandwidth=bandwidth, cluster_all=cluster_all)
labels = ms.fit(X).labels_

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The testing of mean_shift should be independent of ms.fit. At the moment, ms.fit calls mean_shift, but we do not know how the code base will change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thomasjpfan do we need another test for mean_shift?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving the original test here will sufficiently test mean_shift.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thomasjpfan added test for mean_shift as well

ms = MeanShift(bandwidth=bandwidth)
msg = ("bandwidth needs to be greater than zero or None,"
" got -1.000000")
assert_raise_message(ValueError, msg, ms.fit, X)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are moving to using pytest.raises:

msg = (r"bandwidth needs to be greater than zero or None,"
       r" got -1\.000000")
with pytest.raises(ValueError, match=msg):
    ms.fit(X)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thomasjpfan fixed

@rajdeepd rajdeepd force-pushed the test_mean_shift branch 2 times, most recently from 71df239 to 1b9f928 Compare April 21, 2019 08:56
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM otherwise

n_clusters_ = len(labels_unique)
assert_equal(n_clusters_, n_clusters)
cluster_centers, labels_mean_shift = mean_shift(X, cluster_all=cluster_all)
print(cluster_centers)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

@@ -36,23 +37,36 @@ def test_estimate_bandwidth_1sample():
# Test estimate_bandwidth when n_samples=1 and quantile<1, so that
# n_neighbors is set to 1.
bandwidth = estimate_bandwidth(X, n_samples=1, quantile=0.3)
assert_array_almost_equal(bandwidth, 0., decimal=5)
assert_equal(bandwidth, 0.)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could just be assert a == b then

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated @NicolasHug

@NicolasHug NicolasHug merged commit 690464b into scikit-learn:master Apr 25, 2019
@NicolasHug
Copy link
Member

Thanks @rajdeepd

jeremiedbb pushed a commit to jeremiedbb/scikit-learn that referenced this pull request Apr 25, 2019
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
koenvandevelde pushed a commit to koenvandevelde/scikit-learn that referenced this pull request Jul 12, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants