Join GitHub today
GitHub is home to over 28 million developers working together to host and review code, manage projects, and build software together.Sign up
[MRG+1] fix sampling in stratified shuffle split #6472
I'll remove the asserts etc if we can agree on this implementation (and how the tests need to be changed).
This tries to adjust
Also, it adds in the "missing" samples based on their class frequency. So on expectation, this should be doing exactly the right thing (I think, I didn't have any coffee today yet).
The ugly bit is
np.bincount(rng.permutation(np.repeat(range(n_classes), left_over_per_class))[:missing_train], minlength=n_classes)
which is the only way I new to say "sample
The first test that breaks checks that if something with
ok so I can't spend that much more time on it, but I think the newest version without sampling is ok.
What I'm doing now is computing the most likely draw from the original
There is an alternative, which is "draw n_train points from class_counts, draw n_test points from class_counts, and make sure that we they don't sum up to more than class_counts". While this behavior might be a bit more "intuitive", the "ensure that they don't sum up to more than class_counts" part is more or less what was buggy in #6121.
So I'd rather stay with the simpler semantics of doing one draw and then the other draw from what's left over.
I have no idea what to do with the failing tests, though.
The distribution between training and test part are not the same to one digit.
So indeed, the test that is failing passes on master with this configuration:
So it violated the n_train / n_test sizes and put one more in the training set than it should have, to balance the classes.
It is also not clear to me, how this block of the previous code preserves the class proportionality in the previous code?. Was it simply that it was not tested enough?
@MechCoder it doesn't. It just samples randomly. There are very strict tests that passed.
My intuition of the problem was: n_train and n_test are specified directly by the user, so these are hard limits. The stratification is a "best effort" kind of thing. The best possible is to pick a mode of the multivariate hypergeometric, but doing that exactly would require quite a bit of code. So I do an approximation that might be off by one per class (I have not proven this bound).