Skip to content

Commit

Permalink
Merge pull request #1287 from spcl/fix-writes-to-nested-definitions
Browse files Browse the repository at this point in the history
Fix writes to nested definitions
  • Loading branch information
acalotoiu committed Jun 28, 2023
2 parents dcc284d + 3e5eff4 commit e0dbdf6
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 11 deletions.
4 changes: 2 additions & 2 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3067,11 +3067,11 @@ def _add_write_access(self,
arr_type: data.Data = None):

if name in self.sdfg.arrays:
return (name, None)
return (name, rng)
if (name, rng, 'w') in self.accesses:
return self.accesses[(name, rng, 'w')]
elif name in self.variables:
return (self.variables[name], None)
return (self.variables[name], rng)
elif (name, rng, 'r') in self.accesses or name in self.scope_vars:
return self._add_access(name, rng, 'w', target, new_name, arr_type)
else:
Expand Down
85 changes: 76 additions & 9 deletions tests/python_frontend/nested_name_accesses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_nested_name_accesses():
diff_norm = np.linalg.norm(dc_out - np_out)
ref_norm = np.linalg.norm(np_out)
rel_err = diff_norm / ref_norm
assert (rel_err < 1e-7)
assert rel_err < 1e-7


def test_nested_offset_access():
Expand All @@ -42,7 +42,7 @@ def nested_offset_access(inp: dc.float64[6, 5, 5]):
inp = np.reshape(np.arange(6 * 5 * 5, dtype=np.float64), (6, 5, 5)).copy()
out = nested_offset_access(inp)
ref = nested_offset_access.f(inp)
assert (np.allclose(out, ref))
assert np.allclose(out, ref)


def test_nested_offset_access_dappy():
Expand All @@ -62,7 +62,7 @@ def nested_offset_access(inp: dc.float64[6, 5, 5]):
inp = np.reshape(np.arange(6 * 5 * 5, dtype=np.float64), (6, 5, 5)).copy()
out = nested_offset_access(inp)
ref = nested_offset_access.f(inp)
assert (np.allclose(out, ref))
assert np.allclose(out, ref)


def test_nested_multi_offset_access():
Expand All @@ -79,7 +79,7 @@ def nested_offset_access(inp: dc.float64[6, 5, 10]):
inp = np.reshape(np.arange(6 * 5 * 10, dtype=np.float64), (6, 5, 10)).copy()
out = nested_offset_access(inp)
ref = nested_offset_access.f(inp)
assert (np.allclose(out, ref))
assert np.allclose(out, ref)


def test_nested_multi_offset_access_dappy():
Expand All @@ -100,7 +100,7 @@ def nested_offset_access(inp: dc.float64[6, 5, 10]):
inp = np.reshape(np.arange(6 * 5 * 10, dtype=np.float64), (6, 5, 10)).copy()
out = nested_offset_access(inp)
ref = nested_offset_access.f(inp)
assert (np.allclose(out, ref))
assert np.allclose(out, ref)


def test_nested_dec_offset_access():
Expand All @@ -116,7 +116,7 @@ def nested_offset_access(inp: dc.float64[6, 5, 5]):
inp = np.reshape(np.arange(6 * 5 * 5, dtype=np.float64), (6, 5, 5)).copy()
out = nested_offset_access(inp)
ref = nested_offset_access.f(inp)
assert (np.allclose(out, ref))
assert np.allclose(out, ref)


def test_nested_dec_offset_access_dappy():
Expand All @@ -136,7 +136,7 @@ def nested_offset_access(inp: dc.float64[6, 5, 5]):
inp = np.reshape(np.arange(6 * 5 * 5, dtype=np.float64), (6, 5, 5)).copy()
out = nested_offset_access(inp)
ref = nested_offset_access.f(inp)
assert (np.allclose(out, ref))
assert np.allclose(out, ref)


def test_nested_offset_access_nested_dependency():
Expand All @@ -157,7 +157,7 @@ def nested_offset_access_nested_dep(inp: dc.float64[6, 5, 5]):
out = nested_offset_access_nested_dep(inp)
os.environ['DACE_testing_serialization'] = last_value
ref = nested_offset_access_nested_dep.f(inp)
assert (np.allclose(out, ref))
assert np.allclose(out, ref)


def test_nested_offset_access_nested_dependency_dappy():
Expand All @@ -178,9 +178,74 @@ def nested_offset_access_nested_dep(inp: dc.float64[6, 5, 10]):
inp = np.reshape(np.arange(6 * 5 * 10, dtype=np.float64), (6, 5, 10)).copy()
out = nested_offset_access_nested_dep(inp)
ref = nested_offset_access_nested_dep.f(inp)
assert (np.allclose(out, ref))
assert np.allclose(out, ref)


def test_access_to_nested_transient():

KLEV = 3
KLON = 4
NBLOCKS = 5

@dc.program
def small_wip(inp: dc.float64[KLEV+1, KLON, NBLOCKS], out: dc.float64[KLEV, KLON, NBLOCKS]):
for jn in dc.map[0:NBLOCKS]:
tmp = np.zeros([KLEV+1, KLON])
for jl in range(KLON):
for jk in range(KLEV):
tmp[jk, jl] = inp[jk, jl, jn] + inp[jk+1, jl, jn]

for jl in range(KLON):
for jk in range(KLEV):
out[jk, jl, jn] = tmp[jk, jl] + tmp[jk+1, jl]

rng = np.random.default_rng(42)
inp = rng.random((KLEV+1, KLON, NBLOCKS))
ref = np.zeros((KLEV, KLON, NBLOCKS))
val = np.zeros((KLEV, KLON, NBLOCKS))

small_wip(inp, val)
small_wip.f(inp, ref)

assert np.allclose(val, ref)


def test_access_to_nested_transient_dappy():

KLEV = 3
KLON = 4
NBLOCKS = 5

@dc.program
def small_wip_dappy(inp: dc.float64[KLEV+1, KLON, NBLOCKS], out: dc.float64[KLEV, KLON, NBLOCKS]):
for jn in dc.map[0:NBLOCKS]:
tmp = np.zeros([KLEV+1, KLON])
for jl in range(KLON):
for jk in range(KLEV):
with dc.tasklet():
in1 << inp[jk, jl, jn]
in2 << inp[jk+1, jl, jn]
out1 >> tmp[jk, jl]
out1 = in1 + in2

for jl in range(KLON):
for jk in range(KLEV):
with dc.tasklet():
in1 << tmp[jk, jl]
in2 << tmp[jk+1, jl]
out1 >> out[jk, jl, jn]
out1 = in1 + in2

rng = np.random.default_rng(42)
inp = rng.random((KLEV+1, KLON, NBLOCKS))
ref = np.zeros((KLEV, KLON, NBLOCKS))
val = np.zeros((KLEV, KLON, NBLOCKS))

small_wip_dappy(inp, val)
small_wip_dappy.f(inp, ref)

assert np.allclose(val, ref)


if __name__ == "__main__":
test_nested_name_accesses()
Expand All @@ -192,3 +257,5 @@ def nested_offset_access_nested_dep(inp: dc.float64[6, 5, 10]):
test_nested_dec_offset_access_dappy()
test_nested_offset_access_nested_dependency()
test_nested_offset_access_nested_dependency_dappy()
test_access_to_nested_transient()
test_access_to_nested_transient_dappy()

0 comments on commit e0dbdf6

Please sign in to comment.