From a6974d5f26e375f310320dae9071cece92a2de2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Wed, 7 Jun 2017 15:20:29 +0200 Subject: [PATCH] Add test for scale=True --- sklearn/cross_decomposition/tests/test_pls.py | 31 ++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/sklearn/cross_decomposition/tests/test_pls.py b/sklearn/cross_decomposition/tests/test_pls.py index c476b2724792f..08bf2f2b775da 100644 --- a/sklearn/cross_decomposition/tests/test_pls.py +++ b/sklearn/cross_decomposition/tests/test_pls.py @@ -1,4 +1,6 @@ import numpy as np +from numpy.testing import assert_approx_equal + from sklearn.utils.testing import (assert_equal, assert_array_almost_equal, assert_array_equal, assert_true, assert_raise_message) @@ -351,6 +353,7 @@ def test_scale_and_stability(): assert_array_almost_equal(X_s_score, X_score) assert_array_almost_equal(Y_s_score, Y_score) + def test_pls_errors(): d = load_linnerud() X = d.data @@ -358,4 +361,30 @@ def test_pls_errors(): for clf in [pls_.PLSCanonical(), pls_.PLSRegression(), pls_.PLSSVD()]: clf.n_components = 4 - assert_raise_message(ValueError, "Invalid number of components", clf.fit, X, Y) + assert_raise_message(ValueError, "Invalid number of components", + clf.fit, X, Y) + + +def test_pls_scaling(): + # sanity check for scale=True + n_samples = 1000 + n_targets = 5 + n_features = 10 + + rng = np.random.RandomState(0) + + Q = rng.randn(n_targets, n_features) + Y = rng.randn(n_samples, n_targets) + X = np.dot(Y, Q) + 2 * rng.randn(n_samples, n_features) + + scale = 1000. + X_scaled = scale * X + + pls = pls_.PLSRegression(n_components=5, scale=True) + pls.fit(X, Y) + score = pls.score(X, Y) + + pls.fit(X_scaled, Y) + score_scaled = pls.score(X_scaled, Y) + + assert_approx_equal(score, score_scaled)