-
-
Notifications
You must be signed in to change notification settings - Fork 25.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG+1] cross_val_predict handles multi-label predict_proba #8773
[MRG+1] cross_val_predict handles multi-label predict_proba #8773
Conversation
Modify the `cross_val_predict` and `_fit_and_predict` functions so that they handle multi-label (and multi-class multi-label) classification problems with `predict_proba`, `predict_log_proba`, and `decision_function` methods. There's two different kinds of multi-label outputs from scikit-learn estimators. The `OneVersusRestClassifier` handles multi-label tasks with binary indicator target arrays (no multi-label targets). It outputs 2D arrays from `predict_proba`, etc. methods. The `RandomForestClassifier` handles multi-class multi-label problems. It outputs a list of 2D arrays from `predict_proba`, etc. Recognize the RandomForest-like outputs by type-checking. Lists of 2D arrays require slightly different code for keeping track of indices.
7e21876
to
ca2ee0b
Compare
Sparse outputs aren't so relevant to |
@@ -419,9 +428,20 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1, | |||
# Check for sparse predictions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this, or move it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
@@ -419,9 +428,20 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1, | |||
# Check for sparse predictions | |||
if sp.issparse(predictions[0]): | |||
predictions = sp.vstack(predictions, format=predictions[0].format) | |||
elif do_manual_encoding and isinstance(predictions[0], list): | |||
n_labels = y.shape[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a comment here is deserved to remind us what we've got and where it's going
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a comment.
else: | ||
predictions = np.concatenate(predictions) | ||
return predictions[inv_test_indices] | ||
|
||
if do_manual_encoding and isinstance(predictions, list): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see why you need do_manual_encoding here...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, you're right. Checking if predictions
is a list is enough. Changed.
if method in ['decision_function', 'predict_proba', 'predict_log_proba']: | ||
le = LabelEncoder() | ||
y = le.fit_transform(y) | ||
do_manual_encoding = method in ['decision_function', 'predict_proba', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like this variable name. Perhaps just encode
, encoded
or is_encoded
would do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to encode
.
@jnothman , apologies for taking a long time to respond here. I believe I've addressed your comments. |
@jnothman , would you like additional changes here? |
Oh, I thought we'd merged this fix at some point :( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is good, but I'm too tired to be sure!
doc/whats_new.rst
Outdated
@@ -178,6 +178,10 @@ Enhancements | |||
removed by setting it to `None`. | |||
:issue:`7674` by :user:`Yichuan Liu <yl565>`. | |||
|
|||
- Added ability for :func:`model_selection.cross_val_predict` to handle multi-label | |||
(and multi-class multi-label) targets with `predict_proba`-type methods. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
double backticks
elif encode and isinstance(predictions[0], list): | ||
# `predictions` is a list of method outputs from each fold. | ||
# If each of those is also a list, then treat this as a | ||
# multi-class multi-label task. We need to separately concatenate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean multi-output multiclass?
@jnothman , thanks for taking another look at this PR. I added double backticks and changed "multi-class multi-label" to "multioutput-multiclass" (the spelling I found in existing documentation) in that comment and in the What's New. |
definitely wasn't a full review yet. All of the class list games are so
ugly!
…On 28 Jul 2017 2:43 am, "Stephen Hoover" ***@***.***> wrote:
@jnothman <https://github.com/jnothman> , thanks for taking another look
at this PR. I added double backticks and changed "multi-class multi-label"
to "multioutput-multiclass" (the spelling I found in existing
documentation) in that comment and in the What's New.
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#8773 (comment)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AAEz6xTsIRXhSnVb8v6xEUNzcjWRJIcXks5sSL5EgaJpZM4NEX4n>
.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not convinced that the need for _enforce_prediction_order
is tested here. Can we make sure this is tested on data where the set of classes for training each fold will vary, such as y = [1,2,3,4,5]
and Y = [[0, 1], [1, 2], [0, 3], [1, 4], [0, 5]]
?
It wouldn't hurt if check_cross_val_predict_with_method
explicitly checked that the shape was as expected, either...
expected_predictions[test] = func(X[test]) | ||
preds = func(X[test]) | ||
if isinstance(predictions, list): | ||
for i_label in range(y.shape[1]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call it output rather than label, please
func = getattr(est, method) | ||
|
||
# Naive loop (should be same as cross_val_predict): | ||
for train, test in kfold.split(X, y): | ||
est.fit(X[train], y[train]) | ||
expected_predictions[test] = func(X[test]) | ||
preds = func(X[test]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused. Doesn't this mean that sometimes we're going to have mismatched numbers of classes vis-a-vis what we try to solve with _enforce_prediction_order
?
Let individual test functions control the for loops.
@jnothman , you're right that the tests weren't covering the case where |
then the output prediction array might not have the same | ||
columns as other folds. Use the list of class names | ||
(assumed to be integers) to enforce the correct column order. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth fast-pathing the classes == arange case? Or is that premature optimisation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good idea. We can return immediately if all classes were present in the subset of data used to train this fold.
predictions_ = np.zeros((predictions.shape[0], n_classes), | ||
dtype=predictions.dtype) | ||
if one_col_if_binary and len(classes) == 2: | ||
predictions_[:, classes[-1]] = predictions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This leaves one of the columns zeroed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's what we have to do if n_classes >= 3
. But you're right that this is a problem if n_classes = 2
. That should have a single column of output. I added a new test to catch this, test_cross_val_predict_binary_decision_function
. It's handled now by having the function return predictions
directly when len(classes) == n_classes
.
|
||
|
||
def test_cross_val_predict_with_method_multilabel_rf(): | ||
# The RandomForest allows anything for the contents of the labels. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The wording here is unclear. Do you just mean that RF handles multiclass-multioutput and produces predict_proba
in that vein?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to "The RandomForest allows multiple classes in each label.".
X = rng.normal(0, 1, size=(10, 10)) | ||
y = np.array([0, 1, 0, 1, 0, 1, 0, 1, 0, 2]) | ||
est = LogisticRegression() | ||
for method in ['predict_proba', 'predict_log_proba']: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
decision_function also?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd left it off because the test code doesn't handle the decision_function case when there's three classes in the full data but only 2 in one of the folds. I think that case is covered by test_cross_val_predict_class_subset
. In this test, I modified y
to have 4 classes instead of 3 and added "decision_function" to the list.
Also don't do unnecessary work if number of classes in a cross_val_predict fold equals number of classes in the full data.
@jnothman , comments addressed. Thanks for suggesting that optimization and pointing out a bug in |
needs merge fix |
Codecov Report
@@ Coverage Diff @@
## master #8773 +/- ##
==========================================
+ Coverage 96.16% 96.17% +<.01%
==========================================
Files 336 335 -1
Lines 62144 61953 -191
==========================================
- Hits 59762 59581 -181
+ Misses 2382 2372 -10
Continue to review full report at Codecov.
|
Is there anything I can do to help this PR along? |
I thought of this PR yesterday. I think it's essential before we merge
stacking using cross_val_predict (though there are unresolved conflicts
there). @glemaitre? @qinhanmin2014?
|
I will look at it tomorrow.
|
@glemaitre, do you have some time to review this? |
1f1fe77
to
02d53ec
Compare
does someone wanna review? otherwise I say untag. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, should we merge this while the other issues/approaches move forward?
I think merging this will be especially valuable if we review and merge StackingClassifier. |
Thanks @stephen-hoover, and sorry for the very slow ride!! |
Thank you for the merge! |
…cikit-learn#8773)" This reverts commit 0eaceb6.
…cikit-learn#8773)" This reverts commit 0eaceb6.
What does this implement/fix? Explain your changes.
Fixes #11058
The following fails under v0.18.1:
The error is
This PR modifies the
cross_val_predict
and_fit_and_predict
functions so that they handle multi-label (and also multi-class multi-label) classification problems with thepredict_proba
,predict_log_proba
, anddecision_function
methods.I've found two different kinds of multi-label outputs from scikit-learn estimators. The
OneVersusRestClassifier
handles multi-label tasks with binary indicator target arrays (no multi-label targets). It outputs 2D arrays frompredict_proba
, etc. methods. TheRandomForestClassifier
handles multi-class multi-label problems. It outputs a list of 2D arrays frompredict_proba
, etc.I recognize the RandomForest-like outputs by type-checking. Lists of 2D arrays require slightly different code for keeping track of indices than single 2D output arrays.
Any other comments?
I didn't make any modifications to handle sparse outputs for these cases. I don't know if it's necessary. Do any estimators return sparse outputs for multi-class multi-label classification problems?