New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modified behavior of __getitem__ for arrays #17226
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,8 @@ | |
from sympy.core.sympify import _sympify | ||
from sympy.tensor.array.mutable_ndim_array import MutableNDimArray | ||
from sympy.tensor.array.ndim_array import NDimArray, ImmutableNDimArray | ||
from sympy.core.numbers import Integer | ||
from sympy.core.compatibility import SYMPY_INTS | ||
|
||
import functools | ||
|
||
|
@@ -28,9 +30,9 @@ def __getitem__(self, index): | |
>>> a[1, 1] | ||
3 | ||
>>> a[0] | ||
0 | ||
>>> a[2] | ||
2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Print the new output here. |
||
[0, 1] | ||
>>> a[1] | ||
[2, 3] | ||
|
||
Symbolic indexing: | ||
|
||
|
@@ -48,6 +50,12 @@ def __getitem__(self, index): | |
if syindex is not None: | ||
return syindex | ||
|
||
if isinstance(index, (SYMPY_INTS, Integer)): | ||
index = (index, ) | ||
if not isinstance(index, slice) and len(index) < self.rank(): | ||
index = tuple([i for i in index] + \ | ||
[slice(None) for i in range(len(index), self.rank())]) | ||
|
||
# `index` is a tuple with one or more slices: | ||
if isinstance(index, tuple) and any([isinstance(i, slice) for i in index]): | ||
sl_factors, eindices = self._get_slice_data_for_array_access(index) | ||
|
@@ -101,7 +109,7 @@ def tomatrix(self): | |
def __iter__(self): | ||
def iterator(): | ||
for i in range(self._loop_size): | ||
yield self[i] | ||
yield self[self._get_tuple_index(i)] | ||
return iterator() | ||
|
||
def reshape(self, *newshape): | ||
|
@@ -119,7 +127,7 @@ def __new__(cls, iterable=None, shape=None, **kwargs): | |
shape, flat_list = cls._handle_ndarray_creation_inputs(iterable, shape, **kwargs) | ||
shape = Tuple(*map(_sympify, shape)) | ||
cls._check_special_bounds(flat_list, shape) | ||
loop_size = functools.reduce(lambda x,y: x*y, shape) if shape else 0 | ||
loop_size = functools.reduce(lambda x,y: x*y, shape) if shape else len(flat_list) | ||
|
||
# Sparse array: | ||
if isinstance(flat_list, (dict, Dict)): | ||
|
@@ -156,7 +164,7 @@ def __new__(cls, iterable=None, shape=None, **kwargs): | |
self = object.__new__(cls) | ||
self._shape = shape | ||
self._rank = len(shape) | ||
self._loop_size = functools.reduce(lambda x,y: x*y, shape) if shape else 0 | ||
self._loop_size = functools.reduce(lambda x,y: x*y, shape) if shape else len(flat_list) | ||
|
||
# Sparse array: | ||
if isinstance(flat_list, (dict, Dict)): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -108,6 +108,23 @@ def test_iterator(): | |
j += 1 | ||
|
||
|
||
def test_getitem(): | ||
for ArrayType in [MutableDenseNDimArray, MutableSparseNDimArray]: | ||
array = ArrayType(range(24)).reshape(2, 3, 4) | ||
assert array.tolist() == [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] | ||
assert array[0] == ArrayType([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]) | ||
assert array[0, 0] == ArrayType([0, 1, 2, 3]) | ||
value = 0 | ||
for i in range(2): | ||
for j in range(3): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe add some subtests? |
||
for k in range(4): | ||
assert array[i, j, k] == value | ||
value += 1 | ||
|
||
raises(ValueError, lambda: array[3, 4, 5]) | ||
raises(ValueError, lambda: array[3, 4, 5, 6]) | ||
|
||
|
||
def test_sparse(): | ||
sparse_array = MutableSparseNDimArray([0, 0, 0, 1], (2, 2)) | ||
assert len(sparse_array) == 2 * 2 | ||
|
@@ -184,7 +201,7 @@ def test_ndim_array_converting(): | |
assert (isinstance(matrix, Matrix)) | ||
|
||
for i in range(len(dense_array)): | ||
assert dense_array[i] == matrix[i] | ||
assert dense_array[dense_array._get_tuple_index(i)] == matrix[i] | ||
assert matrix.shape == dense_array.shape | ||
|
||
assert MutableDenseNDimArray(matrix) == dense_array | ||
|
@@ -200,7 +217,7 @@ def test_ndim_array_converting(): | |
assert(isinstance(matrix, SparseMatrix)) | ||
|
||
for i in range(len(sparse_array)): | ||
assert sparse_array[i] == matrix[i] | ||
assert sparse_array[sparse_array._get_tuple_index(i)] == matrix[i] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe we need new tests for this feature. |
||
assert matrix.shape == sparse_array.shape | ||
|
||
assert MutableSparseNDimArray(matrix) == sparse_array | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because in the old way, an Array with one item will have a
_loop_size
equal to 0. I have opened an issue about it: #17230. Maybe there is a reason for keeping the_loop_size
as 0? If so, I would be happy to know it. :-)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If
shape == ()
, maybe loop size should be 1? (i.e. it's a scalar)