Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numpy fill accepts also variables #1420

Merged
merged 17 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
0d09081
Made some small changes to the code.
philip-paul-mueller Nov 1, 2023
59fe6d0
Modified how `fill()` works, it now also accepts non variables.
philip-paul-mueller Nov 1, 2023
e755422
Added a test programm for the issue.
philip-paul-mueller Nov 1, 2023
8cfe431
Fixed a bug in the `fill()` checks.
philip-paul-mueller Nov 1, 2023
ef2e5f0
Added tests to for the `fill()` function.
philip-paul-mueller Nov 1, 2023
57413e9
Renamed a demonstartion function function to better reflect what it d…
philip-paul-mueller Nov 1, 2023
35c72cc
This commit fixes the issue in `A.fill(locVar)` where `locVar` is a l…
philip-paul-mueller Nov 2, 2023
1a35909
Added a completly new implementation of `fill()`.
philip-paul-mueller Nov 2, 2023
a29ae4a
Undid most of my changes to `_elementwise()` also removed some old fi…
philip-paul-mueller Nov 2, 2023
6e8cb02
Merge branch 'master' into diff_parsing_i1389
philip-paul-mueller Nov 2, 2023
0bd35bc
Merge branch 'master' into diff_parsing_i1389
philip-paul-mueller Nov 3, 2023
c58806f
Merge branch 'master' into diff_parsing_i1389
acalotoiu Nov 6, 2023
daf9856
Merge branch 'master' into diff_parsing_i1389
philip-paul-mueller Nov 7, 2023
70ddd7c
Incooperated Ben's suggesten which where correct.
philip-paul-mueller Nov 16, 2023
e425053
Merge branch 'master' into diff_parsing_i1389
philip-paul-mueller Nov 16, 2023
4ad4ae6
Update dace/frontend/python/replacements.py
philip-paul-mueller Nov 17, 2023
3b15322
Update dace/frontend/python/replacements.py
philip-paul-mueller Nov 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 1 addition & 6 deletions dace/frontend/common/op_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@ def _get_all_bases(class_or_name: Union[str, Type]) -> List[str]:
"""
if isinstance(class_or_name, str):
return [class_or_name]

classes = [class_or_name.__name__]
for base in class_or_name.__bases__:
classes.extend(_get_all_bases(base))

return deduplicate(classes)
return [base.__name__ for base in class_or_name.__mro__]


class Replacements(object):
Expand Down
25 changes: 14 additions & 11 deletions dace/frontend/python/astutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,10 @@ class ExtNodeTransformer(ast.NodeTransformer):
bodies in order to discern DaCe statements from others.
"""
def visit_TopLevel(self, node):
clsname = type(node).__name__
if getattr(self, "visit_TopLevel" + clsname, False):
return getattr(self, "visit_TopLevel" + clsname)(node)
visitor_name = "visit_TopLevel" + type(node).__name__
if hasattr(self, visitor_name):
visitor = getattr(self, visitor_name)
return visitor(node)
else:
return self.visit(node)

