Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide passes to detect shadowed scalar reads and fission them into separte scalars #1198

Merged
merged 8 commits into from
Feb 28, 2023
2 changes: 1 addition & 1 deletion dace/codegen/targets/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def _compute_pool_release(self, top_sdfg: SDFG):
# Lazily compute reachability and access nodes
if reachability is None:
reachability = ap.StateReachability().apply_pass(top_sdfg, {})
access_nodes = ap.FindAccessNodes().apply_pass(top_sdfg, {})
access_nodes = ap.FindAccessStates().apply_pass(top_sdfg, {})

reachable = reachability[sdfg.sdfg_id]
access_sets = access_nodes[sdfg.sdfg_id]
Expand Down
2 changes: 1 addition & 1 deletion dace/transformation/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .analysis import StateReachability, AccessSets, FindAccessNodes
from .analysis import StateReachability, AccessSets, FindAccessStates
from .array_elimination import ArrayElimination
from .consolidate_edges import ConsolidateEdges
from .constant_propagation import ConstantPropagation
Expand Down
153 changes: 150 additions & 3 deletions dace/transformation/passes/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@

from collections import defaultdict
from dace.transformation import pass_pipeline as ppl
from dace import SDFG, SDFGState, properties
from typing import Dict, Set, Tuple
from dace import SDFG, SDFGState, properties, InterstateEdge
from dace.sdfg import nodes as nd
from typing import Dict, Set, Tuple, Any, Optional, Union
import networkx as nx
from networkx.algorithms import shortest_paths as nxsp

WriteScopeDict = Dict[str, Dict[Optional[Tuple[SDFGState, nd.AccessNode]],
Set[Tuple[SDFGState, Union[nd.AccessNode, InterstateEdge]]]]]


@properties.make_properties
Expand Down Expand Up @@ -80,7 +85,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Tuple[Set[s


@properties.make_properties
class FindAccessNodes(ppl.Pass):
class FindAccessStates(ppl.Pass):
"""
For each data descriptor, creates a set of states in which access nodes of that data are used.
"""
Expand Down Expand Up @@ -115,3 +120,145 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[SDFGState]]]:

top_result[sdfg.sdfg_id] = result
return top_result


@properties.make_properties
class FindAccessNodes(ppl.Pass):
"""
For each data descriptor, creates a dictionary mapping states to all read and write access nodes with the given
data descriptor.
"""

CATEGORY: str = 'Analysis'

def modifies(self) -> ppl.Modifies:
return ppl.Modifies.Nothing

def should_reapply(self, modified: ppl.Modifies) -> bool:
return modified & ppl.Modifies.AccessNodes

def apply_pass(self, top_sdfg: SDFG,
_) -> Dict[int, Dict[str, Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]]]]:
"""
:return: A dictionary mapping each data descriptor name to a dictionary keyed by states with all access nodes
that use that data descriptor.
"""
top_result: Dict[int, Dict[str, Set[nd.AccessNode]]] = dict()

for sdfg in top_sdfg.all_sdfgs_recursive():
result: Dict[str, Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]]] = defaultdict(
lambda: defaultdict(lambda: [set(), set()]))
for state in sdfg.nodes():
for anode in state.data_nodes():
if state.in_degree(anode) > 0:
result[anode.data][state][1].add(anode)
if state.out_degree(anode) > 0:
result[anode.data][state][0].add(anode)
top_result[sdfg.sdfg_id] = result
return top_result


@properties.make_properties
class ScalarWriteShadowScopes(ppl.Pass):
"""
For each scalar or array of size 1, create a dictionary mapping each write to that data container
to the set of reads that are shadowed / dominated by that write.
"""

CATEGORY: str = 'Analysis'

def modifies(self) -> ppl.Modifies:
return ppl.Modifies.Nothing

def should_reapply(self, modified: ppl.Modifies) -> bool:
# If anything was modified, reapply
return modified & ppl.Modifies.States

