From 7ed0a025bfb1e5a2605d6a621c077db82e2f4f31 Mon Sep 17 00:00:00 2001 From: Ailing Zhang Date: Wed, 10 Aug 2022 16:45:23 +0800 Subject: [PATCH] [bug] Support indexing via np.integer --- python/taichi/lang/field.py | 3 ++- tests/python/test_field.py | 8 ++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/python/taichi/lang/field.py b/python/taichi/lang/field.py index e36d88f43aafd7..cfea4d10069f4c 100644 --- a/python/taichi/lang/field.py +++ b/python/taichi/lang/field.py @@ -352,8 +352,9 @@ def __getitem__(self, key): # Check for potential slicing behaviour # for instance: x[0, :] padded_key = self._pad_key(key) + import numpy as np # pylint: disable=C0415 for key in padded_key: - if not isinstance(key, int): + if not isinstance(key, (int, np.integer)): raise TypeError( f"Detected illegal element of type: {type(key)}. " f"Please be aware that slicing a ti.field is not supported so far." diff --git a/tests/python/test_field.py b/tests/python/test_field.py index 97f3db484e1276..e22cdc8b546c42 100644 --- a/tests/python/test_field.py +++ b/tests/python/test_field.py @@ -2,6 +2,7 @@ To test our new `ti.field` API is functional (#1500) ''' +import numpy as np import pytest from taichi.lang import impl from taichi.lang.misc import get_host_arch_list @@ -282,6 +283,13 @@ def test_invalid_slicing(): val[0, :] +@test_utils.test() +def test_slicing_with_np_int(): + val = ti.field(ti.i32, shape=(2)) + idx = np.int32(0) + val[idx] + + @test_utils.test(exclude=[ti.cc], debug=True) def test_field_fill(): x = ti.field(int, shape=(3, 3))