Skip to content

Commit

Permalink
Merge pull request #1341 from spcl/fix-scalar-return-validation
Browse files Browse the repository at this point in the history
Validation improvements for Scalars.
  • Loading branch information
alexnick83 committed Aug 3, 2023
2 parents 350cff4 + 7ad6176 commit 7ce5fdb
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 11 deletions.
4 changes: 0 additions & 4 deletions dace/sdfg/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,10 +636,6 @@ def validate(self, sdfg, state, references: Optional[Set[int]] = None, **context
f'Connector "{conn}" was given but is not a registered data descriptor in the nested SDFG. '
'Example: parameter passed to a function without a matching array within it.')
for dname, desc in self.sdfg.arrays.items():
# TODO(later): Disallow scalars without access nodes (so that this
# check passes for them too).
if isinstance(desc, data.Scalar):
continue
if not desc.transient and dname not in connectors:
raise NameError('Data descriptor "%s" not found in nested SDFG connectors' % dname)
if dname in connectors and desc.transient:
Expand Down
47 changes: 47 additions & 0 deletions tests/sdfg/validation/nested_sdfg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,53 @@ def test_inout_connector_validation_fail():
assert False, "SDFG should not validate"


def test_nested_sdfg_with_transient_connector():

sdfg = dace.SDFG('nested_main')
sdfg.add_array('A', [2], dace.float32)

def mystate(state, src, dst):
src_node = state.add_read(src)
dst_node = state.add_write(dst)
tasklet = state.add_tasklet('aaa2', {'a'}, {'b'}, 'b = a + 1')

# input path (src->tasklet[a])
state.add_memlet_path(src_node, tasklet, dst_conn='a', memlet=dace.Memlet(data=src, subset='0'))
# output path (tasklet[b]->dst)
state.add_memlet_path(tasklet, dst_node, src_conn='b', memlet=dace.Memlet(data=dst, subset='0'))


sub_sdfg = dace.SDFG('nested_sub')
sub_sdfg.add_scalar('sA', dace.float32)
sub_sdfg.add_scalar('sB', dace.float32, transient=True)
sub_sdfg.add_scalar('sC', dace.float32, transient=True)

state0 = sub_sdfg.add_state('subs0')
mystate(state0, 'sA', 'sB')
state1 = sub_sdfg.add_state('subs1')
mystate(state1, 'sB', 'sC')

sub_sdfg.add_edge(state0, state1, dace.InterstateEdge())


state = sdfg.add_state('s0')
me, mx = state.add_map('mymap', dict(k='0:2'))
nsdfg = state.add_nested_sdfg(sub_sdfg, sdfg, {'sA'}, {'sC'})
Ain = state.add_read('A')
Aout = state.add_write('A')

state.add_memlet_path(Ain, me, nsdfg, memlet=dace.Memlet(data='A', subset='k'), dst_conn='sA')
state.add_memlet_path(nsdfg, mx, Aout, memlet=dace.Memlet(data='A', subset='k'), src_conn='sC')

try:
sdfg.validate()
except dace.sdfg.InvalidSDFGError:
return

assert False, "SDFG should not validate"


if __name__ == "__main__":
test_inout_connector_validation_success()
test_inout_connector_validation_fail()
test_nested_sdfg_with_transient_connector()
41 changes: 34 additions & 7 deletions tests/transformations/refine_nested_access_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,13 @@ def inner_sdfg(A: dace.int32[5, 5], B: dace.int32[5, 5], select: dace.bool[5, 5]
assert np.allclose(B, lower.T + lower - diag)


def test_free_sybmols_only_by_indices():
def test_free_symbols_only_by_indices():
i = dace.symbol('i')
idx_a = dace.symbol('idx_a')
idx_b = dace.symbol('idx_b')
sdfg = dace.SDFG('refine_free_symbols_only_by_indices')
sdfg.add_array('A', [5], dace.int32)
sdfg.add_array('B', [5, 5], dace.int32)
sdfg.add_scalar('idx_a', dace.int64)
sdfg.add_scalar('idx_b', dace.int64)

@dace.program
def inner_sdfg(A: dace.int32[5], B: dace.int32[5, 5], idx_a: int, idx_b: int):
Expand All @@ -116,10 +116,22 @@ def inner_sdfg(A: dace.int32[5], B: dace.int32[5, 5], idx_a: int, idx_b: int):
state = sdfg.add_state()
A = state.add_access('A')
B = state.add_access('B')
ia = state.add_access('idx_a')
ib = state.add_access('idx_b')
map_entry, map_exit = state.add_map('map', dict(i='0:5'))
nsdfg = state.add_nested_sdfg(inner_sdfg.to_sdfg(simplify=False), sdfg, {'A'}, {'B'}, {'i': 'i'})
state.add_memlet_path(A, map_entry, nsdfg, dst_conn='A', memlet=dace.Memlet.from_array('A', sdfg.arrays['A']))
state.add_memlet_path(nsdfg, map_exit, B, src_conn='B', memlet=dace.Memlet.from_array('B', sdfg.arrays['B']))
nsdfg = state.add_nested_sdfg(inner_sdfg.to_sdfg(simplify=False), sdfg, {'A', 'idx_a', 'idx_b'}, {'B'}, {'i': 'i'})
state.add_memlet_path(A, map_entry, nsdfg, dst_conn='A', memlet=dace.Memlet.from_array('A', sdfg.arrays['A']))
state.add_memlet_path(nsdfg, map_exit, B, src_conn='B', memlet=dace.Memlet.from_array('B', sdfg.arrays['B']))
state.add_memlet_path(ia,
map_entry,
nsdfg,
dst_conn='idx_a',
memlet=dace.Memlet.from_array('idx_a', sdfg.arrays['idx_a']))
state.add_memlet_path(ib,
map_entry,
nsdfg,
dst_conn='idx_b',
memlet=dace.Memlet.from_array('idx_b', sdfg.arrays['idx_b']))

num = sdfg.apply_transformations_repeated(RefineNestedAccess)
assert num == 1
Expand All @@ -128,8 +140,23 @@ def inner_sdfg(A: dace.int32[5], B: dace.int32[5, 5], idx_a: int, idx_b: int):
edge = state.in_edges(map_exit)[0]
assert edge.data.subset == dace.subsets.Range([(i, i, 1), (0, 4, 1)])

A = np.array([0, 1, 0, 1, 0], dtype=np.int32)
ref = np.zeros((5, 5), dtype=np.int32)
val = np.zeros((5, 5), dtype=np.int32)
ia = 3
ib = 2

for i in range(5):
if A[i] > 0.5:
ref[i, ia] = 1
else:
ref[i, ib] = 0
sdfg(A=A, B=val, idx_a=ia, idx_b=ib)

assert np.allclose(ref, val)


if __name__ == '__main__':
test_refine_dataflow()
test_refine_interstate()
test_free_sybmols_only_by_indices()
test_free_symbols_only_by_indices()

0 comments on commit 7ce5fdb

Please sign in to comment.