Skip to content

Commit

Permalink
Merge pull request #1212 from spcl/fix-intermediate-nodes
Browse files Browse the repository at this point in the history
Fix-intermediate-nodes
  • Loading branch information
alexnick83 committed Mar 13, 2023
2 parents c88c0ca + 476503b commit b7325ea
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 1 deletion.
2 changes: 1 addition & 1 deletion dace/transformation/dataflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
# Complexity reduction
from .dedup_access import DeduplicateAccess
from .redundant_array import (RedundantArray, RedundantSecondArray, SqueezeViewRemove, UnsqueezeViewRemove,
RedundantReadSlice, RedundantWriteSlice, RemoveSliceView)
RedundantReadSlice, RedundantWriteSlice, RemoveSliceView, RemoveIntermediateWrite)
from .redundant_array_copying import (RedundantArrayCopyingIn, RedundantArrayCopying, RedundantArrayCopying2,
RedundantArrayCopying3)
from .merge_arrays import InMergeArrays, OutMergeArrays, MergeSourceSinkArrays
Expand Down
50 changes: 50 additions & 0 deletions dace/transformation/dataflow/redundant_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1608,3 +1608,53 @@ def _offset_subset(self, mapping: Dict[int, int], subset: subsets.Range, edge_su
new_subset[adim] = (rb, re, rs)

return subsets.Range(new_subset)


class RemoveIntermediateWrite(pm.SingleStateTransformation):
""" Moves intermediate writes insde a Map's subgraph outside the Map.
Currently, the transformation supports only the case `WriteAccess -> MapExit`, where the edge has an empty Memlet.
"""

write = pm.PatternNode(nodes.AccessNode)
map_exit = pm.PatternNode(nodes.MapExit)


@classmethod
def expressions(cls):
return [sdutil.node_path_graph(cls.write, cls.map_exit)]

def can_be_applied(self, state: SDFGState, _: int, sdfg: SDFG, permissive=False):

# The output edges must have empty Memlets
edges = state.edges_between(self.write, self.map_exit)
if any(not e.data.is_empty() for e in edges):
return False

# The input edges must either depend on all the Map parameters or have WCR.
for edge in state.in_edges(self.write):
if edge.data.wcr:
continue
fsymbols = [str(s) for s in edge.data.free_symbols]
if any(p not in fsymbols for p in self.map_exit.map.params):
return False

return True

def apply(self, state: SDFGState, sdfg: SDFG):

entry_node = state.entry_node(self.map_exit)
scope_dict = state.scope_dict()

outer_write = state.add_access(self.write.data)
for edge in state.in_edges(self.write):
state.add_memlet_path(edge.src, self.map_exit, outer_write, memlet=edge.data, src_conn=edge.src_conn)
state.remove_node(self.write)

if scope_dict[entry_node] is not None:
exit_node = state.exit_node(scope_dict[entry_node])
state.add_nedge(outer_write, exit_node, mm.Memlet())
# Clean-up edges with empty Memlets since they are not needed anymore.
for edge in state.edges_between(self.map_exit, exit_node):
if edge.data.is_empty():
state.remove_edge(edge)
119 changes: 119 additions & 0 deletions tests/transformations/remove_intermediate_write_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
import copy
import dace
import numpy as np

from dace.transformation.dataflow import RemoveIntermediateWrite


def test_write_before_map_exit():

sdfg = dace.SDFG('test_write_before_map_exit')
sdfg.add_array('A', (10, ), dace.int32)
sdfg.add_array('B', (10, ), dace.int32)

state = sdfg.add_state('state')
me, mx = state.add_map('map', dict(i='0:10'))
a_access = state.add_read('A')
b_access = state.add_write('B')
tasklet = state.add_tasklet('tasklet', {'__inp'}, {'__out'}, '__out = __inp')
state.add_memlet_path(a_access, me, tasklet, memlet=dace.Memlet(data='A', subset='i'), dst_conn='__inp')
state.add_edge(tasklet, '__out', b_access, None, dace.Memlet(data='B', subset='i'))
state.add_edge(b_access, None, mx, None, dace.Memlet())

A = np.arange(10, dtype=np.int32)
ref = A

before_val = np.empty((10, ), dtype=np.int32)
after_val = np.empty((10, ), dtype=np.int32)

sdfg_before = copy.deepcopy(sdfg)
sdfg_before(A=A, B=before_val)
assert np.allclose(before_val, ref)

sdfg.apply_transformations_repeated(RemoveIntermediateWrite)
sdfg(A=A, B=after_val)
assert np.allclose(after_val, ref)


def test_write_before_nested_map_exit():

sdfg = dace.SDFG('test_write_before_nested_map_exit')
sdfg.add_array('A', (10, 10), dace.int32)
sdfg.add_array('B', (10, 10), dace.int32)

state = sdfg.add_state('state')
me0, mx0 = state.add_map('map', dict(i='0:10'))
me1, mx1 = state.add_map('map2', dict(j='0:10'))
a_access = state.add_read('A')
b_access = state.add_write('B')
tasklet = state.add_tasklet('tasklet', {'__inp'}, {'__out'}, '__out = __inp')
state.add_memlet_path(a_access, me0, me1, tasklet, memlet=dace.Memlet(data='A', subset='i, j'), dst_conn='__inp')
state.add_edge(tasklet, '__out', b_access, None, dace.Memlet(data='B', subset='i, j'))
state.add_nedge(b_access, mx1, dace.Memlet())
state.add_nedge(mx1, mx0, dace.Memlet())

A = np.arange(100, dtype=np.int32).reshape((10, 10)).copy()
ref = A

before_val = np.empty((10, 10), dtype=np.int32)
after_val = np.empty((10, 10), dtype=np.int32)

sdfg_before = copy.deepcopy(sdfg)
sdfg_before(A=A, B=before_val)
assert np.allclose(before_val, ref)

sdfg.apply_transformations_repeated(RemoveIntermediateWrite)
sdfg(A=A, B=after_val)
assert np.allclose(after_val, ref)


def test_write_before_nested_map_exit_2():

sdfg = dace.SDFG('test_write_before_nested_map_exit_2')
sdfg.add_array('A', (10, 10), dace.int32)
sdfg.add_array('B', (10, 10), dace.int32)
sdfg.add_array('C', (10, ), dace.int32, transient=True)

state = sdfg.add_state('state')
me0, mx0 = state.add_map('map', dict(i='0:10'))
me1, mx1 = state.add_map('map2', dict(j='0:10'))
a_access = state.add_read('A')
b_access = state.add_write('B')
c_access = state.add_write('C')
tasklet0 = state.add_tasklet('tasklet0', {'__inp'}, {'__out'}, '__out = __inp')
tasklet1 = state.add_tasklet('tasklet1', {'__inp'}, {'__out'}, '__out = __inp')
state.add_memlet_path(a_access, me0, me1, tasklet0, memlet=dace.Memlet(data='A', subset='i, j'), dst_conn='__inp')
state.add_memlet_path(a_access, me0, me1, tasklet1, memlet=dace.Memlet(data='A', subset='i, j'), dst_conn='__inp')
state.add_edge(tasklet0, '__out', b_access, None, dace.Memlet(data='B', subset='i, j'))
state.add_edge(tasklet1, '__out', c_access, None, dace.Memlet(data='C', subset='j'))
state.add_nedge(b_access, mx1, dace.Memlet())
state.add_nedge(c_access, mx1, dace.Memlet())
state.add_nedge(mx1, mx0, dace.Memlet())

A = np.arange(100, dtype=np.int32).reshape((10, 10)).copy()
ref = A

before_val = np.empty((10, 10), dtype=np.int32)
after_val = np.empty((10, 10), dtype=np.int32)

sdfg_before = copy.deepcopy(sdfg)
sdfg_before(A=A, B=before_val)
assert np.allclose(before_val, ref)

sdfg.apply_transformations_repeated(RemoveIntermediateWrite)
c_nodes = [n for n in state.data_nodes() if n.data == 'C']
assert len(c_nodes) == 1
assert len(state.edges_between(tasklet1, c_nodes[0])) == 0
assert len(state.edges_between(c_nodes[0], mx1)) == 0
assert len(state.edges_between(mx1, c_nodes[0])) == 1
assert len(state.edges_between(c_nodes[0], mx0)) == 1
assert len(state.edges_between(mx0, c_nodes[0])) == 0
sdfg(A=A, B=after_val)
assert np.allclose(after_val, ref)


if __name__ == '__main__':
test_write_before_map_exit()
test_write_before_nested_map_exit()
test_write_before_nested_map_exit_2()

0 comments on commit b7325ea

Please sign in to comment.