Skip to content

Commit

Permalink
PruneConnectors: Fission into separate states before pruning
Browse files Browse the repository at this point in the history
  • Loading branch information
lukastruemper committed Nov 29, 2023
1 parent 8f229bc commit 50befc8
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 38 deletions.
87 changes: 53 additions & 34 deletions dace/transformation/dataflow/prune_connectors.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
from os import stat
from typing import Any, AnyStr, Dict, Optional, Set, Tuple, Union
from typing import Set, Tuple
import re

from dace import dtypes, registry, SDFG, SDFGState, symbolic, properties, data as dt
from dace import dtypes, SDFG, SDFGState, symbolic, properties, data as dt
from dace.transformation import transformation as pm, helpers
from dace.sdfg import nodes, utils
from dace.sdfg.analysis import cfg
from dace.sdfg.state import StateSubgraphView


@properties.make_properties
Expand Down Expand Up @@ -46,60 +46,79 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi
# Add WCR outputs to "do not prune" input list
for e in graph.out_edges(nsdfg):
if e.data.wcr is not None and e.src_conn in prune_in:
if (graph.in_degree(next(iter(graph.in_edges_by_connector(nsdfg, e.src_conn))).src) > 0):
prune_in.remove(e.src_conn)
has_before = all(
graph.in_degree(graph.memlet_path(e)[0].src) > 0 for e in graph.in_edges(nsdfg) if e.dst_conn in prune_in)
has_after = all(
graph.out_degree(graph.memlet_path(e)[-1].dst) > 0 for e in graph.out_edges(nsdfg)
if e.src_conn in prune_out)
if has_before and has_after:
prune_in.remove(e.src_conn)

if not prune_in and not prune_out:
return False
if len(prune_in) > 0 or len(prune_out) > 0:
return True

return False
return True

def apply(self, state: SDFGState, sdfg: SDFG):
nsdfg = self.nsdfg

# Fission subgraph around nsdfg into its own state to avoid data races
predecessors = set()
for inedge in state.in_edges(nsdfg):
if inedge.data is None:
continue

pred = state.memlet_path(inedge)[0].src
if state.in_degree(pred) == 0:
continue

predecessors.add(pred)
for e in state.bfs_edges(pred, reverse=True):
predecessors.add(e.src)

subgraph = StateSubgraphView(state, predecessors)
pred_state = helpers.state_fission(sdfg, subgraph)

subgraph_nodes = set()
subgraph_nodes.add(nsdfg)
for inedge in state.in_edges(nsdfg):
if inedge.data is None:
continue
path = state.memlet_path(inedge)
for edge in path:
subgraph_nodes.add(edge.src)

for oedge in state.out_edges(nsdfg):
if oedge.data is None:
continue
path = state.memlet_path(oedge)
for edge in path:
subgraph_nodes.add(edge.dst)

subgraph = StateSubgraphView(state, subgraph_nodes)
nsdfg_state = helpers.state_fission(sdfg, subgraph)

read_set, write_set = nsdfg.sdfg.read_and_write_sets()
prune_in = nsdfg.in_connectors.keys() - read_set
prune_out = nsdfg.out_connectors.keys() - write_set

# Detect which nodes are used, so we can delete unused nodes after the
# connectors have been pruned
all_data_used = read_set | write_set

# Add WCR outputs to "do not prune" input list
for e in state.out_edges(nsdfg):
for e in nsdfg_state.out_edges(nsdfg):
if e.data.wcr is not None and e.src_conn in prune_in:
if (state.in_degree(next(iter(state.in_edges_by_connector(nsdfg, e.src_conn))).src) > 0):
prune_in.remove(e.src_conn)
do_not_prune = set()
prune_in.remove(e.src_conn)

for conn in prune_in:
if any(
state.in_degree(state.memlet_path(e)[0].src) > 0 for e in state.in_edges(nsdfg)
if e.dst_conn == conn):
do_not_prune.add(conn)
continue
for e in state.in_edges_by_connector(nsdfg, conn):
state.remove_memlet_path(e, remove_orphans=True)
for e in nsdfg_state.in_edges_by_connector(nsdfg, conn):
nsdfg_state.remove_memlet_path(e, remove_orphans=True)

for conn in prune_out:
if any(
state.out_degree(state.memlet_path(e)[-1].dst) > 0 for e in state.out_edges(nsdfg)
if e.src_conn == conn):
do_not_prune.add(conn)
continue
for e in state.out_edges_by_connector(nsdfg, conn):
state.remove_memlet_path(e, remove_orphans=True)
for e in nsdfg_state.out_edges_by_connector(nsdfg, conn):
nsdfg_state.remove_memlet_path(e, remove_orphans=True)

for conn in prune_in:
if conn in nsdfg.sdfg.arrays and conn not in all_data_used and conn not in do_not_prune:
if conn in nsdfg.sdfg.arrays and conn not in all_data_used:
# If the data is now unused, we can purge it from the SDFG
nsdfg.sdfg.remove_data(conn)
for conn in prune_out:
if conn in nsdfg.sdfg.arrays and conn not in all_data_used and conn not in do_not_prune:
if conn in nsdfg.sdfg.arrays and conn not in all_data_used:
# If the data is now unused, we can purge it from the SDFG
nsdfg.sdfg.remove_data(conn)

Expand Down
14 changes: 10 additions & 4 deletions tests/npbench/polybench/floyd_warshall_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
import argparse
from dace.fpga_testing import fpga_test
from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG
from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG, StateFusion
from dace.transformation.dataflow import StreamingMemory, MapFusion, StreamingComposition, PruneConnectors
from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt

Expand Down Expand Up @@ -91,15 +91,21 @@ def run_floyd_warshall(device_type: dace.dtypes.DeviceType):
}])

