/
kernel_ridge.pyx
303 lines (253 loc) · 10.5 KB
/
kernel_ridge.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# distutils: language = c++
from cuml.internals.safe_imports import cpu_only_import
import warnings
from cuml.internals.safe_imports import gpu_only_import_from
from cuml.internals.safe_imports import gpu_only_import
from cupyx import lapack, geterr, seterr
from cuml.common.array_descriptor import CumlArrayDescriptor
from cuml.internals.base import Base
from cuml.internals.mixins import RegressorMixin
from cuml.common.doc_utils import generate_docstring
from cuml.common import input_to_cuml_array
from cuml.metrics import pairwise_kernels
cp = gpu_only_import('cupy')
linalg = gpu_only_import_from('cupy', 'linalg')
np = cpu_only_import('numpy')
# cholesky solve with fallback to least squares for singular problems
def _safe_solve(K, y):
try:
# we need to set the error mode of cupy to raise
# otherwise we silently get an array of NaNs
err_mode = geterr()["linalg"]
seterr(linalg="raise")
dual_coef = lapack.posv(K, y)
# Perform following check as a workaround for cusolver issue to be
# fixed in a future CUDA version
if cp.all(cp.isnan(dual_coef)):
raise np.linalg.LinAlgError
seterr(linalg=err_mode)
except np.linalg.LinAlgError:
warnings.warn(
"Singular matrix in solving dual problem. Using "
"least-squares solution instead."
)
dual_coef = linalg.lstsq(K, y, rcond=None)[0]
return dual_coef
def _solve_cholesky_kernel(K, y, alpha, sample_weight=None):
# dual_coef = inv(X X^t + alpha*Id) y
n_samples = K.shape[0]
n_targets = y.shape[1]
K = cp.array(K, dtype=np.float64)
alpha = cp.atleast_1d(alpha)
one_alpha = alpha.size == 1
has_sw = sample_weight is not None
if has_sw:
# Unlike other solvers, we need to support sample_weight directly
# because K might be a pre-computed kernel.
sw = cp.sqrt(cp.atleast_1d(sample_weight))
y = y * sw[:, cp.newaxis]
K *= cp.outer(sw, sw)
if one_alpha:
# Only one penalty, we can solve multi-target problems in one time.
K.flat[:: n_samples + 1] += alpha[0]
dual_coef = _safe_solve(K, y)
if has_sw:
dual_coef *= sw[:, cp.newaxis]
return dual_coef
else:
# One penalty per target. We need to solve each target separately.
dual_coefs = cp.empty([n_targets, n_samples], K.dtype)
for dual_coef, target, current_alpha in zip(dual_coefs, y.T, alpha):
K.flat[:: n_samples + 1] += current_alpha
dual_coef[:] = _safe_solve(K, target).ravel()
K.flat[:: n_samples + 1] -= current_alpha
if has_sw:
dual_coefs *= sw[cp.newaxis, :]
return dual_coefs.T
class KernelRidge(Base, RegressorMixin):
"""
Kernel ridge regression (KRR) performs l2 regularised ridge regression
using the kernel trick. The kernel trick allows the estimator to learn a
linear function in the space induced by the kernel. This may be a
non-linear function in the original feature space (when a non-linear
kernel is used).
This estimator supports multi-output regression (when y is 2 dimensional).
See the sklearn user guide for more information.
Parameters
----------
alpha : float or array-like of shape (n_targets,), default=1.0
Regularization strength; must be a positive float. Regularization
improves the conditioning of the problem and reduces the variance of
the estimates. Larger values specify stronger regularization.
If an array is passed, penalties are assumed to be specific
to the targets.
kernel : str or callable, default="linear"
Kernel mapping used internally. This parameter is directly passed to
:class:`~cuml.metrics.pairwise_kernel`.
If `kernel` is a string, it must be one of the metrics
in `cuml.metrics.PAIRWISE_KERNEL_FUNCTIONS` or "precomputed".
If `kernel` is "precomputed", X is assumed to be a kernel matrix.
`kernel` may be a callable numba device function. If so, is called on
each pair of instances (rows) and the resulting value recorded.
gamma : float, default=None
Gamma parameter for the RBF, laplacian, polynomial, exponential chi2
and sigmoid kernels. Interpretation of the default value is left to
the kernel; see the documentation for sklearn.metrics.pairwise.
Ignored by other kernels.
degree : float, default=3
Degree of the polynomial kernel. Ignored by other kernels.
coef0 : float, default=1
Zero coefficient for polynomial and sigmoid kernels.
Ignored by other kernels.
kernel_params : mapping of str to any, default=None
Additional parameters (keyword arguments) for kernel function passed
as callable object.
output_type : {'input', 'array', 'dataframe', 'series', 'df_obj', \
'numba', 'cupy', 'numpy', 'cudf', 'pandas'}, default=None
Return results and set estimator attributes to the indicated output
type. If None, the output type set at the module level
(`cuml.global_settings.output_type`) will be used. See
:ref:`output-data-type-configuration` for more info.
handle : cuml.Handle
Specifies the cuml.handle that holds internal CUDA state for
computations in this model. Most importantly, this specifies the
CUDA stream that will be used for the model's computations, so
users can run different models concurrently in different streams
by creating handles in several streams.
If it is None, a new one is created.
verbose : int or boolean, default=False
Sets logging level. It must be one of `cuml.common.logger.level_*`.
See :ref:`verbosity-levels` for more info.
Attributes
----------
dual_coef_ : ndarray of shape (n_samples,) or (n_samples, n_targets)
Representation of weight vector(s) in kernel space
X_fit_ : ndarray of shape (n_samples, n_features)
Training data, which is also required for prediction. If
kernel == "precomputed" this is instead the precomputed
training matrix, of shape (n_samples, n_samples).
Examples
--------
.. code-block:: python
>>> import cupy as cp
>>> from cuml.kernel_ridge import KernelRidge
>>> from numba import cuda
>>> import math
>>> n_samples, n_features = 10, 5
>>> rng = cp.random.RandomState(0)
>>> y = rng.randn(n_samples)
>>> X = rng.randn(n_samples, n_features)
>>> model = KernelRidge(kernel="poly").fit(X, y)
>>> pred = model.predict(X)
>>> @cuda.jit(device=True)
... def custom_rbf_kernel(x, y, gamma=None):
... if gamma is None:
... gamma = 1.0 / len(x)
... sum = 0.0
... for i in range(len(x)):
... sum += (x[i] - y[i]) ** 2
... return math.exp(-gamma * sum)
>>> model = KernelRidge(kernel=custom_rbf_kernel,
... kernel_params={"gamma": 2.0}).fit(X, y)
>>> pred = model.predict(X)
"""
dual_coef_ = CumlArrayDescriptor()
def __init__(
self,
*,
alpha=1,
kernel="linear",
gamma=None,
degree=3,
coef0=1,
kernel_params=None,
output_type=None,
handle=None,
verbose=False
):
super().__init__(handle=handle, verbose=verbose,
output_type=output_type)
self.alpha = cp.asarray(alpha)
self.kernel = kernel
self.gamma = gamma
self.degree = degree
self.coef0 = coef0
self.kernel_params = kernel_params
def get_param_names(self):
return super().get_param_names() + [
"alpha",
"kernel",
"gamma",
"degree",
"coef0",
"kernel_params",
]
def _get_kernel(self, X, Y=None):
if isinstance(self.kernel, str):
params = {"gamma": self.gamma,
"degree": self.degree, "coef0": self.coef0}
else:
params = self.kernel_params or {}
return pairwise_kernels(X, Y, metric=self.kernel,
filter_params=True, **params)
@generate_docstring()
def fit(self, X, y, sample_weight=None,
convert_dtype=True) -> "KernelRidge":
ravel = False
if len(y.shape) == 1:
y = y.reshape(-1, 1)
ravel = True
X_m, n_rows, self.n_cols, self.dtype = input_to_cuml_array(
X, check_dtype=[np.float32, np.float64]
)
y_m, _, _, _ = input_to_cuml_array(
y,
check_dtype=self.dtype,
convert_to_dtype=(self.dtype if convert_dtype else None),
check_rows=n_rows,
)
if self.n_cols < 1:
msg = "X matrix must have at least a column"
raise TypeError(msg)
K = self._get_kernel(X_m)
self.dual_coef_ = _solve_cholesky_kernel(
K, cp.asarray(y_m), self.alpha, sample_weight
)
if ravel:
self.dual_coef_ = self.dual_coef_.ravel()
self.X_fit_ = X_m
return self
def predict(self, X):
"""
Predict using the kernel ridge model.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Samples. If kernel == "precomputed" this is instead a
precomputed kernel matrix, shape = [n_samples,
n_samples_fitted], where n_samples_fitted is the number of
samples used in the fitting for this estimator.
Returns
-------
C : array of shape (n_samples,) or (n_samples, n_targets)
Returns predicted values.
"""
X_m, _, _, _ = input_to_cuml_array(
X, check_dtype=[np.float32, np.float64])
K = self._get_kernel(X_m, self.X_fit_)
return cp.dot(cp.asarray(K), cp.asarray(self.dual_coef_))