diff --git a/.coveragerc b/.coveragerc index f17bd80..35cfba8 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,9 +1,12 @@ [run] +source=miniutils concurrency = multiprocessing -include = miniutils/*.py omit = miniutils/py2_template.py *__init__* [report] +include=miniutils/*.py +omit = + *__init__* show_missing = true diff --git a/miniutils/pragma/__init__.py b/miniutils/pragma/__init__.py deleted file mode 100644 index 6bae598..0000000 --- a/miniutils/pragma/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .collapse_literals import collapse_literals -from .core import * -from .deindex import deindex -from .unroll import unroll -from .inline import inline diff --git a/miniutils/pragma/collapse_literals.py b/miniutils/pragma/collapse_literals.py deleted file mode 100644 index b70b71e..0000000 --- a/miniutils/pragma/collapse_literals.py +++ /dev/null @@ -1,49 +0,0 @@ -import ast - -from .core import TrackedContextTransformer, make_function_transformer, resolve_literal - - -# noinspection PyPep8Naming -class CollapseTransformer(TrackedContextTransformer): - def visit_Name(self, node): - return resolve_literal(node, self.ctxt) - - def visit_BinOp(self, node): - return resolve_literal(node, self.ctxt) - - def visit_UnaryOp(self, node): - return resolve_literal(node, self.ctxt) - - def visit_BoolOp(self, node): - return resolve_literal(node, self.ctxt) - - def visit_Compare(self, node): - return resolve_literal(node, self.ctxt) - - def visit_Subscript(self, node): - return resolve_literal(node, self.ctxt) - - def visit_If(self, node): - cond = resolve_literal(node.test, self.ctxt, True) - # print("Attempting to collapse IF conditioned on {}".format(cond)) - if not isinstance(cond, ast.AST): - # print("Yes, this IF can be consolidated, condition is {}".format(bool(cond))) - body = node.body if cond else node.orelse - result = [] - for subnode in body: - res = self.visit(subnode) - if res is None: - pass - elif isinstance(res, list): - result += res - else: - result.append(res) - return result - else: - # print("No, this IF cannot be consolidated") - return super().generic_visit(node) - - -# Collapse defined literal values, and operations thereof, where possible -collapse_literals = make_function_transformer(CollapseTransformer, 'collapse_literals', - "Collapses literal expressions in the decorated function into single literals") diff --git a/miniutils/pragma/core/__init__.py b/miniutils/pragma/core/__init__.py deleted file mode 100644 index 6f6a53e..0000000 --- a/miniutils/pragma/core/__init__.py +++ /dev/null @@ -1,36 +0,0 @@ -import ast -import inspect - -import astor - -from miniutils.magic_contract import safe_new_contract - - -def is_iterable(x): - try: - iter(x) - return True - except Exception: - return False - - -safe_new_contract('function', lambda x: callable(x)) -# safe_new_contract('function', type(lambda: None)) -safe_new_contract('iterable', is_iterable) -safe_new_contract('literal', 'int|float|str|bool|tuple|list|None') -for name, tp in inspect.getmembers(ast, inspect.isclass): - safe_new_contract(name, tp) - -# Astor tries to get fancy by failing nicely, but in doing so they fail when traversing non-AST type node properties. -# By deleting this custom handler, it'll fall back to the default ast visit pattern, which skips these missing -# properties. Everything else seems to be implemented, so this will fail silently if it hits an AST node that isn't -# supported but should be. -try: - del astor.node_util.ExplicitNodeVisitor.visit -except AttributeError: - # visit isn't defined in this version of astor - pass - -from .stack import DictStack -from .resolve import * -from .transformer import * diff --git a/miniutils/pragma/core/resolve.py b/miniutils/pragma/core/resolve.py deleted file mode 100644 index 2cce0a5..0000000 --- a/miniutils/pragma/core/resolve.py +++ /dev/null @@ -1,337 +0,0 @@ -import ast -import traceback -import warnings - -from miniutils.magic_contract import magic_contract -from .stack import DictStack - -try: - import numpy - - num_types = (int, float, numpy.number) - float_types = (float, numpy.floating) -except ImportError: # pragma: nocover - numpy = None - num_types = (int, float) - float_types = (float,) - -_collapse_map = { - ast.Add: lambda a, b: a + b, - ast.Sub: lambda a, b: a - b, - ast.Mult: lambda a, b: a * b, - ast.Div: lambda a, b: a / b, - ast.FloorDiv: lambda a, b: a // b, - - ast.Mod: lambda a, b: a % b, - ast.Pow: lambda a, b: a ** b, - ast.LShift: lambda a, b: a << b, - ast.RShift: lambda a, b: a >> b, - ast.MatMult: lambda a, b: a @ b, - - ast.BitAnd: lambda a, b: a & b, - ast.BitOr: lambda a, b: a | b, - ast.BitXor: lambda a, b: a ^ b, - ast.And: lambda a, b: a and b, - ast.Or: lambda a, b: a or b, - ast.Invert: lambda a: ~a, - ast.Not: lambda a: not a, - - ast.UAdd: lambda a: a, - ast.USub: lambda a: -a, - - ast.Eq: lambda a, b: a == b, - ast.NotEq: lambda a, b: a != b, - ast.Lt: lambda a, b: a < b, - ast.LtE: lambda a, b: a <= b, - ast.Gt: lambda a, b: a > b, - ast.GtE: lambda a, b: a >= b, -} - - -@magic_contract -def can_have_side_effect(node, ctxt): - """ - Checks whether or not copying the given AST node could cause side effects in the resulting function - :param node: The AST node to be checked - :type node: AST - :param ctxt: The environment stack to use when running the check - :type ctxt: DictStack - :return: Whether or not duplicating this node could cause side effects - :rtype: bool - """ - if isinstance(node, ast.AST): - # print("Can {} have side effects?".format(node)) - if isinstance(node, ast.Call): - # print(" Yes!") - return True - else: - for field, old_value in ast.iter_fields(node): - if isinstance(old_value, list): - return any(can_have_side_effect(n, ctxt) for n in old_value if isinstance(n, ast.AST)) - elif isinstance(old_value, ast.AST): - return can_have_side_effect(old_value, ctxt) - else: - # print(" No!") - return False - else: - return False - - -@magic_contract -def constant_iterable(node, ctxt, avoid_side_effects=True): - """ - If the given node is a known iterable of some sort, return the list of its elements. - :param node: The AST node to be checked - :type node: AST - :param ctxt: The environment stack to use when running the check - :type ctxt: DictStack - :param avoid_side_effects: Whether or not to avoid unwrapping side effect-causing AST nodes - :type avoid_side_effects: bool - :return: The iterable if possible, else None - :rtype: iterable|None - """ - - # TODO: Support zipping - # TODO: Support sets/dicts? - # TODO: Support for reversed, enumerate, etc. - # TODO: Support len, in, etc. - # Check for range(*constants) - def wrap(return_node, name, idx): - if not avoid_side_effects: - return return_node - if can_have_side_effect(return_node, ctxt): - return ast.Subscript(name, ast.Index(idx)) - return make_ast_from_literal(return_node) - - if isinstance(node, ast.Call): - if resolve_name_or_attribute(node.func, ctxt) == range: - args = [resolve_literal(arg, ctxt) for arg in node.args] - if all(isinstance(arg, ast.Num) for arg in args): - return [ast.Num(n) for n in range(*[arg.n for arg in args])] - - return None - elif isinstance(node, (ast.List, ast.Tuple)): - return [resolve_literal(e, ctxt) for e in node.elts] - # return [_resolve_name_or_attribute(e, ctxt) for e in node.elts] - # Can't yet support sets and lists, since you need to compute what the unique values would be - # elif isinstance(node, ast.Dict): - # return node.keys - elif isinstance(node, (ast.Name, ast.Attribute, ast.NameConstant)): - res = resolve_name_or_attribute(node, ctxt) - # print("Trying to resolve '{}' as list, got {}".format(astor.to_source(node), res)) - if isinstance(res, ast.AST) and not isinstance(res, (ast.Name, ast.Attribute, ast.NameConstant)): - res = constant_iterable(res, ctxt) - if not isinstance(res, ast.AST): - try: - if hasattr(res, 'items'): - return dict([(k, wrap(make_ast_from_literal(v), node, k)) for k, v in res.items()]) - else: - return [wrap(make_ast_from_literal(res_node), node, i) for i, res_node in enumerate(res)] - except TypeError: - pass - return None - - -# @magic_contract -def constant_dict(node, ctxt): - if isinstance(node, (ast.Name, ast.NameConstant, ast.Attribute)): - res = resolve_name_or_attribute(node, ctxt) - if hasattr(res, 'items'): - return dict(res.items()) - return None - - -@magic_contract -def resolve_name_or_attribute(node, ctxt): - """ - If the given name of attribute is defined in the current context, return its value. Else, returns the node - :param node: The node to try to resolve - :type node: AST - :param ctxt: The environment stack to use when running the check - :type ctxt: DictStack - :return: The object if the name was found, else the original node - :rtype: * - """ - if isinstance(node, ast.Name): - if node.id in ctxt: - try: - return ctxt[node.id] - except KeyError: - # This occurs if we know that the name was assigned, but we don't know what to... just return the node - return node - else: - return node - elif isinstance(node, ast.NameConstant): - return node.value - elif isinstance(node, ast.Attribute): - base_obj = resolve_name_or_attribute(node.value, ctxt) - if not isinstance(base_obj, ast.AST): - return getattr(base_obj, node.attr, node) - else: - return node - else: - return node - - -@magic_contract -def make_ast_from_literal(lit): - """ - Converts literals into their AST equivalent - :param lit: The literal to attempt to turn into an AST - :type lit: * - :return: The AST version of the literal, or the original AST node if one was given - :rtype: * - """ - if isinstance(lit, ast.AST): - return lit - elif isinstance(lit, (list, tuple)): - res = [make_ast_from_literal(e) for e in lit] - tp = ast.List if isinstance(lit, list) else ast.Tuple - return tp(elts=res) - elif isinstance(lit, num_types): - if isinstance(lit, float_types): - lit2 = float(lit) - else: - lit2 = int(lit) - if lit2 != lit: - raise AssertionError("({}){} != ({}){}".format(type(lit), lit, type(lit2), lit2)) - return ast.Num(lit2) - elif isinstance(lit, str): - return ast.Str(lit) - elif isinstance(lit, bool): - return ast.NameConstant(lit) - else: - # warnings.warn("'{}' of type {} is not able to be made into an AST node".format(lit, type(lit))) - return lit - - -@magic_contract -def is_wrappable(lit): - """ - Checks if the given object either is or can be made into a known AST node - :param lit: The object to try to wrap - :type lit: * - :return: Whether or not this object can be wrapped as an AST node - :rtype: bool - """ - return isinstance(make_ast_from_literal(lit), ast.AST) - - -@magic_contract -def _resolve_literal(node, ctxt): - """ - Collapses literal expressions. Returns literals if they're available, AST nodes otherwise - :param node: The AST node to be checked - :type node: AST - :param ctxt: The environment stack to use when running the check - :type ctxt: DictStack - :return: The given AST node with literal operations collapsed as much as possible - :rtype: * - """ - # try: - # print("Trying to collapse {}".format(astor.to_source(node))) - # except: - # print("Trying to collapse (source not possible) {}".format(astor.dump_tree(node))) - - if isinstance(node, (ast.Name, ast.Attribute, ast.NameConstant)): - res = resolve_name_or_attribute(node, ctxt) - if isinstance(res, ast.AST) and not isinstance(res, (ast.Name, ast.Attribute, ast.NameConstant)): - new_res = _resolve_literal(res, ctxt) - if is_wrappable(new_res): - # print("{} can be replaced by more specific literal {}".format(res, new_res)) - res = new_res - # else: - # print("{} is an AST node, but can't safely be made more specific".format(res)) - return res - elif isinstance(node, ast.Num): - return node.n - elif isinstance(node, ast.Str): - return node.s - elif isinstance(node, ast.Index): - return _resolve_literal(node.value, ctxt) - elif isinstance(node, (ast.Slice, ast.ExtSlice)): - raise NotImplementedError() - elif isinstance(node, ast.Subscript): - # print("Attempting to subscript {}".format(astor.to_source(node))) - lst = constant_iterable(node.value, ctxt) - # print("Can I subscript {}?".format(lst)) - if lst is None: - return node - slc = _resolve_literal(node.slice, ctxt) - # print("Getting subscript at {}".format(slc)) - if isinstance(slc, ast.AST): - return node - # print("Value at {}[{}] = {}".format(lst, slc, lst[slc])) - val = lst[slc] - if isinstance(val, ast.AST): - new_val = _resolve_literal(val, ctxt) - if is_wrappable(new_val): - # print("{} can be replaced by more specific literal {}".format(val, new_val)) - val = new_val - # else: - # print("{} is an AST node, but can't safely be made more specific".format(val)) - # print("Final value at {}[{}] = {}".format(lst, slc, val)) - return val - elif isinstance(node, ast.UnaryOp): - operand = _resolve_literal(node.operand, ctxt) - if isinstance(operand, ast.AST): - return node - else: - try: - return _collapse_map[type(node.op)](operand) - except: - warnings.warn( - "Unary op collapse failed. Collapsing skipped, but executing this function will likely fail." - " Error was:\n{}".format(traceback.format_exc())) - elif isinstance(node, ast.BinOp): - left = _resolve_literal(node.left, ctxt) - right = _resolve_literal(node.right, ctxt) - # print("({} {})".format(repr(node.op), ", ".join(repr(o) for o in operands))) - lliteral = not isinstance(left, ast.AST) - rliteral = not isinstance(right, ast.AST) - if lliteral and rliteral: - # print("Both operands {} and {} are literals, attempting to collapse".format(left, right)) - try: - return _collapse_map[type(node.op)](left, right) - except: - warnings.warn( - "Binary op collapse failed. Collapsing skipped, but executing this function will likely fail." - " Error was:\n{}".format(traceback.format_exc())) - return node - else: - left = make_ast_from_literal(left) - left = left if isinstance(left, ast.AST) else node.left - - right = make_ast_from_literal(right) - right = right if isinstance(right, ast.AST) else node.right - # print("Attempting to combine {} and {} ({} op)".format(left, right, node.op)) - return ast.BinOp(left=left, right=right, op=node.op) - elif isinstance(node, ast.Compare): - operands = [_resolve_literal(o, ctxt) for o in [node.left] + node.comparators] - if all(not isinstance(opr, ast.AST) for opr in operands): - return all(_collapse_map[type(cmp_func)](operands[i - 1], operands[i]) - for i, cmp_func in zip(range(1, len(operands)), node.ops)) - else: - return node - else: - return node - - -@magic_contract -def resolve_literal(node, ctxt, give_raw_result=False): - """ - Collapse literal expressions in the given node. Returns the node with the collapsed literals - :param node: The AST node to be checked - :type node: AST - :param ctxt: The environment stack to use when running the check - :type ctxt: DictStack - :return: The given AST node with literal operations collapsed as much as possible - :rtype: * - """ - result = _resolve_literal(node, ctxt) - if give_raw_result: - return result - result = make_ast_from_literal(result) - if not isinstance(result, ast.AST): - return node - return result diff --git a/miniutils/pragma/core/stack.py b/miniutils/pragma/core/stack.py deleted file mode 100644 index 3e6e4f4..0000000 --- a/miniutils/pragma/core/stack.py +++ /dev/null @@ -1,53 +0,0 @@ -class DictStack: - """ - Creates a stack of dictionaries to roughly emulate closures and variable environments - """ - - def __init__(self, *base): - import builtins - self.dicts = [dict(builtins.__dict__)] + [dict(d) for d in base] - self.constants = [True] + [False] * len(base) - - def __iter__(self): - return (key for dct in self.dicts for key in dct.keys()) - - def __setitem__(self, key, value): - # print("SETTING {} = {}".format(key, value)) - self.dicts[-1][key] = value - - def __getitem__(self, item): - for dct in self.dicts[::-1]: - if item in dct: - if dct[item] is None: - raise KeyError("Found '{}', but it was set to an unknown value".format(item)) - return dct[item] - raise KeyError("Can't find '{}' anywhere in the function's context".format(item)) - - def __delitem__(self, item): - for dct in self.dicts[::-1]: - if item in dct: - del dct[item] - return - raise KeyError() - - def __contains__(self, item): - return any(item == key for dct in self.dicts for key in dct.keys()) - - def items(self): - items = [] - for dct in self.dicts[::-1]: - for k, v in dct.items(): - if k not in items: - items.append((k, v)) - return items - - def keys(self): - return set().union(*[dct.keys() for dct in self.dicts]) - - def push(self, dct=None, is_constant=False): - self.dicts.append(dct or {}) - self.constants.append(is_constant) - - def pop(self): - self.constants.pop() - return self.dicts.pop() diff --git a/miniutils/pragma/core/transformer.py b/miniutils/pragma/core/transformer.py deleted file mode 100644 index b4dad76..0000000 --- a/miniutils/pragma/core/transformer.py +++ /dev/null @@ -1,234 +0,0 @@ -import ast -import copy -import inspect -import sys -import tempfile -import textwrap - -import astor - -from miniutils.magic_contract import magic_contract -from miniutils.opt_decorator import optional_argument_decorator -from .resolve import * -from .stack import DictStack - - -@magic_contract -def function_ast(f): - """ - Returns ast for the given function. Gives a tuple of (ast_module, function_body, function_file - :param f: The function to parse - :type f: function - :return: The relevant AST code: A module including only the function definition; the func body; the func file - :rtype: tuple(Module, list(AST), str) - """ - try: - f_file = sys.modules[f.__module__].__file__ - except (KeyError, AttributeError): - f_file = '' - - root = ast.parse(textwrap.dedent(inspect.getsource(f)), f_file) - return root, root.body[0].body, f_file - - -@magic_contract -def _assign_names(node): - """ - Gets names from a assign-to tuple in flat form, just to know what's affected - "x=3" -> "x" - "a,b=4,5" -> ["a", "b"] - "(x,(y,z)),(a,) = something" -> ["x", "y", "z", "a"] - - :param node: The AST node to resolve to a list of names - :type node: AST - :return: The flattened list of names referenced in this node - :rtype: iterable - """ - if isinstance(node, ast.Name): - yield node.id - elif isinstance(node, ast.Tuple): - for e in node.elts: - yield from _assign_names(e) - elif isinstance(node, ast.Subscript): - yield from _assign_names(node.value) - - -class DebugTransformerMixin: # pragma: nocover - def visit(self, node): - orig_node_code = astor.to_source(node).strip() - print("Starting to visit >> {} <<".format(orig_node_code)) - - new_node = super().visit(node) - - try: - if new_node is None: - print("Deleted >>> {} <<<".format(orig_node_code)) - elif isinstance(new_node, ast.AST): - print("Converted >>> {} <<< to >>> {} <<<".format(orig_node_code, astor.to_source(new_node).strip())) - elif isinstance(new_node, list): - print("Converted >>> {} <<< to [[[ {} ]]]".format(orig_node_code, ", ".join(astor.to_source(n).strip() for n in new_node))) - except Exception as ex: - raise AssertionError("Failed on {} >>> {}".format(orig_node_code, astor.dump_tree(new_node))) from ex - # print("Failed on {} >>> {}".format(astor.dump_tree(orig_node), astor.dump_tree(new_node))) - # return orig_node - - return new_node - - - -class TrackedContextTransformer(ast.NodeTransformer): - def __init__(self, ctxt=None): - self.ctxt = ctxt or DictStack() - super().__init__() - - def visit_many(self, nodes): - for n in nodes: - n = self.visit(n) - if n is not None: - if isinstance(n, ast.AST): - yield n - else: - yield from n - - def visit_Assign(self, node): - node.value = self.visit(node.value) - erase_targets = True - # print(node.value) - # TODO: Support tuple assignments - if len(node.targets) == 1: - if isinstance(node.targets[0], ast.Name): - nvalue = copy.deepcopy(node.value) - var = node.targets[0].id - val = constant_iterable(nvalue, self.ctxt) - if val is not None: - # print("Setting {} = {}".format(var, val)) - self.ctxt[var] = val - else: - val = resolve_literal(nvalue, self.ctxt) - # print("Setting {} = {}".format(var, val)) - self.ctxt[var] = val - erase_targets = False - # elif isinstance(node.targets[0], ast.Subscript): - # targ = node.targets[0] - # iterable = constant_iterable(targ.value, self.ctxt, False) - # if iterable is None: - # iterable = constant_dict(targ.value, self.ctxt) - # if iterable is None: - # return node - # key = resolve_literal(targ.slice, self.ctxt) - # if isinstance(key, ast.AST): - # return node - # - # nvalue = copy.deepcopy(node.value) - # val = constant_iterable(nvalue, self.ctxt) - # warnings.warn("Iterable assignment not fully implemented yet...") - # if val is not None: - # # print("Setting {} = {}".format(var, val)) - # iterable[key] = val - # else: - # val = resolve_literal(nvalue, self.ctxt) - # # print("Setting {} = {}".format(var, val)) - # iterable[key] = val - # erase_targets = False - - if erase_targets: - for targ in node.targets: - for assgn in _assign_names(targ): - self.ctxt[assgn] = None - return node - - def visit_AugAssign(self, node): - for assgn in _assign_names(node.target): - self.ctxt[assgn] = None - return super().generic_visit(node) - - def visit_Delete(self, node): - for targ in node.targets: - for assgn in _assign_names(targ): - del self.ctxt[assgn] - return super().generic_visit(node) - - def visit_FunctionDef(self, node): - self.ctxt.push({}, False) - node.body = list(self.visit_many(node.body)) - self.ctxt.pop() - return self.generic_visit(node) - - def visit_AsyncFunctionDef(self, node): - self.ctxt.push({}, False) - node.body = list(self.visit_many(node.body)) - self.ctxt.pop() - return self.generic_visit(node) - - def visit_ClassDef(self, node): - self.ctxt.push({}, False) - node.body = list(self.visit_many(node.body)) - self.ctxt.pop() - return self.generic_visit(node) - - - -def make_function_transformer(transformer_type, name, description, **transformer_kwargs): - @optional_argument_decorator - @magic_contract - def transform(return_source=False, save_source=True, function_globals=None, **kwargs): - """ - :param return_source: Returns the unrolled function's source code instead of compiling it - :type return_source: bool - :param save_source: Saves the function source code to a tempfile to make it inspectable - :type save_source: bool - :param function_globals: Overridden global name assignments to use when processing the function - :type function_globals: dict|None - :param kwargs: Any other environmental variables to provide during unrolling - :type kwargs: dict - :return: The unrolled function, or its source code if requested - :rtype: function - """ - - @magic_contract(f='function', returns='function|str') - def inner(f): - f_mod, f_body, f_file = function_ast(f) - # Grab function globals - glbls = f.__globals__ - # Grab function closure variables - if isinstance(f.__closure__, tuple): - glbls.update({k: v.cell_contents for k, v in zip(f.__code__.co_freevars, f.__closure__)}) - # Apply manual globals override - if function_globals is not None: - glbls.update(function_globals) - # print({k: v for k, v in glbls.items() if k not in globals()}) - trans = transformer_type(DictStack(glbls, kwargs), **transformer_kwargs) - f_mod.body[0].decorator_list = [] - f_mod = trans.visit(f_mod) - # print(astor.dump_tree(f_mod)) - if return_source or save_source: - try: - source = astor.to_source(f_mod) - except ImportError: # pragma: nocover - raise ImportError("miniutils.pragma.{name} requires 'astor' to be installed to obtain source code" - .format(name=name)) - except Exception as ex: # pragma: nocover - raise RuntimeError(astor.dump_tree(f_mod)) from ex - else: - source = None - - if return_source: - return source - else: - f_mod = ast.fix_missing_locations(f_mod) - if save_source: - temp = tempfile.NamedTemporaryFile('w', delete=True) - f_file = temp.name - exec(compile(f_mod, f_file, 'exec'), glbls) - func = glbls[f_mod.body[0].name] - if save_source: - func.__tempfile__ = temp - temp.write(source) - temp.flush() - return func - - return inner - - transform.__name__ = name - transform.__doc__ = '\n'.join([description, transform.__doc__]) - return transform diff --git a/miniutils/pragma/deindex.py b/miniutils/pragma/deindex.py deleted file mode 100644 index 3d20d88..0000000 --- a/miniutils/pragma/deindex.py +++ /dev/null @@ -1,33 +0,0 @@ -import ast - -from .collapse_literals import collapse_literals -from .. import magic_contract - - -# Directly reference elements of constant list, removing literal indexing into that list within a function -@magic_contract -def deindex(iterable, iterable_name, *args, **kwargs): - """ - :param iterable: The list to deindex in the target function - :type iterable: iterable - :param iterable_name: The list's name (must be unique if deindexing multiple lists) - :type iterable_name: str - :param args: Other command line arguments (see :func:`collapse_literals` for documentation) - :type args: tuple - :param kwargs: Any other environmental variables to provide during unrolling - :type kwargs: dict - :return: The unrolled function, or its source code if requested - :rtype: function - """ - - if hasattr(iterable, 'items'): # Support dicts and the like - internal_iterable = {k: '{}_{}'.format(iterable_name, k) for k, val in iterable.items()} - mapping = {internal_iterable[k]: val for k, val in iterable.items()} - raise NotImplementedError('Dictionary indices are not yet supported') - else: # Support lists, tuples, and the like - internal_iterable = {i: '{}_{}'.format(iterable_name, i) for i, val in enumerate(iterable)} - mapping = {internal_iterable[i]: val for i, val in enumerate(iterable)} - - kwargs[iterable_name] = {k: ast.Name(id=name, ctx=ast.Load()) for k, name in internal_iterable.items()} - - return collapse_literals(*args, function_globals=mapping, **kwargs) diff --git a/miniutils/pragma/inline.py b/miniutils/pragma/inline.py deleted file mode 100644 index 62affe7..0000000 --- a/miniutils/pragma/inline.py +++ /dev/null @@ -1,413 +0,0 @@ -from .core import * -from .. import magic_contract -from collections import OrderedDict as odict - - -# stmt = FunctionDef(identifier name, arguments args, -# stmt* body, expr* decorator_list, expr? returns) -# | AsyncFunctionDef(identifier name, arguments args, -# stmt* body, expr* decorator_list, expr? returns) -# -# | ClassDef(identifier name, -# expr* bases, -# keyword* keywords, -# stmt* body, -# expr* decorator_list) -# | Return(expr? value) -# -# | Delete(expr* targets) -# | Assign(expr* targets, expr value) -# | AugAssign(expr target, operator op, expr value) -# -- 'simple' indicates that we annotate simple name without parens -# | AnnAssign(expr target, expr annotation, expr? value, int simple) -# -# -- use 'orelse' because else is a keyword in target languages -# | For(expr target, expr iter, stmt* body, stmt* orelse) -# | AsyncFor(expr target, expr iter, stmt* body, stmt* orelse) -# | While(expr test, stmt* body, stmt* orelse) -# | If(expr test, stmt* body, stmt* orelse) -# | With(withitem* items, stmt* body) -# | AsyncWith(withitem* items, stmt* body) -# -# | Raise(expr? exc, expr? cause) -# | Try(stmt* body, excepthandler* handlers, stmt* orelse, stmt* finalbody) -# | Assert(expr test, expr? msg) -# -# | Import(alias* names) -# | ImportFrom(identifier? module, alias* names, int? level) -# -# | Global(identifier* names) -# | Nonlocal(identifier* names) -# | Expr(expr value) -# | Pass | Break | Continue -# -# -- XXX Jython will be different -# -- col_offset is the byte offset in the utf8 string the parser uses -# attributes (int lineno, int col_offset) - -DICT_FMT = "_{fname}_{n}" - -# @magic_contract -def make_name(fname, var, n, ctx=ast.Load): - """ - Create an AST node to represent the given argument name in the given function - :param fname: Function name - :type fname: str - :param var: Argument name - :type var: str - :param ctx: Context of this name (LOAD or STORE) - :type ctx: Load|Store - :param n: The number to append to this name (to allow for finite recursion) - :type n: int - :param fmt: Name format (if not stored in a dictionary) - :type fmt: str - :return: An AST node representing this argument - :rtype: Subscript|Call - """ - return ast.Subscript(value=ast.Name(id=DICT_FMT.format(fname=fname, n=n), ctx=ast.Load()), - slice=ast.Index(ast.Str(var)), - ctx=ctx()) - - -class _InlineBodyTransformer(TrackedContextTransformer): - def __init__(self, func_name, param_names, n): - self.func_name = func_name - # print("Func {} takes parameters {}".format(func_name, param_names)) - self.param_names = param_names - self.in_break_block = False - self.n = n - self.had_return = False - self.had_yield = False - super().__init__() - - def visit_Name(self, node): - # Check if this is a parameter, and hasn't had another value assigned to it - if node.id in self.param_names: - # print("Found parameter reference {}".format(node.id)) - if node.id not in self.ctxt: - # If so, get its value from the argument dictionary - return make_name(self.func_name, node.id, self.n, ctx=type(getattr(node, 'ctx', ast.Load()))) - else: - # print("But it's been overwritten to {} = {}".format(node.id, self.ctxt[node.id])) - pass - return node - - def visit_Return(self, node): - if self.in_break_block: - raise NotImplementedError("miniutils.pragma.inline cannot handle returns from within a loop") - result = [] - if node.value: - result.append(ast.Assign(targets=[make_name(self.func_name, 'return', self.n, ctx=ast.Store)], - value=self.visit(node.value))) - result.append(ast.Break()) - self.had_return = True - return result - - def visit_Yield(self, node): - self.had_yield = True - if node.value: - return ast.Call(func=ast.Attribute(value=make_name(self.func_name, 'yield', self.n, ctx=ast.Load), - attr='append', - ctx=ast.Load), - args=[self.visit(node.value)], - keywords=[]) - return node - - def visit_YieldFrom(self, node): - self.had_yield = True - return ast.Call(func=ast.Attribute(value=make_name(self.func_name, 'yield', self.n, ctx=ast.Load), - attr='extend', - ctx=ast.Load), - args=[self.visit(node.value)], - keywords=[]) - - def visit_For(self, node): - orig_in_break_block = self.in_break_block - self.in_break_block = True - res = self.generic_visit(node) - self.in_break_block = orig_in_break_block - return res - - def visit_While(self, node): - orig_in_break_block = self.in_break_block - self.in_break_block = True - res = self.generic_visit(node) - self.in_break_block = orig_in_break_block - return res - - def visit_FunctionDef(self, node): - return node - - def visit_AsyncFunctionDef(self, node): - return node - - def visit_ClassDef(self, node): - return node - - -class InlineTransformer(TrackedContextTransformer): - def __init__(self, *args, funs=None, max_depth=1, **kwargs): - assert funs is not None - super().__init__(*args, **kwargs) - - self.funs = funs - self.code_blocks = [] - self.max_depth = max_depth - - def nested_visit(self, nodes): - """When we visit a block of statements, create a new "code block" and push statements into it""" - lst = [] - self.code_blocks.append(lst) - for n in nodes: - res = self.visit(n) - if res is None: - continue - elif isinstance(res, list): - lst += res - else: - lst.append(res) - self.code_blocks.pop() - return lst - - def generic_visit_less(self, node, *without): - for field, old_value in ast.iter_fields(node): - if field in without: - continue - elif isinstance(old_value, list): - new_values = [] - for value in old_value: - if isinstance(value, ast.AST): - value = self.visit(value) - if value is None: - continue - elif not isinstance(value, ast.AST): - new_values.extend(value) - continue - new_values.append(value) - old_value[:] = new_values - elif isinstance(old_value, ast.AST): - new_node = self.visit(old_value) - if new_node is None: - delattr(node, field) - else: - setattr(node, field, new_node) - return node - - def visit_Call(self, node): - """When we see a function call, insert the function body into the current code block, then replace the call - with the return expression """ - node = self.generic_visit(node) - node_fun = resolve_name_or_attribute(resolve_literal(node.func, self.ctxt), self.ctxt) - - for (fun, fname, fsig, fbody) in self.funs: - if fun != node_fun: - continue - - n = 0 - for i in range(self.max_depth): - args_dict_name = DICT_FMT.format(fname=fname, n=i) - n = i # This is redundant, but a bit clearer and safer than just referencing i later - if args_dict_name not in self.ctxt: - break - else: - warnings.warn("Inline hit recursion limit, using normal function call") - return node - - func_for_inlining = _InlineBodyTransformer(fname, fsig.parameters, n) - fbody = list(func_for_inlining.visit_many(copy.deepcopy(fbody))) - - print(self.code_blocks) - cur_block = self.code_blocks[-1] - new_code = [] - - # Load arguments into their appropriate variables - args = node.args - flattened_args = [] - for a in args: - if isinstance(a, ast.Starred): - a = constant_iterable(a.value, self.ctxt) - if a: - flattened_args.extend(a) - else: - warnings.warn("Cannot inline function call that uses non-constant star args") - return node - else: - flattened_args.append(a) - - keywords = [(kw.arg, kw.value) for kw in node.keywords if kw.arg is not None] - kw_dict = [kw.value for kw in node.keywords if kw.arg is None] - kw_dict = kw_dict[0] if kw_dict else None - - bound_args = fsig.bind(*flattened_args, **odict(keywords)) - bound_args.apply_defaults() - - # Create args dictionary - final_args = [] - - for arg_name, arg_value in bound_args.arguments.items(): - if isinstance(arg_value, tuple): - arg_value = ast.Tuple(elts=list(arg_value), ctx=ast.Load()) - elif isinstance(arg_value, dict): - keys, values = zip(*list(arg_value.items())) - keys = [ast.Str(k) for k in keys] - values = list(values) - arg_value = ast.Dict(keys=keys, values=values) - # fun_name['param_name'] = param_value - final_args.append((arg_name, arg_value)) - - if kw_dict: - final_args.append((None, kw_dict)) - - if func_for_inlining.had_yield: - final_args.append(('yield', ast.List(elts=[]))) - - # fun_name = {} - dict_call = ast.Call( - func=ast.Name(id='dict', ctx=ast.Load()), - args=[], - keywords=[ast.keyword(arg=name, value=val) for name, val in final_args] - ) - new_code.append(ast.Assign( - targets=[ast.Name(id=args_dict_name, ctx=ast.Store())], - value=dict_call - )) - - # Process assignments before resolving body - cur_block.extend(self.visit_many(new_code)) - - # Inline function code - new_body = list(self.visit_many(fbody)) - - # cur_block.append(self.visit(ast.For(target=ast.Name(id='____', ctx=ast.Store()), - # iter=ast.List(elts=[ast.NameConstant(None)], ctx=ast.Load()), - # body=new_body, - # orelse=[]))) - cur_block.append(ast.For(target=ast.Name(id='____', ctx=ast.Store()), - iter=ast.List(elts=[ast.NameConstant(None)], ctx=ast.Load()), - body=new_body, - orelse=[])) - - # fun_name['return'] - if func_for_inlining.had_yield or func_for_inlining.had_return: - for j in range(100000): - output_name = DICT_FMT.format(fname=fname + '_return', n=j) - if output_name not in self.ctxt: - break - else: - raise RuntimeError("Function {} called and returned too many times during inlining, not able to " - "put the return value into a uniquely named variable".format(fname)) - - if func_for_inlining.had_yield: - cur_block.append(self.visit(ast.Assign(targets=[ast.Name(id=output_name, ctx=ast.Store())], - value=make_name(fname, 'yield', n)))) - elif func_for_inlining.had_return: - get_call = ast.Call( - func=ast.Attribute( - value=ast.Name(id=args_dict_name, ctx=ast.Load()), - attr='get', - ctx=ast.Load()), - args=[ast.Str('return'), ast.NameConstant(None)], - keywords=[] - ) - cur_block.append(self.visit(ast.Assign(targets=[ast.Name(id=output_name, ctx=ast.Store())], - value=get_call))) - - return_node = ast.Name(id=output_name, ctx=ast.Load()) - else: - return_node = ast.NameConstant(None) - - cur_block.append(self.visit(ast.Delete(targets=[ast.Name(id=args_dict_name, ctx=ast.Del())]))) - return return_node - - else: - return node - - - - ################################################### - # From here on down, we just have handlers for ever AST node that has a "code block" (stmt*) - ################################################### - - def visit_FunctionDef(self, node): - self.ctxt.push({}, False) - node.body = self.nested_visit(node.body) - self.ctxt.pop() - return self.generic_visit_less(node, 'body') - - def visit_AsyncFunctionDef(self, node): - self.ctxt.push({}, False) - node.body = self.nested_visit(node.body) - self.ctxt.pop() - return self.generic_visit_less(node, 'body') - - def visit_ClassDef(self, node): - self.ctxt.push({}, False) - node.body = self.nested_visit(node.body) - self.ctxt.pop() - return self.generic_visit_less(node, 'body') - - def visit_For(self, node): - node.body = self.nested_visit(node.body) - node.orelse = self.nested_visit(node.orelse) - return self.generic_visit_less(node, 'body', 'orelse') - - def visit_AsyncFor(self, node): - node.body = self.nested_visit(node.body) - node.orelse = self.nested_visit(node.orelse) - return self.generic_visit_less(node, 'body', 'orelse') - - def visit_While(self, node): - node.body = self.nested_visit(node.body) - node.orelse = self.nested_visit(node.orelse) - return self.generic_visit_less(node, 'body', 'orelse') - - def visit_If(self, node): - node.body = self.nested_visit(node.body) - node.orelse = self.nested_visit(node.orelse) - return self.generic_visit_less(node, 'body', 'orelse') - - def visit_With(self, node): - node.body = self.nested_visit(node.body) - return self.generic_visit_less(node, 'body') - - def visit_AsyncWith(self, node): - node.body = self.nested_visit(node.body) - return self.generic_visit_less(node, 'body') - - def visit_Try(self, node): - node.body = self.nested_visit(node.body) - node.orelse = self.nested_visit(node.orelse) - node.finalbody = self.nested_visit(node.finalbody) - return self.generic_visit_less(node, 'body', 'orelse', 'finalbody') - - def visit_Module(self, node): - node.body = self.nested_visit(node.body) - return self.generic_visit_less(node, 'body') - - def visit_ExceptHandler(self, node): - node.body = self.nested_visit(node.body) - return self.generic_visit_less(node, 'body') - - -# @magic_contract -def inline(*funs_to_inline, max_depth=1, **kwargs): - """ - :param funs_to_inline: The inner called function that should be inlined in the wrapped function - :type funs_to_inline: tuple(function) - :param max_depth: The maximum number of times to inline the provided function (limits recursion) - :type max_depth: int - :return: The unrolled function, or its source code if requested - :rtype: function - """ - funs = [] - for fun_to_inline in funs_to_inline: - fname = fun_to_inline.__name__ - fsig = inspect.signature(fun_to_inline) - _, fbody, _ = function_ast(fun_to_inline) - - funs.append((fun_to_inline, fname, fsig, fbody)) - - return make_function_transformer(InlineTransformer, - 'inline', - 'Inline the specified function within the decorated function', - funs=funs, max_depth=max_depth)(**kwargs) diff --git a/miniutils/pragma/unroll.py b/miniutils/pragma/unroll.py deleted file mode 100644 index 26763f3..0000000 --- a/miniutils/pragma/unroll.py +++ /dev/null @@ -1,78 +0,0 @@ -import copy - -from .core import * - - -def has_break(node): - for field, value in ast.iter_fields(node): - if isinstance(value, list): - for item in value: - if isinstance(item, (ast.Break, ast.Continue)): - return True - if isinstance(item, ast.AST): - if has_break(item): - return True - elif isinstance(value, ast.AST): - if has_break(value): - return True - return False - -# noinspection PyPep8Naming -class UnrollTransformer(TrackedContextTransformer): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.loop_vars = set() - - def visit_For(self, node): - top_level_break = False - for n in node.body: - if isinstance(n, ast.Break): - top_level_break = True - # We don't need to check if there's a break in an inner loop, since that doesn't affect this loop - elif isinstance(n, ast.If) and has_break(n): - # If there's a conditional break, there's not much we can do about that - # TODO: If the conditional is resolvable at unroll time, then do so - return self.generic_visit(node) - - iterable = constant_iterable(node.iter, self.ctxt) - if iterable is None: - return self.generic_visit(node) - - result = [] - loop_var = node.target.id - orig_loop_vars = self.loop_vars - # print("Unrolling 'for {} in {}'".format(loop_var, list(iterable))) - for val in iterable: - self.ctxt.push({loop_var: val}) - self.loop_vars = orig_loop_vars | {loop_var} - for body_node in copy.deepcopy(node.body): - res = self.visit(body_node) - if isinstance(res, list): - result.extend(res) - elif res is None: - continue - else: - result.append(res) - # result.extend([self.visit(body_node) for body_node in copy.deepcopy(node.body)]) - self.ctxt.pop() - if top_level_break: - first_result = result - result = [] - for n in first_result: - if isinstance(n, ast.Break): - break - result.append(n) - break - self.loop_vars = orig_loop_vars - return result - - def visit_Name(self, node): - if node.id in self.loop_vars: - if node.id in self.ctxt: - return self.ctxt[node.id] - raise NameError("'{}' not defined in context".format(node.id)) - return node - - -# Unroll literal loops -unroll = make_function_transformer(UnrollTransformer, 'unroll', "Unrolls constant loops in the decorated function") diff --git a/stress_tests/test_pragma.py b/stress_tests/test_pragma.py deleted file mode 100644 index 2213269..0000000 --- a/stress_tests/test_pragma.py +++ /dev/null @@ -1,48 +0,0 @@ -from unittest import TestCase -from miniutils import pragma -from textwrap import dedent -import inspect - - -class PragmaTest(TestCase): - def setUp(self): - pass - # # This is a quick hack to disable contracts for testing if needed - # import contracts - # contracts.enable_all() - - -class TestUnroll(PragmaTest): - def test_unroll_long_range(self): - @pragma.unroll - def f(): - for i in range(3): - yield i - - self.assertEqual(list(f()), [0, 1, 2]) - - -class TestCollapseLiterals(PragmaTest): - pass - - -class TestDeindex(PragmaTest): - pass - - -class TestInline(PragmaTest): - def test_recursive(self): - def fib(n): - if n <= 0: - return 1 - elif n == 1: - return 1 - else: - return fib(n-1) + fib(n-2) - - from miniutils import tic - toc = tic() - fib_code = pragma.inline(fib, max_depth=1, return_source=True)(fib) - fib_code = pragma.inline(fib, max_depth=2, return_source=True)(fib) - fib_code = pragma.inline(fib, max_depth=3, return_source=True)(fib) - #fib_code = pragma.inline(fib, max_depth=4, return_source=True)(fib) diff --git a/tests/test_pragma.py b/tests/test_pragma.py deleted file mode 100644 index bd4af5a..0000000 --- a/tests/test_pragma.py +++ /dev/null @@ -1,887 +0,0 @@ -from unittest import TestCase -from miniutils import pragma -from textwrap import dedent -import inspect - - -class PragmaTest(TestCase): - def setUp(self): - pass - # # This is a quick hack to disable contracts for testing if needed - # import contracts - # contracts.enable_all() - - -class TestUnroll(PragmaTest): - def test_unroll_range(self): - @pragma.unroll - def f(): - for i in range(3): - yield i - - self.assertEqual(list(f()), [0, 1, 2]) - - def test_unroll_various(self): - g = lambda: None - g.a = [1, 2, 3] - g.b = 6 - - @pragma.unroll(return_source=True) - def f(x): - y = 5 - a = range(3) - b = [1, 2, 4] - c = (1, 2, 5) - d = reversed(a) - e = [x, x, x] - f = [y, y, y] - for i in a: - yield i - for i in b: - yield i - for i in c: - yield i - for i in d: - yield i - for i in e: - yield i - for i in f: - yield i - for i in g.a: - yield i - for i in [g.b + 0, g.b + 1, g.b + 2]: - yield i - - result = dedent(''' - def f(x): - y = 5 - a = range(3) - b = [1, 2, 4] - c = 1, 2, 5 - d = reversed(a) - e = [x, x, x] - f = [y, y, y] - yield 0 - yield 1 - yield 2 - yield 1 - yield 2 - yield 4 - yield 1 - yield 2 - yield 5 - for i in d: - yield i - yield x - yield x - yield x - yield 5 - yield 5 - yield 5 - yield 1 - yield 2 - yield 3 - yield 6 - yield 7 - yield 8 - ''') - self.assertEqual(f.strip(), result.strip()) - - def test_unroll_const_list(self): - @pragma.unroll - def f(): - for i in [1, 2, 4]: - yield i - - self.assertEqual(list(f()), [1, 2, 4]) - - def test_unroll_const_tuple(self): - @pragma.unroll - def f(): - for i in (1, 2, 4): - yield i - - self.assertEqual(list(f()), [1, 2, 4]) - - def test_unroll_range_source(self): - @pragma.unroll(return_source=True) - def f(): - for i in range(3): - yield i - - result = dedent(''' - def f(): - yield 0 - yield 1 - yield 2 - ''') - self.assertEqual(f.strip(), result.strip()) - - def test_unroll_list_source(self): - @pragma.unroll(return_source=True) - def f(): - for i in [1, 2, 4]: - yield i - - result = dedent(''' - def f(): - yield 1 - yield 2 - yield 4 - ''') - self.assertEqual(f.strip(), result.strip()) - - def test_unroll_dyn_list_source(self): - @pragma.unroll(return_source=True) - def f(): - x = 3 - a = [x, x, x] - for i in a: - yield i - x = 4 - a = [x, x, x] - for i in a: - yield i - - result = dedent(''' - def f(): - x = 3 - a = [x, x, x] - yield 3 - yield 3 - yield 3 - x = 4 - a = [x, x, x] - yield 4 - yield 4 - yield 4 - ''') - self.assertEqual(f.strip(), result.strip()) - - def test_unroll_dyn_list(self): - def summation(x=0): - a = [x, x, x] - v = 0 - for _a in a: - v += _a - return v - - summation_source = pragma.unroll(return_source=True)(summation) - summation = pragma.unroll(summation) - - code = dedent(''' - def summation(x=0): - a = [x, x, x] - v = 0 - v += x - v += x - v += x - return v - ''') - self.assertEqual(summation_source.strip(), code.strip()) - self.assertEqual(summation(), 0) - self.assertEqual(summation(1), 3) - self.assertEqual(summation(5), 15) - - def test_unroll_2range_source(self): - @pragma.unroll(return_source=True) - def f(): - for i in range(3): - for j in range(3): - yield i + j - - result = dedent(''' - def f(): - yield 0 + 0 - yield 0 + 1 - yield 0 + 2 - yield 1 + 0 - yield 1 + 1 - yield 1 + 2 - yield 2 + 0 - yield 2 + 1 - yield 2 + 2 - ''') - self.assertEqual(f.strip(), result.strip()) - - def test_unroll_2list_source(self): - @pragma.unroll(return_source=True) - def f(): - for i in [[1, 2, 3], [4, 5], [6]]: - for j in i: - yield j - - result = dedent(''' - def f(): - yield 1 - yield 2 - yield 3 - yield 4 - yield 5 - yield 6 - ''') - self.assertEqual(f.strip(), result.strip()) - - def test_external_definition(self): - # Known bug: this works when defined as a kwarg, but not as an external variable, but ONLY in unittests... - # External variables work in practice - @pragma.unroll(return_source=True, a=range) - def f(): - for i in a(3): - print(i) - - result = dedent(''' - def f(): - print(0) - print(1) - print(2) - ''') - self.assertEqual(f.strip(), result.strip()) - - def test_tuple_assign(self): - # This is still early code, so just make sure that it recognizes when a name is assigned to... we don't get values yet - # TODO: Implement tuple assignment - @pragma.unroll(return_source=True) - def f(): - x = 3 - ((y, x), z) = ((1, 2), 3) - for i in [x,x,x]: - print(i) - - result = dedent(''' - def f(): - x = 3 - (y, x), z = (1, 2), 3 - print(x) - print(x) - print(x) - ''') - self.assertEqual(f.strip(), result.strip()) - - def test_top_break(self): - @pragma.unroll(return_source=True) - def f(): - for i in range(10): - print(i) - break - - result = dedent(''' - def f(): - print(0) - ''') - self.assertEqual(f.strip(), result.strip()) - - def test_inner_break(self): - @pragma.unroll(return_source=True) - def f(y): - for i in range(10): - print(i) - if i == y: - break - - result = dedent(''' - def f(y): - for i in range(10): - print(i) - if i == y: - break - ''') - self.assertEqual(f.strip(), result.strip()) - - -class TestCollapseLiterals(PragmaTest): - def test_full_run(self): - def f(y): - x = 3 - r = 1 + x - for z in range(2): - r *= 1 + 2 * 3 - for abc in range(x): - for a in range(abc): - for b in range(y): - r += 1 + 2 + y - return r - - deco_f = pragma.collapse_literals(f) - self.assertEqual(f(0), deco_f(0)) - self.assertEqual(f(1), deco_f(1)) - self.assertEqual(f(5), deco_f(5)) - self.assertEqual(f(-1), deco_f(-1)) - - deco_f = pragma.collapse_literals(pragma.unroll(f)) - self.assertEqual(f(0), deco_f(0)) - self.assertEqual(f(1), deco_f(1)) - self.assertEqual(f(5), deco_f(5)) - self.assertEqual(f(-1), deco_f(-1)) - - def test_basic(self): - @pragma.collapse_literals(return_source=True) - def f(): - return 1 + 1 - - result = dedent(''' - def f(): - return 2 - ''') - self.assertEqual(f.strip(), result.strip()) - - def test_vars(self): - @pragma.collapse_literals(return_source=True) - def f(): - x = 3 - y = 2 - return x + y - - result = dedent(''' - def f(): - x = 3 - y = 2 - return 5 - ''') - self.assertEqual(f.strip(), result.strip()) - - def test_partial(self): - @pragma.collapse_literals(return_source=True) - def f(y): - x = 3 - return x + 2 + y - - result = dedent(''' - def f(y): - x = 3 - return 5 + y - ''') - self.assertEqual(f.strip(), result.strip()) - - def test_constant_index(self): - @pragma.collapse_literals(return_source=True) - def f(): - x = [1,2,3] - return x[0] - - result = dedent(''' - def f(): - x = [1, 2, 3] - return 1 - ''') - self.assertEqual(f.strip(), result.strip()) - - def test_with_unroll(self): - @pragma.collapse_literals(return_source=True) - @pragma.unroll - def f(): - for i in range(3): - print(i + 2) - - result = dedent(''' - def f(): - print(2) - print(3) - print(4) - ''') - self.assertEqual(f.strip(), result.strip()) - - def test_with_objects(self): - @pragma.collapse_literals(return_source=True) - def f(): - v = [object(), object()] - return v[0] - - result = dedent(''' - def f(): - v = [object(), object()] - return v[0] - ''') - self.assertEqual(f.strip(), result.strip()) - - def test_invalid_collapse(self): - import warnings - warnings.resetwarnings() - with warnings.catch_warnings(record=True) as w: - @pragma.collapse_literals - def f(): - return 1 + "2" - - self.assertIsInstance(w[-1].category(), UserWarning) - - warnings.resetwarnings() - with warnings.catch_warnings(record=True) as w: - @pragma.collapse_literals - def f(): - return -"5" - - self.assertIsInstance(w[-1].category(), UserWarning) - - # TODO: implement the features to get this test to work - # def test_conditional_erasure(self): - # @pragma.collapse_literals(return_source=True) - # def f(y): - # x = 0 - # if y == x: - # x = 1 - # return x - # - # result = dedent(''' - # def f(y): - # x = 0 - # if y == 0: - # x = 1 - # return x - # ''') - # self.assertEqual(f.strip(), result.strip()) - - def test_constant_conditional_erasure(self): - @pragma.collapse_literals(return_source=True) - def f(y): - x = 0 - if x <= 0: - x = 1 - return x - - result = dedent(''' - def f(y): - x = 0 - x = 1 - return 1 - ''') - self.assertEqual(f.strip(), result.strip()) - - def fn(): - if x == 0: - x = 'a' - elif x == 1: - x = 'b' - else: - x = 'c' - return x - - result0 = dedent(''' - def fn(): - x = 'a' - return 'a' - ''') - result1 = dedent(''' - def fn(): - x = 'b' - return 'b' - ''') - result2 = dedent(''' - def fn(): - x = 'c' - return 'c' - ''') - self.assertEqual(pragma.collapse_literals(return_source=True, x=0)(fn).strip(), result0.strip()) - self.assertEqual(pragma.collapse_literals(return_source=True, x=1)(fn).strip(), result1.strip()) - self.assertEqual(pragma.collapse_literals(return_source=True, x=2)(fn).strip(), result2.strip()) - - def test_unary(self): - @pragma.collapse_literals(return_source=True) - def f(): - return 1 + -5 - - result = dedent(''' - def f(): - return -4 - ''') - self.assertEqual(f.strip(), result.strip()) - - -class TestDeindex(PragmaTest): - def test_with_literals(self): - v = [1, 2, 3] - @pragma.collapse_literals(return_source=True) - @pragma.deindex(v, 'v') - def f(): - return v[0] + v[1] + v[2] - - result = dedent(''' - def f(): - return 6 - ''') - self.assertEqual(f.strip(), result.strip()) - - def test_with_objects(self): - v = [object(), object(), object()] - @pragma.deindex(v, 'v', return_source=True) - def f(): - return v[0] + v[1] + v[2] - - result = dedent(''' - def f(): - return v_0 + v_1 + v_2 - ''') - self.assertEqual(result.strip(), f.strip()) - - def test_with_unroll(self): - v = [None, None, None] - - @pragma.deindex(v, 'v', return_source=True) - @pragma.unroll(lv=len(v)) - def f(): - for i in range(lv): - yield v[i] - - result = dedent(''' - def f(): - yield v_0 - yield v_1 - yield v_2 - ''') - self.assertEqual(f.strip(), result.strip()) - - def test_with_literals_run(self): - v = [1, 2, 3] - @pragma.collapse_literals - @pragma.deindex(v, 'v') - def f(): - return v[0] + v[1] + v[2] - - self.assertEqual(f(), sum(v)) - - def test_with_objects_run(self): - v = [object(), object(), object()] - @pragma.deindex(v, 'v') - def f(): - return v[0] - - self.assertEqual(f(), v[0]) - - def test_with_variable_indices(self): - v = [object(), object(), object()] - @pragma.deindex(v, 'v', return_source=True) - def f(x): - yield v[0] - yield v[x] - - result = dedent(''' - def f(x): - yield v_0 - yield v[x] - ''') - self.assertEqual(f.strip(), result.strip()) - - # Not yet supported - def test_dict(self): - d = {'a': 1, 'b': 2} - - def f(x): - yield d['a'] - yield d[x] - - self.assertRaises(NotImplementedError, pragma.deindex, d, 'd') - # result = dedent(''' - # def f(x): - # yield v_a - # yield v[x] - # ''') - # self.assertEqual(f.strip(), result.strip()) - - def test_dynamic_function_calls(self): - funcs = [lambda x: x, lambda x: x ** 2, lambda x: x ** 3] - - # TODO: Support enumerate transparently - # TODO: Support tuple assignment in loop transparently - - @pragma.deindex(funcs, 'funcs') - @pragma.unroll(lf=len(funcs)) - def run_func(i, x): - for j in range(lf): - if i == j: - return funcs[j](x) - - self.assertEqual(run_func(0, 5), 5) - self.assertEqual(run_func(1, 5), 25) - self.assertEqual(run_func(2, 5), 125) - - result = dedent(''' - def run_func(i, x): - if i == 0: - return funcs_0(x) - if i == 1: - return funcs_1(x) - if i == 2: - return funcs_2(x) - ''') - self.assertEqual(inspect.getsource(run_func).strip(), result.strip()) - - -class TestInline(PragmaTest): - def test_basic(self): - def g(x): - return x**2 - - @pragma.inline(g, return_source=True) - def f(y): - return g(y + 3) - - result = dedent(''' - def f(y): - _g_0 = dict(x=y + 3) - for ____ in [None]: - _g_0['return'] = _g_0['x'] ** 2 - break - _g_return_0 = _g_0.get('return', None) - del _g_0 - return _g_return_0 - ''') - self.assertEqual(f.strip(), result.strip()) - - def test_basic_run(self): - def g(x): - return x**2 - - @pragma.inline(g) - def f(y): - return g(y + 3) - - self.assertEqual(f(1), ((1 + 3) ** 2)) - - def test_basic_unroll(self): - def g(x): - return x**2 - - @pragma.unroll(return_source=True) - @pragma.inline(g) - def f(y): - return g(y + 3) - - result = dedent(''' - def f(y): - _g_0 = dict(x=y + 3) - _g_0['return'] = _g_0['x'] ** 2 - _g_return_0 = _g_0.get('return', None) - del _g_0 - return _g_return_0 - ''') - self.assertEqual(f.strip(), result.strip()) - - def test_more_complex(self): - def g(x, *args, y, **kwargs): - print("X = {}".format(x)) - for i, a in enumerate(args): - print("args[{}] = {}".format(i, a)) - print("Y = {}".format(y)) - for k, v in kwargs.items(): - print("{} = {}".format(k, v)) - - def f(): - g(1, 2, 3, 4, y=5, z=6, w=7) - - result = dedent(''' - def f(): - _g_0 = dict(x=1, args=(2, 3, 4), y=5, kwargs={'z': 6, 'w': 7}) - for ____ in [None]: - print('X = {}'.format(_g_0['x'])) - for i, a in enumerate(_g_0['args']): - print('args[{}] = {}'.format(i, a)) - print('Y = {}'.format(_g_0['y'])) - for k, v in _g_0['kwargs'].items(): - print('{} = {}'.format(k, v)) - del _g_0 - None - ''') - self.assertEqual(pragma.inline(g, return_source=True)(f).strip(), result.strip()) - - self.assertEqual(f(), pragma.inline(g)(f)()) - - def test_recursive(self): - def fib(n): - if n <= 0: - return 1 - elif n == 1: - return 1 - else: - return fib(n-1) + fib(n-2) - - from miniutils import tic - toc = tic() - fib_code = pragma.inline(fib, max_depth=1, return_source=True)(fib) - toc("Inlined recursive function to depth 1") - print(fib_code) - # fib_code = pragma.inline(fib, max_depth=3, return_source=True)(fib) - # toc("Inlined recursive function to depth 3") - # print(fib_code) - - fib = pragma.inline(fib, max_depth=2)(fib) - toc("Inlined executable function") - self.assertEqual(fib(0), 1) - toc("Ran fib(0)") - self.assertEqual(fib(1), 1) - toc("Ran fib(1)") - self.assertEqual(fib(2), 2) - toc("Ran fib(2)") - self.assertEqual(fib(3), 3) - toc("Ran fib(3)") - self.assertEqual(fib(4), 5) - toc("Ran fib(4)") - self.assertEqual(fib(5), 8) - toc("Ran fib(5)") - - # def test_failure_cases(self): - # def g_for(x): - # for i in range(5): - # yield x - # - # def f(y): - # return g_for(y) - # - # self.assertRaises(AssertionError, pragma.inline(g_for), f) - - def test_flip_flop(self): - def g(x): - return f(x / 2) - - def f(y): - if y <= 0: - return 0 - return g(y - 1) - - f_code = pragma.inline(g, return_source=True)(f) - - result = dedent(''' - def f(y): - if y <= 0: - return 0 - _g_0 = dict(x=y - 1) - for ____ in [None]: - _g_0['return'] = f(_g_0['x'] / 2) - break - _g_return_0 = _g_0.get('return', None) - del _g_0 - return _g_return_0 - ''') - self.assertEqual(f_code.strip(), result.strip()) - - f_unroll_code = pragma.unroll(return_source=True)(pragma.inline(g)(f)) - - result_unroll = dedent(''' - def f(y): - if y <= 0: - return 0 - _g_0 = dict(x=y - 1) - _g_0['return'] = f(_g_0['x'] / 2) - _g_return_0 = _g_0.get('return', None) - del _g_0 - return _g_return_0 - ''') - self.assertEqual(f_unroll_code.strip(), result_unroll.strip()) - - f2_code = pragma.inline(f, g, return_source=True, f=f)(f) - - result2 = dedent(''' - def f(y): - if y <= 0: - return 0 - _g_0 = dict(x=y - 1) - _f_0 = dict(y=_g_0['x'] / 2) - for ____ in [None]: - if _f_0['y'] <= 0: - _f_0['return'] = 0 - break - _f_0['return'] = g(_f_0['y'] - 1) - break - _f_return_0 = _f_0.get('return', None) - del _f_0 - for ____ in [None]: - _g_0['return'] = _f_return_0 - break - _g_return_0 = _g_0.get('return', None) - del _g_0 - return _g_return_0 - ''') - print(f2_code) - self.assertEqual(f2_code.strip(), result2.strip()) - - def test_generator(self): - def g(y): - for i in range(y): - yield i - yield from range(y) - - @pragma.inline(g, return_source=True) - def f(x): - return sum(g(x)) - - result = dedent(''' - def f(x): - _g_0 = dict(y=x, yield=[]) - for ____ in [None]: - for i in range(_g_0['y']): - _g_0['yield'].append(i) - _g_0['yield'].extend(range(_g_0['y'])) - _g_return_0 = _g_0['yield'] - del _g_0 - return sum(_g_return_0) - ''') - self.assertEqual(f.strip(), result.strip()) - - def test_variable_starargs(self): - def g(a, b, c): - return a + b + c - - @pragma.inline(g, return_source=True) - def f(x): - return g(*x) - - result = dedent(''' - def f(x): - return g(*x) - ''') - self.assertEqual(f.strip(), result.strip()) - - def test_multiple_inline(self): - def a(x): - return x ** 2 - - def b(x): - return x + 2 - - @pragma.unroll(return_source=True) - @pragma.inline(a, b) - def f(x): - return a(x) + b(x) - - result = dedent(''' - def f(x): - _a_0 = dict(x=x) - _a_0['return'] = _a_0['x'] ** 2 - _a_return_0 = _a_0.get('return', None) - del _a_0 - _b_0 = dict(x=x) - _b_0['return'] = _b_0['x'] + 2 - _b_return_0 = _b_0.get('return', None) - del _b_0 - return _a_return_0 + _b_return_0 - ''') - self.assertEqual(f.strip(), result.strip()) - - def test_coverage(self): - def g(y): - while False: - print(y) - - def f(): - try: - g(5) - except: - raise - - print(pragma.inline(g, return_source=True)(f)) - self.assertEqual(f(), pragma.inline(g)(f)()) - - -class TestDictStack(PragmaTest): - def test_most(self): - stack = pragma.DictStack() - stack.push({'x': 3}) - stack.push() - stack['x'] = 4 - self.assertEqual(stack['x'], 4) - res = stack.pop() - self.assertEqual(res['x'], 4) - self.assertEqual(stack['x'], 3) - self.assertIn('x', stack) - stack.items() - stack.keys() - del stack['x'] - self.assertNotIn('x', stack)