Skip to content

Commit

Permalink
Merge pull request #1412 from matteonu/implement-SubsetList-and-cover…
Browse files Browse the repository at this point in the history
…s_precise

Implement Subsetlist and covers_precise
  • Loading branch information
acalotoiu committed Nov 3, 2023
2 parents 9430e87 + 26c17a8 commit 9a0eafd
Show file tree
Hide file tree
Showing 3 changed files with 451 additions and 31 deletions.
2 changes: 1 addition & 1 deletion dace/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,7 @@ def allow_none(self):
def __set__(self, obj, val):
if isinstance(val, str):
val = self.from_string(val)
if (val is not None and not isinstance(val, sbs.Range) and not isinstance(val, sbs.Indices)):
if (val is not None and not isinstance(val, sbs.Range) and not isinstance(val, sbs.Indices) and not isinstance(val, sbs.SubsetUnion)):
raise TypeError("Subset property must be either Range or Indices: got {}".format(type(val).__name__))
super(SubsetProperty, self).__set__(obj, val)

Expand Down
268 changes: 238 additions & 30 deletions dace/subsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,52 @@
from dace.config import Config


def nng(expr):
# When dealing with set sizes, assume symbols are non-negative
try:
# TODO: Fix in symbol definition, not here
for sym in list(expr.free_symbols):
expr = expr.subs({sym: sp.Symbol(sym.name, nonnegative=True)})
return expr
except AttributeError: # No free_symbols in expr
return expr

def bounding_box_cover_exact(subset_a, subset_b) -> bool:
return all([(symbolic.simplify_ext(nng(rb)) <= symbolic.simplify_ext(nng(orb))) == True
and (symbolic.simplify_ext(nng(re)) >= symbolic.simplify_ext(nng(ore))) == True
for rb, re, orb, ore in zip(subset_a.min_element(), subset_a.max_element(),
subset_b.min_element(), subset_b.max_element())])

def bounding_box_symbolic_positive(subset_a, subset_b, approximation = False)-> bool:
min_elements_a = subset_a.min_element_approx() if approximation else subset_a.min_element()
max_elements_a = subset_a.max_element_approx() if approximation else subset_a.max_element()
min_elements_b = subset_b.min_element_approx() if approximation else subset_b.min_element()
max_elements_b = subset_b.max_element_approx() if approximation else subset_b.max_element()

for rb, re, orb, ore in zip(min_elements_a, max_elements_a,
min_elements_b, max_elements_b):
# NOTE: We first test for equality, which always returns True or False. If the equality test returns
# False, then we test for less-equal and greater-equal, which may return an expression, leading to
# TypeError. This is a workaround for the case where two expressions are the same or equal and
# SymPy confirms this but fails to return True when testing less-equal and greater-equal.

# lower bound: first check whether symbolic positive condition applies
if not (len(rb.free_symbols) == 0 and len(orb.free_symbols) == 1):
if not (symbolic.simplify_ext(nng(rb)) == symbolic.simplify_ext(nng(orb)) or
symbolic.simplify_ext(nng(rb)) <= symbolic.simplify_ext(nng(orb))):
return False
# upper bound: first check whether symbolic positive condition applies
if not (len(re.free_symbols) == 1 and len(ore.free_symbols) == 0):
if not (symbolic.simplify_ext(nng(re)) == symbolic.simplify_ext(nng(ore)) or
symbolic.simplify_ext(nng(re)) >= symbolic.simplify_ext(nng(ore))):
return False
return True

class Subset(object):
""" Defines a subset of a data descriptor. """
def covers(self, other):
""" Returns True if this subset covers (using a bounding box) another
subset. """
def nng(expr):
# When dealing with set sizes, assume symbols are non-negative
try:
# TODO: Fix in symbol definition, not here
for sym in list(expr.free_symbols):
expr = expr.subs({sym: sp.Symbol(sym.name, nonnegative=True)})
return expr
except AttributeError: # No free_symbols in expr
return expr

symbolic_positive = Config.get('optimizer', 'symbolic_positive')

if not symbolic_positive:
Expand All @@ -38,28 +69,65 @@ def nng(expr):

