Dictionary learning #221

Merged
merged 190 commits into from Sep 19, 2011

Projects

None yet

7 participants

@vene
Member
vene commented Jun 25, 2011

Pull request contains: (UPDATED)

BaseDictionaryLearning object implementing transform methods
DictionaryLearning and OnlineDictionaryLearning implementing fit in different ways

Dictionary learning example
Image denoising example

@agramfort

avoid lambda they won't pickle.

@agramfort

I would rename D with X (standard in linear_model module)

@agramfort

OMP ref is [mallat 93]

@agramfort

y shoudl be (n_samples, n_targets) so y[i] is still sample i.

@agramfort

you don't need the \

@agramfort

doc formatting pb: Should have a Parameters and Returns sections. A reference to [Mallat 93] should be added too.

@agramfort

teh Returns section should be there also the coefs should be named w for consistancy

vene and others added some commits Jun 15, 2011
@vene vene Renaming, some transposing 0cb2e3c
@vene vene Tests and the refactoring they induce 12fcc75
@vene vene PEP8 7b4cdb0
@vene vene Added signal recovery test 09e17cf
@vene vene rigurous pep8 69ccac9
@vene vene Added the example b3e6f81
@vene vene Cosmetized the example 0264056
@vene vene Added Olivier's patch extractor with enhancements 873256b
@vene vene cleanup 6ee443a
@vene vene Tests for various cases 06067f6
@vene vene PEP8, renaming, removed image size from params f1df50f
@vene vene Merge branch 'sparsepca' into sc 898777a
@vene vene Merge branch 'sparsepca' into sc 1de74af
@vene vene Revert "FIX: update_V without warm restart"
This reverts commit 38235cf.
e1ef751
@vene vene Merge branch 'sparsepca' into sc 2cef229
@ogrisel ogrisel FIX: make the dataset doctest fixture modular 00c976e
@ogrisel ogrisel typo 2c425fa
@larsmans larsmans document placement new in SVMlight reader 7ee358f
@mblondel mblondel Documentation fixes. 7158ee3
@vene vene Initial integration of Orthogonal MP 23704df
@vene vene Renaming, some transposing 170cc2f
@vene vene Tests and the refactoring they induce aa3bd39
@vene vene PEP8 487a8e1
@vene vene Added signal recovery test 6c74b15
@vene vene rigurous pep8 9f5f0c3
@vene vene Added the example 26fad8d
@vene vene Cosmetized the example 37dc0c5
@vene vene Added Olivier's patch extractor with enhancements b8ed9cd
@vene vene cleanup afd860a
@vene vene Tests for various cases dde98c0
@vene vene PEP8, renaming, removed image size from params 057b7f3
@vene vene FIX: weird branching accident 77b6612
@vene vene Revert "FIX: update_V without warm restart"
This reverts commit 38235cf.
c02d81e
@vene vene Revert "Revert "FIX: update_V without warm restart""
This reverts commit a557817.
7aa4926
@vene vene Merge branch 'sc' of github.com:vene/scikit-learn into sc cfd5b34
@vene vene FIX: update_V without warm restart 1598c40
@vene vene Added dictionary learning example 34c6585
@agramfort agramfort ENH : prettify dict learn example on image patches 43acd7f
@agramfort agramfort pep8 ad5b4ca
@vene vene Merge pull request #3 from agramfort/sc
Sc
2a256ff
@vene vene renaming for consistency, tests for PatchExtractor 849ff59
@vene vene Initial shape of dictionary learning object 7af69d7
@vene vene Added DictionaryLearning to __init__.py bf613b0
@vene vene FIX: silly bugs so that the example runs 8550f10
@vene vene ENH: Tweaked the example a bit 620b907
@vene vene PEP8 0c51fba
@agramfort agramfort FIX : using product form utils.fixes for python 2.5 2fd33fd
@agramfort agramfort pep8 badddc8
@agramfort agramfort MISC : fix docstring, cosmit in image.py a9ea545
@agramfort agramfort FIX; missing import in dict_learning.py (OMP in transform in not tested 98f592d
@GaelVaroquaux

I'd like the instantiation of the object to be done in a separate step than the fit, as some people might be confused.

@GaelVaroquaux

seed should be called random_state. That remark applies all throughout the package.

@GaelVaroquaux
Member

There are unused imports in your code. Please run pyflakes on all the files (I use a shell for loop to do this).

@GaelVaroquaux
Member

As discussed by mail, sparse_pca should be turned to a dict_learning function, and work on the transposed problem. Only objects should be exposed as SparsePCA. This will mean that you will need to rename the logging messages.

EDIT: Indeed, after rereading the codebase, there is a clean-up required to make sure that the vocabulary is consistent, and the imports paths make sens.

@GaelVaroquaux

I have been running this example, and noticing that the results are much more consistent with my expectations when user higher values for max_patches. Once we have the online version working, we will need to get this parameter back up.

@fabianp
Member

Men! You just created a branch sc on origin, maybe you pushed to the wrong repo ?

Member
Member

haha. yes, happens to me all the time.

I think the fix is something like git push origin :sc (which of course makes a lot of sense ...)

Member
@larsmans larsmans and 2 others commented on an outdated diff Jun 26, 2011
scikits/learn/linear_model/omp.py
+ Whether to perform precomputations. Improves performance when n_targets
+ or n_samples is very large.
+
+ Returns:
+ --------
+ coef: array of shape: n_features or (n_features, n_targets)
+ Coefficients of the OMP solution
+ """
+ X = np.asanyarray(X)
+ y = np.asanyarray(y)
+ if y.ndim == 1:
+ y = y[:, np.newaxis]
+ if n_atoms == None and eps == None:
+ raise ValueError('OMP needs either a target number of atoms (n_atoms) \
+ or a target residual error (eps)')
+ if eps != None and eps < 0:
@larsmans
larsmans Jun 26, 2011 scikit-learn member

If the value must be positive, shouldn't the check be eps <= 0 then? Same question for n_atoms, below.

@vene
vene Jun 28, 2011 scikit-learn member

About n_atoms you are right, but I think we should accept eps = 0 for when you want perfect reconstruction... do you agree?

@larsmans
larsmans Jun 28, 2011 scikit-learn member

I'm not even remotely familiar with dictionary learning, so don't ask for comments on the actual algorithm :)
It was only the exception message that I was concerned about. Looks alright now.

@GaelVaroquaux
GaelVaroquaux Jun 28, 2011 scikit-learn member

I think we should accept eps = 0 for when you want perfect reconstruction... do you agree?

Yes

larsmans and others added some commits Jun 26, 2011
@larsmans larsmans Copyedit SparsePCA docs c7365e2
@vene vene Merge pull request #5 from agramfort/sc
Sc
24c3a68
@vene vene Merge branch 'master' into sc 9847371
@vene vene Merge branch 'sparsepca' into sc 1b76e96
@vene vene Merge branch 'sc' of git://github.com/larsmans/scikit-learn into sc c01ea0e
@vene vene Renamed online dict_learning appropriately f68b50a
@vene vene Merge branch 'sparsepca' into sc 61ee071
@vene vene Renaming part three b579582
@vene vene Fixed dico learning example d014aad
@vene vene Merge branch 'sparsepca' into sc b99397a
@vene vene FIX: checks in orthogonal_mp df89fab
@vene vene Cleanup orthogonal_mp docstrings 5c3bafd
@vene vene OMP docs, a little broken for now 484d9bd
@vene vene DOC: omp documentation improved c1e234e
@vene vene DOC: omp documentation fixes a48b00e
@vene vene DOC: dict_learning docs 8fb39dd
@vene vene dictionary learning tests 786ce12
@vene vene Fixed overcomplete case and updated dl example 598aee0
@vene vene online dictionary learning object 2f741ac
@vene vene factored base dico object ff671f5
@vene vene Merge branch 'sparsepca' into sc
Conflicts:
	scikits/learn/decomposition/sparse_pca.py
e3ef711
@vene vene pep8 16ac0ec
@vene vene more transform methods, split_sign db08066
@vene vene OMP dictionary must have normalized columns. 7826f94
@vene vene Merge branch 'master' into sc 8a553f5
@vene vene DOC: improved dict learning docs d08d58b
@vene vene Tweaked the dico example f65dec3
@vene vene exposed dict learning online in init 75773ff
@vene vene working on partial fit b06f121
@vene vene denoising example bad0431
@vene vene Annotate the example 5ae3c91
@vene vene partial fit iteration tracking, test still fails 7b42c64
@vene vene FIX: typo, s/treshold/threshold 3bdc425
@vene vene Tweak denoise example spacing 1fd8277
@vene vene pep8 examples ef41e46
@GaelVaroquaux
Member

I know I keep coming up with new comments :$, but you could simply set vmin=-.5, vmax=.5 in the imshow of the difference: -1, 1 is never achieved, and it kills the visual contrast.

In addition, I wonder if giving the norm of the difference in the title would be useful: it would enable comparison on a numerical basis.

Finally, I think that you should explain a bit more what you are doing in the docstring, and comment on the results. For instance the fact that lars (l1 penalized regression) induces a bias in the coefficients can be seen in the difference that is reminiscent of the local intensity value.

@vene
Member
vene commented Aug 24, 2011

I think this is getting close to merge-quality. Would anybody care to take a look?

@GaelVaroquaux
Member

I think this is getting close to merge-quality. Would anybody care to take a look?

I have to sleep...

@agramfort
Member

what if you use the stopping condition on the norm of the residual rather than hard coding the number of atoms when using OMP? since you know the noise level it should be easy to set.

Otherwise the code looks cleaner and the API is now simple.

I would reference the JMLR paper for the online Dict learning.

Pyflakes report :

scikits/learn/decomposition/dict_learning.py:22
'euclidean_distances' imported but unused
scikits/learn/decomposition/dict_learning.py:92
local variable 'overwrite_gram' is assigned to but never used
scikits/learn/decomposition/dict_learning.py:95
local variable 'overwrite_cov' is assigned to but never used

scikits/learn/decomposition/tests/test_dict_learning.py:4
'SkipTest' imported but unused
scikits/learn/decomposition/tests/test_dict_learning.py:6
'make_sparse_coded_signal' imported but unused

we're almost there ! :)

@ogrisel
Member
ogrisel commented Aug 28, 2011

+1 for a stopping criterion on the scaled residuals rather than hard coding the number of iterations.

@ogrisel
Member
ogrisel commented Aug 28, 2011

Actually using the scaled normal of the change of the dictionary weights is probably more stable that the residuals (as we do in coordinate descent).

@agramfort
Member

Actually using the scaled normal of the change of the dictionary weights is probably more stable that the residuals (as we do in coordinate descent).

I don't follow. If you know that || noise || = sigma then you should
make sure that || data - reconstructed_data || \approx sigma. That's
all I am saying.

@ogrisel
Member
ogrisel commented Aug 28, 2011

The issue is that you don't have access to the complete dataset in the online / minibatch setting. Hence the use of the change on the dictionary weights as a measure of convergence.

@vene
Member
vene commented Aug 28, 2011

Wow! Not only do I have internet access in Tarragona, but it's faster
than in Paris :P

If I won't party too hard I will push some commits tonight.

I have a nice idea in my head on how to implement the stopping
condition (based on my older thought that there is no use in having a
number of iterations n that is not a multiple of the number of
batches.

Also I will push the pyflakes fix before going to bed.
Best,
Vlad

@ogrisel ogrisel and 1 other commented on an outdated diff Sep 3, 2011
scikits/learn/decomposition/dict_learning.py
+ Pseudo number generator state used for random sampling.
+
+ Attributes
+ ----------
+ components_: array, [n_atoms, n_features]
+ components extracted from the data
+
+ References
+ ----------
+ J. Mairal, F. Bach, J. Ponce, G. Sapiro, 2009: Online dictionary learning
+ for sparse coding (http://www.di.ens.fr/sierra/pdfs/icml09.pdf)
+
+
+ See also
+ --------
+ SparsePCA
@ogrisel
ogrisel Sep 3, 2011 scikit-learn member

Explain here the relationship between dictionary learning and sparse PCA (i.e. that one is solving the transposed problem of the other).

@vene
vene Sep 3, 2011 scikit-learn member

What is the correct sphinx format for adding descriptions to "See also" entries?

@ogrisel
ogrisel Sep 3, 2011 scikit-learn member

I don't know, just grep on the existing doc and copy one that looks good :)

@vene
Member
vene commented Sep 10, 2011

I have merged the denoising enhancements from Alex. I've been toying with a exponentially weighted average tracking the amount of change in the dictionary, and it seems to get pretty close to zero for real data (digits), but for a randn array I found it oscillating around 50. We could have a max_iter param and warn that convergence was not reached (toying with the value of alpha can improve results, etc). Do you think we should put this in now or after merging?

@agramfort
Member

stopping criteria is hard for general pbs especially non convex like this one. Also it's frequent to see methods converge faster on real data where the structure is present. Let's add this after merging. What I'd really like to see added is the constraint on the residual for the denoising. As you know the noise variance you should do better than fixing the number of atoms which could then be data dependent. This should improve the denoising result.

@ogrisel ogrisel commented on the diff Sep 12, 2011
doc/modules/decomposition.rst
@@ -347,3 +347,105 @@ of the data.
matrix factorization"
<http://www.cs.rpi.edu/~boutsc/files/nndsvd.pdf>`_
C. Boutsidis, E. Gallopoulos, 2008
+
+
+
+.. _DictionaryLearning:
+
+Dictionary Learning
+===================
+
+Generic dictionary learning
+-------------------------
@ogrisel
ogrisel Sep 12, 2011 scikit-learn member

missing 2 -

@ogrisel ogrisel commented on the diff Sep 12, 2011
doc/modules/decomposition.rst
+--------------------------
+
+:class:`DictionaryLearningOnline` implements a faster, but less accurate
+version of the dictionary learning algorithm that is better suited for large
+datasets.
+
+By default, :class:`DictionaryLearningOnline` divides the data into
+mini-batches and optimizes in an online manner by cycling over the mini-batches
+for the specified number of iterations. However, at the moment it does not
+implement a stopping condition.
+
+The estimator also implements `partial_fit`, which updates the dictionary by
+iterating only once over a mini-batch. This can be used for online learning
+when the data is not readily available from the start, or for when the data
+does not fit into the memory.
+
@ogrisel
ogrisel Sep 12, 2011 scikit-learn member

Could you please update the face decomposition example to include OnlineDictionaryLearning and insert the matching plots as figure here?

@ogrisel
ogrisel Sep 12, 2011 scikit-learn member

It would also be a good way to tell that {Online}SparsePCA and {Online}DictionaryLearning are using the same underlying implementation but one is putting the sparse penalty on the dictionary atoms while the other is putting it on the dictionary loadings.

@ogrisel
Member
ogrisel commented Sep 12, 2011

Ok for putting the stopping criterion after merging but we should not forget about it.

@vene
Member
vene commented Sep 16, 2011

OK is this ready for merging?

@ogrisel
Member
ogrisel commented Sep 16, 2011

The documentation on dictionary learning should have the math formaul of the objective function (as for sparse PCA). Also I would move the DL section right after the section on SparsePCA to make it clear that this is the same algo but in one case we put the sparsity penalty on the dictionary components while on the other case on the dictionary loadings (a.k.a. the sparse code).

Maybe you should also include the figure obtained from the decomposition of the faces using DL (but they are arguably less interesting than the Sparse PCA since they look like the original: maybe alpha is too strong?).

@ogrisel
Member
ogrisel commented Sep 16, 2011

Also maybe the class DictionaryLearningOnline should be renamed to MiniBatchDictionaryLearning for consistency with MiniBatchKMeans and MiniBatchSparsePCA. WDYT?

@vene
Member
vene commented Sep 16, 2011

Also maybe the class DictionaryLearningOnline should be renamed to MiniBatchDictionaryLearning for consistency with MiniBatchKMeans and MiniBatchSparsePCA. WDYT?

That makes sense, but the key thing to note is that online dict
learning supports partial fit, whereas the minibatch sparse pca
doesn't. Minibatch k-means does, however, so I'm not sure what my
opinion is. We probably should rename it.

The documentation on dictionary learning [...]

I will address everything you said tonight when I get back to
Bucharest. Thank you for the input.

@ogrisel
Member
ogrisel commented Sep 16, 2011

+1 for renaming everything pseudo-online to MiniBatch[Class] even if in the Sparse PCA case it's not in the usual n_samples axis and does not have a partial_fit method.

vene and others added some commits Sep 16, 2011
@vene vene Merge branch 'master' into sc ca92354
@vene vene Merge branch 'vene-sc' of git://github.com/ogrisel/scikit-learn into sc 5841984
@GaelVaroquaux GaelVaroquaux DOC: larger lena size in denoising example
Large size work better because they give a better training set to the
dictionary_learning algorithm. This is a tradeoff between computation
time and quality of example
f0b0cfe
@GaelVaroquaux GaelVaroquaux commented on the diff Sep 17, 2011
sklearn/decomposition/dict_learning.py
+ new_code = np.sign(cov) * np.maximum(np.abs(cov) - alpha, 0)
+
+ elif algorithm == 'omp':
+ if n_nonzero_coefs is None and alpha is None:
+ n_nonzero_coefs = n_features / 10
+ norms_squared = np.sum((Y ** 2), axis=0)
+ new_code = orthogonal_mp_gram(gram, cov, n_nonzero_coefs, alpha,
+ norms_squared, overwrite_Xy=overwrite_cov
+ )
+ else:
+ raise NotImplemented('Sparse coding method %s not implemented' %
+ algorithm)
+ return new_code
+
+
+def sparse_encode_parallel(X, Y, gram=None, cov=None, algorithm='lasso_lars',
@GaelVaroquaux
GaelVaroquaux Sep 17, 2011 scikit-learn member

I believe that this should be moved to linear_models and renamed something like 'multivariate_lasso'. I am worried that in the current situation, people starting from the lasso solver will not find it.

@ogrisel
ogrisel Sep 17, 2011 scikit-learn member

I don't think we should call it "multivariate lasso" as this is not restricted to lasso but also works for OMP and simple thresholding too. I find the current function name much more explicit.

@vene
vene Sep 17, 2011 scikit-learn member

We could add See alsos?

@agramfort
agramfort Sep 19, 2011 scikit-learn member

more multitask_lasso than multivariate_lasso but I feel it's too much jargon.
+1 for see also and maybe latter refactor lasso_lars and lasso_cd to support multiple inputs.

@GaelVaroquaux GaelVaroquaux and 1 other commented on an outdated diff Sep 17, 2011
sklearn/decomposition/dict_learning.py
+ X: array of shape (n_samples, n_features)
+ Data matrix.
+
+ n_atoms: int,
+ Number of dictionary atoms to extract.
+
+ alpha: int,
+ Sparsity controlling parameter.
+
+ max_iter: int,
+ Maximum number of iterations to perform.
+
+ tol: float,
+ Tolerance for the stopping condition.
+
+ method: {'lasso_lars', 'lasso_cd'}
@GaelVaroquaux
GaelVaroquaux Sep 17, 2011 scikit-learn member

I guess this should be renamed to 'lars' or 'cv'.

@agramfort
agramfort Sep 19, 2011 scikit-learn member

+1 for 'lars' and 'cd' as only lasso makes sense here.

@GaelVaroquaux
Member

I made a couple of comments in the github diff. In addition I made a pull request. Once these are done, I am +1 for merge/

@vene
Member
vene commented Sep 17, 2011

I think I addressed everything. I left the function name dictionary_learning_online, it could also be renamed as minibatch_dictionary_learning or mini_batch_dictionary_learning, do you think that should be done?

@ogrisel
Member
ogrisel commented Sep 18, 2011

+0 for minibatch_dictionary_learning or online_dictionary_learning. Running the tests / doc right now.

@ogrisel
Member
ogrisel commented Sep 18, 2011

It seems that the new figure in the MiniBatchDictionaryLearning section of the doc is pointing to the wrong image (MiniBatchSparsePCA) and the alignment is weird: the lena patches are centered and the figure for the faces decomposition below on the left. I would rather move the faces decomposition figure up, right after the mathematical formulation explanation and before the paragraph on sparse coding and image denoising application.

@agramfort agramfort commented on the diff Sep 19, 2011
sklearn/decomposition/dict_learning.py
+ lasso_cd: uses the coordinate descent method to compute the
+ Lasso solution (linear_model.Lasso). lasso_lars will be faster if
+ the estimated components are sparse.
+ omp: uses orthogonal matching pursuit to estimate the sparse solution
+ threshold: squashes to zero all coefficients less than alpha from
+ the projection X.T * Y
+
+ n_nonzero_coefs: int, 0.1 * n_features by default
+ Number of nonzero coefficients to target in each column of the
+ solution. This is only used by `algorithm='lars'` and `algorithm='omp'`
+ and is overridden by `alpha` in the `omp` case.
+
+ alpha: float, 1. by default
+ If `algorithm='lasso_lars'` or `algorithm='lasso_cd'`, `alpha` is the
+ penalty applied to the L1 norm.
+ If `algorithm='threhold'`, `alpha` is the absolute value of the
@agramfort
agramfort Sep 19, 2011 scikit-learn member

s/threhold/threshold

@agramfort agramfort commented on the diff Sep 19, 2011
sklearn/decomposition/dict_learning.py
+ threshold: squashes to zero all coefficients less than alpha from
+ the projection X.T * Y
+
+ n_nonzero_coefs: int, 0.1 * n_features by default
+ Number of nonzero coefficients to target in each column of the
+ solution. This is only used by `algorithm='lars'` and `algorithm='omp'`
+ and is overridden by `alpha` in the `omp` case.
+
+ alpha: float, 1. by default
+ If `algorithm='lasso_lars'` or `algorithm='lasso_cd'`, `alpha` is the
+ penalty applied to the L1 norm.
+ If `algorithm='threhold'`, `alpha` is the absolute value of the
+ threshold below which coefficients will be squashed to zero.
+ If `algorithm='omp'`, `alpha` is the tolerance parameter: the value of
+ the reconstruction error targeted. In this case, it overrides
+ `n_nonzero_coefs`.
@agramfort
agramfort Sep 19, 2011 scikit-learn member

that makes me think that add the constrain on the l2 reconstruction error had omp_rec_error ? @ogrisel, @GaelVaroquaux thoughts?

@agramfort agramfort commented on an outdated diff Sep 19, 2011
sklearn/decomposition/dict_learning.py
+ verbose:
+ degree of output the procedure will print
+
+ shuffle: boolean,
+ whether to shuffle the data before splitting it in batches
+
+ n_jobs: int,
+ number of parallel jobs to run, or -1 to autodetect.
+
+ method: {'lasso_lars', 'lasso_cd'}
+ lasso_lars: uses the least angle regression method
+ (linear_model.lars_path)
+ lasso_cd: uses the coordinate descent method to compute the
+ Lasso solution (linear_model.Lasso). Lars will be faster if
+ the estimated components are sparse.
+
@agramfort
agramfort Sep 19, 2011 scikit-learn member

lars and cd here also

@agramfort agramfort and 3 others commented on an outdated diff Sep 19, 2011
sklearn/decomposition/sparse_pca.py
@@ -553,9 +36,10 @@ class SparsePCA(BaseEstimator, TransformerMixin):
tol: float,
Tolerance for the stopping condition.
- method: {'lars', 'cd'}
- lars: uses the least angle regression method (linear_model.lars_path)
- cd: uses the coordinate descent method to compute the
+ method: {'lasso_lars', 'lasso_cd'}
+ lasso_lars: uses the least angle regression method
@agramfort
agramfort Sep 19, 2011 scikit-learn member

here also lars or cd

@vene
vene Sep 19, 2011 scikit-learn member

How about also algorithm instead of method? or even fit_algorithm for consistency with dictionary learning classes?

@ogrisel
ogrisel Sep 19, 2011 scikit-learn member

Yes we already had this discussion in the precedent comments. I am still in favor of using "algorithm" pervasively but that would require updating lars_path and LocallyLinearEmbedding and maybe others.

What do people think? If we do so one should not forget to update the API section of the whats_new.rst doc.

@vene
vene Sep 19, 2011 scikit-learn member

OTOH SparsePCA and MiniBatchSparsePCA can be safely changed now, as
they weren't featured in releases right? Or would you like to wait
and change everything at the same time?

@GaelVaroquaux
GaelVaroquaux Sep 19, 2011 scikit-learn member

On Mon, Sep 19, 2011 at 02:22:03AM -0700, Olivier Grisel wrote:

What do people think? If we do so one should not forget to update the API section of the whats_new.rst doc.

If wez break the API, I'd like a backward compatible mode for one release
(e.g. using **kwards).

G

@agramfort
Member

besides my comments I would rename plot_img_denoising.py to plot_image_denoising.py

also I'd love to see added the omp_rec_error for a better result in plot_img_denoising.py but I don't wand to block the merge

@vene vene and 1 other commented on an outdated diff Sep 19, 2011
sklearn/decomposition/dict_learning.py
+
+ Parameters
+ ----------
+ n_atoms: int,
+ number of dictionary elements to extract
+
+ alpha: int,
+ sparsity controlling parameter
+
+ max_iter: int,
+ maximum number of iterations to perform
+
+ tol: float,
+ tolerance for numerical error
+
+ fit_algorithm: {'lasso_lars', 'lasso_cd'}
@vene
vene Sep 19, 2011 scikit-learn member

And here too, lars and cd, right? and in the rest of the objects.

@ogrisel
ogrisel Sep 19, 2011 scikit-learn member

Yes only but where the lasso_ part is mandatory. E.g. not for the transform_algorithm.

@ogrisel
Member
ogrisel commented Sep 19, 2011

@agramfort: I am not familiar with what omp_rec_error is all about. So I agree we should report that discussion for another pull request.

@ogrisel
Member
ogrisel commented Sep 19, 2011

About the doc, this part of my comment was not addressed: "I would rather move the faces decomposition figure up, right after the mathematical formulation explanation and before the paragraph on sparse coding and image denoising application."

Also I think it should be compared to the PCA output as done for all other methods in this chapter so as to keep the chapter consistent.

@ogrisel
Member
ogrisel commented Sep 19, 2011

Also when I run the decomp example I often have 2 or 4 of atoms that are not white noise and the non noisy components look almost duplicated or one is a near negative of another. Maybe the L1 reg is too strong the algorithm is not stable on this data for so small dictionaries (6 atoms only, this is far from overcomplete in this regime...).

@vene
Member
vene commented Sep 19, 2011

Addressed now. Sorry about that. You're right this is much better.

@vene
Member
vene commented Sep 19, 2011

Decomp example is also fixed, and random state has been fixed. Looks
more informative now, but also a lot creepier. shivers

@GaelVaroquaux
Member

So I agree we should report that discussion for another pull request.

+1

@ogrisel
Member
ogrisel commented Sep 19, 2011

Ok this looks good for me. +1 for merge.

@GaelVaroquaux
Member

+1 me too

@ogrisel ogrisel merged commit d56281e into scikit-learn:master Sep 19, 2011
@ogrisel
Member
ogrisel commented Sep 19, 2011

Merged. Thanks again for your work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment