Skip to content

Commit

Permalink
Sparse2Corpus: update __getitem__ to work on slices, lists and ellips…
Browse files Browse the repository at this point in the history
…is (#3247)

* Sparse2Corpus: update __getitem__ to work on slices, ellipsis, and iterable

* Sparse2Corpus: update __getitem__ to work on slices, ellipsis, and iterable

* Update CHANGELOG.md

Co-authored-by: Michael Penkov <m@penkov.dev>
  • Loading branch information
PrimozGodec and mpenkov committed Dec 4, 2021
1 parent fa2d1b1 commit 2f182d7
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Changes
* [#3194](https://github.com/RaRe-Technologies/gensim/pull/3194): Added random_seed parameter to make LsiModel reproducible, by [@parashardhapola](https://github.com/parashardhapola)
* [#3251](https://github.com/RaRe-Technologies/gensim/pull/3251): Apply new convention of delimiting instance params in str function, by [@menshikh-iv](https://github.com/menshikh-iv)
* [#3227](https://github.com/RaRe-Technologies/gensim/pull/3227): Fix FastText doc-comment example for `build_vocab` and `train` to use correct argument names, by [@HLasse](https://github.com/HLasse)
* [#3247](https://github.com/RaRe-Technologies/gensim/pull/3247): Sparse2Corpus: update __getitem__ to work on slices, lists and ellipsis, by [@PrimozGodec](https://github.com/PrimozGodec)
* [#3250](https://github.com/RaRe-Technologies/gensim/pull/3250): Make negative ns_exponent work correctly, by [@menshikh-iv](https://github.com/menshikh-iv)
* [#3258](https://github.com/RaRe-Technologies/gensim/pull/3258): Adding another check to _check_corpus_sanity for compressed files, adding test, by [@dchaplinsky](https://github.com/dchaplinsky)

Expand Down
27 changes: 17 additions & 10 deletions gensim/matutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,23 +597,30 @@ def __iter__(self):
def __len__(self):
return self.sparse.shape[1]

def __getitem__(self, document_index):
"""Retrieve a document vector from the corpus by its index.
def __getitem__(self, key):
"""
Retrieve a document vector or subset from the corpus by key.
Parameters
----------
document_index : int
Index of document
key: int, ellipsis, slice, iterable object
Index of the document retrieve.
Less commonly, the key can also be a slice, ellipsis, or an iterable
to retrieve multiple documents.
Returns
-------
list of (int, number)
Document in BoW format.
list of (int, number), Sparse2Corpus
Document in BoW format when `key` is an integer. Otherwise :class:`~gensim.matutils.Sparse2Corpus`.
"""
indprev = self.sparse.indptr[document_index]
indnow = self.sparse.indptr[document_index + 1]
return list(zip(self.sparse.indices[indprev:indnow], self.sparse.data[indprev:indnow]))
sparse = self.sparse
if isinstance(key, int):
iprev = self.sparse.indptr[key]
inow = self.sparse.indptr[key + 1]
return list(zip(sparse.indices[iprev:inow], sparse.data[iprev:inow]))

sparse = self.sparse.__getitem__((slice(None, None, None), key))
return Sparse2Corpus(sparse)


def veclen(vec):
Expand Down
40 changes: 40 additions & 0 deletions gensim/test/test_matutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import logging
import unittest
import numpy as np
from numpy.testing import assert_array_equal
from scipy import sparse
from scipy.sparse import csc_matrix
from scipy.special import psi # gamma function utils

import gensim.matutils as matutils
Expand Down Expand Up @@ -266,6 +268,44 @@ def test_return_norm_zero_vector_gensim_sparse(self):
self.assertEqual(norm, 1.0)


class TestSparse2Corpus(unittest.TestCase):
def setUp(self):
self.orig_array = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
self.s2c = matutils.Sparse2Corpus(csc_matrix(self.orig_array))

def test_getitem_slice(self):
assert_array_equal(self.s2c[:2].sparse.toarray(), self.orig_array[:, :2])
assert_array_equal(self.s2c[1:3].sparse.toarray(), self.orig_array[:, 1:3])

def test_getitem_index(self):
self.assertListEqual(self.s2c[1], [(0, 2), (1, 5), (2, 8)])

def test_getitem_list_of_indices(self):
assert_array_equal(
self.s2c[[1, 2]].sparse.toarray(), self.orig_array[:, [1, 2]]
)
assert_array_equal(self.s2c[[1]].sparse.toarray(), self.orig_array[:, [1]])

def test_getitem_ndarray(self):
assert_array_equal(
self.s2c[np.array([1, 2])].sparse.toarray(), self.orig_array[:, [1, 2]]
)
assert_array_equal(
self.s2c[np.array([1])].sparse.toarray(), self.orig_array[:, [1]]
)

def test_getitem_range(self):
assert_array_equal(
self.s2c[range(1, 3)].sparse.toarray(), self.orig_array[:, [1, 2]]
)
assert_array_equal(
self.s2c[range(1, 2)].sparse.toarray(), self.orig_array[:, [1]]
)

def test_getitem_ellipsis(self):
assert_array_equal(self.s2c[...].sparse.toarray(), self.orig_array)


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
unittest.main()

0 comments on commit 2f182d7

Please sign in to comment.