Skip to content

Commit

Permalink
FIX Issue 379 and use the opportunity to refactor libsvm code
Browse files Browse the repository at this point in the history
Sparse SVC's would try to use the dense decision_function, since
SparseBaseLibSVM did not override that method. Solution: factor out
non-common code to DenseBaseLibSVM. (@diogojc's test program still
crashes, but with an appropriate error message.)

Also, leverage superclass __init__ in SparseBaseLibSVM
  • Loading branch information
larsmans committed Oct 9, 2011
1 parent b1b2cc1 commit 87c8664
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 74 deletions.
68 changes: 35 additions & 33 deletions sklearn/svm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,41 @@ def __init__(self, impl, kernel, degree, gamma, coef0,
self.shrinking = shrinking
self.probability = probability

def predict_log_proba(self, T):
"""Compute the log likehoods each possible outcomes of samples in T.
The model need to have probability information computed at training
time: fit with attribute `probability` set to True.
Parameters
----------
T : array-like, shape = [n_samples, n_features]
Returns
-------
T : array-like, shape = [n_samples, n_classes]
Returns the log-probabilities of the sample for each class in
the model, where classes are ordered by arithmetical
order.
Notes
-----
The probability model is created using cross validation, so
the results can be slightly different than those obtained by
predict. Also, it will meaningless results on very small
datasets.
"""
return np.log(self.predict_proba(T))

@property
def coef_(self):
if self.kernel != 'linear':
raise NotImplementedError('coef_ is only available when using a '
'linear kernel')
return np.dot(self.dual_coef_, self.support_vectors_)


class DenseBaseLibSVM(BaseLibSVM):
def _compute_kernel(self, X):
"""Return the data transformed by a callable kernel"""
if hasattr(self, 'kernel_function'):
Expand Down Expand Up @@ -230,32 +265,6 @@ def predict_proba(self, X):

return pprob

def predict_log_proba(self, T):
"""Compute the log likehoods each possible outcomes of samples in T.
The model need to have probability information computed at training
time: fit with attribute `probability` set to True.
Parameters
----------
T : array-like, shape = [n_samples, n_features]
Returns
-------
T : array-like, shape = [n_samples, n_classes]
Returns the log-probabilities of the sample for each class in
the model, where classes are ordered by arithmetical
order.
Notes
-----
The probability model is created using cross validation, so
the results can be slightly different than those obtained by
predict. Also, it will meaningless results on very small
datasets.
"""
return np.log(self.predict_proba(T))

def decision_function(self, X):
"""Distance of the samples T to the separating hyperplane.
Expand Down Expand Up @@ -290,13 +299,6 @@ def decision_function(self, X):
else:
return dec_func

@property
def coef_(self):
if self.kernel != 'linear':
raise NotImplementedError('coef_ is only available when using a '
'linear kernel')
return np.dot(self.dual_coef_, self.support_vectors_)


class BaseLibLinear(BaseEstimator):
"""Base for classes binding liblinear (dense and sparse versions)"""
Expand Down
41 changes: 20 additions & 21 deletions sklearn/svm/classes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ..base import ClassifierMixin, RegressorMixin
from ..linear_model.base import CoefSelectTransformerMixin
from .base import BaseLibLinear, BaseLibSVM
from .base import BaseLibLinear, DenseBaseLibSVM


class LinearSVC(BaseLibLinear, ClassifierMixin, CoefSelectTransformerMixin):
Expand Down Expand Up @@ -83,7 +83,7 @@ class LinearSVC(BaseLibLinear, ClassifierMixin, CoefSelectTransformerMixin):
pass


class SVC(BaseLibSVM, ClassifierMixin):
class SVC(DenseBaseLibSVM, ClassifierMixin):
"""C-Support Vector Classification.
Parameters
Expand Down Expand Up @@ -161,11 +161,11 @@ def __init__(self, C=1.0, kernel='rbf', degree=3, gamma=0.0,
coef0=0.0, shrinking=True, probability=False,
tol=1e-3):

BaseLibSVM.__init__(self, 'c_svc', kernel, degree, gamma, coef0,
tol, C, 0., 0., shrinking, probability)
DenseBaseLibSVM.__init__(self, 'c_svc', kernel, degree, gamma, coef0,
tol, C, 0., 0., shrinking, probability)


class NuSVC(BaseLibSVM, ClassifierMixin):
class NuSVC(DenseBaseLibSVM, ClassifierMixin):
"""Nu-Support Vector Classification.
Parameters
Expand Down Expand Up @@ -262,12 +262,11 @@ def __init__(self, nu=0.5, kernel='rbf', degree=3, gamma=0.0,
coef0=0.0, shrinking=True, probability=False,
tol=1e-3):

BaseLibSVM.__init__(self, 'nu_svc', kernel, degree, gamma,
coef0, tol, 0., nu, 0.,
shrinking, probability)
DenseBaseLibSVM.__init__(self, 'nu_svc', kernel, degree, gamma,
coef0, tol, 0., nu, 0., shrinking, probability)


class SVR(BaseLibSVM, RegressorMixin):
class SVR(DenseBaseLibSVM, RegressorMixin):
"""epsilon-Support Vector Regression.
The free parameters in the model are C and epsilon.
Expand Down Expand Up @@ -349,9 +348,9 @@ def __init__(self, kernel='rbf', degree=3, gamma=0.0, coef0=0.0,
tol=1e-3, C=1.0, epsilon=0.1, shrinking=True,
probability=False):

BaseLibSVM.__init__(self, 'epsilon_svr', kernel, degree,
gamma, coef0, tol, C, 0.0,
epsilon, shrinking, probability)
DenseBaseLibSVM.__init__(self, 'epsilon_svr', kernel, degree, gamma,
coef0, tol, C, 0., epsilon, shrinking,
probability)

def fit(self, X, y, sample_weight=None, **params):
"""
Expand All @@ -371,10 +370,11 @@ def fit(self, X, y, sample_weight=None, **params):
Returns self.
"""
# we copy this method because SVR does not accept class_weight
return BaseLibSVM.fit(self, X, y, sample_weight=sample_weight, **params)
return DenseBaseLibSVM.fit(self, X, y, sample_weight=sample_weight,
**params)


class NuSVR(BaseLibSVM, RegressorMixin):
class NuSVR(DenseBaseLibSVM, RegressorMixin):
"""Nu Support Vector Regression.
Similar to NuSVC, for regression, uses a parameter nu to control
Expand Down Expand Up @@ -458,9 +458,8 @@ def __init__(self, nu=0.5, C=1.0, kernel='rbf', degree=3,
gamma=0.0, coef0=0.0, shrinking=True,
probability=False, tol=1e-3):

BaseLibSVM.__init__(self, 'nu_svr', kernel, degree,
gamma, coef0, tol, C, nu,
None, shrinking, probability)
DenseBaseLibSVM.__init__(self, 'nu_svr', kernel, degree, gamma, coef0,
tol, C, nu, None, shrinking, probability)

def fit(self, X, y, sample_weight=None, **params):
"""
Expand All @@ -480,10 +479,10 @@ def fit(self, X, y, sample_weight=None, **params):
Returns self.
"""
# we copy this method because SVR does not accept class_weight
return BaseLibSVM.fit(self, X, y, sample_weight=[], **params)
return DenseBaseLibSVM.fit(self, X, y, sample_weight=[], **params)


class OneClassSVM(BaseLibSVM):
class OneClassSVM(DenseBaseLibSVM):
"""Unsupervised Outliers Detection.
Estimate the support of a high-dimensional distribution.
Expand Down Expand Up @@ -539,8 +538,8 @@ class OneClassSVM(BaseLibSVM):
"""
def __init__(self, kernel='rbf', degree=3, gamma=0.0, coef0=0.0,
tol=1e-3, nu=0.5, shrinking=True):
BaseLibSVM.__init__(self, 'one_class', kernel, degree, gamma, coef0,
tol, 0.0, nu, 0.0, shrinking, False)
DenseBaseLibSVM.__init__(self, 'one_class', kernel, degree, gamma,
coef0, tol, 0., nu, 0., shrinking, False)

def fit(self, X, class_weight={}, sample_weight=None, **params):
"""
Expand Down
27 changes: 7 additions & 20 deletions sklearn/svm/sparse/base.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,24 @@
import numpy as np

from ..base import BaseLibSVM, BaseLibLinear, _get_class_weight
from ..base import BaseLibSVM, BaseLibLinear, LIBSVM_IMPL, _get_class_weight
from . import libsvm
from .. import liblinear


class SparseBaseLibSVM(BaseLibSVM):

_kernel_types = ['linear', 'poly', 'rbf', 'sigmoid', 'precomputed']
_svm_types = ['c_svc', 'nu_svc', 'one_class', 'epsilon_svr', 'nu_svr']

def __init__(self, impl, kernel, degree, gamma, coef0,
tol, C, nu, epsilon, shrinking, probability):

assert impl in self._svm_types, \
"impl should be one of %s, %s was given" % (
self._svm_types, impl)

assert kernel in self._kernel_types, \
"kernel should be one of %s, "\
"%s was given." % (self._kernel_types, kernel)

self.kernel = kernel
self.impl = impl
self.degree = degree
self.gamma = gamma
self.coef0 = coef0
self.tol = tol
self.C = C
self.nu = nu
self.epsilon = epsilon
self.shrinking = shrinking
self.probability = probability
super(SparseBaseLibSVM, self).__init__(impl, kernel, degree, gamma,
coef0, tol, C, nu, epsilon,
shrinking, probability)

# container for when we call fit
self._support_data = np.empty(0, dtype=np.float64, order='C')
Expand Down Expand Up @@ -101,7 +88,7 @@ def fit(self, X, y, class_weight=None, sample_weight=None, cache_size=100.):
"Note: Sparse matrices cannot be indexed w/" +
"boolean masks (use `indices=True` in CV).")

