diff --git a/gensim/models/hdpmodel.py b/gensim/models/hdpmodel.py index 09239fa605..6d0bfbce56 100755 --- a/gensim/models/hdpmodel.py +++ b/gensim/models/hdpmodel.py @@ -10,27 +10,37 @@ # -""" -This module encapsulates functionality for the online Hierarchical Dirichlet Process algorithm. +"""Module for `online Hierarchical Dirichlet Processing +`_. -It allows both model estimation from a training corpus and inference of topic -distribution on new, unseen documents. +The core estimation code is directly adapted from the `blei-lab/online-hdp `_ +from `Wang, Paisley, Blei: "Online Variational Inference for the Hierarchical Dirichlet Process", JMLR (2011) +`_. -The core estimation code is directly adapted from the `onlinelhdp.py` script -by C. Wang see -**Wang, Paisley, Blei: Online Variational Inference for the Hierarchical Dirichlet -Process, JMLR (2011).** +Examples +-------- -http://jmlr.csail.mit.edu/proceedings/papers/v15/wang11a/wang11a.pdf +Train :class:`~gensim.models.hdpmodel.HdpModel` -The algorithm: +>>> from gensim.test.utils import common_corpus, common_dictionary +>>> from gensim.models import HdpModel +>>> +>>> hdp = HdpModel(common_corpus, common_dictionary) - * is **streamed**: training documents come in sequentially, no random access, - * runs in **constant memory** w.r.t. the number of documents: size of the - training corpus does not affect memory footprint +You can then infer topic distributions on new, unseen documents, with -""" +>>> unseen_document = [(1, 3.), (2, 4)] +>>> doc_hdp = hdp[unseen_document] + +To print 20 topics with top 10 most probable words. + +>>> topic_info = hdp.print_topics(num_topics=20, num_words=10) +The model can be updated (trained) with new documents via + +>>> hdp.update([[(1, 2)], [(1, 1), (4, 5)]]) + +""" from __future__ import with_statement import logging @@ -45,6 +55,8 @@ from gensim.matutils import dirichlet_expectation from gensim.models import basemodel, ldamodel +from gensim.utils import deprecated + logger = logging.getLogger(__name__) meanchangethresh = 0.00001 @@ -52,8 +64,18 @@ def expect_log_sticks(sticks): - """ - For stick-breaking hdp, return the E[log(sticks)] + """For stick-breaking hdp, get the :math:`\mathbb{E}[log(sticks)]`. + + Parameters + ---------- + sticks : numpy.ndarray + Array of values for stick. + + Returns + ------- + numpy.ndarray + Computed :math:`\mathbb{E}[log(sticks)]`. + """ dig_sum = psi(np.sum(sticks, 0)) ElogW = psi(sticks[0]) - dig_sum @@ -67,6 +89,27 @@ def expect_log_sticks(sticks): def lda_e_step(doc_word_ids, doc_word_counts, alpha, beta, max_iter=100): + """Performs EM-iteration on a single document for calculation of likelihood for a maximum iteration of `max_iter`. + + Parameters + ---------- + doc_word_ids : int + Id of corresponding words in a document. + doc_word_counts : int + Count of words in a single document. + alpha : numpy.ndarray + Lda equivalent value of alpha. + beta : numpy.ndarray + Lda equivalent value of beta. + max_iter : int, optional + Maximum number of times the expectation will be maximised. + + Returns + ------- + (numpy.ndarray, numpy.ndarray) + Computed (:math:`likelihood`, :math:`\\gamma`). + + """ gamma = np.ones(len(alpha)) expElogtheta = np.exp(dirichlet_expectation(gamma)) betad = beta[:, doc_word_ids] @@ -80,7 +123,7 @@ def lda_e_step(doc_word_ids, doc_word_counts, alpha, beta, max_iter=100): expElogtheta = np.exp(Elogtheta) phinorm = np.dot(expElogtheta, betad) + 1e-100 meanchange = np.mean(abs(gamma - lastgamma)) - if (meanchange < meanchangethresh): + if meanchange < meanchangethresh: break likelihood = np.sum(counts * np.log(phinorm)) @@ -88,57 +131,205 @@ def lda_e_step(doc_word_ids, doc_word_counts, alpha, beta, max_iter=100): likelihood += np.sum(gammaln(gamma) - gammaln(alpha)) likelihood += gammaln(np.sum(alpha)) - gammaln(np.sum(gamma)) - return (likelihood, gamma) + return likelihood, gamma class SuffStats(object): + """Stores sufficient statistics for the current chunk of document(s) whenever Hdp model is updated with new corpus. + These stats are used when updating lambda and top level sticks. The statistics include number of documents in the + chunk, length of words in the documents and top level truncation level. + + """ def __init__(self, T, Wt, Dt): + """ + + Parameters + ---------- + T : int + Top level truncation level. + Wt : int + Length of words in the documents. + Dt : int + Chunk size. + + """ self.m_chunksize = Dt self.m_var_sticks_ss = np.zeros(T) self.m_var_beta_ss = np.zeros((T, Wt)) def set_zero(self): + """Fill the sticks and beta array with 0 scalar value.""" self.m_var_sticks_ss.fill(0.0) self.m_var_beta_ss.fill(0.0) class HdpModel(interfaces.TransformationABC, basemodel.BaseTopicModel): - """ - The constructor estimates Hierachical Dirichlet Process model parameters based - on a training corpus: - - >>> hdp = HdpModel(corpus, id2word) - - You can infer topic distributions on new, unseen documents with - - >>> doc_hdp = hdp[doc_bow] - - Inference on new documents is based on the approximately LDA-equivalent topics. - - To print 20 topics with top 10 most probable words - - >>> hdp.print_topics(num_topics=20, num_words=10) - - Model persistency is achieved through its `load`/`save` methods. + """`Hierarchical Dirichlet Process model `_ + + Topic models promise to help summarize and organize large archives of texts that cannot be easily analyzed by hand. + Hierarchical Dirichlet process (HDP) is a powerful mixed-membership model for the unsupervised analysis of grouped + data. Unlike its finite counterpart, latent Dirichlet allocation, the HDP topic model infers the number of topics + from the data. Here we have used Online HDP, which provides the speed of online variational Bayes with the modeling + flexibility of the HDP. The idea behind Online variational Bayes in general is to optimize the variational + objective function with stochastic optimization.The challenge we face is that the existing coordinate ascent + variational Bayes algorithms for the HDP require complicated approximation methods or numerical optimization. This + model utilises stick breaking construction of Hdp which enables it to allow for coordinate-ascent variational Bayes + without numerical approximation. + + **Stick breaking construction** + + To understand the HDP model we need to understand how it is modelled using the stick breaking construction. A very + good analogy to understand the stick breaking construction is `chinese restaurant franchise + `_. + + + For this assume that there is a restaurant franchise (`corpus`) which has a large number of restaurants + (`documents`, `j`) under it. They have a global menu of dishes (`topics`, :math:`\Phi_{k}`) which they serve. + Also, a single dish (`topic`, :math:`\Phi_{k}`) is only served at a single table `t` for all the customers + (`words`, :math:`\\theta_{j,i}`) who sit at that table. + So, when a customer enters the restaurant he/she has the choice to make where he/she wants to sit. + He/she can choose to sit at a table where some customers are already sitting , or he/she can choose to sit + at a new table. Here the probability of choosing each option is not same. + + Now, in this the global menu of dishes correspond to the global atoms :math:`\Phi_{k}`, and each restaurant + correspond to a single document `j`. So the number of dishes served in a particular restaurant correspond to the + number of topics in a particular document. And the number of people sitting at each table correspond to the number + of words belonging to each topic inside the document `j`. + + Now, coming on to the stick breaking construction, the concept understood from the chinese restaurant franchise is + easily carried over to the stick breaking construction for hdp (`"Figure 1" from "Online Variational Inference + for the Hierarchical Dirichlet Process" `_). + + A two level hierarchical dirichlet process is a collection of dirichlet processes :math:`G_{j}` , one for each + group, which share a base distribution :math:`G_{0}`, which is also a dirichlet process. Also, all :math:`G_{j}` + share the same set of atoms, :math:`\Phi_{k}`, and only the atom weights :math:`\pi _{jt}` differs. + + There will be multiple document-level atoms :math:`\psi_{jt}` which map to the same corpus-level atom + :math:`\Phi_{k}`. Here, the :math:`\\beta` signify the weights given to each of the topics globally. Also, each + factor :math:`\\theta_{j,i}` is distributed according to :math:`G_{j}`, i.e., it takes on the value of + :math:`\Phi_{k}` with probability :math:`\pi _{jt}`. :math:`C_{j,t}` is an indicator variable whose value `k` + signifies the index of :math:`\Phi`. This helps to map :math:`\psi_{jt}` to :math:`\Phi_{k}`. + + The top level (`corpus` level) stick proportions correspond the values of :math:`\\beta`, + bottom level (`document` level) stick proportions correspond to the values of :math:`\pi`. + The truncation level for the corpus (`K`) and document (`T`) corresponds to the number of :math:`\\beta` + and :math:`\pi` which are in existence. + + Now, whenever coordinate ascent updates are to be performed, they happen at two level. The document level as well + as corpus level. + + At document level, we update the following: + + #. The parameters to the document level sticks, i.e, a and b parameters of :math:`\\beta` distribution of the + variable :math:`\pi _{jt}`. + #. The parameters to per word topic indicators, :math:`Z_{j,n}`. Here :math:`Z_{j,n}` selects topic parameter + :math:`\psi_{jt}`. + #. The parameters to per document topic indices :math:`\Phi_{jtk}`. + + At corpus level, we update the following: + + #. The parameters to the top level sticks, i.e., the parameters of the :math:`\\beta` distribution for the + corpus level :math:`\\beta`, which signify the topic distribution at corpus level. + #. The parameters to the topics :math:`\Phi_{k}`. + + Now coming on to the steps involved, procedure for online variational inference for the Hdp model is as follows: + + 1. We initialise the corpus level parameters, topic parameters randomly and set current time to 1. + 2. Fetch a random document j from the corpus. + 3. Compute all the parameters required for document level updates. + 4. Compute natural gradients of corpus level parameters. + 5. Initialise the learning rate as a function of kappa, tau and current time. Also, increment current time by 1 + each time it reaches this step. + 6. Update corpus level parameters. + + Repeat 2 to 6 until stopping condition is not met. + + Here the stopping condition corresponds to + + * time limit expired + * chunk limit reached + * whole corpus processed + + Attributes + ---------- + lda_alpha : numpy.ndarray + Same as :math:`\\alpha` from :class:`gensim.models.ldamodel.LdaModel`. + lda_beta : numpy.ndarray + Same as :math:`\\beta` from from :class:`gensim.models.ldamodel.LdaModel`. + m_D : int + Number of documents in the corpus. + m_Elogbeta : numpy.ndarray: + Stores value of dirichlet expectation, i.e., compute :math:`E[log \\theta]` for a vector + :math:`\\theta \sim Dir(\\alpha)`. + m_lambda : {numpy.ndarray, float} + Drawn samples from the parameterized gamma distribution. + m_lambda_sum : {numpy.ndarray, float} + An array with the same shape as `m_lambda`, with the specified axis (1) removed. + m_num_docs_processed : int + Number of documents finished processing.This is incremented in size of chunks. + m_r : list + Acts as normaliser in lazy updating of `m_lambda` attribute. + m_rhot : float + Assigns weight to the information obtained from the mini-chunk and its value it between 0 and 1. + m_status_up_to_date : bool + Flag to indicate whether `lambda `and :math:`E[log \\theta]` have been updated if True, otherwise - not. + m_timestamp : numpy.ndarray + Helps to keep track and perform lazy updates on lambda. + m_updatect : int + Keeps track of current time and is incremented every time :meth:`~gensim.models.hdpmodel.HdpModel.update_lambda` + is called. + m_var_sticks : numpy.ndarray + Array of values for stick. + m_varphi_ss : numpy.ndarray + Used to update top level sticks. + m_W : int + Length of dictionary for the input corpus. """ - def __init__(self, corpus, id2word, max_chunks=None, max_time=None, chunksize=256, kappa=1.0, tau=64.0, K=15, T=150, alpha=1, gamma=1, eta=0.01, scale=1.0, var_converge=0.0001, outputdir=None, random_state=None): """ - `gamma`: first level concentration - `alpha`: second level concentration - `eta`: the topic Dirichlet - `T`: top level truncation level - `K`: second level truncation level - `kappa`: learning rate - `tau`: slow down parameter - `max_time`: stop training after this many seconds - `max_chunks`: stop after having processed this many chunks (wrap around - corpus beginning in another corpus pass, if there are not enough chunks - in the corpus) + + Parameters + ---------- + corpus : iterable of list of (int, float) + Corpus in BoW format. + id2word : :class:`~gensim.corpora.dictionary.Dictionary` + Dictionary for the input corpus. + max_chunks : int, optional + Upper bound on how many chunks to process. It wraps around corpus beginning in another corpus pass, + if there are not enough chunks in the corpus. + max_time : int, optional + Upper bound on time (in seconds) for which model will be trained. + chunksize : int, optional + Number of documents in one chuck. + kappa: float,optional + Learning parameter which acts as exponential decay factor to influence extent of learning from each batch. + tau: float, optional + Learning parameter which down-weights early iterations of documents. + K : int, optional + Second level truncation level + T : int, optional + Top level truncation level + alpha : int, optional + Second level concentration + gamma : int, optional + First level concentration + eta : float, optional + The topic Dirichlet + scale : float, optional + Weights information from the mini-chunk of corpus to calculate rhot. + var_converge : float, optional + Lower bound on the right side of convergence. Used when updating variational parameters for a + single document. + outputdir : str, optional + Stores topic and options information in the specified directory. + random_state : {None, int, array_like, :class:`~np.random.RandomState`, optional} + Adds a little random jitter to randomize results around same alpha when trying to fetch a closest + corresponding lda model from :meth:`~gensim.models.hdpmodel.HdpModel.suggested_lda_model` + """ self.corpus = corpus self.id2word = id2word @@ -192,6 +383,24 @@ def __init__(self, corpus, id2word, max_chunks=None, max_time=None, self.update(corpus) def inference(self, chunk): + """Infers the gamma value based for `chunk`. + + Parameters + ---------- + chunk : iterable of list of (int, float) + Corpus in BoW format. + + Returns + ------- + numpy.ndarray + First level concentration, i.e., Gamma value. + + Raises + ------ + RuntimeError + If model doesn't trained yet. + + """ if self.lda_alpha is None or self.lda_beta is None: raise RuntimeError("model must be trained to perform inference") chunk = list(chunk) @@ -208,6 +417,22 @@ def inference(self, chunk): return gamma def __getitem__(self, bow, eps=0.01): + """Accessor method for generating topic distribution of given document. + + Parameters + ---------- + bow : {iterable of list of (int, float), list of (int, float) + BoW representation of the document/corpus to get topics for. + eps : float, optional + Ignore topics with probability below `eps`. + + Returns + ------- + list of (int, float) **or** :class:`gensim.interfaces.TransformedCorpus` + Topic distribution for the given document/corpus `bow`, as a list of `(topic_id, topic_probability)` or + transformed corpus + + """ is_corpus, corpus = utils.is_corpus(bow) if is_corpus: return self._apply(corpus) @@ -217,6 +442,18 @@ def __getitem__(self, bow, eps=0.01): return [(topicid, topicvalue) for topicid, topicvalue in enumerate(topic_dist) if topicvalue >= eps] def update(self, corpus): + """Train the model with new documents, by EM-iterating over `corpus` until any of the conditions is satisfied. + + * time limit expired + * chunk limit reached + * whole corpus processed + + Parameters + ---------- + corpus : iterable of list of (int, float) + Corpus in BoW format. + + """ save_freq = max(1, int(10000 / self.chunksize)) # save every 10k docs, roughly chunks_processed = 0 start_time = time.clock() @@ -244,6 +481,25 @@ def update(self, corpus): logger.info('PROGRESS: finished document %i of %i', self.m_num_docs_processed, self.m_D) def update_finished(self, start_time, chunks_processed, docs_processed): + """Flag to determine whether the model has been updated with the new corpus or not. + + Parameters + ---------- + start_time : float + Indicates the current processor time as a floating point number expressed in seconds. + The resolution is typically better on Windows than on Unix by one microsecond due to differing + implementation of underlying function calls. + chunks_processed : int + Indicates progress of the update in terms of the number of chunks processed. + docs_processed : int + Indicates number of documents finished processing.This is incremented in size of chunks. + + Returns + ------- + bool + If True - model is updated, False otherwise. + + """ return ( # chunk limit reached (self.max_chunks and chunks_processed == self.max_chunks) or @@ -255,6 +511,24 @@ def update_finished(self, start_time, chunks_processed, docs_processed): (not self.max_chunks and not self.max_time and docs_processed >= self.m_D)) def update_chunk(self, chunk, update=True, opt_o=True): + """Performs lazy update on necessary columns of lambda and variational inference for documents in the chunk. + + Parameters + ---------- + chunk : iterable of list of (int, float) + Corpus in BoW format. + update : bool, optional + If True - call :meth:`~gensim.models.hdpmodel.HdpModel.update_lambda`. + opt_o : bool, optional + Passed as argument to :meth:`~gensim.models.hdpmodel.HdpModel.update_lambda`. + If True then the topics will be ordered, False otherwise. + + Returns + ------- + (float, int) + A tuple of likelihood and sum of all the word counts from each document in the corpus. + + """ # Find the unique words in this chunk... unique_words = dict() word_list = [] @@ -297,8 +571,29 @@ def update_chunk(self, chunk, update=True, opt_o=True): return score, count def doc_e_step(self, ss, Elogsticks_1st, unique_words, doc_word_ids, doc_word_counts, var_converge): - """ - e step for a single doc + """Performs E step for a single doc. + + Parameters + ---------- + ss : :class:`~gensim.models.hdpmodel.SuffStats` + Stats for all document(s) in the chunk. + Elogsticks_1st : numpy.ndarray + Computed Elogsticks value by stick-breaking process. + unique_words : dict of (int, int) + Number of unique words in the chunk. + doc_word_ids : iterable of int + Word ids of for a single document. + doc_word_counts : iterable of int + Word counts of all words in a single document. + var_converge : float + Lower bound on the right side of convergence. Used when updating variational parameters for a single + document. + + Returns + ------- + float + Computed value of likelihood for a single document. + """ chunkids = [unique_words[id] for id in doc_word_ids] @@ -382,6 +677,18 @@ def doc_e_step(self, ss, Elogsticks_1st, unique_words, doc_word_ids, doc_word_co return likelihood def update_lambda(self, sstats, word_list, opt_o): + """Update appropriate columns of lambda and top level sticks based on documents. + + Parameters + ---------- + sstats : :class:`~gensim.models.hdpmodel.SuffStats` + Statistic for all document(s) in the chunk. + word_list : list of int + Contains word id of all the unique words in the chunk of documents on which update is being performed. + opt_o : bool, optional + If True - invokes a call to :meth:`~gensim.models.hdpmodel.HdpModel.optimal_ordering` to order the topics. + + """ self.m_status_up_to_date = False # rhot will be between 0 and 1, and says how much to weight # the information we got from this mini-chunk. @@ -412,9 +719,7 @@ def update_lambda(self, sstats, word_list, opt_o): self.m_var_sticks[1] = np.flipud(np.cumsum(var_phi_sum)) + self.m_gamma def optimal_ordering(self): - """ - ordering the topics - """ + """Performs ordering on the topics.""" idx = matutils.argsort(self.m_lambda_sum, reverse=True) self.m_varphi_ss = self.m_varphi_ss[idx] self.m_lambda = self.m_lambda[idx, :] @@ -422,12 +727,10 @@ def optimal_ordering(self): self.m_Elogbeta = self.m_Elogbeta[idx, :] def update_expectations(self): - """ - Since we're doing lazy updates on lambda, at any given moment - the current state of lambda may not be accurate. This function - updates all of the elements of lambda and Elogbeta - so that if (for example) we want to print out the - topics we've learned we'll get the correct behavior. + """Since we're doing lazy updates on lambda, at any given moment the current state of lambda may not be + accurate. This function updates all of the elements of lambda and Elogbeta so that if (for example) we want to + print out the topics we've learned we'll get the correct behavior. + """ for w in xrange(self.m_W): self.m_lambda[:, w] *= np.exp(self.m_r[-1] - self.m_r[self.m_timestamp[w]]) @@ -438,11 +741,29 @@ def update_expectations(self): self.m_status_up_to_date = True def show_topic(self, topic_id, topn=20, log=False, formatted=False, num_words=None): - """ - Print the `num_words` most probable words for topic `topic_id`. - - Set `formatted=True` to return the topics as a list of strings, or - `False` as lists of (weight, word) pairs. + """Print the `num_words` most probable words for topic `topic_id`. + + Parameters + ---------- + topic_id : int + Acts as a representative index for a particular topic. + topn : int, optional + Number of most probable words to show from given `topic_id`. + log : bool, optional + If True - logs a message with level INFO on the logger object. + formatted : bool, optional + If True - get the topics as a list of strings, otherwise - get the topics as lists of (weight, word) pairs. + num_words : int, optional + DEPRECATED, USE `topn` INSTEAD. + + Warnings + -------- + The parameter `num_words` is deprecated, will be removed in 4.0.0, please use `topn` instead. + + Returns + ------- + list of (str, numpy.float) **or** list of str + Topic terms output displayed whose format depends on `formatted` parameter. """ if num_words is not None: # deprecated num_words is used @@ -458,21 +779,35 @@ def show_topic(self, topic_id, topn=20, log=False, formatted=False, num_words=No return hdp_formatter.show_topic(topic_id, topn, log, formatted) def get_topics(self): - """ - Returns: - np.ndarray: `num_topics` x `vocabulary_size` array of floats which represents - the term topic matrix learned during inference. + """Get the term topic matrix learned during inference. + + Returns + ------- + np.ndarray + `num_topics` x `vocabulary_size` array of floats + """ topics = self.m_lambda + self.m_eta return topics / topics.sum(axis=1)[:, None] def show_topics(self, num_topics=20, num_words=20, log=False, formatted=True): - """ - Print the `num_words` most probable words for `num_topics` number of topics. - Set `num_topics=-1` to print all topics. - - Set `formatted=True` to return the topics as a list of strings, or - `False` as lists of (weight, word) pairs. + """Print the `num_words` most probable words for `num_topics` number of topics. + + Parameters + ---------- + num_topics : int, optional + Number of topics for which most probable `num_words` words will be fetched, if -1 - print all topics. + num_words : int, optional + Number of most probable words to show from `num_topics` number of topics. + log : bool, optional + If True - log a message with level INFO on the logger object. + formatted : bool, optional + If True - get the topics as a list of strings, otherwise - get the topics as lists of (weight, word) pairs. + + Returns + ------- + list of (str, numpy.float) **or** list of str + Output format for topic terms depends on the value of `formatted` parameter. """ if not self.m_status_up_to_date: @@ -481,8 +816,20 @@ def show_topics(self, num_topics=20, num_words=20, log=False, formatted=True): hdp_formatter = HdpTopicFormatter(self.id2word, betas) return hdp_formatter.show_topics(num_topics, num_words, log, formatted) + @deprecated("This method will be removed in 4.0.0, use `save` instead.") def save_topics(self, doc_count=None): - """legacy method; use `self.save()` instead""" + """Save discovered topics. + + Warnings + -------- + This method is deprecated, use :meth:`~gensim.models.hdpmodel.HdpModel.save` instead. + + Parameters + ---------- + doc_count : int, optional + Indicates number of documents finished processing and are to be saved. + + """ if not self.outputdir: logger.error("cannot store topics without having specified an output directory") @@ -495,8 +842,15 @@ def save_topics(self, doc_count=None): betas = self.m_lambda + self.m_eta np.savetxt(fname, betas) + @deprecated("This method will be removed in 4.0.0, use `save` instead.") def save_options(self): - """legacy method; use `self.save()` instead""" + """Writes all the values of the attributes for the current model in "options.dat" file. + + Warnings + -------- + This method is deprecated, use :meth:`~gensim.models.hdpmodel.HdpModel.save` instead. + + """ if not self.outputdir: logger.error("cannot store options without having specified an output directory") return @@ -515,8 +869,13 @@ def save_options(self): fout.write('gamma: %s\n' % str(self.m_gamma)) def hdp_to_lda(self): - """ - Compute the LDA almost equivalent HDP. + """Get corresponding alpha and beta values of a LDA almost equivalent to current HDP. + + Returns + ------- + (numpy.ndarray, numpy.ndarray) + Alpha and Beta arrays. + """ # alpha sticks = self.m_var_sticks[0] / (self.m_var_sticks[0] + self.m_var_sticks[1]) @@ -534,10 +893,15 @@ def hdp_to_lda(self): return alpha, beta def suggested_lda_model(self): - """ - Returns closest corresponding ldamodel object corresponding to current hdp model. - The hdp_to_lda method only returns corresponding alpha, beta values, and this method returns a trained ldamodel. - The num_topics is m_T (default is 150) so as to preserve the matrice shapes when we assign alpha and beta. + """Get a trained ldamodel object which is closest to the current hdp model. + + The `num_topics=m_T`, so as to preserve the matrices shapes when we assign alpha and beta. + + Returns + ------- + :class:`~gensim.models.ldamodel.LdaModel` + Closest corresponding LdaModel to current HdpModel. + """ alpha, beta = self.hdp_to_lda() ldam = ldamodel.LdaModel( @@ -547,6 +911,19 @@ def suggested_lda_model(self): return ldam def evaluate_test_corpus(self, corpus): + """Evaluates the model on test corpus. + + Parameters + ---------- + corpus : iterable of list of (int, float) + Test corpus in BoW format. + + Returns + ------- + float + The value of total likelihood obtained by evaluating the model for all documents in the test corpus. + + """ logger.info('TEST: evaluating test corpus') if self.lda_alpha is None or self.lda_beta is None: self.lda_alpha, self.lda_beta = self.hdp_to_lda() @@ -571,9 +948,30 @@ def evaluate_test_corpus(self, corpus): class HdpTopicFormatter(object): + """Helper class for :class:`gensim.models.hdpmodel.HdpModel` to format the output of topics.""" (STYLE_GENSIM, STYLE_PRETTY) = (1, 2) def __init__(self, dictionary=None, topic_data=None, topic_file=None, style=None): + """Initialise the :class:`gensim.models.hdpmodel.HdpTopicFormatter` and store topic data in sorted order. + + Parameters + ---------- + dictionary : :class:`~gensim.corpora.dictionary.Dictionary`,optional + Dictionary for the input corpus. + topic_data : numpy.ndarray, optional + The term topic matrix. + topic_file : {file-like object, str, pathlib.Path} + File, filename, or generator to read. If the filename extension is .gz or .bz2, the file is first + decompressed. Note that generators should return byte strings for Python 3k. + style : bool, optional + If True - get the topics as a list of strings, otherwise - get the topics as lists of (word, weight) pairs. + + Raises + ------ + ValueError + Either dictionary is None or both `topic_data` and `topic_file` is None. + + """ if dictionary is None: raise ValueError('no dictionary!') @@ -597,9 +995,44 @@ def __init__(self, dictionary=None, topic_data=None, topic_file=None, style=None self.style = style def print_topics(self, num_topics=10, num_words=10): + """Give the most probable `num_words` words from `num_topics` topics. + Alias for :meth:`~gensim.models.hdpmodel.HdpTopicFormatter.show_topics`. + + Parameters + ---------- + num_topics : int, optional + Top `num_topics` to be printed. + num_words : int, optional + Top `num_words` most probable words to be printed from each topic. + + Returns + ------- + list of (str, numpy.float) **or** list of str + Output format for `num_words` words from `num_topics` topics depends on the value of `self.style` attribute. + + """ return self.show_topics(num_topics, num_words, True) def show_topics(self, num_topics=10, num_words=10, log=False, formatted=True): + """Give the most probable `num_words` words from `num_topics` topics. + + Parameters + ---------- + num_topics : int, optional + Top `num_topics` to be printed. + num_words : int, optional + Top `num_words` most probable words to be printed from each topic. + log : bool, optional + If True - log a message with level INFO on the logger object. + formatted : bool, optional + If True - get the topics as a list of strings, otherwise as lists of (word, weight) pairs. + + Returns + ------- + list of (int, list of (str, numpy.float) **or** list of str) + Output format for terms from `num_topics` topics depends on the value of `self.style` attribute. + + """ shown = [] if num_topics < 0: num_topics = len(self.data) @@ -628,6 +1061,27 @@ def show_topics(self, num_topics=10, num_words=10, log=False, formatted=True): return shown def print_topic(self, topic_id, topn=None, num_words=None): + """Print the `topn` most probable words from topic id `topic_id`. + + Warnings + -------- + The parameter `num_words` is deprecated, will be removed in 4.0.0, please use `topn` instead. + + Parameters + ---------- + topic_id : int + Acts as a representative index for a particular topic. + topn : int, optional + Number of most probable words to show from given `topic_id`. + num_words : int, optional + DEPRECATED, USE `topn` INSTEAD. + + Returns + ------- + list of (str, numpy.float) **or** list of str + Output format for terms from a single topic depends on the value of `formatted` parameter. + + """ if num_words is not None: # deprecated num_words is used warnings.warn( "The parameter `num_words` is deprecated, will be removed in 4.0.0, please use `topn` instead." @@ -637,6 +1091,32 @@ def print_topic(self, topic_id, topn=None, num_words=None): return self.show_topic(topic_id, topn, formatted=True) def show_topic(self, topic_id, topn=20, log=False, formatted=False, num_words=None,): + """Give the most probable `num_words` words for the id `topic_id`. + + Warnings + -------- + The parameter `num_words` is deprecated, will be removed in 4.0.0, please use `topn` instead. + + Parameters + ---------- + topic_id : int + Acts as a representative index for a particular topic. + topn : int, optional + Number of most probable words to show from given `topic_id`. + log : bool, optional + If True logs a message with level INFO on the logger object, False otherwise. + formatted : bool, optional + If True return the topics as a list of strings, False as lists of + (word, weight) pairs. + num_words : int, optional + DEPRECATED, USE `topn` INSTEAD. + + Returns + ------- + list of (str, numpy.float) **or** list of str + Output format for terms from a single topic depends on the value of `self.style` attribute. + + """ if num_words is not None: # deprecated num_words is used warnings.warn( "The parameter `num_words` is deprecated, will be removed in 4.0.0, please use `topn` instead." @@ -664,9 +1144,39 @@ def show_topic(self, topic_id, topn=20, log=False, formatted=False, num_words=No return topic[1] def show_topic_terms(self, topic_data, num_words): + """Give the topic terms along with their probabilities for a single topic data. + + Parameters + ---------- + topic_data : list of (str, numpy.float) + Contains probabilities for each word id belonging to a single topic. + num_words : int + Number of words for which probabilities are to be extracted from the given single topic data. + + Returns + ------- + list of (str, numpy.float) + A sequence of topic terms and their probabilities. + + """ return [(self.dictionary[wid], weight) for (weight, wid) in topic_data[:num_words]] def format_topic(self, topic_id, topic_terms): + """Format the display for a single topic in two different ways. + + Parameters + ---------- + topic_id : int + Acts as a representative index for a particular topic. + topic_terms : list of (str, numpy.float) + Contains the most probable words from a single topic. + + Returns + ------- + list of (str, numpy.float) **or** list of str + Output format for topic terms depends on the value of `self.style` attribute. + + """ if self.STYLE_GENSIM == self.style: fmt = ' + '.join(['%.3f*%s' % (weight, word) for (word, weight) in topic_terms]) else: diff --git a/gensim/models/lda_dispatcher.py b/gensim/models/lda_dispatcher.py index db7a33468b..c6865981ab 100755 --- a/gensim/models/lda_dispatcher.py +++ b/gensim/models/lda_dispatcher.py @@ -4,13 +4,59 @@ # Copyright (C) 2010 Radim Rehurek # Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html -""" -USAGE: %(program)s SIZE_OF_JOBS_QUEUE +""":class:`~gensim.models.lda_dispatcher.Dispatcher` process which orchestrates +distributed :class:`~gensim.models.ldamodel.LdaModel` computations. +Run this script only once, on the master node in your cluster. + +Notes +----- +The dispatches expects to find worker scripts already running. Make sure you run as many workers as you like on +your machines **before** launching the dispatcher. + +Warnings +-------- +Requires installed `Pyro4 `_. + + +How to use distributed :class:`~gensim.models.ldamodel.LdaModel` +---------------------------------------------------------------- + + +#. Install needed dependencies (Pyro4) :: + + pip install gensim[distributed] + +#. Setup serialization (on each machine) :: + + export PYRO_SERIALIZERS_ACCEPTED=pickle + export PYRO_SERIALIZER=pickle + +#. Run nameserver :: + + python -m Pyro4.naming -n 0.0.0.0 & + +#. Run workers (on each machine) :: + + python -m gensim.models.lda_worker & - Dispatcher process which orchestrates distributed LDA computations. Run this \ -script only once, on any node in your cluster. +#. Run dispatcher :: + + python -m gensim.models.lda_dispatcher & + +#. Run :class:`~gensim.models.ldamodel.LdaModel` in distributed mode :: + + >>> from gensim.test.utils import common_corpus,common_dictionary + >>> from gensim.models import LdaModel + >>> + >>> model = LdaModel(common_corpus, id2word=common_dictionary, distributed=True) + + +Command line arguments +---------------------- + +.. program-output:: python -m gensim.models.lda_dispatcher --help + :ellipsis: 0, -7 -Example: python -m gensim.models.lda_dispatcher """ @@ -50,33 +96,54 @@ class Dispatcher(object): - """ - Dispatcher object that communicates and coordinates individual workers. + """Dispatcher object that communicates and coordinates individual workers. + Warnings + -------- There should never be more than one dispatcher running at any one time. + """ def __init__(self, maxsize=MAX_JOBS_QUEUE, ns_conf=None): - """ - Note that the constructor does not fully initialize the dispatcher; - use the `initialize()` function to populate it with workers etc. + """Partly initializes the dispatcher. + + A full initialization (including initialization of the workers) requires a call to + :meth:`~gensim.models.lda_dispatcher.Dispatcher.initialize` + + Parameters + ---------- + maxsize : int, optional + Maximum number of jobs to be kept pre-fetched in the queue. + ns_conf : dict of (str, object) + Sets up the name server configuration for the pyro daemon server of dispatcher. + This also helps to keep track of your objects in your network by using logical object names + instead of exact object name(or id) and its location. + """ self.maxsize = maxsize - self.callback = None # a pyro proxy to this object (unknown at init time, but will be set later) + self.callback = None self.ns_conf = ns_conf if ns_conf is not None else {} @Pyro4.expose def initialize(self, **model_params): - """ - `model_params` are parameters used to initialize individual workers (gets - handed all the way down to `worker.initialize()`). + """Fully initializes the dispatcher and all its workers. + + Parameters + ---------- + **model_params + Keyword parameters used to initialize individual workers, see :class:`~gensim.models.ldamodel.LdaModel`. + + Raises + ------ + RuntimeError + When no workers are found (the :mod:`gensim.models.lda_worker` script must be ran beforehand). + """ self.jobs = Queue(maxsize=self.maxsize) self.lock_update = threading.Lock() self._jobsdone = 0 self._jobsreceived = 0 - # locate all available workers and store their proxies, for subsequent RMI calls self.workers = {} with utils.getNS(**self.ns_conf) as ns: self.callback = Pyro4.Proxy(ns.list(prefix=LDA_DISPATCHER_PREFIX)[LDA_DISPATCHER_PREFIX]) @@ -89,7 +156,7 @@ def initialize(self, **model_params): worker.initialize(workerid, dispatcher=self.callback, **model_params) self.workers[workerid] = worker except Pyro4.errors.PyroError: - logger.warning("unresponsive worker at %s, deleting it from the name server", uri) + logger.warning("unresponsive worker at %s,deleting it from the name server", uri) ns.remove(name) if not self.workers: @@ -97,13 +164,31 @@ def initialize(self, **model_params): @Pyro4.expose def getworkers(self): - """ - Return pyro URIs of all registered workers. + """Return pyro URIs of all registered workers. + + Returns + ------- + list of URIs + The pyro URIs for each worker. + """ return [worker._pyroUri for worker in itervalues(self.workers)] @Pyro4.expose def getjob(self, worker_id): + """Atomically pops a job from the queue. + + Parameters + ---------- + worker_id : int + The worker that requested the job. + + Returns + ------- + iterable of list of (int, float) + The corpus in BoW format. + + """ logger.info("worker #%i requesting a new job", worker_id) job = self.jobs.get(block=True, timeout=1) logger.info("worker #%i got a new job (%i left)", worker_id, self.jobs.qsize()) @@ -111,14 +196,27 @@ def getjob(self, worker_id): @Pyro4.expose def putjob(self, job): + """Atomically add a job to the queue. + + Parameters + ---------- + job : iterable of list of (int, float) + The corpus in BoW format. + + """ self._jobsreceived += 1 self.jobs.put(job, block=True, timeout=HUGE_TIMEOUT) logger.info("added a new job (len(queue)=%i items)", self.jobs.qsize()) @Pyro4.expose def getstate(self): - """ - Merge states from across all workers and return the result. + """Merge states from across all workers and return the result. + + Returns + ------- + :class:`~gensim.models.ldamodel.LdaState` + Merged resultant state + """ logger.info("end of input, assigning all remaining jobs") logger.debug("jobs done: %s, jobs received: %s", self._jobsdone, self._jobsreceived) @@ -144,8 +242,13 @@ def getstate(self): @Pyro4.expose def reset(self, state): - """ - Initialize all workers for a new EM iterations. + """Reinitializes all workers for a new EM iteration. + + Parameters + ---------- + state : :class:`~gensim.models.ldamodel.LdaState` + State of :class:`~gensim.models.lda.LdaModel`. + """ for workerid, worker in iteritems(self.workers): logger.info("resetting worker %s", workerid) @@ -158,36 +261,46 @@ def reset(self, state): @Pyro4.oneway @utils.synchronous('lock_update') def jobdone(self, workerid): - """ - A worker has finished its job. Log this event and then asynchronously - transfer control back to the worker. + """Callback used by workers to notify when their job is done. + + The job done event is logged and then control is asynchronously transfered back to the worker + (who can then request another job). In this way, control flow basically oscillates between + :meth:`gensim.models.lda_dispatcher.Dispatcher.jobdone` and :meth:`gensim.models.lda_worker.Worker.requestjob`. + + Parameters + ---------- + workerid : int + The ID of the worker that finished the job (used for logging). - In this way, control flow basically oscillates between `dispatcher.jobdone()` - and `worker.requestjob()`. """ self._jobsdone += 1 logger.info("worker #%s finished job #%i", workerid, self._jobsdone) self.workers[workerid].requestjob() # tell the worker to ask for another job, asynchronously (one-way) def jobsdone(self): - """Wrap self._jobsdone, needed for remote access through Pyro proxies""" + """Wrap :attr:`~gensim.models.lda_dispatcher.Dispatcher._jobsdone` needed for remote access through proxies. + + Returns + ------- + int + Number of jobs already completed. + + """ return self._jobsdone @Pyro4.oneway def exit(self): - """ - Terminate all registered workers and then the dispatcher. - """ + """Terminate all workers and then the dispatcher.""" for workerid, worker in iteritems(self.workers): logger.info("terminating worker %s", workerid) worker.exit() logger.info("terminating dispatcher") os._exit(0) # exit the whole process (not just this thread ala sys.exit()) -# endclass Dispatcher def main(): - parser = argparse.ArgumentParser(description=__doc__) + """Set up argument parser,logger and launches pyro daemon.""" + parser = argparse.ArgumentParser(description=__doc__[:-135], formatter_class=argparse.RawTextHelpFormatter) parser.add_argument( "--maxsize", help="How many jobs (=chunks of N documents) to keep 'pre-fetched' in a queue (default: %(default)s)", diff --git a/gensim/models/lda_worker.py b/gensim/models/lda_worker.py index 16110486d6..56314e8388 100755 --- a/gensim/models/lda_worker.py +++ b/gensim/models/lda_worker.py @@ -4,16 +4,57 @@ # Copyright (C) 2011 Radim Rehurek # Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html -""" -Worker ("slave") process used in computing distributed LDA. Run this script \ -on every node in your cluster. If you wish, you may even run it multiple times \ -on a single machine, to make better use of multiple cores (just beware that \ -memory footprint increases accordingly). +""":class:`~gensim.models.lda_worker.Worker` ("slave") process used in computing +distributed :class:`~gensim.models.ldamodel.LdaModel`. -Example: python -m gensim.models.lda_worker -""" +Run this script on every node in your cluster. If you wish, you may even run it multiple times on a single machine, +to make better use of multiple cores (just beware that memory footprint increases accordingly). + +Warnings +-------- +Requires installed `Pyro4 `_. + + +How to use distributed :class:`~gensim.models.ldamodel.LdaModel` +---------------------------------------------------------------- + + +#. Install needed dependencies (Pyro4) :: + + pip install gensim[distributed] + +#. Setup serialization (on each machine) :: + + export PYRO_SERIALIZERS_ACCEPTED=pickle + export PYRO_SERIALIZER=pickle + +#. Run nameserver :: + + python -m Pyro4.naming -n 0.0.0.0 & + +#. Run workers (on each machine) :: + + python -m gensim.models.lda_worker & + +#. Run dispatcher :: + + python -m gensim.models.lda_dispatcher & + +#. Run :class:`~gensim.models.ldamodel.LdaModel` in distributed mode :: + + >>> from gensim.test.utils import common_corpus,common_dictionary + >>> from gensim.models import LdaModel + >>> + >>> model = LdaModel(common_corpus, id2word=common_dictionary, distributed=True) + + +Command line arguments +---------------------- +.. program-output:: python -m gensim.models.lda_worker --help + :ellipsis: 0, -7 +""" from __future__ import with_statement import os import sys @@ -40,11 +81,30 @@ class Worker(object): + """Used as a Pyro4 class with exposed methods. + + Exposes every non-private method and property of the class automatically to be available for remote access. + + """ + def __init__(self): + """Partly initializes the model.""" self.model = None @Pyro4.expose def initialize(self, myid, dispatcher, **model_params): + """Fully initializes the worker. + + Parameters + ---------- + myid : int + An ID number used to identify this worker in the dispatcher object. + dispatcher : :class:`~gensim.models.lda_dispatcher.Dispatcher` + The dispatcher responsible for scheduling this worker. + **model_params + Keyword parameters to initialize the inner LDA model,see :class:`~gensim.models.ldamodel.LdaModel`. + + """ self.lock_update = threading.Lock() self.jobsdone = 0 # how many jobs has this worker completed? # id of this worker in the dispatcher; just a convenience var for easy access/logging TODO remove? @@ -57,8 +117,14 @@ def initialize(self, myid, dispatcher, **model_params): @Pyro4.expose @Pyro4.oneway def requestjob(self): - """ - Request jobs from the dispatcher, in a perpetual loop until `getstate()` is called. + """Request jobs from the dispatcher, in a perpetual loop until :meth:`~gensim.models.lda_worker.Worker.getstate` + is called. + + Raises + ------ + RuntimeError + If `self.model` is None (i.e. worker non initialized). + """ if self.model is None: raise RuntimeError("worker must be initialized before receiving jobs") @@ -79,6 +145,14 @@ def requestjob(self): @utils.synchronous('lock_update') def processjob(self, job): + """Incrementally processes the job and potentially logs progress. + + Parameters + ---------- + job : iterable of list of (int, float) + Corpus in BoW format. + + """ logger.debug("starting to process job #%i", self.jobsdone) self.model.do_estep(job) self.jobsdone += 1 @@ -89,11 +163,20 @@ def processjob(self, job): @Pyro4.expose def ping(self): + """Test the connectivity with Worker.""" return True @Pyro4.expose @utils.synchronous('lock_update') def getstate(self): + """Log and get the LDA model's current state. + + Returns + ------- + result : :class:`~gensim.models.ldamodel.LdaState` + The current state. + + """ logger.info("worker #%i returning its state after %s jobs", self.myid, self.jobsdone) result = self.model.state assert isinstance(result, ldamodel.LdaState) @@ -104,6 +187,14 @@ def getstate(self): @Pyro4.expose @utils.synchronous('lock_update') def reset(self, state): + """Reset the worker by setting sufficient stats to 0. + + Parameters + ---------- + state : :class:`~gensim.models.ldamodel.LdaState` + Encapsulates information for distributed computation of LdaModel objects. + + """ assert state is not None logger.info("resetting worker #%i", self.myid) self.model.state = state @@ -113,12 +204,13 @@ def reset(self, state): @Pyro4.oneway def exit(self): + """Terminate the worker.""" logger.info("terminating worker #%i", self.myid) os._exit(0) def main(): - parser = argparse.ArgumentParser(description=__doc__) + parser = argparse.ArgumentParser(description=__doc__[:-130], formatter_class=argparse.RawTextHelpFormatter) parser.add_argument("--host", help="Nameserver hostname (default: %(default)s)", default=None) parser.add_argument("--port", help="Nameserver port (default: %(default)s)", default=None, type=int) parser.add_argument( diff --git a/gensim/models/lsi_dispatcher.py b/gensim/models/lsi_dispatcher.py index e4c06ef307..af435999e2 100755 --- a/gensim/models/lsi_dispatcher.py +++ b/gensim/models/lsi_dispatcher.py @@ -127,7 +127,7 @@ def initialize(self, **model_params): Raises ------ RuntimeError - When no workers are found (the `gensim.scripts.lsi_worker` script must be ran beforehand). + When no workers are found (the :mod:`gensim.model.lsi_worker` script must be ran beforehand). """ self.jobs = Queue(maxsize=self.maxsize) @@ -192,7 +192,7 @@ def putjob(self, job): Parameters ---------- - job : iterable of iterable of (int, float) + job : iterable of list of (int, float) The corpus in BoW format. """ @@ -246,8 +246,7 @@ def jobdone(self, workerid): The job done event is logged and then control is asynchronously transfered back to the worker (who can then request another job). In this way, control flow basically oscillates between - :meth:`gensim.models.lsi_dispatcher.Dispatcher.jobdone` and - :meth:`gensim.models.lsi_worker.Worker.requestjob`. + :meth:`gensim.models.lsi_dispatcher.Dispatcher.jobdone` and :meth:`gensim.models.lsi_worker.Worker.requestjob`. Parameters ---------- @@ -273,7 +272,7 @@ def jobsdone(self): @Pyro4.oneway def exit(self): - """Terminate all registered workers and then the dispatcher.""" + """Terminate all workers and then the dispatcher.""" for workerid, worker in iteritems(self.workers): logger.info("terminating worker %s", workerid) worker.exit() diff --git a/gensim/models/lsi_worker.py b/gensim/models/lsi_worker.py index 5f4ccc5c2f..2a4a66bb9e 100755 --- a/gensim/models/lsi_worker.py +++ b/gensim/models/lsi_worker.py @@ -112,8 +112,13 @@ def initialize(self, myid, dispatcher, **model_params): @Pyro4.expose @Pyro4.oneway def requestjob(self): - """Request jobs from the dispatcher, in a perpetual loop until - :meth:`~gensim.models.lsi_worker.Worker.getstate()` is called. + """Request jobs from the dispatcher, in a perpetual loop until :meth:`~gensim.models.lsi_worker.Worker.getstate` + is called. + + Raises + ------ + RuntimeError + If `self.model` is None (i.e. worker non initialized). """ if self.model is None: