Skip to content

Commit

Permalink
PruneConnectors: Fission into separate states before pruning
Browse files Browse the repository at this point in the history
  • Loading branch information
lukastruemper committed Nov 27, 2023
1 parent 8f229bc commit d0d8722
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 34 deletions.
87 changes: 53 additions & 34 deletions dace/transformation/dataflow/prune_connectors.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
from os import stat
from typing import Any, AnyStr, Dict, Optional, Set, Tuple, Union
from typing import Set, Tuple
import re

from dace import dtypes, registry, SDFG, SDFGState, symbolic, properties, data as dt
from dace import dtypes, SDFG, SDFGState, symbolic, properties, data as dt
from dace.transformation import transformation as pm, helpers
from dace.sdfg import nodes, utils
from dace.sdfg.analysis import cfg
from dace.sdfg.state import StateSubgraphView


@properties.make_properties
Expand Down Expand Up @@ -46,60 +46,79 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi
# Add WCR outputs to "do not prune" input list
for e in graph.out_edges(nsdfg):
if e.data.wcr is not None and e.src_conn in prune_in:
if (graph.in_degree(next(iter(graph.in_edges_by_connector(nsdfg, e.src_conn))).src) > 0):
prune_in.remove(e.src_conn)
has_before = all(
graph.in_degree(graph.memlet_path(e)[0].src) > 0 for e in graph.in_edges(nsdfg) if e.dst_conn in prune_in)
has_after = all(
graph.out_degree(graph.memlet_path(e)[-1].dst) > 0 for e in graph.out_edges(nsdfg)
if e.src_conn in prune_out)
if has_before and has_after:
prune_in.remove(e.src_conn)

if not prune_in and not prune_out:
return False
if len(prune_in) > 0 or len(prune_out) > 0:
return True

return False
return True

def apply(self, state: SDFGState, sdfg: SDFG):
nsdfg = self.nsdfg

# Fission subgraph around nsdfg into its own state to avoid data races
predecessors = set()
for inedge in state.in_edges(nsdfg):
if inedge.data is None:
continue

pred = state.memlet_path(inedge)[0].src
if state.in_degree(pred) == 0:
continue

predecessors.add(pred)
for e in state.bfs_edges(pred, reverse=True):
predecessors.add(e.src)

subgraph = StateSubgraphView(state, predecessors)
pred_state = helpers.state_fission(sdfg, subgraph)

subgraph_nodes = set()
subgraph_nodes.add(nsdfg)
for inedge in state.in_edges(nsdfg):
if inedge.data is None:
continue
path = state.memlet_path(inedge)
for edge in path:
subgraph_nodes.add(edge.src)

for oedge in state.out_edges(nsdfg):
if oedge.data is None:
continue
path = state.memlet_path(oedge)
for edge in path:
subgraph_nodes.add(edge.dst)

subgraph = StateSubgraphView(state, subgraph_nodes)
nsdfg_state = helpers.state_fission(sdfg, subgraph)

read_set, write_set = nsdfg.sdfg.read_and_write_sets()
prune_in = nsdfg.in_connectors.keys() - read_set
prune_out = nsdfg.out_connectors.keys() - write_set

# Detect which nodes are used, so we can delete unused nodes after the
# connectors have been pruned
all_data_used = read_set | write_set

# Add WCR outputs to "do not prune" input list
for e in state.out_edges(nsdfg):
for e in nsdfg_state.out_edges(nsdfg):
if e.data.wcr is not None and e.src_conn in prune_in:
if (state.in_degree(next(iter(state.in_edges_by_connector(nsdfg, e.src_conn))).src) > 0):
prune_in.remove(e.src_conn)
do_not_prune = set()
prune_in.remove(e.src_conn)

for conn in prune_in:
if any(
state.in_degree(state.memlet_path(e)[0].src) > 0 for e in state.in_edges(nsdfg)
if e.dst_conn == conn):
do_not_prune.add(conn)
continue
for e in state.in_edges_by_connector(nsdfg, conn):
state.remove_memlet_path(e, remove_orphans=True)
for e in nsdfg_state.in_edges_by_connector(nsdfg, conn):
nsdfg_state.remove_memlet_path(e, remove_orphans=True)

for conn in prune_out:
if any(
state.out_degree(state.memlet_path(e)[-1].dst) > 0 for e in state.out_edges(nsdfg)
if e.src_conn == conn):
do_not_prune.add(conn)
continue
for e in state.out_edges_by_connector(nsdfg, conn):
state.remove_memlet_path(e, remove_orphans=True)
for e in nsdfg_state.out_edges_by_connector(nsdfg, conn):
nsdfg_state.remove_memlet_path(e, remove_orphans=True)

for conn in prune_in:
if conn in nsdfg.sdfg.arrays and conn not in all_data_used and conn not in do_not_prune:
if conn in nsdfg.sdfg.arrays and conn not in all_data_used:
# If the data is now unused, we can purge it from the SDFG
nsdfg.sdfg.remove_data(conn)
for conn in prune_out:
if conn in nsdfg.sdfg.arrays and conn not in all_data_used and conn not in do_not_prune:
if conn in nsdfg.sdfg.arrays and conn not in all_data_used:
# If the data is now unused, we can purge it from the SDFG
nsdfg.sdfg.remove_data(conn)

