Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AugAssignToWCR: Support for more cases and increased test coverage #1359

Merged
merged 10 commits into from
Nov 3, 2023
152 changes: 85 additions & 67 deletions dace/transformation/dataflow/wcr_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
""" Transformations to convert subgraphs to write-conflict resolutions. """
import ast
import re
from dace import registry, nodes, dtypes
import copy
from dace import registry, nodes, dtypes, Memlet
from dace.transformation import transformation, helpers as xfh
from dace.sdfg import graph as gr, utils as sdutil
from dace import SDFG, SDFGState
from dace.sdfg.state import StateSubgraphView
from dace.transformation import helpers
from dace.sdfg.propagation import propagate_memlets_state


class AugAssignToWCR(transformation.SingleStateTransformation):
Expand All @@ -20,13 +24,15 @@ class AugAssignToWCR(transformation.SingleStateTransformation):
map_exit = transformation.PatternNode(nodes.MapExit)

_EXPRESSIONS = ['+', '-', '*', '^', '%'] #, '/']
_FUNCTIONS = ['min', 'max']
_EXPR_MAP = {'-': ('+', '-({expr})'), '/': ('*', '((decltype({expr}))1)/({expr})')}
_PYOP_MAP = {ast.Add: '+', ast.Sub: '-', ast.Mult: '*', ast.BitXor: '^', ast.Mod: '%', ast.Div: '/'}

@classmethod
def expressions(cls):
return [
sdutil.node_path_graph(cls.input, cls.tasklet, cls.output),
sdutil.node_path_graph(cls.input, cls.map_entry, cls.tasklet, cls.map_exit, cls.output)
]

def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
Expand All @@ -38,7 +44,6 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):

# Free tasklet
if expr_index == 0:
# Only free tasklets supported for now
if graph.entry_node(tasklet) is not None:
return False

Expand All @@ -49,8 +54,6 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
# Make sure augmented assignment can be fissioned as necessary
if any(not isinstance(e.src, nodes.AccessNode) for e in graph.in_edges(tasklet)):
return False
if graph.in_degree(inarr) > 0 and graph.out_degree(outarr) > 0:
return False

outedge = graph.edges_between(tasklet, outarr)[0]
else: # Free map
Expand All @@ -65,19 +68,18 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
if len(graph.edges_between(tasklet, mx)) > 1:
return False

# Currently no fission is supported
# Make sure augmented assignment can be fissioned as necessary
if any(e.src is not me and not isinstance(e.src, nodes.AccessNode)
for e in graph.in_edges(me) + graph.in_edges(tasklet)):
return False
if graph.in_degree(inarr) > 0:
lukastruemper marked this conversation as resolved.
Show resolved Hide resolved
return False

outedge = graph.edges_between(tasklet, mx)[0]

# Get relevant output connector
outconn = outedge.src_conn

ops = '[%s]' % ''.join(re.escape(o) for o in AugAssignToWCR._EXPRESSIONS)
funcs = '|'.join(re.escape(o) for o in AugAssignToWCR._FUNCTIONS)

if tasklet.language is dtypes.Language.Python:
# Match a single assignment with a binary operation as RHS
Expand Down Expand Up @@ -108,18 +110,33 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
# Try to match a single C assignment that can be converted to WCR
inconn = edge.dst_conn
lhs = r'^\s*%s\s*=\s*%s\s*%s.*;$' % (re.escape(outconn), re.escape(inconn), ops)
rhs = r'^\s*%s\s*=\s*.*%s\s*%s;$' % (re.escape(outconn), ops, re.escape(inconn))
if re.match(lhs, cstr) is None:
continue
# rhs: a = (...) op b
rhs = r'^\s*%s\s*=\s*\(.*\)\s*%s\s*%s;$' % (re.escape(outconn), ops, re.escape(inconn))
func_lhs = r'^\s*%s\s*=\s*(%s)\(\s*%s\s*,.*\)\s*;$' % (re.escape(outconn), funcs, re.escape(inconn))
func_rhs = r'^\s*%s\s*=\s*(%s)\(.*,\s*%s\s*\)\s*;$' % (re.escape(outconn), funcs, re.escape(inconn))
if re.match(lhs, cstr) is None and re.match(rhs, cstr) is None:
if re.match(func_lhs, cstr) is None and re.match(func_rhs, cstr) is None:
inconns = list(self.tasklet.in_connectors)
if len(inconns) != 2:
continue

# Special case: a = <other> op b
other_inconn = inconns[0] if inconns[0] != inconn else inconns[1]
rhs2 = r'^\s*%s\s*=\s*%s\s*%s\s*%s;$' % (re.escape(outconn), re.escape(other_inconn), ops,
re.escape(inconn))
if re.match(rhs2, cstr) is None:
continue

# Same memlet
if edge.data.subset != outedge.data.subset:
continue

# If in map, only match if the subset is independent of any
# map indices (otherwise no conflict)
if (expr_index == 1 and len(outedge.data.subset.free_symbols
& set(me.map.params)) == len(me.map.params)):
continue
if expr_index == 1:
if not permissive and len(outedge.data.subset.free_symbols & set(me.map.params)) == len(
me.map.params):
continue

return True
else:
Expand All @@ -132,57 +149,30 @@ def apply(self, state: SDFGState, sdfg: SDFG):
input: nodes.AccessNode = self.input
tasklet: nodes.Tasklet = self.tasklet
output: nodes.AccessNode = self.output
if self.expr_index == 1:
me = self.map_entry
mx = self.map_exit

