Skip to content

Commit

Permalink
Merge pull request #1333 from spcl/validate-nsdfg-inout
Browse files Browse the repository at this point in the history
Validate NestedSDFG inout connectors
  • Loading branch information
alexnick83 committed Aug 3, 2023
2 parents 8adcea6 + bacc292 commit dc70efe
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 1 deletion.
18 changes: 18 additions & 0 deletions dace/sdfg/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,24 @@ def validate(self, sdfg, state, references: Optional[Set[int]] = None, **context
raise NameError('Data descriptor "%s" not found in nested SDFG connectors' % dname)
if dname in connectors and desc.transient:
raise NameError('"%s" is a connector but its corresponding array is transient' % dname)

# Validate inout connectors
from dace.sdfg import utils # Avoids circular import
inout_connectors = self.in_connectors.keys() & self.out_connectors.keys()
for conn in inout_connectors:
inputs = set()
outputs = set()
for edge in state.in_edges_by_connector(self, conn):
src = utils.get_global_memlet_path_src(sdfg, state, edge)
if isinstance(src, AccessNode):
inputs.add(src.data)
for edge in state.out_edges_by_connector(self, conn):
dst = utils.get_global_memlet_path_dst(sdfg, state, edge)
if isinstance(dst, AccessNode):
outputs.add(dst.data)
if len(inputs - outputs) > 0:
raise ValueError(f"Inout connector {conn} is connected to different input ({inputs}) and "
f"output ({outputs}) arrays")

# Validate undefined symbols
symbols = set(k for k in self.sdfg.free_symbols if k not in connectors)
Expand Down
42 changes: 42 additions & 0 deletions dace/sdfg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1797,3 +1797,45 @@ def get_thread_local_data(sdfg: SDFG) -> List[str]:
if not sdfg.arrays[name].transient:
warnings.warn(f'Found thread-local data "{name}" that is not transient.')
return result


def get_global_memlet_path_src(sdfg: SDFG, state: SDFGState, edge: MultiConnectorEdge) -> nd.Node:
"""
Finds the global source node of an edge/memlet path, crossing nested SDFG scopes.
:param sdfg: The SDFG containing the edge.
:param state: The state containing the edge.
:param edge: The edge to find the global source node for.
:return: The global source node of the edge.
"""
src = state.memlet_path(edge)[0].src
if isinstance(src, nd.AccessNode) and not sdfg.arrays[src.data].transient and sdfg.parent is not None:
psdfg = sdfg.parent_sdfg
pstate = sdfg.parent
pnode = sdfg.parent_nsdfg_node
pedges = list(pstate.in_edges_by_connector(pnode, src.data))
if len(pedges) > 0:
pedge = pedges[0]
return get_global_memlet_path_src(psdfg, pstate, pedge)
return src


def get_global_memlet_path_dst(sdfg: SDFG, state: SDFGState, edge: MultiConnectorEdge) -> nd.Node:
"""
Finds the global destination node of an edge/memlet path, crossing nested SDFG scopes.
:param sdfg: The SDFG containing the edge.
:param state: The state containing the edge.
:param edge: The edge to find the global destination node for.
:return: The global destination node of the edge.
"""
dst = state.memlet_path(edge)[-1].dst
if isinstance(dst, nd.AccessNode) and not sdfg.arrays[dst.data].transient and sdfg.parent is not None:
psdfg = sdfg.parent_sdfg
pstate = sdfg.parent
pnode = sdfg.parent_nsdfg_node
pedges = list(pstate.out_edges_by_connector(pnode, dst.data))
if len(pedges) > 0:
pedge = pedges[0]
return get_global_memlet_path_dst(psdfg, pstate, pedge)
return dst
7 changes: 6 additions & 1 deletion dace/transformation/subgraph/subgraph_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,10 +1146,15 @@ def change_data(transient_array, shape, strides, total_size, offset, lifetime, s
# by reconnecting their adjacent edges to nodes outside the subgraph.
# NOTE: Currently limited to cases where there is a single source and sink
# if there are multiple intermediate accesses for the same data.
# NOTE: Currently limited to intermediate data that do not have a separate output node

# Filter out outputs
output_data = set([n.data for n in out_nodes])
true_intermediate_nodes = set([n for n in intermediate_nodes if n.data not in output_data])

# Sort intermediate nodes by data name
intermediate_data = dict()
for acc in intermediate_nodes:
for acc in true_intermediate_nodes:
if acc.data in intermediate_data:
intermediate_data[acc.data].append(acc)
else:
Expand Down
69 changes: 69 additions & 0 deletions tests/sdfg/validation/nested_sdfg_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
import dace


def test_inout_connector_validation_success():

sdfg = dace.SDFG("test_inout_connector_validation_success")
sdfg.add_array("A", [1], dace.int32)
sdfg.add_array("B", [1], dace.int32)

nsdfg = dace.SDFG("nested_sdfg")
nsdfg.add_array("C", [1], dace.int32)

nstate = nsdfg.add_state()
read_c = nstate.add_access("C")
write_c = nstate.add_access("C")
tasklet = nstate.add_tasklet("tasklet", {"__inp"}, {"__out"}, "__out = __inp + 5")
nstate.add_edge(read_c, None, tasklet, '__inp', dace.Memlet.from_array('C', nsdfg.arrays['C']))
nstate.add_edge(tasklet, '__out', write_c, None, dace.Memlet.from_array('C', nsdfg.arrays['C']))

state = sdfg.add_state()
read_b = state.add_access("B")
write_b = state.add_access("B")
tasklet = state.add_nested_sdfg(nsdfg, sdfg, {"C"}, {"C"})
state.add_edge(read_b, None, tasklet, 'C', dace.Memlet.from_array('B', sdfg.arrays['B']))
state.add_edge(tasklet, 'C', write_b, None, dace.Memlet.from_array('B', sdfg.arrays['B']))

try:
sdfg.validate()
except dace.sdfg.InvalidSDFGError:
assert False, "SDFG should validate"

return


def test_inout_connector_validation_fail():

sdfg = dace.SDFG("test_inout_connector_validation_fail")
sdfg.add_array("A", [1], dace.int32)
sdfg.add_array("B", [1], dace.int32)

nsdfg = dace.SDFG("nested_sdfg")
nsdfg.add_array("C", [1], dace.int32)

nstate = nsdfg.add_state()
read_c = nstate.add_access("C")
write_c = nstate.add_access("C")
tasklet = nstate.add_tasklet("tasklet", {"__inp"}, {"__out"}, "__out = __inp + 5")
nstate.add_edge(read_c, None, tasklet, '__inp', dace.Memlet.from_array('C', nsdfg.arrays['C']))
nstate.add_edge(tasklet, '__out', write_c, None, dace.Memlet.from_array('C', nsdfg.arrays['C']))

state = sdfg.add_state()
read_a = state.add_access("A")
write_b = state.add_access("B")
tasklet = state.add_nested_sdfg(nsdfg, sdfg, {"C"}, {"C"})
state.add_edge(read_a, None, tasklet, 'C', dace.Memlet.from_array('A', sdfg.arrays['A']))
state.add_edge(tasklet, 'C', write_b, None, dace.Memlet.from_array('B', sdfg.arrays['B']))

try:
sdfg.validate()
except dace.sdfg.InvalidSDFGError:
return

assert False, "SDFG should not validate"


if __name__ == "__main__":
test_inout_connector_validation_success()
test_inout_connector_validation_fail()

0 comments on commit dc70efe

Please sign in to comment.