Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Subsetlist and covers_precise #1412

Merged
2 changes: 1 addition & 1 deletion dace/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,7 @@ def allow_none(self):
def __set__(self, obj, val):
if isinstance(val, str):
val = self.from_string(val)
if (val is not None and not isinstance(val, sbs.Range) and not isinstance(val, sbs.Indices)):
if (val is not None and not isinstance(val, sbs.Range) and not isinstance(val, sbs.Indices) and not isinstance(val, sbs.SubsetUnion)):
raise TypeError("Subset property must be either Range or Indices: got {}".format(type(val).__name__))
super(SubsetProperty, self).__set__(obj, val)

Expand Down
268 changes: 238 additions & 30 deletions dace/subsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,52 @@
from dace.config import Config


def nng(expr):
# When dealing with set sizes, assume symbols are non-negative
try:
# TODO: Fix in symbol definition, not here
for sym in list(expr.free_symbols):
expr = expr.subs({sym: sp.Symbol(sym.name, nonnegative=True)})
return expr
except AttributeError: # No free_symbols in expr
return expr

def bounding_box_cover_exact(subset_a, subset_b) -> bool:
return all([(symbolic.simplify_ext(nng(rb)) <= symbolic.simplify_ext(nng(orb))) == True
and (symbolic.simplify_ext(nng(re)) >= symbolic.simplify_ext(nng(ore))) == True
for rb, re, orb, ore in zip(subset_a.min_element(), subset_a.max_element(),
subset_b.min_element(), subset_b.max_element())])

def bounding_box_symbolic_positive(subset_a, subset_b, approximation = False)-> bool:
min_elements_a = subset_a.min_element_approx() if approximation else subset_a.min_element()
max_elements_a = subset_a.max_element_approx() if approximation else subset_a.max_element()
min_elements_b = subset_b.min_element_approx() if approximation else subset_b.min_element()
max_elements_b = subset_b.max_element_approx() if approximation else subset_b.max_element()

for rb, re, orb, ore in zip(min_elements_a, max_elements_a,
min_elements_b, max_elements_b):
# NOTE: We first test for equality, which always returns True or False. If the equality test returns
# False, then we test for less-equal and greater-equal, which may return an expression, leading to
# TypeError. This is a workaround for the case where two expressions are the same or equal and
# SymPy confirms this but fails to return True when testing less-equal and greater-equal.

# lower bound: first check whether symbolic positive condition applies
if not (len(rb.free_symbols) == 0 and len(orb.free_symbols) == 1):
if not (symbolic.simplify_ext(nng(rb)) == symbolic.simplify_ext(nng(orb)) or
symbolic.simplify_ext(nng(rb)) <= symbolic.simplify_ext(nng(orb))):
return False
# upper bound: first check whether symbolic positive condition applies
if not (len(re.free_symbols) == 1 and len(ore.free_symbols) == 0):
if not (symbolic.simplify_ext(nng(re)) == symbolic.simplify_ext(nng(ore)) or
symbolic.simplify_ext(nng(re)) >= symbolic.simplify_ext(nng(ore))):
return False
return True

class Subset(object):
""" Defines a subset of a data descriptor. """
def covers(self, other):
""" Returns True if this subset covers (using a bounding box) another
subset. """
def nng(expr):
# When dealing with set sizes, assume symbols are non-negative
try:
# TODO: Fix in symbol definition, not here
for sym in list(expr.free_symbols):
expr = expr.subs({sym: sp.Symbol(sym.name, nonnegative=True)})
return expr
except AttributeError: # No free_symbols in expr
return expr

symbolic_positive = Config.get('optimizer', 'symbolic_positive')

if not symbolic_positive:
Expand All @@ -38,28 +69,65 @@ def nng(expr):

else:
try:
for rb, re, orb, ore in zip(self.min_element_approx(), self.max_element_approx(),
other.min_element_approx(), other.max_element_approx()):
# NOTE: We first test for equality, which always returns True or False. If the equality test returns
# False, then we test for less-equal and greater-equal, which may return an expression, leading to
# TypeError. This is a workaround for the case where two expressions are the same or equal and
# SymPy confirms this but fails to return True when testing less-equal and greater-equal.

