Skip to content

Commit

Permalink
[bug] Support indexing via np.integer for field
Browse files Browse the repository at this point in the history
  • Loading branch information
Ailing Zhang committed Aug 10, 2022
1 parent ca5bf7d commit fc440ff
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
3 changes: 2 additions & 1 deletion python/taichi/lang/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,9 @@ def __getitem__(self, key):
# Check for potential slicing behaviour
# for instance: x[0, :]
padded_key = self._pad_key(key)
import numpy as np # pylint: disable=C0415
for key in padded_key:
if not isinstance(key, int):
if not isinstance(key, (int, np.integer)):
raise TypeError(
f"Detected illegal element of type: {type(key)}. "
f"Please be aware that slicing a ti.field is not supported so far."
Expand Down
2 changes: 1 addition & 1 deletion python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _linearize_entry_id(self, *args):
args = args + (0, )
# TODO(#1004): See if it's possible to support indexing at runtime
for i, a in enumerate(args):
if not isinstance(a, int):
if not isinstance(a, (int, np.integer)):
raise TaichiSyntaxError(
f'The {i}-th index of a Matrix/Vector must be a compile-time constant '
f'integer, got {type(a)}.\n'
Expand Down
22 changes: 22 additions & 0 deletions tests/python/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
To test our new `ti.field` API is functional (#1500)
'''

import numpy as np
import pytest
from taichi.lang import impl
from taichi.lang.misc import get_host_arch_list
Expand Down Expand Up @@ -282,6 +283,27 @@ def test_invalid_slicing():
val[0, :]


@test_utils.test()
def test_indexing_with_np_int():
val = ti.field(ti.i32, shape=(2))
idx = np.int32(0)
val[idx]


@test_utils.test()
def test_indexing_vec_field_with_np_int():
val = ti.Vector.field(2, ti.i32, shape=(2))
idx = np.int32(0)
val[idx][idx]


@test_utils.test()
def test_indexing_mat_field_with_np_int():
val = ti.Matrix.field(2, 2, ti.i32, shape=(2))
idx = np.int32(0)
val[idx][idx, idx]


@test_utils.test(exclude=[ti.cc], debug=True)
def test_field_fill():
x = ti.field(int, shape=(3, 3))
Expand Down

0 comments on commit fc440ff

Please sign in to comment.