Skip to content

Commit

Permalink
Fixes for TaskletFusion, AugAssignToWCR and MapExpansion (#1432)
Browse files Browse the repository at this point in the history
- The PR fixes two minor bugs for corner cases of the AugAssignToWCR and
TaskletFusion which are reflected in additional test cases:
- TaskletFusion: Should not remove array from SDFG, since it could be
used elsewhere
- AugAssignToWCR: Handle tasklets where all inputs come from same array
- The PR re-writes MapExpansion to create only one memlet path per out
connector to be more efficient. I experienced MapExpansion running for
literally hours because it uses add_memlet_path for each edge to a
tasklet. This is too expensive for >4 dimensional stencils with >50
edges
  • Loading branch information
lukastruemper committed Nov 25, 2023
1 parent 4139ddf commit 6d53e24
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 60 deletions.
31 changes: 24 additions & 7 deletions dace/transformation/dataflow/map_expansion.py
Expand Up @@ -3,13 +3,15 @@

from dace.sdfg.utils import consolidate_edges
from typing import Dict, List
import copy
import dace
from dace import dtypes, subsets, symbolic
from dace.properties import EnumProperty, make_properties
from dace.sdfg import nodes
from dace.sdfg import utils as sdutil
from dace.sdfg.graph import OrderedMultiDiConnectorGraph
from dace.transformation import transformation as pm
from dace.sdfg.propagation import propagate_memlets_scope


@make_properties
Expand Down Expand Up @@ -66,14 +68,28 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG):
# 1. If there are no edges coming from the outside, use empty memlets
# 2. Edges with IN_* connectors replicate along the maps
# 3. Edges for dynamic map ranges replicate until reaching range(s)
for edge in graph.out_edges(map_entry):
for edge in list(graph.out_edges(map_entry)):
if edge.src_conn is not None and edge.src_conn not in entries[-1].out_connectors:
entries[-1].add_out_connector(edge.src_conn)

graph.add_edge(entries[-1], edge.src_conn, edge.dst, edge.dst_conn, memlet=copy.deepcopy(edge.data))
graph.remove_edge(edge)
graph.add_memlet_path(map_entry,
*entries,
edge.dst,
src_conn=edge.src_conn,
memlet=edge.data,
dst_conn=edge.dst_conn)

if graph.in_degree(map_entry) == 0:
graph.add_memlet_path(map_entry, *entries, memlet=dace.Memlet())
else:
for edge in graph.in_edges(map_entry):
if not edge.dst_conn.startswith("IN_"):
continue

in_conn = edge.dst_conn
out_conn = "OUT_" + in_conn[3:]
if in_conn not in entries[-1].in_connectors:
graph.add_memlet_path(map_entry,
*entries,
memlet=copy.deepcopy(edge.data),
src_conn=out_conn,
dst_conn=in_conn)

# Modify dynamic map ranges
dynamic_edges = dace.sdfg.dynamic_map_inputs(graph, map_entry)
Expand Down Expand Up @@ -116,6 +132,7 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG):
else:
raise ValueError('Cannot find scope in state')

propagate_memlets_scope(sdfg, state=graph, scopes=scope)
consolidate_edges(sdfg, scope)

return [map_entry] + entries
2 changes: 1 addition & 1 deletion dace/transformation/dataflow/tasklet_fusion.py
Expand Up @@ -267,5 +267,5 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG):
graph.remove_node(t1)
if data is not None:
graph.remove_node(data)
sdfg.remove_data(data.data, True)

graph.remove_node(t2)
28 changes: 13 additions & 15 deletions dace/transformation/dataflow/wcr_conversion.py
Expand Up @@ -75,6 +75,12 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):

outedge = graph.edges_between(tasklet, mx)[0]

# If in map, only match if the subset is independent of any
# map indices (otherwise no conflict)
if not permissive and len(outedge.data.subset.free_symbols & set(me.map.params)) == len(
me.map.params):
return False

# Get relevant output connector
outconn = outedge.src_conn

Expand Down Expand Up @@ -131,17 +137,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
if edge.data.subset != outedge.data.subset:
continue

