Skip to content

Commit

Permalink
OTFMapFusion: Minor bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
lukastruemper committed Nov 25, 2023
1 parent 4139ddf commit ff9e2e2
Showing 1 changed file with 15 additions and 19 deletions.
34 changes: 15 additions & 19 deletions dace/transformation/dataflow/otf_map_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,27 +132,28 @@ def apply(self, graph: SDFGState, sdfg: SDFG):
intermediate_access_node = self.array
first_map_exit = self.first_map_exit
first_map_entry = graph.entry_node(first_map_exit)
second_map_entry = self.second_map_entry

# Prepare: Make first and second map parameters disjoint
# This avoids mutual matching: i -> j, j -> i
subgraph = graph.scope_subgraph(first_map_entry, include_entry=True, include_exit=True)
for param in first_map_entry.map.params:
i = 0
new_param = f"_i{i}"
while new_param in self.second_map_entry.map.params or new_param in first_map_entry.map.params:
while new_param in second_map_entry.map.params or new_param in first_map_entry.map.params:
i = i + 1
new_param = f"_i{i}"

advanced_replace(subgraph, param, new_param)

# Prepare: Preemptively rename params defined by second map in scope of first
# This avoids that local variables (e.g., in nested SDFG) have collisions with new map scope
for param in self.second_map_entry.map.params:
for param in second_map_entry.map.params:
new_param = param + "_local"
advanced_replace(subgraph, param, new_param)

# Add local buffers for array-like OTFs
for edge in graph.out_edges(self.second_map_entry):
for edge in graph.out_edges(second_map_entry):
if edge.data is None or edge.data.data != intermediate_access_node.data:
continue

Expand Down Expand Up @@ -208,18 +209,18 @@ def apply(self, graph: SDFGState, sdfg: SDFG):
save=False)

# Phase 1: Add new access nodes to second map
for edge in graph.edges_between(intermediate_access_node, self.second_map_entry):
for edge in graph.edges_between(intermediate_access_node, second_map_entry):
graph.remove_edge_and_connectors(edge)

connector_mapping = {}
for edge in graph.in_edges(first_map_entry):
new_in_connector = self.second_map_entry.next_connector(edge.dst_conn[3:])
new_in_connector = second_map_entry.next_connector(edge.dst_conn[3:])
new_in_connector = "IN_" + new_in_connector
if not self.second_map_entry.add_in_connector(new_in_connector):
if not second_map_entry.add_in_connector(new_in_connector):
raise ValueError("Failed to add new in connector")

memlet = copy.deepcopy(edge.data)
graph.add_edge(edge.src, edge.src_conn, self.second_map_entry, new_in_connector, memlet)
graph.add_edge(edge.src, edge.src_conn, second_map_entry, new_in_connector, memlet)

connector_mapping[edge.dst_conn] = new_in_connector

Expand All @@ -231,7 +232,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG):

# Group by same access scheme
consume_memlets = {}
for edge in graph.out_edges(self.second_map_entry):
for edge in graph.out_edges(second_map_entry):
memlet = edge.data
if memlet.data not in produce_memlets:
continue
Expand All @@ -246,7 +247,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG):
consume_memlets[memlet.data][accesses].append(edge)

# And remove from second map
self.second_map_entry.remove_out_connector(edge.src_conn)
second_map_entry.remove_out_connector(edge.src_conn)
graph.remove_edge(edge)

# Phase 3: OTF - copy content of first map for each memlet of second according to matches
Expand All @@ -256,7 +257,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG):
for second_accesses in consume_memlets[array]:
# Step 1: Infer index access of second map to new inputs with respect to original first map
mapping = OTFMapFusion.solve(first_map_entry.map.params, first_accesses,
self.second_map_entry.map.params, second_accesses)
second_map_entry.map.params, second_accesses)

# Step 2: Add Temporary buffer
tmp_name = sdfg.temp_data_name()
Expand Down Expand Up @@ -296,16 +297,16 @@ def apply(self, graph: SDFGState, sdfg: SDFG):
else:
out_connector = edge.src_conn

if out_connector not in self.second_map_entry.out_connectors:
self.second_map_entry.add_out_connector(out_connector)
if out_connector not in second_map_entry.out_connectors:
second_map_entry.add_out_connector(out_connector)
else:
out_connector = None

graph.add_edge(self.second_map_entry, out_connector, node, edge.dst_conn, memlet)
graph.add_edge(second_map_entry, out_connector, node, edge.dst_conn, memlet)
graph.remove_edge(edge)

# Step 4: Rename all symbols of first map in copied content my matched symbol of second map
otf_nodes.append(self.second_map_entry)
otf_nodes.append(second_map_entry)
otf_subgraph = StateSubgraphView(graph, otf_nodes)
for param in mapping:
if isinstance(param, tuple):
Expand All @@ -316,14 +317,9 @@ def apply(self, graph: SDFGState, sdfg: SDFG):

# Check if first_map is still consumed by some node
if graph.out_degree(intermediate_access_node) == 0:
del sdfg.arrays[intermediate_access_node.data]
graph.remove_node(intermediate_access_node)

subgraph = graph.scope_subgraph(first_map_entry, include_entry=True, include_exit=True)
for dnode in subgraph.data_nodes():
if dnode.data in sdfg.arrays:
del sdfg.arrays[dnode.data]

obsolete_nodes = graph.all_nodes_between(first_map_entry,
first_map_exit) | {first_map_entry, first_map_exit}
graph.remove_nodes_from(obsolete_nodes)
Expand Down

0 comments on commit ff9e2e2

Please sign in to comment.