Skip to content

Commit

Permalink
Merge pull request #425 from uwescience/bmyerz/improve_testing
Browse files Browse the repository at this point in the history
Fix some testing issues
  • Loading branch information
domoritz committed May 15, 2015
2 parents a870129 + 68be426 commit 950c602
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 30 deletions.
27 changes: 15 additions & 12 deletions c_test_environment/grappalang_myrial_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,18 @@
import logging
logging.basicConfig(level=logging.DEBUG)

def is_skipping():
return not ('RACO_GRAPPA_TESTS' in os.environ
and int(os.environ['RACO_GRAPPA_TESTS']) == 1)

def raise_skip_test(query=None):
if 'RACO_GRAPPA_TESTS' in os.environ:
if int(os.environ['RACO_GRAPPA_TESTS']) == 1:
return
if not is_skipping():
return None

if query is not None:
raise SkipTest(query)
else:
raise SkipTest()
if query is not None:
raise SkipTest(query)
else:
raise SkipTest()


class MyriaLGrappaTest(MyriaLPlatformTestHarness, MyriaLPlatformTests):
Expand Down Expand Up @@ -60,18 +62,19 @@ def check(self, query, name, **kwargs):
with open(fname, 'w') as f:
f.write(code)

#raise Exception()
raise_skip_test(query)

with Chdir("c_test_environment") as d:
checkquery(name, GrappalangRunner())

def setUp(self):
raise_skip_test()
super(MyriaLGrappaTest, self).setUp()
with Chdir("c_test_environment") as d:
targetpath = os.path.join(os.environ.copy()['GRAPPA_HOME'], 'build/Make+Release/applications/join')
if need_generate(targetpath):
generate_default(targetpath)
if not is_skipping():
with Chdir("c_test_environment") as d:
targetpath = os.path.join(os.environ.copy()['GRAPPA_HOME'], 'build/Make+Release/applications/join')
if need_generate(targetpath):
generate_default(targetpath)

def _uda_def(self):
uda_def_path = os.path.join("c_test_environment", "testqueries", "argmax.myl")
Expand Down
11 changes: 7 additions & 4 deletions raco/language/clang.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,9 +574,6 @@ def opt_rules(self, **kwargs):
# rules.FreeMemory()
# ]

# disable specified rules
rules.Rule.set_global_rule_flags(*kwargs.keys())

