Skip to content

Commit

Permalink
Merge branch 'pr/3710'
Browse files Browse the repository at this point in the history
  • Loading branch information
larsmans committed Sep 29, 2014
2 parents aa66dea + b584ac4 commit 6ea371a
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 8 deletions.
7 changes: 7 additions & 0 deletions doc/whats_new.rst
Expand Up @@ -120,6 +120,13 @@ API changes summary
- `n_jobs` parameter of the fit method shifted to the constructor of the
LinearRegression class.

- The ``predict_proba`` method of :class:`multiclass.OneVsRestClassifier`
now returns two probabilities per sample in the multiclass case; this
is consistent with other estimators and with the method's documentation,
but previous versions accidentally returned only the positive
probability. Fixed by Will Lamond and `Lars Buitinck`_.


.. _changes_0_15_2:

0.15.2
Expand Down
12 changes: 9 additions & 3 deletions sklearn/multiclass.py
Expand Up @@ -54,6 +54,7 @@
"OutputCodeClassifier",
]


def _fit_binary(estimator, X, y, classes=None):
"""Fit a single binary estimator."""
unique_y = np.unique(y)
Expand Down Expand Up @@ -299,17 +300,17 @@ def predict(self, X):
else:
thresh = .5

n_samples = _num_samples(X)
if self.label_binarizer_.y_type_ == "multiclass":
maxima = np.empty(X.shape[0], dtype=float)
maxima = np.empty(n_samples, dtype=float)
maxima.fill(-np.inf)
argmaxima = np.zeros(X.shape[0], dtype=int)
argmaxima = np.zeros(n_samples, dtype=int)
for i, e in enumerate(self.estimators_):
pred = _predict_binary(e, X)
np.maximum(maxima, pred, out=maxima)
argmaxima[maxima == pred] = i
return self.label_binarizer_.classes_[np.array(argmaxima.T)]
else:
n_samples = _num_samples(X)
indices = array.array('i')
indptr = array.array('i', [0])
for e in self.estimators_:
Expand Down Expand Up @@ -347,6 +348,11 @@ def predict_proba(self, X):
# In the multi-label case, these are not disjoint.
Y = np.array([e.predict_proba(X)[:, 1] for e in self.estimators_]).T

if len(self.estimators_) == 1:
# Only one estimator, but we still want to return probabilities
# for two classes.
Y = np.concatenate(((1 - Y), Y), axis=1)

if not self.multilabel_:
# Then, probabilities should be normalized to 1.
Y /= np.sum(Y, axis=1)[:, np.newaxis]
Expand Down
23 changes: 18 additions & 5 deletions sklearn/tests/test_multiclass.py
Expand Up @@ -27,7 +27,7 @@

from sklearn.preprocessing import LabelBinarizer

from sklearn.svm import LinearSVC
from sklearn.svm import LinearSVC, SVC
from sklearn.naive_bayes import MultinomialNB
from sklearn.linear_model import (LinearRegression, Lasso, ElasticNet, Ridge,
Perceptron, LogisticRegression)
Expand Down Expand Up @@ -186,20 +186,33 @@ def test_ovr_binary():

classes = set("eggs spam".split())

for base_clf in (MultinomialNB(), LinearSVC(random_state=0),
LinearRegression(), Ridge(),
ElasticNet()):

def conduct_test(base_clf, test_predict_proba=False):
clf = OneVsRestClassifier(base_clf).fit(X, y)
assert_equal(set(clf.classes_), classes)
y_pred = clf.predict(np.array([[0, 0, 4]]))[0]
assert_equal(set(y_pred), set("eggs"))

if test_predict_proba:
X_test = np.array([[0, 0, 4]])
probabilities = clf.predict_proba(X_test)
assert_equal(2, len(probabilities[0]))
assert_equal(clf.classes_[np.argmax(probabilities, axis=1)],
clf.predict(X_test))

# test input as label indicator matrix
clf = OneVsRestClassifier(base_clf).fit(X, Y)
y_pred = clf.predict([[3, 0, 0]])[0]
assert_equal(y_pred, 1)

for base_clf in (LinearSVC(random_state=0), LinearRegression(),
Ridge(), ElasticNet()):
conduct_test(base_clf)

for base_clf in (MultinomialNB(), SVC(probability=True),
LogisticRegression()):
conduct_test(base_clf, test_predict_proba=True)


@ignore_warnings
def test_ovr_multilabel():
# Toy dataset where features correspond directly to labels.
Expand Down

0 comments on commit 6ea371a

Please sign in to comment.