Skip to content

Commit

Permalink
FIX: inheritance in DenseBaseSVM
Browse files Browse the repository at this point in the history
Sequel for d079dde, including a test
  • Loading branch information
Fabian Pedregosa committed Feb 1, 2012
1 parent d079dde commit beda3f4
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 19 deletions.
41 changes: 22 additions & 19 deletions sklearn/svm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,18 +323,20 @@ def _dense_predict(self, X):
"the number of features at training time" %
(n_features, self.shape_fit_[1]))

params = self.get_params()
if 'scale_C' in params:
del params['scale_C']
if "sparse" in params:
del params["sparse"]
epsilon = self.epsilon
if epsilon == None:
epsilon = 0.1

svm_type = LIBSVM_IMPL.index(self.impl)
return libsvm.predict(
X, self.support_, self.support_vectors_, self.n_support_,
self.dual_coef_, self.intercept_,
self.label_, self.probA_, self.probB_,
svm_type=svm_type, **params)
svm_type=svm_type,
kernel=self.kernel, C=self.C, nu=self.nu,
probability=self.probability, degree=self.degree,
shrinking=self.shrinking, tol=self.tol, cache_size=self.cache_size,
coef0=self.coef0, gamma=self.gamma, epsilon=epsilon)

def _sparse_predict(self, X):
X = sp.csr_matrix(X, dtype=np.float64)
Expand Down Expand Up @@ -393,18 +395,19 @@ def predict_proba(self, X):
def _dense_predict_proba(self, X):
X = self._compute_kernel(X)

params = self.get_params()
if 'scale_C' in params:
del params['scale_C']
if "sparse" in params:
del params["sparse"]
epsilon = self.epsilon
if epsilon == None:
epsilon = 0.1

svm_type = LIBSVM_IMPL.index(self.impl)
pprob = libsvm.predict_proba(
X, self.support_, self.support_vectors_, self.n_support_,
self.dual_coef_, self.intercept_, self.label_,
self.probA_, self.probB_,
svm_type=svm_type, **params)
svm_type=svm_type, kernel=self.kernel, C=self.C, nu=self.nu,
probability=self.probability, degree=self.degree,
shrinking=self.shrinking, tol=self.tol, cache_size=self.cache_size,
coef0=self.coef0, gamma=self.gamma, epsilon=epsilon)

return pprob

Expand Down Expand Up @@ -478,18 +481,18 @@ def decision_function(self, X):

X = array2d(X, dtype=np.float64, order="C")

params = self.get_params()
if 'scale_C' in params:
del params['scale_C']
if "sparse" in params:
del params["sparse"]

epsilon = self.epsilon
if epsilon == None:
epsilon = 0.1
dec_func = libsvm.decision_function(
X, self.support_, self.support_vectors_, self.n_support_,
self.dual_coef_, self.intercept_, self.label_,
self.probA_, self.probB_,
svm_type=LIBSVM_IMPL.index(self.impl),
**params)
kernel=self.kernel, C=self.C, nu=self.nu,
probability=self.probability, degree=self.degree,
shrinking=self.shrinking, tol=self.tol, cache_size=self.cache_size,
coef0=self.coef0, gamma=self.gamma, epsilon=epsilon)

return dec_func

Expand Down
11 changes: 11 additions & 0 deletions sklearn/svm/tests/test_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,17 @@ def test_immutable_coef_property():
assert_raises(AttributeError, clf.__setattr__, 'coef_', np.arange(3))
assert_raises(RuntimeError, clf.coef_.__setitem__, (0, 0), 0)

def test_inheritance():
# check that SVC classes can do inheritance
class ChildSVC(svm.SVC):
def __init__(self, foo=0):
self.foo = foo
svm.SVC.__init__(self)

clf = ChildSVC()
clf.fit(iris.data, iris.target)
clf.predict(iris.data[-1])
clf.decision_function(iris.data[-1])

if __name__ == '__main__':
import nose
Expand Down

0 comments on commit beda3f4

Please sign in to comment.