Skip to content

Commit

Permalink
fixes in constructor chaining in SVC, doctests, etc
Browse files Browse the repository at this point in the history
  • Loading branch information
amueller committed May 18, 2015
1 parent 855038f commit b2b33d4
Show file tree
Hide file tree
Showing 9 changed files with 141 additions and 72 deletions.
7 changes: 4 additions & 3 deletions doc/modules/model_persistence.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ persistence model, namely `pickle <http://docs.python.org/library/pickle.html>`_
>>> iris = datasets.load_iris()
>>> X, y = iris.data, iris.target
>>> clf.fit(X, y) # doctest: +NORMALIZE_WHITESPACE
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,
kernel='rbf', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
compact_decision_function=None, degree=3, gamma=0.0, kernel='rbf',
max_iter=-1, probability=False, random_state=None, shrinking=True, tol=0.001,
verbose=False)

>>> import pickle
>>> s = pickle.dumps(clf)
Expand Down
12 changes: 6 additions & 6 deletions doc/modules/pipeline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ is an estimator object::
>>> clf # doctest: +NORMALIZE_WHITESPACE
Pipeline(steps=[('reduce_dim', PCA(copy=True, n_components=None,
whiten=False)), ('svm', SVC(C=1.0, cache_size=200, class_weight=None,
coef0=0.0, degree=3, gamma=0.0, kernel='rbf', max_iter=-1,
probability=False, random_state=None, shrinking=True, tol=0.001,
verbose=False))])
coef0=0.0, compact_decision_function=None, degree=3, gamma=0.0,
kernel='rbf', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False))])

The utility function :func:`make_pipeline` is a shorthand
for constructing pipelines;
Expand Down Expand Up @@ -76,9 +76,9 @@ Parameters of the estimators in the pipeline can be accessed using the
>>> clf.set_params(svm__C=10) # doctest: +NORMALIZE_WHITESPACE
Pipeline(steps=[('reduce_dim', PCA(copy=True, n_components=None,
whiten=False)), ('svm', SVC(C=10, cache_size=200, class_weight=None,
coef0=0.0, degree=3, gamma=0.0, kernel='rbf', max_iter=-1,
probability=False, random_state=None, shrinking=True, tol=0.001,
verbose=False))])
coef0=0.0, compact_decision_function=None, degree=3, gamma=0.0,
kernel='rbf', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False))])

This is particularly important for doing grid searches::

Expand Down
21 changes: 12 additions & 9 deletions doc/modules/svm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,10 @@ n_features]`` holding the training samples, and an array y of class labels
>>> y = [0, 1]
>>> clf = svm.SVC()
>>> clf.fit(X, y) # doctest: +NORMALIZE_WHITESPACE
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3,
gamma=0.0, kernel='rbf', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
compact_decision_function=None, degree=3, gamma=0.0, kernel='rbf',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)

After being fitted, the model can then be used to predict new values::

Expand Down Expand Up @@ -115,9 +116,10 @@ classifiers are constructed and each one trains data from two classes::
>>> Y = [0, 1, 2, 3]
>>> clf = svm.SVC()
>>> clf.fit(X, Y) # doctest: +NORMALIZE_WHITESPACE
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3,
gamma=0.0, kernel='rbf', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
compact_decision_function=None, degree=3, gamma=0.0, kernel='rbf',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)
>>> dec = clf.decision_function([[1]])
>>> dec.shape[1] # 4 classes: 4*3/2 = 6
6
Expand Down Expand Up @@ -503,9 +505,10 @@ test vectors must be provided.
>>> # linear kernel computation
>>> gram = np.dot(X, X.T)
>>> clf.fit(gram, y) # doctest: +NORMALIZE_WHITESPACE
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3,
gamma=0.0, kernel='precomputed', max_iter=-1, probability=False,
random_state=None, shrinking=True, tol=0.001, verbose=False)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
compact_decision_function=None, degree=3, gamma=0.0, kernel='precomputed',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)
>>> # predict on training examples
>>> clf.predict(gram)
array([0, 1])
Expand Down
42 changes: 24 additions & 18 deletions doc/tutorial/basic/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,10 @@ which produces a new array that contains all but
the last entry of ``digits.data``::

