Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support attributes in symbolic expressions #1369

Merged
merged 14 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dace/codegen/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,8 @@ def as_cpp(self, codegen, symbols) -> str:
expr = f'{preinit}\nfor ({init}; {cond}; {update}) {{\n'
expr += _clean_loop_body(self.body.as_cpp(codegen, symbols))
expr += '\n}\n'
# TODO: Check that the dot is used to access struct members
expr = expr.replace('.', '->')
return expr

@property
Expand Down
3 changes: 3 additions & 0 deletions dace/codegen/tools/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,9 @@ def _infer_dtype(t: Union[ast.Name, ast.Attribute]):

def _Attribute(t, symbols, inferred_symbols):
inferred_type = _dispatch(t.value, symbols, inferred_symbols)
if (isinstance(inferred_type, dtypes.pointer) and isinstance(inferred_type.base_type, dtypes.struct) and
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexnick83 what about plain structs (i.e., not pointers)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have decided that Structures will always resolve to a pointer to a struct, I think it is better to not handle that case (now), since it might be something unrelated to nested data and cause further confusion.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This assumption was valid for the time being, but I am now arguing to change that assumption. For cases where the type of an expression is being evaluated, where the t.value type of an Attribute expression is not merely a structure, but an array of structures (indexed through a Slice), the returned inferred_type is no longer a pointer but directly a plain struct.

Example: my_array_of_structs[i].my_attribute

Commit e2561ed implements a handle for this additional case.

t.attr in inferred_type.base_type.fields):
return inferred_type.base_type.fields[t.attr]
return inferred_type


Expand Down
112 changes: 67 additions & 45 deletions dace/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,7 @@ def eval(cls, x, y):
def _eval_is_boolean(self):
return True


class IfExpr(sympy.Function):

@classmethod
Expand Down Expand Up @@ -724,6 +725,19 @@ class IsNot(sympy.Function):
pass


class Attr(sympy.Function):
"""
Represents a get-attribute call on a function, equivalent to ``a.b`` in Python.
"""

@property
def free_symbols(self):
return {sympy.Symbol(str(self))}

def __str__(self):
return f'{self.args[0]}.{self.args[1]}'


