Skip to content

Commit

Permalink
Merge branch 'master' of github.com:scikit-learn/scikit-learn
Browse files Browse the repository at this point in the history
  • Loading branch information
jnothman committed Feb 28, 2018
2 parents bf77164 + d9c2122 commit 3df03e5
Show file tree
Hide file tree
Showing 17 changed files with 136 additions and 137 deletions.
11 changes: 8 additions & 3 deletions doc/modules/cross_validation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ The ``cross_validate`` function differs from ``cross_val_score`` in two ways -

- It allows specifying multiple metrics for evaluation.

- It returns a dict containing training scores, fit-times and score-times in
- It returns a dict containing fit-times, score-times
(and optionally training scores as well as fitted estimators) in
addition to the test score.

For single metric evaluation, where the scoring parameter is a string,
Expand All @@ -196,6 +197,9 @@ following keys -
for all the scorers. If train scores are not needed, this should be set to
``False`` explicitly.

You may also retain the estimator fitted on each training set by setting
``return_estimator=True``.

The multiple metrics can be specified either as a list, tuple or set of
predefined scorer names::

Expand Down Expand Up @@ -226,9 +230,10 @@ Or as a dict mapping scorer name to a predefined or custom scoring function::
Here is an example of ``cross_validate`` using a single metric::

>>> scores = cross_validate(clf, iris.data, iris.target,
... scoring='precision_macro')
... scoring='precision_macro',
... return_estimator=True)
>>> sorted(scores.keys())
['fit_time', 'score_time', 'test_score', 'train_score']
['estimator', 'fit_time', 'score_time', 'test_score', 'train_score']


Obtaining predictions by cross-validation
Expand Down
11 changes: 11 additions & 0 deletions doc/whats_new/v0.20.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ Classifiers and regressors
via ``n_iter_no_change``, ``validation_fraction`` and ``tol``. :issue:`7071`
by `Raghav RV`_

- :class:`dummy.DummyRegressor` now has a ``return_std`` option in its
``predict`` method. The returned standard deviations will be zeros.

- Added :class:`naive_bayes.ComplementNB`, which implements the Complement
Naive Bayes classifier described in Rennie et al. (2003).
:issue:`8190` by :user:`Michael A. Alcorn <airalcorn2>`.
Expand Down Expand Up @@ -164,6 +167,10 @@ Model evaluation and meta-estimators
group-based CV strategies. :issue:`9085` by :user:`Laurent Direr <ldirer>`
and `Andreas Müller`_.

- Add `return_estimator` parameter in :func:`model_selection.cross_validate` to
return estimators fitted on each split. :issue:`9686` by :user:`Aurélien Bellet
<bellet>`.

Decomposition and manifold learning

- Speed improvements for both 'exact' and 'barnes_hut' methods in
Expand Down Expand Up @@ -253,6 +260,10 @@ Classifiers and regressors
overridden when using parameter ``copy_X=True`` and ``check_input=False``.
:issue:`10581` by :user:`Yacine Mazari <ymazari>`.

- Fixed a bug in :class:`sklearn.linear_model.Lasso`
where the coefficient had wrong shape when ``fit_intercept=False``.
:issue:`10687` by :user:`Martin Hahn <martin-hahn>`.

Decomposition, manifold learning and clustering

- Fix for uninformative error in :class:`decomposition.IncrementalPCA`:
Expand Down
15 changes: 12 additions & 3 deletions sklearn/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def fit(self, X, y, sample_weight=None):
self.constant_ = np.reshape(self.constant_, (1, -1))
return self

def predict(self, X):
def predict(self, X, return_std=False):
"""
Perform classification on test vectors X.
Expand All @@ -454,17 +454,26 @@ def predict(self, X):
X : {array-like, object with finite length or shape}
Training data, requires length = n_samples
return_std : boolean, optional
Whether to return the standard deviation of posterior prediction.
All zeros in this case.
Returns
-------
y : array, shape = [n_samples] or [n_samples, n_outputs]
Predicted target values for X.
y_std : array, shape = [n_samples] or [n_samples, n_outputs]
Standard deviation of predictive distribution of query points.
"""
check_is_fitted(self, "constant_")
n_samples = _num_samples(X)

