Skip to content

Commit

Permalink
SpMV optimzation (#830)
Browse files Browse the repository at this point in the history
* Add new transient to ProgramVisitor.variables.

* Added support for (some) Python tasklets.

* Added capability to select the iteration variable for the transformation.

* `extract_map_dims` should now not fail when the input map has only 1 dimension.

* Added StartStateElimination transformation.

* InlineSDFG fissions the state if necessary.

* Input connectors are not pruned when they appear in symbol mappings.

* Try to find an entry node in the parent SDFG if not found in the current SDFG.

* Improved search for outer scope.

* Added TrivialTaskletEliminiation transformation.

* Updated WarpTiling.

* Updated map range

* Enabled setzero for local transient.

* Added init state for thread-local variable.

* Don't apply on streams.

* Added tutorial

* Small refactor for error reporting.

* yapf

* Made condition more robust to nested SDFGs with more than one WCCs,

* Updated tutorial.

* Fixed identation error.

* Fixed missing code.
  • Loading branch information
alexnick83 committed Aug 22, 2021
1 parent 8f3a6c2 commit 7a35b4a
Show file tree
Hide file tree
Showing 13 changed files with 794 additions and 36 deletions.
9 changes: 9 additions & 0 deletions dace/codegen/targets/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1874,6 +1874,15 @@ def generate_devicelevel_scope(self, sdfg, dfg_scope, state_id,
sdfg, state_id, scope_entry)

outer_scope = sdfg.nodes()[state_id].entry_node(scope_entry)
current_sdfg = sdfg
while not outer_scope and current_sdfg:
current_state = current_sdfg.parent
nsdfg_node = current_sdfg.parent_nsdfg_node
outer_scope = current_state.entry_node(nsdfg_node)
current_sdfg = current_state.parent
if not outer_scope:
raise ValueError(
f'Failed to find the outer scope of {scope_entry}')
callsite_stream.write(
'if ({} < {}) {{'.format(
outer_scope.map.params[0],
Expand Down
1 change: 1 addition & 0 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4732,6 +4732,7 @@ def _add_read_slice(self, array: str, node: ast.Subscript,
else:
tmp, tmparr = self.sdfg.add_temp_transient(
other_subset.size(), arrobj.dtype, arrobj.storage)
self.variables[tmp] = tmp
wnode = self.last_state.add_write(tmp,
debuginfo=self.current_lineinfo)
self.last_state.add_nedge(
Expand Down
1 change: 1 addition & 0 deletions dace/transformation/dataflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .prune_connectors import PruneConnectors, PruneSymbols
from .wcr_conversion import AugAssignToWCR
from .tasklet_fusion import SimpleTaskletFusion
from .trivial_tasklet_elimination import TrivialTaskletElimination

# Device-related
from .copy_to_device import CopyToDevice
Expand Down
8 changes: 8 additions & 0 deletions dace/transformation/dataflow/prune_connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ def can_be_applied(graph: Union[SDFG, SDFGState],
prune_in = nsdfg.in_connectors.keys() - read_set
prune_out = nsdfg.out_connectors.keys() - write_set

# Take into account symbol mappings
strs = tuple(nsdfg.symbol_mapping.values())
syms = tuple(symbolic.pystr_to_symbolic(s) for s in strs)
symnames = tuple(s.name if hasattr(s, 'name') else '' for s in syms)
for conn in list(prune_in):
if conn in syms or conn in symnames:
prune_in.remove(conn)

# Add WCR outputs to "do not prune" input list
for e in graph.out_edges(nsdfg):
if e.data.wcr is not None and e.src_conn in prune_in:
Expand Down
74 changes: 74 additions & 0 deletions dace/transformation/dataflow/trivial_tasklet_elimination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
""" Contains classes that implement the trivial-tasklet-elimination transformation. """

from dace import data, registry
from dace.sdfg import nodes
from dace.sdfg import utils as sdutil
from dace.transformation import transformation
from dace.properties import make_properties


@registry.autoregister_params(singlestate=True)
@make_properties
class TrivialTaskletElimination(transformation.Transformation):
""" Implements the Trivial-Tasklet Elimination pattern.
Trivial-Tasklet Elimination removes tasklets that just copy the input
to the output without WCR.
"""

read = transformation.PatternNode(nodes.AccessNode)
tasklet = transformation.PatternNode(nodes.Tasklet)
write = transformation.PatternNode(nodes.AccessNode)

@staticmethod
def expressions():
return [sdutil.node_path_graph(TrivialTaskletElimination.read,
TrivialTaskletElimination.tasklet,
TrivialTaskletElimination.write)]

@staticmethod
def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
read = graph.nodes()[candidate[TrivialTaskletElimination.read]]
tasklet = graph.nodes()[candidate[TrivialTaskletElimination.tasklet]]
write = graph.nodes()[candidate[TrivialTaskletElimination.write]]
# Do not apply on Streams
if isinstance(sdfg.arrays[read.data], data.Stream):
return False
if isinstance(sdfg.arrays[write.data], data.Stream):
return False
if len(graph.in_edges(tasklet)) != 1:
return False
if len(graph.out_edges(tasklet)) != 1:
return False
if graph.edges_between(tasklet, write)[0].data.wcr:
return False
if len(tasklet.in_connectors) != 1:
return False
if len(tasklet.out_connectors) != 1:
return False
in_conn = list(tasklet.in_connectors.keys())[0]
out_conn = list(tasklet.out_connectors.keys())[0]
if tasklet.code.as_string != f'{out_conn} = {in_conn}':
return False

return True

@staticmethod
def match_to_str(graph, candidate):
tasklet = graph.nodes()[candidate[TrivialTaskletElimination.tasklet]]
return tasklet.label

def apply(self, sdfg):
graph = sdfg.nodes()[self.state_id]
read = graph.nodes()[self.subgraph[TrivialTaskletElimination.read]]
tasklet = graph.nodes()[self.subgraph[TrivialTaskletElimination.tasklet]]
write = graph.nodes()[self.subgraph[TrivialTaskletElimination.write]]

in_edge = graph.edges_between(read, tasklet)[0]
out_edge = graph.edges_between(tasklet, write)[0]
graph.remove_edge(in_edge)
graph.remove_edge(out_edge)
out_edge.data.other_subset = in_edge.data.subset
graph.add_nedge(read, write, out_edge.data)
graph.remove_node(tasklet)
34 changes: 26 additions & 8 deletions dace/transformation/dataflow/warp_tiling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
import copy
from dace import registry, properties, nodes, dtypes, symbolic
from dace.sdfg.graph import SubgraphView
from dace import registry, properties, nodes, dtypes, subsets, symbolic
from dace import Memlet, SDFG, SDFGState
from dace.frontend.operations import detect_reduction_type
from dace.transformation import transformation as xf, helpers as xfh
Expand Down Expand Up @@ -70,7 +71,7 @@ def apply(self, sdfg: SDFG) -> nodes.MapEntry:
nsdfg_node.symbol_mapping['__tid'] = __tid
if '__tid' not in nsdfg.symbols:
nsdfg.add_symbol('__tid', dtypes.int32)
nmap.range[-1] = (nmap.range[-1][0], nmap.range[-1][1],
nmap.range[-1] = (nmap.range[-1][0], nmap.range[-1][1] - __tid,
nmap.range[-1][2] * self.warp_size)
subgraph = nstate.scope_subgraph(nmap)
subgraph.replace(nmap.params[-1], f'{nmap.params[-1]} + __tid')
Expand Down Expand Up @@ -124,20 +125,37 @@ def apply(self, sdfg: SDFG) -> nodes.MapEntry:
credtype = ('dace::ReductionType::' +
str(redtype)[str(redtype).find('.') + 1:])

# Add local access between thread-locan and warp reduction
newnode = nstate.add_access(out_edge.data.data)
# Add local access between thread-local and warp reduction
name = nsdfg._find_new_name(out_edge.data.data)
nsdfg.add_scalar(name,
nsdfg.arrays[out_edge.data.data].dtype,
transient=True)

# Initialize thread-local to global value
read = nstate.add_read(out_edge.data.data)
write = nstate.add_write(name)
edge = nstate.add_nedge(read, write,
copy.deepcopy(out_edge.data))
edge.data.wcr = None
xfh.state_fission(nsdfg,
SubgraphView(nstate, [read, write]))

newnode = nstate.add_access(name)
nstate.remove_edge(out_edge)
nstate.add_edge(out_edge.src, out_edge.src_conn, newnode,
None, copy.deepcopy(out_edge.data))
edge = nstate.add_edge(out_edge.src, out_edge.src_conn,
newnode, None,
copy.deepcopy(out_edge.data))
for e in nstate.memlet_path(edge):
e.data.data = name
e.data.subset = subsets.Range([(0, 0, 1)])

if out_edge.data.subset.num_elements(
) == 1: # One element: tasklet
wrt = nstate.add_tasklet(
'warpreduce', {'__a'}, {'__out'},
f'__out = dace::warpReduce<{credtype}, {ctype}>::reduce(__a);',
dtypes.Language.CPP)
nstate.add_edge(newnode, None, wrt, '__a',
Memlet(out_edge.data.data))
nstate.add_edge(newnode, None, wrt, '__a', Memlet(name))
out_edge.data.wcr = None
nstate.add_edge(wrt, '__out', out_edge.dst, None,
out_edge.data)
Expand Down
49 changes: 46 additions & 3 deletions dace/transformation/dataflow/wcr_conversion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
""" Transformations to convert subgraphs to write-conflict resolutions. """
import ast
import re
from dace import registry, nodes, dtypes
from dace.transformation import transformation, helpers as xfh
Expand All @@ -24,6 +25,14 @@ class AugAssignToWCR(transformation.Transformation):
'-': ('+', '-({expr})'),
'/': ('*', '((decltype({expr}))1)/({expr})')
}
_PYOP_MAP = {
ast.Add: '+',
ast.Sub: '-',
ast.Mult: '*',
ast.BitXor: '^',
ast.Mod: '%',
ast.Div: '/'
}

@staticmethod
def expressions():
Expand Down Expand Up @@ -86,8 +95,28 @@ def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
re.escape(o) for o in AugAssignToWCR._EXPRESSIONS)

if tasklet.language is dtypes.Language.Python:
# Expect ast.Assign(ast.Expr())
return False
# Match a single assignment with a binary operation as RHS
if len(tasklet.code.code) > 1:
return False
if not isinstance(tasklet.code.code[0], ast.Assign):
return False
ast_node: ast.Assign = tasklet.code.code[0]
if len(ast_node.targets) > 1:
return False
if not isinstance(ast_node.targets[0], ast.Name):
return False
lhs: ast.Name = ast_node.targets[0]
if lhs.id != outconn:
return False
if not isinstance(ast_node.value, ast.BinOp):
return False
rhs: ast.BinOp = ast_node.value
if not isinstance(rhs.op, tuple(AugAssignToWCR._PYOP_MAP.keys())):
return False
inconns = tuple(edge.dst_conn for edge in inedges)
for n in (rhs.left, rhs.right):
if isinstance(n, ast.Name) and n.id in inconns:
return True
elif tasklet.language is dtypes.Language.CPP:
cstr = tasklet.code.as_string.strip()
for edge in inedges:
Expand Down Expand Up @@ -178,7 +207,21 @@ def apply(self, sdfg: SDFG):

# Change tasklet code
if tasklet.language is dtypes.Language.Python:
raise NotImplementedError
# Match a single assignment with a binary operation as RHS
ast_node: ast.Assign = tasklet.code.code[0]
lhs: ast.Name = ast_node.targets[0]
rhs: ast.BinOp = ast_node.value
op = AugAssignToWCR._PYOP_MAP[type(rhs.op)]
inconns = list(edge.dst_conn for edge in inedges)
for n in (rhs.left, rhs.right):
if isinstance(n, ast.Name) and n.id in inconns:
inedge = inedges[inconns.index(n.id)]
else:
new_rhs = n
new_node = ast.copy_location(
ast.Assign(targets=[lhs], value=new_rhs), ast_node)
tasklet.code.code = [new_node]

elif tasklet.language is dtypes.Language.CPP:
cstr = tasklet.code.as_string.strip()
for edge in inedges:
Expand Down
43 changes: 24 additions & 19 deletions dace/transformation/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,25 +726,30 @@ def extract_map_dims(sdfg: SDFG, map_entry: nodes.MapEntry,
map_entry,
dims + [i for i in range(len(map_entry.map.params)) if i not in dims])
# Expand map
entries = MapExpansion.apply_to(sdfg, map_entry=map_entry)

# Collapse extracted maps
extracted_map = entries[0]
for idx in range(len(dims) - 1):
extracted_map, _ = MapCollapse.apply_to(
sdfg,
_outer_map_entry=extracted_map,
_inner_map_entry=entries[idx + 1],
)

# Collapse remaining maps
map_to_collapse = entries[len(dims)]
for idx in range(len(dims), len(entries) - 1):
map_to_collapse, _ = MapCollapse.apply_to(
sdfg,
_outer_map_entry=map_to_collapse,
_inner_map_entry=entries[idx + 1],
)
if len(map_entry.map.params) > 1:
entries = MapExpansion.apply_to(sdfg, map_entry=map_entry)

# Collapse extracted maps
extracted_map = entries[0]
for idx in range(len(dims) - 1):
extracted_map, _ = MapCollapse.apply_to(
sdfg,
_outer_map_entry=extracted_map,
_inner_map_entry=entries[idx + 1],
)

# Collapse remaining maps
map_to_collapse = entries[len(dims)]
for idx in range(len(dims), len(entries) - 1):
map_to_collapse, _ = MapCollapse.apply_to(
sdfg,
_outer_map_entry=map_to_collapse,
_inner_map_entry=entries[idx + 1],
)
else:
extracted_map = map_entry
map_to_collapse = map_entry


return extracted_map, map_to_collapse

Expand Down
4 changes: 3 additions & 1 deletion dace/transformation/interstate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
""" This module initializes the inter-state transformations package."""

from .state_fusion import StateFusion
from .state_elimination import EndStateElimination, StateAssignElimination, HoistState
from .state_elimination import (EndStateElimination, StartStateElimination,
StateAssignElimination, SymbolAliasPromotion,
HoistState)
from .fpga_transform_state import FPGATransformState
from .fpga_transform_sdfg import FPGATransformSDFG
from .gpu_transform_sdfg import GPUTransformSDFG
Expand Down
18 changes: 14 additions & 4 deletions dace/transformation/interstate/loop_to_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,22 @@ def _check_range(subset, a, itersym, b, step):


@registry.autoregister
@make_properties
class LoopToMap(DetectLoop):
"""Convert a control flow loop into a dataflow map. Currently only supports
the simple case where there is no overlap between inputs and outputs in
the body of the loop, and where the loop body only consists of a single
state.
"""
@staticmethod
def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):

itervar = Property(
dtype=str,
allow_none=True,
default=None,
desc='The name of the iteration variable (optional).',
)

def can_be_applied(self, graph, candidate, expr_index, sdfg, strict=False):
# Is this even a loop
if not DetectLoop.can_be_applied(graph, candidate, expr_index, sdfg,
strict):
Expand All @@ -72,7 +80,8 @@ def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
return False

# If loop cannot be detected, fail
found = find_for_loop(graph, guard, begin)
found = find_for_loop(graph, guard, begin,
itervar=self.itervar)
if not found:
return False

Expand Down Expand Up @@ -218,7 +227,8 @@ def apply(self, sdfg: sd.SDFG):
after: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._exit_state])

# Obtain iteration variable, range, and stride
itervar, (start, end, step), (_, body_end) = find_for_loop(sdfg, guard, body)
itervar, (start, end, step), (_, body_end) = find_for_loop(
sdfg, guard, body, itervar=self.itervar)

# Find all loop-body states
states = set([body_end])
Expand Down
11 changes: 10 additions & 1 deletion dace/transformation/interstate/sdfg_nesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from dace import memlet, registry, sdfg as sd, Memlet, symbolic, dtypes, subsets
from dace.frontend.python import astutils
from dace.sdfg import nodes, propagation
from dace.sdfg import nodes, propagation, utils
from dace.sdfg.graph import MultiConnectorEdge, SubgraphView
from dace.sdfg import SDFG, SDFGState
from dace.sdfg import utils as sdutil, infer_types, propagation
Expand Down Expand Up @@ -541,6 +541,10 @@ def apply(self, sdfg: SDFG):
'(reconnecting inputs)')
state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn,
edge.data)
# Fission state if necessary
cc = utils.weakly_connected_component(state, node)
if not any(n in cc for n in subgraph.nodes()):
helpers.state_fission(state.parent, cc)
for edge in removed_out_edges:
# Find last access node that refers to this edge
try:
Expand All @@ -553,6 +557,11 @@ def apply(self, sdfg: SDFG):
'(reconnecting outputs)')
state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn,
edge.data)
# Fission state if necessary
cc = utils.weakly_connected_component(state, node)
if not any(n in cc for n in subgraph.nodes()):
cc2 = SubgraphView([n for n in state.nodes() if n not in cc])
state = helpers.state_fission(sdfg, cc2)

#######################################################
# Remove nested SDFG node
Expand Down
Loading

0 comments on commit 7a35b4a

Please sign in to comment.