Expand Down
82 changes: 82 additions & 0 deletions tests/transformations/prune_connectors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
import argparse
import numpy as np
import os
import copy
import pytest
import dace
from dace.transformation.dataflow import PruneConnectors
from dace.transformation.helpers import nest_state_subgraph
from dace.sdfg.state import StateSubgraphView


def make_sdfg():
Expand Down Expand Up @@ -237,6 +240,84 @@ def test_unused_retval_2():
assert np.allclose(a, 1)


def test_prune_connectors_with_dependencies():
sdfg = dace.SDFG('tester')
A, A_desc = sdfg.add_array('A', [4], dace.float64)
B, B_desc = sdfg.add_array('B', [4], dace.float64)
C, C_desc = sdfg.add_array('C', [4], dace.float64)
D, D_desc = sdfg.add_array('D', [4], dace.float64)

state = sdfg.add_state()
a = state.add_access("A")
b1 = state.add_access("B")
b2 = state.add_access("B")
c1 = state.add_access("C")
c2 = state.add_access("C")
d = state.add_access("D")

_, map_entry_a, map_exit_a = state.add_mapped_tasklet("a",
map_ranges={"i": "0:4"},
inputs={"_in": dace.Memlet(data="A", subset='i')},
outputs={"_out": dace.Memlet(data="B", subset='i')},
code="_out = _in + 1")
state.add_edge(a, None, map_entry_a, None, dace.Memlet(data="A", subset="0:4"))
state.add_edge(map_exit_a, None, b1, None, dace.Memlet(data="B", subset="0:4"))

tasklet_c, map_entry_c, map_exit_c = state.add_mapped_tasklet("c",
map_ranges={"i": "0:4"},
inputs={"_in": dace.Memlet(data="C", subset='i')},
outputs={"_out": dace.Memlet(data="C", subset='i')},
code="_out = _in + 1")
state.add_edge(c1, None, map_entry_c, None, dace.Memlet(data="C", subset="0:4"))
state.add_edge(map_exit_c, None, c2, None, dace.Memlet(data="C", subset="0:4"))

_, map_entry_d, map_exit_d = state.add_mapped_tasklet("d",
map_ranges={"i": "0:4"},
inputs={"_in": dace.Memlet(data="B", subset='i')},
outputs={"_out": dace.Memlet(data="D", subset='i')},
code="_out = _in + 1")
state.add_edge(b2, None, map_entry_d, None, dace.Memlet(data="B", subset="0:4"))
state.add_edge(map_exit_d, None, d, None, dace.Memlet(data="D", subset="0:4"))

sdfg.fill_scope_connectors()

subgraph = StateSubgraphView(state, subgraph_nodes=[map_entry_c, map_exit_c, tasklet_c])
nsdfg_node = nest_state_subgraph(sdfg, state, subgraph=subgraph)

nsdfg_node.sdfg.add_datadesc("B1", datadesc=copy.deepcopy(B_desc))
nsdfg_node.sdfg.arrays["B1"].transient = False
nsdfg_node.sdfg.add_datadesc("B2", datadesc=copy.deepcopy(B_desc))
nsdfg_node.sdfg.arrays["B2"].transient = False

nsdfg_node.add_in_connector("B1")
state.add_edge(b1, None, nsdfg_node, "B1", dace.Memlet.from_array(dataname="B", datadesc=B_desc))
nsdfg_node.add_out_connector("B2")
state.add_edge(nsdfg_node, "B2", b2, None, dace.Memlet.from_array(dataname="B", datadesc=B_desc))

np_a = np.random.random(4)
np_a_ = np.copy(np_a)
np_b = np.random.random(4)
np_b_ = np.copy(np_b)
np_c = np.random.random(4)
np_c_ = np.copy(np_c)
np_d = np.random.random(4)
np_d_ = np.copy(np_d)

sdfg(A=np_a, B=np_b, C=np_c, D=np_d)

applied = sdfg.apply_transformations_repeated(PruneConnectors)
assert applied == 1
assert len(sdfg.states()) == 3
assert "B1" not in nsdfg_node.in_connectors
assert "B2" not in nsdfg_node.out_connectors

sdfg(A=np_a_, B=np_b_, C=np_c_, D=np_d_)
assert np.allclose(np_a, np_a_)
assert np.allclose(np_b, np_b_)
assert np.allclose(np_c, np_c_)
assert np.allclose(np_d, np_d_)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--N", default=64)
Expand All @@ -248,3 +329,4 @@ def test_unused_retval_2():
test_prune_connectors(True, n=n)
test_unused_retval()
test_unused_retval_2()
test_prune_connectors_with_dependencies()

0 comments on commit d0d8722

Please sign in to comment.