Skip to content

Commit

Permalink
[Bug] [lang] Fix AST not being transformed inside ti.ndrange (#2187)
Browse files Browse the repository at this point in the history
  • Loading branch information
archibate committed Feb 17, 2021
1 parent 58feee3 commit 3bc0341
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 9 deletions.
19 changes: 10 additions & 9 deletions python/taichi/lang/transformer.py
Expand Up @@ -421,30 +421,31 @@ def visit_ndrange_for(self, node):
# for i, j in ti.ndrange(n)
template = f'''
if ti.static(1):
__ndrange = ti.static(0)
__ndrange{id(node)} = 0
for __ndrange_I{id(node)} in range(0):
__I = __ndrange_I{id(node)}
'''
t = ast.parse(template).body[0]
t.body[0].value.args[0] = node.iter
t.body[0].value = node.iter
t_loop = t.body[1]
t_loop.iter.args[0] = self.parse_expr('__ndrange.acc_dimensions[0]')
t_loop.iter.args[0] = self.parse_expr(
f'__ndrange{id(node)}.acc_dimensions[0]')
targets = self.get_targets(node)
targets_tmp = ['__' + name for name in targets]
loop_body = t_loop.body
for i in range(len(targets)):
if i + 1 < len(targets):
stmt = '{} = __I // __ndrange.acc_dimensions[{}]'.format(
targets_tmp[i], i + 1)
stmt = '{} = __I // __ndrange{}.acc_dimensions[{}]'.format(
targets_tmp[i], id(node), i + 1)
else:
stmt = '{} = __I'.format(targets_tmp[i])
loop_body.append(self.parse_stmt(stmt))
stmt = '{} = {} + __ndrange.bounds[{}][0]'.format(
targets[i], targets_tmp[i], i)
stmt = '{} = {} + __ndrange{}.bounds[{}][0]'.format(
targets[i], targets_tmp[i], id(node), i)
loop_body.append(self.parse_stmt(stmt))
if i + 1 < len(targets):
stmt = '__I = __I - {} * __ndrange.acc_dimensions[{}]'.format(
targets_tmp[i], i + 1)
stmt = '__I = __I - {} * __ndrange{}.acc_dimensions[{}]'.format(
targets_tmp[i], id(node), i + 1)
loop_body.append(self.parse_stmt(stmt))
loop_body += node.body

Expand Down
30 changes: 30 additions & 0 deletions tests/python/test_ndrange.py
Expand Up @@ -160,3 +160,33 @@ def init():
for l in range(n):
r = i * n**3 + j * n**2 + k * n + l
assert A[i, j, k, l] == r


@ti.test(ti.cpu)
def test_ndrange_ast_transform():
n, u, v = 4, 3, 2

a = ti.field(ti.i32, ())
b = ti.field(ti.i32, ())
A = ti.field(ti.i32, (n, n))

@ti.kernel
def func():
# `__getitem__ cannot be called from Python-scope` will be raised if
# `a[None]` is not transformed to `ti.subscript(a, None)` in ti.ndrange:
for i, j in ti.ndrange(a[None], b[None]):
r = i * n + j + 1
A[i, j] = r

a[None] = u
b[None] = v

func()

for i in range(n):
for j in range(n):
if i < u and j < v:
r = i * n + j + 1
else:
r = 0
assert A[i, j] == r

0 comments on commit 3bc0341

Please sign in to comment.