Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+1] cross_val_predict handles multi-label predict_proba #8773

Merged

Conversation

stephen-hoover
Copy link
Contributor

@stephen-hoover stephen-hoover commented Apr 21, 2017

What does this implement/fix? Explain your changes.

Fixes #11058

The following fails under v0.18.1:

from sklearn.datasets import make_multilabel_classification
from sklearn.model_selection import cross_val_predict
from sklearn.ensemble import RandomForestClassifier

X, y = make_multilabel_classification(n_samples=100, n_labels=3, n_classes=4, n_features=5, random_state=42)
y[:, 0] += y[:, 1]  # Put three classes in the first column
est = RandomForestClassifier(n_estimators=5, random_state=0)
cross_val_predict(est, X, y, 'predict_proba')

The error is

ValueError: Found input variables with inconsistent numbers of samples: [100, 100, 13]

This PR modifies the cross_val_predict and _fit_and_predict functions so that they handle multi-label (and also multi-class multi-label) classification problems with the predict_proba, predict_log_proba, and decision_function methods.

I've found two different kinds of multi-label outputs from scikit-learn estimators. The OneVersusRestClassifier handles multi-label tasks with binary indicator target arrays (no multi-label targets). It outputs 2D arrays from predict_proba, etc. methods. The RandomForestClassifier handles multi-class multi-label problems. It outputs a list of 2D arrays from predict_proba, etc.

I recognize the RandomForest-like outputs by type-checking. Lists of 2D arrays require slightly different code for keeping track of indices than single 2D output arrays.

Any other comments?

I didn't make any modifications to handle sparse outputs for these cases. I don't know if it's necessary. Do any estimators return sparse outputs for multi-class multi-label classification problems?

Modify the `cross_val_predict` and `_fit_and_predict` functions so that they handle multi-label (and multi-class multi-label) classification problems with `predict_proba`, `predict_log_proba`, and `decision_function` methods.

There's two different kinds of multi-label outputs from scikit-learn estimators. The `OneVersusRestClassifier` handles multi-label tasks with binary indicator target arrays (no multi-label targets). It outputs 2D arrays from `predict_proba`, etc. methods. The `RandomForestClassifier` handles multi-class multi-label problems. It outputs a list of 2D arrays from `predict_proba`, etc.

Recognize the RandomForest-like outputs by type-checking. Lists of 2D arrays require slightly different code for keeping track of indices.
@stephen-hoover stephen-hoover changed the title [WIP] cross_val_predict handles multi-label predict_proba [MRG] cross_val_predict handles multi-label predict_proba Apr 21, 2017
@jnothman
Copy link
Member

Sparse outputs aren't so relevant to predict_proba

@@ -419,9 +428,20 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1,
# Check for sparse predictions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this, or move it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

