New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG+2] Added n_components parameter to LatentDirichletAllocation to replace … #8922
Changes from 1 commit
fb70d6d
58085be
3791af0
80c9b2f
7131b2d
e137ffc
54cdb09
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -143,17 +143,17 @@ class LatentDirichletAllocation(BaseEstimator, TransformerMixin): | |
|
||
Parameters | ||
---------- | ||
n_topics : int, optional (default=10) | ||
n_components : int, optional (default=10) | ||
Number of topics. | ||
|
||
doc_topic_prior : float, optional (default=None) | ||
Prior of document topic distribution `theta`. If the value is None, | ||
defaults to `1 / n_topics`. | ||
defaults to `1 / n_components`. | ||
In the literature, this is called `alpha`. | ||
|
||
topic_word_prior : float, optional (default=None) | ||
Prior of topic word distribution `beta`. If the value is None, defaults | ||
to `1 / n_topics`. | ||
to `1 / n_components`. | ||
In the literature, this is called `eta`. | ||
|
||
learning_method : 'batch' | 'online', default='online' | ||
|
@@ -227,7 +227,7 @@ class LatentDirichletAllocation(BaseEstimator, TransformerMixin): | |
|
||
Attributes | ||
---------- | ||
components_ : array, [n_topics, n_features] | ||
components_ : array, [n_components, n_features] | ||
Variational parameters for topic word distribution. Since the complete | ||
conditional for topic word distribution is a Dirichlet, | ||
``components_[i, j]`` can be viewed as pseudocount that represents the | ||
|
@@ -241,6 +241,9 @@ class LatentDirichletAllocation(BaseEstimator, TransformerMixin): | |
|
||
n_iter_ : int | ||
Number of passes over the dataset. | ||
|
||
n_topics : int, optional (default=10) | ||
Same as n_components, kept for backward compatibility | ||
|
||
References | ||
---------- | ||
|
@@ -255,13 +258,13 @@ class LatentDirichletAllocation(BaseEstimator, TransformerMixin): | |
|
||
""" | ||
|
||
def __init__(self, n_topics=10, doc_topic_prior=None, | ||
def __init__(self, n_components=10, doc_topic_prior=None, | ||
topic_word_prior=None, learning_method=None, | ||
learning_decay=.7, learning_offset=10., max_iter=10, | ||
batch_size=128, evaluate_every=-1, total_samples=1e6, | ||
perp_tol=1e-1, mean_change_tol=1e-3, max_doc_update_iter=100, | ||
n_jobs=1, verbose=0, random_state=None): | ||
self.n_topics = n_topics | ||
n_jobs=1, verbose=0, random_state=None, n_topics=10): | ||
self.n_components = n_components | ||
self.doc_topic_prior = doc_topic_prior | ||
self.topic_word_prior = topic_word_prior | ||
self.learning_method = learning_method | ||
|
@@ -277,13 +280,18 @@ def __init__(self, n_topics=10, doc_topic_prior=None, | |
self.n_jobs = n_jobs | ||
self.verbose = verbose | ||
self.random_state = random_state | ||
if n_components == 10 and n_topics != 10: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Our convention is to do this with fit so it also applies when set_params is used There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although you should use an additional variable rather than overwrite the parameter. A bit bureaucratic, sorry... |
||
self.n_components = n_topics | ||
warnings.warn("n_topics has been deprecated in favor of n_components", DeprecationWarning) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indicate when it will be removed. See developers' docs |
||
|
||
|
||
def _check_params(self): | ||
"""Check model parameters.""" | ||
|
||
if self.n_topics <= 0: | ||
raise ValueError("Invalid 'n_topics' parameter: %r" | ||
% self.n_topics) | ||
if self.n_components <= 0: | ||
raise ValueError("Invalid 'n_components' parameter: %r" | ||
% self.n_components) | ||
|
||
|
||
if self.total_samples <= 0: | ||
raise ValueError("Invalid 'total_samples' parameter: %r" | ||
|
@@ -305,20 +313,20 @@ def _init_latent_vars(self, n_features): | |
self.n_iter_ = 0 | ||
|
||
if self.doc_topic_prior is None: | ||
self.doc_topic_prior_ = 1. / self.n_topics | ||
self.doc_topic_prior_ = 1. / self.n_components | ||
else: | ||
self.doc_topic_prior_ = self.doc_topic_prior | ||
|
||
if self.topic_word_prior is None: | ||
self.topic_word_prior_ = 1. / self.n_topics | ||
self.topic_word_prior_ = 1. / self.n_components | ||
else: | ||
self.topic_word_prior_ = self.topic_word_prior | ||
|
||
init_gamma = 100. | ||
init_var = 1. / init_gamma | ||
# In the literature, this is called `lambda` | ||
self.components_ = self.random_state_.gamma( | ||
init_gamma, init_var, (self.n_topics, n_features)) | ||
init_gamma, init_var, (self.n_components, n_features)) | ||
|
||
# In the literature, this is `exp(E[log(beta)])` | ||
self.exp_dirichlet_component_ = np.exp( | ||
|
@@ -409,7 +417,7 @@ def _em_step(self, X, total_samples, batch_update, parallel=None): | |
|
||
Returns | ||
------- | ||
doc_topic_distr : array, shape=(n_samples, n_topics) | ||
doc_topic_distr : array, shape=(n_samples, n_components) | ||
Unnormalized document topic distribution. | ||
""" | ||
|
||
|
@@ -569,7 +577,7 @@ def _unnormalized_transform(self, X): | |
|
||
Returns | ||
------- | ||
doc_topic_distr : shape=(n_samples, n_topics) | ||
doc_topic_distr : shape=(n_samples, n_components) | ||
Document topic distribution for X. | ||
""" | ||
if not hasattr(self, 'components_'): | ||
|
@@ -603,7 +611,7 @@ def transform(self, X): | |
|
||
Returns | ||
------- | ||
doc_topic_distr : shape=(n_samples, n_topics) | ||
doc_topic_distr : shape=(n_samples, n_components) | ||
Document topic distribution for X. | ||
""" | ||
doc_topic_distr = self._unnormalized_transform(X) | ||
|
@@ -622,7 +630,7 @@ def _approx_bound(self, X, doc_topic_distr, sub_sampling): | |
X : array-like or sparse matrix, shape=(n_samples, n_features) | ||
Document word matrix. | ||
|
||
doc_topic_distr : array, shape=(n_samples, n_topics) | ||
doc_topic_distr : array, shape=(n_samples, n_components) | ||
Document topic distribution. In the literature, this is called | ||
gamma. | ||
|
||
|
@@ -644,7 +652,7 @@ def _loglikelihood(prior, distr, dirichlet_distr, size): | |
return score | ||
|
||
is_sparse_x = sp.issparse(X) | ||
n_samples, n_topics = doc_topic_distr.shape | ||
n_samples, n_components = doc_topic_distr.shape | ||
n_features = self.components_.shape[1] | ||
score = 0 | ||
|
||
|
@@ -673,7 +681,7 @@ def _loglikelihood(prior, distr, dirichlet_distr, size): | |
|
||
# compute E[log p(theta | alpha) - log q(theta | gamma)] | ||
score += _loglikelihood(doc_topic_prior, doc_topic_distr, | ||
dirichlet_doc_topic, self.n_topics) | ||
dirichlet_doc_topic, self.n_components) | ||
|
||
# Compensate for the subsampling of the population of documents | ||
if sub_sampling: | ||
|
@@ -717,7 +725,7 @@ def _perplexity_precomp_distr(self, X, doc_topic_distr=None, | |
X : array-like or sparse matrix, [n_samples, n_features] | ||
Document word matrix. | ||
|
||
doc_topic_distr : None or array, shape=(n_samples, n_topics) | ||
doc_topic_distr : None or array, shape=(n_samples, n_components) | ||
Document topic distribution. | ||
If it is None, it will be generated by applying transform on X. | ||
|
||
|
@@ -736,12 +744,12 @@ def _perplexity_precomp_distr(self, X, doc_topic_distr=None, | |
if doc_topic_distr is None: | ||
doc_topic_distr = self._unnormalized_transform(X) | ||
else: | ||
n_samples, n_topics = doc_topic_distr.shape | ||
n_samples, n_components = doc_topic_distr.shape | ||
if n_samples != X.shape[0]: | ||
raise ValueError("Number of samples in X and doc_topic_distr" | ||
" do not match.") | ||
|
||
if n_topics != self.n_topics: | ||
if n_components != self.n_components: | ||
raise ValueError("Number of topics does not match.") | ||
|
||
current_samples = X.shape[0] | ||
|
@@ -769,7 +777,7 @@ def perplexity(self, X, doc_topic_distr='deprecated', sub_sampling=False): | |
X : array-like or sparse matrix, [n_samples, n_features] | ||
Document word matrix. | ||
|
||
doc_topic_distr : None or array, shape=(n_samples, n_topics) | ||
doc_topic_distr : None or array, shape=(n_samples, n_components) | ||
Document topic distribution. | ||
This argument is deprecated and is currently being ignored. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the Sphinx deprecated markup