else:
try:
for rb, re, orb, ore in zip(self.min_element_approx(), self.max_element_approx(),
other.min_element_approx(), other.max_element_approx()):
# NOTE: We first test for equality, which always returns True or False. If the equality test returns
# False, then we test for less-equal and greater-equal, which may return an expression, leading to
# TypeError. This is a workaround for the case where two expressions are the same or equal and
# SymPy confirms this but fails to return True when testing less-equal and greater-equal.

# lower bound: first check whether symbolic positive condition applies
if not (len(rb.free_symbols) == 0 and len(orb.free_symbols) == 1):
if not (symbolic.simplify_ext(nng(rb)) == symbolic.simplify_ext(nng(orb)) or
symbolic.simplify_ext(nng(rb)) <= symbolic.simplify_ext(nng(orb))):
return False

# upper bound: first check whether symbolic positive condition applies
if not (len(re.free_symbols) == 1 and len(ore.free_symbols) == 0):
if not (symbolic.simplify_ext(nng(re)) == symbolic.simplify_ext(nng(ore)) or
symbolic.simplify_ext(nng(re)) >= symbolic.simplify_ext(nng(ore))):
return False
if not bounding_box_symbolic_positive(self, other, True):
return False
except TypeError:
return False

return True

def covers_precise(self, other):
""" Returns True if self contains all the elements in other. """

# If self does not cover other with a bounding box union, return false.
symbolic_positive = Config.get('optimizer', 'symbolic_positive')
try:
bounding_box_cover = bounding_box_cover_exact(self, other) if symbolic_positive else bounding_box_symbolic_positive(self, other)
if not bounding_box_cover:
return False
except TypeError:
return False

try:
# if self is an index no further distinction is needed
if isinstance(self, Indices):
return True

elif isinstance(self, Range):
# other is an index so we need to check if the step of self is such that other is covered
# self.start % self.step == other.index % self.step
if isinstance(other, Indices):
try:
return all(
[(symbolic.simplify_ext(nng(start)) % symbolic.simplify_ext(nng(step)) ==
symbolic.simplify_ext(nng(i)) % symbolic.simplify_ext(nng(step))) == True
for (start, _, step), i in zip(self.ranges, other.indices)])
except:
return False
if isinstance(other, Range):
# other is a range so in every dimension self.step has to divide other.step and
# self.start % self.step = other.start % other.step
try:
self_steps = [r[2] for r in self.ranges]
other_steps = [r[2] for r in other.ranges]
for start, step, ostart, ostep in zip(self.min_element(), self_steps, other.min_element(),
other_steps):
if not (ostep % step == 0 and
((symbolic.simplify_ext(nng(start)) == symbolic.simplify_ext(nng(ostart))) or
(symbolic.simplify_ext(nng(start)) % symbolic.simplify_ext(
nng(step)) == symbolic.simplify_ext(nng(ostart)) % symbolic.simplify_ext(
nng(ostep))) == True)):
return False
except:
return False
return True
# unknown type
else:
raise TypeError

except TypeError:
return False


def __repr__(self):
return '%s (%s)' % (type(self).__name__, self.__str__())
Expand Down Expand Up @@ -973,6 +1041,111 @@ def intersection(self, other: 'Indices'):
return self
return None

class SubsetUnion(Subset):
"""
Wrapper subset type that stores multiple Subsets in a list.
"""

def __init__(self, subset):
self.subset_list: list[Subset] = []
if isinstance(subset, SubsetUnion):
self.subset_list = subset.subset_list
elif isinstance(subset, list):
for subset in subset:
if not subset:
break
if isinstance(subset, (Range, Indices)):
self.subset_list.append(subset)
else:
raise NotImplementedError
elif isinstance(subset, (Range, Indices)):
self.subset_list = [subset]

def covers(self, other):
"""
Returns True if this SubsetUnion covers another subset (using a bounding box).
If other is another SubsetUnion then self and other will
only return true if self is other. If other is a different type of subset
true is returned when one of the subsets in self is equal to other.
"""

if isinstance(other, SubsetUnion):
for subset in self.subset_list:
# check if ther is a subset in self that covers every subset in other
if all(subset.covers(s) for s in other.subset_list):
return True
# return False if that's not the case for any of the subsets in self
return False
else:
return any(s.covers(other) for s in self.subset_list)