Expand Down Expand Up @@ -480,21 +481,23 @@ class ExtNodeVisitor(ast.NodeVisitor):
top-level expressions in bodies in order to discern DaCe statements
from others. """
def visit_TopLevel(self, node):
clsname = type(node).__name__
if getattr(self, "visit_TopLevel" + clsname, False):
getattr(self, "visit_TopLevel" + clsname)(node)
visitor_name = "visit_TopLevel" + type(node).__name__
if hasattr(self, visitor_name):
visitor = getattr(self, visitor_name)
return visitor(node)
else:
self.visit(node)
return self.visit(node)

def generic_visit(self, node):
for field, old_value in ast.iter_fields(node):
if isinstance(old_value, list):
for value in old_value:
if isinstance(value, ast.AST):
if (field == 'body' or field == 'orelse'):
clsname = type(value).__name__
if getattr(self, "visit_TopLevel" + clsname, False):
getattr(self, "visit_TopLevel" + clsname)(value)
if field == 'body' or field == 'orelse':
visitor_name = "visit_TopLevel" + type(value).__name__
if hasattr(self, visitor_name):
visitor = getattr(self, visitor_name)
visitor(value)
else:
self.visit(value)
else:
Expand Down
45 changes: 37 additions & 8 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,11 +605,10 @@ def _elementwise(pv: 'ProgramVisitor',
else:
state.add_mapped_tasklet(
name="_elementwise_",
map_ranges={'__i%d' % i: '0:%s' % n
for i, n in enumerate(inparr.shape)},
inputs={'__inp': Memlet.simple(in_array, ','.join(['__i%d' % i for i in range(len(inparr.shape))]))},
map_ranges={f'__i{dim}': f'0:{N}' for dim, N in enumerate(inparr.shape)},
inputs={'__inp': Memlet.simple(in_array, ','.join([f'__i{dim}' for dim in range(len(inparr.shape))]))},
code=code,
outputs={'__out': Memlet.simple(out_array, ','.join(['__i%d' % i for i in range(len(inparr.shape))]))},
outputs={'__out': Memlet.simple(out_array, ','.join([f'__i{dim}' for dim in range(len(inparr.shape))]))},
BenWeber42 marked this conversation as resolved.
Show resolved Hide resolved
external_edges=True)

return out_array
Expand Down Expand Up @@ -4232,10 +4231,40 @@ def _ndarray_copy(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str) ->
@oprepo.replaces_method('Array', 'fill')
@oprepo.replaces_method('Scalar', 'fill')
@oprepo.replaces_method('View', 'fill')
def _ndarray_fill(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str, value: Number) -> str:
if not isinstance(value, (Number, np.bool_)):
raise mem_parser.DaceSyntaxError(pv, None, "Fill value {f} must be a number!".format(f=value))
return _elementwise(pv, sdfg, state, "lambda x: {}".format(value), arr, arr)
def _ndarray_fill(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str, value: Union[str, Number,
sp.Expr]) -> str:
assert arr in sdfg.arrays

if isinstance(value, sp.Expr):
raise NotImplementedError(
f"`{arr}.fill` is not implemented for symbolic expressions (`{value}`).") # Look at `full`.

if isinstance(value, (Number, np.bool_)):
body = value
inputs = {}
elif isinstance(value, str) and value in sdfg.arrays:
value_array = sdfg.arrays[value]
if not isinstance(value_array, data.Scalar):
raise mem_parser.DaceSyntaxError(
pv, None, f"`{arr}.fill` requires a scalar argument, but `{type(value_array)}` was given.")
body = '__inp'
inputs = {'__inp': dace.Memlet(data=value, subset='0')}
else:
raise mem_parser.DaceSyntaxError(pv, None, f"Unsupported argument `{value}` for `{arr}.fill`.")
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved

shape = sdfg.arrays[arr].shape
state.add_mapped_tasklet(
'_numpy_fill_',
map_ranges={
f"__i{dim}": f"0:{s}"
for dim, s in enumerate(shape)
},
inputs=inputs,
code=f"__out = {body}",
outputs={'__out': dace.Memlet.simple(arr, ",".join([f"__i{dim}" for dim in range(len(shape))]))},
external_edges=True)

return arr


@oprepo.replaces_method('Array', 'reshape')
Expand Down
14 changes: 14 additions & 0 deletions tests/numpy/ndarray_attributes_methods_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@ def test_fill(A: dace.int32[M, N]):
return A # return A.fill(5) doesn't work because A is not copied


@compare_numpy_output()
def test_fill2(A: dace.int32[M, N], a: dace.int32):
A.fill(a)
return A # return A.fill(5) doesn't work because A is not copied


@compare_numpy_output()
def test_fill3(A: dace.int32[M, N], a: dace.int32):
A.fill(a + 1)
return A


@compare_numpy_output()
def test_reshape(A: dace.float32[N, N]):
return A.reshape([1, N * N])
Expand Down Expand Up @@ -124,6 +136,8 @@ def test_any():
test_copy()
test_astype()
test_fill()
test_fill2()
test_fill3()
test_reshape()
test_transpose1()
test_transpose2()
Expand Down