solver_type = self._svm_types.index(self.impl)
solver_type = LIBSVM_IMPL.index(self.impl)
kernel_type = self._kernel_types.index(self.kernel)

self.class_weight, self.class_weight_label = \
Expand Down Expand Up @@ -171,7 +158,7 @@ def predict(self, T):
self.support_vectors_.indices,
self.support_vectors_.indptr,
self.dual_coef_.data, self.intercept_,
self._svm_types.index(self.impl), kernel_type,
LIBSVM_IMPL.index(self.impl), kernel_type,
self.degree, self.gamma, self.coef0, self.tol,
self.C, self.class_weight_label, self.class_weight,
self.nu, self.epsilon, self.shrinking,
Expand Down Expand Up @@ -221,7 +208,7 @@ def predict_proba(self, X):
self.support_vectors_.indices,
self.support_vectors_.indptr,
self.dual_coef_.data, self.intercept_,
self._svm_types.index(self.impl), kernel_type,
LIBSVM_IMPL.index(self.impl), kernel_type,
self.degree, self.gamma, self.coef0, self.tol,
self.C, self.class_weight_label, self.class_weight,
self.nu, self.epsilon, self.shrinking,
Expand Down

7 comments on commit 87c8664

@mblondel
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this! Why does the test program crash? I thought that it should work.

@larsmans
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It needs to construct the sparse.SVC with probability=True. We could make this the default behavior quite easily, but I don't oversee the consequences apart from an extra O(|C|²) memory allocation.

@mblondel
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that libsvm fits some kind of logistic regression on top of the SVM when probability=True (as described by Platt) so there's an overhead. But I'm curious why OneVsRestClassifier doesn't detect that SVC has a decision_function method. Normally, it falls back to predict_proba only when decision_function is not present.

@larsmans
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decision_function isn't implemented for sparse SVMs, but in the old inheritance tree, it inherited the one from BaseLibSVM. So, OneVsRestClassifier detected the dense decision_function even for sparse SVMs, which didn't work and raised an exception because of missing self._support.

In the new situation, decision_function is in DenseBaseLibSVM so it's no longer inherited by the sparse SVM classes and predict_proba is used instead.

@mblondel
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok thanks for the clarification. I guess decision_function should be implemented for sparse SVMs as well then. Maybe @fabianp as an opinion on this.

@fabianp
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mblondel: sure, go for it!

@larsmans
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for a sparse decision_function, but may I request that the new inheritance hierarchy be preserved? It's bigger, but also more robust, I hope.

Please sign in to comment.