Skip to content

Commit

Permalink
ENH refactor OVO decision function, use it in SVC for sklearn-like de…
Browse files Browse the repository at this point in the history
…cision_function shape
  • Loading branch information
amueller committed Jun 5, 2015
1 parent 4eda9e6 commit 6d4bcda
Show file tree
Hide file tree
Showing 13 changed files with 271 additions and 131 deletions.
7 changes: 4 additions & 3 deletions doc/modules/model_persistence.rst
Expand Up @@ -22,9 +22,10 @@ persistence model, namely `pickle <http://docs.python.org/library/pickle.html>`_
>>> iris = datasets.load_iris()
>>> X, y = iris.data, iris.target
>>> clf.fit(X, y) # doctest: +NORMALIZE_WHITESPACE
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,
kernel='rbf', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
decision_function_shape=None, degree=3, gamma=0.0, kernel='rbf',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)

>>> import pickle
>>> s = pickle.dumps(clf)
Expand Down
12 changes: 6 additions & 6 deletions doc/modules/pipeline.rst
Expand Up @@ -42,9 +42,9 @@ is an estimator object::
>>> clf # doctest: +NORMALIZE_WHITESPACE
Pipeline(steps=[('reduce_dim', PCA(copy=True, n_components=None,
whiten=False)), ('svm', SVC(C=1.0, cache_size=200, class_weight=None,
coef0=0.0, degree=3, gamma=0.0, kernel='rbf', max_iter=-1,
probability=False, random_state=None, shrinking=True, tol=0.001,
verbose=False))])
coef0=0.0, decision_function_shape=None, degree=3, gamma=0.0,
kernel='rbf', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False))])

The utility function :func:`make_pipeline` is a shorthand
for constructing pipelines;
Expand Down Expand Up @@ -76,9 +76,9 @@ Parameters of the estimators in the pipeline can be accessed using the
>>> clf.set_params(svm__C=10) # doctest: +NORMALIZE_WHITESPACE
Pipeline(steps=[('reduce_dim', PCA(copy=True, n_components=None,
whiten=False)), ('svm', SVC(C=10, cache_size=200, class_weight=None,
coef0=0.0, degree=3, gamma=0.0, kernel='rbf', max_iter=-1,
probability=False, random_state=None, shrinking=True, tol=0.001,
verbose=False))])
coef0=0.0, decision_function_shape=None, degree=3, gamma=0.0,
kernel='rbf', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False))])

This is particularly important for doing grid searches::

Expand Down
33 changes: 22 additions & 11 deletions doc/modules/svm.rst
Expand Up @@ -76,9 +76,10 @@ n_features]`` holding the training samples, and an array y of class labels
>>> y = [0, 1]
>>> clf = svm.SVC()
>>> clf.fit(X, y) # doctest: +NORMALIZE_WHITESPACE
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3,
gamma=0.0, kernel='rbf', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
decision_function_shape=None, degree=3, gamma=0.0, kernel='rbf',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)

After being fitted, the model can then be used to predict new values::

Expand Down Expand Up @@ -109,18 +110,27 @@ Multi-class classification
:class:`SVC` and :class:`NuSVC` implement the "one-against-one"
approach (Knerr et al., 1990) for multi- class classification. If
``n_class`` is the number of classes, then ``n_class * (n_class - 1) / 2``
classifiers are constructed and each one trains data from two classes::
classifiers are constructed and each one trains data from two classes.
To provide a consistent interface with other classifiers, the
``decision_function_shape`` option allows to aggregate the results of the
"one-against-one" classifiers to a decision function of shape ``(n_samples,
n_classes)``::

>>> X = [[0], [1], [2], [3]]
>>> Y = [0, 1, 2, 3]
>>> clf = svm.SVC()
>>> clf = svm.SVC(decision_function_shape='ovo')
>>> clf.fit(X, Y) # doctest: +NORMALIZE_WHITESPACE
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3,
gamma=0.0, kernel='rbf', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
decision_function_shape='ovo', degree=3, gamma=0.0, kernel='rbf',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)
>>> dec = clf.decision_function([[1]])
>>> dec.shape[1] # 4 classes: 4*3/2 = 6
6
>>> clf.decision_function_shape = "ovr"
>>> dec = clf.decision_function([[1]])
>>> dec.shape[1] # 4 classes
4

On the other hand, :class:`LinearSVC` implements "one-vs-the-rest"
multi-class strategy, thus training n_class models. If there are only
Expand Down Expand Up @@ -503,9 +513,10 @@ test vectors must be provided.
>>> # linear kernel computation
>>> gram = np.dot(X, X.T)
>>> clf.fit(gram, y) # doctest: +NORMALIZE_WHITESPACE
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3,
gamma=0.0, kernel='precomputed', max_iter=-1, probability=False,
random_state=None, shrinking=True, tol=0.001, verbose=False)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
decision_function_shape=None, degree=3, gamma=0.0,
kernel='precomputed', max_iter=-1, probability=False,
random_state=None, shrinking=True, tol=0.001, verbose=False)
>>> # predict on training examples
>>> clf.predict(gram)
array([0, 1])
Expand Down
50 changes: 28 additions & 22 deletions doc/tutorial/basic/tutorial.rst
Expand Up @@ -176,9 +176,10 @@ which produces a new array that contains all but
the last entry of ``digits.data``::

