From e21cef27c7e3ebc880330123734985de599b42c9 Mon Sep 17 00:00:00 2001 From: scnerd Date: Sun, 26 Nov 2017 14:05:07 -0500 Subject: [PATCH 1/6] Added basic numpy literal support --- miniutils/pragma.py | 35 +++++++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/miniutils/pragma.py b/miniutils/pragma.py index b058c8f..c001199 100644 --- a/miniutils/pragma.py +++ b/miniutils/pragma.py @@ -10,11 +10,23 @@ from miniutils.opt_decorator import optional_argument_decorator +try: + import numpy + num_types = (int, float, numpy.number) + float_types = (float, numpy.floating) +except ImportError: # pragma: nocover + num_types = (int, float) + float_types = (float,) + # Astor tries to get fancy by failing nicely, but in doing so they fail when traversing non-AST type node properties. # By deleting this custom handler, it'll fall back to the default ast visit pattern, which skips these missing # properties. Everything else seems to be implemented, so this will fail silently if it hits an AST node that isn't # supported but should be. -del astor.node_util.ExplicitNodeVisitor.visit +try: + del astor.node_util.ExplicitNodeVisitor.visit +except AttributeError: + # visit isn't defined in this version of astor + pass class DictStack: @@ -209,14 +221,22 @@ def _make_ast_from_literal(lit): res = [_make_ast_from_literal(e) for e in lit] tp = ast.List if isinstance(lit, list) else ast.Tuple return tp(elts=res) - elif isinstance(lit, (int, float)): - return ast.Num(lit) + elif isinstance(lit, num_types): + if isinstance(lit, float_types): + lit2 = float(lit) + else: + lit2 = int(lit) + if lit2 != lit: + raise AssertionError("({}){} != ({}){}".format(type(lit), lit, type(lit2), lit2)) + return ast.Num(lit2) elif isinstance(lit, str): return ast.Str(lit) elif isinstance(lit, bool): return ast.NameConstant(lit) - else: + elif isinstance(lit, ast.AST): return lit + else: + raise AssertionError("'{}' of type {} is not able to be made into an AST node".format(lit, type(lit))) def __collapse_literal(node, ctxt): @@ -325,13 +345,14 @@ def visit_Assign(self, node): #print(node.value) # TODO: Support tuple assignments if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name): + nvalue = copy.deepcopy(node.value) var = node.targets[0].id - val = _constant_iterable(node.value, self.ctxt) + val = _constant_iterable(nvalue, self.ctxt) if val is not None: #print("Setting {} = {}".format(var, val)) self.ctxt[var] = val else: - val = _collapse_literal(node.value, self.ctxt) + val = _collapse_literal(nvalue, self.ctxt) #print("Setting {} = {}".format(var, val)) self.ctxt[var] = val else: @@ -428,6 +449,8 @@ def inner(f): except ImportError: # pragma: nocover raise ImportError("miniutils.pragma.{name} requires 'astor' to be installed to obtain source code" .format(name=name)) + except AssertionError as ex: # pragma: nocover + raise RuntimeError(astor.dump_tree(f_mod)) from ex else: source = None From b913fa65268e67e394d80664fa8c0a468fdaa7a1 Mon Sep 17 00:00:00 2001 From: scnerd Date: Sun, 26 Nov 2017 15:44:38 -0500 Subject: [PATCH 2/6] Initial effort to add contracts to pragma functionality --- docs/source/api.rst | 10 ++ miniutils/magic_contract.py | 12 ++- miniutils/pragma.py | 200 +++++++++++++++++++++++++++--------- 3 files changed, 172 insertions(+), 50 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index e524a95..919f70f 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -29,6 +29,16 @@ Python 2 .. automethod:: __init__ +Pragma +====== + +.. autofunction:: miniutils.pragma.unroll + +.. autofunction:: miniutils.pragma.collapse_literals + +.. autofunction:: miniutils.pragma.deindex + + Miscellaneous ============= diff --git a/miniutils/magic_contract.py b/miniutils/magic_contract.py index 5f9ea90..65e7139 100644 --- a/miniutils/magic_contract.py +++ b/miniutils/magic_contract.py @@ -3,6 +3,14 @@ from miniutils.opt_decorator import optional_argument_decorator +# TODO: Figure out efficient mechanism to only enable contracts during testing or debug modes + + +def safe_new_contract(name, *args, **kwargs): + if name not in _Ext.registrar: + new_contract(name, *args, **kwargs) + + @optional_argument_decorator def magic_contract(*args, **kwargs): """Drop-in replacement for ``pycontracts.contract`` decorator, except that it supports locally-visible types @@ -13,8 +21,8 @@ def magic_contract(*args, **kwargs): """ def inner_decorator(f): for name, val in f.__globals__.items(): - if not name.startswith('_') and name not in _Ext.registrar and isinstance(val, type): - new_contract(name, val) + if not name.startswith('_') and isinstance(val, type): + safe_new_contract(name, val) return contract(*args, **kwargs)(f) return inner_decorator diff --git a/miniutils/pragma.py b/miniutils/pragma.py index c001199..78c369c 100644 --- a/miniutils/pragma.py +++ b/miniutils/pragma.py @@ -9,15 +9,33 @@ import warnings from miniutils.opt_decorator import optional_argument_decorator +from miniutils.magic_contract import magic_contract, safe_new_contract +from contracts import ContractNotRespected try: import numpy + num_types = (int, float, numpy.number) float_types = (float, numpy.floating) except ImportError: # pragma: nocover num_types = (int, float) float_types = (float,) + +def is_iterable(x): + try: + iter(x) + return True + except Exception: + return False + + +safe_new_contract('function', lambda x: callable(x)) +safe_new_contract('iterable', is_iterable) +safe_new_contract('literal', 'int|float|str|bool|tuple|list|None') +for name, tp in inspect.getmembers(ast, inspect.isclass): + safe_new_contract(name, tp) + # Astor tries to get fancy by failing nicely, but in doing so they fail when traversing non-AST type node properties. # By deleting this custom handler, it'll fall back to the default ast visit pattern, which skips these missing # properties. Everything else seems to be implemented, so this will fail silently if it hits an AST node that isn't @@ -84,10 +102,15 @@ def pop(self): return self.dicts.pop() +@magic_contract def _function_ast(f): - """Returns ast for the given function. Gives a tuple of (ast_module, function_body, function_file""" - assert callable(f) - + """ + Returns ast for the given function. Gives a tuple of (ast_module, function_body, function_file + :param f: The function to parse + :type f: function + :return: The relevant AST code: A module including only the function definition; the func body; the func file + :rtype: tuple(Module, list(AST), str) + """ try: f_file = sys.modules[f.__module__].__file__ except (KeyError, AttributeError): @@ -97,7 +120,17 @@ def _function_ast(f): return root, root.body[0].body, f_file -def can_have_side_effect(node, ctxt): +@magic_contract +def _can_have_side_effect(node, ctxt): + """ + Checks whether or not copying the given AST node could cause side effects in the resulting function + :param node: The AST node to be checked + :type node: AST + :param ctxt: The environment stack to use when running the check + :type ctxt: DictStack + :return: Whether or not duplicating this node could cause side effects + :rtype: bool + """ if isinstance(node, ast.AST): # print("Can {} have side effects?".format(node)) if isinstance(node, ast.Call): @@ -106,9 +139,9 @@ def can_have_side_effect(node, ctxt): else: for field, old_value in ast.iter_fields(node): if isinstance(old_value, list): - return any(can_have_side_effect(n, ctxt) for n in old_value if isinstance(n, ast.AST)) + return any(_can_have_side_effect(n, ctxt) for n in old_value if isinstance(n, ast.AST)) elif isinstance(old_value, ast.AST): - return can_have_side_effect(old_value, ctxt) + return _can_have_side_effect(old_value, ctxt) else: # print(" No!") return False @@ -116,8 +149,20 @@ def can_have_side_effect(node, ctxt): return False +@magic_contract def _constant_iterable(node, ctxt, avoid_side_effects=True): - """If the given node is a known iterable of some sort, return the list of its elements.""" + """ + If the given node is a known iterable of some sort, return the list of its elements. + :param node: The AST node to be checked + :type node: AST + :param ctxt: The environment stack to use when running the check + :type ctxt: DictStack + :param avoid_side_effects: Whether or not to avoid unwrapping side effect-causing AST nodes + :type avoid_side_effects: bool + :return: The iterable if possible, else None + :rtype: iterable|None + """ + # TODO: Support zipping # TODO: Support sets/dicts? # TODO: Support for reversed, enumerate, etc. @@ -126,7 +171,7 @@ def _constant_iterable(node, ctxt, avoid_side_effects=True): def wrap(return_node, name, idx): if not avoid_side_effects: return return_node - if can_have_side_effect(return_node, ctxt): + if _can_have_side_effect(return_node, ctxt): return ast.Subscript(name, ast.Index(idx)) return _make_ast_from_literal(return_node) @@ -139,13 +184,13 @@ def wrap(return_node, name, idx): return None elif isinstance(node, (ast.List, ast.Tuple)): return [_collapse_literal(e, ctxt) for e in node.elts] - #return [_resolve_name_or_attribute(e, ctxt) for e in node.elts] + # return [_resolve_name_or_attribute(e, ctxt) for e in node.elts] # Can't yet support sets and lists, since you need to compute what the unique values would be # elif isinstance(node, ast.Dict): # return node.keys elif isinstance(node, (ast.Name, ast.Attribute, ast.NameConstant)): res = _resolve_name_or_attribute(node, ctxt) - #print("Trying to resolve '{}' as list, got {}".format(astor.to_source(node), res)) + # print("Trying to resolve '{}' as list, got {}".format(astor.to_source(node), res)) if isinstance(res, ast.AST) and not isinstance(res, (ast.Name, ast.Attribute, ast.NameConstant)): res = _constant_iterable(res, ctxt) if not isinstance(res, ast.AST): @@ -159,21 +204,26 @@ def wrap(return_node, name, idx): return None -def _resolve_name_or_attribute(node, ctxt_or_obj): - """If the given name of attribute is defined in the current context, return its value. Else, returns the node""" +@magic_contract +def _resolve_name_or_attribute(node, ctxt): + """ + If the given name of attribute is defined in the current context, return its value. Else, returns the node + :param node: The node to try to resolve + :type node: AST + :param ctxt: The environment stack to use when running the check + :type ctxt: DictStack + :return: The object if the name was found, else the original node + :rtype: * + """ if isinstance(node, ast.Name): - if isinstance(ctxt_or_obj, DictStack): - if node.id in ctxt_or_obj: - #print("Resolved '{}' to {}".format(node.id, ctxt_or_obj[node.id])) - return ctxt_or_obj[node.id] - else: - return node + if isinstance(ctxt, DictStack): + return ctxt[node.id] else: - return getattr(ctxt_or_obj, node.id, node) + return getattr(ctxt, node.id, node) elif isinstance(node, ast.NameConstant): return node.value elif isinstance(node, ast.Attribute): - base_obj = _resolve_name_or_attribute(node.value, ctxt_or_obj) + base_obj = _resolve_name_or_attribute(node.value, ctxt) if not isinstance(base_obj, ast.AST): return getattr(base_obj, node.attr, node) else: @@ -215,8 +265,15 @@ def _resolve_name_or_attribute(node, ctxt_or_obj): } +@magic_contract def _make_ast_from_literal(lit): - """Converts literals into their AST equivalent""" + """ + Converts literals into their AST equivalent + :param lit: The literal to attempt to turn into an AST + :type lit: literal|AST + :return: The AST version of the literal, or the original AST node if one was given + :rtype: AST + """ if isinstance(lit, (list, tuple)): res = [_make_ast_from_literal(e) for e in lit] tp = ast.List if isinstance(lit, list) else ast.Tuple @@ -233,14 +290,24 @@ def _make_ast_from_literal(lit): return ast.Str(lit) elif isinstance(lit, bool): return ast.NameConstant(lit) - elif isinstance(lit, ast.AST): - return lit - else: - raise AssertionError("'{}' of type {} is not able to be made into an AST node".format(lit, type(lit))) + # elif isinstance(lit, ast.AST): + # return lit + # else: + # raise AssertionError("'{}' of type {} is not able to be made into an AST node".format(lit, type(lit))) + return lit +@magic_contract def __collapse_literal(node, ctxt): - """Collapses literal expressions. Returns literals if they're available, AST nodes otherwise""" + """ + Collapses literal expressions. Returns literals if they're available, AST nodes otherwise + :param node: The AST node to be checked + :type node: AST + :param ctxt: The environment stack to use when running the check + :type ctxt: DictStack + :return: The given AST node with literal operations collapsed as much as possible + :rtype: literal|AST + """ if isinstance(node, (ast.Name, ast.Attribute, ast.NameConstant)): res = _resolve_name_or_attribute(node, ctxt) if isinstance(res, ast.AST) and not isinstance(res, (ast.Name, ast.Attribute, ast.NameConstant)): @@ -266,32 +333,40 @@ def __collapse_literal(node, ctxt): return node # print("Value at {}[{}] = {}".format(lst, slc, lst[slc])) return lst[slc] - elif isinstance(node, (ast.UnaryOp, ast.BinOp, ast.BoolOp)): - if isinstance(node, ast.UnaryOp): - operands = [__collapse_literal(node.operand, ctxt)] + elif isinstance(node, ast.UnaryOp): + operand = __collapse_literal(node.operand, ctxt) + if isinstance(operand, ast.AST): + return node else: - operands = [__collapse_literal(o, ctxt) for o in [node.left, node.right]] - #print("({} {})".format(repr(node.op), ", ".join(repr(o) for o in operands))) + return _collapse_map[node.op](operand) + elif isinstance(node, (ast.BinOp, ast.BoolOp)): + operands = [__collapse_literal(o, ctxt) for o in [node.left, node.right]] + # print("({} {})".format(repr(node.op), ", ".join(repr(o) for o in operands))) is_literal = [not isinstance(opr, ast.AST) for opr in operands] if all(is_literal): try: val = _collapse_map[type(node.op)](*operands) return val except: - warnings.warn("Literal collapse failed. Collapsing skipped, but executing this function will likely fail." - " Error was:\n{}".format(traceback.format_exc())) + warnings.warn( + "Literal collapse failed. Collapsing skipped, but executing this function will likely fail." + " Error was:\n{}".format(traceback.format_exc())) + return node + elif any(is_literal): + try: + if isinstance(node, ast.UnaryOp): + return ast.UnaryOp(operand=_make_ast_from_literal(operands[0]), op=node.op) + else: + return type(node)(left=_make_ast_from_literal(operands[0]), + right=_make_ast_from_literal(operands[1]), + op=node.op) + except (AssertionError, ContractNotRespected): + warnings.warn("Unable to re-pack {tp} with {ops}".format(tp=type(node), ops=operands)) return node - else: - if isinstance(node, ast.UnaryOp): - return ast.UnaryOp(operand=_make_ast_from_literal(operands[0]), op=node.op) - else: - return type(node)(left=_make_ast_from_literal(operands[0]), - right=_make_ast_from_literal(operands[1]), - op=node.op) elif isinstance(node, ast.Compare): operands = [__collapse_literal(o, ctxt) for o in [node.left] + node.comparators] if all(not isinstance(opr, ast.AST) for opr in operands): - return all(_collapse_map[type(cmp_func)](operands[i-1], operands[i]) + return all(_collapse_map[type(cmp_func)](operands[i - 1], operands[i]) for i, cmp_func in zip(range(1, len(operands)), node.ops)) else: return node @@ -299,16 +374,32 @@ def __collapse_literal(node, ctxt): return node +@magic_contract def _collapse_literal(node, ctxt): - """Collapse literal expressions in the given node. Returns the node with the collapsed literals""" + """ + Collapse literal expressions in the given node. Returns the node with the collapsed literals + :param node: The AST node to be checked + :type node: AST + :param ctxt: The environment stack to use when running the check + :type ctxt: DictStack + :return: The given AST node with literal operations collapsed as much as possible + :rtype: AST + """ return _make_ast_from_literal(__collapse_literal(node, ctxt)) +@magic_contract def _assign_names(node): - """Gets names from a assign-to tuple in flat form, just to know what's affected + """ + Gets names from a assign-to tuple in flat form, just to know what's affected "x=3" -> "x" "a,b=4,5" -> ["a", "b"] "(x,(y,z)),(a,) = something" -> ["x", "y", "z", "a"] + + :param node: The AST node to resolve to a list of names + :type node: Name|Tuple + :return: The flattened list of names referenced in this node + :rtype: iterable """ if isinstance(node, ast.Name): yield node.id @@ -342,18 +433,18 @@ def __init__(self, ctxt=None): def visit_Assign(self, node): node.value = self.visit(node.value) - #print(node.value) + # print(node.value) # TODO: Support tuple assignments if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name): nvalue = copy.deepcopy(node.value) var = node.targets[0].id val = _constant_iterable(nvalue, self.ctxt) if val is not None: - #print("Setting {} = {}".format(var, val)) + # print("Setting {} = {}".format(var, val)) self.ctxt[var] = val else: val = _collapse_literal(nvalue, self.ctxt) - #print("Setting {} = {}".format(var, val)) + # print("Setting {} = {}".format(var, val)) self.ctxt[var] = val else: for targ in node.targets: @@ -420,14 +511,22 @@ def visit_Subscript(self, node): def _make_function_transformer(transformer_type, name, description): @optional_argument_decorator + @magic_contract def transform(return_source=False, save_source=True, function_globals=None, **kwargs): """ :param return_source: Returns the unrolled function's source code instead of compiling it + :type return_source: bool :param save_source: Saves the function source code to a tempfile to make it inspectable + :type save_source: bool + :param function_globals: Overridden global name assignments to use when processing the function + :type function_globals: dict|None :param kwargs: Any other environmental variables to provide during unrolling + :type kwargs: dict :return: The unrolled function, or its source code if requested + :rtype: function """ + @magic_contract(f='function', returns='function|str') def inner(f): f_mod, f_body, f_file = _function_ast(f) # Grab function globals @@ -449,7 +548,7 @@ def inner(f): except ImportError: # pragma: nocover raise ImportError("miniutils.pragma.{name} requires 'astor' to be installed to obtain source code" .format(name=name)) - except AssertionError as ex: # pragma: nocover + except Exception as ex: # pragma: nocover raise RuntimeError(astor.dump_tree(f_mod)) from ex else: source = None @@ -471,6 +570,7 @@ def inner(f): return func return inner + transform.__name__ = name transform.__doc__ = '\n'.join([description, transform.__doc__]) return transform @@ -483,15 +583,20 @@ def inner(f): collapse_literals = _make_function_transformer(CollapseTransformer, 'collapse_literals', "Collapses literal expressions in the decorated function into single literals") + # Directly reference elements of constant list, removing literal indexing into that list within a function def deindex(iterable, iterable_name, *args, **kwargs): """ :param iterable: The list to deindex in the target function + :type iterable: iterable :param iterable_name: The list's name (must be unique if deindexing multiple lists) + :type iterable_name: str :param return_source: Returns the unrolled function's source code instead of compiling it :param save_source: Saves the function source code to a tempfile to make it inspectable :param kwargs: Any other environmental variables to provide during unrolling + :type kwargs: dict :return: The unrolled function, or its source code if requested + :rtype: function """ if hasattr(iterable, 'items'): # Support dicts and the like @@ -506,4 +611,3 @@ def deindex(iterable, iterable_name, *args, **kwargs): return collapse_literals(*args, function_globals=mapping, **kwargs) # Inline functions? - From dce4042f238425a0c818a59691de49c533861e05 Mon Sep 17 00:00:00 2001 From: scnerd Date: Sun, 26 Nov 2017 16:15:49 -0500 Subject: [PATCH 3/6] [WIP] fixing up binary operation error --- miniutils/__init__.py | 3 ++ miniutils/pragma.py | 65 ++++++++++++++++++++++--------------------- 2 files changed, 36 insertions(+), 32 deletions(-) diff --git a/miniutils/__init__.py b/miniutils/__init__.py index 0e2b7dd..baaae4c 100644 --- a/miniutils/__init__.py +++ b/miniutils/__init__.py @@ -1,3 +1,6 @@ +import contracts +contracts.disable_all() + from .caching import CachedProperty from .magic_contract import magic_contract from .opt_decorator import optional_argument_decorator diff --git a/miniutils/pragma.py b/miniutils/pragma.py index 78c369c..90f59d5 100644 --- a/miniutils/pragma.py +++ b/miniutils/pragma.py @@ -77,7 +77,7 @@ def __delitem__(self, item): def __contains__(self, item): try: - self[item] + _ = self[item] return True except KeyError: return False @@ -216,10 +216,10 @@ def _resolve_name_or_attribute(node, ctxt): :rtype: * """ if isinstance(node, ast.Name): - if isinstance(ctxt, DictStack): + if node.id in ctxt: return ctxt[node.id] else: - return getattr(ctxt, node.id, node) + return node elif isinstance(node, ast.NameConstant): return node.value elif isinstance(node, ast.Attribute): @@ -339,29 +339,30 @@ def __collapse_literal(node, ctxt): return node else: return _collapse_map[node.op](operand) - elif isinstance(node, (ast.BinOp, ast.BoolOp)): - operands = [__collapse_literal(o, ctxt) for o in [node.left, node.right]] + elif isinstance(node, ast.BinOp): + left = __collapse_literal(node.left, ctxt) + right = __collapse_literal(node.right, ctxt) # print("({} {})".format(repr(node.op), ", ".join(repr(o) for o in operands))) - is_literal = [not isinstance(opr, ast.AST) for opr in operands] - if all(is_literal): + lliteral = not isinstance(left, ast.AST) + rliteral = not isinstance(right, ast.AST) + if lliteral and rliteral: try: - val = _collapse_map[type(node.op)](*operands) + val = _collapse_map[type(node.op)](left, right) return val except: warnings.warn( "Literal collapse failed. Collapsing skipped, but executing this function will likely fail." " Error was:\n{}".format(traceback.format_exc())) return node - elif any(is_literal): + elif lliteral or rliteral: try: - if isinstance(node, ast.UnaryOp): - return ast.UnaryOp(operand=_make_ast_from_literal(operands[0]), op=node.op) - else: - return type(node)(left=_make_ast_from_literal(operands[0]), - right=_make_ast_from_literal(operands[1]), - op=node.op) + print((left, _make_ast_from_literal(left))) + print((right, _make_ast_from_literal(right))) + return ast.BinOp(left=_make_ast_from_literal(left), + right=_make_ast_from_literal(right), + op=node.op) except (AssertionError, ContractNotRespected): - warnings.warn("Unable to re-pack {tp} with {ops}".format(tp=type(node), ops=operands)) + warnings.warn("Unable to re-pack {tp} with {l}, {r}".format(tp=type(node), l=left, r=right)) return node elif isinstance(node, ast.Compare): operands = [__collapse_literal(o, ctxt) for o in [node.left] + node.comparators] @@ -414,22 +415,22 @@ def __init__(self, ctxt=None): self.ctxt = ctxt or DictStack() super().__init__() - # def visit(self, node): - # orig_node = copy.deepcopy(node) - # new_node = super().visit(node) - # - # orig_node_code = astor.to_source(orig_node).strip() - # try: - # if new_node is None: - # print("Deleted >>> {} <<<".format(orig_node_code)) - # elif isinstance(new_node, ast.AST): - # print("Converted >>> {} <<< to >>> {} <<<".format(orig_node_code, astor.to_source(new_node).strip())) - # elif isinstance(new_node, list): - # print("Converted >>> {} <<< to [[[ {} ]]]".format(orig_node_code, ", ".join(astor.to_source(n).strip() for n in new_node))) - # except AssertionError as ex: - # raise AssertionError("Failed on {} >>> {}".format(orig_node_code, astor.dump_tree(new_node))) from ex - # - # return new_node + def visit(self, node): + orig_node = copy.deepcopy(node) + new_node = super().visit(node) + + orig_node_code = astor.to_source(orig_node).strip() + try: + if new_node is None: + print("Deleted >>> {} <<<".format(orig_node_code)) + elif isinstance(new_node, ast.AST): + print("Converted >>> {} <<< to >>> {} <<<".format(orig_node_code, astor.to_source(new_node).strip())) + elif isinstance(new_node, list): + print("Converted >>> {} <<< to [[[ {} ]]]".format(orig_node_code, ", ".join(astor.to_source(n).strip() for n in new_node))) + except AssertionError as ex: + raise AssertionError("Failed on {} >>> {}".format(orig_node_code, astor.dump_tree(new_node))) from ex + + return new_node def visit_Assign(self, node): node.value = self.visit(node.value) From 058d7d44509593614330c301730f89e400ce9458 Mon Sep 17 00:00:00 2001 From: scnerd Date: Mon, 4 Dec 2017 21:49:05 -0500 Subject: [PATCH 4/6] A bit more bug fixing, still haven't found the main issue causing op collapsing to fail --- miniutils/pragma.py | 36 ++++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/miniutils/pragma.py b/miniutils/pragma.py index 90f59d5..9135fac 100644 --- a/miniutils/pragma.py +++ b/miniutils/pragma.py @@ -322,23 +322,31 @@ def __collapse_literal(node, ctxt): elif isinstance(node, (ast.Slice, ast.ExtSlice)): raise NotImplemented() elif isinstance(node, ast.Subscript): - # print("Attempting to subscript {}".format(astor.to_source(node))) + print("Attempting to subscript {}".format(astor.to_source(node))) lst = _constant_iterable(node.value, ctxt) - # print("Can I subscript {}?".format(lst)) + print("Can I subscript {}?".format(lst)) if lst is None: return node slc = __collapse_literal(node.slice, ctxt) - # print("Getting subscript at {}".format(slc)) + print("Getting subscript at {}".format(slc)) if isinstance(slc, ast.AST): return node - # print("Value at {}[{}] = {}".format(lst, slc, lst[slc])) - return lst[slc] + print("Value at {}[{}] = {}".format(lst, slc, lst[slc])) + val = lst[slc] + if isinstance(val, ast.AST): + val = __collapse_literal(val, ctxt) + return val elif isinstance(node, ast.UnaryOp): operand = __collapse_literal(node.operand, ctxt) if isinstance(operand, ast.AST): return node else: - return _collapse_map[node.op](operand) + try: + return _collapse_map[node.op](operand) + except: + warnings.warn( + "Unary op collapse failed. Collapsing skipped, but executing this function will likely fail." + " Error was:\n{}".format(traceback.format_exc())) elif isinstance(node, ast.BinOp): left = __collapse_literal(node.left, ctxt) right = __collapse_literal(node.right, ctxt) @@ -347,23 +355,24 @@ def __collapse_literal(node, ctxt): rliteral = not isinstance(right, ast.AST) if lliteral and rliteral: try: - val = _collapse_map[type(node.op)](left, right) - return val + return _collapse_map[type(node.op)](left, right) except: warnings.warn( - "Literal collapse failed. Collapsing skipped, but executing this function will likely fail." + "Binary op collapse failed. Collapsing skipped, but executing this function will likely fail." " Error was:\n{}".format(traceback.format_exc())) return node elif lliteral or rliteral: try: - print((left, _make_ast_from_literal(left))) - print((right, _make_ast_from_literal(right))) + print(('left', left, _make_ast_from_literal(left))) + print(('right', right, _make_ast_from_literal(right))) return ast.BinOp(left=_make_ast_from_literal(left), right=_make_ast_from_literal(right), op=node.op) except (AssertionError, ContractNotRespected): warnings.warn("Unable to re-pack {tp} with {l}, {r}".format(tp=type(node), l=left, r=right)) return node + else: + return node elif isinstance(node, ast.Compare): operands = [__collapse_literal(o, ctxt) for o in [node.left] + node.comparators] if all(not isinstance(opr, ast.AST) for opr in operands): @@ -386,7 +395,10 @@ def _collapse_literal(node, ctxt): :return: The given AST node with literal operations collapsed as much as possible :rtype: AST """ - return _make_ast_from_literal(__collapse_literal(node, ctxt)) + result = _make_ast_from_literal(__collapse_literal(node, ctxt)) + if not isinstance(result, ast.AST): + return node + return result @magic_contract From f0c56337e5bf496126649f77932f9630903b54e5 Mon Sep 17 00:00:00 2001 From: scnerd Date: Tue, 5 Dec 2017 14:18:43 -0500 Subject: [PATCH 5/6] Fixed pragma errors, contracts are mostly implemented and working with tests. --- .travis.yml | 2 +- miniutils/__init__.py | 3 - miniutils/magic_contract.py | 2 +- miniutils/pragma.py | 122 ++++++++++++++++++++++++++--------- tests/test_magic_contract.py | 8 +++ tests/test_pragma.py | 53 +++++++++++++-- 6 files changed, 148 insertions(+), 42 deletions(-) diff --git a/.travis.yml b/.travis.yml index f1206fc..b1f4a2d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,7 +7,7 @@ install: - pip install . - pip install -r requirements.txt -script: nosetests --with-coverage tests +script: nosetests --with-coverage --nologcapture tests after_success: coveralls diff --git a/miniutils/__init__.py b/miniutils/__init__.py index baaae4c..0e2b7dd 100644 --- a/miniutils/__init__.py +++ b/miniutils/__init__.py @@ -1,6 +1,3 @@ -import contracts -contracts.disable_all() - from .caching import CachedProperty from .magic_contract import magic_contract from .opt_decorator import optional_argument_decorator diff --git a/miniutils/magic_contract.py b/miniutils/magic_contract.py index 65e7139..0ee497c 100644 --- a/miniutils/magic_contract.py +++ b/miniutils/magic_contract.py @@ -1,4 +1,4 @@ -from contracts import contract, new_contract +from contracts import * from contracts.library import Extension as _Ext from miniutils.opt_decorator import optional_argument_decorator diff --git a/miniutils/pragma.py b/miniutils/pragma.py index 9135fac..0920754 100644 --- a/miniutils/pragma.py +++ b/miniutils/pragma.py @@ -58,6 +58,7 @@ def __init__(self, *base): self.constants = [True] + [False] * len(base) def __setitem__(self, key, value): + print("SETTING {} = {}".format(key, value)) self.dicts[-1][key] = value def __getitem__(self, item): @@ -270,11 +271,13 @@ def _make_ast_from_literal(lit): """ Converts literals into their AST equivalent :param lit: The literal to attempt to turn into an AST - :type lit: literal|AST + :type lit: * :return: The AST version of the literal, or the original AST node if one was given - :rtype: AST + :rtype: * """ - if isinstance(lit, (list, tuple)): + if isinstance(lit, ast.AST): + return lit + elif isinstance(lit, (list, tuple)): res = [_make_ast_from_literal(e) for e in lit] tp = ast.List if isinstance(lit, list) else ast.Tuple return tp(elts=res) @@ -290,11 +293,21 @@ def _make_ast_from_literal(lit): return ast.Str(lit) elif isinstance(lit, bool): return ast.NameConstant(lit) - # elif isinstance(lit, ast.AST): - # return lit - # else: - # raise AssertionError("'{}' of type {} is not able to be made into an AST node".format(lit, type(lit))) - return lit + else: + # warnings.warn("'{}' of type {} is not able to be made into an AST node".format(lit, type(lit))) + return lit + + +@magic_contract +def _is_wrappable(lit): + """ + Checks if the given object either is or can be made into a known AST node + :param lit: The object to try to wrap + :type lit: * + :return: Whether or not this object can be wrapped as an AST node + :rtype: bool + """ + return isinstance(_make_ast_from_literal(lit), ast.AST) @magic_contract @@ -306,12 +319,22 @@ def __collapse_literal(node, ctxt): :param ctxt: The environment stack to use when running the check :type ctxt: DictStack :return: The given AST node with literal operations collapsed as much as possible - :rtype: literal|AST + :rtype: * """ + try: + print("Trying to collapse {}".format(astor.to_source(node))) + except: + print("Trying to collapse (source not possible) {}".format(astor.dump_tree(node))) + if isinstance(node, (ast.Name, ast.Attribute, ast.NameConstant)): res = _resolve_name_or_attribute(node, ctxt) if isinstance(res, ast.AST) and not isinstance(res, (ast.Name, ast.Attribute, ast.NameConstant)): - res = __collapse_literal(res, ctxt) + new_res = __collapse_literal(res, ctxt) + if _is_wrappable(new_res): + print("{} can be replaced by more specific literal {}".format(res, new_res)) + res = new_res + else: + print("{} is an AST node, but can't safely be made more specific".format(res)) return res elif isinstance(node, ast.Num): return node.n @@ -322,19 +345,25 @@ def __collapse_literal(node, ctxt): elif isinstance(node, (ast.Slice, ast.ExtSlice)): raise NotImplemented() elif isinstance(node, ast.Subscript): - print("Attempting to subscript {}".format(astor.to_source(node))) + # print("Attempting to subscript {}".format(astor.to_source(node))) lst = _constant_iterable(node.value, ctxt) - print("Can I subscript {}?".format(lst)) + # print("Can I subscript {}?".format(lst)) if lst is None: return node slc = __collapse_literal(node.slice, ctxt) - print("Getting subscript at {}".format(slc)) + # print("Getting subscript at {}".format(slc)) if isinstance(slc, ast.AST): return node - print("Value at {}[{}] = {}".format(lst, slc, lst[slc])) + # print("Value at {}[{}] = {}".format(lst, slc, lst[slc])) val = lst[slc] if isinstance(val, ast.AST): - val = __collapse_literal(val, ctxt) + new_val = __collapse_literal(val, ctxt) + if _is_wrappable(new_val): + print("{} can be replaced by more specific literal {}".format(val, new_val)) + val = new_val + else: + print("{} is an AST node, but can't safely be made more specific".format(val)) + print("Final value at {}[{}] = {}".format(lst, slc, val)) return val elif isinstance(node, ast.UnaryOp): operand = __collapse_literal(node.operand, ctxt) @@ -354,6 +383,7 @@ def __collapse_literal(node, ctxt): lliteral = not isinstance(left, ast.AST) rliteral = not isinstance(right, ast.AST) if lliteral and rliteral: + print("Both operands {} and {} are literals, attempting to collapse".format(left, right)) try: return _collapse_map[type(node.op)](left, right) except: @@ -361,18 +391,14 @@ def __collapse_literal(node, ctxt): "Binary op collapse failed. Collapsing skipped, but executing this function will likely fail." " Error was:\n{}".format(traceback.format_exc())) return node - elif lliteral or rliteral: - try: - print(('left', left, _make_ast_from_literal(left))) - print(('right', right, _make_ast_from_literal(right))) - return ast.BinOp(left=_make_ast_from_literal(left), - right=_make_ast_from_literal(right), - op=node.op) - except (AssertionError, ContractNotRespected): - warnings.warn("Unable to re-pack {tp} with {l}, {r}".format(tp=type(node), l=left, r=right)) - return node else: - return node + left = _make_ast_from_literal(left) + left = left if isinstance(left, ast.AST) else node.left + + right = _make_ast_from_literal(right) + right = right if isinstance(right, ast.AST) else node.right + print("Attempting to combine {} and {} ({} op)".format(left, right, node.op)) + return ast.BinOp(left=left, right=right, op=node.op) elif isinstance(node, ast.Compare): operands = [__collapse_literal(o, ctxt) for o in [node.left] + node.comparators] if all(not isinstance(opr, ast.AST) for opr in operands): @@ -385,7 +411,7 @@ def __collapse_literal(node, ctxt): @magic_contract -def _collapse_literal(node, ctxt): +def _collapse_literal(node, ctxt, give_raw_result=False): """ Collapse literal expressions in the given node. Returns the node with the collapsed literals :param node: The AST node to be checked @@ -393,9 +419,12 @@ def _collapse_literal(node, ctxt): :param ctxt: The environment stack to use when running the check :type ctxt: DictStack :return: The given AST node with literal operations collapsed as much as possible - :rtype: AST + :rtype: * """ - result = _make_ast_from_literal(__collapse_literal(node, ctxt)) + result = __collapse_literal(node, ctxt) + if give_raw_result: + return result + result = _make_ast_from_literal(result) if not isinstance(result, ast.AST): return node return result @@ -419,6 +448,8 @@ def _assign_names(node): elif isinstance(node, ast.Tuple): for e in node.elts: yield from _assign_names(e) + elif isinstance(node, ast.Subscript): + raise NotImplemented() # noinspection PyPep8Naming @@ -431,16 +462,18 @@ def visit(self, node): orig_node = copy.deepcopy(node) new_node = super().visit(node) - orig_node_code = astor.to_source(orig_node).strip() try: + orig_node_code = astor.to_source(orig_node).strip() if new_node is None: print("Deleted >>> {} <<<".format(orig_node_code)) elif isinstance(new_node, ast.AST): print("Converted >>> {} <<< to >>> {} <<<".format(orig_node_code, astor.to_source(new_node).strip())) elif isinstance(new_node, list): print("Converted >>> {} <<< to [[[ {} ]]]".format(orig_node_code, ", ".join(astor.to_source(n).strip() for n in new_node))) - except AssertionError as ex: - raise AssertionError("Failed on {} >>> {}".format(orig_node_code, astor.dump_tree(new_node))) from ex + except Exception as ex: + raise AssertionError("Failed on {} >>> {}".format(astor.dump_tree(orig_node), astor.dump_tree(new_node))) from ex + # print("Failed on {} >>> {}".format(astor.dump_tree(orig_node), astor.dump_tree(new_node))) + # return orig_node return new_node @@ -465,6 +498,11 @@ def visit_Assign(self, node): self.ctxt[assgn] = None return node + def visit_AugAssign(self, node): + for assgn in _assign_names(node.target): + self.ctxt[assgn] = None + return super().generic_visit(node) + # noinspection PyPep8Naming class UnrollTransformer(TrackedContextTransformer): @@ -506,6 +544,9 @@ def visit_Name(self, node): # noinspection PyPep8Naming class CollapseTransformer(TrackedContextTransformer): + def visit_Name(self, node): + return _collapse_literal(node, self.ctxt) + def visit_BinOp(self, node): return _collapse_literal(node, self.ctxt) @@ -521,6 +562,19 @@ def visit_Compare(self, node): def visit_Subscript(self, node): return _collapse_literal(node, self.ctxt) + def visit_If(self, node): + cond = _collapse_literal(node.test, self.ctxt, True) + print("Attempting to collapse IF conditioned on {}".format(cond)) + if not isinstance(cond, ast.AST): + print("Yes, this IF can be consolidated, condition is {}".format(bool(cond))) + if cond: + return [self.visit(b) for b in node.body] + else: + return [self.visit(b) for b in node.orelse] + else: + print("No, this IF cannot be consolidated") + return super().generic_visit(node) + def _make_function_transformer(transformer_type, name, description): @optional_argument_decorator @@ -624,3 +678,7 @@ def deindex(iterable, iterable_name, *args, **kwargs): return collapse_literals(*args, function_globals=mapping, **kwargs) # Inline functions? +# You could do something like: +# args, kwargs = (args_in), (kwargs_in) +# function_body +# result = return_expr diff --git a/tests/test_magic_contract.py b/tests/test_magic_contract.py index 2793fa5..fcb5a1d 100644 --- a/tests/test_magic_contract.py +++ b/tests/test_magic_contract.py @@ -1,5 +1,8 @@ from unittest import TestCase + +import contracts +contracts.enable_all() import functools from miniutils.magic_contract import magic_contract from contracts.interface import ContractNotRespected @@ -15,6 +18,7 @@ def fib(n): :return: The fibonnaci number :rtype: int,>=0 """ + assert n >= 0; if n == 0: return 0 elif n == 1: @@ -44,6 +48,10 @@ def sample_func(a): class TestMagicContract(TestCase): + def setUp(self): + import contracts + contracts.enable_all() + def test_magic_contract_1(self): fib(3) fib(5) diff --git a/tests/test_pragma.py b/tests/test_pragma.py index 1609259..27e8b81 100644 --- a/tests/test_pragma.py +++ b/tests/test_pragma.py @@ -4,7 +4,15 @@ import inspect -class TestUnroll(TestCase): +class PragmaTest(TestCase): + def setUp(self): + pass + # # This is a quick hack to disable contracts for testing if needed + # import contracts + # contracts.enable_all() + + +class TestUnroll(PragmaTest): def test_unroll_range(self): @pragma.unroll def f(): @@ -251,7 +259,7 @@ def f(): self.assertEqual(f.strip(), result.strip()) -class TestCollapseLiterals(TestCase): +class TestCollapseLiterals(PragmaTest): def test_full_run(self): def f(y): x = 3 @@ -371,7 +379,42 @@ def f(): self.assertTrue(issubclass(w[-1].category, UserWarning)) -class TestDeindex(TestCase): + # TODO: implement the features to get this test to work + # def test_conditional_erasure(self): + # @pragma.collapse_literals(return_source=True) + # def f(y): + # x = 0 + # if y == x: + # x = 1 + # return x + # + # result = dedent(''' + # def f(y): + # x = 0 + # if y == 0: + # x = 1 + # return x + # ''') + # self.assertEqual(f.strip(), result.strip()) + + def test_constant_conditional_erasure(self): + @pragma.collapse_literals(return_source=True) + def f(y): + x = 0 + if x <= 0: + x = 1 + return x + + result = dedent(''' + def f(y): + x = 0 + x = 1 + return 1 + ''') + self.assertEqual(f.strip(), result.strip()) + + +class TestDeindex(PragmaTest): def test_with_literals(self): v = [1, 2, 3] @pragma.collapse_literals(return_source=True) @@ -395,7 +438,7 @@ def f(): def f(): return v_0 + v_1 + v_2 ''') - self.assertEqual(f.strip(), result.strip()) + self.assertEqual(result.strip(), f.strip()) def test_with_unroll(self): v = [None, None, None] @@ -476,7 +519,7 @@ def run_func(i, x): self.assertEqual(inspect.getsource(run_func).strip(), result.strip()) -class TestDictStack(TestCase): +class TestDictStack(PragmaTest): def test_most(self): stack = pragma.DictStack() stack.push({'x': 3}) From 16dc6924b4de33fb32ac6bfdc0fe9fce6fa6913d Mon Sep 17 00:00:00 2001 From: scnerd Date: Tue, 5 Dec 2017 14:33:40 -0500 Subject: [PATCH 6/6] More bug fixes, removed debug printouts --- miniutils/pragma.py | 90 ++++++++++++++++++++++++-------------------- tests/test_pragma.py | 34 ++++++++++++++--- 2 files changed, 77 insertions(+), 47 deletions(-) diff --git a/miniutils/pragma.py b/miniutils/pragma.py index 0920754..0821b4e 100644 --- a/miniutils/pragma.py +++ b/miniutils/pragma.py @@ -58,7 +58,7 @@ def __init__(self, *base): self.constants = [True] + [False] * len(base) def __setitem__(self, key, value): - print("SETTING {} = {}".format(key, value)) + # print("SETTING {} = {}".format(key, value)) self.dicts[-1][key] = value def __getitem__(self, item): @@ -321,20 +321,20 @@ def __collapse_literal(node, ctxt): :return: The given AST node with literal operations collapsed as much as possible :rtype: * """ - try: - print("Trying to collapse {}".format(astor.to_source(node))) - except: - print("Trying to collapse (source not possible) {}".format(astor.dump_tree(node))) + # try: + # print("Trying to collapse {}".format(astor.to_source(node))) + # except: + # print("Trying to collapse (source not possible) {}".format(astor.dump_tree(node))) if isinstance(node, (ast.Name, ast.Attribute, ast.NameConstant)): res = _resolve_name_or_attribute(node, ctxt) if isinstance(res, ast.AST) and not isinstance(res, (ast.Name, ast.Attribute, ast.NameConstant)): new_res = __collapse_literal(res, ctxt) if _is_wrappable(new_res): - print("{} can be replaced by more specific literal {}".format(res, new_res)) + # print("{} can be replaced by more specific literal {}".format(res, new_res)) res = new_res - else: - print("{} is an AST node, but can't safely be made more specific".format(res)) + # else: + # print("{} is an AST node, but can't safely be made more specific".format(res)) return res elif isinstance(node, ast.Num): return node.n @@ -359,11 +359,11 @@ def __collapse_literal(node, ctxt): if isinstance(val, ast.AST): new_val = __collapse_literal(val, ctxt) if _is_wrappable(new_val): - print("{} can be replaced by more specific literal {}".format(val, new_val)) + # print("{} can be replaced by more specific literal {}".format(val, new_val)) val = new_val - else: - print("{} is an AST node, but can't safely be made more specific".format(val)) - print("Final value at {}[{}] = {}".format(lst, slc, val)) + # else: + # print("{} is an AST node, but can't safely be made more specific".format(val)) + # print("Final value at {}[{}] = {}".format(lst, slc, val)) return val elif isinstance(node, ast.UnaryOp): operand = __collapse_literal(node.operand, ctxt) @@ -383,7 +383,7 @@ def __collapse_literal(node, ctxt): lliteral = not isinstance(left, ast.AST) rliteral = not isinstance(right, ast.AST) if lliteral and rliteral: - print("Both operands {} and {} are literals, attempting to collapse".format(left, right)) + # print("Both operands {} and {} are literals, attempting to collapse".format(left, right)) try: return _collapse_map[type(node.op)](left, right) except: @@ -397,7 +397,7 @@ def __collapse_literal(node, ctxt): right = _make_ast_from_literal(right) right = right if isinstance(right, ast.AST) else node.right - print("Attempting to combine {} and {} ({} op)".format(left, right, node.op)) + # print("Attempting to combine {} and {} ({} op)".format(left, right, node.op)) return ast.BinOp(left=left, right=right, op=node.op) elif isinstance(node, ast.Compare): operands = [__collapse_literal(o, ctxt) for o in [node.left] + node.comparators] @@ -458,24 +458,24 @@ def __init__(self, ctxt=None): self.ctxt = ctxt or DictStack() super().__init__() - def visit(self, node): - orig_node = copy.deepcopy(node) - new_node = super().visit(node) - - try: - orig_node_code = astor.to_source(orig_node).strip() - if new_node is None: - print("Deleted >>> {} <<<".format(orig_node_code)) - elif isinstance(new_node, ast.AST): - print("Converted >>> {} <<< to >>> {} <<<".format(orig_node_code, astor.to_source(new_node).strip())) - elif isinstance(new_node, list): - print("Converted >>> {} <<< to [[[ {} ]]]".format(orig_node_code, ", ".join(astor.to_source(n).strip() for n in new_node))) - except Exception as ex: - raise AssertionError("Failed on {} >>> {}".format(astor.dump_tree(orig_node), astor.dump_tree(new_node))) from ex - # print("Failed on {} >>> {}".format(astor.dump_tree(orig_node), astor.dump_tree(new_node))) - # return orig_node - - return new_node + # def visit(self, node): + # orig_node = copy.deepcopy(node) + # new_node = super().visit(node) + # + # try: + # orig_node_code = astor.to_source(orig_node).strip() + # if new_node is None: + # print("Deleted >>> {} <<<".format(orig_node_code)) + # elif isinstance(new_node, ast.AST): + # print("Converted >>> {} <<< to >>> {} <<<".format(orig_node_code, astor.to_source(new_node).strip())) + # elif isinstance(new_node, list): + # print("Converted >>> {} <<< to [[[ {} ]]]".format(orig_node_code, ", ".join(astor.to_source(n).strip() for n in new_node))) + # except Exception as ex: + # raise AssertionError("Failed on {} >>> {}".format(astor.dump_tree(orig_node), astor.dump_tree(new_node))) from ex + # # print("Failed on {} >>> {}".format(astor.dump_tree(orig_node), astor.dump_tree(new_node))) + # # return orig_node + # + # return new_node def visit_Assign(self, node): node.value = self.visit(node.value) @@ -564,15 +564,22 @@ def visit_Subscript(self, node): def visit_If(self, node): cond = _collapse_literal(node.test, self.ctxt, True) - print("Attempting to collapse IF conditioned on {}".format(cond)) + # print("Attempting to collapse IF conditioned on {}".format(cond)) if not isinstance(cond, ast.AST): - print("Yes, this IF can be consolidated, condition is {}".format(bool(cond))) - if cond: - return [self.visit(b) for b in node.body] - else: - return [self.visit(b) for b in node.orelse] + # print("Yes, this IF can be consolidated, condition is {}".format(bool(cond))) + body = node.body if cond else node.orelse + result = [] + for subnode in body: + res = self.visit(subnode) + if res is None: + pass + elif isinstance(res, list): + result += res + else: + result.append(res) + return result else: - print("No, this IF cannot be consolidated") + # print("No, this IF cannot be consolidated") return super().generic_visit(node) @@ -652,14 +659,15 @@ def inner(f): # Directly reference elements of constant list, removing literal indexing into that list within a function +@magic_contract def deindex(iterable, iterable_name, *args, **kwargs): """ :param iterable: The list to deindex in the target function :type iterable: iterable :param iterable_name: The list's name (must be unique if deindexing multiple lists) :type iterable_name: str - :param return_source: Returns the unrolled function's source code instead of compiling it - :param save_source: Saves the function source code to a tempfile to make it inspectable + :param args: Other command line arguments (see :func:`unroll` for documentation) + :type args: tuple :param kwargs: Any other environmental variables to provide during unrolling :type kwargs: dict :return: The unrolled function, or its source code if requested diff --git a/tests/test_pragma.py b/tests/test_pragma.py index 27e8b81..ae6d561 100644 --- a/tests/test_pragma.py +++ b/tests/test_pragma.py @@ -273,16 +273,12 @@ def f(y): return r import inspect - print(inspect.getsource(f)) - print(pragma.collapse_literals(return_source=True)(f)) deco_f = pragma.collapse_literals(f) self.assertEqual(f(0), deco_f(0)) self.assertEqual(f(1), deco_f(1)) self.assertEqual(f(5), deco_f(5)) self.assertEqual(f(-1), deco_f(-1)) - print(inspect.getsource(f)) - print(pragma.collapse_literals(return_source=True)(pragma.unroll(f))) deco_f = pragma.collapse_literals(pragma.unroll(f)) self.assertEqual(f(0), deco_f(0)) self.assertEqual(f(1), deco_f(1)) @@ -413,6 +409,34 @@ def f(y): ''') self.assertEqual(f.strip(), result.strip()) + def fn(): + if x == 0: + x = 'a' + elif x == 1: + x = 'b' + else: + x = 'c' + return x + + result0 = dedent(''' + def fn(): + x = 'a' + return 'a' + ''') + result1 = dedent(''' + def fn(): + x = 'b' + return 'b' + ''') + result2 = dedent(''' + def fn(): + x = 'c' + return 'c' + ''') + self.assertEqual(pragma.collapse_literals(return_source=True, x=0)(fn).strip(), result0.strip()) + self.assertEqual(pragma.collapse_literals(return_source=True, x=1)(fn).strip(), result1.strip()) + self.assertEqual(pragma.collapse_literals(return_source=True, x=2)(fn).strip(), result2.strip()) + class TestDeindex(PragmaTest): def test_with_literals(self): @@ -501,8 +525,6 @@ def run_func(i, x): if i == j: return funcs[j](x) - print(inspect.getsource(run_func)) - self.assertEqual(run_func(0, 5), 5) self.assertEqual(run_func(1, 5), 25) self.assertEqual(run_func(2, 5), 125)