# If state fission is necessary to keep semantics, do it first
if (self.expr_index == 0 and state.in_degree(input) > 0 and state.out_degree(output) == 0):
newstate = sdfg.add_state_after(state)
newstate.add_node(tasklet)
new_input, new_output = None, None

# Keep old edges for after we remove tasklet from the original state
in_edges = list(state.in_edges(tasklet))
out_edges = list(state.out_edges(tasklet))

for e in in_edges:
r = newstate.add_read(e.src.data)
newstate.add_edge(r, e.src_conn, e.dst, e.dst_conn, e.data)
if e.src is input:
new_input = r
for e in out_edges:
w = newstate.add_write(e.dst.data)
newstate.add_edge(e.src, e.src_conn, w, e.dst_conn, e.data)
if e.dst is output:
new_output = w

# Remove tasklet and resulting isolated nodes
state.remove_node(tasklet)
for e in in_edges:
if state.degree(e.src) == 0:
state.remove_node(e.src)
for e in out_edges:
if state.degree(e.dst) == 0:
state.remove_node(e.dst)

# Reset state and nodes for rest of transformation
input = new_input
output = new_output
state = newstate
# End of state fission
if state.in_degree(input) > 0:
subgraph_nodes = set([e.src for e in state.bfs_edges(input, reverse=True)])
subgraph_nodes.add(input)

subgraph = StateSubgraphView(state, subgraph_nodes)
helpers.state_fission(sdfg, subgraph)

if self.expr_index == 0:
inedges = state.edges_between(input, tasklet)
outedge = state.edges_between(tasklet, output)[0]
else:
me = self.map_entry
mx = self.map_exit

inedges = state.edges_between(me, tasklet)
outedge = state.edges_between(tasklet, mx)[0]

# Get relevant output connector
outconn = outedge.src_conn

ops = '[%s]' % ''.join(re.escape(o) for o in AugAssignToWCR._EXPRESSIONS)
funcs = '|'.join(re.escape(o) for o in AugAssignToWCR._FUNCTIONS)

# Change tasklet code
if tasklet.language is dtypes.Language.Python:
Expand All @@ -206,13 +196,40 @@ def apply(self, state: SDFGState, sdfg: SDFG):
inconn = edge.dst_conn
match = re.match(r'^\s*%s\s*=\s*%s\s*(%s)(.*);$' % (re.escape(outconn), re.escape(inconn), ops), cstr)
if match is None:
# match = re.match(
# r'^\s*%s\s*=\s*(.*)\s*(%s)\s*%s;$' %
# (re.escape(outconn), ops, re.escape(inconn)), cstr)
# if match is None:
continue
# op = match.group(2)
# expr = match.group(1)
match = re.match(
r'^\s*%s\s*=\s*\((.*)\)\s*(%s)\s*%s;$' % (re.escape(outconn), ops, re.escape(inconn)), cstr)
if match is None:
func_rhs = r'^\s*%s\s*=\s*(%s)\((.*),\s*%s\s*\)\s*;$' % (re.escape(outconn), funcs,
re.escape(inconn))
match = re.match(func_rhs, cstr)
if match is None:
func_lhs = r'^\s*%s\s*=\s*(%s)\(\s*%s\s*,(.*)\)\s*;$' % (re.escape(outconn), funcs,
re.escape(inconn))
match = re.match(func_lhs, cstr)
if match is None:
inconns = list(self.tasklet.in_connectors)
if len(inconns) != 2:
continue

# Special case: a = <other> op b
other_inconn = inconns[0] if inconns[0] != inconn else inconns[1]
rhs2 = r'^\s*%s\s*=\s*(%s)\s*(%s)\s*%s;$' % (
re.escape(outconn), re.escape(other_inconn), ops, re.escape(inconn))
match = re.match(rhs2, cstr)
if match is None:
continue
else:
op = match.group(2)
expr = match.group(1)
else:
op = match.group(1)
expr = match.group(2)
else:
op = match.group(1)
expr = match.group(2)
else:
op = match.group(2)
expr = match.group(1)
else:
op = match.group(1)
expr = match.group(2)
Expand All @@ -232,16 +249,14 @@ def apply(self, state: SDFGState, sdfg: SDFG):
raise NotImplementedError

# Change output edge
outedge.data.wcr = f'lambda a,b: a {op} b'

if self.expr_index == 0:
# Remove input node and connector
state.remove_edge_and_connectors(inedge)
if state.degree(input) == 0:
state.remove_node(input)
if op in AugAssignToWCR._FUNCTIONS:
outedge.data.wcr = f'lambda a,b: {op}(a, b)'
else:
# Remove input edge and dst connector, but not necessarily src
state.remove_memlet_path(inedge)
outedge.data.wcr = f'lambda a,b: a {op} b'

# Remove input node and connector
state.remove_memlet_path(inedge)
propagate_memlets_state(sdfg, state)

# If outedge leads to non-transient, and this is a nested SDFG,
# propagate outwards
Expand All @@ -252,6 +267,9 @@ def apply(self, state: SDFGState, sdfg: SDFG):
sd = sd.parent_sdfg
outedge = next(iter(nstate.out_edges_by_connector(nsdfg, outedge.data.data)))
for outedge in nstate.memlet_path(outedge):
outedge.data.wcr = f'lambda a,b: a {op} b'
if op in AugAssignToWCR._FUNCTIONS:
outedge.data.wcr = f'lambda a,b: {op}(a, b)'
else:
outedge.data.wcr = f'lambda a,b: a {op} b'
# At this point we are leading to an access node again and can
# traverse further up