Skip to content
This repository
Browse code

FIX : callable kernel for prediction

  • Loading branch information...
commit e00044c4e8d05b5d2f96439ab3336b37b3350134 1 parent a13138a
Alexandre Gramfort agramfort authored

Showing 2 changed files with 15 additions and 3 deletions. Show diff stats Hide diff stats

  1. +11 2 sklearn/svm/base.py
  2. +4 1 sklearn/svm/tests/test_svm.py
13 sklearn/svm/base.py
@@ -441,12 +441,16 @@ def _dense_predict_proba(self, X):
441 441
442 442 C = 0.0 # C is not useful here
443 443
  444 + kernel = self.kernel
  445 + if hasattr(kernel, '__call__'):
  446 + kernel = 'precomputed'
  447 +
444 448 svm_type = LIBSVM_IMPL.index(self.impl)
445 449 pprob = libsvm.predict_proba(
446 450 X, self.support_, self.support_vectors_, self.n_support_,
447 451 self.dual_coef_, self._intercept_, self.label_,
448 452 self.probA_, self.probB_,
449   - svm_type=svm_type, kernel=self.kernel, C=C, nu=self.nu,
  453 + svm_type=svm_type, kernel=kernel, C=C, nu=self.nu,
450 454 probability=self.probability, degree=self.degree,
451 455 shrinking=self.shrinking, tol=self.tol, cache_size=self.cache_size,
452 456 coef0=self.coef0, gamma=self.gamma, epsilon=epsilon)
@@ -528,12 +532,17 @@ def decision_function(self, X):
528 532 epsilon = self.epsilon
529 533 if epsilon == None:
530 534 epsilon = 0.1
  535 +
  536 + kernel = self.kernel
  537 + if hasattr(kernel, '__call__'):
  538 + kernel = 'precomputed'
  539 +
531 540 dec_func = libsvm.decision_function(
532 541 X, self.support_, self.support_vectors_, self.n_support_,
533 542 self.dual_coef_, self._intercept_, self.label_,
534 543 self.probA_, self.probB_,
535 544 svm_type=LIBSVM_IMPL.index(self.impl),
536   - kernel=self.kernel, C=C, nu=self.nu,
  545 + kernel=kernel, C=C, nu=self.nu,
537 546 probability=self.probability, degree=self.degree,
538 547 shrinking=self.shrinking, tol=self.tol, cache_size=self.cache_size,
539 548 coef0=self.coef0, gamma=self.gamma, epsilon=epsilon)
5 sklearn/svm/tests/test_svm.py
@@ -662,9 +662,12 @@ def test_linearsvc_verbose():
662 662
663 663
664 664 def test_svc_pickle_with_callable_kernel():
665   - a = svm.SVC(kernel=lambda x, y: np.dot(x, y.T))
  665 + a = svm.SVC(kernel=lambda x, y: np.dot(x, y.T), probability=True)
666 666 b = base.clone(a)
667 667 b.fit(X, Y)
  668 + b.predict(X)
  669 + b.predict_proba(X)
  670 + b.decision_function(X)
668 671
669 672
670 673 if __name__ == '__main__':

0 comments on commit e00044c

Please sign in to comment.
Something went wrong with that request. Please try again.