Skip to content

Commit

Permalink
FIX Ensure determinism of SVD init in dict_learning (#18433)
Browse files Browse the repository at this point in the history
  • Loading branch information
brcharron committed Feb 3, 2021
1 parent 50d3aaa commit 4fd851c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
7 changes: 6 additions & 1 deletion doc/whats_new/v1.0.rst
Expand Up @@ -65,13 +65,18 @@ Changelog
- |API| In :class:`decomposition.DictionaryLearning`,
:class:`decomposition.MiniBatchDictionaryLearning`,
:func:`dict_learning` and :func:`dict_learning_online`,
`transform_alpha` will be equal to `alpha` instead of 1.0 by default
`transform_alpha` will be equal to `alpha` instead of 1.0 by default
starting from version 1.2
:pr:`19159` by :user:`Benoît Malézieux <bmalezieux>`.

- |Fix| Fixes incorrect multiple data-conversion warnings when clustering
boolean data. :pr:`19046` by :user:`Surya Prakash <jdsurya>`.

- |Fix| Fixed :func:`dict_learning`, used by :class:`DictionaryLearning`, to
ensure determinism of the output. Achieved by flipping signs of the SVD
output which is used to initialize the code.
:pr:`18433` by :user:`Bruno Charron <brcharron>`.

:mod:`sklearn.ensemble`
.......................

Expand Down
4 changes: 3 additions & 1 deletion sklearn/decomposition/_dict_learning.py
Expand Up @@ -18,7 +18,7 @@
from ..utils import deprecated
from ..utils import (check_array, check_random_state, gen_even_slices,
gen_batches)
from ..utils.extmath import randomized_svd, row_norms
from ..utils.extmath import randomized_svd, row_norms, svd_flip
from ..utils.validation import check_is_fitted, _deprecate_positional_args
from ..utils.fixes import delayed
from ..linear_model import Lasso, orthogonal_mp_gram, LassoLars, Lars
Expand Down Expand Up @@ -567,6 +567,8 @@ def dict_learning(X, n_components, *, alpha, max_iter=100, tol=1e-8,
dictionary = dict_init
else:
code, S, dictionary = linalg.svd(X, full_matrices=False)
# flip the initial code's sign to enforce deterministic output
code, dictionary = svd_flip(code, dictionary)
dictionary = S[:, np.newaxis] * dictionary
r = len(dictionary)
if n_components <= r: # True even if n_components=None
Expand Down

0 comments on commit 4fd851c

Please sign in to comment.