# If in map, only match if the subset is independent of any
# map indices (otherwise no conflict)
if expr_index == 1:
if not permissive and len(outedge.data.subset.free_symbols & set(me.map.params)) == len(
me.map.params):
continue

return True
else:
# Only Python/C++ tasklets supported
return False

return False

Expand Down Expand Up @@ -182,11 +178,13 @@ def apply(self, state: SDFGState, sdfg: SDFG):
rhs: ast.BinOp = ast_node.value
op = AugAssignToWCR._PYOP_MAP[type(rhs.op)]
inconns = list(edge.dst_conn for edge in inedges)
for n in (rhs.left, rhs.right):
if isinstance(n, ast.Name) and n.id in inconns:
inedge = inedges[inconns.index(n.id)]
else:
new_rhs = n
if isinstance(rhs.left, ast.Name) and rhs.left.id in inconns:
inedge = inedges[inconns.index(rhs.left.id)]
new_rhs = rhs.right
else:
inedge = inedges[inconns.index(rhs.right.id)]
new_rhs = rhs.left

new_node = ast.copy_location(ast.Assign(targets=[lhs], value=new_rhs), ast_node)
tasklet.code.code = [new_node]

Expand Down
37 changes: 0 additions & 37 deletions tests/expansion_dynamic_range_test.py

This file was deleted.

119 changes: 119 additions & 0 deletions tests/transformations/map_expansion_test.py
@@ -0,0 +1,119 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
import dace
import numpy as np
from dace.transformation.dataflow import MapExpansion

def test_expand_with_inputs():
@dace.program
def toexpand(A: dace.float64[4, 2], B: dace.float64[2, 2]):
for i, j in dace.map[1:3, 0:2]:
with dace.tasklet:
a1 << A[i, j]
a2 << A[i + 1, j]
a3 << A[i - 1, j]
b >> B[i-1, j]
b = a1 + a2 + a3

sdfg = toexpand.to_sdfg()
sdfg.simplify()

# Init conditions
sdfg.validate()
assert len([node for node in sdfg.start_state.nodes() if isinstance(node, dace.nodes.MapEntry)]) == 1
assert len([node for node in sdfg.start_state.nodes() if isinstance(node, dace.nodes.MapExit)]) == 1

# Expansion
assert sdfg.apply_transformations_repeated(MapExpansion) == 1
sdfg.validate()

map_entries = set()
state = sdfg.start_state
for node in state.nodes():
if not isinstance(node, dace.nodes.MapEntry):
continue

# (Fast) MapExpansion should not add memlet paths for each memlet to a tasklet
if sdfg.start_state.entry_node(node) is None:
assert state.in_degree(node) == 1
assert state.out_degree(node) == 1
assert len(node.out_connectors) == 1
else:
assert state.in_degree(node) == 1
assert state.out_degree(node) == 3
assert len(node.out_connectors) == 1

map_entries.add(node)

assert len(map_entries) == 2

def test_expand_without_inputs():
@dace.program
def toexpand(B: dace.float64[4, 4]):
for i, j in dace.map[0:4, 0:4]:
with dace.tasklet:
b >> B[i, j]
b = 0

sdfg = toexpand.to_sdfg()
sdfg.simplify()

# Init conditions
sdfg.validate()
assert len([node for node in sdfg.start_state.nodes() if isinstance(node, dace.nodes.MapEntry)]) == 1
assert len([node for node in sdfg.start_state.nodes() if isinstance(node, dace.nodes.MapExit)]) == 1

# Expansion
assert sdfg.apply_transformations_repeated(MapExpansion) == 1
sdfg.validate()

map_entries = set()
state = sdfg.start_state
for node in state.nodes():
if not isinstance(node, dace.nodes.MapEntry):
continue

# (Fast) MapExpansion should not add memlet paths for each memlet to a tasklet
if sdfg.start_state.entry_node(node) is None:
assert state.in_degree(node) == 0
assert state.out_degree(node) == 1
assert len(node.out_connectors) == 0
else:
assert state.in_degree(node) == 1
assert state.out_degree(node) == 1
assert len(node.out_connectors) == 0