y = np.ones((n_samples, 1)) * self.constant_
y = np.ones((n_samples, self.n_outputs_)) * self.constant_
y_std = np.zeros((n_samples, self.n_outputs_))

if self.n_outputs_ == 1 and not self.output_2d_:
y = np.ravel(y)
y_std = np.ravel(y_std)

return y
return (y, y_std) if return_std else y
28 changes: 16 additions & 12 deletions sklearn/impute.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,10 @@ def fit(self, X, y=None):
# transform(X), the imputation data will be computed in transform()
# when the imputation is done per sample (i.e., when axis=1).
if self.axis == 0:
X = check_array(X, accept_sparse='csc', dtype=np.float64,
force_all_finite=False)
X = check_array(X, accept_sparse='csc', dtype=FLOAT_DTYPES,
force_all_finite='allow-nan'
if self.missing_values == 'NaN'
or np.isnan(self.missing_values) else True)

if sparse.issparse(X):
self.statistics_ = self._sparse_fit(X,
Expand Down Expand Up @@ -249,7 +251,9 @@ def _sparse_fit(self, X, strategy, missing_values, axis):

def _dense_fit(self, X, strategy, missing_values, axis):
"""Fit the transformer on dense data."""
X = check_array(X, force_all_finite=False)
X = check_array(X, force_all_finite='allow-nan'
if self.missing_values == 'NaN'
or np.isnan(self.missing_values) else True)
mask = _get_mask(X, missing_values)
masked_X = ma.masked_array(X, mask=mask)

Expand All @@ -264,12 +268,6 @@ def _dense_fit(self, X, strategy, missing_values, axis):

# Median
elif strategy == "median":
if tuple(int(v) for v in np.__version__.split('.')[:2]) < (1, 5):
# In old versions of numpy, calling a median on an array
# containing nans returns nan. This is different is
# recent versions of numpy, which we want to mimic
masked_X.mask = np.logical_or(masked_X.mask,
np.isnan(X))
median_masked = np.ma.median(masked_X, axis=axis)
# Avoid the warning "Warning: converting a masked element to nan."
median = np.ma.getdata(median_masked)
Expand Down Expand Up @@ -309,7 +307,10 @@ def transform(self, X):
if self.axis == 0:
check_is_fitted(self, 'statistics_')
X = check_array(X, accept_sparse='csc', dtype=FLOAT_DTYPES,
force_all_finite=False, copy=self.copy)
force_all_finite='allow-nan'
if self.missing_values == 'NaN'
or np.isnan(self.missing_values) else True,
copy=self.copy)
statistics = self.statistics_
if X.shape[1] != statistics.shape[0]:
raise ValueError("X has %d features per sample, expected %d"
Expand All @@ -320,7 +321,10 @@ def transform(self, X):
# when the imputation is done per sample
else:
X = check_array(X, accept_sparse='csr', dtype=FLOAT_DTYPES,
force_all_finite=False, copy=self.copy)
force_all_finite='allow-nan'
if self.missing_values == 'NaN'
or np.isnan(self.missing_values) else True,
copy=self.copy)

