diff --git a/python/taichi/lang/field.py b/python/taichi/lang/field.py index e36d88f43aafd..cfea4d10069f4 100644 --- a/python/taichi/lang/field.py +++ b/python/taichi/lang/field.py @@ -352,8 +352,9 @@ def __getitem__(self, key): # Check for potential slicing behaviour # for instance: x[0, :] padded_key = self._pad_key(key) + import numpy as np # pylint: disable=C0415 for key in padded_key: - if not isinstance(key, int): + if not isinstance(key, (int, np.integer)): raise TypeError( f"Detected illegal element of type: {type(key)}. " f"Please be aware that slicing a ti.field is not supported so far." diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 5fb7c98534179..dba878e00c840 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -125,7 +125,7 @@ def _linearize_entry_id(self, *args): args = args + (0, ) # TODO(#1004): See if it's possible to support indexing at runtime for i, a in enumerate(args): - if not isinstance(a, int): + if not isinstance(a, (int, np.integer)): raise TaichiSyntaxError( f'The {i}-th index of a Matrix/Vector must be a compile-time constant ' f'integer, got {type(a)}.\n' diff --git a/tests/python/test_field.py b/tests/python/test_field.py index 97f3db484e127..b7af5e70c21b4 100644 --- a/tests/python/test_field.py +++ b/tests/python/test_field.py @@ -2,6 +2,7 @@ To test our new `ti.field` API is functional (#1500) ''' +import numpy as np import pytest from taichi.lang import impl from taichi.lang.misc import get_host_arch_list @@ -282,6 +283,27 @@ def test_invalid_slicing(): val[0, :] +@test_utils.test() +def test_indexing_with_np_int(): + val = ti.field(ti.i32, shape=(2)) + idx = np.int32(0) + val[idx] + + +@test_utils.test() +def test_indexing_vec_field_with_np_int(): + val = ti.Vector.field(2, ti.i32, shape=(2)) + idx = np.int32(0) + val[idx][idx] + + +@test_utils.test() +def test_indexing_mat_field_with_np_int(): + val = ti.Matrix.field(2, 2, ti.i32, shape=(2)) + idx = np.int32(0) + val[idx][idx, idx] + + @test_utils.test(exclude=[ti.cc], debug=True) def test_field_fill(): x = ti.field(int, shape=(3, 3))