Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

FIX : callable kernel for prediction

  • Loading branch information...
commit e00044c4e8d05b5d2f96439ab3336b37b3350134 1 parent a13138a
Alexandre Gramfort agramfort authored
Showing with 15 additions and 3 deletions.
  1. +11 −2 sklearn/svm/base.py
  2. +4 −1 sklearn/svm/tests/test_svm.py
13 sklearn/svm/base.py
View
@@ -441,12 +441,16 @@ def _dense_predict_proba(self, X):
C = 0.0 # C is not useful here
+ kernel = self.kernel
+ if hasattr(kernel, '__call__'):
+ kernel = 'precomputed'
+
svm_type = LIBSVM_IMPL.index(self.impl)
pprob = libsvm.predict_proba(
X, self.support_, self.support_vectors_, self.n_support_,
self.dual_coef_, self._intercept_, self.label_,
self.probA_, self.probB_,
- svm_type=svm_type, kernel=self.kernel, C=C, nu=self.nu,
+ svm_type=svm_type, kernel=kernel, C=C, nu=self.nu,
probability=self.probability, degree=self.degree,
shrinking=self.shrinking, tol=self.tol, cache_size=self.cache_size,
coef0=self.coef0, gamma=self.gamma, epsilon=epsilon)
@@ -528,12 +532,17 @@ def decision_function(self, X):
epsilon = self.epsilon
if epsilon == None:
epsilon = 0.1
+
+ kernel = self.kernel
+ if hasattr(kernel, '__call__'):
+ kernel = 'precomputed'
+
dec_func = libsvm.decision_function(
X, self.support_, self.support_vectors_, self.n_support_,
self.dual_coef_, self._intercept_, self.label_,
self.probA_, self.probB_,
svm_type=LIBSVM_IMPL.index(self.impl),
- kernel=self.kernel, C=C, nu=self.nu,
+ kernel=kernel, C=C, nu=self.nu,
probability=self.probability, degree=self.degree,
shrinking=self.shrinking, tol=self.tol, cache_size=self.cache_size,
coef0=self.coef0, gamma=self.gamma, epsilon=epsilon)
5 sklearn/svm/tests/test_svm.py
View
@@ -662,9 +662,12 @@ def test_linearsvc_verbose():
def test_svc_pickle_with_callable_kernel():
- a = svm.SVC(kernel=lambda x, y: np.dot(x, y.T))
+ a = svm.SVC(kernel=lambda x, y: np.dot(x, y.T), probability=True)
b = base.clone(a)
b.fit(X, Y)
+ b.predict(X)
+ b.predict_proba(X)
+ b.decision_function(X)
if __name__ == '__main__':
Please sign in to comment.
Something went wrong with that request. Please try again.