Skip to content

Commit

Permalink
Merge pull request #1296 from spcl/attribute-replacements-new-state
Browse files Browse the repository at this point in the history
New State for Attribute Replacements
  • Loading branch information
alexnick83 committed Jul 4, 2023
2 parents 3e83820 + 1ee3764 commit 484630a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
7 changes: 6 additions & 1 deletion dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4612,7 +4612,12 @@ def visit_Attribute(self, node: ast.Attribute):
# Try to find sub-SDFG attribute
func = oprepo.Replacements.get_attribute(type(arr), node.attr)
if func is not None:
return func(self, self.sdfg, self.last_state, result)
# A new state is likely needed here, e.g., for transposition (ndarray.T)
self._add_state('%s_%d' % (type(node).__name__, node.lineno))
self.last_state.set_default_lineinfo(self.current_lineinfo)
result = func(self, self.sdfg, self.last_state, result)
self.last_state.set_default_lineinfo(None)
return result

# Otherwise, try to find compile-time attribute (such as shape)
try:
Expand Down
22 changes: 22 additions & 0 deletions tests/numpy/attribute_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,28 @@ def test_attribute_in_ranged_loop_symbolic():
assert np.allclose(a, regression)


def test_attribute_new_state():

N, F_in, F_out, heads = 2, 3, 4, 5

@dace.program
def fn(a: dace.float64[N, F_in], b: dace.float64[N, heads, F_out], c: dace.float64[heads * F_out, F_in]):
tmp = a.T @ np.reshape(b, (N, heads * F_out))
c[:] = tmp.T

rng = np.random.default_rng(42)

a = rng.random((N, F_in))
b = rng.random((N, heads, F_out))
c_expected = np.zeros((heads * F_out, F_in))
c = np.zeros((heads * F_out, F_in))

fn.f(a, b, c_expected)
fn(a, b, c)
assert np.allclose(c, c_expected)


if __name__ == '__main__':
test_attribute_in_ranged_loop()
test_attribute_in_ranged_loop_symbolic()
test_attribute_new_state()

0 comments on commit 484630a

Please sign in to comment.