if sparse.issparse(X):
statistics = self._sparse_fit(X,
Expand All @@ -338,7 +342,7 @@ def transform(self, X):
invalid_mask = np.isnan(statistics)
valid_mask = np.logical_not(invalid_mask)
valid_statistics = statistics[valid_mask]
valid_statistics_indexes = np.where(valid_mask)[0]
valid_statistics_indexes = np.flatnonzero(valid_mask)
missing = np.arange(X.shape[not self.axis])[invalid_mask]

if self.axis == 0 and invalid_mask.any():
Expand Down
6 changes: 5 additions & 1 deletion sklearn/linear_model/coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,8 +763,12 @@ def fit(self, X, y, check_input=True):

if n_targets == 1:
self.n_iter_ = self.n_iter_[0]
self.coef_ = coef_[0]
self.dual_gap_ = dual_gaps_[0]
else:
self.coef_ = coef_
self.dual_gap_ = dual_gaps_

self.coef_, self.dual_gap_ = map(np.squeeze, [coef_, dual_gaps_])
self._set_intercept(X_offset, y_offset, X_scale)

# workaround since _set_intercept will cast self.coef_ into X.dtype
Expand Down
20 changes: 6 additions & 14 deletions sklearn/linear_model/huber.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,24 +255,16 @@ def fit(self, X, y, sample_weight=None):
bounds = np.tile([-np.inf, np.inf], (parameters.shape[0], 1))
bounds[-1][0] = np.finfo(np.float64).eps * 10

# Type Error caused in old versions of SciPy because of no
# maxiter argument ( <= 0.9).
try:
parameters, f, dict_ = optimize.fmin_l_bfgs_b(
_huber_loss_and_gradient, parameters,
args=(X, y, self.epsilon, self.alpha, sample_weight),
maxiter=self.max_iter, pgtol=self.tol, bounds=bounds,
iprint=0)
except TypeError:
parameters, f, dict_ = optimize.fmin_l_bfgs_b(
_huber_loss_and_gradient, parameters,
args=(X, y, self.epsilon, self.alpha, sample_weight),
bounds=bounds)
parameters, f, dict_ = optimize.fmin_l_bfgs_b(
_huber_loss_and_gradient, parameters,
args=(X, y, self.epsilon, self.alpha, sample_weight),
maxiter=self.max_iter, pgtol=self.tol, bounds=bounds,
iprint=0)
if dict_['warnflag'] == 2:
raise ValueError("HuberRegressor convergence failed:"
" l-BFGS-b solver terminated with %s"
% dict_['task'].decode('ascii'))
self.n_iter_ = dict_.get('nit', None)
self.n_iter_ = dict_['nit']
self.scale_ = parameters[-1]
if self.fit_intercept:
self.intercept_ = parameters[-2]
Expand Down
23 changes: 7 additions & 16 deletions sklearn/linear_model/omp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
from ..model_selection import check_cv
from ..externals.joblib import Parallel, delayed

solve_triangular_args = {'check_finite': False}

premature = """ Orthogonal matching pursuit ended prematurely due to linear
dependence in the dictionary. The requested precision might not have been met.
"""
Expand Down Expand Up @@ -85,12 +83,8 @@ def _cholesky_omp(X, y, n_nonzero_coefs, tol=None, copy_X=True,
indices = np.arange(X.shape[1]) # keeping track of swapping

max_features = X.shape[1] if tol is not None else n_nonzero_coefs
if solve_triangular_args:
# new scipy, don't need to initialize because check_finite=False
L = np.empty((max_features, max_features), dtype=X.dtype)
else:
# old scipy, we need the garbage upper triangle to be non-Inf
L = np.zeros((max_features, max_features), dtype=X.dtype)

L = np.empty((max_features, max_features), dtype=X.dtype)

if return_path:
coefs = np.empty_like(L)
Expand All @@ -109,7 +103,7 @@ def _cholesky_omp(X, y, n_nonzero_coefs, tol=None, copy_X=True,
L[n_active, :n_active],
trans=0, lower=1,
overwrite_b=True,
**solve_triangular_args)
check_finite=False)
v = nrm2(L[n_active, :n_active]) ** 2
Lkk = linalg.norm(X[:, lam]) ** 2 - v
if Lkk <= min_float: # selected atoms are dependent
Expand Down Expand Up @@ -212,12 +206,9 @@ def _gram_omp(Gram, Xy, n_nonzero_coefs, tol_0=None, tol=None,
n_active = 0

max_features = len(Gram) if tol is not None else n_nonzero_coefs
if solve_triangular_args:
# new scipy, don't need to initialize because check_finite=False
L = np.empty((max_features, max_features), dtype=Gram.dtype)
else:
# old scipy, we need the garbage upper triangle to be non-Inf
L = np.zeros((max_features, max_features), dtype=Gram.dtype)

L = np.empty((max_features, max_features), dtype=Gram.dtype)

L[0, 0] = 1.
if return_path:
coefs = np.empty_like(L)
Expand All @@ -234,7 +225,7 @@ def _gram_omp(Gram, Xy, n_nonzero_coefs, tol_0=None, tol=None,
L[n_active, :n_active],
trans=0, lower=1,
overwrite_b=True,
**solve_triangular_args)
check_finite=False)
v = nrm2(L[n_active, :n_active]) ** 2
Lkk = Gram[lam, lam] - v
if Lkk <= min_float: # selected atoms are dependent
Expand Down
17 changes: 6 additions & 11 deletions sklearn/linear_model/ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ def ridge_regression(X, y, alpha, sample_weight=None, solver='auto',
(possibility to set `tol` and `max_iter`).
- 'lsqr' uses the dedicated regularized least-squares routine
scipy.sparse.linalg.lsqr. It is the fastest but may not be available
in old scipy versions. It also uses an iterative procedure.
scipy.sparse.linalg.lsqr. It is the fastest and uses an iterative
procedure.
- 'sag' uses a Stochastic Average Gradient descent, and 'saga' uses
its improved, unbiased version named SAGA. Both methods also use an
Expand Down Expand Up @@ -360,11 +360,6 @@ def ridge_regression(X, y, alpha, sample_weight=None, solver='auto',
else:
solver = 'sparse_cg'

elif solver == 'lsqr' and not hasattr(sp_linalg, 'lsqr'):
warnings.warn("""lsqr not available on this machine, falling back
to sparse_cg.""")
solver = 'sparse_cg'

if has_sw:
if np.atleast_1d(sample_weight).ndim > 1:
raise ValueError("Sample weights must be 1D array or scalar")
Expand Down Expand Up @@ -578,8 +573,8 @@ class Ridge(_BaseRidge, RegressorMixin):
(possibility to set `tol` and `max_iter`).
- 'lsqr' uses the dedicated regularized least-squares routine
scipy.sparse.linalg.lsqr. It is the fastest but may not be available
in old scipy versions. It also uses an iterative procedure.
scipy.sparse.linalg.lsqr. It is the fastest and uses an iterative
procedure.
- 'sag' uses a Stochastic Average Gradient descent, and 'saga' uses
its improved, unbiased version named SAGA. Both methods also use an
Expand Down Expand Up @@ -736,8 +731,8 @@ class RidgeClassifier(LinearClassifierMixin, _BaseRidge):
(possibility to set `tol` and `max_iter`).
- 'lsqr' uses the dedicated regularized least-squares routine
scipy.sparse.linalg.lsqr. It is the fastest but may not be available
in old scipy versions. It also uses an iterative procedure.
scipy.sparse.linalg.lsqr. It is the fastest and uses an iterative
procedure.
- 'sag' uses a Stochastic Average Gradient descent, and 'saga' uses
its unbiased and more flexible version named SAGA. Both methods
Expand Down
6 changes: 6 additions & 0 deletions sklearn/linear_model/tests/test_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,3 +803,9 @@ def test_enet_l1_ratio():
est.fit(X, y[:, None])
est_desired.fit(X, y[:, None])
assert_array_almost_equal(est.coef_, est_desired.coef_, decimal=5)


def test_coef_shape_not_zero():
est_no_intercept = Lasso(fit_intercept=False)
est_no_intercept.fit(np.c_[np.ones(3)], np.ones(3))
assert est_no_intercept.coef_.shape == (1,)
Loading

0 comments on commit 3df03e5

Please sign in to comment.