-
-
Notifications
You must be signed in to change notification settings - Fork 26.1k
[MRG+1] SKF raises error if all n_labels for individual classes <n_folds #6182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@rvraghav93 @MechCoder could you please check this? Thanks! |
" members, which is too few. The minimum" | ||
" number of labels for any class cannot" | ||
" be less than n_folds=%d." | ||
% (min_labels, self.n_folds)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be np.all(self.n_folds > label_counts)
. It is fine if the least populated class has fewer than n_folds. It is a problem if all the classes have n_labels lesser than n_folds.
@MechCoder is this fine? |
This looks good! Thanks. @MechCoder or @amueller for a second review? |
(Squash the commits please) |
assert_raises(ValueError, cval.StratifiedKFold, y, 0) | ||
assert_raises(ValueError, cval.StratifiedKFold, y, 1) | ||
assert_raises(ValueError, cval.StratifiedKFold, y2, 0) | ||
assert_raises(ValueError, cval.StratifiedKFold, y2, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you check for the err message maybe? (assert_raises_msg
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(As the ValueError
could sometimes be from a different error)
Please update whatsnew |
@rvraghav93 @MechCoder is this fine? |
@rvraghav93 @MechCoder how can I restart this build? I this test failing is not related to my pr :/ |
Ah don't mind about that. appveyor gets whacky at times. (And I've confirmed that the current appveyor failure is unrelated to your PR). As the commit hash is dependent on the commit time you could reset, re-commit and force push for the build to get restarted. |
@@ -92,6 +92,9 @@ Enhancements | |||
Bug fixes | |||
......... | |||
|
|||
- :class:`StratifiedKFold` now raises error if all n_labels for individual classes is less than n_folds. | |||
(`#6182 <https://github.com/scikit-learn/scikit-learn/pull/6182>`_) by `Raghav R V`_, `Manoj Kumar`_ and `Devashish Deshpande`_. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whoa that was generous ;) Please add your name alone :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hahaha I wouldn't really like to do that though 🍻
Thanks for addressing the comments. Looks good to me apart from the nitpicks. Wait for @jnothman or @MechCoder |
@@ -519,6 +519,12 @@ def __init__(self, y, n_folds=3, shuffle=False, | |||
unique_labels, y_inversed = np.unique(y, return_inverse=True) | |||
label_counts = bincount(y_inversed) | |||
min_labels = np.min(label_counts) | |||
# Raise error when all the n_labels for individual classes | |||
# are less than n_folds |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Umm I think the code is self explanatory
lgtm pending nitpicks |
ping @TomDLT to just verify? |
Done :) |
I don't find the word |
@TomDLT yeah it's concise too. Should I go ahead and make the changes? |
Please do! |
Done! |
Thanks !! 🍹 🍷 |
No problem :) Thanks for the help! |
Addresses #6177.