Skip to content

Commit

Permalink
Merge pull request #3178 from MechCoder/Iss3174
Browse files Browse the repository at this point in the history
[MRG+1] ENetCV and LassoCV now accept np.float32 input
  • Loading branch information
arjoly committed May 22, 2014
2 parents 7c9a353 + a015c4a commit d9cc662
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
6 changes: 4 additions & 2 deletions sklearn/linear_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@ def sparse_center_data(X, y, fit_intercept, normalize=False):
if fit_intercept:
# we might require not to change the csr matrix sometimes
# store a copy if normalize is True.
# Change dtype to float64 since mean_variance_axis0 accepts
# it that way.
if sp.isspmatrix(X) and X.getformat() == 'csr':
X = sp.csr_matrix(X, copy=normalize)
X = sp.csr_matrix(X, copy=normalize, dtype=np.float64)
else:
X = sp.csc_matrix(X, copy=normalize)
X = sp.csc_matrix(X, copy=normalize, dtype=np.float64)

X_mean, X_var = mean_variance_axis0(X)
if normalize:
Expand Down
19 changes: 18 additions & 1 deletion sklearn/linear_model/tests/test_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sys import version_info

import numpy as np
from scipy import interpolate
from scipy import interpolate, sparse

from sklearn.utils.testing import assert_array_almost_equal
from sklearn.utils.testing import assert_almost_equal
Expand Down Expand Up @@ -443,6 +443,23 @@ def test_1d_multioutput_lasso_and_multitask_lasso_cv():
assert_almost_equal(clf.intercept_, clf1.intercept_[0])


def test_sparse_input_dtype_enet_and_lassocv():
X, y, _, _ = build_dataset(n_features=10)
clf = ElasticNetCV(n_alphas=5)
clf.fit(sparse.csr_matrix(X), y)
clf1 = ElasticNetCV(n_alphas=5)
clf1.fit(sparse.csr_matrix(X, dtype=np.float32), y)
assert_almost_equal(clf.alpha_, clf1.alpha_, decimal=6)
assert_almost_equal(clf.coef_, clf1.coef_, decimal=6)

clf = LassoCV(n_alphas=5)
clf.fit(sparse.csr_matrix(X), y)
clf1 = LassoCV(n_alphas=5)
clf1.fit(sparse.csr_matrix(X, dtype=np.float32), y)
assert_almost_equal(clf.alpha_, clf1.alpha_, decimal=6)
assert_almost_equal(clf.coef_, clf1.coef_, decimal=6)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit d9cc662

Please sign in to comment.