From 3bc0341e4f3e84b6c133e2bb0942d3a5ab8ba3dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BD=AD=E4=BA=8E=E6=96=8C?= <1931127624@qq.com> Date: Wed, 17 Feb 2021 14:55:49 +0800 Subject: [PATCH] [Bug] [lang] Fix AST not being transformed inside ti.ndrange (#2187) --- python/taichi/lang/transformer.py | 19 ++++++++++--------- tests/python/test_ndrange.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/python/taichi/lang/transformer.py b/python/taichi/lang/transformer.py index 8d27e5cc52963..e3f49853d109d 100644 --- a/python/taichi/lang/transformer.py +++ b/python/taichi/lang/transformer.py @@ -421,30 +421,31 @@ def visit_ndrange_for(self, node): # for i, j in ti.ndrange(n) template = f''' if ti.static(1): - __ndrange = ti.static(0) + __ndrange{id(node)} = 0 for __ndrange_I{id(node)} in range(0): __I = __ndrange_I{id(node)} ''' t = ast.parse(template).body[0] - t.body[0].value.args[0] = node.iter + t.body[0].value = node.iter t_loop = t.body[1] - t_loop.iter.args[0] = self.parse_expr('__ndrange.acc_dimensions[0]') + t_loop.iter.args[0] = self.parse_expr( + f'__ndrange{id(node)}.acc_dimensions[0]') targets = self.get_targets(node) targets_tmp = ['__' + name for name in targets] loop_body = t_loop.body for i in range(len(targets)): if i + 1 < len(targets): - stmt = '{} = __I // __ndrange.acc_dimensions[{}]'.format( - targets_tmp[i], i + 1) + stmt = '{} = __I // __ndrange{}.acc_dimensions[{}]'.format( + targets_tmp[i], id(node), i + 1) else: stmt = '{} = __I'.format(targets_tmp[i]) loop_body.append(self.parse_stmt(stmt)) - stmt = '{} = {} + __ndrange.bounds[{}][0]'.format( - targets[i], targets_tmp[i], i) + stmt = '{} = {} + __ndrange{}.bounds[{}][0]'.format( + targets[i], targets_tmp[i], id(node), i) loop_body.append(self.parse_stmt(stmt)) if i + 1 < len(targets): - stmt = '__I = __I - {} * __ndrange.acc_dimensions[{}]'.format( - targets_tmp[i], i + 1) + stmt = '__I = __I - {} * __ndrange{}.acc_dimensions[{}]'.format( + targets_tmp[i], id(node), i + 1) loop_body.append(self.parse_stmt(stmt)) loop_body += node.body diff --git a/tests/python/test_ndrange.py b/tests/python/test_ndrange.py index d108cc3557e99..cc32561bc6765 100644 --- a/tests/python/test_ndrange.py +++ b/tests/python/test_ndrange.py @@ -160,3 +160,33 @@ def init(): for l in range(n): r = i * n**3 + j * n**2 + k * n + l assert A[i, j, k, l] == r + + +@ti.test(ti.cpu) +def test_ndrange_ast_transform(): + n, u, v = 4, 3, 2 + + a = ti.field(ti.i32, ()) + b = ti.field(ti.i32, ()) + A = ti.field(ti.i32, (n, n)) + + @ti.kernel + def func(): + # `__getitem__ cannot be called from Python-scope` will be raised if + # `a[None]` is not transformed to `ti.subscript(a, None)` in ti.ndrange: + for i, j in ti.ndrange(a[None], b[None]): + r = i * n + j + 1 + A[i, j] = r + + a[None] = u + b[None] = v + + func() + + for i in range(n): + for j in range(n): + if i < u and j < v: + r = i * n + j + 1 + else: + r = 0 + assert A[i, j] == r