# lower bound: first check whether symbolic positive condition applies
if not (len(rb.free_symbols) == 0 and len(orb.free_symbols) == 1):
if not (symbolic.simplify_ext(nng(rb)) == symbolic.simplify_ext(nng(orb)) or
symbolic.simplify_ext(nng(rb)) <= symbolic.simplify_ext(nng(orb))):
return False

# upper bound: first check whether symbolic positive condition applies
if not (len(re.free_symbols) == 1 and len(ore.free_symbols) == 0):
if not (symbolic.simplify_ext(nng(re)) == symbolic.simplify_ext(nng(ore)) or
symbolic.simplify_ext(nng(re)) >= symbolic.simplify_ext(nng(ore))):
return False
if not bounding_box_symbolic_positive(self, other, True):
return False
except TypeError:
return False

return True

def covers_precise(self, other):
""" Returns True if self contains all the elements in other. """

# If self does not cover other with a bounding box union, return false.
symbolic_positive = Config.get('optimizer', 'symbolic_positive')
try:
bounding_box_cover = bounding_box_cover_exact(self, other) if symbolic_positive else bounding_box_symbolic_positive(self, other)
if not bounding_box_cover:
return False
except TypeError:
return False

try:
# if self is an index no further distinction is needed
if isinstance(self, Indices):
return True

elif isinstance(self, Range):
# other is an index so we need to check if the step of self is such that other is covered
# self.start % self.step == other.index % self.step
if isinstance(other, Indices):
try:
return all(
[(symbolic.simplify_ext(nng(start)) % symbolic.simplify_ext(nng(step)) ==
symbolic.simplify_ext(nng(i)) % symbolic.simplify_ext(nng(step))) == True
for (start, _, step), i in zip(self.ranges, other.indices)])
except:
return False
if isinstance(other, Range):
# other is a range so in every dimension self.step has to divide other.step and
# self.start % self.step = other.start % other.step
try:
self_steps = [r[2] for r in self.ranges]
other_steps = [r[2] for r in other.ranges]
for start, step, ostart, ostep in zip(self.min_element(), self_steps, other.min_element(),
other_steps):
if not (ostep % step == 0 and
((symbolic.simplify_ext(nng(start)) == symbolic.simplify_ext(nng(ostart))) or
(symbolic.simplify_ext(nng(start)) % symbolic.simplify_ext(
nng(step)) == symbolic.simplify_ext(nng(ostart)) % symbolic.simplify_ext(
nng(ostep))) == True)):
return False
except:
return False
return True
# unknown type
else:
raise TypeError

except TypeError:
return False


def __repr__(self):
return '%s (%s)' % (type(self).__name__, self.__str__())
Expand Down Expand Up @@ -973,6 +1041,111 @@ def intersection(self, other: 'Indices'):
return self
return None

class SubsetUnion(Subset):
"""
Wrapper subset type that stores multiple Subsets in a list.
"""

def __init__(self, subset):
self.subset_list: list[Subset] = []
if isinstance(subset, SubsetUnion):
self.subset_list = subset.subset_list
elif isinstance(subset, list):
for subset in subset:
if not subset:
break
if isinstance(subset, (Range, Indices)):
self.subset_list.append(subset)
else:
raise NotImplementedError
elif isinstance(subset, (Range, Indices)):
self.subset_list = [subset]

def covers(self, other):
"""
Returns True if this SubsetUnion covers another subset (using a bounding box).
If other is another SubsetUnion then self and other will
only return true if self is other. If other is a different type of subset
true is returned when one of the subsets in self is equal to other.
"""

if isinstance(other, SubsetUnion):
for subset in self.subset_list:
# check if ther is a subset in self that covers every subset in other
if all(subset.covers(s) for s in other.subset_list):
return True
# return False if that's not the case for any of the subsets in self
return False
else:
return any(s.covers(other) for s in self.subset_list)

def covers_precise(self, other):
"""
Returns True if this SubsetUnion covers another
subset. If other is another SubsetUnion then self and other will
only return true if self is other. If other is a different type of subset
true is returned when one of the subsets in self is equal to other
"""

if isinstance(other, SubsetUnion):
for subset in self.subset_list:
# check if ther is a subset in self that covers every subset in other
if all(subset.covers_precise(s) for s in other.subset_list):
return True
# return False if that's not the case for any of the subsets in self
return False
else:
return any(s.covers_precise(other) for s in self.subset_list)

