You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When using a GPR model for multi-target data, if we don't set normalize_y=True then the shape of the predicted standard deviation is (n_samples,) instead of (n_samples, n_targets) and similarly for the covariance.
Steps/Code to Reproduce
import numpy as np
import sklearn
from sklearn.gaussian_process import GaussianProcessRegressor as GPR
print(sklearn.__version__)
X_train = np.random.rand(7,3)
Y_train = np.random.randn(7,2)
X_test = np.random.rand(4,3)
# ---- WORKING CODE ---- #
model = GPR(normalize_y=True)
model.fit(X_train, Y_train)
Y_pred, Y_std = model.predict(X_test, return_std=True)
print(Y_pred.shape, Y_std.shape)
# ---- BROKEN CODE ---- #
model = GPR()
model.fit(X_train, Y_train)
Y_pred, Y_std = model.predict(X_test, return_std=True)
print(Y_pred.shape, Y_std.shape)
Expected Results
Should get Y_std.shape = (n_samples, n_targets) = (4,2)
Actual Results
Get Y_std.shape = (n_samples,) = (4,)
Versions
System:
python: 3.9.5 | packaged by conda-forge | (default, Jun 19 2021, 00:27:35) [Clang 11.1.0 ]
executable: /Users/tnakam10/opt/anaconda3/envs/aerofusion/bin/python
machine: macOS-11.6.1-x86_64-i386-64bit
Describe the bug
Supposed to have been fixed in #20761?
See #22199
When using a GPR model for multi-target data, if we don't set normalize_y=True then the shape of the predicted standard deviation is (n_samples,) instead of (n_samples, n_targets) and similarly for the covariance.
Steps/Code to Reproduce
Expected Results
Should get Y_std.shape = (n_samples, n_targets) = (4,2)
Actual Results
Get Y_std.shape = (n_samples,) = (4,)
Versions
System:
python: 3.9.5 | packaged by conda-forge | (default, Jun 19 2021, 00:27:35) [Clang 11.1.0 ]
executable: /Users/tnakam10/opt/anaconda3/envs/aerofusion/bin/python
machine: macOS-11.6.1-x86_64-i386-64bit
Python dependencies:
pip: 21.3.1
setuptools: 60.5.0
sklearn: 1.0.2
numpy: 1.19.5
scipy: 1.7.3
Cython: None
pandas: 1.3.5
matplotlib: 3.5.1
joblib: 1.1.0
threadpoolctl: 3.0.0
Built with OpenMP: True
The text was updated successfully, but these errors were encountered: