-
-
Notifications
You must be signed in to change notification settings - Fork 26.5k
WIP Stratified Shuffle Split #1060
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
can you comment with the previous use case that was failing and that now works? cc/ @npinto |
Exampley = np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]) for train, test in sss: Outputindices [2, 1, 6, 4, 7, 9] [0, 5, 8] So it appears to work now, where it previously failed badly. If you tweak the test_size parameter you might end up with test sets too small to contain all the classes (e.g. try test_size=.1). The cv scheme itself is not to modify here I think, but rather the validation function that has to check extra things (typically that the test and train sets are not smaller than the number of classes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not important since you don't use it in the outer loop, but you are redefining the i variable in the nested loop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apparently, you didn't address @fabianp's comment yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mblondel: I did, the i variable is no longer defined in the outer loop.
Could you make a unit test out of this? Basically, I think that a good Something that I don't really like, is that the class are sorted in the |
|
I played with the class for a bit and it seems to do what I want. Regarding the sorted classes--can't we just throw a permutation on the training and test indices? |
|
@kyleabeauchamp Thanks for tackling this. Yes, doing a permutation would be ok. It might be possible to avoid it but I'd rather go for an easy to understand solution, even if we go over the indices once more. |
|
Thanks for the comments, if you all agree, before I make the implementation more efficient and cleaner, I'd rather add the tests and make it rock solid. The other stuff are non/less essential and would come later. |
|
I added tests, and permuted the test and train sets as asked. It should be pretty good now. One comment I have is that the implementation itself is not very efficient and starts to be a bit slow when you get around 1M samples. |
|
It would be great if someone could volunteer to add benchmarks for CV iterators and other utility functions (such as the score functions in the metrics module) to the benchmark suite from scikit-learn-speed. Current benchmark source code lives in this folder: https://github.com/scikit-learn/scikit-learn-speed/blob/master/benchmarks/ and use these templates: https://github.com/scikit-learn/scikit-learn-speed/blob/master/benchmarks/templates.py Currently all of those benchmarks use the same template which is focused to bench classes that implement the fit / predict API but nothing prevent us to add other utility functions or classes to the benchmark suite. Maybe @vene you could add a new benchmark for a non-fit-predict object to have an first example for new benchmark contributors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add assert_equal or assert_almost_equal assertions too? (to check that the proportions of each class are roughly respected)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I would argue that we don't care if it does a better job than ShuffleSplit. We just care that it does the job as expected. So, I would just remove the above inequality assertions (comparison with ShuffleSplit).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we want something like this (incomplete snippet), where we check
that the observed training counts are within +-1 of their desired
values. The +-1 results because desired_train_counts might be a float,
so we can only assume equality to within rounding.
observed_counts = np.bincount(y,minlength=y.max()+1)
observed_probs = 1.*observed_counts / observed_counts.sum()
desired_train_counts = test_size*observed_probs
desired_test_counts = test_size*observed_probs
cv = sklearn.cross_validation.StratifiedShuffleSplit(y,indices=True)
for train_ind, test_ind in cv:
y_train, y_test = y[train_ind],y[test_ind]
training_counts = np.array([sum(y_train==i) for i in range(y.max()+1)])
test_counts = np.array([sum(y_test==i) for i in range(y.max()+1)])
np.testing.assert_true(np.abs(train_counts - desired_train_counts).max()
<= 1.)
np.testing.assert_true(np.abs(test_counts - desired_test_counts).max()
<= 1.)On 08/25/2012 03:38 AM, Mathieu Blondel wrote:
In sklearn/tests/test_cross_validation.py:
@@ -133,6 +133,17 @@ def test_stratified_shuffle_split():
assert_true(train_std[i] <= np.std(np.bincount(y[train])))
assert_true(test_std[i] <= np.std(np.bincount(y[test])))Also, I would argue that we don't care if it does a better job than
|ShuffleSplit|. We just care that it does the job as expected. So, I
would just remove the above inequality assertions (comparison with
|ShuffleSplit|).—
Reply to this email directly or view it on GitHub
https://github.com/scikit-learn/scikit-learn/pull/1060/files#r1459507.
|
@schwarty could you adapt the problem reported by Dan on the mailing list into a non regression test? |
|
I think the problem reported by Dan actually happens with the previous implementation of the StratifiedShuffleSplit. Basically just check your installation and you should be good! |
|
@schwarty Can you still please add a test? |
|
@amueller: Done. I also added additional validation for corner cases, and the associated tests. And I replaced the comparison to the ShuffleSplit by something that should be more relevant. |
|
Thanks. Do you think this should be ok now and should I or someone else have another close look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you also check that the size of the two sets together gives the total training set size?
And that training and test don't overlap?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok that's also done. And I think @GaelVaroquaux would like to have another look before we merge it. But he should be busy for the next couple of days...
…training and testing sets, and that they don't overlap
sklearn/cross_validation.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me that you could use np.minimum, and be more readable. No?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, done
sklearn/cross_validation.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that you code this as:
train = rng.permutation(train) test = rng.permutation(test)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to know you can pass a sequence to permutation, thanks for the tip. Done as well.
|
LGTM. +1 for merge. Good work at @schwarty : you draw almost no complaints from me :) |
|
Did you address the scalability issue? I'm working with a 1 million example dataset and wants to use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ogrisel : that's the line doing it (147)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops sorry I missed it.
|
@mblondel : I didn't change anything regarding the speed, is it too slow at the moment? |
|
@schwarty Haven't tried yet and won't have time to try before next week. We can merge and optimize later if necessary (correctness is more important). |
|
LGTM, +1 for merging. |
|
@mblondel I agree, FYI currently it takes around 0.8s per fold |
|
Looks good, merging. Thanks a lot for the fix! |
|
I get this error: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand this test (maybe it is to late). This calls the constructor, right? Where does the constructor do input validation?
It shouldn't, that should be done in fit. But I don't see where it happens at all.
|
nondeterministic test failures, my favourite -_- |
|
Ok was just a doctest. Hopefully I fixed it and merged. |
|
@amueller did you merge or just closed the PR this time? |
|
It shows up in the commits so I guess I did what I intended for once ;) |
In terms of speed, np.unique(y) should be computed only once: with many |
|
@GaelVaroquaux: I fixed that in master. @schwarty: CS 101 "Don't repeat the same computation twice" :) |
Arhh, gut! I have really bad Internet connection, so I am a hard time |
|
Yeah I really have problems catching up with all that is going on! Crazy :) |
|
@GaelVaroquaux : no, I fixed that thanks to your remark! :) Am I the only one who hate the new notification system in github? I'm flooded with notifications now. I preferred the old system: notifications on mentions, new PRs and commit comments... |
Same thing here, its a nightmare. I have the feeling that it is killing |
|
@mblondel yeah it does flood my inbox :-/ |
It definitely requires further testing, I'm just interested in knowing if you can find cases where it doesn't behave properly so that it can be fixed. And I would like to be positively sure it works properly before discussing design considerations.