Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling for iterables without 0-th element, fixes #2556 #2629

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion gensim/sklearn_api/d2vmodel.py
Expand Up @@ -159,7 +159,7 @@ def fit(self, X, y=None):
The trained model.

"""
if isinstance(X[0], doc2vec.TaggedDocument):
if isinstance([i for i in X[:1]][0], doc2vec.TaggedDocument):
d2v_sentences = X
else:
d2v_sentences = [doc2vec.TaggedDocument(words, [i]) for i, words in enumerate(X)]
Expand Down
57 changes: 57 additions & 0 deletions gensim/test/test_d2vmodel.py
@@ -0,0 +1,57 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2010 Radim Rehurek <radimrehurek@seznam.cz>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html

"""
Automated tests for checking D2VTransformer class.
"""

import unittest
import logging
from gensim.sklearn_api import D2VTransformer
from gensim.test.utils import common_texts


class IteratorForIterable:
"""Iterator capable of folding into list."""
def __init__(self, iterable):
self._data = iterable
self._index = 0

def __next__(self):
if len(self._data) > self._index:
result = self._data[self._index]
self._index += 1
return result
raise StopIteration


class IterableWithoutZeroElement:
"""
Iterable, emulating pandas.Series behaviour without 0-th element.
Equivalent to calling `series.index += 1`.
"""
def __init__(self, data):
self.data = data

def __getitem__(self, key):
if key == 0:
raise KeyError("Emulation of absence of item with key 0.")
return self.data[key]

def __iter__(self):
return IteratorForIterable(self.data)


class TestD2VTransformer(unittest.TestCase):
def TestWorksWithIterableNotHavingElementWithZeroIndex(self):
a = IterableWithoutZeroElement(common_texts)
transformer = D2VTransformer(min_count=1, size=5)
transformer.fit(a)


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
unittest.main()