From 6eafa239183b6f484825781ffa3a610afe8f32de Mon Sep 17 00:00:00 2001 From: aquatiko Date: Sat, 27 Oct 2018 16:35:20 +0530 Subject: [PATCH] svmlightcorpus.py: Add sequence serialization of corpus Current version of serialization support only lists, but this adds support for any type of sequence. Closes: https://github.com/RaRe-Technologies/gensim/issues/2113 --- gensim/corpora/svmlightcorpus.py | 3 +++ gensim/test/test_corpora.py | 13 ++++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/gensim/corpora/svmlightcorpus.py b/gensim/corpora/svmlightcorpus.py index 459274cfae..153bd973e0 100644 --- a/gensim/corpora/svmlightcorpus.py +++ b/gensim/corpora/svmlightcorpus.py @@ -111,6 +111,9 @@ def save_corpus(fname, corpus, id2word=None, labels=False, metadata=False): """ logger.info("converting corpus to SVMlight format: %s", fname) + if labels is not False: + # Cast any sequence (incl. a numpy array) to a list, to simplify the processing below. + labels = list(labels) offsets = [] with utils.smart_open(fname, 'wb') as fout: for docno, doc in enumerate(corpus): diff --git a/gensim/test/test_corpora.py b/gensim/test/test_corpora.py index 8eb10faa0e..34959c717c 100644 --- a/gensim/test/test_corpora.py +++ b/gensim/test/test_corpora.py @@ -23,7 +23,7 @@ ucicorpus, malletcorpus, textcorpus, indexedcorpus, wikicorpus) from gensim.interfaces import TransformedCorpus from gensim.utils import to_unicode -from gensim.test.utils import datapath, get_tmpfile +from gensim.test.utils import datapath, get_tmpfile, common_corpus class DummyTransformer(object): @@ -382,6 +382,17 @@ def setUp(self): self.corpus_class = svmlightcorpus.SvmLightCorpus self.file_extension = '.svmlight' + def test_serialization(self): + path = get_tmpfile("svml.corpus") + labels = [1] * len(common_corpus) + second_corpus = [(0, 1.0), (3, 1.0), (4, 1.0), (5, 1.0), (6, 1.0), (7, 1.0)] + self.corpus_class.serialize(path, common_corpus, labels=labels) + serialized_corpus = self.corpus_class(path) + self.assertEqual(serialized_corpus[1], second_corpus) + self.corpus_class.serialize(path, common_corpus, labels=np.array(labels)) + serialized_corpus = self.corpus_class(path) + self.assertEqual(serialized_corpus[1], second_corpus) + class TestBleiCorpus(CorpusTestCase): def setUp(self):