Skip to content

Commit

Permalink
OTFMapFusion: Check for state dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
lukastruemper committed Nov 25, 2023
1 parent ff9e2e2 commit 87373ae
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 11 deletions.
46 changes: 41 additions & 5 deletions dace/transformation/dataflow/otf_map_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,50 @@ def expressions(cls):
return [sdutil.node_path_graph(cls.first_map_exit, cls.array, cls.second_map_entry)]

def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
# General conditions
if not sdfg.arrays[self.array.data].transient:
return False
# Unique paths
if graph.out_degree(self.first_map_exit) > 1:
return False
if graph.in_degree(self.array) > 1:
return False

# Check if array is overwritten afterwards
subset = next(graph.out_edges(self.first_map_exit).__iter__()).data.subset
is_overwritten = False
current = graph
while len(sdfg.successors(current)) == 1:
succ = next(sdfg.successors(current).__iter__())
covered = []
for dnode in succ.data_nodes():
if dnode.data != self.array.data:
continue

Check warning on line 57 in dace/transformation/dataflow/otf_map_fusion.py

View check run for this annotation

Codecov / codecov/patch

dace/transformation/dataflow/otf_map_fusion.py#L57

Added line #L57 was not covered by tests

if succ.in_degree(dnode) == 0:
covered.append(False)
break

Check warning on line 61 in dace/transformation/dataflow/otf_map_fusion.py

View check run for this annotation

Codecov / codecov/patch

dace/transformation/dataflow/otf_map_fusion.py#L60-L61

Added lines #L60 - L61 were not covered by tests

for inedge in succ.in_edges(dnode):
if not inedge.data.subset.covers(subset):
covered.append(False)
break

Check warning on line 66 in dace/transformation/dataflow/otf_map_fusion.py

View check run for this annotation

Codecov / codecov/patch

dace/transformation/dataflow/otf_map_fusion.py#L65-L66

Added lines #L65 - L66 were not covered by tests

covered.append(True)

if covered:
is_overwritten = all(covered)
break

current = succ

if not is_overwritten:
if not sdfg.arrays[self.array.data].transient:
return False

Check warning on line 78 in dace/transformation/dataflow/otf_map_fusion.py

View check run for this annotation

Codecov / codecov/patch

dace/transformation/dataflow/otf_map_fusion.py#L78

Added line #L78 was not covered by tests

# If not used anywhere, continue
for state in sdfg.states():
for dnode in state.data_nodes():
if self.array != dnode and dnode.data == self.array.data:
return False

# No non-transients in scope of first map
first_map_entry = graph.entry_node(self.first_map_exit)
subgraph = graph.scope_subgraph(first_map_entry, include_entry=True, include_exit=True)
Expand Down Expand Up @@ -256,8 +292,8 @@ def apply(self, graph: SDFGState, sdfg: SDFG):
first_accesses = tuple(first_memlet.subset.ranges)
for second_accesses in consume_memlets[array]:
# Step 1: Infer index access of second map to new inputs with respect to original first map
mapping = OTFMapFusion.solve(first_map_entry.map.params, first_accesses,
second_map_entry.map.params, second_accesses)
mapping = OTFMapFusion.solve(first_map_entry.map.params, first_accesses, second_map_entry.map.params,
second_accesses)

# Step 2: Add Temporary buffer
tmp_name = sdfg.temp_data_name()
Expand Down
92 changes: 86 additions & 6 deletions tests/transformations/otf_map_fusion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,15 +718,91 @@ def test_hdiff():
assert np.allclose(out_field, out_field_)


def test_state_dependencies_raw():

@dace.program
def state_dependencies_raw(A: dace.float64[10, 10], B: dace.float64[10, 10]):
tmp = dace.define_local([10, 10], dtype=A.dtype)
for i, j in dace.map[0:10, 0:10]:
with dace.tasklet:
a << A[i, j]
b >> tmp[i, j]

b = a + 1

for i, j in dace.map[0:10, 0:10]:
with dace.tasklet:
a << tmp[i, j]
b >> B[i, j]

b = a + 2

for i in range(1, 10):
tmp[0, i] = tmp[0, i - 1]

B[0, 9] = tmp[0, 9]

sdfg = state_dependencies_raw.to_sdfg()
sdfg.simplify()

applied = sdfg.apply_transformations(OTFMapFusion)
assert applied == 0


def test_state_dependencies_waw():

@dace.program
def state_dependencies_waw(A: dace.float64[10, 10], B: dace.float64[10, 10]):
tmp = dace.define_local([10, 10], dtype=A.dtype)
for i, j in dace.map[0:10, 0:10]:
with dace.tasklet:
a << A[i, j]
b >> tmp[i, j]

b = a + 1

for i, j in dace.map[0:10, 0:10]:
with dace.tasklet:
a << tmp[i, j]
b >> B[i, j]

b = a + 2

for i, j in dace.map[0:10, 0:10]:
with dace.tasklet:
a >> tmp[i, j]
a = 0

for i in range(1, 10):
A[0, i] = A[0, i - 1] + tmp[0, i]

sdfg = state_dependencies_waw.to_sdfg()
sdfg.simplify()

A = np.random.rand(10, 10).astype(np.float64)
A_ = np.copy(A)
B = np.zeros_like(A)
sdfg(A=A, B=B)

applied = sdfg.apply_transformations(OTFMapFusion)
assert applied == 1

B_ = np.zeros_like(A_)
sdfg(A=A_, B=B_)

assert np.allclose(B, B_)
assert np.allclose(A, A_)


if __name__ == '__main__':
# Solver
# # Solver
test_solve()
test_solve_permute()
test_solve_constant()
test_solve_constant2()
test_solve_unsolvable()

# Trivial fusion
# # Trivial fusion
test_trivial_fusion()
test_trivial_fusion_rename()
test_trivial_fusion_flip()
Expand All @@ -735,19 +811,23 @@ def test_hdiff():
test_trivial_fusion_nested_sdfg()
test_trivial_fusion_none_connectors()

# Defined subsets
# # Defined subsets
test_undefined_subset()
test_defined_subset()
test_undefined_subset_step()
test_defined_subset_step()

# Recomputation
# # Recomputation
test_recomputation_fusion()

# Local buffer
# # Local buffer
test_local_storage_fusion()
test_local_storage_fusion_nested_map()

# Applications
# # State dependencies
test_state_dependencies_raw()
test_state_dependencies_waw()

# # Applications
test_matmuls()
test_hdiff()

0 comments on commit 87373ae

Please sign in to comment.