Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for TaskletFusion, AugAssignToWCR and MapExpansion #1432

Merged
merged 7 commits into from
Nov 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 24 additions & 7 deletions dace/transformation/dataflow/map_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@

from dace.sdfg.utils import consolidate_edges
from typing import Dict, List
import copy
import dace
from dace import dtypes, subsets, symbolic
from dace.properties import EnumProperty, make_properties
from dace.sdfg import nodes
from dace.sdfg import utils as sdutil
from dace.sdfg.graph import OrderedMultiDiConnectorGraph
from dace.transformation import transformation as pm
from dace.sdfg.propagation import propagate_memlets_scope


@make_properties
Expand Down Expand Up @@ -66,14 +68,28 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG):
# 1. If there are no edges coming from the outside, use empty memlets
# 2. Edges with IN_* connectors replicate along the maps
# 3. Edges for dynamic map ranges replicate until reaching range(s)
for edge in graph.out_edges(map_entry):
for edge in list(graph.out_edges(map_entry)):
if edge.src_conn is not None and edge.src_conn not in entries[-1].out_connectors:
entries[-1].add_out_connector(edge.src_conn)

graph.add_edge(entries[-1], edge.src_conn, edge.dst, edge.dst_conn, memlet=copy.deepcopy(edge.data))
graph.remove_edge(edge)
graph.add_memlet_path(map_entry,
*entries,
edge.dst,
src_conn=edge.src_conn,
memlet=edge.data,
dst_conn=edge.dst_conn)

if graph.in_degree(map_entry) == 0:
graph.add_memlet_path(map_entry, *entries, memlet=dace.Memlet())
else:
for edge in graph.in_edges(map_entry):
if not edge.dst_conn.startswith("IN_"):
continue

in_conn = edge.dst_conn
out_conn = "OUT_" + in_conn[3:]
if in_conn not in entries[-1].in_connectors:
graph.add_memlet_path(map_entry,
*entries,
memlet=copy.deepcopy(edge.data),
src_conn=out_conn,
dst_conn=in_conn)

# Modify dynamic map ranges
dynamic_edges = dace.sdfg.dynamic_map_inputs(graph, map_entry)
Expand Down Expand Up @@ -116,6 +132,7 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG):
else:
raise ValueError('Cannot find scope in state')

propagate_memlets_scope(sdfg, state=graph, scopes=scope)
consolidate_edges(sdfg, scope)

return [map_entry] + entries
2 changes: 1 addition & 1 deletion dace/transformation/dataflow/tasklet_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,5 +267,5 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG):
graph.remove_node(t1)
if data is not None:
graph.remove_node(data)
sdfg.remove_data(data.data, True)

graph.remove_node(t2)
28 changes: 13 additions & 15 deletions dace/transformation/dataflow/wcr_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):

outedge = graph.edges_between(tasklet, mx)[0]

# If in map, only match if the subset is independent of any
# map indices (otherwise no conflict)
if not permissive and len(outedge.data.subset.free_symbols & set(me.map.params)) == len(
me.map.params):
return False

# Get relevant output connector
outconn = outedge.src_conn

Expand Down Expand Up @@ -131,17 +137,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
if edge.data.subset != outedge.data.subset:
continue

# If in map, only match if the subset is independent of any
# map indices (otherwise no conflict)
if expr_index == 1:
if not permissive and len(outedge.data.subset.free_symbols & set(me.map.params)) == len(
me.map.params):
continue

return True
else:
# Only Python/C++ tasklets supported
return False

return False

Expand Down Expand Up @@ -182,11 +178,13 @@ def apply(self, state: SDFGState, sdfg: SDFG):
rhs: ast.BinOp = ast_node.value
op = AugAssignToWCR._PYOP_MAP[type(rhs.op)]
inconns = list(edge.dst_conn for edge in inedges)
for n in (rhs.left, rhs.right):
if isinstance(n, ast.Name) and n.id in inconns:
inedge = inedges[inconns.index(n.id)]
else:
new_rhs = n
if isinstance(rhs.left, ast.Name) and rhs.left.id in inconns:
inedge = inedges[inconns.index(rhs.left.id)]
new_rhs = rhs.right
else:
inedge = inedges[inconns.index(rhs.right.id)]
new_rhs = rhs.left

new_node = ast.copy_location(ast.Assign(targets=[lhs], value=new_rhs), ast_node)
tasklet.code.code = [new_node]

Expand Down
37 changes: 0 additions & 37 deletions tests/expansion_dynamic_range_test.py

This file was deleted.

119 changes: 119 additions & 0 deletions tests/transformations/map_expansion_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
import dace
import numpy as np
from dace.transformation.dataflow import MapExpansion

def test_expand_with_inputs():
@dace.program
def toexpand(A: dace.float64[4, 2], B: dace.float64[2, 2]):
for i, j in dace.map[1:3, 0:2]:
with dace.tasklet:
a1 << A[i, j]
a2 << A[i + 1, j]
a3 << A[i - 1, j]
b >> B[i-1, j]
b = a1 + a2 + a3

sdfg = toexpand.to_sdfg()
sdfg.simplify()

# Init conditions
sdfg.validate()
assert len([node for node in sdfg.start_state.nodes() if isinstance(node, dace.nodes.MapEntry)]) == 1
assert len([node for node in sdfg.start_state.nodes() if isinstance(node, dace.nodes.MapExit)]) == 1

# Expansion
assert sdfg.apply_transformations_repeated(MapExpansion) == 1
sdfg.validate()