>>> clf.fit(digits.data[:-1], digits.target[:-1]) # doctest: +NORMALIZE_WHITESPACE
SVC(C=100.0, cache_size=200, class_weight=None, coef0=0.0, degree=3,
gamma=0.001, kernel='rbf', max_iter=-1, probability=False,
random_state=None, shrinking=True, tol=0.001, verbose=False)
SVC(C=100.0, cache_size=200, class_weight=None, coef0=0.0,
decision_function_shape=None, degree=3, gamma=0.001, kernel='rbf',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)

Now you can predict new values, in particular, we can ask to the
classifier what is the digit of our last image in the ``digits`` dataset,
Expand Down Expand Up @@ -214,9 +215,10 @@ persistence model, namely `pickle <http://docs.python.org/library/pickle.html>`_
>>> iris = datasets.load_iris()
>>> X, y = iris.data, iris.target
>>> clf.fit(X, y) # doctest: +NORMALIZE_WHITESPACE
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,
kernel='rbf', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
decision_function_shape=None, degree=3, gamma=0.0, kernel='rbf',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)

>>> import pickle
>>> s = pickle.dumps(clf)
Expand Down Expand Up @@ -287,18 +289,20 @@ maintained::

>>> iris = datasets.load_iris()
>>> clf = SVC()
>>> clf.fit(iris.data, iris.target)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,
kernel='rbf', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False)
>>> clf.fit(iris.data, iris.target) # doctest: +NORMALIZE_WHITESPACE
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
decision_function_shape=None, degree=3, gamma=0.0, kernel='rbf',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)

>>> list(clf.predict(iris.data[:3]))
[0, 0, 0]

>>> clf.fit(iris.data, iris.target_names[iris.target])
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,
kernel='rbf', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False)
>>> clf.fit(iris.data, iris.target_names[iris.target]) # doctest: +NORMALIZE_WHITESPACE
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
decision_function_shape=None, degree=3, gamma=0.0, kernel='rbf',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)

>>> list(clf.predict(iris.data[:3])) # doctest: +NORMALIZE_WHITESPACE
['setosa', 'setosa', 'setosa']
Expand All @@ -324,17 +328,19 @@ more than once will overwrite what was learned by any previous ``fit()``::
>>> X_test = rng.rand(5, 10)

>>> clf = SVC()
>>> clf.set_params(kernel='linear').fit(X, y)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,
kernel='linear', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False)
>>> clf.set_params(kernel='linear').fit(X, y) # doctest: +NORMALIZE_WHITESPACE
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
decision_function_shape=None, degree=3, gamma=0.0, kernel='linear',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)
>>> clf.predict(X_test)
array([1, 0, 1, 1, 0])

>>> clf.set_params(kernel='rbf').fit(X, y)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,
kernel='rbf', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False)
>>> clf.set_params(kernel='rbf').fit(X, y) # doctest: +NORMALIZE_WHITESPACE
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
decision_function_shape=None, degree=3, gamma=0.0, kernel='rbf',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)
>>> clf.predict(X_test)
array([0, 0, 0, 1, 0])

Expand Down
7 changes: 4 additions & 3 deletions doc/tutorial/statistical_inference/supervised_learning.rst
Expand Up @@ -453,9 +453,10 @@ classification --:class:`SVC` (Support Vector Classification).
>>> from sklearn import svm
>>> svc = svm.SVC(kernel='linear')
>>> svc.fit(iris_X_train, iris_y_train) # doctest: +NORMALIZE_WHITESPACE
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,
kernel='linear', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
decision_function_shape=None, degree=3, gamma=0.0, kernel='linear',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)


.. warning:: **Normalizing data**
Expand Down
5 changes: 5 additions & 0 deletions doc/whats_new.rst
Expand Up @@ -66,6 +66,11 @@ API changes summary
for retrieving the leaf indices samples are predicted as. By
`Daniel Galvez`_ and `Gilles Louppe`_.

