Skip to content

Commit

Permalink
Sparse2Corpus: update __getitem__ to work on slices, ellipsis, and it…
Browse files Browse the repository at this point in the history
…erable
  • Loading branch information
PrimozGodec committed Oct 28, 2021
1 parent 3d72896 commit 23cf0c4
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 10 deletions.
27 changes: 17 additions & 10 deletions gensim/matutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import logging
import math
from typing import Union, Tuple, List, Iterable

from gensim import utils

Expand Down Expand Up @@ -597,23 +598,29 @@ def __iter__(self):
def __len__(self):
return self.sparse.shape[1]

def __getitem__(self, document_index):
"""Retrieve a document vector from the corpus by its index.
def __getitem__(self, key):
"""
Retrieve a document vector or subset from the corpus by key.
Parameters
----------
document_index : int
Index of document
key: int, ellipsis, slice, iterable object
Index of document or slice, ellipsis, iterable with subset of documents
Returns
-------
list of (int, number)
Document in BoW format.
list of (int, number), Sparse2Corpus
Document in BoW format when key is int or Sparse2Corpus with corpus subset
if key demand subset (not single document)
"""
indprev = self.sparse.indptr[document_index]
indnow = self.sparse.indptr[document_index + 1]
return list(zip(self.sparse.indices[indprev:indnow], self.sparse.data[indprev:indnow]))
sparse = self.sparse
if isinstance(key, int):
iprev = self.sparse.indptr[key]
inow = self.sparse.indptr[key + 1]
return list(zip(sparse.indices[iprev:inow], sparse.data[iprev:inow]))

sparse = self.sparse.__getitem__((slice(None, None, None), key))
return Sparse2Corpus(sparse)


def veclen(vec):
Expand Down
40 changes: 40 additions & 0 deletions gensim/test/test_matutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import logging
import unittest
import numpy as np
from numpy.testing import assert_array_equal
from scipy import sparse
from scipy.sparse import csc_matrix
from scipy.special import psi # gamma function utils

import gensim.matutils as matutils
Expand Down Expand Up @@ -266,6 +268,44 @@ def test_return_norm_zero_vector_gensim_sparse(self):
self.assertEqual(norm, 1.0)


class TestSparse2Corpus(unittest.TestCase):
def setUp(self) -> None:
self.orig_array = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
self.s2c = matutils.Sparse2Corpus(csc_matrix(self.orig_array))

def test_getitem_slice(self):
assert_array_equal(self.s2c[:2].sparse.toarray(), self.orig_array[:, :2])
assert_array_equal(self.s2c[1:3].sparse.toarray(), self.orig_array[:, 1:3])

def test_getitem_index(self):
self.assertListEqual(self.s2c[1], [(0, 2), (1, 5), (2, 8)])

def test_getitem_list_of_indices(self):
assert_array_equal(
self.s2c[[1, 2]].sparse.toarray(), self.orig_array[:, [1, 2]]
)
assert_array_equal(self.s2c[[1]].sparse.toarray(), self.orig_array[:, [1]])

def test_getitem_ndarray(self):
assert_array_equal(
self.s2c[np.array([1, 2])].sparse.toarray(), self.orig_array[:, [1, 2]]
)
assert_array_equal(
self.s2c[np.array([1])].sparse.toarray(), self.orig_array[:, [1]]
)

def test_getitem_range(self):
assert_array_equal(
self.s2c[range(1, 3)].sparse.toarray(), self.orig_array[:, [1, 2]]
)
assert_array_equal(
self.s2c[range(1, 2)].sparse.toarray(), self.orig_array[:, [1]]
)

def test_getitem_ellipsis(self):
assert_array_equal(self.s2c[...].sparse.toarray(), self.orig_array)


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
unittest.main()

0 comments on commit 23cf0c4

Please sign in to comment.