From c33c2f244019cb7edd359f8aa062d7ada0b294d8 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 17 Feb 2016 07:24:07 -0500 Subject: [PATCH] TST: Test whether the positivity constraint works. --- .../decomposition/tests/test_dict_learning.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/sklearn/decomposition/tests/test_dict_learning.py b/sklearn/decomposition/tests/test_dict_learning.py index 65e3fc99d1742..864278906e65c 100644 --- a/sklearn/decomposition/tests/test_dict_learning.py +++ b/sklearn/decomposition/tests/test_dict_learning.py @@ -34,6 +34,15 @@ def test_dict_learning_overcomplete(): assert_true(dico.components_.shape == (n_components, n_features)) +def test_dict_learning_positivity(): + n_components = 5 + dico = DictionaryLearning( + n_components, transform_algorithm='lasso_lars', random_state=0, + positive=True).fit(X) + code = dico.transform(X) + assert_true((code >= 0).all()) + + def test_dict_learning_reconstruction(): n_components = 12 dico = DictionaryLearning(n_components, transform_algorithm='omp', @@ -111,6 +120,15 @@ def test_dict_learning_online_shapes(): assert_equal(np.dot(code, dictionary).shape, X.shape) +def test_dict_learning_online_positivity(): + rng = np.random.RandomState(0) + n_components = 8 + code, dictionary = dict_learning_online(X, n_components=n_components, + alpha=1, random_state=rng, + positive=True) + assert_true((code >= 0).all()) + + def test_dict_learning_online_verbosity(): n_components = 5 # test verbosity @@ -191,6 +209,20 @@ def test_sparse_encode_shapes(): assert_equal(code.shape, (n_samples, n_components)) +def test_sparse_encode_positivity(): + n_components = 12 + rng = np.random.RandomState(0) + V = rng.randn(n_components, n_features) # random init + V /= np.sum(V ** 2, axis=1)[:, np.newaxis] + for algo in ('lasso_lars', 'lasso_cd', 'lars', 'threshold'): + code = sparse_encode(X, V, algorithm=algo, positive=True) + assert_true((code >= 0).all()) + + assert_raises( + ValueError, sparse_encode, X, V, algorithm='omp', positive=True + ) + + def test_sparse_encode_input(): n_components = 100 rng = np.random.RandomState(0)