@@ -419,9 +428,20 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1,
# Check for sparse predictions
if sp.issparse(predictions[0]):
predictions = sp.vstack(predictions, format=predictions[0].format)
elif do_manual_encoding and isinstance(predictions[0], list):
n_labels = y.shape[1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a comment here is deserved to remind us what we've got and where it's going

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment.

else:
predictions = np.concatenate(predictions)
return predictions[inv_test_indices]

if do_manual_encoding and isinstance(predictions, list):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why you need do_manual_encoding here...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, you're right. Checking if predictions is a list is enough. Changed.

if method in ['decision_function', 'predict_proba', 'predict_log_proba']:
le = LabelEncoder()
y = le.fit_transform(y)
do_manual_encoding = method in ['decision_function', 'predict_proba',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this variable name. Perhaps just encode, encoded or is_encoded would do

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to encode.

@stephen-hoover
Copy link
Contributor Author

@jnothman , apologies for taking a long time to respond here. I believe I've addressed your comments.

@stephen-hoover
Copy link
Contributor Author

@jnothman , would you like additional changes here?

@jnothman
Copy link
Member

Oh, I thought we'd merged this fix at some point :(

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is good, but I'm too tired to be sure!

@@ -178,6 +178,10 @@ Enhancements
removed by setting it to `None`.
:issue:`7674` by :user:`Yichuan Liu <yl565>`.

- Added ability for :func:`model_selection.cross_val_predict` to handle multi-label
(and multi-class multi-label) targets with `predict_proba`-type methods.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double backticks

elif encode and isinstance(predictions[0], list):
# `predictions` is a list of method outputs from each fold.
# If each of those is also a list, then treat this as a
# multi-class multi-label task. We need to separately concatenate
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean multi-output multiclass?

@stephen-hoover
Copy link
Contributor Author

@jnothman , thanks for taking another look at this PR. I added double backticks and changed "multi-class multi-label" to "multioutput-multiclass" (the spelling I found in existing documentation) in that comment and in the What's New.

@jnothman
Copy link
Member

jnothman commented Jul 27, 2017 via email

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced that the need for _enforce_prediction_order is tested here. Can we make sure this is tested on data where the set of classes for training each fold will vary, such as y = [1,2,3,4,5] and Y = [[0, 1], [1, 2], [0, 3], [1, 4], [0, 5]]?

It wouldn't hurt if check_cross_val_predict_with_method explicitly checked that the shape was as expected, either...

expected_predictions[test] = func(X[test])
preds = func(X[test])
if isinstance(predictions, list):
for i_label in range(y.shape[1]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Call it output rather than label, please

func = getattr(est, method)

# Naive loop (should be same as cross_val_predict):
for train, test in kfold.split(X, y):
est.fit(X[train], y[train])
expected_predictions[test] = func(X[test])
preds = func(X[test])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused. Doesn't this mean that sometimes we're going to have mismatched numbers of classes vis-a-vis what we try to solve with _enforce_prediction_order?

@stephen-hoover
Copy link
Contributor Author

@jnothman , you're right that the tests weren't covering the case where _enforce_prediction_order was needed. I added two new tests which have a split with fewer classes than the full dataset. I also added some explicit asserts on the output.

then the output prediction array might not have the same
columns as other folds. Use the list of class names
(assumed to be integers) to enforce the correct column order.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth fast-pathing the classes == arange case? Or is that premature optimisation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea. We can return immediately if all classes were present in the subset of data used to train this fold.

predictions_ = np.zeros((predictions.shape[0], n_classes),
dtype=predictions.dtype)
if one_col_if_binary and len(classes) == 2:
predictions_[:, classes[-1]] = predictions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This leaves one of the columns zeroed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what we have to do if n_classes >= 3. But you're right that this is a problem if n_classes = 2. That should have a single column of output. I added a new test to catch this, test_cross_val_predict_binary_decision_function. It's handled now by having the function return predictions directly when len(classes) == n_classes.



def test_cross_val_predict_with_method_multilabel_rf():
# The RandomForest allows anything for the contents of the labels.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The wording here is unclear. Do you just mean that RF handles multiclass-multioutput and produces predict_proba in that vein?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to "The RandomForest allows multiple classes in each label.".

X = rng.normal(0, 1, size=(10, 10))
y = np.array([0, 1, 0, 1, 0, 1, 0, 1, 0, 2])
est = LogisticRegression()
for method in ['predict_proba', 'predict_log_proba']:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decision_function also?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd left it off because the test code doesn't handle the decision_function case when there's three classes in the full data but only 2 in one of the folds. I think that case is covered by test_cross_val_predict_class_subset. In this test, I modified y to have 4 classes instead of 3 and added "decision_function" to the list.

Also don't do unnecessary work if number of classes in a cross_val_predict fold equals number of classes in the full data.
@stephen-hoover
Copy link
Contributor Author

@jnothman , comments addressed. Thanks for suggesting that optimization and pointing out a bug in _enforce_prediction_order.

@amueller
Copy link
Member

needs merge fix

@codecov
Copy link

codecov bot commented Aug 28, 2017

Codecov Report

Merging #8773 into master will increase coverage by <.01%.
The diff coverage is 97.29%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #8773      +/-   ##
==========================================
+ Coverage   96.16%   96.17%   +<.01%     
==========================================
  Files         336      335       -1     
  Lines       62144    61953     -191     
==========================================
- Hits        59762    59581     -181     
+ Misses       2382     2372      -10
Impacted Files Coverage Δ
sklearn/model_selection/_validation.py 96.73% <94.44%> (+0.27%) ⬆️
sklearn/model_selection/tests/test_validation.py 98.75% <98.66%> (-0.02%) ⬇️
sklearn/utils/tests/test_testing.py 80.61% <0%> (-0.39%) ⬇️
sklearn/ensemble/tests/test_gradient_boosting.py 96.03% <0%> (-0.24%) ⬇️
sklearn/linear_model/logistic.py 96.86% <0%> (-0.21%) ⬇️
sklearn/utils/estimator_checks.py 93.19% <0%> (-0.2%) ⬇️
sklearn/preprocessing/data.py 98.65% <0%> (-0.17%) ⬇️
...rn/semi_supervised/tests/test_label_propagation.py 98.91% <0%> (-0.16%) ⬇️
sklearn/utils/__init__.py 94.48% <0%> (-0.05%) ⬇️
sklearn/linear_model/stochastic_gradient.py 98.13% <0%> (-0.04%) ⬇️
... and 21 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 0b8a936...764e1fd. Read the comment docs.

@glemaitre glemaitre added this to the 0.20 milestone Jun 8, 2018
@stephen-hoover
Copy link
Contributor Author

Is there anything I can do to help this PR along?

@jnothman
Copy link
Member

jnothman commented Jun 25, 2018 via email

@glemaitre
Copy link
Member

glemaitre commented Jun 25, 2018 via email

@jnothman
Copy link
Member

@glemaitre, do you have some time to review this?

@ogrisel ogrisel added this to PRs tagged in scikit-learn 0.20 Jul 16, 2018
@amueller amueller force-pushed the multilabel-cross-val-predict branch 2 times, most recently from 1f1fe77 to 02d53ec Compare July 20, 2018 18:45
@amueller
Copy link
Member

does someone wanna review? otherwise I say untag.

Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, should we merge this while the other issues/approaches move forward?

@jnothman
Copy link
Member

jnothman commented Apr 6, 2019

I think merging this will be especially valuable if we review and merge StackingClassifier.

@jnothman
Copy link
Member

jnothman commented Apr 6, 2019

Thanks @stephen-hoover, and sorry for the very slow ride!!

@jnothman jnothman merged commit 24df999 into scikit-learn:master Apr 6, 2019
@stephen-hoover
Copy link
Contributor Author

Thank you for the merge!

@stephen-hoover stephen-hoover deleted the multilabel-cross-val-predict branch April 7, 2019 01:21
jeremiedbb pushed a commit to jeremiedbb/scikit-learn that referenced this pull request Apr 25, 2019
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
koenvandevelde pushed a commit to koenvandevelde/scikit-learn that referenced this pull request Jul 12, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
No open projects
scikit-learn 0.20
  
PRs tagged
Development

Successfully merging this pull request may close these issues.

'cross_val_predict' throws error when estimator is 'OneVsRest Classifier' and method is 'decision_function'
6 participants