Skip to content
Permalink
 
 
Cannot retrieve contributors at this time
import builtins
import inspect
import itertools
import logging as log
import typing
import libcst as cst
import libcst.matchers as m
from libcst.metadata import ExpressionContext
from .common import (SEP, a2s, get_function_locals, make_assign, make_dict,
make_index, make_list, make_string, parse_expr,
parse_statement)
from .contexts import ctx_inliner, ctx_pass
from .visitors import (ExpressionContextProviderBlock, RemoveFunctoolsWraps,
ReplaceReturn, ReplaceSuper, ReplaceYield,
ScopeProviderFunction, collect_imports, rename)
def rename_in_function(f_ast, src, dst):
mod = cst.Module(body=f_ast.body.body)
return f_ast.with_deep_changes(f_ast.body, body=rename(mod, src, dst).body)
# Scope a variable name as unique to the function, and update any
# references to it in the function
def unique_and_rename(f_ast, name):
unique_name = f'{name}{SEP}{f_ast.name.value}'
return rename_in_function(f_ast, name, unique_name), unique_name
def bind_arguments(f_ast, call_expr, new_stmts):
pass_ = ctx_pass.get()
args_def = f_ast.params
def bind_new_argument(k, v):
nonlocal f_ast
# Add a binding from function argument to call argument
f_ast, uniq_k = unique_and_rename(f_ast, k)
stmt = make_assign(cst.Name(uniq_k), v)
new_stmts.append(stmt)
# If function is called with f(*args)
star_arg = next(filter(lambda arg: arg.star == '*', call_expr.args), None)
if star_arg is not None:
star_arg = star_arg.value
# Get the length of the star_arg runtime list
star_arg_obj = pass_.eval(star_arg)
# Generate an indexing expression for each element of the list
call_star_args = [
make_index(star_arg, cst.Integer(str(i)))
for i in range(len(star_arg_obj))
]
else:
star_arg = None
# If function is called with f(**kwargs)
star_kwarg = next(filter(lambda arg: arg.star == '**', call_expr.args),
None)
if star_kwarg is not None:
star_kwarg = star_kwarg.value
star_kwarg_dict = pass_.eval(star_kwarg)
call_star_kwarg = {
key: make_index(star_kwarg, make_string(key))
for key in star_kwarg_dict.keys()
}
# Function's anonymous arguments, e.g. f(1, 2) becomes [1, 2]
call_anon_args = [
arg.value for arg in call_expr.args
if arg.keyword is None and arg.star == ''
]
# Function's keyword arguments, e.g. f(x=1, y=2) becomes {'x': 1, 'y': 2}
call_kwargs = {
arg.keyword.value: arg.value
for arg in call_expr.args if arg.keyword is not None and arg.star == ''
}
# Match up defaults with variable names.
#
# Python convention is that if function has N arguments and K < N defaults, then
# the defaults correspond to arguments N - K .. N.
anon_defaults = {
arg.name.value: arg.default
for arg in args_def.params if arg.default is not None
}
# All keyword-only arguments must have defaults.
#
# kwonlyargs occur if a function definition has args AFTER a *args, e.g.
# the var "y" in `def foo(x, *args, y=1)`
kw_defaults = {
arg.name.value: arg.default
for arg in args_def.kwonly_params
}
# For each non-keyword-only argument, match it up with the corresponding
# syntax from the call expression
for arg in args_def.params:
k = arg.name.value
# First, match with anonymous arguments
if len(call_anon_args) > 0:
v = call_anon_args.pop(0)
# Then use *args if it exists
elif star_arg is not None and len(call_star_args) > 0:
v = call_star_args.pop(0)
# Then use keyword arguments
elif k in call_kwargs:
v = call_kwargs.pop(k)
# Then use **kwargs if it exists
elif star_kwarg is not None and k in call_star_kwarg:
v = call_star_kwarg.pop(k)
# Otherwise use the default value
else:
v = anon_defaults.pop(k)
bind_new_argument(k, v)
# Perform equivalent procedure as above, but for keyword-only arguments
for arg in args_def.kwonly_params:
k = arg.name.value
if k in call_kwargs:
v = call_kwargs.pop(k)
elif star_kwarg is not None and k in call_star_kwarg:
v = call_star_kwarg.pop(k)
else:
v = kw_defaults.pop(k)
bind_new_argument(k, v)
# If function definition uses *args, then assign it to the remaining anonymous
# arguments from the call_expr
if (args_def.star_arg is not cst.MaybeSentinel.DEFAULT
and not isinstance(args_def.star_arg, cst.ParamStar)):
f_ast, k = unique_and_rename(f_ast, args_def.star_arg.name.value)
v = call_anon_args[:]
if star_arg is not None:
v += call_star_args
new_stmts.append(make_assign(cst.Name(k), make_list(v)))
# Similarly for **kwargs in the function definition
if args_def.star_kwarg is not None:
f_ast, k = unique_and_rename(f_ast, args_def.star_kwarg.name.value)
items = call_kwargs.items()
if star_kwarg is not None:
items = itertools.chain(items, call_star_kwarg.items())
new_stmts.append(
make_assign(cst.Name(k),
make_dict([(make_string(k), v) for k, v in items])))
return f_ast
def replace_super(f_ast, cls, call, func_obj, new_stmts):
pass_ = ctx_pass.get()
# If we don't know what the class is, e.g. in Foo.method(foo), then
# eval the LHS of the attribute, e.g. Foo here
if cls is None:
if m.matches(call.func, m.Attribute()):
cls = pass_.eval(call.func.value)
else:
cls = pass_.eval(call.func).__class__
# TODO: support multiple inheritance
# Add import for base class
assert len(cls.__bases__) == 1
base = cls.__bases__[0]
file_imports = collect_imports(func_obj)
imprt = generate_import(base.__name__, base, func_obj, file_imports)
if imprt is not None:
new_stmts.insert(0, imprt)
return f_ast.visit(ReplaceSuper(base))
def generate_imports_for_nonlocals(f_ast, func_obj, call):
# Get all read-position variables
contexts = cst.MetadataWrapper(
f_ast.body,
unsafe_skip_copy=True).resolve(ExpressionContextProviderBlock)
used_names = set(
node.value for node, ctx in contexts.items()
if ctx == ExpressionContext.LOAD and m.matches(node, m.Name()))
closure = {**func_obj.__globals__, **get_function_locals(func_obj)}
file_imports = collect_imports(func_obj)
imports = [
generate_import(name, closure[name], func_obj, file_imports)
for name in used_names if name in closure
]
imports = [i for i in imports if i]
return imports
def inline_decorators(f_ast, call, func_obj, ret_var):
"""
Expand decorator calls to an inlined function.
Example:
@foo
def bar(x):
return x + 1
assert bar(1) == 2
>> becomes >>
def bar(x):
return x + 1
assert foo(bar(1)) == 2
"""
decorators = f_ast.decorators
# TODO
# used_globals = UsedGlobals(func_obj.__globals__)
# used_globals.visit(f_ast)
# used = used_globals.used
# file_imports = collect_imports(func_obj)
# for name, globl in used_globals.used.items():
# imprt = self.generate_imports(name,
# globl,
# func_obj=func_obj,
# file_imports=file_imports)
# if imprt is not None:
# new_stmts.insert(0, imprt)
f_ast = f_ast.with_changes(decorators=[])
new_call = call.with_changes(
func=cst.Call(func=decorators[0].decorator, args=[cst.Arg(f_ast.name)]))
return [f_ast, make_assign(cst.Name(ret_var), new_call)]
def inline_function(func_obj,
call,
ret_var,
cls=None,
f_ast=None,
is_toplevel=False):
log.debug('Inlining {}'.format(a2s(call)))
inliner = ctx_inliner.get()
pass_ = ctx_pass.get()
if f_ast is None:
# Get the source code for the function
try:
f_source = inspect.getsource(func_obj)
except TypeError:
print('Failed to get source of {}'.format(a2s(call)))
raise
# Record statistics about length of inlined source
inliner.length_inlined += len(f_source.split('\n'))
# Then parse the function into an AST
f_ast = parse_statement(f_source)
# Give the function a fresh name so it won't conflict with other calls to
# the same function
f_ast = f_ast.with_changes(name=cst.Name(pass_.fresh_var(f_ast.name.value)))
# TODO
# If function has decorators, deal with those first. Just inline decorator call
# and stop there.
decorators = f_ast.decorators
assert len(decorators) <= 1 # TODO: deal with multiple decorators
if len(decorators) == 1:
d = decorators[0].decorator
builtin_decorator = (
isinstance(d, cst.Name)
and (d.value in ['property', 'classmethod', 'staticmethod']))
derived_decorator = (isinstance(d, cst.Attribute)
and (d.attr.value in ['setter']))
if not (builtin_decorator or derived_decorator):
return inline_decorators(f_ast, call, func_obj, ret_var)
# # If we're inlining a decorator, we need to remove @functools.wraps calls
# # to avoid messing up inspect.getsource
f_ast = f_ast.with_changes(body=f_ast.body.visit(RemoveFunctoolsWraps()))
new_stmts = []
# If the function is a method (which we proxy by first arg being named "self"),
# then we need to replace uses of special "super" keywords.
args_def = f_ast.params
if len(args_def.params) > 0:
first_arg_is_self = m.matches(args_def.params[0],
m.Param(m.Name('self')))
if first_arg_is_self:
f_ast = replace_super(f_ast, cls, call, func_obj, new_stmts)
# Add bindings from arguments in the call expression to arguments in function def
f_ast = bind_arguments(f_ast, call, new_stmts)
scopes = cst.MetadataWrapper(
f_ast, unsafe_skip_copy=True).resolve(ScopeProviderFunction)
func_scope = scopes[f_ast.body]
for assgn in func_scope.assignments:
if m.matches(assgn.node, m.Name()):
var = assgn.node.value
f_ast = unique_and_rename(f_ast, var)
# Add an explicit return None at the end to reify implicit return
f_body = f_ast.body
last_stmt_is_return = m.matches(f_body.body[-1],
m.SimpleStatementLine([m.Return()]))
if (not is_toplevel and # If function return is being assigned
cls is None and # And not an __init__ fn
not last_stmt_is_return):
f_ast = f_ast.with_deep_changes(f_body,
body=list(f_body.body) +
[parse_statement("return None")])
# Replace returns with if statements
f_ast = f_ast.with_changes(body=f_ast.body.visit(ReplaceReturn(ret_var)))
# Inline function body
new_stmts.extend(f_ast.body.body)
# Create imports for non-local variables
imports = generate_imports_for_nonlocals(f_ast, func_obj, call)
new_stmts = imports + new_stmts
if inliner.add_comments:
# Add header comment to first statement
call_str = a2s(call)
header_comment = [
cst.EmptyLine(comment=cst.Comment(f'# {line}'))
for line in call_str.splitlines()
]
first_stmt = new_stmts[0]
new_stmts[0] = first_stmt.with_changes(
leading_lines=[cst.EmptyLine(indent=False)] + header_comment +
list(first_stmt.leading_lines))
return new_stmts
def inline_constructor(func_obj, call, ret_var):
"""
Inlines a class constructor.
Construction has two parts: creating the object with __new__, and
initializing it with __init__. We insert a __new__ call and then
inline the __init__ function.
Example:
class Foo:
def __init__(self):
self.x = 1
f = Foo()
>> becomes >>
f = Foo.__new__(Foo)
self = f
self.x = 1
"""
cls_name = func_obj.__name__
new_stmts = []
# # Add an import for the class
cls_import = generate_import(cls_name, func_obj)
if cls_import is not None:
new_stmts.append(cls_import)
# Create a raw object using __new__
make_obj = make_assign(
cst.Name(ret_var),
cst.parse_expression(f'{cls_name}.__new__({cls_name})'))
new_stmts.append(make_obj)
# Add the object as an explicit argument to the __init__ function
call = call.with_changes(args=[cst.Arg(cst.Name(ret_var))] +
list(call.args))
if func_obj.__init__ is not object.__init__:
# Inline the __init__ function
init_inline = inline_function(func_obj.__init__,
call,
ret_var,
cls=func_obj)
else:
init_inline = []
return new_stmts + init_inline
def inline_method(func_obj, call, ret_var):
"""
Replace bound methods with unbound functions.
Example:
f = Foo()
assert f.bar(0) == 1
>> becomes >>
f = Foo()
assert Foo.bar(f, 0) == 1
"""
pass_ = ctx_pass.get()
# HACK: assume all methods are called syntactically as obj.method()
# as opposed to x = obj.method; x()
assert isinstance(call.func, cst.Attribute)
method_name = call.func.attr.value
# Get the object bound to the method
bound_obj = func_obj.__self__
# If the method is a classmethod, the method is bound to the class
if inspect.isclass(bound_obj):
cls_name = bound_obj.__name__
cls_obj = bound_obj
new_func = f'{cls_name}.{method_name}.__func__'
else:
cls_name = bound_obj.__class__.__name__
cls_obj = bound_obj.__class__
new_func = f'{cls_name}.{method_name}'
new_func = parse_expr(new_func)
new_stmts = []
cls_import = generate_import(cls_name, cls_obj)
if cls_import is not None:
new_stmts.append(cls_import)
exec(a2s(cls_import), pass_.globls, pass_.globls)
# Add the object as explicit self parameter
new_call = call.with_changes(func=new_func,
args=[cst.Arg(call.func.value)] +
list(call.args))
# Go back to general inline dispatcher since e.g. function may
# be a generator method, so we can't directly call inline_function
return new_stmts + inline(func_obj.__func__, new_call, ret_var)
def inline_generator(func_obj, call, ret_var):
"""
Inlines generators (those using yield).
There is no easy way to proxy generator semantics/control flow
without generators, unlike early returns. The simple strategy is to
eagerly materialize the generator into a list. However, this is both
inefficient and does not always preserve semantics, e.g. see
requests.ipynb.
Example:
def foo():
for i in range(10):
yield i
for i in foo():
print(i)
>> becomes >>
l = []
for i in range(10):
l.append(i)
for i in l:
print(i)
"""
f_ast = parse_statement(inspect.getsource(func_obj))
# Initialize the list
new_stmts = [parse_statement(f'{ret_var} = []')]
# Replace all yield statements by appending to the list
f_ast = f_ast.visit(ReplaceYield(ret_var))
# Then inline the function as normal
new_stmts.extend(inline_function(func_obj, call, ret_var, f_ast=f_ast))
return new_stmts
def generate_import(name, obj, func_obj=None, file_imports=None):
"""
Generate an import statement for a (name, runtime object) pair.
"""
inliner = ctx_inliner.get()
# HACK? is this still needed?
if name == 'self':
return None
# If the name is already in scope, don't need to import it
if name in inliner.base_globls:
# TODO: name conflicts? e.g. host imports json as x, and
# another module imports foo as x
return None
# If the name appears directly in an import statement in the object's file,
# then use that import
if file_imports is not None and name in file_imports:
return cst.SimpleStatementLine([file_imports[name]])
# If we're importing a module, then add an import directly
if inspect.ismodule(obj):
mod_name = obj.__name__
return parse_statement(f'import {mod_name} as {name}'
if name != mod_name else f'import {mod_name}')
else:
# Get module where global is defined
mod = inspect.getmodule(obj)
# TODO: When is mod None?
if mod is None or mod is typing or mod.__name__ == '__main__':
return None
# Can't import builtins
elif mod is __builtins__ or mod is builtins:
return None
# If the value is a class or function, then import it from the defining
# module
elif inspect.isclass(obj) or inspect.isfunction(obj):
return parse_statement(f'from {mod.__name__} import {name}')
# Otherwise import it from the module using the global
elif func_obj is not None:
func_mod_name = inspect.getmodule(func_obj).__name__
if func_mod_name == '__main__':
return None
return parse_statement(f'from {func_mod_name} import {name}')
def inline(func_obj, call, ret_var):
if inspect.ismethod(func_obj):
return inline_method(func_obj, call, ret_var)
elif inspect.isgeneratorfunction(func_obj):
return inline_generator(func_obj, call, ret_var)
elif inspect.isclass(func_obj):
return inline_constructor(func_obj, call, ret_var)
elif inspect.isfunction(func_obj):
return inline_function(func_obj, call, ret_var)
else:
raise NotImplementedError