Skip to content

Commit

Permalink
ENH scipy blas for svm kernel function (scikit-learn#16530)
Browse files Browse the repository at this point in the history
  • Loading branch information
jim0421 authored and viclafargue committed Jun 26, 2020
1 parent 79ec4bb commit 4ec06cd
Show file tree
Hide file tree
Showing 11 changed files with 233 additions and 171 deletions.
8 changes: 8 additions & 0 deletions doc/whats_new/v0.24.rst
Expand Up @@ -97,6 +97,14 @@ Changelog
and by validating data out of loops.
:pr:`17038` by :user:`Wenbo Zhao <webber26232>`.

:mod:`sklearn.svm`
....................
- |Enhancement| invoke scipy blas api for svm kernel function in ``fit``,
``predict`` and related methods of :class:`svm.SVC`, :class:`svm.NuSVC`,
:class:`svm.SVR`, :class:`svm.NuSVR`, :class:`OneClassSVM`.
:pr:`16530` by :user:`Shuhua Fan <jim0421>`.


Code and Documentation Contributors
-----------------------------------

Expand Down
100 changes: 55 additions & 45 deletions sklearn/ensemble/tests/test_bagging.py
Expand Up @@ -4,9 +4,11 @@

# Author: Gilles Louppe
# License: BSD 3 clause
from itertools import product

import numpy as np
import joblib
import pytest

from sklearn.base import BaseEstimator

Expand All @@ -30,7 +32,7 @@
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_diabetes, load_iris, make_hastie_10_2
from sklearn.utils import check_random_state
from sklearn.preprocessing import FunctionTransformer
from sklearn.preprocessing import FunctionTransformer, scale

from scipy.sparse import csc_matrix, csr_matrix

Expand Down Expand Up @@ -74,7 +76,32 @@ def test_classification():
**params).fit(X_train, y_train).predict(X_test)


def test_sparse_classification():
@pytest.mark.parametrize(
'sparse_format, params, method',
product(
[csc_matrix, csr_matrix],
[{
"max_samples": 0.5,
"max_features": 2,
"bootstrap": True,
"bootstrap_features": True
}, {
"max_samples": 1.0,
"max_features": 4,
"bootstrap": True,
"bootstrap_features": True
}, {
"max_features": 2,
"bootstrap": False,
"bootstrap_features": True
}, {
"max_samples": 0.5,
"bootstrap": True,
"bootstrap_features": False
}],
['predict', 'predict_proba',
'predict_log_proba', 'decision_function']))
def test_sparse_classification(sparse_format, params, method):
# Check classification for various parameter settings on sparse input.

class CustomSVC(SVC):
Expand All @@ -86,52 +113,35 @@ def fit(self, X, y):
return self