def __str__(self):
string = ''
for subset in self.subset_list:
if not string == '':
string += " "
string += subset.__str__()
return string

def dims(self):
if not self.subset_list:
return 0
return next(iter(self.subset_list)).dims()

def union(self, other: Subset):
"""In place union of self with another Subset"""
try:
if isinstance(other, SubsetUnion):
self.subset_list += other.subset_list
elif isinstance(other, Indices) or isinstance(other, Range):
self.subset_list.append(other)
else:
raise TypeError
except TypeError: # cannot determine truth value of Relational
return None

@property
def free_symbols(self) -> Set[str]:
result = set()
for subset in self.subset_list:
result |= subset.free_symbols
return result

def replace(self, repl_dict):
for subset in self.subset_list:
subset.replace(repl_dict)

def num_elements(self):
# TODO: write something more meaningful here
min = 0
for subset in self.subset_list:
try:
if subset.num_elements() < min or min ==0:
min = subset.num_elements()
except:
continue

return min



def _union_special_cases(arb: symbolic.SymbolicType, brb: symbolic.SymbolicType, are: symbolic.SymbolicType,
bre: symbolic.SymbolicType):
Expand Down Expand Up @@ -1038,6 +1211,8 @@ def bounding_box_union(subset_a: Subset, subset_b: Subset) -> Range:
return Range(result)




def union(subset_a: Subset, subset_b: Subset) -> Subset:
""" Compute the union of two Subset objects.
If the subsets are not of the same type, degenerates to bounding-box
Expand All @@ -1056,6 +1231,9 @@ def union(subset_a: Subset, subset_b: Subset) -> Subset:
return subset_b
elif subset_a is None and subset_b is None:
raise TypeError('Both subsets cannot be None')
elif isinstance(subset_a, SubsetUnion) or isinstance(
subset_b, SubsetUnion):
return list_union(subset_a, subset_b)
elif type(subset_a) != type(subset_b):
return bounding_box_union(subset_a, subset_b)
elif isinstance(subset_a, Indices):
Expand All @@ -1066,13 +1244,43 @@ def union(subset_a: Subset, subset_b: Subset) -> Subset:
# TODO(later): More involved Strided-Tiled Range union
return bounding_box_union(subset_a, subset_b)
else:
warnings.warn('Unrecognized Subset type %s in union, degenerating to'
' bounding box' % type(subset_a).__name__)
warnings.warn(
'Unrecognized Subset type %s in union, degenerating to'
' bounding box' % type(subset_a).__name__)
return bounding_box_union(subset_a, subset_b)
except TypeError: # cannot determine truth value of Relational
return None


def list_union(subset_a: Subset, subset_b: Subset) -> Subset:
"""
Returns the union of two Subset lists.

:param subset_a: The first subset.
:param subset_b: The second subset.
:return: A SubsetUnion object that contains all elements of subset_a and subset_b.
"""
# TODO(later): Merge subsets in both lists if possible
try:
if subset_a is not None and subset_b is None:
return subset_a
elif subset_b is not None and subset_a is None:
return subset_b
elif subset_a is None and subset_b is None:
raise TypeError('Both subsets cannot be None')
elif type(subset_a) != type(subset_b):
if isinstance(subset_b, SubsetUnion):
return SubsetUnion(subset_b.subset_list.append(subset_a))
else:
return SubsetUnion(subset_a.subset_list.append(subset_b))
elif isinstance(subset_a, SubsetUnion):
return SubsetUnion(subset_a.subset_list + subset_b.subset_list)
else:
return SubsetUnion([subset_a, subset_b])

except TypeError:
return None

def intersects(subset_a: Subset, subset_b: Subset) -> Union[bool, None]:
"""
Returns True if two subsets intersect, False if they do not, or
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
},
include_package_data=True,
install_requires=[
'numpy', 'networkx >= 2.5', 'astunparse', 'sympy<=1.9', 'pyyaml', 'ply', 'websockets', 'requests', 'flask',
'numpy', 'networkx >= 2.5', 'astunparse', 'sympy>=1.12', 'pyyaml', 'ply', 'websockets', 'requests', 'flask',
'fparser >= 0.1.3', 'aenum >= 3.1', 'dataclasses; python_version < "3.7"', 'dill',
'pyreadline;platform_system=="Windows"', 'typing-compat; python_version < "3.8"'
] + cmake_requires,
Expand Down