Skip to content

Commit

Permalink
FIX Use cho_solve when return_std=True for GaussianProcessRegressor (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
iwhalvic authored and glemaitre committed Apr 28, 2021
1 parent 84969bb commit e25c9b1
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 41 deletions.
9 changes: 8 additions & 1 deletion doc/whats_new/v0.24.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ Changelog
:mod:`sklearn.gaussian_process`
...............................

- |Fix| Avoid explicitly forming inverse covariance matrix in
:class:`gaussian_process.GaussianProcessRegressor` when set to output
standard deviation. With certain covariance matrices this inverse is unstable
to compute explicitly. Calling Cholesky solver mitigates this issue in
computation.
:pr:`19939` by :user:`Ian Halvic <iwhalvic>`.

- |Fix| Avoid division by zero when scaling constant target in
:class:`gaussian_process.GaussianProcessRegressor`. It was due to a std. dev.
equal to 0. Now, such case is detected and the std. dev. is affected to 1
Expand All @@ -59,7 +66,7 @@ Changelog
- |Fix|: Fixed a bug in :class:`linear_model.LogisticRegression`: the
sample_weight object is not modified anymore. :pr:`19182` by
:user:`Yosuke KOBAYASHI <m7142yosuke>`.

:mod:`sklearn.metrics`
......................

Expand Down
24 changes: 9 additions & 15 deletions sklearn/gaussian_process/_gpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from operator import itemgetter

import numpy as np
from scipy.linalg import cholesky, cho_solve, solve_triangular
from scipy.linalg import cholesky, cho_solve
import scipy.optimize

from ..base import BaseEstimator, RegressorMixin, clone
Expand Down Expand Up @@ -271,8 +271,6 @@ def obj_func(theta, eval_gradient=True):
K[np.diag_indices_from(K)] += self.alpha
try:
self.L_ = cholesky(K, lower=True) # Line 2
# self.L_ changed, self._K_inv needs to be recomputed
self._K_inv = None
except np.linalg.LinAlgError as exc:
exc.args = ("The kernel, %s, is not returning a "
"positive definite matrix. Try gradually "
Expand Down Expand Up @@ -345,31 +343,27 @@ def predict(self, X, return_std=False, return_cov=False):
else: # Predict based on GP posterior
K_trans = self.kernel_(X, self.X_train_)
y_mean = K_trans.dot(self.alpha_) # Line 4 (y_mean = f_star)

# undo normalisation
y_mean = self._y_train_std * y_mean + self._y_train_mean

if return_cov:
v = cho_solve((self.L_, True), K_trans.T) # Line 5
y_cov = self.kernel_(X) - K_trans.dot(v) # Line 6
# Solve K @ V = K_trans.T
V = cho_solve((self.L_, True), K_trans.T) # Line 5
y_cov = self.kernel_(X) - K_trans.dot(V) # Line 6

# undo normalisation
y_cov = y_cov * self._y_train_std**2

return y_mean, y_cov
elif return_std:
# cache result of K_inv computation
if self._K_inv is None:
# compute inverse K_inv of K based on its Cholesky
# decomposition L and its inverse L_inv
L_inv = solve_triangular(self.L_.T,
np.eye(self.L_.shape[0]))
self._K_inv = L_inv.dot(L_inv.T)
# Solve K @ V = K_trans.T
V = cho_solve((self.L_, True), K_trans.T) # Line 5

# Compute variance of predictive distribution
# Use einsum to avoid explicitly forming the large matrix
# K_trans @ V just to extract its diagonal afterward.
y_var = self.kernel_.diag(X)
y_var -= np.einsum("ij,ij->i",
np.dot(K_trans, self._K_inv), K_trans)
y_var -= np.einsum("ij,ji->i", K_trans, V)

# Check if any of the variances is negative because of
# numerical issues. If yes: set the variance to 0.
Expand Down
58 changes: 33 additions & 25 deletions sklearn/gaussian_process/tests/test_gpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
from sklearn.gaussian_process.tests._mini_sequence_kernel import MiniSeqKernel
from sklearn.exceptions import ConvergenceWarning

from sklearn.utils._testing \
import (assert_array_less,
assert_almost_equal, assert_raise_message,
assert_array_almost_equal, assert_array_equal,
assert_allclose, assert_warns_message)
from sklearn.utils._testing import (
assert_array_less,
assert_almost_equal,
assert_array_almost_equal,
assert_allclose
)


def f(x):
Expand Down Expand Up @@ -185,7 +186,8 @@ def test_no_optimizer():


@pytest.mark.parametrize('kernel', kernels)
def test_predict_cov_vs_std(kernel):
@pytest.mark.parametrize("target", [y, np.ones(X.shape[0], dtype=np.float64)])
def test_predict_cov_vs_std(kernel, target):
if sys.maxsize <= 2 ** 32 and sys.version_info[:2] == (3, 6):
pytest.xfail("This test may fail on 32bit Py3.6")

Expand Down Expand Up @@ -452,25 +454,6 @@ def test_no_fit_default_predict():
assert_array_almost_equal(y_cov1, y_cov2)


@pytest.mark.parametrize('kernel', kernels)
def test_K_inv_reset(kernel):
y2 = f(X2).ravel()

# Test that self._K_inv is reset after a new fit
gpr = GaussianProcessRegressor(kernel=kernel).fit(X, y)
assert hasattr(gpr, '_K_inv')
assert gpr._K_inv is None
gpr.predict(X, return_std=True)
assert gpr._K_inv is not None
gpr.fit(X2, y2)
assert gpr._K_inv is None
gpr.predict(X2, return_std=True)
gpr2 = GaussianProcessRegressor(kernel=kernel).fit(X2, y2)
gpr2.predict(X2, return_std=True)
# the value of K_inv should be independent of the first fit
assert_array_equal(gpr._K_inv, gpr2._K_inv)


def test_warning_bounds():
kernel = RBF(length_scale_bounds=[1e-5, 1e-3])
gpr = GaussianProcessRegressor(kernel=kernel)
Expand Down Expand Up @@ -566,3 +549,28 @@ def test_constant_target(kernel):
assert_allclose(y_pred, y_constant)
# set atol because we compare to zero
assert_allclose(np.diag(y_cov), 0., atol=1e-9)


def test_gpr_consistency_std_cov_non_invertible_kernel():
"""Check the consistency between the returned std. dev. and the covariance.
Non-regression test for:
https://github.com/scikit-learn/scikit-learn/issues/19936
Inconsistencies were observed when the kernel cannot be inverted (or
numerically stable).
"""
kernel = (C(8.98576054e+05, (1e-12, 1e12)) *
RBF([5.91326520e+02, 1.32584051e+03], (1e-12, 1e12)) +
WhiteKernel(noise_level=1e-5))
gpr = GaussianProcessRegressor(kernel=kernel, alpha=0, optimizer=None)
X_train = np.array([[0., 0.], [1.54919334, -0.77459667], [-1.54919334, 0.],
[0., -1.54919334], [0.77459667, 0.77459667],
[-0.77459667, 1.54919334]])
y_train = np.array([[-2.14882017e-10], [-4.66975823e+00], [4.01823986e+00],
[-1.30303674e+00], [-1.35760156e+00],
[3.31215668e+00]])
gpr.fit(X_train, y_train)
X_test = np.array([[-1.93649167, -1.93649167], [1.93649167, -1.93649167],
[-1.93649167, 1.93649167], [1.93649167, 1.93649167]])
pred1, std = gpr.predict(X_test, return_std=True)
pred2, cov = gpr.predict(X_test, return_cov=True)
assert_allclose(std, np.sqrt(np.diagonal(cov)), rtol=1e-5)

0 comments on commit e25c9b1

Please sign in to comment.