map_entries.add(node)

assert len(map_entries) == 2

def test_expand_without_dynamic_inputs():
@dace.program
def expansion(A: dace.float32[20, 30, 5], rng: dace.int32[2]):
@dace.map
def mymap(i: _[0:20], j: _[rng[0]:rng[1]], k: _[0:5]):
a << A[i, j, k]
b >> A[i, j, k]
b = a * 2

A = np.random.rand(20, 30, 5).astype(np.float32)
b = np.array([5, 10], dtype=np.int32)
expected = A.copy()
expected[:, 5:10, :] *= 2

sdfg = expansion.to_sdfg()
sdfg(A=A, rng=b)
diff = np.linalg.norm(A - expected)
print('Difference (before transformation):', diff)

sdfg.apply_transformations(MapExpansion)

sdfg(A=A, rng=b)
expected[:, 5:10, :] *= 2
diff2 = np.linalg.norm(A - expected)
print('Difference:', diff2)
assert (diff <= 1e-5) and (diff2 <= 1e-5)

if __name__ == '__main__':
test_expand_with_inputs()
test_expand_without_inputs()
test_expand_without_dynamic_inputs()
29 changes: 29 additions & 0 deletions tests/transformations/tasklet_fusion_test.py
Expand Up @@ -3,6 +3,7 @@
import dace
from dace import dtypes
from dace.transformation.dataflow import TaskletFusion, MapFusion
from dace.transformation.optimizer import Optimizer
import pytest

datatype = dace.float32
Expand Down Expand Up @@ -257,6 +258,33 @@ def sdfg_none_connector(A: dace.float32[32], B: dace.float32[32]):
assert sdfg.start_state.out_degree(map_entry) == 1
assert len([edge.src_conn for edge in sdfg.start_state.out_edges(map_entry) if edge.src_conn is None]) == 0


def test_intermediate_transients():
@dace.program
def sdfg_intermediate_transients(A: dace.float32[10], B: dace.float32[10]):
tmp = dace.define_local_scalar(dace.float32)

# Use tmp twice to test removal of data
tmp = A[0] + 1
tmp = tmp * 2
B[0] = tmp


sdfg = sdfg_intermediate_transients.to_sdfg(simplify=True)
assert len([node for node in sdfg.start_state.data_nodes() if node.data == "tmp"]) == 2

xforms = Optimizer(sdfg=sdfg).get_pattern_matches(patterns=(TaskletFusion,))
applied = False
for xform in xforms:
if xform.data.data == "tmp":
xform.apply(sdfg.start_state, sdfg)
applied = True
break

assert applied
assert len([node for node in sdfg.start_state.data_nodes() if node.data == "tmp"]) == 1
assert "tmp" in sdfg.arrays

if __name__ == '__main__':
test_basic()
test_same_name()
Expand All @@ -268,3 +296,4 @@ def sdfg_none_connector(A: dace.float32[32], B: dace.float32[32]):
test_map_with_tasklets(language='CPP', with_data=False)
test_map_with_tasklets(language='CPP', with_data=True)
test_none_connector()
test_intermediate_transients()
18 changes: 18 additions & 0 deletions tests/transformations/wcr_conversion_test.py
Expand Up @@ -245,3 +245,21 @@ def sdfg_free_map_permissive(A: dace.float64[32], B: dace.float64[32]):

applied = sdfg.apply_transformations_repeated(AugAssignToWCR, permissive=True)
assert applied == 1

def test_aug_assign_same_inconns():

@dace.program
def sdfg_aug_assign_same_inconns(A: dace.float64[32]):
for i in dace.map[0:31]:
with dace.tasklet(language=dace.Language.Python):
a << A[i]
b << A[i+1]
c >> A[i]

c = a * b

sdfg = sdfg_aug_assign_same_inconns.to_sdfg()
sdfg.simplify()

applied = sdfg.apply_transformations_repeated(AugAssignToWCR, permissive=True)
assert applied == 1

0 comments on commit 6d53e24

Please sign in to comment.