>>> clf.fit(digits.data[:-1], digits.target[:-1]) # doctest: +NORMALIZE_WHITESPACE
SVC(C=100.0, cache_size=200, class_weight=None, coef0=0.0, degree=3,
gamma=0.001, kernel='rbf', max_iter=-1, probability=False,
random_state=None, shrinking=True, tol=0.001, verbose=False)
SVC(C=100.0, cache_size=200, class_weight=None, coef0=0.0,
compact_decision_function=None, degree=3, gamma=0.001, kernel='rbf',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)

Now you can predict new values, in particular, we can ask to the
classifier what is the digit of our last image in the ``digits`` dataset,
Expand Down Expand Up @@ -214,9 +215,10 @@ persistence model, namely `pickle <http://docs.python.org/library/pickle.html>`_
>>> iris = datasets.load_iris()
>>> X, y = iris.data, iris.target
>>> clf.fit(X, y) # doctest: +NORMALIZE_WHITESPACE
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,
kernel='rbf', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
compact_decision_function=None, degree=3, gamma=0.0, kernel='rbf',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)

>>> import pickle
>>> s = pickle.dumps(clf)
Expand Down Expand Up @@ -288,17 +290,19 @@ maintained::
>>> iris = datasets.load_iris()
>>> clf = SVC()
>>> clf.fit(iris.data, iris.target)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,
kernel='rbf', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
compact_decision_function=None, degree=3, gamma=0.0, kernel='rbf',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)

>>> list(clf.predict(iris.data[:3]))
[0, 0, 0]

>>> clf.fit(iris.data, iris.target_names[iris.target])
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,
kernel='rbf', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
compact_decision_function=None, degree=3, gamma=0.0, kernel='rbf',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)

>>> list(clf.predict(iris.data[:3])) # doctest: +NORMALIZE_WHITESPACE
['setosa', 'setosa', 'setosa']
Expand All @@ -325,16 +329,18 @@ more than once will overwrite what was learned by any previous ``fit()``::

>>> clf = SVC()
>>> clf.set_params(kernel='linear').fit(X, y)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,
kernel='linear', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
compact_decision_function=None, degree=3, gamma=0.0, kernel='linear',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)
>>> clf.predict(X_test)
array([1, 0, 1, 1, 0])

>>> clf.set_params(kernel='rbf').fit(X, y)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,
kernel='rbf', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
compact_decision_function=None, degree=3, gamma=0.0, kernel='rbf',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)
>>> clf.predict(X_test)
array([0, 0, 0, 1, 0])

Expand Down
7 changes: 4 additions & 3 deletions doc/tutorial/statistical_inference/supervised_learning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -453,9 +453,10 @@ classification --:class:`SVC` (Support Vector Classification).
>>> from sklearn import svm
>>> svc = svm.SVC(kernel='linear')
>>> svc.fit(iris_X_train, iris_y_train) # doctest: +NORMALIZE_WHITESPACE
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,
kernel='linear', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
compact_decision_function=None, degree=3, gamma=0.0, kernel='linear',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)


.. warning:: **Normalizing data**
Expand Down
14 changes: 7 additions & 7 deletions sklearn/multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,8 +549,8 @@ def decision_function(self, X):
"""
check_is_fitted(self, 'estimators_')

predictions = [est.predict(X) for est in self.estimators_]
confidences = [_predict_binary(est, X) for est in self.estimators_]
predictions = np.vstack([est.predict(X) for est in self.estimators_]).T
confidences = np.vstack([_predict_binary(est, X) for est in self.estimators_]).T
return _ovr_decision_function(predictions, confidences,
len(self.classes_))

Expand All @@ -574,17 +574,17 @@ def _ovr_decision_function(predictions, confidences, n_classes):
Number of classes. n_classifiers must be
``n_classes * (n_classes - 1 ) / 2``
"""
n_samples = predictions[0].shape[0]
n_samples = predictions.shape[0]
votes = np.zeros((n_samples, n_classes))
sum_of_confidences = np.zeros((n_samples, n_classes))

