Skip to content

Commit

Permalink
Merge branch 'master' into wcr
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun committed Nov 2, 2023
2 parents adeca97 + dff301c commit ebd8015
Show file tree
Hide file tree
Showing 16 changed files with 1,337 additions and 509 deletions.
6 changes: 5 additions & 1 deletion dace/cli/sdfv.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ def view(sdfg: dace.SDFG, filename: Optional[Union[str, int]] = None):
"""
# If vscode is open, try to open it inside vscode
if filename is None:
if 'VSCODE_IPC_HOOK_CLI' in os.environ or 'VSCODE_GIT_IPC_HANDLE' in os.environ:
if (
'VSCODE_IPC_HOOK' in os.environ
or 'VSCODE_IPC_HOOK_CLI' in os.environ
or 'VSCODE_GIT_IPC_HANDLE' in os.environ
):
filename = tempfile.mktemp(suffix='.sdfg')
sdfg.save(filename)
os.system(f'code {filename}')
Expand Down
13 changes: 8 additions & 5 deletions dace/codegen/instrumentation/papi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from dace.sdfg.graph import SubgraphView
from dace.memlet import Memlet
from dace.sdfg import scope_contains_scope
from dace.sdfg.state import StateGraphView
from dace.sdfg.state import DataflowGraphView

import sympy as sp
import os
Expand Down Expand Up @@ -392,7 +392,7 @@ def should_instrument_entry(map_entry: EntryNode) -> bool:
return cond

@staticmethod
def has_surrounding_perfcounters(node, dfg: StateGraphView):
def has_surrounding_perfcounters(node, dfg: DataflowGraphView):
""" Returns true if there is a possibility that this node is part of a
section that is profiled. """
parent = dfg.entry_node(node)
Expand Down Expand Up @@ -605,7 +605,7 @@ def get_memlet_byte_size(sdfg: dace.SDFG, memlet: Memlet):
return memlet.volume * memdata.dtype.bytes

@staticmethod
def get_out_memlet_costs(sdfg: dace.SDFG, state_id: int, node: nodes.Node, dfg: StateGraphView):
def get_out_memlet_costs(sdfg: dace.SDFG, state_id: int, node: nodes.Node, dfg: DataflowGraphView):
scope_dict = sdfg.node(state_id).scope_dict()

out_costs = 0
Expand Down Expand Up @@ -636,7 +636,10 @@ def get_out_memlet_costs(sdfg: dace.SDFG, state_id: int, node: nodes.Node, dfg:
return out_costs

@staticmethod
def get_tasklet_byte_accesses(tasklet: nodes.CodeNode, dfg: StateGraphView, sdfg: dace.SDFG, state_id: int) -> str:
def get_tasklet_byte_accesses(tasklet: nodes.CodeNode,
dfg: DataflowGraphView,
sdfg: dace.SDFG,
state_id: int) -> str:
""" Get the amount of bytes processed by `tasklet`. The formula is
sum(inedges * size) + sum(outedges * size) """
in_accum = []
Expand Down Expand Up @@ -693,7 +696,7 @@ def get_memory_input_size(node, sdfg, state_id) -> str:
return sym2cpp(input_size)

@staticmethod
def accumulate_byte_movement(outermost_node, node, dfg: StateGraphView, sdfg, state_id):
def accumulate_byte_movement(outermost_node, node, dfg: DataflowGraphView, sdfg, state_id):

itvars = dict() # initialize an empty dict

Expand Down
2 changes: 1 addition & 1 deletion dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def remove_name_collisions(sdfg: SDFG):
# Rename duplicate states
for state in nsdfg.nodes():
if state.label in state_names_seen:
state.set_label(data.find_new_name(state.label, state_names_seen))
state.label = data.find_new_name(state.label, state_names_seen)
state_names_seen.add(state.label)

replacements: Dict[str, str] = {}
Expand Down
5 changes: 2 additions & 3 deletions dace/sdfg/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,8 @@ def label(self):
def __label__(self, sdfg, state):
return self.data

def desc(self, sdfg):
from dace.sdfg import SDFGState, ScopeSubgraphView
if isinstance(sdfg, (SDFGState, ScopeSubgraphView)):
def desc(self, sdfg: Union['dace.sdfg.SDFG', 'dace.sdfg.SDFGState', 'dace.sdfg.ScopeSubgraphView']):
if isinstance(sdfg, (dace.sdfg.SDFGState, dace.sdfg.ScopeSubgraphView)):
sdfg = sdfg.parent
return sdfg.arrays[self.data]

Expand Down
29 changes: 15 additions & 14 deletions dace/sdfg/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,17 +175,18 @@ def replace_datadesc_names(sdfg, repl: Dict[str, str]):
sdfg.constants_prop[repl[aname]] = sdfg.constants_prop[aname]
del sdfg.constants_prop[aname]

# Replace in interstate edges
for e in sdfg.edges():
e.data.replace_dict(repl, replace_keys=False)

for state in sdfg.nodes():
# Replace in access nodes
for node in state.data_nodes():
if node.data in repl:
node.data = repl[node.data]

# Replace in memlets
for edge in state.edges():
if edge.data.data in repl:
edge.data.data = repl[edge.data.data]
for cf in sdfg.all_control_flow_regions():
# Replace in interstate edges
for e in cf.edges():
e.data.replace_dict(repl, replace_keys=False)

for block in cf.nodes():
if isinstance(block, dace.SDFGState):
# Replace in access nodes
for node in block.data_nodes():
if node.data in repl:
node.data = repl[node.data]
# Replace in memlets
for edge in block.edges():
if edge.data.data in repl:
edge.data.data = repl[edge.data.data]

0 comments on commit ebd8015

Please sign in to comment.