Skip to content
Browse files

factored base dico object

  • Loading branch information...
1 parent 2f741ac commit ff671f55537e00a436d0f2d03540ec7f69671893 @vene committed Jul 1, 2011
Showing with 72 additions and 68 deletions.
  1. +59 −66 scikits/learn/decomposition/dict_learning.py
  2. +13 −2 scikits/learn/decomposition/tests/test_dict_learning.py
View
125 scikits/learn/decomposition/dict_learning.py
@@ -12,7 +12,46 @@
from ..linear_model import orthogonal_mp
-class DictionaryLearning(BaseEstimator, TransformerMixin):
+class BaseDictionaryLearning(BaseEstimator, TransformerMixin):
+ """ Dictionary learning base class
+ """
+ def __init__(self, n_atoms, transform_method='omp'):
+ self.n_atoms = n_atoms
+ self.transform_method = transform_method
+
+ def transform(self, X, y=None, **kwargs):
+ """Encode the data as a sparse combination of the learned dictionary
+ atoms.
+
+ Coding method is determined by the object parameter `transform_method`.
+
+ Parameters
+ ----------
+ X: array of shape (n_samples, n_features)
+ Test data to be transformed, must have the same number of
+ features as the data used to train the model.
+
+ TODO: document kwargs for each possible coding method
+
+ Returns
+ -------
+ X_new array, shape (n_samples, n_components)
+ Transformed data
+ """
+ # XXX : kwargs is not documented
+
+ # XXX: parameters should be made explicit so we can have defaults
+ if self.transform_method == 'omp':
+ return orthogonal_mp(self.components_.T, X.T, **kwargs).T
+ elif self.transform_method in ('lasso_cd', 'lasso_lars'):
+ return _update_code_parallel(self.components_.T, X.T, **kwargs).T
+ # XXX: add tresholding and others
+ else:
+ raise NotImplemented('Coding method %s is not implemented' %
+ self.transform_method)
+
+
+class DictionaryLearning(BaseDictionaryLearning):
""" Dictionary learning
Finds a dictionary (a set of atoms) that can best be used to represent data
@@ -37,26 +76,29 @@ class DictionaryLearning(BaseEstimator, TransformerMixin):
tol: float,
tolerance for numerical error
- method: 'batch'
- algorithm to use
+ transform_method: 'lasso_lars' | 'lasso_cd' | 'omp'
+ method to use for transforming the data after the dictionary has been
+ learned
coding_method: 'lars' | 'cd',
method to use for solving the lasso problem
n_jobs: int,
number of parallel jobs to run
- U_init: array of shape (n_samples, n_atoms),
- V_init: array of shape (n_atoms, n_features),
- initial values for the decomposition for warm restart scenarios
+ code_init: array of shape (n_samples, n_atoms),
+ initial value for the code, for warm restart
+
+ dict_init: array of shape (n_atoms, n_features),
+ initial values for the dictionary, for warm restart
verbose:
degree of verbosity of the printed output
Attributes
----------
components_: array, [n_atoms, n_features]
- components extracted from the data
+ dictionary atoms extracted from the data
error_: array
vector of errors at each iteration
@@ -73,13 +115,13 @@ class DictionaryLearning(BaseEstimator, TransformerMixin):
"""
def __init__(self, n_atoms, alpha=1, max_iter=1000, tol=1e-8,
- method='batch', coding_method='lars', n_jobs=1,
+ transform_method='omp', coding_method='lars', n_jobs=1,
code_init=None, dict_init=None, verbose=False):
self.n_atoms = n_atoms
self.alpha = alpha
self.max_iter = max_iter
self.tol = tol
- self.method = method
+ self.transform_method = transform_method
self.coding_method = coding_method
self.n_jobs = n_jobs
self.code_init = code_init
@@ -132,36 +174,9 @@ def fit(self, X, y=None, **params):
self.fit_transform(X, y, **params)
return self
- def transform(self, X, y=None, method='omp', **kwargs):
- """Apply the projection onto the learned sparse components
- to new data.
-
- Parameters
- ----------
- X: array of shape (n_samples, n_features)
- Test data to be transformed, must have the same number of
- features as the data used to train the model.
-
- method: 'omp' | 'lars' | 'cd'
- Sparse coding method to use. Additional parameters are passed
- to the corresponding solver.
- Returns
- -------
- X_new array, shape (n_samples, n_components)
- Transformed data
- """
- # XXX : kwargs is not documented
-
- # XXX: parameters should be made explicit so we can have defaults
- if method == 'omp':
- return orthogonal_mp(self.components_, X.T, **kwargs)
- else:
- return _update_code_parallel(self.components_.T, X.T, **kwargs).T
-
-
-class DictionaryLearningOnline():
-""" Online dictionary learning
+class DictionaryLearningOnline(BaseDictionaryLearning):
+ """ Online dictionary learning
Finds a dictionary (a set of atoms) that can best be used to represent data
using a sparse code.
@@ -182,6 +197,10 @@ class DictionaryLearningOnline():
n_iter: int,
total number of iterations to perform
+ transform_method: 'lasso_lars' | 'lasso_cd' | 'omp'
+ method to use for transforming the data after the dictionary has been
+ learned
+
coding_method: 'lars' | 'cd',
method to use for solving the lasso problem
@@ -218,11 +237,12 @@ class DictionaryLearningOnline():
"""
def __init__(self, n_atoms, alpha=1, max_iter=1000, coding_method='lars',
n_jobs=1, chunk_size=3, shuffle=True, dict_init=None,
- verbose=False):
+ transform_method='omp', verbose=False):
self.n_atoms = n_atoms
self.alpha = alpha
self.n_iter = n_iter
self.coding_method = coding_method
+ self.transform_method = transform_method
self.n_jobs = n_jobs
self.dict_init = dict_init
self.verbose = verbose
@@ -254,30 +274,3 @@ def fit(self, X, y=None, **params):
return_code=False)
self.components_ = U
return self
-
- def transform(self, X, y=None, method='omp', **kwargs):
- """Apply the projection onto the learned sparse components
- to new data.
-
- Parameters
- ----------
- X: array of shape (n_samples, n_features)
- Test data to be transformed, must have the same number of
- features as the data used to train the model.
-
- method: 'omp' | 'lars' | 'cd'
- Sparse coding method to use. Additional parameters are passed
- to the corresponding solver.
-
- Returns
- -------
- X_new array, shape (n_samples, n_components)
- Transformed data
- """
- # XXX : kwargs is not documented
-
- # XXX: parameters should be made explicit so we can have defaults
- if method == 'omp':
- return orthogonal_mp(self.components_, X.T, **kwargs)
- else:
- return _update_code_parallel(self.components_.T, X.T, **kwargs).T
View
15 scikits/learn/decomposition/tests/test_dict_learning.py
@@ -19,7 +19,9 @@ def test_dict_learning_overcomplete():
dico = DictionaryLearning(n_atoms).fit(X)
assert dico.components_.shape == (n_atoms, n_features)
+
def test_dict_learning_reconstruction():
+ np.random.seed(0)
n_samples, n_features = 10, 8
n_atoms = 5
U = np.zeros((n_samples, n_atoms)).ravel()
@@ -28,7 +30,16 @@ def test_dict_learning_reconstruction():
V = np.random.randn(n_atoms, n_features)
X = np.dot(U, V)
- dico = DictionaryLearning(n_atoms)
- code = dico.fit(X).transform(X, method='lars', alpha=0.01)
+ dico = DictionaryLearning(n_atoms, transform_method='omp')
+ code = dico.fit(X).transform(X, eps=0.01)
+
+ assert_array_almost_equal(np.dot(code, dico.components_), X)
+ dico.transform_method = 'lasso_lars'
+ code = dico.transform(X, alpha=0.01)
+ # decimal=1 because lars is sensitive to roundup errors
assert_array_almost_equal(np.dot(code, dico.components_), X, decimal=1)
+
+
+def test_dict_learning_online():
+ pass

0 comments on commit ff671f5

Please sign in to comment.
Something went wrong with that request. Please try again.