Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-target GPR predicts only 1 std when normalize_y=False #22174

Closed
Tenavi opened this issue Jan 10, 2022 · 1 comment · Fixed by #22199
Closed

Multi-target GPR predicts only 1 std when normalize_y=False #22174

Tenavi opened this issue Jan 10, 2022 · 1 comment · Fixed by #22199

Comments

@Tenavi
Copy link
Contributor

Tenavi commented Jan 10, 2022

Describe the bug

Supposed to have been fixed in #20761?

See #22199

When using a GPR model for multi-target data, if we don't set normalize_y=True then the shape of the predicted standard deviation is (n_samples,) instead of (n_samples, n_targets) and similarly for the covariance.

Steps/Code to Reproduce

import numpy as np
import sklearn
from sklearn.gaussian_process import GaussianProcessRegressor as GPR

print(sklearn.__version__)

X_train = np.random.rand(7,3)
Y_train = np.random.randn(7,2)

X_test = np.random.rand(4,3)

# ---- WORKING CODE ---- #

model = GPR(normalize_y=True)

model.fit(X_train, Y_train)

Y_pred, Y_std = model.predict(X_test, return_std=True)

print(Y_pred.shape, Y_std.shape)

# ---- BROKEN CODE ---- #

model = GPR()

model.fit(X_train, Y_train)

Y_pred, Y_std = model.predict(X_test, return_std=True)

print(Y_pred.shape, Y_std.shape)

Expected Results

Should get Y_std.shape = (n_samples, n_targets) = (4,2)

Actual Results

Get Y_std.shape = (n_samples,) = (4,)

Versions

System:
python: 3.9.5 | packaged by conda-forge | (default, Jun 19 2021, 00:27:35) [Clang 11.1.0 ]
executable: /Users/tnakam10/opt/anaconda3/envs/aerofusion/bin/python
machine: macOS-11.6.1-x86_64-i386-64bit

Python dependencies:
pip: 21.3.1
setuptools: 60.5.0
sklearn: 1.0.2
numpy: 1.19.5
scipy: 1.7.3
Cython: None
pandas: 1.3.5
matplotlib: 3.5.1
joblib: 1.1.0
threadpoolctl: 3.0.0

Built with OpenMP: True

@Tenavi Tenavi added Bug Needs Triage Issue requires triage labels Jan 10, 2022
@glemaitre glemaitre removed the Needs Triage Issue requires triage label Jan 10, 2022
@glemaitre
Copy link
Member

Yes apparently we did not solve the problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants