Skip to content

Commit

Permalink
FIX OvR/OvO classifier decision_function shape fixes (#9100)
Browse files Browse the repository at this point in the history
* fix OVR classifier edgecase bugs

* add regression tests for OVO and OVR decision function shapes
  • Loading branch information
amueller authored and vene committed Jun 10, 2017
1 parent 56a21ea commit b43c791
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 1 deletion.
5 changes: 5 additions & 0 deletions doc/whats_new.rst
Expand Up @@ -418,6 +418,11 @@ API changes summary
:func:`model_selection.cross_val_predict`.
:issue:`2879` by :user:`Stephen Hoover <stephen-hoover>`.

- The ``decision_function`` output shape for binary classification in
:class:`multi_class.OneVsRestClassifier` and
:class:`multi_class.OneVsOneClassifier` is now ``(n_samples,)`` to conform
to scikit-learn conventions. :issue:`9100` by `Andreas Müller`_.

- Gradient boosting base models are no longer estimators. By `Andreas Müller`_.

- :class:`feature_selection.SelectFromModel` now validates the ``threshold``
Expand Down
7 changes: 6 additions & 1 deletion sklearn/multiclass.py
Expand Up @@ -368,6 +368,8 @@ def decision_function(self, X):
T : array-like, shape = [n_samples, n_classes]
"""
check_is_fitted(self, 'estimators_')
if len(self.estimators_) == 1:
return self.estimators_[0].decision_function(X)
return np.array([est.decision_function(X).ravel()
for est in self.estimators_]).T

Expand Down Expand Up @@ -574,6 +576,8 @@ def predict(self, X):
Predicted multi-class targets.
"""
Y = self.decision_function(X)
if self.n_classes_ == 2:
return self.classes_[(Y > 0).astype(np.int)]
return self.classes_[Y.argmax(axis=1)]

def decision_function(self, X):
Expand Down Expand Up @@ -606,7 +610,8 @@ def decision_function(self, X):
for est, Xi in zip(self.estimators_, Xs)]).T
Y = _ovr_decision_function(predictions,
confidences, len(self.classes_))

if self.n_classes_ == 2:
return Y[:, 1]
return Y

@property
Expand Down
9 changes: 9 additions & 0 deletions sklearn/tests/test_multiclass.py
Expand Up @@ -251,6 +251,9 @@ def conduct_test(base_clf, test_predict_proba=False):
assert_equal(set(clf.classes_), classes)
y_pred = clf.predict(np.array([[0, 0, 4]]))[0]
assert_equal(set(y_pred), set("eggs"))
if hasattr(base_clf, 'decision_function'):
dec = clf.decision_function(X)
assert_equal(dec.shape, (5,))

if test_predict_proba:
X_test = np.array([[0, 0, 4]])
Expand Down Expand Up @@ -524,6 +527,12 @@ def test_ovo_decision_function():
n_samples = iris.data.shape[0]

ovo_clf = OneVsOneClassifier(LinearSVC(random_state=0))
# first binary
ovo_clf.fit(iris.data, iris.target == 0)
decisions = ovo_clf.decision_function(iris.data)
assert_equal(decisions.shape, (n_samples,))

# then multi-class
ovo_clf.fit(iris.data, iris.target)
decisions = ovo_clf.decision_function(iris.data)

Expand Down

0 comments on commit b43c791

Please sign in to comment.