Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix-nested-sdfg-deepcopy #1221

Merged
merged 14 commits into from
Mar 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 10 additions & 0 deletions dace/sdfg/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,16 @@ def __init__(self,
self.symbol_mapping = symbol_mapping or {}
self.schedule = schedule
self.debuginfo = debuginfo

def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
setattr(result, k, dcpy(v, memo))
if result._sdfg is not None:
result._sdfg.parent_nsdfg_node = result
return result

@staticmethod
def from_json(json_obj, context=None):
Expand Down
32 changes: 31 additions & 1 deletion dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,34 @@ def __init__(self,
self._orig_name = name
self._num = 0

def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
# Skip derivative attributes
if k in ('_cached_start_state', '_edges', '_nodes', '_parent', '_parent_sdfg', '_parent_nsdfg_node',
'_sdfg_list', '_transformation_hist'):
continue
setattr(result, k, copy.deepcopy(v, memo))
# Copy edges and nodes
result._edges = copy.deepcopy(self._edges, memo)
result._nodes = copy.deepcopy(self._nodes, memo)
result._cached_start_state = copy.deepcopy(self._cached_start_state, memo)
# Copy parent attributes
for k in ('_parent', '_parent_sdfg', '_parent_nsdfg_node'):
if id(getattr(self, k)) in memo:
setattr(result, k, memo[id(getattr(self, k))])
else:
setattr(result, k, None)
# Copy SDFG list and transformation history
if hasattr(self, '_transformation_hist'):
setattr(result, '_transformation_hist', copy.deepcopy(self._transformation_hist, memo))
result._sdfg_list = []
if self._parent_sdfg is None:
result._sdfg_list = result.reset_sdfg_list()
return result

@property
def sdfg_id(self):
"""
Expand Down Expand Up @@ -520,6 +548,7 @@ def hash_sdfg(self, jsondict: Optional[Dict[str, Any]] = None) -> str:
:param jsondict: If not None, uses given JSON dictionary as input.
:return: The hash (in SHA-256 format).
"""

def keyword_remover(json_obj: Any, last_keyword=""):
# Makes non-unique in SDFG hierarchy v2
# Recursively remove attributes from the SDFG which are not used in
Expand Down Expand Up @@ -1910,7 +1939,8 @@ def add_datadesc(self, name: str, datadesc: dt.Data, find_new_name=False) -> str
if find_new_name:
name = self._find_new_name(name)
else:
raise NameError('Array or Stream with name "%s" already exists ' "in SDFG" % name)
raise NameError('Array or Stream with name "%s" already exists '
"in SDFG" % name)
self._arrays[name] = datadesc

# Add free symbols to the SDFG global symbol storage
Expand Down
21 changes: 21 additions & 0 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,22 @@ def __init__(self, label=None, sdfg=None, debuginfo=None, location=None):
self.nosync = False
self.location = location if location is not None else {}
self._default_lineinfo = None

def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
setattr(result, k, copy.deepcopy(v, memo))
for node in result.nodes():
if isinstance(node, nd.NestedSDFG):
try:
node.sdfg.parent = result
except AttributeError:
# NOTE: There are cases where a NestedSDFG does not have `sdfg` attribute.
# TODO: Investigate why this happens.
pass
return result

@property
def parent(self):
Expand Down Expand Up @@ -819,6 +835,11 @@ def all_edges_and_connectors(self, *nodes):
def add_node(self, node):
if not isinstance(node, nd.Node):
raise TypeError("Expected Node, got " + type(node).__name__ + " (" + str(node) + ")")
# Correct nested SDFG's parent attributes
if isinstance(node, nd.NestedSDFG):
node.sdfg.parent = self
node.sdfg.parent_sdfg = self.parent
node.sdfg.parent_nsdfg_node = node
self._clear_scopedict_cache()
return super(SDFGState, self).add_node(node)

Expand Down
16 changes: 16 additions & 0 deletions dace/sdfg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1629,3 +1629,19 @@ def map_view_to_array(vdesc: dt.View, adesc: dt.Array,
squeezed.append(i)

return dimension_mapping, unsqueezed, squeezed


def check_sdfg(sdfg: SDFG):
""" Checks that the parent attributes of an SDFG are correct.

:param sdfg: The SDFG to check.
:raises AssertionError: If any of the parent attributes are incorrect.
"""
for state in sdfg.nodes():
for node in state.nodes():
if isinstance(node, dace.nodes.NestedSDFG):
assert node.sdfg.parent_nsdfg_node is node
assert node.sdfg.parent is state
assert node.sdfg.parent_sdfg is sdfg
assert node.sdfg.parent.parent is sdfg
check_sdfg(node.sdfg)
2 changes: 2 additions & 0 deletions dace/transformation/interstate/multistate_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,8 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG):
# Remove nested SDFG and state
sdfg.remove_node(outer_state)

sdfg._sdfg_list = sdfg.reset_sdfg_list()

return nsdfg.nodes()

# def _modify_access_to_access(
Expand Down
2 changes: 2 additions & 0 deletions dace/transformation/interstate/sdfg_nesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,8 @@ def apply(self, state: SDFGState, sdfg: SDFG):
for dnode in state.data_nodes():
if state.degree(dnode) == 0 and dnode not in isolated_nodes:
state.remove_node(dnode)

sdfg._sdfg_list = sdfg.reset_sdfg_list()

def _modify_access_to_access(self,
input_edges: Dict[nodes.Node, MultiConnectorEdge],
Expand Down
1 change: 1 addition & 0 deletions dace/transformation/subgraph/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool:
# deepcopy
graph_indices = [i for (i, n) in enumerate(graph.nodes()) if n in subgraph]
sdfg_copy = copy.deepcopy(sdfg)
sdfg_copy.reset_sdfg_list()
graph_copy = sdfg_copy.nodes()[sdfg.nodes().index(graph)]
subgraph_copy = SubgraphView(graph_copy, [graph_copy.nodes()[i] for i in graph_indices])
expansion.sdfg_id = sdfg_copy.sdfg_id
Expand Down
153 changes: 153 additions & 0 deletions tests/sdfg/nested_sdfg_deepcopy_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
""" Tests deepcopying (nested) SDFGs. """
import copy
import dace
import numpy as np


def test_deepcopy_same_state():

sdfg = dace.SDFG('deepcopy_nested_sdfg')
state = sdfg.add_state('state')

nsdfg = dace.SDFG('nested')
nsdfg_node = state.add_nested_sdfg(nsdfg, None, {}, {})

copy_nsdfg = copy.deepcopy(nsdfg_node)
assert copy_nsdfg.sdfg.parent_nsdfg_node is copy_nsdfg
assert copy_nsdfg.sdfg.parent is None
assert copy_nsdfg.sdfg.parent_sdfg is None

state.add_node(copy_nsdfg)
assert copy_nsdfg.sdfg.parent is state
assert copy_nsdfg.sdfg.parent_sdfg is sdfg


def test_deepcopy_same_state_edge():

sdfg = dace.SDFG('deepcopy_nested_sdfg')
state = sdfg.add_state('state')

nsdfg = dace.SDFG('nested')
nsdfg_node = state.add_nested_sdfg(nsdfg, None, {}, {})

copy_nsdfg = copy.deepcopy(nsdfg_node)
assert copy_nsdfg.sdfg.parent_nsdfg_node is copy_nsdfg
assert copy_nsdfg.sdfg.parent is None
assert copy_nsdfg.sdfg.parent_sdfg is None

state.add_edge(nsdfg_node, None, copy_nsdfg, None, dace.Memlet())
assert copy_nsdfg.sdfg.parent is state
assert copy_nsdfg.sdfg.parent_sdfg is sdfg


def test_deepcopy_diff_state():

sdfg = dace.SDFG('deepcopy_nested_sdfg')
state_0 = sdfg.add_state('state_0')
state_1 = sdfg.add_state('state_1')

nsdfg = dace.SDFG('nested')
nsdfg_node = state_0.add_nested_sdfg(nsdfg, None, {}, {})

copy_nsdfg = copy.deepcopy(nsdfg_node)
assert copy_nsdfg.sdfg.parent_nsdfg_node is copy_nsdfg
assert copy_nsdfg.sdfg.parent is None
assert copy_nsdfg.sdfg.parent_sdfg is None

state_1.add_node(copy_nsdfg)
assert copy_nsdfg.sdfg.parent is state_1
assert copy_nsdfg.sdfg.parent_sdfg is sdfg


def test_deepcopy_diff_state_edge():

sdfg = dace.SDFG('deepcopy_nested_sdfg')
sdfg.add_array('A', [1], dace.int32)
state_0 = sdfg.add_state('state_0')
state_1 = sdfg.add_state('state_1')

nsdfg = dace.SDFG('nested')
nsdfg_node = state_0.add_nested_sdfg(nsdfg, None, {}, {})

copy_nsdfg = copy.deepcopy(nsdfg_node)
assert copy_nsdfg.sdfg.parent_nsdfg_node is copy_nsdfg
assert copy_nsdfg.sdfg.parent is None
assert copy_nsdfg.sdfg.parent_sdfg is None

a = state_1.add_access('A')
state_1.add_edge(a, None, copy_nsdfg, None, dace.Memlet())
assert copy_nsdfg.sdfg.parent is state_1
assert copy_nsdfg.sdfg.parent_sdfg is sdfg


def test_deepcopy_diff_sdfg():

sdfg_0 = dace.SDFG('deepcopy_nested_sdfg_0')
state_0 = sdfg_0.add_state('state_0')

nsdfg = dace.SDFG('nested')
nsdfg_node = state_0.add_nested_sdfg(nsdfg, None, {}, {})

copy_nsdfg = copy.deepcopy(nsdfg_node)
assert copy_nsdfg.sdfg.parent_nsdfg_node is copy_nsdfg
assert copy_nsdfg.sdfg.parent is None
assert copy_nsdfg.sdfg.parent_sdfg is None

sdfg_1 = dace.SDFG('deepcopy_nested_sdfg_1')
state_1 = sdfg_1.add_state('state_1')

state_1.add_node(copy_nsdfg)
assert copy_nsdfg.sdfg.parent is state_1
assert copy_nsdfg.sdfg.parent_sdfg is sdfg_1


def test_deepcopy_diff_sdfg_edge():

sdfg_0 = dace.SDFG('deepcopy_nested_sdfg_0')
state_0 = sdfg_0.add_state('state_0')

nsdfg = dace.SDFG('nested')
nsdfg_node = state_0.add_nested_sdfg(nsdfg, None, {}, {})

copy_nsdfg = copy.deepcopy(nsdfg_node)
assert copy_nsdfg.sdfg.parent_nsdfg_node is copy_nsdfg
assert copy_nsdfg.sdfg.parent is None
assert copy_nsdfg.sdfg.parent_sdfg is None

sdfg_1 = dace.SDFG('deepcopy_nested_sdfg_1')
sdfg_1.add_array('A', [1], dace.int32)
state_1 = sdfg_1.add_state('state_1')

a = state_1.add_access('A')
state_1.add_edge(a, None, copy_nsdfg, None, dace.Memlet())
assert copy_nsdfg.sdfg.parent is state_1
assert copy_nsdfg.sdfg.parent_sdfg is sdfg_1


def test_deepcopy_top_level():

sdfg = dace.SDFG('deepcopy_nested_sdfg')
state = sdfg.add_state('state')

nsdfg = dace.SDFG('nested')
nsdfg_node = state.add_nested_sdfg(nsdfg, None, {}, {})

copy_sdfg = copy.deepcopy(sdfg)
copy_state = copy_sdfg.states()[0]
copy_nsdfg_node = copy_state.nodes()[0]
for sd in copy_sdfg.all_sdfgs_recursive():
if sd is copy_sdfg:
continue
assert sd.parent_nsdfg_node is copy_nsdfg_node
assert sd.parent is copy_state
assert sd.parent_sdfg is copy_sdfg


if __name__ == '__main__':
test_deepcopy_same_state()
test_deepcopy_same_state_edge()
test_deepcopy_diff_state()
test_deepcopy_diff_state_edge()
test_deepcopy_diff_sdfg()
test_deepcopy_top_level()