ENH check_classification_targets raises a warning when unique classes > 50% of n_samples#26335
Conversation
thomasjpfan
left a comment
There was a problem hiding this comment.
I am -1 with this change in check_classification_targets. If we want a quick fix, I'll include the warning in type_of_target.
diff --git a/sklearn/utils/multiclass.py b/sklearn/utils/multiclass.py
index 24e528d10a..35e2752212 100644
--- a/sklearn/utils/multiclass.py
+++ b/sklearn/utils/multiclass.py
@@ -387,7 +387,13 @@ def type_of_target(y, input_name=""):
# Check multiclass
first_row = y[0] if not issparse(y) else y.getrow(0).data
- if xp.unique_values(y).shape[0] > 2 or (y.ndim == 2 and len(first_row) > 1):
+ classes = xp.unique_values(y)
+ if classes.shape[0] > round(0.5 * y.shape[0]):
+ warnings.warn(
+ r"The number of unique classes is greater than 50% of the samples."
+ )
+
+ if classes.shape[0] > 2 or (y.ndim == 2 and len(first_row) > 1):
# [1, 2, 3] or [[1., 2., 3]] or [[1, 2]]
return "multiclass" + suffix
else:This way the classes are reused for the check and do not need to be recomputed.
| ) | ||
|
|
||
| if n_samples is not None: | ||
| if len(np.unique(y)) > round(0.5 * n_samples): |
There was a problem hiding this comment.
This runs counter to the second point in #16399:
check_classification_targets which calls type_of_target triggers a call to _assert_all_finite and np.unique(y) which are redundant with checks done elsewhere
In general, including another np.unique(y) adds another n*log(n) operation. (np.unique sorts the data)
|
Hey, sorry for the late reply. I made the changes. Using |
|
Apologies once again for the late reply, my college exams are going on and I might not get a lot of time this month.
Yes, sure. I had tried the following code earlier: along with this as the test:
DetailsI was taking a look at the documentation to see if there was another way to raise warnings and I came across warnings.formatwarning, which allowed me to execute all the tests without any errors. I do realize the current implementation is a bit of a hacky workaround, with |
I think this is because |
14aa5a8 to
3988f80
Compare
Hey @betatim, I have removed |
Update: I have removed the merge conflicts and modified the tests as per your suggestions. I have used a try/except block to check for any |
glemaitre
left a comment
There was a problem hiding this comment.
I fixed the conflict and modified the logic at the same time. Now, we raise only if we have more than 20 samples in y. For instance, it avoids to raise the warnings for toy dataset where we test with a couple of samples.
It would avoid spurious warning in the documentation.
|
@betatim do you want to have an additional look at this PR? |
|
Btw this PR would be great since it addresses the issue of someone using a classifier for a regression task. |
| # less than 20 samples, no warning should be raised | ||
| y = np.arange(10) | ||
| with warnings.catch_warnings(): | ||
| warnings.simplefilter("error", UserWarning) |
There was a problem hiding this comment.
Does anyone have wisdom on whether we should use warnings.simplefilter("error", UserWarning) or warnings.simplefilter("error")?
As far as I can tell the test passes with both options, but I dont know which is the "better" way of testing that no warnings were raised. Maybe it doesn't matter
There was a problem hiding this comment.
removing UserWarning to make sure nothing's raised.
betatim
left a comment
There was a problem hiding this comment.
LGTM. I think @thomasjpfan's concerns were addressed
…cikit-learn into enh_check_classification_targets
|
Thanks everyone for updating and merging the PR! |
|
Could someone please improve the warning message and mention that it might be a regression problem instead of classification. This part is really missing. |
Reference Issues/PRs
Towards #16399
What does this implement/fix? Explain your changes.
Addresses the first point of #16399 (comment)
Any other comments?