Skip to content

Commit

Permalink
Fix scikit-learn#10440, will give up as another fix was proposed
Browse files Browse the repository at this point in the history
  • Loading branch information
sdpython committed Feb 14, 2019
1 parent 9ac5793 commit f0a84f0
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 3 deletions.
43 changes: 41 additions & 2 deletions sklearn/svm/base.py
Expand Up @@ -7,7 +7,7 @@
from . import libsvm_sparse
from ..base import BaseEstimator, ClassifierMixin
from ..preprocessing import LabelEncoder
from ..utils.multiclass import _ovr_decision_function
from ..utils.multiclass import _ovr_decision_function, _ovr_decision_function_raw
from ..utils import check_array, check_consistent_length, check_random_state
from ..utils import column_or_1d, check_X_y
from ..utils import compute_class_weight
Expand Down Expand Up @@ -550,9 +550,48 @@ def decision_function(self, X):
"""
dec = self._decision_function(X)
if self.decision_function_shape == 'ovr' and len(self.classes_) > 2:
return _ovr_decision_function(dec < 0, -dec, len(self.classes_))
return _ovr_decision_function(dec < 0, -dec, len(self.classes_), scale=self.votes_scale_)
return dec

def fit(self, X, y, sample_weight=None):
"""Fit the SVM model according to the given training data.
Parameters
----------
X : {array-like, sparse matrix}, shape (n_samples, n_features)
Training vectors, where n_samples is the number of samples
and n_features is the number of features.
For kernel="precomputed", the expected shape of X is
(n_samples, n_samples).
y : array-like, shape (n_samples,)
Target values (class labels in classification, real numbers in
regression)
sample_weight : array-like, shape (n_samples,)
Per-sample weights. Rescale C per sample. Higher weights
force the classifier to put more emphasis on these points.
Returns
-------
self : object
Notes
------
If X and y are not C-ordered and contiguous arrays of np.float64 and
X is not a scipy.sparse.csr_matrix, X and/or y may be copied.
If X is a dense array, then the other methods will not support sparse
matrices as input.
"""
super().fit(X, y, sample_weight=sample_weight)
if self.decision_function_shape == 'ovr' and len(self.classes_) > 2:
dec = self._decision_function(X)
self.votes_scale_ = _ovr_decision_function_raw(dec < 0, -dec, len(self.classes_))[2]
else:
self.votes_scale_ = None
return self

def predict(self, X):
"""Perform classification on samples in X.
Expand Down
5 changes: 5 additions & 0 deletions sklearn/svm/classes.py
Expand Up @@ -566,6 +566,11 @@ class SVC(BaseSVC):
where ``probA_`` and ``probB_`` are learned from the dataset [2]_. For
more information on the multiclass case and training procedure see
section 8 of [1]_.
votes_scale_: float
If *decision_function_shape == 'ovr'* and len(self.classes_) > 2,
method *decision_function* rescale the outputs with these constants
estimated on the training dataset.
Examples
--------
Expand Down
12 changes: 12 additions & 0 deletions sklearn/svm/tests/test_svm.py
Expand Up @@ -1009,3 +1009,15 @@ def test_gamma_scale():
# gamma is not explicitly set.
X, y = [[1, 2], [3, 2 * np.sqrt(6) / 3 + 2]], [0, 1]
assert_no_warnings(clf.fit, X, y)


def test_consistent_SVC_output():
model = svm.SVC()
X = iris.data[:, :3]
Y = iris.target
model.fit(X, Y)
X1 = X[:1] * 0.10
X2 = X[1:2]
X1X2 = np.vstack([X1, X2])
assert_array_equal(model.decision_function(X1),
model.decision_function(X1X2)[:1])
33 changes: 32 additions & 1 deletion sklearn/utils/multiclass.py
Expand Up @@ -397,7 +397,7 @@ def class_distribution(y, sample_weight=None):
return (classes, n_classes, class_prior)


def _ovr_decision_function(predictions, confidences, n_classes):
def _ovr_decision_function_raw(predictions, confidences, n_classes):
"""Compute a continuous, tie-breaking OvR decision function from OvO.
It is important to include a continuous value, not only votes,
Expand Down Expand Up @@ -442,4 +442,35 @@ def _ovr_decision_function(predictions, confidences, n_classes):
eps = np.finfo(sum_of_confidences.dtype).eps
max_abs_confidence = max(abs(max_confidences), abs(min_confidences))
scale = (0.5 - eps) / max_abs_confidence
return votes, sum_of_confidences, scale


def _ovr_decision_function(predictions, confidences, n_classes, scale=None):
"""Compute a continuous, tie-breaking OvR decision function from OvO.
It is important to include a continuous value, not only votes,
to make computing AUC or calibration meaningful.
Parameters
----------
predictions : array-like, shape (n_samples, n_classifiers)
Predicted classes for each binary classifier.
confidences : array-like, shape (n_samples, n_classifiers)
Decision functions or predicted probabilities for positive class
for each binary classifier.
n_classes : int
Number of classes. n_classifiers must be
``n_classes * (n_classes - 1 ) / 2``
scale: float
*sum_of_confidences* are rescaled with a number estimated
on the data itself if *scale* is None or a number estimated
on another dataset such a training dataset.
"""
votes, sum_of_confidences, scale_ = _ovr_decision_function_raw(predictions, confidences, n_classes)
if scale is not None:
scale_ = scale
return votes + sum_of_confidences * scale

0 comments on commit f0a84f0

Please sign in to comment.