Skip to content

Commit

Permalink
Debugged initial inline implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
scnerd committed Dec 9, 2017
1 parent cec2bd4 commit b9d2ab5
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 75 deletions.
2 changes: 1 addition & 1 deletion miniutils/pragma/core/resolve.py
Expand Up @@ -137,7 +137,7 @@ def constant_dict(node, ctxt):
if isinstance(node, (ast.Name, ast.NameConstant, ast.Attribute)):
res = resolve_name_or_attribute(node, ctxt)
if hasattr(res, 'items'):
return res
return dict(res.items())
return None


Expand Down
99 changes: 67 additions & 32 deletions miniutils/pragma/core/transformer.py
Expand Up @@ -9,7 +9,7 @@

from miniutils.magic_contract import magic_contract
from miniutils.opt_decorator import optional_argument_decorator
from .resolve import resolve_literal, constant_iterable
from .resolve import *
from .stack import DictStack


Expand Down Expand Up @@ -53,46 +53,75 @@ def _assign_names(node):
yield from _assign_names(node.value)


class DebugTransformerMixin:
def visit(self, node):
orig_node = copy.deepcopy(node)
new_node = super().visit(node)

try:
orig_node_code = astor.to_source(orig_node).strip()
if new_node is None:
print("Deleted >>> {} <<<".format(orig_node_code))
elif isinstance(new_node, ast.AST):
print("Converted >>> {} <<< to >>> {} <<<".format(orig_node_code, astor.to_source(new_node).strip()))
elif isinstance(new_node, list):
print("Converted >>> {} <<< to [[[ {} ]]]".format(orig_node_code, ", ".join(astor.to_source(n).strip() for n in new_node)))
except Exception as ex:
raise AssertionError("Failed on {} >>> {}".format(astor.dump_tree(orig_node), astor.dump_tree(new_node))) from ex
# print("Failed on {} >>> {}".format(astor.dump_tree(orig_node), astor.dump_tree(new_node)))
# return orig_node

return new_node



class TrackedContextTransformer(ast.NodeTransformer):
def __init__(self, ctxt=None):
self.ctxt = ctxt or DictStack()
super().__init__()

# def visit(self, node):
# orig_node = copy.deepcopy(node)
# new_node = super().visit(node)
#
# try:
# orig_node_code = astor.to_source(orig_node).strip()
# if new_node is None:
# print("Deleted >>> {} <<<".format(orig_node_code))
# elif isinstance(new_node, ast.AST):
# print("Converted >>> {} <<< to >>> {} <<<".format(orig_node_code, astor.to_source(new_node).strip()))
# elif isinstance(new_node, list):
# print("Converted >>> {} <<< to [[[ {} ]]]".format(orig_node_code, ", ".join(astor.to_source(n).strip() for n in new_node)))
# except Exception as ex:
# raise AssertionError("Failed on {} >>> {}".format(astor.dump_tree(orig_node), astor.dump_tree(new_node))) from ex
# # print("Failed on {} >>> {}".format(astor.dump_tree(orig_node), astor.dump_tree(new_node)))
# # return orig_node
#
# return new_node

def visit_Assign(self, node):
node.value = self.visit(node.value)
erase_targets = True
# print(node.value)
# TODO: Support tuple assignments
if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name):
nvalue = copy.deepcopy(node.value)
var = node.targets[0].id
val = constant_iterable(nvalue, self.ctxt)
if val is not None:
# print("Setting {} = {}".format(var, val))
self.ctxt[var] = val
else:
val = resolve_literal(nvalue, self.ctxt)
# print("Setting {} = {}".format(var, val))
self.ctxt[var] = val
else:
if len(node.targets) == 1:
if isinstance(node.targets[0], ast.Name):
nvalue = copy.deepcopy(node.value)
var = node.targets[0].id
val = constant_iterable(nvalue, self.ctxt)
if val is not None:
# print("Setting {} = {}".format(var, val))
self.ctxt[var] = val
else:
val = resolve_literal(nvalue, self.ctxt)
# print("Setting {} = {}".format(var, val))
self.ctxt[var] = val
erase_targets = False
elif isinstance(node.targets[0], ast.Subscript):
targ = node.targets[0]
iterable = constant_iterable(targ.value, self.ctxt, False)
if iterable is None:
iterable = constant_dict(targ.value, self.ctxt)
if iterable is None:
return node
key = resolve_literal(targ.slice, self.ctxt)
if isinstance(key, ast.AST):
return node