# sequence that works for myrial
rule_grps_sequence = [
rules.remove_trivial_sequences,
Expand All @@ -591,4 +588,10 @@ def opt_rules(self, **kwargs):
if kwargs.get('SwapJoinSides'):
rule_grps_sequence.insert(0, [rules.SwapJoinSides()])

return list(itertools.chain(*rule_grps_sequence))
# flatten the rules lists
rule_list = list(itertools.chain(*rule_grps_sequence))

# disable specified rules
rules.Rule.apply_disable_flags(rule_list, *kwargs.keys())

return rule_list
2 changes: 2 additions & 0 deletions raco/language/clangcommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,7 @@ class BreakHashJoinConjunction(rules.Rule):
def __init__(self, select_clazz, join_clazz):
self.select_clazz = select_clazz
self.join_clazz = join_clazz
super(BreakHashJoinConjunction, self).__init__()

def fire(self, expr):
if isinstance(expr, self.join_clazz) \
Expand Down Expand Up @@ -616,6 +617,7 @@ def __init__(self, emit_print, subclass):
assert issubclass(subclass, CBaseStore), \
"%s is not a subclass of %s" % (subclass, CBaseStore)
self.subclass = subclass
super(StoreToBaseCStore, self).__init__()

def fire(self, expr):
if isinstance(expr, algebra.Store):
Expand Down
11 changes: 7 additions & 4 deletions raco/language/grappalang.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,9 +1059,6 @@ def opt_rules(self, **kwargs):
# rules.FreeMemory()
# ]

# disable specified rules
rules.Rule.set_global_rule_flags(*kwargs.keys())

join_type = kwargs.get('join_type', GrappaHashJoin)

# sequence that works for myrial
Expand All @@ -1077,4 +1074,10 @@ def opt_rules(self, **kwargs):
if kwargs.get('SwapJoinSides'):
rule_grps_sequence.insert(0, [rules.SwapJoinSides()])

return list(itertools.chain(*rule_grps_sequence))
# flatten the rules lists
rule_list = list(itertools.chain(*rule_grps_sequence))

# disable specified rules
rules.Rule.apply_disable_flags(rule_list, *kwargs.keys())

return rule_list
24 changes: 19 additions & 5 deletions raco/language/myrialang.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,7 @@ class HCShuffleBeforeNaryJoin(rules.Rule):
def __init__(self, catalog):
assert isinstance(catalog, Catalog)
self.catalog = catalog
super(HCShuffleBeforeNaryJoin, self).__init__()

@staticmethod
def reversed_index(child_schemes, conditions):
Expand Down Expand Up @@ -1324,6 +1325,7 @@ def fire(self, op):
class PushIntoSQL(rules.Rule):
def __init__(self, dialect=None):
self.dialect = dialect or postgresql.dialect()
super(PushIntoSQL, self).__init__()

def fire(self, expr):
if isinstance(expr, (algebra.Scan, algebra.ScanTemp)):
Expand Down Expand Up @@ -1442,6 +1444,7 @@ class GetCardinalities(rules.Rule):
def __init__(self, catalog):
assert isinstance(catalog, Catalog)
self.catalog = catalog
super(GetCardinalities, self).__init__()

def fire(self, expr):
# if not Zeroary operator, who cares?
Expand Down Expand Up @@ -1535,9 +1538,6 @@ class MyriaAlgebra(Algebra):
class MyriaLeftDeepTreeAlgebra(MyriaAlgebra):
"""Myria physical algebra using left deep tree pipeline and 1-D shuffle"""
def opt_rules(self, **kwargs):
# disable specified rules
rules.Rule.set_global_rule_flags(*kwargs.keys())

opt_grps_sequence = [
rules.remove_trivial_sequences,
[
Expand Down Expand Up @@ -1572,7 +1572,14 @@ def opt_rules(self, **kwargs):
compile_grps_sequence.append([BreakSplit()])

rule_grps_sequence = opt_grps_sequence + compile_grps_sequence
return list(itertools.chain(*rule_grps_sequence))

# flatten the rules lists
rule_list = list(itertools.chain(*rule_grps_sequence))

# disable specified rules
rules.Rule.apply_disable_flags(rule_list, *kwargs.keys())

return rule_list


class MyriaHyperCubeAlgebra(MyriaAlgebra):
Expand Down Expand Up @@ -1624,7 +1631,14 @@ def opt_rules(self, **kwargs):
compile_grps_sequence.append([BreakSplit()])

rule_grps_sequence = opt_grps_sequence + compile_grps_sequence
return list(itertools.chain(*rule_grps_sequence))

# flatten the rules lists
rule_list = list(itertools.chain(*rule_grps_sequence))

# disable specified rules
rules.Rule.apply_disable_flags(rule_list, *kwargs.keys())

return rule_list

def __init__(self, catalog=None):
self.catalog = catalog
Expand Down
16 changes: 11 additions & 5 deletions raco/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,29 @@ class Rule(object):

__metaclass__ = ABCMeta

_disabled_rules = set()
_flag_pattern = re.compile(r'no_([A-Za-z_]+)') # e.g., no_MergeSelects

def __init__(self):
self._disabled = False

def __call__(self, expr):
# if the rule is in the set of disabled, then don't allow it to fire
if self.__class__.__name__ in self._disabled_rules:
if self._disabled:
return expr
else:
return self.fire(expr)

@classmethod
def set_global_rule_flags(cls, *args):
def apply_disable_flags(cls, rule_list, *args):
disabled_rules = set()
# Automatically create a flag to disable any rule by name
# e.g., to disable MergeSelects, pass the arg "no_MergeSelects"
for a in args:
mat = re.match(cls._flag_pattern, a)
if mat:
cls._disabled_rules.add(mat.group(1))
disabled_rules.add(mat.group(1))

for r in rule_list:
r._disabled = r.__class__.__name__ in disabled_rules

@abstractmethod
def fire(self, expr):
Expand Down Expand Up @@ -74,6 +79,7 @@ class OneToOne(Rule):
def __init__(self, opfrom, opto):
self.opfrom = opfrom
self.opto = opto
super(OneToOne, self).__init__()

def fire(self, expr):
if isinstance(expr, self.opfrom):
Expand Down

0 comments on commit 950c602

Please sign in to comment.