- :class:`svm.SVC`` and :class:`svm.NuSVC` now have an ``decision_function_shape``
parameter to make their decision function of shape ``(n_samples, n_classes)``
by setting ``decision_function_shape='ovr'``. This will be the default behavior
starting in 0.19. By `Andreas Müller`_.

.. _changes_0_1_16:

0.16.1
Expand Down
6 changes: 5 additions & 1 deletion sklearn/base.py
Expand Up @@ -11,7 +11,11 @@
from .externals import six


###############################################################################
class ChangedBehaviorWarning(UserWarning):
pass


##############################################################################
def clone(estimator, safe=True):
"""Constructs a new estimator with the same parameters.
Expand Down
13 changes: 5 additions & 8 deletions sklearn/grid_search.py
Expand Up @@ -20,7 +20,7 @@
import numpy as np

from .base import BaseEstimator, is_classifier, clone
from .base import MetaEstimatorMixin
from .base import MetaEstimatorMixin, ChangedBehaviorWarning
from .cross_validation import _check_cv as check_cv
from .cross_validation import _fit_and_score
from .externals.joblib import Parallel, delayed
Expand Down Expand Up @@ -304,10 +304,6 @@ def __repr__(self):
self.parameters)


class ChangedBehaviorWarning(UserWarning):
pass


class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
MetaEstimatorMixin)):
"""Base class for hyper parameter search with cross-validation."""
Expand Down Expand Up @@ -642,9 +638,10 @@ class GridSearchCV(BaseSearchCV):
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
GridSearchCV(cv=None, error_score=...,
estimator=SVC(C=1.0, cache_size=..., class_weight=..., coef0=...,
degree=..., gamma=..., kernel='rbf', max_iter=-1,
probability=False, random_state=None, shrinking=True,
tol=..., verbose=False),
decision_function_shape=None, degree=..., gamma=...,
kernel='rbf', max_iter=-1, probability=False,
random_state=None, shrinking=True, tol=...,
verbose=False),
fit_params={}, iid=..., n_jobs=1,
param_grid=..., pre_dispatch=..., refit=...,
scoring=..., verbose=...)
Expand Down
82 changes: 52 additions & 30 deletions sklearn/multiclass.py
Expand Up @@ -552,36 +552,58 @@ def decision_function(self, X):
"""
check_is_fitted(self, 'estimators_')

n_samples = X.shape[0]
n_classes = self.classes_.shape[0]
votes = np.zeros((n_samples, n_classes))
sum_of_confidences = np.zeros((n_samples, n_classes))

k = 0
for i in range(n_classes):
for j in range(i + 1, n_classes):
pred = self.estimators_[k].predict(X)
confidence_levels_ij = _predict_binary(self.estimators_[k], X)
sum_of_confidences[:, i] -= confidence_levels_ij
sum_of_confidences[:, j] += confidence_levels_ij
votes[pred == 0, i] += 1
votes[pred == 1, j] += 1
k += 1

max_confidences = sum_of_confidences.max()
min_confidences = sum_of_confidences.min()

if max_confidences == min_confidences:
return votes

# Scale the sum_of_confidences to (-0.5, 0.5) and add it with votes.
# The motivation is to use confidence levels as a way to break ties in
# the votes without switching any decision made based on a difference
# of 1 vote.
eps = np.finfo(sum_of_confidences.dtype).eps
max_abs_confidence = max(abs(max_confidences), abs(min_confidences))
scale = (0.5 - eps) / max_abs_confidence
return votes + sum_of_confidences * scale
predictions = np.vstack([est.predict(X) for est in self.estimators_]).T
confidences = np.vstack([_predict_binary(est, X) for est in self.estimators_]).T
return _ovr_decision_function(predictions, confidences,
len(self.classes_))


def _ovr_decision_function(predictions, confidences, n_classes):
"""Compute a continuous, tie-breaking ovr decision function.
It is important to include a continuous value, not only votes,
to make computing AUC or calibration meaningful.
Parameters
----------
predictions : array-like, shape (n_samples, n_classifiers)
Predicted classes for each binary classifier.
confidences : array-like, shape (n_samples, n_classifiers)
Decision functions or predicted probabilities for positive class
for each binary classifier.
n_classes : int
Number of classes. n_classifiers must be
``n_classes * (n_classes - 1 ) / 2``
"""
n_samples = predictions.shape[0]
votes = np.zeros((n_samples, n_classes))
sum_of_confidences = np.zeros((n_samples, n_classes))

k = 0
for i in range(n_classes):
for j in range(i + 1, n_classes):
sum_of_confidences[:, i] -= confidences[:, k]
sum_of_confidences[:, j] += confidences[:, k]
votes[predictions[:, k] == 0, i] += 1
votes[predictions[:, k] == 1, j] += 1
k += 1

max_confidences = sum_of_confidences.max()
min_confidences = sum_of_confidences.min()

if max_confidences == min_confidences:
return votes

# Scale the sum_of_confidences to (-0.5, 0.5) and add it with votes.
# The motivation is to use confidence levels as a way to break ties in
# the votes without switching any decision made based on a difference
# of 1 vote.
eps = np.finfo(sum_of_confidences.dtype).eps
max_abs_confidence = max(abs(max_confidences), abs(min_confidences))
scale = (0.5 - eps) / max_abs_confidence
return votes + sum_of_confidences * scale


@deprecated("fit_ecoc is deprecated and will be removed in 0.18."
Expand Down

0 comments on commit 6d4bcda

Please sign in to comment.