From 7b4eb9545bc9472b2ce02ce854691335f5eb7966 Mon Sep 17 00:00:00 2001 From: scnerd Date: Fri, 17 Nov 2017 19:25:00 -0500 Subject: [PATCH 1/5] Most of literal collapsing implemented --- docs/source/misc.rst | 5 + miniutils/opt_decorator.py | 7 +- miniutils/pragma.py | 320 ++++++++++++++++++++++++++++++------- tests/test_pragma.py | 122 +++++++++++++- 4 files changed, 386 insertions(+), 68 deletions(-) diff --git a/docs/source/misc.rst b/docs/source/misc.rst index f2f8d80..51ff136 100644 --- a/docs/source/misc.rst +++ b/docs/source/misc.rst @@ -348,3 +348,8 @@ Currently not-yet-supported features include: - ``zip``, ``reversed``, and other known operators, when performed on definition-time constant iterables .. autofunction:: miniutils.pragma.unroll + +Collapse Literals +----------------- + + diff --git a/miniutils/opt_decorator.py b/miniutils/opt_decorator.py index 2fdaa1c..72949d8 100644 --- a/miniutils/opt_decorator.py +++ b/miniutils/opt_decorator.py @@ -15,12 +15,9 @@ def inner_decorator_make(*args, **kwargs): decorator = _decorator(*args, **kwargs) - def inner_decorator_maker(_func): - return decorator(_func) - if func: - return inner_decorator_maker(func) + return decorator(func) else: - return inner_decorator_maker + return decorator return inner_decorator_make diff --git a/miniutils/pragma.py b/miniutils/pragma.py index 0e8a36b..fc2c0ab 100644 --- a/miniutils/pragma.py +++ b/miniutils/pragma.py @@ -1,7 +1,7 @@ import ast, inspect, sys, copy from miniutils.opt_decorator import optional_argument_decorator import textwrap -import astor +import warnings, traceback, tempfile class DictStack: @@ -16,6 +16,8 @@ def __setitem__(self, key, value): def __getitem__(self, item): for dct in self.dicts[::-1]: if item in dct: + if dct[item] is None: + raise KeyError("Found '{}', but it was set to an unknown value".format(item)) return dct[item] raise KeyError("Can't find '{}' anywhere in the function's context".format(item)) @@ -27,7 +29,11 @@ def __delitem__(self, item): raise KeyError() def __contains__(self, item): - return any(item in dct for dct in self.dicts[::-1]) + try: + self[item] + return True + except: + return False def items(self): items = [] @@ -62,17 +68,23 @@ def _function_ast(f): def _constant_iterable(node, ctxt): + # TODO: Support zipping + # TODO: Support sets/dicts? # Check for range(*constants) if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and ctxt[node.func.id] == range and all( isinstance(arg, ast.Num) for arg in node.args): return [ast.Num(n) for n in range(*[arg.n for arg in node.args])] elif isinstance(node, (ast.List, ast.Tuple)): - return [resolve_name_or_attribute(e, ctxt) for e in node.elts] + return [_resolve_name_or_attribute(e, ctxt) for e in node.elts] # Can't yet support sets and lists, since you need to compute what the unique values would be # elif isinstance(node, ast.Dict): # return node.keys - elif isinstance(node, (ast.Name, ast.Attribute)): - res = resolve_name_or_attribute(node, ctxt) + elif isinstance(node, (ast.Name, ast.Attribute, ast.NameConstant)): + res = _resolve_name_or_attribute(node, ctxt) + import astor + #print("Trying to resolve '{}' as list, got {}".format(astor.to_source(node), res)) + if isinstance(res, ast.AST) and not isinstance(res, (ast.Name, ast.Attribute, ast.NameConstant)): + res = _constant_iterable(res, ctxt) if not isinstance(res, ast.AST): try: iter(res) @@ -82,17 +94,20 @@ def _constant_iterable(node, ctxt): return None -def resolve_name_or_attribute(node, ctxt_or_obj): +def _resolve_name_or_attribute(node, ctxt_or_obj): if isinstance(node, ast.Name): if isinstance(ctxt_or_obj, DictStack): if node.id in ctxt_or_obj: + #print("Resolved '{}' to {}".format(node.id, ctxt_or_obj[node.id])) return ctxt_or_obj[node.id] else: return node else: return getattr(ctxt_or_obj, node.id, node) + elif isinstance(node, ast.NameConstant): + return node.value elif isinstance(node, ast.Attribute): - base_obj = resolve_name_or_attribute(node.value, ctxt_or_obj) + base_obj = _resolve_name_or_attribute(node.value, ctxt_or_obj) if not isinstance(base_obj, ast.AST): return getattr(base_obj, node.attr, node) else: @@ -101,12 +116,185 @@ def resolve_name_or_attribute(node, ctxt_or_obj): return node -class UnrollTransformer(ast.NodeTransformer): +# slice = Slice(expr? lower, expr? upper, expr? step) +# | ExtSlice(slice* dims) +# | Index(expr value) +# +# boolop = And | Or +# +# operator = Add | Sub | Mult | MatMult | Div | Mod | Pow | LShift +# | RShift | BitOr | BitXor | BitAnd | FloorDiv +# +# unaryop = Invert | Not | UAdd | USub +# +# cmpop = Eq | NotEq | Lt | LtE | Gt | GtE | Is | IsNot | In | NotIn + +_collapse_map = { + ast.Add: lambda a, b: a + b, + ast.Sub: lambda a, b: a - b, + ast.Mult: lambda a, b: a * b, + ast.Div: lambda a, b: a / b, + ast.FloorDiv: lambda a, b: a // b, + + ast.Mod: lambda a, b: a % b, + ast.Pow: lambda a, b: a ** b, + ast.LShift: lambda a, b: a << b, + ast.RShift: lambda a, b: a >> b, + ast.MatMult: lambda a, b: a @ b, + + ast.BitAnd: lambda a, b: a & b, + ast.BitOr: lambda a, b: a | b, + ast.BitXor: lambda a, b: a ^ b, + ast.And: lambda a, b: a and b, + ast.Or: lambda a, b: a or b, + ast.Invert: lambda a: ~a, + ast.Not: lambda a: not a, + + ast.UAdd: lambda a: a, + ast.USub: lambda a: -a, + + ast.Eq: lambda a, b: a == b, + ast.NotEq: lambda a, b: a != b, + ast.Lt: lambda a, b: a < b, + ast.LtE: lambda a, b: a <= b, + ast.Gt: lambda a, b: a > b, + ast.GtE: lambda a, b: a >= b, +} + + +def _make_ast_from_literal(lit): + if isinstance(lit, (list, tuple)): + res = [_make_ast_from_literal(e) for e in lit] + tp = ast.List if isinstance(lit, list) else ast.Tuple + return tp(elts=res) + elif isinstance(lit, (int, float)): + return ast.Num(lit) + elif isinstance(lit, str): + return ast.Str(lit) + elif isinstance(lit, bool): + return ast.NameConstant(lit) + else: + return lit + + +def __collapse_literal(node, ctxt): + if isinstance(node, (ast.Name, ast.Attribute, ast.NameConstant)): + res = _resolve_name_or_attribute(node, ctxt) + if isinstance(res, ast.AST) and not isinstance(res, (ast.Name, ast.Attribute, ast.NameConstant)): + res = __collapse_literal(res, ctxt) + return res + elif isinstance(node, ast.Num): + return node.n + elif isinstance(node, ast.Str): + return node.s + elif isinstance(node, ast.Index): + return __collapse_literal(node.value) + elif isinstance(node, ast.Subscript): + print("SUBSCRIPT") + lst = _constant_iterable(node.value) + print(lst) + if lst is None: + return node + slc = __collapse_literal(node.slice) + print(slc) + if isinstance(slc, ast.AST): + return node + return lst[slc] + elif isinstance(node, (ast.UnaryOp, ast.BinOp, ast.BoolOp)): + if isinstance(node, ast.UnaryOp): + operands = [__collapse_literal(node.operand, ctxt)] + else: + operands = [__collapse_literal(o, ctxt) for o in [node.left, node.right]] + #print("({} {})".format(repr(node.op), ", ".join(repr(o) for o in operands))) + is_literal = [not isinstance(opr, ast.AST) for opr in operands] + if all(is_literal): + try: + val = _collapse_map[type(node.op)](*operands) + return val + except: + warnings.warn("Literal collapse failed. Collapsing skipped, but executing this function will likely fail." + " Error was:\n{}".format(traceback.format_exc())) + return node + else: + if any(is_literal): + # Note that we know that it wasn't a unary op, else it would've succeded... so it was a binary op + return type(node)(left=_make_ast_from_literal(operands[0]), + right=_make_ast_from_literal(operands[1]), + op=node.op) + return node + elif isinstance(node, ast.Compare): + operands = [__collapse_literal(o, ctxt) for o in [node.left] + node.comparators] + if all(not isinstance(opr, ast.AST) for opr in operands): + return all(_collapse_map[type(cmp_func)](operands[i-1], operands[i]) + for i, cmp_func in zip(range(1, len(operands)), node.ops)) + else: + return node + else: + return node + + +def _collapse_literal(node, ctxt): + return _make_ast_from_literal(__collapse_literal(node, ctxt)) + + +def _assign_names(node): + if isinstance(node, ast.Name): + yield node.id + elif isinstance(node, ast.Tuple): + yield from [_assign_names(e) for e in node.elts] + else: + pass + + +# noinspection PyPep8Naming +class TrackedContextTransformer(ast.NodeTransformer): def __init__(self, ctxt=None, *args, **kwargs): self.ctxt = ctxt or DictStack() - self.loop_vars = set() super().__init__(*args, **kwargs) + def visit(self, node): + import astor + orig_node = node + new_node = super().visit(node) + #print("Converted >>> {} <<< to >>> {} <<<".format(astor.to_source(orig_node).strip(), + # astor.to_source(new_node).strip())) + return new_node + + def visit_Assign(self, node): + node.value = self.visit(node.value) + #print(node.value) + # TODO: Support tuple assignments + if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name): + var = node.targets[0].id + val = _constant_iterable(node.value, self.ctxt) + if val is not None: + #print("Setting {} = {}".format(var, val)) + self.ctxt[var] = val + else: + val = _collapse_literal(node.value, self.ctxt) + #print("Setting {} = {}".format(var, val)) + self.ctxt[var] = val + else: + for assgn in _assign_names(node.targets): + self.ctxt[assgn] = None + return node + + # def visit_AugAssign(self, node): + # if isinstance(node.target, ast.Name): + # self.ctxt.push({node.target.id: None}) + # res = self.visit(ast.Assign(targets=[node.target], + # value=ast.BinOp(left=ast.Name(node.target.id, ast.Load()), + # right=node.value, op=node.op))) + # self.ctxt.pop() + # return res + + +# noinspection PyPep8Naming +class UnrollTransformer(TrackedContextTransformer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.loop_vars = set() + def visit_For(self, node): result = [node] iterable = _constant_iterable(node.iter, self.ctxt) @@ -114,7 +302,7 @@ def visit_For(self, node): result = [] loop_var = node.target.id orig_loop_vars = self.loop_vars - #print("Unrolling 'for {} in {}'".format(loop_var, list(iterable))) + # print("Unrolling 'for {} in {}'".format(loop_var, list(iterable))) for val in iterable: self.ctxt.push({loop_var: val}) self.loop_vars = orig_loop_vars | {loop_var} @@ -126,7 +314,7 @@ def visit_For(self, node): continue else: result.append(res) - #result.extend([self.visit(body_node) for body_node in copy.deepcopy(node.body)]) + # result.extend([self.visit(body_node) for body_node in copy.deepcopy(node.body)]) self.ctxt.pop() self.loop_vars = orig_loop_vars return result @@ -138,60 +326,72 @@ def visit_Name(self, node): raise NameError("'{}' not defined in context".format(node.id)) return node - def visit_Assign(self, node): - if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name): - var = node.targets[0].id - val = _constant_iterable(node.value, self.ctxt) - if val is not None: - self.ctxt[var] = val - else: - self.ctxt[var] = node.value - return node +# noinspection PyPep8Naming +class CollapseTransformer(TrackedContextTransformer): + def visit_BinOp(self, node): + return _collapse_literal(node, self.ctxt) -@optional_argument_decorator -def unroll(return_source=False, **kwargs): - """Unrolls constant loops in the decorated function + def visit_UnaryOp(self, node): + return _collapse_literal(node, self.ctxt) - :param return_source: Returns the unrolled function's source code instead of compiling it - :param kwargs: Any other environmental variables to provide during unrolling - :return: The unrolled function, or its source code if requested - """ - # TODO: Support zipping - # TODO: Support sets/dicts - def inner(f): - f_mod, f_body, f_file = _function_ast(f) - glbls = f.__globals__ - print(glbls) - trans = UnrollTransformer(DictStack(glbls, kwargs)) - f_mod.body[0].decorator_list = [] - f_mod = trans.visit(f_mod) - if return_source: - try: - import astor - return astor.to_source(f_mod) - except ImportError: - raise ImportError("miniutils.pragma.unroll requires 'astor' to be installed to return source code") - else: - #func_source = astor.to_source(f_mod) - f_mod = ast.fix_missing_locations(f_mod) - exec(compile(f_mod, f_file, 'exec'), glbls) - return glbls[f_mod.body[0].name] + def visit_BoolOp(self, node): + return _collapse_literal(node, self.ctxt) - return inner + def visit_Compare(self, node): + return _collapse_literal(node, self.ctxt) -# # Given a set of functions that are called, inline their code into the decorated function -# def inline(*inline_fs): -# def inner(f): -# return f -# -# return inner +def _make_function_transformer(transformer_type, name): + @optional_argument_decorator + def transform(return_source=False, save_source=True, **kwargs): + """Unrolls constant loops in the decorated function + :param return_source: Returns the unrolled function's source code instead of compiling it + :param save_source: Saves the function source code to a tempfile to make it inspectable + :param kwargs: Any other environmental variables to provide during unrolling + :return: The unrolled function, or its source code if requested + """ -# # Collapse defined literal values, and operations thereof, where possible -# def collapse_constants(): -# def inner(f): -# return f -# -# return inner + def inner(f): + f_mod, f_body, f_file = _function_ast(f) + glbls = f.__globals__ + trans = transformer_type(DictStack(glbls, kwargs)) + f_mod.body[0].decorator_list = [] + f_mod = trans.visit(f_mod) + if return_source or save_source: + try: + import astor + source = astor.to_source(f_mod) + except ImportError: + raise ImportError("miniutils.pragma.{name} requires 'astor' to be installed to obtain source code" + .format(name=name)) + else: + source = None + + if return_source: + return source + else: + # func_source = astor.to_source(f_mod) + f_mod = ast.fix_missing_locations(f_mod) + if save_source: + temp = tempfile.NamedTemporaryFile('w', delete=True) + f_file = temp.name + print(astor.dump_tree(f_mod)) + exec(compile(f_mod, f_file, 'exec'), glbls) + func = glbls[f_mod.body[0].name] + if save_source: + func.__tempfile__ = temp + temp.write(source) + temp.flush() + return func + + inner.__name__ = name + return inner + return transform + + +unroll = _make_function_transformer(UnrollTransformer, 'unroll') + +# Collapse defined literal values, and operations thereof, where possible +collapse_literals = _make_function_transformer(CollapseTransformer, 'collapse_literals') diff --git a/tests/test_pragma.py b/tests/test_pragma.py index e80e5b8..bf5a2a4 100644 --- a/tests/test_pragma.py +++ b/tests/test_pragma.py @@ -14,15 +14,15 @@ def f(): def test_unroll_various(self): g = lambda: None - g.a = [1,2,3] + g.a = [1, 2, 3] g.b = 6 @pragma.unroll(return_source=True) def f(x): y = 5 a = range(3) - b = [1,2,4] - c = (1,2,5) + b = [1, 2, 4] + c = (1, 2, 5) d = reversed(a) e = [x, x, x] f = [y, y, y] @@ -194,6 +194,24 @@ def f(): ''') self.assertEqual(f.strip(), result.strip()) + def test_unroll_2list_source(self): + @pragma.unroll(return_source=True) + def f(): + for i in [[1, 2, 3], [4, 5], [6]]: + for j in i: + yield j + + result = dedent(''' + def f(): + yield 1 + yield 2 + yield 3 + yield 4 + yield 5 + yield 6 + ''') + self.assertEqual(f.strip(), result.strip()) + def test_external_definition(self): # Known bug: this works when defined as a kwarg, but not as an external variable, but ONLY in unittests... # External variables work in practice @@ -211,6 +229,104 @@ def f(): self.assertEqual(f.strip(), result.strip()) +class TestCollapseLiterals(TestCase): + def test_full_run(self): + def f(y): + x = 3 + r = 1 + x + for z in range(2): + r *= 1 + 2 * 3 + for abc in range(x): + for a in range(abc): + for b in range(y): + r += 1 + 2 + y + return r + + import inspect + print(inspect.getsource(f)) + print(pragma.collapse_literals(return_source=True)(f)) + deco_f = pragma.collapse_literals(f) + self.assertEqual(f(0), deco_f(0)) + self.assertEqual(f(1), deco_f(1)) + self.assertEqual(f(5), deco_f(5)) + self.assertEqual(f(-1), deco_f(-1)) + + import inspect + print(inspect.getsource(f)) + print(pragma.collapse_literals(return_source=True)(pragma.unroll(f))) + deco_f = pragma.collapse_literals(pragma.unroll(f)) + self.assertEqual(f(0), deco_f(0)) + self.assertEqual(f(1), deco_f(1)) + self.assertEqual(f(5), deco_f(5)) + self.assertEqual(f(-1), deco_f(-1)) + + def test_basic(self): + @pragma.collapse_literals(return_source=True) + def f(): + return 1 + 1 + + result = dedent(''' + def f(): + return 2 + ''') + self.assertEqual(f.strip(), result.strip()) + + def test_vars(self): + @pragma.collapse_literals(return_source=True) + def f(): + x = 3 + y = 2 + return x + y + + result = dedent(''' + def f(): + x = 3 + y = 2 + return 5 + ''') + self.assertEqual(f.strip(), result.strip()) + + def test_partial(self): + @pragma.collapse_literals(return_source=True) + def f(y): + x = 3 + return x + 2 + y + + result = dedent(''' + def f(y): + x = 3 + return 5 + y + ''') + self.assertEqual(f.strip(), result.strip()) + + def test_constant_index(self): + @pragma.collapse_literals(return_source=True) + def f(): + x = [1,2,3] + return x[0] + + result = dedent(''' + def f(y): + x = [1, 2, 3] + return 1 + ''') + self.assertEqual(f.strip(), result.strip()) + + def test_with_unroll(self): + @pragma.collapse_literals(return_source=True) + @pragma.unroll + def f(): + for i in range(3): + print(i + 2) + + result = dedent(''' + def f(): + print(2) + print(3) + print(4) + ''') + self.assertEqual(f.strip(), result.strip()) + class TestDictStack(TestCase): def test_most(self): From 09b22847863773947d1b402870550d1a6b24febd Mon Sep 17 00:00:00 2001 From: scnerd Date: Sat, 18 Nov 2017 00:21:35 -0500 Subject: [PATCH 2/5] Finalized initial version of collapsing literals --- docs/source/misc.rst | 2 ++ miniutils/opt_decorator.py | 1 + miniutils/pragma.py | 39 +++++++++++++++++++------------------- tests/test_pragma.py | 2 +- 4 files changed, 24 insertions(+), 20 deletions(-) diff --git a/docs/source/misc.rst b/docs/source/misc.rst index 51ff136..4f68b4d 100644 --- a/docs/source/misc.rst +++ b/docs/source/misc.rst @@ -352,4 +352,6 @@ Currently not-yet-supported features include: Collapse Literals ----------------- +Collapse literal operations in code to their results, e.g. ``x = 1 + 2`` gets converted to ``x = 3``. +.. autofunction:: miniutils.pragma.collapse_literals diff --git a/miniutils/opt_decorator.py b/miniutils/opt_decorator.py index 72949d8..aad880a 100644 --- a/miniutils/opt_decorator.py +++ b/miniutils/opt_decorator.py @@ -5,6 +5,7 @@ def optional_argument_decorator(_decorator): """Decorate your decorator with this to allow it to always receive *args and **kwargs, making @deco equivalent to @deco()""" + @functools.wraps(_decorator) def inner_decorator_make(*args, **kwargs): if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): func = args[0] diff --git a/miniutils/pragma.py b/miniutils/pragma.py index fc2c0ab..ae5a3b0 100644 --- a/miniutils/pragma.py +++ b/miniutils/pragma.py @@ -32,7 +32,7 @@ def __contains__(self, item): try: self[item] return True - except: + except KeyError: return False def items(self): @@ -188,14 +188,14 @@ def __collapse_literal(node, ctxt): elif isinstance(node, ast.Str): return node.s elif isinstance(node, ast.Index): - return __collapse_literal(node.value) + return __collapse_literal(node.value, ctxt) elif isinstance(node, ast.Subscript): print("SUBSCRIPT") - lst = _constant_iterable(node.value) + lst = _constant_iterable(node.value, ctxt) print(lst) if lst is None: return node - slc = __collapse_literal(node.slice) + slc = __collapse_literal(node.slice, ctxt) print(slc) if isinstance(slc, ast.AST): return node @@ -248,16 +248,23 @@ def _assign_names(node): # noinspection PyPep8Naming class TrackedContextTransformer(ast.NodeTransformer): - def __init__(self, ctxt=None, *args, **kwargs): + def __init__(self, ctxt=None): self.ctxt = ctxt or DictStack() - super().__init__(*args, **kwargs) + super().__init__() def visit(self, node): import astor - orig_node = node + orig_node = copy.deepcopy(node) new_node = super().visit(node) - #print("Converted >>> {} <<< to >>> {} <<<".format(astor.to_source(orig_node).strip(), - # astor.to_source(new_node).strip())) + + orig_node_code = astor.to_source(orig_node).strip() + if new_node is None: + print("Deleted >>> {} <<<".format(orig_node_code)) + elif isinstance(new_node, ast.AST): + print("Converted >>> {} <<< to >>> {} <<<".format(orig_node_code, astor.to_source(new_node).strip())) + elif isinstance(new_node, list): + print("Converted >>> {} <<< to [[[ {} ]]]".format(orig_node_code, ", ".join(astor.to_source(n).strip() for n in new_node))) + return new_node def visit_Assign(self, node): @@ -279,15 +286,6 @@ def visit_Assign(self, node): self.ctxt[assgn] = None return node - # def visit_AugAssign(self, node): - # if isinstance(node.target, ast.Name): - # self.ctxt.push({node.target.id: None}) - # res = self.visit(ast.Assign(targets=[node.target], - # value=ast.BinOp(left=ast.Name(node.target.id, ast.Load()), - # right=node.value, op=node.op))) - # self.ctxt.pop() - # return res - # noinspection PyPep8Naming class UnrollTransformer(TrackedContextTransformer): @@ -341,6 +339,9 @@ def visit_BoolOp(self, node): def visit_Compare(self, node): return _collapse_literal(node, self.ctxt) + def visit_Subscript(self, node): + return _collapse_literal(node, self.ctxt) + def _make_function_transformer(transformer_type, name): @optional_argument_decorator @@ -377,7 +378,7 @@ def inner(f): if save_source: temp = tempfile.NamedTemporaryFile('w', delete=True) f_file = temp.name - print(astor.dump_tree(f_mod)) + #print(astor.dump_tree(f_mod)) exec(compile(f_mod, f_file, 'exec'), glbls) func = glbls[f_mod.body[0].name] if save_source: diff --git a/tests/test_pragma.py b/tests/test_pragma.py index bf5a2a4..256c609 100644 --- a/tests/test_pragma.py +++ b/tests/test_pragma.py @@ -306,7 +306,7 @@ def f(): return x[0] result = dedent(''' - def f(y): + def f(): x = [1, 2, 3] return 1 ''') From 6d3e93b870cdf9ac566b3cefbaa5cfdd6d2ce8d0 Mon Sep 17 00:00:00 2001 From: scnerd Date: Mon, 20 Nov 2017 17:39:06 -0500 Subject: [PATCH 3/5] Added de-indexing, improved pragma module with bugfixes. --- docs/source/caching.rst | 2 +- docs/source/index.rst | 7 +- docs/source/misc.rst | 154 +-------------------- docs/source/pragma.rst | 256 +++++++++++++++++++++++++++++++++++ docs/source/progress_bar.rst | 8 +- docs/source/python2.rst | 2 +- miniutils/opt_decorator.py | 13 ++ miniutils/pragma.py | 180 +++++++++++++++++------- tests/test_pragma.py | 131 +++++++++++++++++- 9 files changed, 544 insertions(+), 209 deletions(-) create mode 100644 docs/source/pragma.rst diff --git a/docs/source/caching.rst b/docs/source/caching.rst index 0ad7db1..db5a0d1 100644 --- a/docs/source/caching.rst +++ b/docs/source/caching.rst @@ -1,5 +1,5 @@ Property Cache -++++++++++++++ +============== In some cases, an object has properties that don't need to be computed until necessary, and once computed are generally static and could just be cached. This could be accomplished using the following simple recipe:: diff --git a/docs/source/index.rst b/docs/source/index.rst index f24e1e0..983d3c8 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -23,12 +23,13 @@ Welcome to miniutils's documentation! progress_bar caching python2 + pragma misc api Overview --------- +++++++++ This module provides numerous helper utilities for Python3.X code to add functionality with minimal code footprint. It has tools for the following tasks: @@ -38,7 +39,7 @@ This module provides numerous helper utilities for Python3.X code to add functio - More intuitive contract decorator (leveraging ``pycontracts``) Installation ------------- +++++++++++++ As usual, you can install the latest code version directly from Github:: @@ -49,7 +50,7 @@ Or you can ``pip`` install the latest release from PyPi:: pip install miniutils Examples --------- +++++++++ To get started, you can import your desired utilities directly from ``miniutils``. For example, to use the ``CachedProperty`` decorator:: diff --git a/docs/source/misc.rst b/docs/source/misc.rst index 4f68b4d..1299682 100644 --- a/docs/source/misc.rst +++ b/docs/source/misc.rst @@ -1,8 +1,8 @@ Miscellaneous -+++++++++++++ +============= Code Contracts -============== +++++++++++++++ Code contracting seems like a great way to define and document your code's expected behavior, easily integrate bounds checking, and just generally write code that tries to avoid bugs. The `pycontracts `_ package provides this capability within python, but as soon as I started using it I realized that it was meant primarily to be robust, not concise. For example, consider the following code:: @@ -51,7 +51,7 @@ And now the function works like you'd expect. If you want to do something more c .. autofunction:: miniutils.magic_contract.magic_contract Simplifying Decorators -====================== +++++++++++++++++++++++ When writing a decorator that could be used like ``@deco`` or ``@deco()``, there's a little code I've found necessary in order to make both cases function identically. I've isolated this code into another decorator (meta-decorator?) to keep my other decorators simple (since, let's be honest, decorators are usually convoluted enough as is). @@ -115,7 +115,7 @@ This makes sense, but is somewhat annoying when parameters aren't required, such .. autofunction:: miniutils.opt_decorator.optional_argument_decorator Logging Made Easy -================= ++++++++++++++++++ The standard ``logging`` module provides a lot of great functionality, but there are a few simplifications missing: @@ -158,7 +158,7 @@ The ``coloredlogs`` module didn't quite work as expected when I tried to use it. Timing -====== +++++++ Simple ``printf``-like timing utilities when proper profiling won't quite work. @@ -212,146 +212,4 @@ Use ``tic``/``toc`` to time and report the run times of different chunks of code This utility is just less verbose than tracking various times yourself. The output is printed to the log for later review. It can also accept a custom print format string, including information about the code calling ``toc()`` and runtimes since the last ``tic``/``toc``. -.. autofunction:: miniutils.timing.tic - -Pragma -====== - -When Python code is being executed abnormally, or being replaced entirely (e.g., by ``numba.jit``), it's sometimes highly relevant how your code is written. However, writing it that way isn't always practical, or you might want the code itself to be dependant on runtime data. In these cases, basic code templating or modification can be useful. This sub-module provides some simple utilities to perform Python code modification at runtime, similar to compiler directives in C. - -Unroll ------- - -Unroll constant loops. If the `for`-loop iterator is a known value at function definition time, then replace it with its body duplicated for each value. For example:: - - def f(): - for i in [1, 2, 4]: - yield i - -could be identically replaced by:: - - def f(): - yield 1 - yield 2 - yield 4 - -The ``unroll`` decorator accomplishes this by parsing the input function, performing the unrolling transformation on the function's AST, then compiling and returning the defined function. - -If using a transformational decorator of some sort, such as ``numba.jit`` or ``tangent.grad``, if that function isn't yet able to unwrap loops like this, then using this function might yield cleaner results on constant-length loops. - -``unroll`` is currently smart enough to notice singly-defined variables and literals, as well as able to unroll the ``range`` function and unroll nested loops:: - - @pragma.unroll - def summation(x=0): - a = [x, x, x] - v = 0 - for _a in a: - v += _a - return v - - # ... Becomes ... - - def summation(x=0): - a = [x, x, x] - v = 0 - v += x - v += x - v += x - return v - - # ... But ... - - @pragma.unroll - def f(): - x = 3 - for i in [x, x, x]: - yield i - x = 4 - a = [x, x, x] - for i in a: - yield i - - # ... Becomes ... - - def f(): - x = 3 - yield 3 - yield 3 - yield 3 - x = 4 - a = [x, x, x] - yield 4 - yield 4 - yield 4 - - # Even nested loops and ranges work! - - @pragma.unroll - def f(): - for i in range(3): - for j in range(3): - yield i + j - - # ... Becomes ... - - def f(): - yield 0 + 0 - yield 0 + 1 - yield 0 + 2 - yield 1 + 0 - yield 1 + 1 - yield 1 + 2 - yield 2 + 0 - yield 2 + 1 - yield 2 + 2 - -You can also request to get the function source code instead of the compiled callable by using ``return_source=True``:: - - In [1]: @pragma.unroll(return_source=True) - ...: def f(): - ...: for i in range(3): - ...: print(i) - ...: - - In [2]: print(f) - def f(): - print(0) - print(1) - print(2) - -It also supports limited recognition of externally and internally defined values:: - - @pragma.unroll(a=range) - def f(): - for b in a(3): - print(b) - - # Is equivalent to: - - a = range - @pragma.unroll - def f(): - for b in a(3): - print(b) - - # Both of which become: - - def f(): - print(0) - print(1) - print(2) - -Currently not-yet-supported features include: - -- Handling constant sets and dictionaries (since the values contained in the AST's, not the AST nodes themselves, must be uniquely identified) -- Tuple assignments (``a, b = 3, 4``) -- ``zip``, ``reversed``, and other known operators, when performed on definition-time constant iterables - -.. autofunction:: miniutils.pragma.unroll - -Collapse Literals ------------------ - -Collapse literal operations in code to their results, e.g. ``x = 1 + 2`` gets converted to ``x = 3``. - -.. autofunction:: miniutils.pragma.collapse_literals +.. autofunction:: miniutils.timing.tic \ No newline at end of file diff --git a/docs/source/pragma.rst b/docs/source/pragma.rst new file mode 100644 index 0000000..f200dff --- /dev/null +++ b/docs/source/pragma.rst @@ -0,0 +1,256 @@ +Pragma +++++++ + +When Python code is being executed abnormally, or being replaced entirely (e.g., by ``numba.jit``), it's sometimes highly relevant how your code is written. However, writing it that way isn't always practical, or you might want the code itself to be dependant on runtime data. In these cases, basic code templating or modification can be useful. This sub-module provides some simple utilities to perform Python code modification at runtime, similar to compiler directives in C. + +These functions are designed as decorators that can be stacked together. Each one modifies the provided function's AST, and then re-compiles the function with identical context to the original. A side effect of accomplishing this means that source code is (optionally) made available for each function, either as a return value (replace the function with a string of its modified source code) or, more usefully, by saving it to a temporary file so that ``inspect.getsource`` works correctly on it. + +Because Python is an interpreted language and functions are first-order objects, it's possible to use these functions to perform runtime-based code "optimization" or "templating". As a simple example of this, let's consider ``numba.cuda.jit``, which imposes numerous ``nopython`` limitations on what your function can do. One such limitation is that a ``numba.cuda`` kernel can't treat functions as first order objects. It must know, at function definition time, which function it's calling. Take the following example:: + + funcs = [lambda x: x, lambda x: x ** 2, lambda x: x ** 3] + + def run_func(i, x): + return funcs[i](x) + +How could we re-define this function such that it both:: + +1) Is dynamic to a list that's constant at function definition-time +2) Doesn't actually index that list in its defition + +We'll start by defining the function as an ``if`` check for the index, and call the appropriate function:: + + funcs = [lambda x: x, lambda x: x ** 2, lambda x: x ** 3] + + def run_func(i, x): + for j in range(len(funcs)): + if i == j: + return funcs[j](x) + +The ``miniutils.pragma`` module enables us to go from here to accomplish our goal above by re-writing a function's AST and re-compiling it as a closure, while making certain modifications to its syntax and environment. While each function will be fully described lower, the example above can be succinctly solved by unrolling the loop (whose length is known at function definition time) and by assigning the elements of the list to individual variables and swapping out their indexed references with de-indexed references:: + + funcs = [lambda x: x, lambda x: x ** 2, lambda x: x ** 3] + + @pragma.deindex(funcs, 'funcs') + @pragma.unroll(lf=len(funcs)) + def run_func(i, x): + for j in range(lf): + if i == j: + return funcs[j](x) + + # ... gets transformed at definition time into the below code ... + + funcs = [lambda x: x, lambda x: x ** 2, lambda x: x ** 3] + funcs_0 = funcs[0] + funcs_1 = funcs[1] + funcs_2 = funcs[2] + + def run_func(i, x): + if i == 0: + return funcs_0(x) + if i == 1: + return funcs_1(x) + if i == 2: + return funcs_2(x) + +Unroll +------ + +Unroll constant loops. If the `for`-loop iterator is a known value at function definition time, then replace it with its body duplicated for each value. For example:: + + def f(): + for i in [1, 2, 4]: + yield i + +could be identically replaced by:: + + def f(): + yield 1 + yield 2 + yield 4 + +The ``unroll`` decorator accomplishes this by parsing the input function, performing the unrolling transformation on the function's AST, then compiling and returning the defined function. + +If using a transformational decorator of some sort, such as ``numba.jit`` or ``tangent.grad``, if that function isn't yet able to unwrap loops like this, then using this function might yield cleaner results on constant-length loops. + +``unroll`` is currently smart enough to notice singly-defined variables and literals, as well as able to unroll the ``range`` function and unroll nested loops:: + + @pragma.unroll + def summation(x=0): + a = [x, x, x] + v = 0 + for _a in a: + v += _a + return v + + # ... Becomes ... + + def summation(x=0): + a = [x, x, x] + v = 0 + v += x + v += x + v += x + return v + + # ... But ... + + @pragma.unroll + def f(): + x = 3 + for i in [x, x, x]: + yield i + x = 4 + a = [x, x, x] + for i in a: + yield i + + # ... Becomes ... + + def f(): + x = 3 + yield 3 + yield 3 + yield 3 + x = 4 + a = [x, x, x] + yield 4 + yield 4 + yield 4 + + # Even nested loops and ranges work! + + @pragma.unroll + def f(): + for i in range(3): + for j in range(3): + yield i + j + + # ... Becomes ... + + def f(): + yield 0 + 0 + yield 0 + 1 + yield 0 + 2 + yield 1 + 0 + yield 1 + 1 + yield 1 + 2 + yield 2 + 0 + yield 2 + 1 + yield 2 + 2 + +You can also request to get the function source code instead of the compiled callable by using ``return_source=True``:: + + In [1]: @pragma.unroll(return_source=True) + ...: def f(): + ...: for i in range(3): + ...: print(i) + ...: + + In [2]: print(f) + def f(): + print(0) + print(1) + print(2) + +It also supports limited recognition of externally and internally defined values:: + + @pragma.unroll(a=range) + def f(): + for b in a(3): + print(b) + + # Is equivalent to: + + a = range + @pragma.unroll + def f(): + for b in a(3): + print(b) + + # Both of which become: + + def f(): + print(0) + print(1) + print(2) + +Currently not-yet-supported features include: + +- Handling constant sets and dictionaries (since the values contained in the AST's, not the AST nodes themselves, must be uniquely identified) +- Tuple assignments (``a, b = 3, 4``) +- ``zip``, ``reversed``, and other known operators, when performed on definition-time constant iterables + +.. autofunction:: miniutils.pragma.unroll + +Collapse Literals +----------------- + +Collapse literal operations in code to their results, e.g. ``x = 1 + 2`` gets converted to ``x = 3``. + +For example:: + + @pragma.collapse_literals + def f(y): + x = 3 + return x + 2 + y + + # ... Becomes ... + + def f(y): + x = 3 + return 5 + y + + +.. autofunction:: miniutils.pragma.collapse_literals + +De-index Arrays +--------------- + +Convert literal indexing operations for a given array into named value references. The new value names are de-indexed and stashed in the function's closure so that the resulting code both uses no literal indices and still behaves as if it did. Variable indices are unaffected. + +For example:: + + v = [object(), object(), object()] + + @pragma.deindex(v, 'v') + def f(x): + yield v[0] + yield v[x] + + # ... f becomes ... + + def f(x): + yield v_0 # This is defined as v_0 = v[0] by the function's closure + yield v[x] + + # We can check that this works correctly + assert list(f(2)) == [v[0], v[2]] + +This can be easily stacked with :func:`miniutils.pragma.unroll` to unroll iterables in a function when their values are known at function definition time:: + + funcs = [lambda x: x, lambda x: x ** 2, lambda x: x ** 3] + + @pragma.deindex(funcs, 'funcs') + @pragma.unroll(lf=len(funcs)) + def run_func(i, x): + for j in range(lf): + if i == j: + return funcs[j](x) + + # ... Becomes ... + + def run_func(i, x): + if i == 0: + return funcs_0(x) + if i == 1: + return funcs_1(x) + if i == 2: + return funcs_2(x) + +This could be used, for example, in a case where dynamically calling functions isn't supported, such as in ``numba.jit`` or ``numba.cuda.jit``. + +Note that because the array being de-indexed is passed to the decorator, the value of the constant-defined variables (e.g. ``v_0`` in the code above) is "compiled" into the code of the function, and won't update if the array is updated. Again, variable-indexed calls remain unaffected. + +Since names are (and must) be used as references to the consi + +.. autofunction:: miniutils.pragma.deindex diff --git a/docs/source/progress_bar.rst b/docs/source/progress_bar.rst index 5fe1a12..f0d22b3 100644 --- a/docs/source/progress_bar.rst +++ b/docs/source/progress_bar.rst @@ -1,10 +1,10 @@ Progress Bars -+++++++++++++ +============= Three progress bar utilities are provided, all leveraging the excellent `tqdm `_ library. progbar -======= ++++++++ A simple iterable wrapper, much like the default ``tqdm`` wrapper. It can be used on any iterable to display a progress bar as it gets iterated:: @@ -16,7 +16,7 @@ However, unlike the standard ``tqdm`` function, this code has two additional, us .. autofunction:: miniutils.progress_bar.progbar parallel_progbar -================ +++++++++++++++++ A parallel mapper based on ``multiprocessing`` that replaces ``Pool.map``. In attempting to use ``Pool.map``, I've had issues with unintuitive errors and, of course, wanting a progress bar of my map job's progress. Both of these are solved in ``parallel_progbar``:: @@ -40,7 +40,7 @@ It also supports runtime disabling, limited number of parallel processes, shuffl .. autofunction:: miniutils.progress_bar.parallel_progbar iparallel_progbar -================= ++++++++++++++++++ This has the exact same behavior as ``parallel_progbar``, but produces an unordered generator instead of a list, yielding results as soon as they're available. It also permits a ``max_cache`` argument that allows you to limit the number of computed results available to the generator. :: diff --git a/docs/source/python2.rst b/docs/source/python2.rst index 35fb4f3..7dcca8f 100644 --- a/docs/source/python2.rst +++ b/docs/source/python2.rst @@ -1,5 +1,5 @@ Nesting Python 2 -++++++++++++++++ +================ In `very` rare situations, the standard means of Python2 compatibility within Python3 (such as ``six``, ``2to3``, or ``__futures__``) might simply be insufficient. Sometimes, you just need to run Python2 wholesale to get the correct behavior. diff --git a/miniutils/opt_decorator.py b/miniutils/opt_decorator.py index aad880a..4349a24 100644 --- a/miniutils/opt_decorator.py +++ b/miniutils/opt_decorator.py @@ -22,3 +22,16 @@ def inner_decorator_make(*args, **kwargs): return decorator return inner_decorator_make + + +# class PipedFunction: +# def __init__(self, f, *args, reversed_args=[]): +# self.f = f +# self.args = args +# self.reversed_args = reversed_args +# +# def __ror__(self, other): +# return PipedFunction(f, self.args, self.reversed_args + [other]) +# +# +# def function_piping(func) diff --git a/miniutils/pragma.py b/miniutils/pragma.py index ae5a3b0..5d26593 100644 --- a/miniutils/pragma.py +++ b/miniutils/pragma.py @@ -1,10 +1,27 @@ -import ast, inspect, sys, copy -from miniutils.opt_decorator import optional_argument_decorator +import ast +import astor +import copy +import inspect +import sys +import tempfile import textwrap -import warnings, traceback, tempfile +import traceback +import warnings + +from miniutils.opt_decorator import optional_argument_decorator + +# Astor tries to get fancy by failing nicely, but in doing so they fail when traversing non-AST type node properties. +# By deleting this custom handler, it'll fall back to the default ast visit pattern, which skips these missing +# properties. Everything else seems to be implemented, so this will fail silently if it hits an AST node that isn't +# supported but should be. +del astor.node_util.ExplicitNodeVisitor.visit class DictStack: + """ + Creates a stack of dictionaries to roughly emulate closures and variable environments + """ + def __init__(self, *base): import builtins self.dicts = [dict(builtins.__dict__)] + [dict(d) for d in base] @@ -56,6 +73,7 @@ def pop(self): def _function_ast(f): + """Returns ast for the given function. Gives a tuple of (ast_module, function_body, function_file""" assert callable(f) try: @@ -67,34 +85,70 @@ def _function_ast(f): return root, root.body[0].body, f_file -def _constant_iterable(node, ctxt): +def can_have_side_effect(node, ctxt): + if isinstance(node, ast.AST): + print("Can {} have side effects?".format(node)) + if isinstance(node, ast.Call): + print(" Yes!") + return True + else: + for field, old_value in ast.iter_fields(node): + if isinstance(old_value, list): + return any(can_have_side_effect(n, ctxt) for n in old_value if isinstance(n, ast.AST)) + elif isinstance(old_value, ast.AST): + return can_have_side_effect(old_value, ctxt) + else: + print(" No!") + return False + else: + return False + + +def _constant_iterable(node, ctxt, avoid_side_effects=True): + """If the given node is a known iterable of some sort, return the list of its elements.""" # TODO: Support zipping # TODO: Support sets/dicts? + # TODO: Support for reversed, enumerate, etc. + # TODO: Support len, in, etc. # Check for range(*constants) - if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and ctxt[node.func.id] == range and all( - isinstance(arg, ast.Num) for arg in node.args): - return [ast.Num(n) for n in range(*[arg.n for arg in node.args])] + def wrap(return_node, name, idx): + if not avoid_side_effects: + return return_node + if can_have_side_effect(return_node, ctxt): + return ast.Subscript(name, ast.Index(idx)) + return _make_ast_from_literal(return_node) + + if isinstance(node, ast.Call): + if _resolve_name_or_attribute(node.func, ctxt) == range: + args = [_collapse_literal(arg, ctxt) for arg in node.args] + if all(isinstance(arg, ast.Num) for arg in args): + return [ast.Num(n) for n in range(*[arg.n for arg in args])] + + return None elif isinstance(node, (ast.List, ast.Tuple)): - return [_resolve_name_or_attribute(e, ctxt) for e in node.elts] + return [_collapse_literal(e, ctxt) for e in node.elts] + #return [_resolve_name_or_attribute(e, ctxt) for e in node.elts] # Can't yet support sets and lists, since you need to compute what the unique values would be # elif isinstance(node, ast.Dict): # return node.keys elif isinstance(node, (ast.Name, ast.Attribute, ast.NameConstant)): res = _resolve_name_or_attribute(node, ctxt) - import astor #print("Trying to resolve '{}' as list, got {}".format(astor.to_source(node), res)) if isinstance(res, ast.AST) and not isinstance(res, (ast.Name, ast.Attribute, ast.NameConstant)): res = _constant_iterable(res, ctxt) if not isinstance(res, ast.AST): try: - iter(res) - return list(res) + if hasattr(res, 'items'): + return dict([(k, wrap(_make_ast_from_literal(v), node, k)) for k, v in res.items()]) + else: + return [wrap(_make_ast_from_literal(res_node), node, i) for i, res_node in enumerate(res)] except TypeError: pass return None def _resolve_name_or_attribute(node, ctxt_or_obj): + """If the given name of attribute is defined in the current context, return its value. Else, returns the node""" if isinstance(node, ast.Name): if isinstance(ctxt_or_obj, DictStack): if node.id in ctxt_or_obj: @@ -116,19 +170,6 @@ def _resolve_name_or_attribute(node, ctxt_or_obj): return node -# slice = Slice(expr? lower, expr? upper, expr? step) -# | ExtSlice(slice* dims) -# | Index(expr value) -# -# boolop = And | Or -# -# operator = Add | Sub | Mult | MatMult | Div | Mod | Pow | LShift -# | RShift | BitOr | BitXor | BitAnd | FloorDiv -# -# unaryop = Invert | Not | UAdd | USub -# -# cmpop = Eq | NotEq | Lt | LtE | Gt | GtE | Is | IsNot | In | NotIn - _collapse_map = { ast.Add: lambda a, b: a + b, ast.Sub: lambda a, b: a - b, @@ -163,6 +204,7 @@ def _resolve_name_or_attribute(node, ctxt_or_obj): def _make_ast_from_literal(lit): + """Converts literals into their AST equivalent""" if isinstance(lit, (list, tuple)): res = [_make_ast_from_literal(e) for e in lit] tp = ast.List if isinstance(lit, list) else ast.Tuple @@ -178,6 +220,7 @@ def _make_ast_from_literal(lit): def __collapse_literal(node, ctxt): + """Collapses literal expressions. Returns literals if they're available, AST nodes otherwise""" if isinstance(node, (ast.Name, ast.Attribute, ast.NameConstant)): res = _resolve_name_or_attribute(node, ctxt) if isinstance(res, ast.AST) and not isinstance(res, (ast.Name, ast.Attribute, ast.NameConstant)): @@ -189,16 +232,19 @@ def __collapse_literal(node, ctxt): return node.s elif isinstance(node, ast.Index): return __collapse_literal(node.value, ctxt) + elif isinstance(node, (ast.Slice, ast.ExtSlice)): + raise NotImplemented() elif isinstance(node, ast.Subscript): - print("SUBSCRIPT") + # print("Attempting to subscript {}".format(astor.to_source(node))) lst = _constant_iterable(node.value, ctxt) - print(lst) + # print("Can I subscript {}?".format(lst)) if lst is None: return node slc = __collapse_literal(node.slice, ctxt) - print(slc) + # print("Getting subscript at {}".format(slc)) if isinstance(slc, ast.AST): return node + # print("Value at {}[{}] = {}".format(lst, slc, lst[slc])) return lst[slc] elif isinstance(node, (ast.UnaryOp, ast.BinOp, ast.BoolOp)): if isinstance(node, ast.UnaryOp): @@ -216,12 +262,12 @@ def __collapse_literal(node, ctxt): " Error was:\n{}".format(traceback.format_exc())) return node else: - if any(is_literal): - # Note that we know that it wasn't a unary op, else it would've succeded... so it was a binary op + if isinstance(node, ast.UnaryOp): + return ast.UnaryOp(operand=_make_ast_from_literal(operands[0]), op=node.op) + else: return type(node)(left=_make_ast_from_literal(operands[0]), right=_make_ast_from_literal(operands[1]), op=node.op) - return node elif isinstance(node, ast.Compare): operands = [__collapse_literal(o, ctxt) for o in [node.left] + node.comparators] if all(not isinstance(opr, ast.AST) for opr in operands): @@ -234,10 +280,16 @@ def __collapse_literal(node, ctxt): def _collapse_literal(node, ctxt): + """Collapse literal expressions in the given node. Returns the node with the collapsed literals""" return _make_ast_from_literal(__collapse_literal(node, ctxt)) def _assign_names(node): + """Gets names from a assign-to tuple in flat form, just to know what's affected + "x=3" -> "x" + "a,b=4,5" -> ["a", "b"] + "(x,(y,z)),(a,) = something" -> ["x", "y", "z", "a"] + """ if isinstance(node, ast.Name): yield node.id elif isinstance(node, ast.Tuple): @@ -253,17 +305,19 @@ def __init__(self, ctxt=None): super().__init__() def visit(self, node): - import astor orig_node = copy.deepcopy(node) new_node = super().visit(node) orig_node_code = astor.to_source(orig_node).strip() - if new_node is None: - print("Deleted >>> {} <<<".format(orig_node_code)) - elif isinstance(new_node, ast.AST): - print("Converted >>> {} <<< to >>> {} <<<".format(orig_node_code, astor.to_source(new_node).strip())) - elif isinstance(new_node, list): - print("Converted >>> {} <<< to [[[ {} ]]]".format(orig_node_code, ", ".join(astor.to_source(n).strip() for n in new_node))) + try: + if new_node is None: + print("Deleted >>> {} <<<".format(orig_node_code)) + elif isinstance(new_node, ast.AST): + print("Converted >>> {} <<< to >>> {} <<<".format(orig_node_code, astor.to_source(new_node).strip())) + elif isinstance(new_node, list): + print("Converted >>> {} <<< to [[[ {} ]]]".format(orig_node_code, ", ".join(astor.to_source(n).strip() for n in new_node))) + except AssertionError as ex: + raise AssertionError("Failed on {} >>> {}".format(orig_node_code, astor.dump_tree(new_node))) from ex return new_node @@ -343,11 +397,10 @@ def visit_Subscript(self, node): return _collapse_literal(node, self.ctxt) -def _make_function_transformer(transformer_type, name): +def _make_function_transformer(transformer_type, name, description): @optional_argument_decorator - def transform(return_source=False, save_source=True, **kwargs): - """Unrolls constant loops in the decorated function - + def transform(return_source=False, save_source=True, function_globals=None, **kwargs): + """ :param return_source: Returns the unrolled function's source code instead of compiling it :param save_source: Saves the function source code to a tempfile to make it inspectable :param kwargs: Any other environmental variables to provide during unrolling @@ -356,13 +409,21 @@ def transform(return_source=False, save_source=True, **kwargs): def inner(f): f_mod, f_body, f_file = _function_ast(f) + # Grab function globals glbls = f.__globals__ + # Grab function closure variables + if isinstance(f.__closure__, tuple): + glbls.update({k: v.cell_contents for k, v in zip(f.__code__.co_freevars, f.__closure__)}) + # Apply manual globals override + if function_globals is not None: + glbls.update(function_globals) + print({k: v for k, v in glbls.items() if k not in globals()}) trans = transformer_type(DictStack(glbls, kwargs)) f_mod.body[0].decorator_list = [] f_mod = trans.visit(f_mod) + print(astor.dump_tree(f_mod)) if return_source or save_source: try: - import astor source = astor.to_source(f_mod) except ImportError: raise ImportError("miniutils.pragma.{name} requires 'astor' to be installed to obtain source code" @@ -378,7 +439,6 @@ def inner(f): if save_source: temp = tempfile.NamedTemporaryFile('w', delete=True) f_file = temp.name - #print(astor.dump_tree(f_mod)) exec(compile(f_mod, f_file, 'exec'), glbls) func = glbls[f_mod.body[0].name] if save_source: @@ -387,12 +447,40 @@ def inner(f): temp.flush() return func - inner.__name__ = name return inner + transform.__name__ = name + transform.__doc__ = '\n'.join([description, transform.__doc__]) return transform -unroll = _make_function_transformer(UnrollTransformer, 'unroll') +# Unroll literal loops +unroll = _make_function_transformer(UnrollTransformer, 'unroll', "Unrolls constant loops in the decorated function") # Collapse defined literal values, and operations thereof, where possible -collapse_literals = _make_function_transformer(CollapseTransformer, 'collapse_literals') +collapse_literals = _make_function_transformer(CollapseTransformer, 'collapse_literals', + "Collapses literal expressions in the decorated function into single literals") + +# Directly reference elements of constant list, removing literal indexing into that list within a function +def deindex(iterable, iterable_name, *args, **kwargs): + """ + :param iterable: The list to deindex in the target function + :param iterable_name: The list's name (must be unique if deindexing multiple lists) + :param return_source: Returns the unrolled function's source code instead of compiling it + :param save_source: Saves the function source code to a tempfile to make it inspectable + :param kwargs: Any other environmental variables to provide during unrolling + :return: The unrolled function, or its source code if requested + """ + + if hasattr(iterable, 'items'): # Support dicts and the like + internal_iterable = {k: '{}_{}'.format(iterable_name, k) for k, val in iterable.items()} + mapping = {internal_iterable[k]: val for k, val in iterable.items()} + else: # Support lists, tuples, and the like + internal_iterable = {i: '{}_{}'.format(iterable_name, i) for i, val in enumerate(iterable)} + mapping = {internal_iterable[i]: val for i, val in enumerate(iterable)} + + kwargs[iterable_name] = {k: ast.Name(id=name, ctx=ast.Load()) for k, name in internal_iterable.items()} + + return collapse_literals(*args, function_globals=mapping, **kwargs) + +# Inline functions? + diff --git a/tests/test_pragma.py b/tests/test_pragma.py index 256c609..14e606c 100644 --- a/tests/test_pragma.py +++ b/tests/test_pragma.py @@ -1,6 +1,7 @@ from unittest import TestCase from miniutils import pragma from textwrap import dedent +import inspect class TestUnroll(TestCase): @@ -69,11 +70,12 @@ def f(x): yield 5 yield 5 yield 5 - for i in g.a: - yield i - yield g.b + 0 - yield g.b + 1 - yield g.b + 2 + yield 1 + yield 2 + yield 3 + yield 6 + yield 7 + yield 8 ''') self.assertEqual(f.strip(), result.strip()) @@ -251,7 +253,6 @@ def f(y): self.assertEqual(f(5), deco_f(5)) self.assertEqual(f(-1), deco_f(-1)) - import inspect print(inspect.getsource(f)) print(pragma.collapse_literals(return_source=True)(pragma.unroll(f))) deco_f = pragma.collapse_literals(pragma.unroll(f)) @@ -327,6 +328,124 @@ def f(): ''') self.assertEqual(f.strip(), result.strip()) + def test_with_objects(self): + @pragma.collapse_literals(return_source=True) + def f(): + v = [object(), object()] + return v[0] + + result = dedent(''' + def f(): + v = [object(), object()] + return v[0] + ''') + self.assertEqual(f.strip(), result.strip()) + + +class TestDeindex(TestCase): + def test_with_literals(self): + v = [1, 2, 3] + @pragma.collapse_literals(return_source=True) + @pragma.deindex(v, 'v') + def f(): + return v[0] + v[1] + v[2] + + result = dedent(''' + def f(): + return 6 + ''') + self.assertEqual(f.strip(), result.strip()) + + def test_with_objects(self): + v = [object(), object(), object()] + @pragma.deindex(v, 'v', return_source=True) + def f(): + return v[0] + v[1] + v[2] + + result = dedent(''' + def f(): + return v_0 + v_1 + v_2 + ''') + self.assertEqual(f.strip(), result.strip()) + + def test_with_unroll(self): + v = [None, None, None] + + @pragma.deindex(v, 'v', return_source=True) + @pragma.unroll(lv=len(v)) + def f(): + for i in range(lv): + yield v[i] + + result = dedent(''' + def f(): + yield v_0 + yield v_1 + yield v_2 + ''') + self.assertEqual(f.strip(), result.strip()) + + def test_with_literals_run(self): + v = [1, 2, 3] + @pragma.collapse_literals + @pragma.deindex(v, 'v') + def f(): + return v[0] + v[1] + v[2] + + self.assertEqual(f(), sum(v)) + + def test_with_objects_run(self): + v = [object(), object(), object()] + @pragma.deindex(v, 'v') + def f(): + return v[0] + + self.assertEqual(f(), v[0]) + + def test_with_variable_indices(self): + v = [object(), object(), object()] + @pragma.deindex(v, 'v', return_source=True) + def f(x): + yield v[0] + yield v[x] + + result = dedent(''' + def f(x): + yield v_0 + yield v[x] + ''') + self.assertEqual(f.strip(), result.strip()) + + def test_dynamic_function_calls(self): + funcs = [lambda x: x, lambda x: x ** 2, lambda x: x ** 3] + + # TODO: Support enumerate transparently + # TODO: Support tuple assignment in loop transparently + + @pragma.deindex(funcs, 'funcs') + @pragma.unroll(lf=len(funcs)) + def run_func(i, x): + for j in range(lf): + if i == j: + return funcs[j](x) + + print(inspect.getsource(run_func)) + + self.assertEqual(run_func(0, 5), 5) + self.assertEqual(run_func(1, 5), 25) + self.assertEqual(run_func(2, 5), 125) + + result = dedent(''' + def run_func(i, x): + if i == 0: + return funcs_0(x) + if i == 1: + return funcs_1(x) + if i == 2: + return funcs_2(x) + ''') + self.assertEqual(inspect.getsource(run_func).strip(), result.strip()) + class TestDictStack(TestCase): def test_most(self): From 0854d0e6ed47bf37b29ac014143d04d39e3ec4cf Mon Sep 17 00:00:00 2001 From: scnerd Date: Mon, 20 Nov 2017 17:43:12 -0500 Subject: [PATCH 4/5] Removed pragma print statements --- .gitignore | 3 ++- miniutils/pragma.py | 42 +++++++++++++++++++++--------------------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/.gitignore b/.gitignore index b5322b8..9ed17e4 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ build .coverage *.log htmlcov -*.ipynb_checkpoints \ No newline at end of file +*.ipynb_checkpoints +.coverage.* \ No newline at end of file diff --git a/miniutils/pragma.py b/miniutils/pragma.py index 5d26593..3650586 100644 --- a/miniutils/pragma.py +++ b/miniutils/pragma.py @@ -87,9 +87,9 @@ def _function_ast(f): def can_have_side_effect(node, ctxt): if isinstance(node, ast.AST): - print("Can {} have side effects?".format(node)) + # print("Can {} have side effects?".format(node)) if isinstance(node, ast.Call): - print(" Yes!") + # print(" Yes!") return True else: for field, old_value in ast.iter_fields(node): @@ -98,7 +98,7 @@ def can_have_side_effect(node, ctxt): elif isinstance(old_value, ast.AST): return can_have_side_effect(old_value, ctxt) else: - print(" No!") + # print(" No!") return False else: return False @@ -304,22 +304,22 @@ def __init__(self, ctxt=None): self.ctxt = ctxt or DictStack() super().__init__() - def visit(self, node): - orig_node = copy.deepcopy(node) - new_node = super().visit(node) - - orig_node_code = astor.to_source(orig_node).strip() - try: - if new_node is None: - print("Deleted >>> {} <<<".format(orig_node_code)) - elif isinstance(new_node, ast.AST): - print("Converted >>> {} <<< to >>> {} <<<".format(orig_node_code, astor.to_source(new_node).strip())) - elif isinstance(new_node, list): - print("Converted >>> {} <<< to [[[ {} ]]]".format(orig_node_code, ", ".join(astor.to_source(n).strip() for n in new_node))) - except AssertionError as ex: - raise AssertionError("Failed on {} >>> {}".format(orig_node_code, astor.dump_tree(new_node))) from ex - - return new_node + # def visit(self, node): + # orig_node = copy.deepcopy(node) + # new_node = super().visit(node) + # + # orig_node_code = astor.to_source(orig_node).strip() + # try: + # if new_node is None: + # print("Deleted >>> {} <<<".format(orig_node_code)) + # elif isinstance(new_node, ast.AST): + # print("Converted >>> {} <<< to >>> {} <<<".format(orig_node_code, astor.to_source(new_node).strip())) + # elif isinstance(new_node, list): + # print("Converted >>> {} <<< to [[[ {} ]]]".format(orig_node_code, ", ".join(astor.to_source(n).strip() for n in new_node))) + # except AssertionError as ex: + # raise AssertionError("Failed on {} >>> {}".format(orig_node_code, astor.dump_tree(new_node))) from ex + # + # return new_node def visit_Assign(self, node): node.value = self.visit(node.value) @@ -417,11 +417,11 @@ def inner(f): # Apply manual globals override if function_globals is not None: glbls.update(function_globals) - print({k: v for k, v in glbls.items() if k not in globals()}) + # print({k: v for k, v in glbls.items() if k not in globals()}) trans = transformer_type(DictStack(glbls, kwargs)) f_mod.body[0].decorator_list = [] f_mod = trans.visit(f_mod) - print(astor.dump_tree(f_mod)) + # print(astor.dump_tree(f_mod)) if return_source or save_source: try: source = astor.to_source(f_mod) From 716475213a349fcddf6f82cce8448c85e25fbc6f Mon Sep 17 00:00:00 2001 From: scnerd Date: Tue, 21 Nov 2017 10:19:39 -0500 Subject: [PATCH 5/5] Removed multiprocessing from travis testing --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 2a2ae71..c490d9d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,7 +8,7 @@ install: - pip install . - pip install -r requirements.txt -script: nosetests --processes=4 --with-coverage tests +script: nosetests --with-coverage tests after_success: coveralls