def depends_on(self):
return {AccessSets, FindAccessNodes}

def _find_dominating_write(
self, desc: str, state: SDFGState, read: Union[nd.AccessNode, InterstateEdge],
access_nodes: Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]],
state_idom: Dict[SDFGState, SDFGState], access_sets: Dict[SDFGState, Tuple[Set[str], Set[str]]]
) -> Optional[Tuple[SDFGState, nd.AccessNode]]:
write_state = None

if isinstance(read, nd.AccessNode):
# If the read is also a write, it shadows itself.
iedges = state.in_edges(read)
if len(iedges) > 0 and any(not e.data.is_empty() for e in iedges):
return (state, read)

# Find a dominating write within the same state.
# TODO: Can this be done more efficiently?
closest_candidate = None
write_nodes = access_nodes[desc][state][1]
for cand in write_nodes:
if nxsp.has_path(state._nx, cand, read):
if closest_candidate is None or nxsp.has_path(state._nx, closest_candidate, cand):
closest_candidate = cand
if closest_candidate is not None:
return (state, closest_candidate)

# Find the dominating write state if the current state is not the dominating write state.
write_state = None
nstate = state_idom[state] if state_idom[state] != state else None
while nstate is not None and write_state is None:
if desc in access_sets[nstate][1]:
write_state = nstate
nstate = state_idom[nstate] if state_idom[nstate] != nstate else None
elif isinstance(read, InterstateEdge):
# Consider the current state as the write state, since the read is happening on an outgoing interstate edge.
write_state = state

# Find a dominating write in the write state, i.e., the 'last' write to the data container.
if write_state is not None:
closest_candidate = None
for cand in access_nodes[desc][write_state][1]:
if write_state.out_degree(cand) == 0:
closest_candidate = cand
break
elif closest_candidate is None or nxsp.has_path(write_state._nx, closest_candidate, cand):
closest_candidate = cand
if closest_candidate is not None:
return (write_state, closest_candidate)

return None

def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[int, WriteScopeDict]:
"""
:return: A dictionary mapping each data descriptor name to a dictionary, where writes to that data descriptor
and the states they are contained in are mapped to the set of reads (and their states) that are in the
scope of that write.
"""
top_result: Dict[int, WriteScopeDict] = dict()

for sdfg in top_sdfg.all_sdfgs_recursive():
result: WriteScopeDict = defaultdict(lambda: defaultdict(lambda: set()))
idom = nx.immediate_dominators(sdfg.nx, sdfg.start_state)
access_sets: Dict[SDFGState, Tuple[Set[str],
Set[str]]] = pipeline_results[AccessSets.__name__][sdfg.sdfg_id]
access_nodes: Dict[str, Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]]] = pipeline_results[
FindAccessNodes.__name__][sdfg.sdfg_id]

anames = sdfg.arrays.keys()
for desc in sdfg.arrays:
desc_states_with_nodes = set(access_nodes[desc].keys())
for state in desc_states_with_nodes:
for read_node in access_nodes[desc][state][0]:
write = self._find_dominating_write(desc, state, read_node, access_nodes, idom, access_sets)
result[desc][write].add((state, read_node))
# Ensure accesses to interstate edges are also considered.
for state, accesses in access_sets.items():
if desc in accesses[0]:
out_edges = sdfg.out_edges(state)
for oedge in out_edges:
syms = oedge.data.free_symbols & anames
if desc in syms:
write = self._find_dominating_write(
desc, state, oedge.data, access_nodes, idom, access_sets
)
result[desc][write].add((state, oedge.data))
top_result[sdfg.sdfg_id] = result
return top_result
6 changes: 3 additions & 3 deletions dace/transformation/passes/array_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def should_reapply(self, modified: ppl.Modifies) -> bool:
return modified & ppl.Modifies.AccessNodes

