Skip to content

Commit

Permalink
Sparse2Corpus: update __getitem__ to work on slices, ellipsis, and it…
Browse files Browse the repository at this point in the history
…erable
  • Loading branch information
PrimozGodec committed Oct 27, 2021
1 parent 3d72896 commit 2279428
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 10 deletions.
28 changes: 18 additions & 10 deletions gensim/matutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import logging
import math
from typing import Union, Tuple, List, Iterable

from gensim import utils

Expand Down Expand Up @@ -597,23 +598,30 @@ def __iter__(self):
def __len__(self):
return self.sparse.shape[1]

def __getitem__(self, document_index):
"""Retrieve a document vector from the corpus by its index.
def __getitem__(
self, key: Union[int, type(...), slice, Iterable]
) -> Union["Sparse2Corpus", List[Tuple[int, int]]]:
"""
Retrieve a document vector or subset from the corpus by key.
Parameters
----------
document_index : int
Index of document
key
Index of document or slice, ellipsis, iterable with subset of documents
Returns
-------
list of (int, number)
Document in BoW format.
Document in BoW format when key is int or Sparse2Corpus with corpus subset
if key demand subset (not single document)
"""
indprev = self.sparse.indptr[document_index]
indnow = self.sparse.indptr[document_index + 1]
return list(zip(self.sparse.indices[indprev:indnow], self.sparse.data[indprev:indnow]))
sparse = self.sparse
if isinstance(key, int):
iprev = self.sparse.indptr[key]
inow = self.sparse.indptr[key + 1]
return list(zip(sparse.indices[iprev:inow], sparse.data[iprev:inow]))

sparse = self.sparse.__getitem__((slice(None, None, None), key))
return Sparse2Corpus(sparse)


def veclen(vec):
Expand Down
40 changes: 40 additions & 0 deletions gensim/test/test_matutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import logging
import unittest
import numpy as np
from numpy.testing import assert_array_equal
from scipy import sparse
from scipy.sparse import csc_matrix
from scipy.special import psi # gamma function utils

import gensim.matutils as matutils
Expand Down Expand Up @@ -266,6 +268,44 @@ def test_return_norm_zero_vector_gensim_sparse(self):
self.assertEqual(norm, 1.0)


class TestSparse2Corpus(unittest.TestCase):
def setUp(self) -> None:
self.orig_array = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
self.s2c = matutils.Sparse2Corpus(csc_matrix(self.orig_array))

def test_slice(self):
assert_array_equal(self.s2c[:2].sparse.toarray(), self.orig_array[:, :2])
assert_array_equal(self.s2c[1:3].sparse.toarray(), self.orig_array[:, 1:3])

def test_index(self):
self.assertListEqual(self.s2c[1], [(0, 2), (1, 5), (2, 8)])

def test_list_of_indices(self):
assert_array_equal(
self.s2c[[1, 2]].sparse.toarray(), self.orig_array[:, [1, 2]]
)
assert_array_equal(self.s2c[[1]].sparse.toarray(), self.orig_array[:, [1]])

def test_ndarray(self):
assert_array_equal(
self.s2c[np.array([1, 2])].sparse.toarray(), self.orig_array[:, [1, 2]]
)
assert_array_equal(
self.s2c[np.array([1])].sparse.toarray(), self.orig_array[:, [1]]
)

def test_range(self):
assert_array_equal(
self.s2c[range(1, 3)].sparse.toarray(), self.orig_array[:, [1, 2]]
)
assert_array_equal(
self.s2c[range(1, 2)].sparse.toarray(), self.orig_array[:, [1]]
)

def test_elipsis(self):
assert_array_equal(self.s2c[...].sparse.toarray(), self.orig_array)


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
unittest.main()

0 comments on commit 2279428

Please sign in to comment.