nvalue = copy.deepcopy(node.value)
val = constant_iterable(nvalue, self.ctxt)
warnings.warn("Iterable assignment not fully implemented yet...")
if val is not None:
# print("Setting {} = {}".format(var, val))
iterable[key] = val
else:
val = resolve_literal(nvalue, self.ctxt)
# print("Setting {} = {}".format(var, val))
iterable[key] = val
erase_targets = False

if erase_targets:
for targ in node.targets:
for assgn in _assign_names(targ):
self.ctxt[assgn] = None
Expand All @@ -103,6 +132,12 @@ def visit_AugAssign(self, node):
self.ctxt[assgn] = None
return super().generic_visit(node)

def visit_Del(self, node):
for targ in node.targets:
for assgn in _assign_names(targ):
del self.ctxt[assgn]
return super().generic_visit(node)


def make_function_transformer(transformer_type, name, description, **transformer_kwargs):
@optional_argument_decorator
Expand Down
44 changes: 24 additions & 20 deletions miniutils/pragma/inline.py
@@ -1,8 +1,4 @@
import ast
import inspect

from .core import TrackedContextTransformer, function_ast, constant_dict, make_function_transformer, \
resolve_name_or_attribute, resolve_literal
from .core import *
from .. import magic_contract
from collections import OrderedDict as odict

Expand Down Expand Up @@ -50,37 +46,45 @@
# attributes (int lineno, int col_offset)


class _InlineBodyTransformer(ast.NodeTransformer):
class _InlineBodyTransformer(DebugTransformerMixin, TrackedContextTransformer):
def __init__(self, func_name, param_names):
self.func_name = func_name
print("Func {} takes parameters {}".format(func_name, param_names))
self.param_names = param_names
super().__init__()

def visit_Name(self, node):
# Check if this is a parameter, and hasn't had another value assigned to it
if node.id in self.param_names:
return ast.Subscript(value=ast.Name(id=self.func_name),
slice=ast.Index(ast.Name(node.id)),
expr_context=getattr(node, 'expr_context', ast.Load()))
print("Found parameter reference {}".format(node.id))
if node.id not in self.ctxt:
# If so, get its value from the argument dictionary
return ast.Subscript(value=ast.Name(id=self.func_name),
slice=ast.Index(ast.Str(node.id)),
expr_context=getattr(node, 'expr_context', ast.Load()))
else:
print("But it's been overwritten to {} = {}".format(node.id, self.ctxt[node.id]))
return node

def visit_Return(self, node):
result = []
if node.value:
result.append(ast.Assign(targets=[ast.Subscript(value=ast.Name(id=self.func_name),
slice=ast.Name(id='return'),
slice=ast.Str('return'),
expr_context=ast.Store())],
value=node.value))
value=self.visit(node.value)))
result.append(ast.Break())
return result


class InlineTransformer(TrackedContextTransformer):
class InlineTransformer(DebugTransformerMixin, TrackedContextTransformer):
def __init__(self, *args, funs=None, **kwargs):
assert funs is not None
super().__init__(*args, **kwargs)

self.funs = funs
self.code_blocks = []


def nested_visit(self, nodes):
"""When we visit a block of statements, create a new "code block" and push statements into it"""
lst = []
Expand All @@ -100,6 +104,7 @@ def nested_visit(self, nodes):
def visit_Call(self, node):
"""When we see a function call, insert the function body into the current code block, then replace the call
with the return expression """
node = self.generic_visit(node)
node_fun = resolve_name_or_attribute(resolve_literal(node.func, self.ctxt), self.ctxt)

