Skip to content

Commit

Permalink
Works on basic test cases. Clearly, passing in kwargs into function w…
Browse files Browse the repository at this point in the history
…on't be smooth, but a safe fallback is in place
  • Loading branch information
scnerd committed Dec 12, 2017
1 parent 67706b2 commit 10e1869
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 27 deletions.
59 changes: 40 additions & 19 deletions miniutils/pragma/inline.py
Expand Up @@ -46,6 +46,24 @@
# attributes (int lineno, int col_offset)


# @magic_contract
def make_name(fname, var, ctx=ast.Load, fmt="{fname}_{var}"):
"""
:param fname:
:type fname: str
:param var:
:type var: str
:param ctx:
:type ctx: Load|Store
:param fmt:
:type fmt: str
:return:
:rtype: Name
"""
return ast.Name(id=fmt.format(fname=fname, var=var), ctx=ctx())


class _InlineBodyTransformer(TrackedContextTransformer):
def __init__(self, func_name, param_names):
self.func_name = func_name
Expand All @@ -60,9 +78,7 @@ def visit_Name(self, node):
print("Found parameter reference {}".format(node.id))
if node.id not in self.ctxt:
# If so, get its value from the argument dictionary
return ast.Subscript(value=ast.Name(id='_'+self.func_name, ctx=ast.Load()),
slice=ast.Index(ast.Str(node.id)),
ctx=getattr(node, 'expr_context', ast.Load()))
return make_name(self.func_name, node.id)
else:
print("But it's been overwritten to {} = {}".format(node.id, self.ctxt[node.id]))
return node
Expand All @@ -72,7 +88,7 @@ def visit_Return(self, node):
raise NotImplementedError("miniutils.pragma.inline cannot handle returns from within a loop")
result = []
if node.value:
result.append(ast.Assign(targets=[ast.Name(id='_'+self.func_name, ctx=ast.Store())],
result.append(ast.Assign(targets=[make_name(self.func_name, 'return', ast.Store)],
value=self.visit(node.value)))
result.append(ast.Break())
return result
Expand Down Expand Up @@ -137,25 +153,30 @@ def visit_Call(self, node):
bound_args = fsig.bind(*node.args, **odict(keywords))
bound_args.apply_defaults()

# Create args dictionary
# fun_name = {}
cur_block.append(ast.Assign(targets=[ast.Name(id='_'+fname, ctx=ast.Store())],
value=ast.Dict(keys=[], values=[])))

for arg_name, arg_value in bound_args.arguments.items():
# fun_name['param_name'] = param_value
cur_block.append(ast.Assign(targets=[ast.Subscript(value=ast.Name(id='_'+fname, ctx=ast.Load()),
slice=ast.Index(ast.Str(arg_name)),
ctx=ast.Store())],
cur_block.append(ast.Assign(targets=[make_name(fname, arg_name, ast.Store)],
value=arg_value))

# Up to here is done...
# if kw_dict:
# # fun_name.update(kwargs)
# cur_block.append(ast.Call(func=ast.Attribute(value=ast.Name(id='_'+fname, ctx=ast.Load()),
# attr='update',
# ctx=ast.Load()),
# args=[kw_dict],
# keywords=[]))

# for k, v in kw_dict.items():
# exec(KEY_EQUALS_VALUE)
if kw_dict:
# fun_name.update(kwargs)
cur_block.append(ast.Call(func=ast.Attribute(value=ast.Name(id='_'+fname, ctx=ast.Load()),
attr='update',
ctx=ast.Load()),
args=[kw_dict],
keywords=[]))
loop = ast.parse('''
for k, v in {kw_dict}.items():
assert(isinstance(k, str))
exec("{{}} = {{}}".format(k, v))
'''.format(kw_dict=kw_dict)).body[0]
cur_block.append(loop)


# Inline function code
# This is our opportunity to recurse... please don't yet
Expand All @@ -165,7 +186,7 @@ def visit_Call(self, node):
orelse=[]))

# fun_name['return']
return ast.Name(id='_'+fname, ctx=ast.Load())
return make_name(fname, 'return')

else:
return node
Expand Down
14 changes: 6 additions & 8 deletions tests/test_pragma.py
Expand Up @@ -581,12 +581,11 @@ def f(y):

result = dedent('''
def f(y):
_g = {}
_g['x'] = y + 3
g_x = y + 3
for ____ in [None]:
_g = _g['x'] ** 2
g_return = g_x ** 2
break
return _g
return g_return
''')
self.assertEqual(f.strip(), result.strip())

Expand All @@ -611,10 +610,9 @@ def f(y):

result = dedent('''
def f(y):
_g = {}
_g['x'] = y + 3
_g = _g['x'] ** 2
return _g
g_x = y + 3
g_return = g_x ** 2
return g_return
''')
self.assertEqual(f.strip(), result.strip())

Expand Down

0 comments on commit 10e1869

Please sign in to comment.