map_entries = set()
state = sdfg.start_state
for node in state.nodes():
if not isinstance(node, dace.nodes.MapEntry):
continue

# (Fast) MapExpansion should not add memlet paths for each memlet to a tasklet
if sdfg.start_state.entry_node(node) is None:
assert state.in_degree(node) == 1
assert state.out_degree(node) == 1
assert len(node.out_connectors) == 1
else:
assert state.in_degree(node) == 1
assert state.out_degree(node) == 3
assert len(node.out_connectors) == 1

map_entries.add(node)

assert len(map_entries) == 2

def test_expand_without_inputs():
@dace.program
def toexpand(B: dace.float64[4, 4]):
for i, j in dace.map[0:4, 0:4]:
with dace.tasklet:
b >> B[i, j]
b = 0

sdfg = toexpand.to_sdfg()
sdfg.simplify()

# Init conditions
sdfg.validate()
assert len([node for node in sdfg.start_state.nodes() if isinstance(node, dace.nodes.MapEntry)]) == 1
assert len([node for node in sdfg.start_state.nodes() if isinstance(node, dace.nodes.MapExit)]) == 1

# Expansion
assert sdfg.apply_transformations_repeated(MapExpansion) == 1
sdfg.validate()

map_entries = set()
state = sdfg.start_state
for node in state.nodes():
if not isinstance(node, dace.nodes.MapEntry):
continue

# (Fast) MapExpansion should not add memlet paths for each memlet to a tasklet
if sdfg.start_state.entry_node(node) is None:
assert state.in_degree(node) == 0
assert state.out_degree(node) == 1
assert len(node.out_connectors) == 0
else:
assert state.in_degree(node) == 1
assert state.out_degree(node) == 1
assert len(node.out_connectors) == 0

map_entries.add(node)

assert len(map_entries) == 2

def test_expand_without_dynamic_inputs():
@dace.program
def expansion(A: dace.float32[20, 30, 5], rng: dace.int32[2]):
@dace.map
def mymap(i: _[0:20], j: _[rng[0]:rng[1]], k: _[0:5]):
a << A[i, j, k]
b >> A[i, j, k]
b = a * 2

A = np.random.rand(20, 30, 5).astype(np.float32)
b = np.array([5, 10], dtype=np.int32)
expected = A.copy()
expected[:, 5:10, :] *= 2

sdfg = expansion.to_sdfg()
sdfg(A=A, rng=b)
diff = np.linalg.norm(A - expected)
print('Difference (before transformation):', diff)

sdfg.apply_transformations(MapExpansion)

sdfg(A=A, rng=b)
expected[:, 5:10, :] *= 2
diff2 = np.linalg.norm(A - expected)
print('Difference:', diff2)
assert (diff <= 1e-5) and (diff2 <= 1e-5)

if __name__ == '__main__':
test_expand_with_inputs()
test_expand_without_inputs()
test_expand_without_dynamic_inputs()
29 changes: 29 additions & 0 deletions tests/transformations/tasklet_fusion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dace
from dace import dtypes
from dace.transformation.dataflow import TaskletFusion, MapFusion
from dace.transformation.optimizer import Optimizer
import pytest

datatype = dace.float32
Expand Down Expand Up @@ -257,6 +258,33 @@ def sdfg_none_connector(A: dace.float32[32], B: dace.float32[32]):
assert sdfg.start_state.out_degree(map_entry) == 1
assert len([edge.src_conn for edge in sdfg.start_state.out_edges(map_entry) if edge.src_conn is None]) == 0


def test_intermediate_transients():
@dace.program
def sdfg_intermediate_transients(A: dace.float32[10], B: dace.float32[10]):
tmp = dace.define_local_scalar(dace.float32)

# Use tmp twice to test removal of data
tmp = A[0] + 1
tmp = tmp * 2
B[0] = tmp


sdfg = sdfg_intermediate_transients.to_sdfg(simplify=True)
assert len([node for node in sdfg.start_state.data_nodes() if node.data == "tmp"]) == 2

xforms = Optimizer(sdfg=sdfg).get_pattern_matches(patterns=(TaskletFusion,))
applied = False
for xform in xforms:
if xform.data.data == "tmp":
xform.apply(sdfg.start_state, sdfg)
applied = True
break

assert applied
assert len([node for node in sdfg.start_state.data_nodes() if node.data == "tmp"]) == 1
assert "tmp" in sdfg.arrays

if __name__ == '__main__':
test_basic()
test_same_name()
Expand All @@ -268,3 +296,4 @@ def sdfg_none_connector(A: dace.float32[32], B: dace.float32[32]):
test_map_with_tasklets(language='CPP', with_data=False)
test_map_with_tasklets(language='CPP', with_data=True)
test_none_connector()
test_intermediate_transients()
18 changes: 18 additions & 0 deletions tests/transformations/wcr_conversion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,21 @@ def sdfg_free_map_permissive(A: dace.float64[32], B: dace.float64[32]):

applied = sdfg.apply_transformations_repeated(AugAssignToWCR, permissive=True)
assert applied == 1

def test_aug_assign_same_inconns():

@dace.program
def sdfg_aug_assign_same_inconns(A: dace.float64[32]):
for i in dace.map[0:31]:
with dace.tasklet(language=dace.Language.Python):
a << A[i]
b << A[i+1]
c >> A[i]

c = a * b

sdfg = sdfg_aug_assign_same_inconns.to_sdfg()
sdfg.simplify()

applied = sdfg.apply_transformations_repeated(AugAssignToWCR, permissive=True)
assert applied == 1