Skip to content

Commit

Permalink
BUG Fix instability issue of ARDRegression (with speedup) (scikit-lea…
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored and viclafargue committed Jun 26, 2020
1 parent c6e9116 commit 442e319
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 156 deletions.
6 changes: 6 additions & 0 deletions doc/whats_new/v0.23.rst
Expand Up @@ -303,6 +303,12 @@ Changelog
of strictly inferior for maximum of `absgrad` and `tol` in `utils.optimize._newton_cg`.
:pr:`16266` by :user:`Rushabh Vasani <rushabh-v>`.

- |Fix| |Efficiency| :class:`linear_model.ARDRegression` is more stable and
much faster when `n_samples > n_features`. It can now scale to hundreds of
thousands of samples. The stability fix might imply changes in the number
of non-zero coefficients and in the predicted output. :pr:`16849` by
`Nicolas Hug`_.

- |Enhancement| :class:`linear_model.LassoLars` and
:class:`linear_model.Lars` now support a `jitter` parameter that adds
random noise to the target. This might help with stability in some edge
Expand Down
118 changes: 0 additions & 118 deletions sklearn/externals/_scipy_linalg.py

This file was deleted.

59 changes: 41 additions & 18 deletions sklearn/linear_model/_bayes.py
Expand Up @@ -12,7 +12,7 @@
from ._base import LinearModel, _rescale_data
from ..base import RegressorMixin
from ..utils.extmath import fast_logdet
from ..utils.fixes import pinvh
from scipy.linalg import pinvh
from ..utils.validation import _check_sample_weight
from ..utils.validation import _deprecate_positional_args

Expand Down Expand Up @@ -554,27 +554,16 @@ def fit(self, X, y):
self.scores_ = list()
coef_old_ = None

# Compute sigma and mu (using Woodbury matrix identity)
def update_sigma(X, alpha_, lambda_, keep_lambda, n_samples):
sigma_ = pinvh(np.eye(n_samples) / alpha_ +
np.dot(X[:, keep_lambda] *
np.reshape(1. / lambda_[keep_lambda], [1, -1]),
X[:, keep_lambda].T))
sigma_ = np.dot(sigma_, X[:, keep_lambda] *
np.reshape(1. / lambda_[keep_lambda], [1, -1]))
sigma_ = - np.dot(np.reshape(1. / lambda_[keep_lambda], [-1, 1]) *
X[:, keep_lambda].T, sigma_)
sigma_.flat[::(sigma_.shape[1] + 1)] += 1. / lambda_[keep_lambda]
return sigma_

def update_coeff(X, y, coef_, alpha_, keep_lambda, sigma_):
coef_[keep_lambda] = alpha_ * np.dot(
sigma_, np.dot(X[:, keep_lambda].T, y))
return coef_

update_sigma = (self._update_sigma if n_samples >= n_features
else self._update_sigma_woodbury)
# Iterative procedure of ARDRegression
for iter_ in range(self.n_iter):
sigma_ = update_sigma(X, alpha_, lambda_, keep_lambda, n_samples)
sigma_ = update_sigma(X, alpha_, lambda_, keep_lambda)
coef_ = update_coeff(X, y, coef_, alpha_, keep_lambda, sigma_)

# Update alpha and lambda
Expand Down Expand Up @@ -606,9 +595,15 @@ def update_coeff(X, y, coef_, alpha_, keep_lambda, sigma_):
break
coef_old_ = np.copy(coef_)

# update sigma and mu using updated parameters from the last iteration
sigma_ = update_sigma(X, alpha_, lambda_, keep_lambda, n_samples)
coef_ = update_coeff(X, y, coef_, alpha_, keep_lambda, sigma_)
if not keep_lambda.any():
break

if keep_lambda.any():
# update sigma and mu using updated params from the last iteration
sigma_ = update_sigma(X, alpha_, lambda_, keep_lambda)
coef_ = update_coeff(X, y, coef_, alpha_, keep_lambda, sigma_)
else:
sigma_ = np.array([]).reshape(0, 0)

self.coef_ = coef_
self.alpha_ = alpha_
Expand All @@ -617,6 +612,34 @@ def update_coeff(X, y, coef_, alpha_, keep_lambda, sigma_):
self._set_intercept(X_offset_, y_offset_, X_scale_)
return self

def _update_sigma_woodbury(self, X, alpha_, lambda_, keep_lambda):
# See slides as referenced in the docstring note
# this function is used when n_samples < n_features and will invert
# a matrix of shape (n_samples, n_samples) making use of the
# woodbury formula:
# https://en.wikipedia.org/wiki/Woodbury_matrix_identity
n_samples = X.shape[0]
X_keep = X[:, keep_lambda]
inv_lambda = 1 / lambda_[keep_lambda].reshape(1, -1)
sigma_ = pinvh(
np.eye(n_samples) / alpha_ + np.dot(X_keep * inv_lambda, X_keep.T)
)
sigma_ = np.dot(sigma_, X_keep * inv_lambda)
sigma_ = - np.dot(inv_lambda.reshape(-1, 1) * X_keep.T, sigma_)
sigma_[np.diag_indices(sigma_.shape[1])] += 1. / lambda_[keep_lambda]
return sigma_

def _update_sigma(self, X, alpha_, lambda_, keep_lambda):
# See slides as referenced in the docstring note
# this function is used when n_samples >= n_features and will
# invert a matrix of shape (n_features, n_features)
X_keep = X[:, keep_lambda]
gram = np.dot(X_keep.T, X_keep)
eye = np.eye(gram.shape[0])
sigma_inv = lambda_[keep_lambda] * eye + alpha_ * gram
sigma_ = pinvh(sigma_inv)
return sigma_

def predict(self, X, return_std=False):
"""Predict using the linear model.
Expand Down
48 changes: 36 additions & 12 deletions sklearn/linear_model/tests/test_bayes.py
Expand Up @@ -7,6 +7,8 @@

import numpy as np
from scipy.linalg import pinvh
import pytest


from sklearn.utils._testing import assert_array_almost_equal
from sklearn.utils._testing import assert_almost_equal
Expand Down Expand Up @@ -159,7 +161,7 @@ def test_std_bayesian_ridge_ard_with_constant_input():
# Test BayesianRidge and ARDRegression standard dev. for edge case of
# constant target vector
# The standard dev. should be relatively small (< 0.01 is tested here)
n_samples = 4
n_samples = 10
n_features = 5
random_state = check_random_state(42)
constant_value = random_state.rand()
Expand All @@ -181,9 +183,9 @@ def test_update_of_sigma_in_ard():
y = np.array([0, 0])
clf = ARDRegression(n_iter=1)
clf.fit(X, y)
# With the inputs above, ARDRegression prunes one of the two coefficients
# in the first iteration. Hence, the expected shape of `sigma_` is (1, 1).
assert clf.sigma_.shape == (1, 1)
# With the inputs above, ARDRegression prunes both of the two coefficients
# in the first iteration. Hence, the expected shape of `sigma_` is (0, 0).
assert clf.sigma_.shape == (0, 0)
# Ensure that no error is thrown at prediction stage
clf.predict(X, return_std=True)

Expand All @@ -200,22 +202,19 @@ def test_toy_ard_object():
assert_array_almost_equal(clf.predict(test), [1, 3, 4], 2)


def test_ard_accuracy_on_easy_problem():
@pytest.mark.parametrize('seed', range(100))
@pytest.mark.parametrize('n_samples, n_features', ((10, 100), (100, 10)))
def test_ard_accuracy_on_easy_problem(seed, n_samples, n_features):
# Check that ARD converges with reasonable accuracy on an easy problem
# (Github issue #14055)
# This particular seed seems to converge poorly in the failure-case
# (scipy==1.3.0, sklearn==0.21.2)
seed = 45
X = np.random.RandomState(seed=seed).normal(size=(250, 3))
y = X[:, 1]

regressor = ARDRegression(n_iter=600)
regressor = ARDRegression()
regressor.fit(X, y)

abs_coef_error = np.abs(1 - regressor.coef_[1])
# Expect an accuracy of better than 1E-4 in most cases -
# Failure-case produces 0.16!
assert abs_coef_error < 0.01
assert abs_coef_error < 1e-10


def test_return_std():
Expand Down Expand Up @@ -248,3 +247,28 @@ def f_noise(X, noise_mult):
m2.fit(X, y)
y_mean2, y_std2 = m2.predict(X_test, return_std=True)
assert_array_almost_equal(y_std2, noise_mult, decimal=decimal)


@pytest.mark.parametrize('seed', range(10))
def test_update_sigma(seed):
# make sure the two update_sigma() helpers are equivalent. The woodbury
# formula is used when n_samples < n_features, and the other one is used
# otherwise.

rng = np.random.RandomState(seed)

# set n_samples == n_features to avoid instability issues when inverting
# the matrices. Using the woodbury formula would be unstable when
# n_samples > n_features
n_samples = n_features = 10
X = rng.randn(n_samples, n_features)
alpha = 1
lmbda = np.arange(1, n_features + 1)
keep_lambda = np.array([True] * n_features)

reg = ARDRegression()

sigma = reg._update_sigma(X, alpha, lmbda, keep_lambda)
sigma_woodbury = reg._update_sigma_woodbury(X, alpha, lmbda, keep_lambda)

np.testing.assert_allclose(sigma, sigma_woodbury)
8 changes: 0 additions & 8 deletions sklearn/utils/fixes.py
Expand Up @@ -42,14 +42,6 @@ def _parse_version(version_string):
# mypy error: Name 'lobpcg' already defined (possibly by an import)
from ..externals._lobpcg import lobpcg # type: ignore # noqa

if sp_version >= (1, 3):
# Preserves earlier default choice of pinvh cutoff `cond` value.
# Can be removed once issue #14055 is fully addressed.
from ..externals._scipy_linalg import pinvh
else:
# mypy error: Name 'pinvh' already defined (possibly by an import)
from scipy.linalg import pinvh # type: ignore # noqa


def _object_dtype_isnan(X):
return X != X
Expand Down

0 comments on commit 442e319

Please sign in to comment.