def sympy_intdiv_fix(expr):
""" Fix for SymPy printing out reciprocal values when they should be
integral in "ceiling/floor" sympy functions.
Expand Down Expand Up @@ -927,10 +941,9 @@ def _process_is(elem: Union[Is, IsNot]):
return expr


class SympyBooleanConverter(ast.NodeTransformer):
class PythonOpToSympyConverter(ast.NodeTransformer):
"""
Replaces boolean operations with the appropriate SymPy functions to avoid
non-symbolic evaluation.
Replaces various operations with the appropriate SymPy functions to avoid non-symbolic evaluation.
"""
_ast_to_sympy_comparators = {
ast.Eq: 'Eq',
Expand All @@ -946,12 +959,37 @@ class SympyBooleanConverter(ast.NodeTransformer):
ast.NotIn: 'NotIn',
}

_ast_to_sympy_functions = {
ast.BitAnd: 'BitwiseAnd',
ast.BitOr: 'BitwiseOr',
ast.BitXor: 'BitwiseXor',
ast.Invert: 'BitwiseNot',
ast.LShift: 'LeftShift',
ast.RShift: 'RightShift',
ast.FloorDiv: 'int_floor',
}

def visit_UnaryOp(self, node):
if isinstance(node.op, ast.Not):
func_node = ast.copy_location(ast.Name(id=type(node.op).__name__, ctx=ast.Load()), node)
new_node = ast.Call(func=func_node, args=[self.visit(node.operand)], keywords=[])
return ast.copy_location(new_node, node)
return node
elif isinstance(node.op, ast.Invert):
func_node = ast.copy_location(ast.Name(id=self._ast_to_sympy_functions[type(node.op)], ctx=ast.Load()),
node)
new_node = ast.Call(func=func_node, args=[self.visit(node.operand)], keywords=[])
return ast.copy_location(new_node, node)
return self.generic_visit(node)

def visit_BinOp(self, node):
if type(node.op) in self._ast_to_sympy_functions:
func_node = ast.copy_location(ast.Name(id=self._ast_to_sympy_functions[type(node.op)], ctx=ast.Load()),
node)
new_node = ast.Call(func=func_node,
args=[self.visit(value) for value in (node.left, node.right)],
keywords=[])
return ast.copy_location(new_node, node)
return self.generic_visit(node)

def visit_BoolOp(self, node):
func_node = ast.copy_location(ast.Name(id=type(node.op).__name__, ctx=ast.Load()), node)
Expand All @@ -971,8 +1009,7 @@ def visit_Compare(self, node: ast.Compare):
raise NotImplementedError
op = node.ops[0]
arguments = [node.left, node.comparators[0]]
func_node = ast.copy_location(
ast.Name(id=SympyBooleanConverter._ast_to_sympy_comparators[type(op)], ctx=ast.Load()), node)
func_node = ast.copy_location(ast.Name(id=self._ast_to_sympy_comparators[type(op)], ctx=ast.Load()), node)
new_node = ast.Call(func=func_node, args=[self.visit(arg) for arg in arguments], keywords=[])
return ast.copy_location(new_node, node)

Expand All @@ -985,41 +1022,28 @@ def visit_NameConstant(self, node):
return self.visit_Constant(node)

def visit_IfExp(self, node):
new_node = ast.Call(func=ast.Name(id='IfExpr', ctx=ast.Load), args=[node.test, node.body, node.orelse], keywords=[])
new_node = ast.Call(func=ast.Name(id='IfExpr', ctx=ast.Load),
args=[self.visit(node.test),
self.visit(node.body),
self.visit(node.orelse)],
keywords=[])
return ast.copy_location(new_node, node)

class BitwiseOpConverter(ast.NodeTransformer):
"""
Replaces C/C++ bitwise operations with functions to avoid sympification to boolean operations.
"""
_ast_to_sympy_functions = {
ast.BitAnd: 'BitwiseAnd',
ast.BitOr: 'BitwiseOr',
ast.BitXor: 'BitwiseXor',
ast.Invert: 'BitwiseNot',
ast.LShift: 'LeftShift',
ast.RShift: 'RightShift',
ast.FloorDiv: 'int_floor',
}

def visit_UnaryOp(self, node):
if isinstance(node.op, ast.Invert):
func_node = ast.copy_location(
ast.Name(id=BitwiseOpConverter._ast_to_sympy_functions[type(node.op)], ctx=ast.Load()), node)
new_node = ast.Call(func=func_node, args=[self.visit(node.operand)], keywords=[])
return ast.copy_location(new_node, node)
return self.generic_visit(node)

def visit_BinOp(self, node):
if type(node.op) in BitwiseOpConverter._ast_to_sympy_functions:
func_node = ast.copy_location(
ast.Name(id=BitwiseOpConverter._ast_to_sympy_functions[type(node.op)], ctx=ast.Load()), node)
new_node = ast.Call(func=func_node,
args=[self.visit(value) for value in (node.left, node.right)],

def visit_Subscript(self, node):
if isinstance(node.value, ast.Attribute):
attr = ast.Subscript(value=ast.Name(id=node.value.attr, ctx=ast.Load()), slice=node.slice, ctx=ast.Load())
new_node = ast.Call(func=ast.Name(id='Attr', ctx=ast.Load),
args=[self.visit(node.value.value), self.visit(attr)],
keywords=[])
return ast.copy_location(new_node, node)
return self.generic_visit(node)

def visit_Attribute(self, node):
new_node = ast.Call(func=ast.Name(id='Attr', ctx=ast.Load),
args=[self.visit(node.value), ast.Name(id=node.attr, ctx=ast.Load)],
keywords=[])
return ast.copy_location(new_node, node)


@lru_cache(maxsize=16384)
def pystr_to_symbolic(expr, symbol_map=None, simplify=None) -> sympy.Basic:
Expand Down Expand Up @@ -1071,21 +1095,17 @@ def pystr_to_symbolic(expr, symbol_map=None, simplify=None) -> sympy.Basic:
'int_ceil': int_ceil,
'IfExpr': IfExpr,
'Mod': sympy.Mod,
'Attr': Attr,
}
# _clash1 enables all one-letter variables like N as symbols
# _clash also allows pi, beta, zeta and other common greek letters
locals.update(_sympy_clash)

if isinstance(expr, str):
# Sympy processes "not/and/or" as direct evaluation. Replace with
# And/Or(x, y), Not(x)
if re.search(r'\bnot\b|\band\b|\bor\b|\bNone\b|==|!=|\bis\b|\bif\b', expr):
expr = unparse(SympyBooleanConverter().visit(ast.parse(expr).body[0]))

# NOTE: If the expression contains bitwise operations, replace them with user-functions.
# NOTE: Sympy does not support bitwise operations and converts them to boolean operations.
if re.search('[&]|[|]|[\^]|[~]|[<<]|[>>]|[//]', expr):
expr = unparse(BitwiseOpConverter().visit(ast.parse(expr).body[0]))
# Sympy processes "not/and/or" as direct evaluation. Replace with And/Or(x, y), Not(x)
# Also replaces bitwise operations with user-functions since SymPy does not support bitwise operations.
if re.search(r'\bnot\b|\band\b|\bor\b|\bNone\b|==|!=|\bis\b|\bif\b|[&]|[|]|[\^]|[~]|[<<]|[>>]|[//]|[\.]', expr):
expr = unparse(PythonOpToSympyConverter().visit(ast.parse(expr).body[0]))

# TODO: support SymExpr over-approximated expressions
try:
Expand Down Expand Up @@ -1126,6 +1146,8 @@ def _print_Function(self, expr):
return f'(({self._print(expr.args[0])}) and ({self._print(expr.args[1])}))'
if str(expr.func) == 'OR':
return f'(({self._print(expr.args[0])}) or ({self._print(expr.args[1])}))'
if str(expr.func) == 'Attr':
return f'{self._print(expr.args[0])}.{self._print(expr.args[1])}'
return super()._print_Function(expr)

def _print_Mod(self, expr):
Expand Down
47 changes: 47 additions & 0 deletions tests/sdfg/data/structure_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,52 @@ def test_direct_read_structure():
assert np.allclose(B, ref)


def test_direct_read_structure_loops():

M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz'))
csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]),
name='CSRMatrix')

sdfg = dace.SDFG('csr_to_dense_direct_loops')

sdfg.add_datadesc('A', csr_obj)
sdfg.add_array('B', [M, N], dace.float32)

state = sdfg.add_state()

indices = state.add_access('A.indices')
data = state.add_access('A.data')
B = state.add_access('B')

t = state.add_tasklet('indirection', {'j', '__val'}, {'__out'}, '__out[i, j] = __val')
state.add_edge(indices, None, t, 'j', dace.Memlet(data='A.indices', subset='idx'))
state.add_edge(data, None, t, '__val', dace.Memlet(data='A.data', subset='idx'))
state.add_edge(t, '__out', B, None, dace.Memlet(data='B', subset='0:M, 0:N', volume=1))

idx_before, idx_guard, idx_after = sdfg.add_loop(None, state, None, 'idx', 'A.indptr[i]', 'idx < A.indptr[i+1]', 'idx + 1')
i_before, i_guard, i_after = sdfg.add_loop(None, idx_before, None, 'i', '0', 'i < M', 'i + 1', loop_end_state=idx_after)

func = sdfg.compile()

rng = np.random.default_rng(42)
A = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng)
B = np.zeros((20, 20), dtype=np.float32)

inpA = csr_obj.dtype._typeclass.as_ctypes()(indptr=A.indptr.__array_interface__['data'][0],
indices=A.indices.__array_interface__['data'][0],
data=A.data.__array_interface__['data'][0],
rows=A.shape[0],
cols=A.shape[1],
M=A.shape[0],
N=A.shape[1],
nnz=A.nnz)

func(A=inpA, B=B, M=20, N=20, nnz=A.nnz)
ref = A.toarray()

assert np.allclose(B, ref)


def test_direct_read_nested_structure():
M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz'))
csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]),
Expand Down Expand Up @@ -505,3 +551,4 @@ def test_direct_read_nested_structure():
test_write_nested_structure()
test_direct_read_structure()
test_direct_read_nested_structure()
test_direct_read_structure_loops()
Loading