diff --git a/gensim/corpora/dictionary.py b/gensim/corpora/dictionary.py index 1ff89a5b31..a0b3f8d73e 100644 --- a/gensim/corpora/dictionary.py +++ b/gensim/corpora/dictionary.py @@ -24,13 +24,13 @@ from gensim import utils -if sys.version_info[0] >= 3: - unicode = str - from six import PY3, iteritems, iterkeys, itervalues, string_types from six.moves import xrange from six.moves import zip as izip +if sys.version_info[0] >= 3: + unicode = str + logger = logging.getLogger('gensim.corpora.dictionary') @@ -180,7 +180,7 @@ def filter_extremes(self, no_below=5, no_above=0.5, keep_n=100000, keep_tokens=N 2. more than `no_above` documents (fraction of total corpus size, *not* absolute number). 3. if tokens are given in keep_tokens (list of strings), they will be kept regardless of - the `no_below` and `no_above` settings + the `no_below` and `no_above` settings 4. after (1), (2) and (3), keep only the first `keep_n` most frequent tokens (or keep all if `None`). @@ -196,8 +196,7 @@ def filter_extremes(self, no_below=5, no_above=0.5, keep_n=100000, keep_tokens=N keep_ids = [self.token2id[v] for v in keep_tokens if v in self.token2id] good_ids = ( v for v in itervalues(self.token2id) - if no_below <= self.dfs.get(v, 0) <= no_above_abs - or v in keep_ids + if no_below <= self.dfs.get(v, 0) <= no_above_abs or v in keep_ids ) else: good_ids = ( @@ -232,7 +231,7 @@ def filter_n_most_frequent(self, remove_n): # do the actual filtering, then rebuild dictionary to remove gaps in ids most_frequent_words = [(self[id], self.dfs.get(id, 0)) for id in most_frequent_ids] logger.info("discarding %i tokens: %s...", len(most_frequent_ids), most_frequent_words[:10]) - + self.filter_tokens(bad_ids=most_frequent_ids) logger.info("resulting dictionary: %s" % self) @@ -282,6 +281,7 @@ def compactify(self): def save_as_text(self, fname, sort_by_word=True): """ Save this Dictionary to a text file, in format: + `num_docs` `id[TAB]word_utf8[TAB]document frequency[NEWLINE]`. Sorted by word, or by decreasing word frequency. @@ -290,6 +290,8 @@ def save_as_text(self, fname, sort_by_word=True): """ logger.info("saving dictionary mapping to %s", fname) with utils.smart_open(fname, 'wb') as fout: + numdocs_line = "%d\n" % self.num_docs + fout.write(utils.to_utf8(numdocs_line)) if sort_by_word: for token, tokenid in sorted(iteritems(self.token2id)): line = "%i\t%s\t%i\n" % (tokenid, token, self.dfs.get(tokenid, 0)) @@ -354,6 +356,13 @@ def load_from_text(fname): with utils.smart_open(fname) as f: for lineno, line in enumerate(f): line = utils.to_unicode(line) + if lineno == 0: + if line.strip().isdigit(): + # Older versions of save_as_text may not write num_docs on first line. + result.num_docs = int(line.strip()) + continue + else: + logging.warning("Text does not contain num_docs on the first line.") try: wordid, word, docfreq = line[:-1].split('\t') except Exception: diff --git a/gensim/test/test_corpora_dictionary.py b/gensim/test/test_corpora_dictionary.py index 16c499b245..210ff94548 100644 --- a/gensim/test/test_corpora_dictionary.py +++ b/gensim/test/test_corpora_dictionary.py @@ -120,35 +120,34 @@ def testFilter(self): d.filter_extremes(no_below=2, no_above=1.0, keep_n=4) expected = {0: 3, 1: 3, 2: 3, 3: 3} self.assertEqual(d.dfs, expected) - + def testFilterKeepTokens_keepTokens(self): # provide keep_tokens argument, keep the tokens given d = Dictionary(self.texts) d.filter_extremes(no_below=3, no_above=1.0, keep_tokens=['human', 'survey']) expected = set(['graph', 'trees', 'human', 'system', 'user', 'survey']) self.assertEqual(set(d.token2id.keys()), expected) - + def testFilterKeepTokens_unchangedFunctionality(self): # do not provide keep_tokens argument, filter_extremes functionality is unchanged d = Dictionary(self.texts) d.filter_extremes(no_below=3, no_above=1.0) expected = set(['graph', 'trees', 'system', 'user']) self.assertEqual(set(d.token2id.keys()), expected) - + def testFilterKeepTokens_unseenToken(self): # do provide keep_tokens argument with unseen tokens, filter_extremes functionality is unchanged d = Dictionary(self.texts) d.filter_extremes(no_below=3, no_above=1.0, keep_tokens=['unknown_token']) expected = set(['graph', 'trees', 'system', 'user']) - self.assertEqual(set(d.token2id.keys()), expected) + self.assertEqual(set(d.token2id.keys()), expected) def testFilterMostFrequent(self): - d = Dictionary(self.texts) - d.filter_n_most_frequent(4) - expected = {0: 2, 1: 2, 2: 2, 3: 2, 4: 2, 5: 2, 6: 2, 7: 2} - self.assertEqual(d.dfs, expected) - - + d = Dictionary(self.texts) + d.filter_n_most_frequent(4) + expected = {0: 2, 1: 2, 2: 2, 3: 2, 4: 2, 5: 2, 6: 2, 7: 2} + self.assertEqual(d.dfs, expected) + def testFilterTokens(self): self.maxDiff = 10000 d = Dictionary(self.texts) @@ -156,9 +155,10 @@ def testFilterTokens(self): removed_word = d[0] d.filter_tokens([0]) - expected = {'computer': 0, 'eps': 8, 'graph': 10, 'human': 1, - 'interface': 2, 'minors': 11, 'response': 3, 'survey': 4, - 'system': 5, 'time': 6, 'trees': 9, 'user': 7} + expected = { + 'computer': 0, 'eps': 8, 'graph': 10, 'human': 1, + 'interface': 2, 'minors': 11, 'response': 3, 'survey': 4, + 'system': 5, 'time': 6, 'trees': 9, 'user': 7} del expected[removed_word] self.assertEqual(sorted(d.token2id.keys()), sorted(expected.keys())) @@ -166,7 +166,6 @@ def testFilterTokens(self): d.add_documents([[removed_word]]) self.assertEqual(sorted(d.token2id.keys()), sorted(expected.keys())) - def test_doc2bow(self): d = Dictionary([["žluťoučký"], ["žluťoučký"]]) @@ -179,6 +178,66 @@ def test_doc2bow(self): # unicode must be converted to utf8 self.assertEqual(d.doc2bow([u'\u017elu\u0165ou\u010dk\xfd']), [(0, 1)]) + def test_saveAsText(self): + """`Dictionary` can be saved as textfile. """ + tmpf = get_tmpfile('save_dict_test.txt') + small_text = [ + ["prvé", "slovo"], + ["slovo", "druhé"], + ["druhé", "slovo"]] + + d = Dictionary(small_text) + + d.save_as_text(tmpf) + with open(tmpf) as file: + serialized_lines = file.readlines() + self.assertEqual(serialized_lines[0], "3\n") + self.assertEqual(len(serialized_lines), 4) + # We do not know, which word will have which index + self.assertEqual(serialized_lines[1][1:], "\tdruhé\t2\n") + self.assertEqual(serialized_lines[2][1:], "\tprvé\t1\n") + self.assertEqual(serialized_lines[3][1:], "\tslovo\t3\n") + + d.save_as_text(tmpf, sort_by_word=False) + with open(tmpf) as file: + serialized_lines = file.readlines() + self.assertEqual(serialized_lines[0], "3\n") + self.assertEqual(len(serialized_lines), 4) + self.assertEqual(serialized_lines[1][1:], "\tslovo\t3\n") + self.assertEqual(serialized_lines[2][1:], "\tdruhé\t2\n") + self.assertEqual(serialized_lines[3][1:], "\tprvé\t1\n") + + def test_loadFromText_legacy(self): + """ + `Dictionary` can be loaded from textfile in legacy format. + Legacy format does not have num_docs on the first line. + """ + tmpf = get_tmpfile('load_dict_test_legacy.txt') + no_num_docs_serialization = "1\tprvé\t1\n2\tslovo\t2\n" + with open(tmpf, "w") as file: + file.write(no_num_docs_serialization) + + d = Dictionary.load_from_text(tmpf) + self.assertEqual(d.token2id[u"prvé"], 1) + self.assertEqual(d.token2id[u"slovo"], 2) + self.assertEqual(d.dfs[1], 1) + self.assertEqual(d.dfs[2], 2) + self.assertEqual(d.num_docs, 0) + + def test_loadFromText(self): + """`Dictionary` can be loaded from textfile.""" + tmpf = get_tmpfile('load_dict_test.txt') + no_num_docs_serialization = "2\n1\tprvé\t1\n2\tslovo\t2\n" + with open(tmpf, "w") as file: + file.write(no_num_docs_serialization) + + d = Dictionary.load_from_text(tmpf) + self.assertEqual(d.token2id[u"prvé"], 1) + self.assertEqual(d.token2id[u"slovo"], 2) + self.assertEqual(d.dfs[1], 1) + self.assertEqual(d.dfs[2], 2) + self.assertEqual(d.num_docs, 2) + def test_saveAsText_and_loadFromText(self): """`Dictionary` can be saved as textfile and loaded again from textfile. """ tmpf = get_tmpfile('dict_test.txt') @@ -194,24 +253,25 @@ def test_saveAsText_and_loadFromText(self): def test_from_corpus(self): """build `Dictionary` from an existing corpus""" - documents = ["Human machine interface for lab abc computer applications", - "A survey of user opinion of computer system response time", - "The EPS user interface management system", - "System and human system engineering testing of EPS", - "Relation of user perceived response time to error measurement", - "The generation of random binary unordered trees", - "The intersection graph of paths in trees", - "Graph minors IV Widths of trees and well quasi ordering", - "Graph minors A survey"] + documents = [ + "Human machine interface for lab abc computer applications", + "A survey of user opinion of computer system response time", + "The EPS user interface management system", + "System and human system engineering testing of EPS", + "Relation of user perceived response time to error measurement", + "The generation of random binary unordered trees", + "The intersection graph of paths in trees", + "Graph minors IV Widths of trees and well quasi ordering", + "Graph minors A survey"] stoplist = set('for a of the and to in'.split()) - texts = [[word for word in document.lower().split() if word not in stoplist] - for document in documents] + texts = [ + [word for word in document.lower().split() if word not in stoplist] + for document in documents] # remove words that appear only once all_tokens = sum(texts, []) tokens_once = set(word for word in set(all_tokens) if all_tokens.count(word) == 1) - texts = [[word for word in text if word not in tokens_once] - for text in texts] + texts = [[word for word in text if word not in tokens_once] for text in texts] dictionary = Dictionary(texts) corpus = [dictionary.doc2bow(text) for text in texts] @@ -260,7 +320,7 @@ def test_dict_interface(self): self.assertTrue(isinstance(d.keys(), list)) self.assertTrue(isinstance(d.values(), list)) -#endclass TestDictionary +# endclass TestDictionary if __name__ == '__main__':