Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix-intermediate-nodes #1212

Merged
merged 5 commits into from
Mar 13, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion dace/transformation/dataflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
# Complexity reduction
from .dedup_access import DeduplicateAccess
from .redundant_array import (RedundantArray, RedundantSecondArray, SqueezeViewRemove, UnsqueezeViewRemove,
RedundantReadSlice, RedundantWriteSlice, RemoveSliceView)
RedundantReadSlice, RedundantWriteSlice, RemoveSliceView, RemoveIntermediateWrite)
from .redundant_array_copying import (RedundantArrayCopyingIn, RedundantArrayCopying, RedundantArrayCopying2,
RedundantArrayCopying3)
from .merge_arrays import InMergeArrays, OutMergeArrays, MergeSourceSinkArrays
Expand Down
50 changes: 50 additions & 0 deletions dace/transformation/dataflow/redundant_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1608,3 +1608,53 @@ def _offset_subset(self, mapping: Dict[int, int], subset: subsets.Range, edge_su
new_subset[adim] = (rb, re, rs)

return subsets.Range(new_subset)


class RemoveIntermediateWrite(pm.SingleStateTransformation):
""" Moves intermediate writes insde a Map's subgraph outside the Map.

Currently, the transformation supports only the case `WriteAccess -> MapExit`, where the edge has an empty Memlet.
"""

write = pm.PatternNode(nodes.AccessNode)
map_exit = pm.PatternNode(nodes.MapExit)


@classmethod
def expressions(cls):
return [sdutil.node_path_graph(cls.write, cls.map_exit)]

def can_be_applied(self, state: SDFGState, _: int, sdfg: SDFG, permissive=False):

# The output edges must have empty Memlets
edges = state.edges_between(self.write, self.map_exit)
if any(not e.data.is_empty() for e in edges):
return False

# The input edges must either depend on all the Map parameters or have WCR.
for edge in state.in_edges(self.write):
if edge.data.wcr:
continue
fsymbols = [str(s) for s in edge.data.free_symbols]
if any(p not in fsymbols for p in self.map_exit.map.params):
return False

return True

def apply(self, state: SDFGState, sdfg: SDFG):

entry_node = state.entry_node(self.map_exit)
scope_dict = state.scope_dict()

outer_write = state.add_access(self.write.data)
for edge in state.in_edges(self.write):
state.add_memlet_path(edge.src, self.map_exit, outer_write, memlet=edge.data, src_conn=edge.src_conn)
state.remove_node(self.write)

if scope_dict[entry_node] is not None:
exit_node = state.exit_node(scope_dict[entry_node])
state.add_nedge(outer_write, exit_node, mm.Memlet())
# Clean-up edges with empty Memlets since they are not needed anymore.
for edge in state.edges_between(self.map_exit, exit_node):
if edge.data.is_empty():
state.remove_edge(edge)
119 changes: 119 additions & 0 deletions tests/transformations/remove_intermediate_write_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
import copy
import dace
import numpy as np

from dace.transformation.dataflow import RemoveIntermediateWrite


def test_write_before_map_exit():

sdfg = dace.SDFG('test_write_before_map_exit')
sdfg.add_array('A', (10, ), dace.int32)
sdfg.add_array('B', (10, ), dace.int32)

state = sdfg.add_state('state')
me, mx = state.add_map('map', dict(i='0:10'))
a_access = state.add_read('A')
b_access = state.add_write('B')
tasklet = state.add_tasklet('tasklet', {'__inp'}, {'__out'}, '__out = __inp')
state.add_memlet_path(a_access, me, tasklet, memlet=dace.Memlet(data='A', subset='i'), dst_conn='__inp')
state.add_edge(tasklet, '__out', b_access, None, dace.Memlet(data='B', subset='i'))
state.add_edge(b_access, None, mx, None, dace.Memlet())

A = np.arange(10, dtype=np.int32)
ref = A

before_val = np.empty((10, ), dtype=np.int32)
after_val = np.empty((10, ), dtype=np.int32)

sdfg_before = copy.deepcopy(sdfg)
sdfg_before(A=A, B=before_val)
assert np.allclose(before_val, ref)

sdfg.apply_transformations_repeated(RemoveIntermediateWrite)
sdfg(A=A, B=after_val)
assert np.allclose(after_val, ref)


def test_write_before_nested_map_exit():

sdfg = dace.SDFG('test_write_before_nested_map_exit')
sdfg.add_array('A', (10, 10), dace.int32)
sdfg.add_array('B', (10, 10), dace.int32)

state = sdfg.add_state('state')
me0, mx0 = state.add_map('map', dict(i='0:10'))
me1, mx1 = state.add_map('map2', dict(j='0:10'))
a_access = state.add_read('A')
b_access = state.add_write('B')
tasklet = state.add_tasklet('tasklet', {'__inp'}, {'__out'}, '__out = __inp')
state.add_memlet_path(a_access, me0, me1, tasklet, memlet=dace.Memlet(data='A', subset='i, j'), dst_conn='__inp')
state.add_edge(tasklet, '__out', b_access, None, dace.Memlet(data='B', subset='i, j'))
state.add_nedge(b_access, mx1, dace.Memlet())
state.add_nedge(mx1, mx0, dace.Memlet())

A = np.arange(100, dtype=np.int32).reshape((10, 10)).copy()
ref = A

before_val = np.empty((10, 10), dtype=np.int32)
after_val = np.empty((10, 10), dtype=np.int32)

sdfg_before = copy.deepcopy(sdfg)
sdfg_before(A=A, B=before_val)
assert np.allclose(before_val, ref)

sdfg.apply_transformations_repeated(RemoveIntermediateWrite)
sdfg(A=A, B=after_val)
assert np.allclose(after_val, ref)


def test_write_before_nested_map_exit_2():

sdfg = dace.SDFG('test_write_before_nested_map_exit_2')
sdfg.add_array('A', (10, 10), dace.int32)
sdfg.add_array('B', (10, 10), dace.int32)
sdfg.add_array('C', (10, ), dace.int32, transient=True)

state = sdfg.add_state('state')
me0, mx0 = state.add_map('map', dict(i='0:10'))
me1, mx1 = state.add_map('map2', dict(j='0:10'))
a_access = state.add_read('A')
b_access = state.add_write('B')
c_access = state.add_write('C')
tasklet0 = state.add_tasklet('tasklet0', {'__inp'}, {'__out'}, '__out = __inp')
tasklet1 = state.add_tasklet('tasklet1', {'__inp'}, {'__out'}, '__out = __inp')
state.add_memlet_path(a_access, me0, me1, tasklet0, memlet=dace.Memlet(data='A', subset='i, j'), dst_conn='__inp')
state.add_memlet_path(a_access, me0, me1, tasklet1, memlet=dace.Memlet(data='A', subset='i, j'), dst_conn='__inp')
state.add_edge(tasklet0, '__out', b_access, None, dace.Memlet(data='B', subset='i, j'))
state.add_edge(tasklet1, '__out', c_access, None, dace.Memlet(data='C', subset='j'))
state.add_nedge(b_access, mx1, dace.Memlet())
state.add_nedge(c_access, mx1, dace.Memlet())
state.add_nedge(mx1, mx0, dace.Memlet())

A = np.arange(100, dtype=np.int32).reshape((10, 10)).copy()
ref = A

before_val = np.empty((10, 10), dtype=np.int32)
after_val = np.empty((10, 10), dtype=np.int32)

sdfg_before = copy.deepcopy(sdfg)
sdfg_before(A=A, B=before_val)
assert np.allclose(before_val, ref)

sdfg.apply_transformations_repeated(RemoveIntermediateWrite)
c_nodes = [n for n in state.data_nodes() if n.data == 'C']
assert len(c_nodes) == 1
assert len(state.edges_between(tasklet1, c_nodes[0])) == 0
assert len(state.edges_between(c_nodes[0], mx1)) == 0
assert len(state.edges_between(mx1, c_nodes[0])) == 1
assert len(state.edges_between(c_nodes[0], mx0)) == 1
assert len(state.edges_between(mx0, c_nodes[0])) == 0
sdfg(A=A, B=after_val)
assert np.allclose(after_val, ref)


if __name__ == '__main__':
test_write_before_map_exit()
test_write_before_nested_map_exit()
test_write_before_nested_map_exit_2()