rng = check_random_state(0)
X_train, X_test, y_train, y_test = train_test_split(iris.data,
X_train, X_test, y_train, y_test = train_test_split(scale(iris.data),
iris.target,
random_state=rng)
parameter_sets = [
{"max_samples": 0.5,
"max_features": 2,
"bootstrap": True,
"bootstrap_features": True},
{"max_samples": 1.0,
"max_features": 4,
"bootstrap": True,
"bootstrap_features": True},
{"max_features": 2,
"bootstrap": False,
"bootstrap_features": True},
{"max_samples": 0.5,
"bootstrap": True,
"bootstrap_features": False},
]

for sparse_format in [csc_matrix, csr_matrix]:
X_train_sparse = sparse_format(X_train)
X_test_sparse = sparse_format(X_test)
for params in parameter_sets:
for f in ['predict', 'predict_proba', 'predict_log_proba', 'decision_function']:
# Trained on sparse format
sparse_classifier = BaggingClassifier(
base_estimator=CustomSVC(decision_function_shape='ovr'),
random_state=1,
**params
).fit(X_train_sparse, y_train)
sparse_results = getattr(sparse_classifier, f)(X_test_sparse)

# Trained on dense format
dense_classifier = BaggingClassifier(
base_estimator=CustomSVC(decision_function_shape='ovr'),
random_state=1,
**params
).fit(X_train, y_train)
dense_results = getattr(dense_classifier, f)(X_test)
assert_array_almost_equal(sparse_results, dense_results)

sparse_type = type(X_train_sparse)
types = [i.data_type_ for i in sparse_classifier.estimators_]

assert all([t == sparse_type for t in types])
X_train_sparse = sparse_format(X_train)
X_test_sparse = sparse_format(X_test)
# Trained on sparse format
sparse_classifier = BaggingClassifier(
base_estimator=CustomSVC(kernel="linear",
decision_function_shape='ovr'),
random_state=1,
**params
).fit(X_train_sparse, y_train)
sparse_results = getattr(sparse_classifier, method)(X_test_sparse)

# Trained on dense format
dense_classifier = BaggingClassifier(
base_estimator=CustomSVC(kernel="linear",
decision_function_shape='ovr'),
random_state=1,
**params
).fit(X_train, y_train)
dense_results = getattr(dense_classifier, method)(X_test)
assert_array_almost_equal(sparse_results, dense_results)

sparse_type = type(X_train_sparse)
types = [i.data_type_ for i in sparse_classifier.estimators_]

assert all([t == sparse_type for t in types])


def test_regression():
Expand Down
15 changes: 10 additions & 5 deletions sklearn/svm/_libsvm.pxi
@@ -1,5 +1,10 @@
################################################################################
# Includes
cdef extern from "_svm_cython_blas_helpers.h":
ctypedef double (*dot_func)(int, double*, int, double*, int)
cdef struct BlasFunctions:
dot_func dot


cdef extern from "svm.h":
cdef struct svm_node
Expand Down Expand Up @@ -32,9 +37,9 @@ cdef extern from "svm.h":
double *W # instance weights

char *svm_check_parameter(svm_problem *, svm_parameter *)
svm_model *svm_train(svm_problem *, svm_parameter *, int *) nogil
svm_model *svm_train(svm_problem *, svm_parameter *, int *, BlasFunctions *) nogil
void svm_free_and_destroy_model(svm_model** model_ptr_ptr)
void svm_cross_validation(svm_problem *, svm_parameter *, int nr_fold, double *target) nogil
void svm_cross_validation(svm_problem *, svm_parameter *, int nr_fold, double *target, BlasFunctions *) nogil


cdef extern from "libsvm_helper.c":
Expand All @@ -54,9 +59,9 @@ cdef extern from "libsvm_helper.c":
void copy_intercept (char *, svm_model *, np.npy_intp *)
void copy_SV (char *, svm_model *, np.npy_intp *)
int copy_support (char *data, svm_model *model)
int copy_predict (char *, svm_model *, np.npy_intp *, char *) nogil
int copy_predict_proba (char *, svm_model *, np.npy_intp *, char *) nogil
int copy_predict_values(char *, svm_model *, np.npy_intp *, char *, int) nogil
int copy_predict (char *, svm_model *, np.npy_intp *, char *, BlasFunctions *) nogil
int copy_predict_proba (char *, svm_model *, np.npy_intp *, char *, BlasFunctions *) nogil
int copy_predict_values(char *, svm_model *, np.npy_intp *, char *, int, BlasFunctions *) nogil
void copy_nSV (char *, svm_model *)
void copy_probA (char *, svm_model *, np.npy_intp *)
void copy_probB (char *, svm_model *, np.npy_intp *)
Expand Down
24 changes: 16 additions & 8 deletions sklearn/svm/_libsvm.pyx
Expand Up @@ -34,6 +34,7 @@ import warnings
import numpy as np
cimport numpy as np
from libc.stdlib cimport free
from ..utils._cython_blas cimport _dot

include "_libsvm.pxi"

Expand Down Expand Up @@ -189,11 +190,12 @@ def fit(
# for SVR: epsilon is called p in libsvm
error_repl = error_msg.decode('utf-8').replace("p < 0", "epsilon < 0")
raise ValueError(error_repl)

cdef BlasFunctions blas_functions
blas_functions.dot = _dot[double]
# this does the real work
cdef int fit_status = 0
with nogil:
model = svm_train(&problem, &param, &fit_status)
model = svm_train(&problem, &param, &fit_status, &blas_functions)

# from here until the end, we just copy the data returned by
# svm_train
Expand Down Expand Up @@ -352,12 +354,13 @@ def predict(np.ndarray[np.float64_t, ndim=2, mode='c'] X,
model = set_model(&param, <int> nSV.shape[0], SV.data, SV.shape,
support.data, support.shape, sv_coef.strides,
sv_coef.data, intercept.data, nSV.data, probA.data, probB.data)

cdef BlasFunctions blas_functions
blas_functions.dot = _dot[double]
#TODO: use check_model
try:
dec_values = np.empty(X.shape[0])
with nogil:
rv = copy_predict(X.data, model, X.shape, dec_values.data)
rv = copy_predict(X.data, model, X.shape, dec_values.data, &blas_functions)
if rv < 0:
raise MemoryError("We've run out of memory")
finally:
Expand Down Expand Up @@ -457,10 +460,12 @@ def predict_proba(
probA.data, probB.data)

cdef np.npy_intp n_class = get_nr(model)
cdef BlasFunctions blas_functions
blas_functions.dot = _dot[double]
try:
dec_values = np.empty((X.shape[0], n_class), dtype=np.float64)
with nogil:
rv = copy_predict_proba(X.data, model, X.shape, dec_values.data)
rv = copy_predict_proba(X.data, model, X.shape, dec_values.data, &blas_functions)
if rv < 0:
raise MemoryError("We've run out of memory")
finally:
Expand Down Expand Up @@ -561,11 +566,12 @@ def decision_function(
else:
n_class = get_nr(model)
n_class = n_class * (n_class - 1) // 2

cdef BlasFunctions blas_functions
blas_functions.dot = _dot[double]
try:
dec_values = np.empty((X.shape[0], n_class), dtype=np.float64)
with nogil:
rv = copy_predict_values(X.data, model, X.shape, dec_values.data, n_class)
rv = copy_predict_values(X.data, model, X.shape, dec_values.data, n_class, &blas_functions)
if rv < 0:
raise MemoryError("We've run out of memory")
finally:
Expand Down Expand Up @@ -704,10 +710,12 @@ def cross_validation(
raise ValueError(error_msg)

cdef np.ndarray[np.float64_t, ndim=1, mode='c'] target
cdef BlasFunctions blas_functions
blas_functions.dot = _dot[double]
try:
target = np.empty((X.shape[0]), dtype=np.float64)
with nogil:
svm_cross_validation(&problem, &param, n_fold, <double *> target.data)
svm_cross_validation(&problem, &param, n_fold, <double *> target.data, &blas_functions)
finally:
free(problem.x)

Expand Down
40 changes: 27 additions & 13 deletions sklearn/svm/_libsvm_sparse.pyx
Expand Up @@ -3,23 +3,27 @@ import numpy as np
cimport numpy as np
from scipy import sparse
from ..exceptions import ConvergenceWarning

from ..utils._cython_blas cimport _dot
np.import_array()


cdef extern from *:
ctypedef char* const_char_p "const char*"

################################################################################
# Includes

cdef extern from "_svm_cython_blas_helpers.h":
ctypedef double (*dot_func)(int, double*, int, double*, int)
cdef struct BlasFunctions:
dot_func dot

cdef extern from "svm.h":
cdef struct svm_csr_node
cdef struct svm_csr_model
cdef struct svm_parameter
cdef struct svm_csr_problem
char *svm_csr_check_parameter(svm_csr_problem *, svm_parameter *)
svm_csr_model *svm_csr_train(svm_csr_problem *, svm_parameter *, int *) nogil
svm_csr_model *svm_csr_train(svm_csr_problem *, svm_parameter *, int *, BlasFunctions *) nogil
void svm_csr_free_and_destroy_model(svm_csr_model** model_ptr_ptr)

cdef extern from "libsvm_sparse_helper.c":
Expand All @@ -39,18 +43,18 @@ cdef extern from "libsvm_sparse_helper.c":
void copy_sv_coef (char *, svm_csr_model *)
void copy_support (char *, svm_csr_model *)
void copy_intercept (char *, svm_csr_model *, np.npy_intp *)
int copy_predict (char *, svm_csr_model *, np.npy_intp *, char *)
int copy_predict (char *, svm_csr_model *, np.npy_intp *, char *, BlasFunctions *)
int csr_copy_predict_values (np.npy_intp *data_size, char *data, np.npy_intp *index_size,
char *index, np.npy_intp *intptr_size, char *size,
svm_csr_model *model, char *dec_values, int nr_class)
svm_csr_model *model, char *dec_values, int nr_class, BlasFunctions *)
int csr_copy_predict (np.npy_intp *data_size, char *data, np.npy_intp *index_size,
char *index, np.npy_intp *intptr_size, char *size,
svm_csr_model *model, char *dec_values) nogil
svm_csr_model *model, char *dec_values, BlasFunctions *) nogil
int csr_copy_predict_proba (np.npy_intp *data_size, char *data, np.npy_intp *index_size,
char *index, np.npy_intp *intptr_size, char *size,
svm_csr_model *model, char *dec_values) nogil
svm_csr_model *model, char *dec_values, BlasFunctions *) nogil

int copy_predict_values(char *, svm_csr_model *, np.npy_intp *, char *, int)
int copy_predict_values(char *, svm_csr_model *, np.npy_intp *, char *, int, BlasFunctions *)
int csr_copy_SV (char *values, np.npy_intp *n_indices,
char *indices, np.npy_intp *n_indptr, char *indptr,
svm_csr_model *model, int n_features)
Expand Down Expand Up @@ -145,11 +149,12 @@ def libsvm_sparse_train ( int n_features,
free_problem(problem)
free_param(param)
raise ValueError(error_msg)

cdef BlasFunctions blas_functions
blas_functions.dot = _dot[double]
# call svm_train, this does the real work
cdef int fit_status = 0
with nogil:
model = svm_csr_train(problem, param, &fit_status)
model = svm_csr_train(problem, param, &fit_status, &blas_functions)

cdef np.npy_intp SV_len = get_l(model)
cdef np.npy_intp n_class = get_nr(model)
Expand Down Expand Up @@ -275,11 +280,14 @@ def libsvm_sparse_predict (np.ndarray[np.float64_t, ndim=1, mode='c'] T_data,
nSV.data, probA.data, probB.data)
#TODO: use check_model
dec_values = np.empty(T_indptr.shape[0]-1)
cdef BlasFunctions blas_functions
blas_functions.dot = _dot[double]
with nogil:
rv = csr_copy_predict(T_data.shape, T_data.data,
T_indices.shape, T_indices.data,
T_indptr.shape, T_indptr.data,
model, dec_values.data)
model, dec_values.data,
&blas_functions)
if rv < 0:
raise MemoryError("We've run out of memory")
# free model and param
Expand Down Expand Up @@ -331,11 +339,14 @@ def libsvm_sparse_predict_proba(
cdef np.npy_intp n_class = get_nr(model)
cdef int rv
dec_values = np.empty((T_indptr.shape[0]-1, n_class), dtype=np.float64)
cdef BlasFunctions blas_functions
blas_functions.dot = _dot[double]
with nogil:
rv = csr_copy_predict_proba(T_data.shape, T_data.data,
T_indices.shape, T_indices.data,
T_indptr.shape, T_indptr.data,
model, dec_values.data)
model, dec_values.data,
&blas_functions)
if rv < 0:
raise MemoryError("We've run out of memory")
# free model and param
Expand Down Expand Up @@ -397,10 +408,13 @@ def libsvm_sparse_decision_function(
n_class = n_class * (n_class - 1) // 2

dec_values = np.empty((T_indptr.shape[0] - 1, n_class), dtype=np.float64)
cdef BlasFunctions blas_functions
blas_functions.dot = _dot[double]
if csr_copy_predict_values(T_data.shape, T_data.data,
T_indices.shape, T_indices.data,
T_indptr.shape, T_indptr.data,
model, dec_values.data, n_class) < 0:
model, dec_values.data, n_class,
&blas_functions) < 0:
raise MemoryError("We've run out of memory")
# free model and param
free_model_SV(model)
Expand Down
2 changes: 1 addition & 1 deletion sklearn/svm/src/libsvm/LIBSVM_CHANGES
Expand Up @@ -6,5 +6,5 @@ This is here mainly as checklist for incorporation of new versions of libsvm.
* Add random_seed support and call to srand in fit function
* Improved random number generator (fix on windows, enhancement on other
platforms). See <https://github.com/scikit-learn/scikit-learn/pull/13511#issuecomment-481729756>

* invoke scipy blas api for svm kernel function to improve performance with speedup rate of 1.5X to 2X for dense data only. See <https://github.com/scikit-learn/scikit-learn/pull/16530>
The changes made with respect to upstream are detailed in the heading of svm.cpp

0 comments on commit 4ec06cd

Please sign in to comment.