def covers_precise(self, other):
"""
Returns True if this SubsetUnion covers another
subset. If other is another SubsetUnion then self and other will
only return true if self is other. If other is a different type of subset
true is returned when one of the subsets in self is equal to other
"""

if isinstance(other, SubsetUnion):
for subset in self.subset_list:
# check if ther is a subset in self that covers every subset in other
if all(subset.covers_precise(s) for s in other.subset_list):
return True
# return False if that's not the case for any of the subsets in self
return False
else:
return any(s.covers_precise(other) for s in self.subset_list)

def __str__(self):
string = ''
for subset in self.subset_list:
if not string == '':
string += " "
string += subset.__str__()
return string

def dims(self):
if not self.subset_list:
return 0
return next(iter(self.subset_list)).dims()

def union(self, other: Subset):
"""In place union of self with another Subset"""
try:
if isinstance(other, SubsetUnion):
self.subset_list += other.subset_list
elif isinstance(other, Indices) or isinstance(other, Range):
self.subset_list.append(other)
else:
raise TypeError
except TypeError: # cannot determine truth value of Relational
return None

@property
def free_symbols(self) -> Set[str]:
result = set()
for subset in self.subset_list:
result |= subset.free_symbols
return result

def replace(self, repl_dict):
for subset in self.subset_list:
subset.replace(repl_dict)

def num_elements(self):
# TODO: write something more meaningful here
min = 0
for subset in self.subset_list:
try:
if subset.num_elements() < min or min ==0:
min = subset.num_elements()
except:
continue

return min



def _union_special_cases(arb: symbolic.SymbolicType, brb: symbolic.SymbolicType, are: symbolic.SymbolicType,
bre: symbolic.SymbolicType):
Expand Down Expand Up @@ -1038,6 +1211,8 @@ def bounding_box_union(subset_a: Subset, subset_b: Subset) -> Range:
return Range(result)




def union(subset_a: Subset, subset_b: Subset) -> Subset:
""" Compute the union of two Subset objects.
If the subsets are not of the same type, degenerates to bounding-box
Expand All @@ -1056,6 +1231,9 @@ def union(subset_a: Subset, subset_b: Subset) -> Subset:
return subset_b
elif subset_a is None and subset_b is None:
raise TypeError('Both subsets cannot be None')
elif isinstance(subset_a, SubsetUnion) or isinstance(
subset_b, SubsetUnion):
return list_union(subset_a, subset_b)
elif type(subset_a) != type(subset_b):
return bounding_box_union(subset_a, subset_b)
elif isinstance(subset_a, Indices):
Expand All @@ -1066,13 +1244,43 @@ def union(subset_a: Subset, subset_b: Subset) -> Subset:
# TODO(later): More involved Strided-Tiled Range union
return bounding_box_union(subset_a, subset_b)
else:
warnings.warn('Unrecognized Subset type %s in union, degenerating to'
' bounding box' % type(subset_a).__name__)
warnings.warn(
'Unrecognized Subset type %s in union, degenerating to'
' bounding box' % type(subset_a).__name__)
return bounding_box_union(subset_a, subset_b)
except TypeError: # cannot determine truth value of Relational
return None


def list_union(subset_a: Subset, subset_b: Subset) -> Subset:
"""
Returns the union of two Subset lists.
:param subset_a: The first subset.
:param subset_b: The second subset.
:return: A SubsetUnion object that contains all elements of subset_a and subset_b.
"""
# TODO(later): Merge subsets in both lists if possible
try:
if subset_a is not None and subset_b is None:
return subset_a
elif subset_b is not None and subset_a is None:
return subset_b
elif subset_a is None and subset_b is None:
raise TypeError('Both subsets cannot be None')
elif type(subset_a) != type(subset_b):
if isinstance(subset_b, SubsetUnion):
return SubsetUnion(subset_b.subset_list.append(subset_a))
else:
return SubsetUnion(subset_a.subset_list.append(subset_b))
elif isinstance(subset_a, SubsetUnion):
return SubsetUnion(subset_a.subset_list + subset_b.subset_list)
else:
return SubsetUnion([subset_a, subset_b])

except TypeError:
return None

def intersects(subset_a: Subset, subset_b: Subset) -> Union[bool, None]:
"""
Returns True if two subsets intersect, False if they do not, or
Expand Down

0 comments on commit 9a0eafd

Please sign in to comment.