Skip to content

Commit

Permalink
OTFMapFusion: Bugfix for tasklets with None connectors (#1415)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukastruemper committed Dec 2, 2023
1 parent a7b3a75 commit f90cf5b
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
9 changes: 3 additions & 6 deletions dace/transformation/dataflow/otf_map_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,13 +297,10 @@ def apply(self, graph: SDFGState, sdfg: SDFG):
else:
out_connector = edge.src_conn

if out_connector not in second_map_entry.out_connectors:
second_map_entry.add_out_connector(out_connector)
if out_connector not in self.second_map_entry.out_connectors:
self.second_map_entry.add_out_connector(out_connector)
else:
out_connector = edge.src_conn

if out_connector not in self.second_map_entry.out_connectors:
self.second_map_entry.add_out_connector(out_connector)
out_connector = None

graph.add_edge(second_map_entry, out_connector, node, edge.dst_conn, memlet)
graph.remove_edge(edge)
Expand Down
31 changes: 31 additions & 0 deletions tests/transformations/otf_map_fusion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,36 @@ def test_trivial_fusion_nested_sdfg():
assert (res == res_fused).all()


@dace.program
def trivial_fusion_none_connectors(B: dace.float64[10, 20]):
tmp = dace.define_local([10, 20], dtype=B.dtype)
for i, j in dace.map[0:10, 0:20]:
with dace.tasklet:
b >> tmp[i, j]
b = 0

for i, j in dace.map[0:10, 0:20]:
with dace.tasklet:
a << tmp[i, j]
b >> B[i, j]
b = a + 2


def test_trivial_fusion_none_connectors():
sdfg = trivial_fusion_none_connectors.to_sdfg()
sdfg.simplify()
assert count_maps(sdfg) == 2

sdfg.apply_transformations(OTFMapFusion)
assert count_maps(sdfg) == 1

B = np.zeros((10, 20))
ref = np.zeros((10, 20)) + 2

sdfg(B=B)
assert np.allclose(B, ref)


@dace.program
def undefined_subset(A: dace.float64[10], B: dace.float64[10]):
tmp = dace.define_local([10], dtype=A.dtype)
Expand Down Expand Up @@ -703,6 +733,7 @@ def test_hdiff():
test_trivial_fusion_permute()
test_trivial_fusion_not_remove_map()
test_trivial_fusion_nested_sdfg()
test_trivial_fusion_none_connectors()

# Defined subsets
test_undefined_subset()
Expand Down

0 comments on commit f90cf5b

Please sign in to comment.