Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+1] Fix bug in StratifiedShuffleSplit for multi-label data with targets having > 1000 labels #9922

Merged

Conversation

@cbrummitt
Copy link
Contributor

@cbrummitt cbrummitt commented Oct 13, 2017

This PR fixes a bug for multi-label targets in StratifiedShuffleSplit. The solution being used now is the "label powerset" method: each sequence of labels is mapped to a string with str(row), which transforms a multi-label problem into a multi-class problem.

To see the source of the problem, note that len(str(np.arange(1000))) returns 4056 while len(str(np.arange(1001))) returns 36. The reason is that arrays with > 1000 elements are truncated with an ellipsis: str(np.arange(1001)) gives '[ 0 1 2 ..., 998 999 1000]'. Thus, for multi-label targets with > 1000 labels, samples are mapped onto the same short string whenever their first three values and last three values are the same, which is not the intended behavior.

The solution proposed here, discussed with @vene in this comment thread, is to use ' '.join(row.astype('str')) to convert each target to a string. We are guaranteed that we can do call .astype('str') on row because y = check_array(y, ensure_2d=False, dtype=None) converts y to a numpy array.

As an added benefit, this approach ends up being several faster than str(row) when len(row) < 1000:

In [1]: import numpy as np
In [2]: row = np.random.randint(0, 2, size=500)
In [3]: %timeit ' '.join(row.astype('str'))
169 µs ± 2.39 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In [4]: %timeit str(row)
738 µs ± 23.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

I also considered using sklearn.utils.murmurhash3_32 (suggested by @vene as another option) but concluded it was tricky to coerce all the kinds of labels people might use into an int32 data type.

…ecause str(row) uses an ellipsis when len(row) > 1000
@vene
Copy link
Member

@vene vene commented Oct 13, 2017

Awesome, thank you for catching this and solving it!

Could you also modify the relevant test so that it fails without the patch? Thanks!

@jnothman jnothman added this to the 0.19.1 milestone Oct 15, 2017
@lesteve
Copy link
Member

@lesteve lesteve commented Oct 16, 2017

It would be nice to add a test.

@cbrummitt
Copy link
Contributor Author

@cbrummitt cbrummitt commented Oct 16, 2017

Good idea @vene and @lesteve. I added a test for a y with > 1000 labels. I simply added a new function test_stratified_shuffle_split_multilabel_many_labels to test_split.py that fails on the old method and passes with this bug fix.

Is there anything else that would need to be done to hook up this test?

Copy link
Member

@jnothman jnothman left a comment

LGTM

@jnothman jnothman changed the title Fix bug in StratifiedShuffleSplit for multi-label data with targets having > 1000 labels [MRG+1] Fix bug in StratifiedShuffleSplit for multi-label data with targets having > 1000 labels Oct 17, 2017
@lesteve
Copy link
Member

@lesteve lesteve commented Oct 17, 2017

LGTM, merging, thanks a lot!

@lesteve lesteve merged commit d074e40 into scikit-learn:master Oct 17, 2017
6 checks passed
6 checks passed
ci/circleci Your tests passed on CircleCI!
Details
codecov/patch 100% of diff hit (target 96.16%)
Details
codecov/project 96.17% (+<.01%) compared to 1c77257
Details
continuous-integration/appveyor/pr AppVeyor build succeeded
Details
continuous-integration/travis-ci/pr The Travis CI build passed
Details
lgtm analysis: Python No alert changes
Details
jnothman added a commit to jnothman/scikit-learn that referenced this pull request Oct 17, 2017
…argets having > 1000 labels (scikit-learn#9922)

* Use ' '.join(row) for multi-label targets in StratifiedShuffleSplit because str(row) uses an ellipsis when len(row) > 1000
* Add a new test for multilabel problems with more than a thousand labels
@cbrummitt cbrummitt deleted the cbrummitt:fix-multilabel-StratifiedShuffleSplit branch Oct 19, 2017
maskani-moh added a commit to maskani-moh/scikit-learn that referenced this pull request Nov 15, 2017
…argets having > 1000 labels (scikit-learn#9922)

* Use ' '.join(row) for multi-label targets in StratifiedShuffleSplit because str(row) uses an ellipsis when len(row) > 1000
* Add a new test for multilabel problems with more than a thousand labels
jwjohnson314 pushed a commit to jwjohnson314/scikit-learn that referenced this pull request Dec 18, 2017
…argets having > 1000 labels (scikit-learn#9922)

* Use ' '.join(row) for multi-label targets in StratifiedShuffleSplit because str(row) uses an ellipsis when len(row) > 1000
* Add a new test for multilabel problems with more than a thousand labels
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked issues

Successfully merging this pull request may close these issues.

None yet

4 participants