for (fun, fname, fsig, fbody) in self.funs:
Expand Down Expand Up @@ -130,10 +135,9 @@ def visit_Call(self, node):
value=arg_value))

# Inline function code
# This is our opportunity to recurse... please don't yet
cur_block.append(ast.For(target=ast.Name(id='____'),
iter=ast.Call(func=ast.Name(id='range'),
args=[ast.Num(1)],
keywords=[]),
iter=ast.List(elts=[ast.NameConstant(None)]),
body=fbody,
orelse=[]))

Expand Down Expand Up @@ -225,13 +229,13 @@ def inline(*funs_to_inline, **kwargs):
fsig = inspect.signature(fun_to_inline)
_, fbody, _ = function_ast(fun_to_inline)

new_name = '_{fname}_{name}'

import astor
print(astor.dump_tree(fbody))
name_transformer = _InlineBodyTransformer(fname, new_name)
name_transformer = _InlineBodyTransformer(fname, fsig.parameters)
fbody = [name_transformer.visit(stmt) for stmt in fbody]
fbody = [stmt for visited in fbody for stmt in (visited if isinstance(visited, list) else [visited])]
fbody = [stmt for visited in fbody for stmt in (visited if isinstance(visited, list)
else [visited] if visited is not None
else [])]
print(astor.dump_tree(fbody))
funs.append((fun_to_inline, fname, fsig, fbody))

Expand Down
57 changes: 36 additions & 21 deletions miniutils/pragma/unroll.py
@@ -1,8 +1,22 @@
import copy

from .core import TrackedContextTransformer, make_function_transformer, constant_iterable
from .core import *


def has_break(node):
for field, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, (ast.Break, ast.Continue)):
return True
if isinstance(item, ast.AST):
if has_break(item):
return True
elif isinstance(value, ast.AST):
if has_break(value):
return True
return False

# noinspection PyPep8Naming
class UnrollTransformer(TrackedContextTransformer):
def __init__(self, *args, **kwargs):
Expand All @@ -11,26 +25,27 @@ def __init__(self, *args, **kwargs):

def visit_For(self, node):
result = [node]
iterable = constant_iterable(node.iter, self.ctxt)
if iterable is not None:
result = []
loop_var = node.target.id
orig_loop_vars = self.loop_vars
# print("Unrolling 'for {} in {}'".format(loop_var, list(iterable)))
for val in iterable:
self.ctxt.push({loop_var: val})
self.loop_vars = orig_loop_vars | {loop_var}
for body_node in copy.deepcopy(node.body):
res = self.visit(body_node)
if isinstance(res, list):
result.extend(res)
elif res is None:
continue
else:
result.append(res)
# result.extend([self.visit(body_node) for body_node in copy.deepcopy(node.body)])
self.ctxt.pop()
self.loop_vars = orig_loop_vars
if not any(has_break(n) for n in node.body):
iterable = constant_iterable(node.iter, self.ctxt)
if iterable is not None:
result = []
loop_var = node.target.id
orig_loop_vars = self.loop_vars
# print("Unrolling 'for {} in {}'".format(loop_var, list(iterable)))
for val in iterable:
self.ctxt.push({loop_var: val})
self.loop_vars = orig_loop_vars | {loop_var}
for body_node in copy.deepcopy(node.body):
res = self.visit(body_node)
if isinstance(res, list):
result.extend(res)
elif res is None:
continue
else:
result.append(res)
# result.extend([self.visit(body_node) for body_node in copy.deepcopy(node.body)])
self.ctxt.pop()
self.loop_vars = orig_loop_vars
return result

def visit_Name(self, node):
Expand Down
4 changes: 3 additions & 1 deletion tests/test_pragma.py
Expand Up @@ -554,7 +554,9 @@ def f(y):
def f(y):
g = {}
g['x'] = y + 3
g['return'] = g['x'] ** 2
for ____ in [None]:
g['return'] = g['x'] ** 2
break
return g['return']
''')
self.assertEqual(f.strip(), result.strip())
Expand Down

0 comments on commit b9d2ab5

Please sign in to comment.