k = 0
for i in range(n_classes):
for j in range(i + 1, n_classes):
sum_of_confidences[:, i] -= confidences[k]
sum_of_confidences[:, j] += confidences[k]
votes[predictions[k] == 0, i] += 1
votes[predictions[k] == 1, j] += 1
sum_of_confidences[:, i] -= confidences[:, k]
sum_of_confidences[:, j] += confidences[:, k]
votes[predictions[:, k] == 0, i] += 1
votes[predictions[:, k] == 1, j] += 1
k += 1

max_confidences = sum_of_confidences.max()
Expand Down
23 changes: 20 additions & 3 deletions sklearn/svm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from . import libsvm_sparse
from ..base import BaseEstimator, ClassifierMixin, ChangedBehaviorWarning
from ..preprocessing import LabelEncoder
from ..multiclass import _ovr_decision_function
from ..utils import check_array, check_random_state, column_or_1d
from ..utils import ConvergenceWarning, compute_class_weight, deprecated
from ..utils.extmath import safe_sparse_dot
Expand Down Expand Up @@ -481,8 +482,19 @@ def _get_coef(self):
return safe_sparse_dot(self._dual_coef_, self.support_vectors_)


class BaseSVC(BaseLibSVM, ClassifierMixin):
class BaseSVC(six.with_metaclass(ABCMeta, BaseLibSVM, ClassifierMixin)):
"""ABC for LibSVM-based classifiers."""
@abstractmethod
def __init__(self, impl, kernel, degree, gamma, coef0, tol, C, nu,
shrinking, probability, cache_size, class_weight, verbose,
max_iter, compact_decision_function, random_state):
self.compact_decision_function = compact_decision_function
super(BaseSVC, self).__init__(
impl=impl, kernel=kernel, degree=degree, gamma=gamma, coef0=coef0,
tol=tol, C=C, nu=nu, epsilon=0., shrinking=shrinking,
probability=probability, cache_size=cache_size,
class_weight=class_weight, verbose=verbose, max_iter=max_iter,
random_state=random_state)

def _validate_targets(self, y):
y_ = column_or_1d(y, warn=True)
Expand All @@ -506,16 +518,21 @@ def decision_function(self, X):
Returns
-------
X : array-like, shape (n_samples, n_class * (n_class-1) / 2)
X : array-like, shape (n_samples, n_classes * (n_classes-1) / 2)
Returns the decision function of the sample for each class
in the model.
If compact_decision_function=True, the shape is (n_samples,
n_classes)
"""
if self.compact_decision_function is None:
warnings.warn("The compact_decision_function default value will "
"change from False to True in 0.19. This will change "
"the shape of the decision function returned by "
"SVC.", ChangedBehaviorWarning)
return self._decision_function(X)
dec = self._decision_function(X)
if self.compact_decision_function:
return _ovr_decision_function(dec < 0, dec, len(self.classes_))
return dec

def predict(self, X):
"""Perform classification on samples in X.
Expand Down
49 changes: 31 additions & 18 deletions sklearn/svm/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,9 +487,10 @@ class frequencies.
>>> from sklearn.svm import SVC
>>> clf = SVC()
>>> clf.fit(X, y) #doctest: +NORMALIZE_WHITESPACE
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3,
gamma=0.0, kernel='rbf', max_iter=-1, probability=False,
random_state=None, shrinking=True, tol=0.001, verbose=False)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
compact_decision_function=None, degree=3, gamma=0.0, kernel='rbf',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)
>>> print(clf.predict([[-0.8, -1]]))
[1]
Expand All @@ -511,9 +512,12 @@ def __init__(self, C=1.0, kernel='rbf', degree=3, gamma=0.0, coef0=0.0,
compact_decision_function=None, random_state=None):

super(SVC, self).__init__(
'c_svc', kernel, degree, gamma, coef0, tol, C, 0., 0., shrinking,
probability, cache_size, class_weight, verbose, max_iter,
compact_decision_function, random_state)
impl='c_svc', kernel=kernel, degree=degree, gamma=gamma, coef0=coef0,
tol=tol, C=C, nu=0., shrinking=shrinking,
probability=probability, cache_size=cache_size,
class_weight=class_weight, verbose=verbose, max_iter=max_iter,
compact_decision_function=compact_decision_function,
random_state=random_state)