assert pruned_conns == 1
sdfg.apply_transformations_repeated(StateFusion)

fpga_auto_opt.fpga_rr_interleave_containers_to_banks(sdfg)

# In this case, we want to generate the top-level state as an host-based state,
# not an FPGA kernel. We need to explicitly indicate that
sdfg.states()[0].location["is_FPGA_kernel"] = False
sdfg.start_state.location["is_FPGA_kernel"] = False

# we need to specialize both the top-level SDFG and the nested SDFG
sdfg.specialize(dict(N=N))
sdfg.states()[0].nodes()[0].sdfg.specialize(dict(N=N))
for state in sdfg.states():
for node in state.nodes():
if isinstance(node, dace.nodes.NestedSDFG):
node.sdfg.specialize(dict(N=N))

# run program
sdfg(path=path)

Expand All @@ -126,7 +132,7 @@ def test_fpga():
if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform')
parser.add_argument("-t", "--target", default='fpga', choices=['cpu', 'gpu', 'fpga'], help='Target platform')

args = vars(parser.parse_args())
target = args["target"]
Expand Down
82 changes: 82 additions & 0 deletions tests/transformations/prune_connectors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
import argparse
import numpy as np
import os
import copy
import pytest
import dace
from dace.transformation.dataflow import PruneConnectors
from dace.transformation.helpers import nest_state_subgraph
from dace.sdfg.state import StateSubgraphView


def make_sdfg():
Expand Down Expand Up @@ -237,6 +240,84 @@ def test_unused_retval_2():
assert np.allclose(a, 1)


def test_prune_connectors_with_dependencies():
sdfg = dace.SDFG('tester')
A, A_desc = sdfg.add_array('A', [4], dace.float64)
B, B_desc = sdfg.add_array('B', [4], dace.float64)
C, C_desc = sdfg.add_array('C', [4], dace.float64)
D, D_desc = sdfg.add_array('D', [4], dace.float64)

state = sdfg.add_state()
a = state.add_access("A")
b1 = state.add_access("B")
b2 = state.add_access("B")
c1 = state.add_access("C")
c2 = state.add_access("C")
d = state.add_access("D")

_, map_entry_a, map_exit_a = state.add_mapped_tasklet("a",
map_ranges={"i": "0:4"},
inputs={"_in": dace.Memlet(data="A", subset='i')},
outputs={"_out": dace.Memlet(data="B", subset='i')},
code="_out = _in + 1")
state.add_edge(a, None, map_entry_a, None, dace.Memlet(data="A", subset="0:4"))
state.add_edge(map_exit_a, None, b1, None, dace.Memlet(data="B", subset="0:4"))

tasklet_c, map_entry_c, map_exit_c = state.add_mapped_tasklet("c",
map_ranges={"i": "0:4"},
inputs={"_in": dace.Memlet(data="C", subset='i')},
outputs={"_out": dace.Memlet(data="C", subset='i')},
code="_out = _in + 1")
state.add_edge(c1, None, map_entry_c, None, dace.Memlet(data="C", subset="0:4"))
state.add_edge(map_exit_c, None, c2, None, dace.Memlet(data="C", subset="0:4"))

_, map_entry_d, map_exit_d = state.add_mapped_tasklet("d",
map_ranges={"i": "0:4"},
inputs={"_in": dace.Memlet(data="B", subset='i')},
outputs={"_out": dace.Memlet(data="D", subset='i')},
code="_out = _in + 1")
state.add_edge(b2, None, map_entry_d, None, dace.Memlet(data="B", subset="0:4"))
state.add_edge(map_exit_d, None, d, None, dace.Memlet(data="D", subset="0:4"))

sdfg.fill_scope_connectors()

subgraph = StateSubgraphView(state, subgraph_nodes=[map_entry_c, map_exit_c, tasklet_c])
nsdfg_node = nest_state_subgraph(sdfg, state, subgraph=subgraph)

nsdfg_node.sdfg.add_datadesc("B1", datadesc=copy.deepcopy(B_desc))
nsdfg_node.sdfg.arrays["B1"].transient = False
nsdfg_node.sdfg.add_datadesc("B2", datadesc=copy.deepcopy(B_desc))
nsdfg_node.sdfg.arrays["B2"].transient = False

nsdfg_node.add_in_connector("B1")
state.add_edge(b1, None, nsdfg_node, "B1", dace.Memlet.from_array(dataname="B", datadesc=B_desc))
nsdfg_node.add_out_connector("B2")
state.add_edge(nsdfg_node, "B2", b2, None, dace.Memlet.from_array(dataname="B", datadesc=B_desc))

np_a = np.random.random(4)
np_a_ = np.copy(np_a)
np_b = np.random.random(4)
np_b_ = np.copy(np_b)
np_c = np.random.random(4)
np_c_ = np.copy(np_c)
np_d = np.random.random(4)
np_d_ = np.copy(np_d)

sdfg(A=np_a, B=np_b, C=np_c, D=np_d)

applied = sdfg.apply_transformations_repeated(PruneConnectors)
assert applied == 1
assert len(sdfg.states()) == 3
assert "B1" not in nsdfg_node.in_connectors
assert "B2" not in nsdfg_node.out_connectors

sdfg(A=np_a_, B=np_b_, C=np_c_, D=np_d_)
assert np.allclose(np_a, np_a_)
assert np.allclose(np_b, np_b_)
assert np.allclose(np_c, np_c_)
assert np.allclose(np_d, np_d_)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--N", default=64)
Expand All @@ -248,3 +329,4 @@ def test_unused_retval_2():
test_prune_connectors(True, n=n)
test_unused_retval()
test_unused_retval_2()
test_prune_connectors_with_dependencies()

0 comments on commit 50befc8

Please sign in to comment.