def depends_on(self):
return {ap.StateReachability, ap.FindAccessNodes}
return {ap.StateReachability, ap.FindAccessStates}

def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Set[str]]:
"""
Expand All @@ -41,9 +41,9 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[S
:return: A set of removed data descriptor names, or None if nothing changed.
"""
result: Set[str] = set()
reachable: Dict[SDFGState, Set[SDFGState]] = pipeline_results['StateReachability'][sdfg.sdfg_id]
reachable: Dict[SDFGState, Set[SDFGState]] = pipeline_results[ap.StateReachability.__name__][sdfg.sdfg_id]
# Get access nodes and modify set as pass continues
access_sets: Dict[str, Set[SDFGState]] = pipeline_results['FindAccessNodes'][sdfg.sdfg_id]
access_sets: Dict[str, Set[SDFGState]] = pipeline_results[ap.FindAccessStates.__name__][sdfg.sdfg_id]

# Traverse SDFG backwards
try:
Expand Down
89 changes: 89 additions & 0 deletions dace/transformation/passes/scalar_fission.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
from collections import defaultdict
from typing import Any, Dict, Optional, Set

from dace import SDFG, InterstateEdge
from dace.sdfg import nodes as nd
from dace.transformation import pass_pipeline as ppl
from dace.transformation.passes import analysis as ap


class ScalarFission(ppl.Pass):
"""
Fission transient scalars or arrays of size 1 that are dominated by a write into separate data containers.
"""

CATEGORY: str = 'Optimization Preparation'

def modifies(self) -> ppl.Modifies:
return ppl.Modifies.Descriptors | ppl.Modifies.AccessNodes

def should_reapply(self, modified: ppl.Modifies) -> bool:
return modified & ppl.Modifies.AccessNodes

def depends_on(self):
return {ap.ScalarWriteShadowScopes}

def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Dict[str, Set[str]]]:
"""
Rename scalars and arrays of size 1 based on dominated scopes.

:param sdfg: The SDFG to modify.
:param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass
results as ``{Pass subclass name: returned object from pass}``. If not run in a
pipeline, an empty dictionary is expected.
:return: A dictionary mapping the original name to a set of all new names created for each data container.
"""
results: Dict[str, Set[str]] = defaultdict(lambda: set())

shadow_scope_dict: ap.WriteScopeDict = pipeline_results[ap.ScalarWriteShadowScopes.__name__][sdfg.sdfg_id]

for name, write_scope_dict in shadow_scope_dict.items():
desc = sdfg.arrays[name]

# If this isn't a scalar or an array of size 1, don't do anything.
if desc.total_size != 1:
continue

# If there is only one scope, don't do anything.
if len(write_scope_dict) <= 1:
continue

# Don't rename anything that's not transient, as it may be used externally.
if not desc.transient:
continue

for write, shadowed_reads in write_scope_dict.items():
if write is not None:
newdesc = desc.clone()
newname = sdfg.add_datadesc(name, newdesc, find_new_name=True)

# Replace the write and any connected memlets with writes to the new data container.
write_node = write[1]
write_node.data = newname
for iedge in write[0].in_edges(write_node):
if iedge.data.data == name:
iedge.data.data = newname
for oeade in write[0].out_edges(write_node):
if oeade.data.data == name:
oeade.data.data = newname

# Replace all dominated reads and connected memlets.
for read in shadowed_reads:
if isinstance(read[1], nd.AccessNode):
read_node = read[1]
read_node.data = newname
for iedge in read[0].in_edges(read_node):
if iedge.data.data == name:
iedge.data.data = newname
for oeade in read[0].out_edges(read_node):
if oeade.data.data == name:
oeade.data.data = newname
elif isinstance(read[1], InterstateEdge):
read[1].replace_dict({name: newname})

results[name].add(newname)
return results

def report(self, pass_retval: Any) -> Optional[str]:
return f'Renamed {len(pass_retval)} scalars: {pass_retval}.'