-
-
Notifications
You must be signed in to change notification settings - Fork 25.3k
/
base.py
427 lines (351 loc) · 14.1 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
"""
Generalized Linear models.
"""
# Author: Alexandre Gramfort <alexandre.gramfort@inria.fr>
# Fabian Pedregosa <fabian.pedregosa@inria.fr>
# Olivier Grisel <olivier.grisel@ensta.org>
# Vincent Michel <vincent.michel@inria.fr>
# Peter Prettenhofer <peter.prettenhofer@gmail.com>
# Mathieu Blondel <mathieu@mblondel.org>
# Lars Buitinck <L.J.Buitinck@uva.nl>
#
# License: BSD 3 clause
from __future__ import division
from abc import ABCMeta, abstractmethod
import numbers
import warnings
import numpy as np
import scipy.sparse as sp
from scipy import linalg
from scipy import sparse
from ..externals import six
from ..externals.joblib import Parallel, delayed
from ..base import BaseEstimator, ClassifierMixin, RegressorMixin
from ..utils import as_float_array, check_array
from ..utils.extmath import safe_sparse_dot
from ..utils.sparsefuncs import mean_variance_axis, inplace_column_scale
from ..utils.fixes import sparse_lsqr
###
### TODO: intercept for all models
### We should define a common function to center data instead of
### repeating the same code inside each fit method.
### TODO: bayesian_ridge_regression and bayesian_regression_ard
### should be squashed into its respective objects.
def sparse_center_data(X, y, fit_intercept, normalize=False):
"""
Compute information needed to center data to have mean zero along
axis 0. Be aware that X will not be centered since it would break
the sparsity, but will be normalized if asked so.
"""
if fit_intercept:
# we might require not to change the csr matrix sometimes
# store a copy if normalize is True.
# Change dtype to float64 since mean_variance_axis accepts
# it that way.
if sp.isspmatrix(X) and X.getformat() == 'csr':
X = sp.csr_matrix(X, copy=normalize, dtype=np.float64)
else:
X = sp.csc_matrix(X, copy=normalize, dtype=np.float64)
X_mean, X_var = mean_variance_axis(X, axis=0)
if normalize:
# transform variance to std in-place
# XXX: currently scaled to variance=n_samples to match center_data
X_var *= X.shape[0]
X_std = np.sqrt(X_var, X_var)
del X_var
X_std[X_std == 0] = 1
inplace_column_scale(X, 1. / X_std)
else:
X_std = np.ones(X.shape[1])
y_mean = y.mean(axis=0)
y = y - y_mean
else:
X_mean = np.zeros(X.shape[1])
X_std = np.ones(X.shape[1])
y_mean = 0. if y.ndim == 1 else np.zeros(y.shape[1], dtype=X.dtype)
return X, y, X_mean, y_mean, X_std
def center_data(X, y, fit_intercept, normalize=False, copy=True,
sample_weight=None):
"""
Centers data to have mean zero along axis 0. This is here because
nearly all linear models will want their data to be centered.
If sample_weight is not None, then the weighted mean of X and y
is zero, and not the mean itself
"""
X = as_float_array(X, copy)
if fit_intercept:
if isinstance(sample_weight, numbers.Number):
sample_weight = None
if sp.issparse(X):
X_mean = np.zeros(X.shape[1])
X_std = np.ones(X.shape[1])
else:
X_mean = np.average(X, axis=0, weights=sample_weight)
X -= X_mean
if normalize:
# XXX: currently scaled to variance=n_samples
X_std = np.sqrt(np.sum(X ** 2, axis=0))
X_std[X_std == 0] = 1
X /= X_std
else:
X_std = np.ones(X.shape[1])
y_mean = np.average(y, axis=0, weights=sample_weight)
y = y - y_mean
else:
X_mean = np.zeros(X.shape[1])
X_std = np.ones(X.shape[1])
y_mean = 0. if y.ndim == 1 else np.zeros(y.shape[1], dtype=X.dtype)
return X, y, X_mean, y_mean, X_std
class LinearModel(six.with_metaclass(ABCMeta, BaseEstimator)):
"""Base class for Linear Models"""
@abstractmethod
def fit(self, X, y):
"""Fit model."""
def decision_function(self, X):
"""Decision function of the linear model.
Parameters
----------
X : {array-like, sparse matrix}, shape = (n_samples, n_features)
Samples.
Returns
-------
C : array, shape = (n_samples,)
Returns predicted values.
"""
X = check_array(X, accept_sparse=['csr', 'csc', 'coo'])
return safe_sparse_dot(X, self.coef_.T,
dense_output=True) + self.intercept_
def predict(self, X):
"""Predict using the linear model
Parameters
----------
X : {array-like, sparse matrix}, shape = (n_samples, n_features)
Samples.
Returns
-------
C : array, shape = (n_samples,)
Returns predicted values.
"""
return self.decision_function(X)
_center_data = staticmethod(center_data)
def _set_intercept(self, X_mean, y_mean, X_std):
"""Set the intercept_
"""
if self.fit_intercept:
self.coef_ = self.coef_ / X_std
self.intercept_ = y_mean - np.dot(X_mean, self.coef_.T)
else:
self.intercept_ = 0.
# XXX Should this derive from LinearModel? It should be a mixin, not an ABC.
# Maybe the n_features checking can be moved to LinearModel.
class LinearClassifierMixin(ClassifierMixin):
"""Mixin for linear classifiers.
Handles prediction for sparse and dense X.
"""
def decision_function(self, X):
"""Predict confidence scores for samples.
The confidence score for a sample is the signed distance of that
sample to the hyperplane.
Parameters
----------
X : {array-like, sparse matrix}, shape = (n_samples, n_features)
Samples.
Returns
-------
array, shape=(n_samples,) if n_classes == 2 else (n_samples, n_classes)
Confidence scores per (sample, class) combination. In the binary
case, confidence score for self.classes_[1] where >0 means this
class would be predicted.
"""
X = check_array(X, accept_sparse='csr')
n_features = self.coef_.shape[1]
if X.shape[1] != n_features:
raise ValueError("X has %d features per sample; expecting %d"
% (X.shape[1], n_features))
scores = safe_sparse_dot(X, self.coef_.T,
dense_output=True) + self.intercept_
return scores.ravel() if scores.shape[1] == 1 else scores
def predict(self, X):
"""Predict class labels for samples in X.
Parameters
----------
X : {array-like, sparse matrix}, shape = [n_samples, n_features]
Samples.
Returns
-------
C : array, shape = [n_samples]
Predicted class label per sample.
"""
scores = self.decision_function(X)
if len(scores.shape) == 1:
indices = (scores > 0).astype(np.int)
else:
indices = scores.argmax(axis=1)
return self.classes_[indices]
def _predict_proba_lr(self, X):
"""Probability estimation for OvR logistic regression.
Positive class probabilities are computed as
1. / (1. + np.exp(-self.decision_function(X)));
multiclass is handled by normalizing that over all classes.
"""
prob = self.decision_function(X)
prob *= -1
np.exp(prob, prob)
prob += 1
np.reciprocal(prob, prob)
if len(prob.shape) == 1:
return np.vstack([1 - prob, prob]).T
else:
# OvR normalization, like LibLinear's predict_probability
prob /= prob.sum(axis=1).reshape((prob.shape[0], -1))
return prob
class SparseCoefMixin(object):
"""Mixin for converting coef_ to and from CSR format.
L1-regularizing estimators should inherit this.
"""
def densify(self):
"""Convert coefficient matrix to dense array format.
Converts the ``coef_`` member (back) to a numpy.ndarray. This is the
default format of ``coef_`` and is required for fitting, so calling
this method is only required on models that have previously been
sparsified; otherwise, it is a no-op.
Returns
-------
self: estimator
"""
if not hasattr(self, "coef_"):
raise ValueError("Estimator must be fitted before densifying.")
if sp.issparse(self.coef_):
self.coef_ = self.coef_.toarray()
return self
def sparsify(self):
"""Convert coefficient matrix to sparse format.
Converts the ``coef_`` member to a scipy.sparse matrix, which for
L1-regularized models can be much more memory- and storage-efficient
than the usual numpy.ndarray representation.
The ``intercept_`` member is not converted.
Notes
-----
For non-sparse models, i.e. when there are not many zeros in ``coef_``,
this may actually *increase* memory usage, so use this method with
care. A rule of thumb is that the number of zero elements, which can
be computed with ``(coef_ == 0).sum()``, must be more than 50% for this
to provide significant benefits.
After calling this method, further fitting with the partial_fit
method (if any) will not work until you call densify.
Returns
-------
self: estimator
"""
if not hasattr(self, "coef_"):
raise ValueError("Estimator must be fitted before sparsifying.")
self.coef_ = sp.csr_matrix(self.coef_)
return self
class LinearRegression(LinearModel, RegressorMixin):
"""
Ordinary least squares Linear Regression.
Parameters
----------
fit_intercept : boolean, optional
whether to calculate the intercept for this model. If set
to false, no intercept will be used in calculations
(e.g. data is expected to be already centered).
normalize : boolean, optional, default False
If True, the regressors X will be normalized before regression.
copy_X : boolean, optional, default True
If True, X will be copied; else, it may be overwritten.
n_jobs : The number of jobs to use for the computation.
If -1 all CPUs are used. This will only provide speedup for
n_targets > 1 and sufficient large problems.
Attributes
----------
coef_ : array, shape (n_features, ) or (n_targets, n_features)
Estimated coefficients for the linear regression problem.
If multiple targets are passed during the fit (y 2D), this
is a 2D array of shape (n_targets, n_features), while if only
one target is passed, this is a 1D array of length n_features.
intercept_ : array
Independent term in the linear model.
Notes
-----
From the implementation point of view, this is just plain Ordinary
Least Squares (scipy.linalg.lstsq) wrapped as a predictor object.
"""
def __init__(self, fit_intercept=True, normalize=False, copy_X=True,
n_jobs=1):
self.fit_intercept = fit_intercept
self.normalize = normalize
self.copy_X = copy_X
self.n_jobs = n_jobs
def fit(self, X, y, n_jobs=1):
"""
Fit linear model.
Parameters
----------
X : numpy array or sparse matrix of shape [n_samples,n_features]
Training data
y : numpy array of shape [n_samples, n_targets]
Target values
Returns
-------
self : returns an instance of self.
"""
if n_jobs != 1:
warnings.warn("The n_jobs parameter in fit is deprecated and will "
"be removed in 0.17. It has been moved from the fit "
"method to the LinearRegression class constructor.",
DeprecationWarning, stacklevel=2)
n_jobs_ = n_jobs
else:
n_jobs_ = self.n_jobs
X = check_array(X, accept_sparse=['csr', 'csc', 'coo'])
y = np.asarray(y)
X, y, X_mean, y_mean, X_std = self._center_data(
X, y, self.fit_intercept, self.normalize, self.copy_X)
if sp.issparse(X):
if y.ndim < 2:
out = sparse_lsqr(X, y)
self.coef_ = out[0]
self.residues_ = out[3]
else:
# sparse_lstsq cannot handle y with shape (M, K)
outs = Parallel(n_jobs=n_jobs_)(
delayed(sparse_lsqr)(X, y[:, j].ravel())
for j in range(y.shape[1]))
self.coef_ = np.vstack(out[0] for out in outs)
self.residues_ = np.vstack(out[3] for out in outs)
else:
self.coef_, self.residues_, self.rank_, self.singular_ = \
linalg.lstsq(X, y)
self.coef_ = self.coef_.T
if y.ndim == 1:
self.coef_ = np.ravel(self.coef_)
self._set_intercept(X_mean, y_mean, X_std)
return self
def _pre_fit(X, y, Xy, precompute, normalize, fit_intercept, copy):
"""Aux function used at beginning of fit in linear models"""
n_samples, n_features = X.shape
if sparse.isspmatrix(X):
precompute = False
X, y, X_mean, y_mean, X_std = sparse_center_data(
X, y, fit_intercept, normalize)
else:
# copy was done in fit if necessary
X, y, X_mean, y_mean, X_std = center_data(
X, y, fit_intercept, normalize, copy=copy)
if hasattr(precompute, '__array__') \
and not np.allclose(X_mean, np.zeros(n_features)) \
and not np.allclose(X_std, np.ones(n_features)):
# recompute Gram
precompute = 'auto'
Xy = None
# precompute if n_samples > n_features
if precompute == 'auto':
precompute = (n_samples > n_features)
if precompute is True:
precompute = np.dot(X.T, X)
if not hasattr(precompute, '__array__'):
Xy = None # cannot use Xy if precompute is not Gram
if hasattr(precompute, '__array__') and Xy is None:
Xy = np.dot(X.T, y)
return X, y, X_mean, y_mean, X_std, precompute, Xy