Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SvmLightCorpus.serialize if labels instance of numpy.ndarray #2243

Merged
merged 2 commits into from Jan 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions gensim/corpora/svmlightcorpus.py
Expand Up @@ -111,6 +111,9 @@ def save_corpus(fname, corpus, id2word=None, labels=False, metadata=False):
"""
logger.info("converting corpus to SVMlight format: %s", fname)

if labels is not False:
# Cast any sequence (incl. a numpy array) to a list, to simplify the processing below.
labels = list(labels)
offsets = []
with utils.smart_open(fname, 'wb') as fout:
for docno, doc in enumerate(corpus):
Expand Down
13 changes: 12 additions & 1 deletion gensim/test/test_corpora.py
Expand Up @@ -23,7 +23,7 @@
ucicorpus, malletcorpus, textcorpus, indexedcorpus, wikicorpus)
from gensim.interfaces import TransformedCorpus
from gensim.utils import to_unicode
from gensim.test.utils import datapath, get_tmpfile
from gensim.test.utils import datapath, get_tmpfile, common_corpus


class DummyTransformer(object):
Expand Down Expand Up @@ -382,6 +382,17 @@ def setUp(self):
self.corpus_class = svmlightcorpus.SvmLightCorpus
self.file_extension = '.svmlight'

def test_serialization(self):
path = get_tmpfile("svml.corpus")
labels = [1] * len(common_corpus)
second_corpus = [(0, 1.0), (3, 1.0), (4, 1.0), (5, 1.0), (6, 1.0), (7, 1.0)]
self.corpus_class.serialize(path, common_corpus, labels=labels)
serialized_corpus = self.corpus_class(path)
self.assertEqual(serialized_corpus[1], second_corpus)
self.corpus_class.serialize(path, common_corpus, labels=np.array(labels))
serialized_corpus = self.corpus_class(path)
self.assertEqual(serialized_corpus[1], second_corpus)


class TestBleiCorpus(CorpusTestCase):
def setUp(self):
Expand Down