Skip to content

Commit

Permalink
[MRG+1] Make partial_fit ignore n_iter in MiniBatchDictionaryLearning (
Browse files Browse the repository at this point in the history
…scikit-learn#17433)

* Make partial_fit ignore n_iter.

* Revert whiteline.

* Fix self.iter_offset.

* Add test on iter_offset_

* Add what's new entry

* Apply thomasjpfan suggestions.
  • Loading branch information
cmarmo authored and viclafargue committed Jun 26, 2020
1 parent 4ad9c06 commit f1138d1
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 2 deletions.
8 changes: 8 additions & 0 deletions doc/whats_new/v0.24.rst
Expand Up @@ -44,6 +44,14 @@ Changelog
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
where 123456 is the *pull request* number, not the issue number.
:mod:`sklearn.decomposition`
............................

- |Fix| Fixed a bug in
:func:`decomposition.MiniBatchDictionaryLearning.partial_fit` which should
update the dictionary by iterating only once over a mini-batch.
:pr:`17433` by :user:`Chiara Marmo <cmarmo>`

:mod:`sklearn.ensemble`
.......................

Expand Down
4 changes: 2 additions & 2 deletions sklearn/decomposition/_dict_learning.py
Expand Up @@ -1490,7 +1490,7 @@ def partial_fit(self, X, y=None, iter_offset=None):
iter_offset = getattr(self, 'iter_offset_', 0)
U, (A, B) = dict_learning_online(
X, self.n_components, alpha=self.alpha,
n_iter=self.n_iter, method=self.fit_algorithm,
n_iter=1, method=self.fit_algorithm,
method_max_iter=self.transform_max_iter,
n_jobs=self.n_jobs, dict_init=dict_init,
batch_size=len(X), shuffle=False,
Expand All @@ -1504,5 +1504,5 @@ def partial_fit(self, X, y=None, iter_offset=None):
# Keep track of the state of the algorithm to be able to do
# some online fitting (partial_fit)
self.inner_stats_ = (A, B)
self.iter_offset_ = iter_offset + self.n_iter
self.iter_offset_ = iter_offset + 1
return self
17 changes: 17 additions & 0 deletions sklearn/decomposition/tests/test_dict_learning.py
Expand Up @@ -391,6 +391,23 @@ def test_dict_learning_online_partial_fit():
decimal=2)


def test_dict_learning_iter_offset():
n_components = 12
rng = np.random.RandomState(0)
V = rng.randn(n_components, n_features)
dict1 = MiniBatchDictionaryLearning(n_components, n_iter=10,
dict_init=V, random_state=0,
shuffle=False)
dict2 = MiniBatchDictionaryLearning(n_components, n_iter=10,
dict_init=V, random_state=0,
shuffle=False)
dict1.fit(X)
for sample in X:
dict2.partial_fit(sample[np.newaxis, :])

assert dict1.iter_offset_ == dict2.iter_offset_


def test_sparse_encode_shapes():
n_components = 12
rng = np.random.RandomState(0)
Expand Down

0 comments on commit f1138d1

Please sign in to comment.