Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG + 3] OneVsRestClassifier: don't expose predict_proba and decision_function if base estimator doesn't support them #7812

Merged
merged 3 commits into from Nov 7, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 7 additions & 4 deletions sklearn/multiclass.py
Expand Up @@ -50,7 +50,7 @@
from .utils.multiclass import (_check_partial_fit_first_call,
check_classification_targets,
_ovr_decision_function)
from .utils.metaestimators import _safe_split
from .utils.metaestimators import _safe_split, if_delegate_has_method

from .externals.joblib import Parallel
from .externals.joblib import delayed
Expand Down Expand Up @@ -309,6 +309,7 @@ def predict(self, X):
shape=(n_samples, len(self.estimators_)))
return self.label_binarizer_.inverse_transform(indicator)

@if_delegate_has_method(['_first_estimator', 'estimator'])
def predict_proba(self, X):
"""Probability estimates.

Expand Down Expand Up @@ -347,6 +348,7 @@ def predict_proba(self, X):
Y /= np.sum(Y, axis=1)[:, np.newaxis]
return Y

@if_delegate_has_method(['_first_estimator', 'estimator'])
def decision_function(self, X):
"""Returns the distance of each sample from the decision boundary for
each class. This can only be used with estimators which implement the
Expand All @@ -361,9 +363,6 @@ def decision_function(self, X):
T : array-like, shape = [n_samples, n_classes]
"""
check_is_fitted(self, 'estimators_')
if not hasattr(self.estimators_[0], "decision_function"):
raise AttributeError(
"Base estimator doesn't have a decision_function attribute.")
return np.array([est.decision_function(X).ravel()
for est in self.estimators_]).T

Expand Down Expand Up @@ -400,6 +399,10 @@ def _pairwise(self):
"""Indicate if wrapped estimator is using a precomputed Gram matrix"""
return getattr(self.estimator, "_pairwise", False)

@property
def _first_estimator(self):
return self.estimators_[0]


def _fit_ovo_binary(estimator, X, y, i, j):
"""Fit a single binary estimator (one-vs-one)."""
Expand Down
22 changes: 17 additions & 5 deletions sklearn/tests/test_multiclass.py
Expand Up @@ -314,14 +314,25 @@ def test_ovr_multilabel_predict_proba():
X_test = X[80:]
clf = OneVsRestClassifier(base_clf).fit(X_train, Y_train)

# decision function only estimator. Fails in current implementation.
# Decision function only estimator.
decision_only = OneVsRestClassifier(svm.SVR()).fit(X_train, Y_train)
assert_raises(AttributeError, decision_only.predict_proba, X_test)
assert_false(hasattr(decision_only, 'predict_proba'))
assert_true(hasattr(decision_only, 'decision_function'))

# Estimator with predict_proba disabled, depending on parameters.
decision_only = OneVsRestClassifier(svm.SVC(probability=False))
assert_false(hasattr(decision_only, 'predict_proba'))
decision_only.fit(X_train, Y_train)
assert_raises(AttributeError, decision_only.predict_proba, X_test)
assert_false(hasattr(decision_only, 'predict_proba'))
assert_true(hasattr(decision_only, 'decision_function'))

# Estimator which can get predict_proba enabled after fitting
gs = GridSearchCV(svm.SVC(probability=False),
param_grid={'probability': [True]})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for testing this.

proba_after_fit = OneVsRestClassifier(gs)
assert_false(hasattr(proba_after_fit, 'predict_proba'))
proba_after_fit.fit(X_train, Y_train)
assert_true(hasattr(proba_after_fit, 'predict_proba'))

Y_pred = clf.predict(X_test)
Y_proba = clf.predict_proba(X_test)
Expand All @@ -339,9 +350,10 @@ def test_ovr_single_label_predict_proba():
X_test = X[80:]
clf = OneVsRestClassifier(base_clf).fit(X_train, Y_train)

# decision function only estimator. Fails in current implementation.
# Decision function only estimator.
decision_only = OneVsRestClassifier(svm.SVR()).fit(X_train, Y_train)
assert_raises(AttributeError, decision_only.predict_proba, X_test)
assert_false(hasattr(decision_only, 'predict_proba'))
assert_true(hasattr(decision_only, 'decision_function'))

Y_pred = clf.predict(X_test)
Y_proba = clf.predict_proba(X_test)
Expand Down