Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fortran fix memlet indices #1342

Merged
merged 5 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions dace/frontend/fortran/fortran_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node,
if i.type == "ALL":
shape.append(array.shape[indices])
mysize = mysize * array.shape[indices]
index_list.append(None)
else:
raise NotImplementedError("Index in ParDecl should be ALL")
else:
Expand Down
50 changes: 50 additions & 0 deletions tests/fortran/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dace.frontend.fortran import fortran_parser
from fparser.two.symbol_table import SymbolTable
from dace.sdfg import utils as sdutil
from dace.sdfg.nodes import AccessNode

import dace.frontend.fortran.ast_components as ast_components
import dace.frontend.fortran.ast_transforms as ast_transforms
Expand Down Expand Up @@ -167,10 +168,59 @@ def test_fortran_frontend_input_output_connector():
assert (a[1, 2] == 0)


def test_fortran_frontend_memlet_in_map_test():
"""
Tests that no assumption is made where the iteration variable is inside a memlet subset
"""
test_string = """
PROGRAM memlet_range_test
implicit None
REAL INP(100, 10)
REAL OUT(100, 10)
CALL memlet_range_test_routine(INP, OUT)
END PROGRAM

SUBROUTINE memlet_range_test_routine(INP, OUT)
REAL INP(100, 10)
REAL OUT(100, 10)
DO I=1,100
CALL inner_loops(INP(I, :), OUT(I, :))
ENDDO
END SUBROUTINE memlet_range_test_routine

SUBROUTINE inner_loops(INP, OUT)
REAL INP(10)
REAL OUT(10)
DO J=1,10
OUT(J) = INP(J) + 1
ENDDO
END SUBROUTINE inner_loops

"""
sdfg = fortran_parser.create_sdfg_from_string(test_string, "memlet_range_test")
sdfg.simplify()
# Expect that start is begin of for loop -> only one out edge to guard defining iterator variable
assert len(sdfg.out_edges(sdfg.start_state)) == 1
iter_var = symbolic.symbol(list(sdfg.out_edges(sdfg.start_state)[0].data.assignments.keys())[0])

for state in sdfg.states():
if len(state.nodes()) > 1:
for node in state.nodes():
if isinstance(node, AccessNode) and node.data in ['INP', 'OUT']:
edges = [*state.in_edges(node), *state.out_edges(node)]
# There should be only one edge in/to the access node
assert len(edges) == 1
memlet = edges[0].data
# Check that the correct memlet has the iteration variable
assert memlet.subset[0] == (iter_var, iter_var, 1)
assert memlet.subset[1] == (1, 10, 1)


if __name__ == "__main__":

test_fortran_frontend_array_3dmap()
test_fortran_frontend_array_access()
test_fortran_frontend_input_output_connector()
test_fortran_frontend_array_ranges()
test_fortran_frontend_twoconnector()
test_fortran_frontend_memlet_in_map_test()