diff --git a/.travis.yml b/.travis.yml index f1206fc..b1f4a2d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,7 +7,7 @@ install: - pip install . - pip install -r requirements.txt -script: nosetests --with-coverage tests +script: nosetests --with-coverage --nologcapture tests after_success: coveralls diff --git a/docs/source/api.rst b/docs/source/api.rst index e524a95..919f70f 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -29,6 +29,16 @@ Python 2 .. automethod:: __init__ +Pragma +====== + +.. autofunction:: miniutils.pragma.unroll + +.. autofunction:: miniutils.pragma.collapse_literals + +.. autofunction:: miniutils.pragma.deindex + + Miscellaneous ============= diff --git a/miniutils/magic_contract.py b/miniutils/magic_contract.py index 5f9ea90..0ee497c 100644 --- a/miniutils/magic_contract.py +++ b/miniutils/magic_contract.py @@ -1,8 +1,16 @@ -from contracts import contract, new_contract +from contracts import * from contracts.library import Extension as _Ext from miniutils.opt_decorator import optional_argument_decorator +# TODO: Figure out efficient mechanism to only enable contracts during testing or debug modes + + +def safe_new_contract(name, *args, **kwargs): + if name not in _Ext.registrar: + new_contract(name, *args, **kwargs) + + @optional_argument_decorator def magic_contract(*args, **kwargs): """Drop-in replacement for ``pycontracts.contract`` decorator, except that it supports locally-visible types @@ -13,8 +21,8 @@ def magic_contract(*args, **kwargs): """ def inner_decorator(f): for name, val in f.__globals__.items(): - if not name.startswith('_') and name not in _Ext.registrar and isinstance(val, type): - new_contract(name, val) + if not name.startswith('_') and isinstance(val, type): + safe_new_contract(name, val) return contract(*args, **kwargs)(f) return inner_decorator diff --git a/miniutils/pragma.py b/miniutils/pragma.py index b058c8f..0821b4e 100644 --- a/miniutils/pragma.py +++ b/miniutils/pragma.py @@ -9,12 +9,42 @@ import warnings from miniutils.opt_decorator import optional_argument_decorator +from miniutils.magic_contract import magic_contract, safe_new_contract +from contracts import ContractNotRespected + +try: + import numpy + + num_types = (int, float, numpy.number) + float_types = (float, numpy.floating) +except ImportError: # pragma: nocover + num_types = (int, float) + float_types = (float,) + + +def is_iterable(x): + try: + iter(x) + return True + except Exception: + return False + + +safe_new_contract('function', lambda x: callable(x)) +safe_new_contract('iterable', is_iterable) +safe_new_contract('literal', 'int|float|str|bool|tuple|list|None') +for name, tp in inspect.getmembers(ast, inspect.isclass): + safe_new_contract(name, tp) # Astor tries to get fancy by failing nicely, but in doing so they fail when traversing non-AST type node properties. # By deleting this custom handler, it'll fall back to the default ast visit pattern, which skips these missing # properties. Everything else seems to be implemented, so this will fail silently if it hits an AST node that isn't # supported but should be. -del astor.node_util.ExplicitNodeVisitor.visit +try: + del astor.node_util.ExplicitNodeVisitor.visit +except AttributeError: + # visit isn't defined in this version of astor + pass class DictStack: @@ -28,6 +58,7 @@ def __init__(self, *base): self.constants = [True] + [False] * len(base) def __setitem__(self, key, value): + # print("SETTING {} = {}".format(key, value)) self.dicts[-1][key] = value def __getitem__(self, item): @@ -47,7 +78,7 @@ def __delitem__(self, item): def __contains__(self, item): try: - self[item] + _ = self[item] return True except KeyError: return False @@ -72,10 +103,15 @@ def pop(self): return self.dicts.pop() +@magic_contract def _function_ast(f): - """Returns ast for the given function. Gives a tuple of (ast_module, function_body, function_file""" - assert callable(f) - + """ + Returns ast for the given function. Gives a tuple of (ast_module, function_body, function_file + :param f: The function to parse + :type f: function + :return: The relevant AST code: A module including only the function definition; the func body; the func file + :rtype: tuple(Module, list(AST), str) + """ try: f_file = sys.modules[f.__module__].__file__ except (KeyError, AttributeError): @@ -85,7 +121,17 @@ def _function_ast(f): return root, root.body[0].body, f_file -def can_have_side_effect(node, ctxt): +@magic_contract +def _can_have_side_effect(node, ctxt): + """ + Checks whether or not copying the given AST node could cause side effects in the resulting function + :param node: The AST node to be checked + :type node: AST + :param ctxt: The environment stack to use when running the check + :type ctxt: DictStack + :return: Whether or not duplicating this node could cause side effects + :rtype: bool + """ if isinstance(node, ast.AST): # print("Can {} have side effects?".format(node)) if isinstance(node, ast.Call): @@ -94,9 +140,9 @@ def can_have_side_effect(node, ctxt): else: for field, old_value in ast.iter_fields(node): if isinstance(old_value, list): - return any(can_have_side_effect(n, ctxt) for n in old_value if isinstance(n, ast.AST)) + return any(_can_have_side_effect(n, ctxt) for n in old_value if isinstance(n, ast.AST)) elif isinstance(old_value, ast.AST): - return can_have_side_effect(old_value, ctxt) + return _can_have_side_effect(old_value, ctxt) else: # print(" No!") return False @@ -104,8 +150,20 @@ def can_have_side_effect(node, ctxt): return False +@magic_contract def _constant_iterable(node, ctxt, avoid_side_effects=True): - """If the given node is a known iterable of some sort, return the list of its elements.""" + """ + If the given node is a known iterable of some sort, return the list of its elements. + :param node: The AST node to be checked + :type node: AST + :param ctxt: The environment stack to use when running the check + :type ctxt: DictStack + :param avoid_side_effects: Whether or not to avoid unwrapping side effect-causing AST nodes + :type avoid_side_effects: bool + :return: The iterable if possible, else None + :rtype: iterable|None + """ + # TODO: Support zipping # TODO: Support sets/dicts? # TODO: Support for reversed, enumerate, etc. @@ -114,7 +172,7 @@ def _constant_iterable(node, ctxt, avoid_side_effects=True): def wrap(return_node, name, idx): if not avoid_side_effects: return return_node - if can_have_side_effect(return_node, ctxt): + if _can_have_side_effect(return_node, ctxt): return ast.Subscript(name, ast.Index(idx)) return _make_ast_from_literal(return_node) @@ -127,13 +185,13 @@ def wrap(return_node, name, idx): return None elif isinstance(node, (ast.List, ast.Tuple)): return [_collapse_literal(e, ctxt) for e in node.elts] - #return [_resolve_name_or_attribute(e, ctxt) for e in node.elts] + # return [_resolve_name_or_attribute(e, ctxt) for e in node.elts] # Can't yet support sets and lists, since you need to compute what the unique values would be # elif isinstance(node, ast.Dict): # return node.keys elif isinstance(node, (ast.Name, ast.Attribute, ast.NameConstant)): res = _resolve_name_or_attribute(node, ctxt) - #print("Trying to resolve '{}' as list, got {}".format(astor.to_source(node), res)) + # print("Trying to resolve '{}' as list, got {}".format(astor.to_source(node), res)) if isinstance(res, ast.AST) and not isinstance(res, (ast.Name, ast.Attribute, ast.NameConstant)): res = _constant_iterable(res, ctxt) if not isinstance(res, ast.AST): @@ -147,21 +205,26 @@ def wrap(return_node, name, idx): return None -def _resolve_name_or_attribute(node, ctxt_or_obj): - """If the given name of attribute is defined in the current context, return its value. Else, returns the node""" +@magic_contract +def _resolve_name_or_attribute(node, ctxt): + """ + If the given name of attribute is defined in the current context, return its value. Else, returns the node + :param node: The node to try to resolve + :type node: AST + :param ctxt: The environment stack to use when running the check + :type ctxt: DictStack + :return: The object if the name was found, else the original node + :rtype: * + """ if isinstance(node, ast.Name): - if isinstance(ctxt_or_obj, DictStack): - if node.id in ctxt_or_obj: - #print("Resolved '{}' to {}".format(node.id, ctxt_or_obj[node.id])) - return ctxt_or_obj[node.id] - else: - return node + if node.id in ctxt: + return ctxt[node.id] else: - return getattr(ctxt_or_obj, node.id, node) + return node elif isinstance(node, ast.NameConstant): return node.value elif isinstance(node, ast.Attribute): - base_obj = _resolve_name_or_attribute(node.value, ctxt_or_obj) + base_obj = _resolve_name_or_attribute(node.value, ctxt) if not isinstance(base_obj, ast.AST): return getattr(base_obj, node.attr, node) else: @@ -203,28 +266,75 @@ def _resolve_name_or_attribute(node, ctxt_or_obj): } +@magic_contract def _make_ast_from_literal(lit): - """Converts literals into their AST equivalent""" - if isinstance(lit, (list, tuple)): + """ + Converts literals into their AST equivalent + :param lit: The literal to attempt to turn into an AST + :type lit: * + :return: The AST version of the literal, or the original AST node if one was given + :rtype: * + """ + if isinstance(lit, ast.AST): + return lit + elif isinstance(lit, (list, tuple)): res = [_make_ast_from_literal(e) for e in lit] tp = ast.List if isinstance(lit, list) else ast.Tuple return tp(elts=res) - elif isinstance(lit, (int, float)): - return ast.Num(lit) + elif isinstance(lit, num_types): + if isinstance(lit, float_types): + lit2 = float(lit) + else: + lit2 = int(lit) + if lit2 != lit: + raise AssertionError("({}){} != ({}){}".format(type(lit), lit, type(lit2), lit2)) + return ast.Num(lit2) elif isinstance(lit, str): return ast.Str(lit) elif isinstance(lit, bool): return ast.NameConstant(lit) else: + # warnings.warn("'{}' of type {} is not able to be made into an AST node".format(lit, type(lit))) return lit +@magic_contract +def _is_wrappable(lit): + """ + Checks if the given object either is or can be made into a known AST node + :param lit: The object to try to wrap + :type lit: * + :return: Whether or not this object can be wrapped as an AST node + :rtype: bool + """ + return isinstance(_make_ast_from_literal(lit), ast.AST) + + +@magic_contract def __collapse_literal(node, ctxt): - """Collapses literal expressions. Returns literals if they're available, AST nodes otherwise""" + """ + Collapses literal expressions. Returns literals if they're available, AST nodes otherwise + :param node: The AST node to be checked + :type node: AST + :param ctxt: The environment stack to use when running the check + :type ctxt: DictStack + :return: The given AST node with literal operations collapsed as much as possible + :rtype: * + """ + # try: + # print("Trying to collapse {}".format(astor.to_source(node))) + # except: + # print("Trying to collapse (source not possible) {}".format(astor.dump_tree(node))) + if isinstance(node, (ast.Name, ast.Attribute, ast.NameConstant)): res = _resolve_name_or_attribute(node, ctxt) if isinstance(res, ast.AST) and not isinstance(res, (ast.Name, ast.Attribute, ast.NameConstant)): - res = __collapse_literal(res, ctxt) + new_res = __collapse_literal(res, ctxt) + if _is_wrappable(new_res): + # print("{} can be replaced by more specific literal {}".format(res, new_res)) + res = new_res + # else: + # print("{} is an AST node, but can't safely be made more specific".format(res)) return res elif isinstance(node, ast.Num): return node.n @@ -245,33 +355,54 @@ def __collapse_literal(node, ctxt): if isinstance(slc, ast.AST): return node # print("Value at {}[{}] = {}".format(lst, slc, lst[slc])) - return lst[slc] - elif isinstance(node, (ast.UnaryOp, ast.BinOp, ast.BoolOp)): - if isinstance(node, ast.UnaryOp): - operands = [__collapse_literal(node.operand, ctxt)] + val = lst[slc] + if isinstance(val, ast.AST): + new_val = __collapse_literal(val, ctxt) + if _is_wrappable(new_val): + # print("{} can be replaced by more specific literal {}".format(val, new_val)) + val = new_val + # else: + # print("{} is an AST node, but can't safely be made more specific".format(val)) + # print("Final value at {}[{}] = {}".format(lst, slc, val)) + return val + elif isinstance(node, ast.UnaryOp): + operand = __collapse_literal(node.operand, ctxt) + if isinstance(operand, ast.AST): + return node else: - operands = [__collapse_literal(o, ctxt) for o in [node.left, node.right]] - #print("({} {})".format(repr(node.op), ", ".join(repr(o) for o in operands))) - is_literal = [not isinstance(opr, ast.AST) for opr in operands] - if all(is_literal): try: - val = _collapse_map[type(node.op)](*operands) - return val + return _collapse_map[node.op](operand) + except: + warnings.warn( + "Unary op collapse failed. Collapsing skipped, but executing this function will likely fail." + " Error was:\n{}".format(traceback.format_exc())) + elif isinstance(node, ast.BinOp): + left = __collapse_literal(node.left, ctxt) + right = __collapse_literal(node.right, ctxt) + # print("({} {})".format(repr(node.op), ", ".join(repr(o) for o in operands))) + lliteral = not isinstance(left, ast.AST) + rliteral = not isinstance(right, ast.AST) + if lliteral and rliteral: + # print("Both operands {} and {} are literals, attempting to collapse".format(left, right)) + try: + return _collapse_map[type(node.op)](left, right) except: - warnings.warn("Literal collapse failed. Collapsing skipped, but executing this function will likely fail." - " Error was:\n{}".format(traceback.format_exc())) + warnings.warn( + "Binary op collapse failed. Collapsing skipped, but executing this function will likely fail." + " Error was:\n{}".format(traceback.format_exc())) return node else: - if isinstance(node, ast.UnaryOp): - return ast.UnaryOp(operand=_make_ast_from_literal(operands[0]), op=node.op) - else: - return type(node)(left=_make_ast_from_literal(operands[0]), - right=_make_ast_from_literal(operands[1]), - op=node.op) + left = _make_ast_from_literal(left) + left = left if isinstance(left, ast.AST) else node.left + + right = _make_ast_from_literal(right) + right = right if isinstance(right, ast.AST) else node.right + # print("Attempting to combine {} and {} ({} op)".format(left, right, node.op)) + return ast.BinOp(left=left, right=right, op=node.op) elif isinstance(node, ast.Compare): operands = [__collapse_literal(o, ctxt) for o in [node.left] + node.comparators] if all(not isinstance(opr, ast.AST) for opr in operands): - return all(_collapse_map[type(cmp_func)](operands[i-1], operands[i]) + return all(_collapse_map[type(cmp_func)](operands[i - 1], operands[i]) for i, cmp_func in zip(range(1, len(operands)), node.ops)) else: return node @@ -279,22 +410,46 @@ def __collapse_literal(node, ctxt): return node -def _collapse_literal(node, ctxt): - """Collapse literal expressions in the given node. Returns the node with the collapsed literals""" - return _make_ast_from_literal(__collapse_literal(node, ctxt)) +@magic_contract +def _collapse_literal(node, ctxt, give_raw_result=False): + """ + Collapse literal expressions in the given node. Returns the node with the collapsed literals + :param node: The AST node to be checked + :type node: AST + :param ctxt: The environment stack to use when running the check + :type ctxt: DictStack + :return: The given AST node with literal operations collapsed as much as possible + :rtype: * + """ + result = __collapse_literal(node, ctxt) + if give_raw_result: + return result + result = _make_ast_from_literal(result) + if not isinstance(result, ast.AST): + return node + return result +@magic_contract def _assign_names(node): - """Gets names from a assign-to tuple in flat form, just to know what's affected + """ + Gets names from a assign-to tuple in flat form, just to know what's affected "x=3" -> "x" "a,b=4,5" -> ["a", "b"] "(x,(y,z)),(a,) = something" -> ["x", "y", "z", "a"] + + :param node: The AST node to resolve to a list of names + :type node: Name|Tuple + :return: The flattened list of names referenced in this node + :rtype: iterable """ if isinstance(node, ast.Name): yield node.id elif isinstance(node, ast.Tuple): for e in node.elts: yield from _assign_names(e) + elif isinstance(node, ast.Subscript): + raise NotImplemented() # noinspection PyPep8Naming @@ -307,32 +462,35 @@ def __init__(self, ctxt=None): # orig_node = copy.deepcopy(node) # new_node = super().visit(node) # - # orig_node_code = astor.to_source(orig_node).strip() # try: + # orig_node_code = astor.to_source(orig_node).strip() # if new_node is None: # print("Deleted >>> {} <<<".format(orig_node_code)) # elif isinstance(new_node, ast.AST): # print("Converted >>> {} <<< to >>> {} <<<".format(orig_node_code, astor.to_source(new_node).strip())) # elif isinstance(new_node, list): # print("Converted >>> {} <<< to [[[ {} ]]]".format(orig_node_code, ", ".join(astor.to_source(n).strip() for n in new_node))) - # except AssertionError as ex: - # raise AssertionError("Failed on {} >>> {}".format(orig_node_code, astor.dump_tree(new_node))) from ex + # except Exception as ex: + # raise AssertionError("Failed on {} >>> {}".format(astor.dump_tree(orig_node), astor.dump_tree(new_node))) from ex + # # print("Failed on {} >>> {}".format(astor.dump_tree(orig_node), astor.dump_tree(new_node))) + # # return orig_node # # return new_node def visit_Assign(self, node): node.value = self.visit(node.value) - #print(node.value) + # print(node.value) # TODO: Support tuple assignments if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name): + nvalue = copy.deepcopy(node.value) var = node.targets[0].id - val = _constant_iterable(node.value, self.ctxt) + val = _constant_iterable(nvalue, self.ctxt) if val is not None: - #print("Setting {} = {}".format(var, val)) + # print("Setting {} = {}".format(var, val)) self.ctxt[var] = val else: - val = _collapse_literal(node.value, self.ctxt) - #print("Setting {} = {}".format(var, val)) + val = _collapse_literal(nvalue, self.ctxt) + # print("Setting {} = {}".format(var, val)) self.ctxt[var] = val else: for targ in node.targets: @@ -340,6 +498,11 @@ def visit_Assign(self, node): self.ctxt[assgn] = None return node + def visit_AugAssign(self, node): + for assgn in _assign_names(node.target): + self.ctxt[assgn] = None + return super().generic_visit(node) + # noinspection PyPep8Naming class UnrollTransformer(TrackedContextTransformer): @@ -381,6 +544,9 @@ def visit_Name(self, node): # noinspection PyPep8Naming class CollapseTransformer(TrackedContextTransformer): + def visit_Name(self, node): + return _collapse_literal(node, self.ctxt) + def visit_BinOp(self, node): return _collapse_literal(node, self.ctxt) @@ -396,17 +562,45 @@ def visit_Compare(self, node): def visit_Subscript(self, node): return _collapse_literal(node, self.ctxt) + def visit_If(self, node): + cond = _collapse_literal(node.test, self.ctxt, True) + # print("Attempting to collapse IF conditioned on {}".format(cond)) + if not isinstance(cond, ast.AST): + # print("Yes, this IF can be consolidated, condition is {}".format(bool(cond))) + body = node.body if cond else node.orelse + result = [] + for subnode in body: + res = self.visit(subnode) + if res is None: + pass + elif isinstance(res, list): + result += res + else: + result.append(res) + return result + else: + # print("No, this IF cannot be consolidated") + return super().generic_visit(node) + def _make_function_transformer(transformer_type, name, description): @optional_argument_decorator + @magic_contract def transform(return_source=False, save_source=True, function_globals=None, **kwargs): """ :param return_source: Returns the unrolled function's source code instead of compiling it + :type return_source: bool :param save_source: Saves the function source code to a tempfile to make it inspectable + :type save_source: bool + :param function_globals: Overridden global name assignments to use when processing the function + :type function_globals: dict|None :param kwargs: Any other environmental variables to provide during unrolling + :type kwargs: dict :return: The unrolled function, or its source code if requested + :rtype: function """ + @magic_contract(f='function', returns='function|str') def inner(f): f_mod, f_body, f_file = _function_ast(f) # Grab function globals @@ -428,6 +622,8 @@ def inner(f): except ImportError: # pragma: nocover raise ImportError("miniutils.pragma.{name} requires 'astor' to be installed to obtain source code" .format(name=name)) + except Exception as ex: # pragma: nocover + raise RuntimeError(astor.dump_tree(f_mod)) from ex else: source = None @@ -448,6 +644,7 @@ def inner(f): return func return inner + transform.__name__ = name transform.__doc__ = '\n'.join([description, transform.__doc__]) return transform @@ -460,15 +657,21 @@ def inner(f): collapse_literals = _make_function_transformer(CollapseTransformer, 'collapse_literals', "Collapses literal expressions in the decorated function into single literals") + # Directly reference elements of constant list, removing literal indexing into that list within a function +@magic_contract def deindex(iterable, iterable_name, *args, **kwargs): """ :param iterable: The list to deindex in the target function + :type iterable: iterable :param iterable_name: The list's name (must be unique if deindexing multiple lists) - :param return_source: Returns the unrolled function's source code instead of compiling it - :param save_source: Saves the function source code to a tempfile to make it inspectable + :type iterable_name: str + :param args: Other command line arguments (see :func:`unroll` for documentation) + :type args: tuple :param kwargs: Any other environmental variables to provide during unrolling + :type kwargs: dict :return: The unrolled function, or its source code if requested + :rtype: function """ if hasattr(iterable, 'items'): # Support dicts and the like @@ -483,4 +686,7 @@ def deindex(iterable, iterable_name, *args, **kwargs): return collapse_literals(*args, function_globals=mapping, **kwargs) # Inline functions? - +# You could do something like: +# args, kwargs = (args_in), (kwargs_in) +# function_body +# result = return_expr diff --git a/tests/test_magic_contract.py b/tests/test_magic_contract.py index 2793fa5..fcb5a1d 100644 --- a/tests/test_magic_contract.py +++ b/tests/test_magic_contract.py @@ -1,5 +1,8 @@ from unittest import TestCase + +import contracts +contracts.enable_all() import functools from miniutils.magic_contract import magic_contract from contracts.interface import ContractNotRespected @@ -15,6 +18,7 @@ def fib(n): :return: The fibonnaci number :rtype: int,>=0 """ + assert n >= 0; if n == 0: return 0 elif n == 1: @@ -44,6 +48,10 @@ def sample_func(a): class TestMagicContract(TestCase): + def setUp(self): + import contracts + contracts.enable_all() + def test_magic_contract_1(self): fib(3) fib(5) diff --git a/tests/test_pragma.py b/tests/test_pragma.py index 1609259..ae6d561 100644 --- a/tests/test_pragma.py +++ b/tests/test_pragma.py @@ -4,7 +4,15 @@ import inspect -class TestUnroll(TestCase): +class PragmaTest(TestCase): + def setUp(self): + pass + # # This is a quick hack to disable contracts for testing if needed + # import contracts + # contracts.enable_all() + + +class TestUnroll(PragmaTest): def test_unroll_range(self): @pragma.unroll def f(): @@ -251,7 +259,7 @@ def f(): self.assertEqual(f.strip(), result.strip()) -class TestCollapseLiterals(TestCase): +class TestCollapseLiterals(PragmaTest): def test_full_run(self): def f(y): x = 3 @@ -265,16 +273,12 @@ def f(y): return r import inspect - print(inspect.getsource(f)) - print(pragma.collapse_literals(return_source=True)(f)) deco_f = pragma.collapse_literals(f) self.assertEqual(f(0), deco_f(0)) self.assertEqual(f(1), deco_f(1)) self.assertEqual(f(5), deco_f(5)) self.assertEqual(f(-1), deco_f(-1)) - print(inspect.getsource(f)) - print(pragma.collapse_literals(return_source=True)(pragma.unroll(f))) deco_f = pragma.collapse_literals(pragma.unroll(f)) self.assertEqual(f(0), deco_f(0)) self.assertEqual(f(1), deco_f(1)) @@ -371,7 +375,70 @@ def f(): self.assertTrue(issubclass(w[-1].category, UserWarning)) -class TestDeindex(TestCase): + # TODO: implement the features to get this test to work + # def test_conditional_erasure(self): + # @pragma.collapse_literals(return_source=True) + # def f(y): + # x = 0 + # if y == x: + # x = 1 + # return x + # + # result = dedent(''' + # def f(y): + # x = 0 + # if y == 0: + # x = 1 + # return x + # ''') + # self.assertEqual(f.strip(), result.strip()) + + def test_constant_conditional_erasure(self): + @pragma.collapse_literals(return_source=True) + def f(y): + x = 0 + if x <= 0: + x = 1 + return x + + result = dedent(''' + def f(y): + x = 0 + x = 1 + return 1 + ''') + self.assertEqual(f.strip(), result.strip()) + + def fn(): + if x == 0: + x = 'a' + elif x == 1: + x = 'b' + else: + x = 'c' + return x + + result0 = dedent(''' + def fn(): + x = 'a' + return 'a' + ''') + result1 = dedent(''' + def fn(): + x = 'b' + return 'b' + ''') + result2 = dedent(''' + def fn(): + x = 'c' + return 'c' + ''') + self.assertEqual(pragma.collapse_literals(return_source=True, x=0)(fn).strip(), result0.strip()) + self.assertEqual(pragma.collapse_literals(return_source=True, x=1)(fn).strip(), result1.strip()) + self.assertEqual(pragma.collapse_literals(return_source=True, x=2)(fn).strip(), result2.strip()) + + +class TestDeindex(PragmaTest): def test_with_literals(self): v = [1, 2, 3] @pragma.collapse_literals(return_source=True) @@ -395,7 +462,7 @@ def f(): def f(): return v_0 + v_1 + v_2 ''') - self.assertEqual(f.strip(), result.strip()) + self.assertEqual(result.strip(), f.strip()) def test_with_unroll(self): v = [None, None, None] @@ -458,8 +525,6 @@ def run_func(i, x): if i == j: return funcs[j](x) - print(inspect.getsource(run_func)) - self.assertEqual(run_func(0, 5), 5) self.assertEqual(run_func(1, 5), 25) self.assertEqual(run_func(2, 5), 125) @@ -476,7 +541,7 @@ def run_func(i, x): self.assertEqual(inspect.getsource(run_func).strip(), result.strip()) -class TestDictStack(TestCase): +class TestDictStack(PragmaTest): def test_most(self): stack = pragma.DictStack() stack.push({'x': 3})