Skip to content

Commit

Permalink
Return the static LDA topic for a term in the corpus vocabulary. (#706)
Browse files Browse the repository at this point in the history
  • Loading branch information
bhargavvader authored and tmylk committed May 31, 2016
1 parent edffd8e commit 97c8455
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.txt
@@ -1,6 +1,7 @@
Changes
=======
0.12.5, 2016
* added term-topics API for most probable topic for word in vocab. (@bhargavvader, #706)
* build_vocab takes progress_per parameter for smaller output (@zer0n, #624)
* Control whether to use lowercase for computing word2vec accuracy. (@alantian, #607)
* Easy import of GloVe vectors using Gensim (Manas Ranjan Kar, #625)
Expand Down
21 changes: 21 additions & 0 deletions gensim/models/ldamodel.py
Expand Up @@ -905,6 +905,27 @@ def get_document_topics(self, bow, minimum_probability=None):
return [(topicid, topicvalue) for topicid, topicvalue in enumerate(topic_dist)
if topicvalue >= minimum_probability]

def get_term_topics(self, word_id, minimum_probability=None):
"""
Returns most likely topics for a particular word in vocab.
"""
if minimum_probability is None:
minimum_probability = self.minimum_probability
minimum_probability = max(minimum_probability, 1e-8) # never allow zero values in sparse output

# if user enters word instead of id in vocab, change to get id
if isinstance(word_id, str):
word_id = self.id2word.doc2bow([word_id])[0][0]

values = []
for topic_id in range(0, self.num_topics):
if self.expElogbeta[topic_id][word_id] >= minimum_probability:
values.append((topic_id, self.expElogbeta[topic_id][word_id]))

return values


def __getitem__(self, bow, eps=None):
"""
Return topic distribution for the given document `bow`, as a list of
Expand Down
16 changes: 16 additions & 0 deletions gensim/test/test_ldamodel.py
Expand Up @@ -265,6 +265,22 @@ def testGetDocumentTopics(self):
self.assertTrue(isinstance(k, int))
self.assertTrue(isinstance(v, float))

def testTermTopics(self):

numpy.random.seed(0)
model = self.class_(self.corpus, id2word=dictionary, num_topics=2, passes=100)

# check with id
result = model.get_term_topics(2)
expected = [(1, 0.1066)]
self.assertEqual(result[0][0], expected[0][0])
self.assertAlmostEqual(result[0][1], expected[0][1], places=2)

# if user has entered word instead, check with word
result = model.get_term_topics(str(model.id2word[2]))
expected = [(1, 0.1066)]
self.assertEqual(result[0][0], expected[0][0])
self.assertAlmostEqual(result[0][1], expected[0][1], places=2)

def testPasses(self):
# long message includes the original error message with a custom one
Expand Down

0 comments on commit 97c8455

Please sign in to comment.