class NuSVC(BaseSVC):
Expand Down Expand Up @@ -563,6 +567,13 @@ class NuSVC(BaseSVC):
cache_size : float, optional
Specify the size of the kernel cache (in MB).
class_weight : {dict, 'auto'}, optional
Set the parameter C of class i to class_weight[i]*C for
SVC. If not given, all classes are supposed to have
weight one. The 'auto' mode uses the values of y to
automatically adjust weights inversely proportional to
class frequencies.
verbose : bool, default: False
Enable verbose output. Note that this setting takes advantage of a
per-process runtime setting in libsvm that, if enabled, may not work
Expand Down Expand Up @@ -616,7 +627,8 @@ class NuSVC(BaseSVC):
>>> from sklearn.svm import NuSVC
>>> clf = NuSVC()
>>> clf.fit(X, y) #doctest: +NORMALIZE_WHITESPACE
NuSVC(cache_size=200, coef0=0.0, degree=3, gamma=0.0, kernel='rbf',
NuSVC(cache_size=200, class_weight=None, coef0=0.0,
compact_decision_function=None, degree=3, gamma=0.0, kernel='rbf',
max_iter=-1, nu=0.5, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False)
>>> print(clf.predict([[-0.8, -1]]))
Expand All @@ -632,15 +644,18 @@ class NuSVC(BaseSVC):
liblinear.
"""

def __init__(self, nu=0.5, kernel='rbf', degree=3, gamma=0.0,
coef0=0.0, shrinking=True, probability=False,
tol=1e-3, cache_size=200, verbose=False, max_iter=-1,
def __init__(self, nu=0.5, kernel='rbf', degree=3, gamma=0.0, coef0=0.0,
shrinking=True, probability=False, tol=1e-3, cache_size=200,
class_weight=None, verbose=False, max_iter=-1,
compact_decision_function=None, random_state=None):

super(NuSVC, self).__init__(
'nu_svc', kernel, degree, gamma, coef0, tol, 0., nu, 0., shrinking,
probability, cache_size, None, verbose, max_iter,
compact_decision_function, random_state)
impl='nu_svc', kernel=kernel, degree=degree, gamma=gamma,
coef0=coef0, tol=tol, C=0., nu=nu, shrinking=shrinking,
probability=probability, cache_size=cache_size,
class_weight=class_weight, verbose=verbose, max_iter=max_iter,
compact_decision_function=compact_decision_function,
random_state=random_state)


class SVR(BaseLibSVM, RegressorMixin):
Expand Down Expand Up @@ -749,8 +764,7 @@ def __init__(self, kernel='rbf', degree=3, gamma=0.0, coef0=0.0, tol=1e-3,
'epsilon_svr', kernel=kernel, degree=degree, gamma=gamma,
coef0=coef0, tol=tol, C=C, nu=0., epsilon=epsilon, verbose=verbose,
shrinking=shrinking, probability=False, cache_size=cache_size,
class_weight=None, max_iter=max_iter,
compact_decision_function=False, random_state=None)
class_weight=None, max_iter=max_iter, random_state=None)


class NuSVR(BaseLibSVM, RegressorMixin):
Expand Down Expand Up @@ -860,8 +874,7 @@ def __init__(self, nu=0.5, C=1.0, kernel='rbf', degree=3,
'nu_svr', kernel=kernel, degree=degree, gamma=gamma, coef0=coef0,
tol=tol, C=C, nu=nu, epsilon=0., shrinking=shrinking,
probability=False, cache_size=cache_size, class_weight=None,
verbose=verbose, max_iter=max_iter, compact_decision_function=False,
random_state=None)
verbose=verbose, max_iter=max_iter, random_state=None)


class OneClassSVM(BaseLibSVM):
Expand Down Expand Up @@ -948,7 +961,7 @@ def __init__(self, kernel='rbf', degree=3, gamma=0.0, coef0=0.0, tol=1e-3,
super(OneClassSVM, self).__init__(
'one_class', kernel, degree, gamma, coef0, tol, 0., nu, 0.,
shrinking, False, cache_size, None, verbose, max_iter,
False, random_state)
random_state)

def fit(self, X, y=None, sample_weight=None, **params):
"""
Expand Down
Loading

0 comments on commit b2b33d4

Please sign in to comment.