Skip to content
This repository has been archived by the owner on Feb 25, 2018. It is now read-only.

Commit

Permalink
apply constant() recursively to inner functions
Browse files Browse the repository at this point in the history
  • Loading branch information
rfk committed Dec 8, 2009
1 parent 0a7493f commit 1335293
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 11 deletions.
64 changes: 53 additions & 11 deletions promise/__init__.py
Expand Up @@ -245,11 +245,12 @@ class constant(Promise):
will be deferred until the function first executes.
"""

def __init__(self,names):
def __init__(self,names,exclude=[]):
self.names = names
self.exclude = exclude
super(constant,self).__init__()

def _load_name(self,func,nm,op=None):
def _load_name(self,func,name,op=None):
"""Look up the given name in the scope of the given function.
This is an attempt to replicate the name lookup rules of LOAD_NAME,
Expand All @@ -258,14 +259,41 @@ def _load_name(self,func,nm,op=None):
If the name cannot be found, NameError is raised.
"""
# TODO: loading of local or closed-over names.
if op in (None,LOAD_NAME,LOAD_DEREF):
try:
return self._load_name_deref(func,name)
except NameError:
pass
if op in (None,LOAD_NAME,LOAD_GLOBAL):
try:
return self._load_name_global(func,name)
except NameError:
pass
raise NameError(name)


def _load_name_deref(self,func,name):
"""Simulate (LOAD_DEREF,name) on the given function."""
# Determine index of cell matching given name
try:
idx = func.func_code.co_cellvars.index(name)
except ValueError:
try:
return func.func_globals[nm]
idx = func.func_code.co_freevars.index(name)
idx -= len(func.func_code.co_cellvars)
except ValueError:
raise NameError(name)
return func.func_closure[idx].cell_contents

def _load_name_global(self,func,name):
"""Simulate (LOAD_GLOBAL,name) on the given function."""
try:
try:
return func.func_globals[name]
except KeyError:
return __builtins__[nm]
return __builtins__[name]
except KeyError:
raise NameError(nm)
raise NameError(name)

def decorate(self,func):
try:
Expand All @@ -279,8 +307,10 @@ def apply(self,func,code):
missing_names = []
for (i,(op,arg)) in enumerate(code.code):
# Replace LOADs of matching names with LOAD_CONST
if op in (LOAD_GLOBAL,):
if arg in self.names and arg not in missing_names:
if op in (LOAD_GLOBAL,LOAD_DEREF,LOAD_NAME):
if arg in self.names:
if arg in self.exclude or arg in missing_names:
continue
try:
val = new_constants[arg]
except KeyError:
Expand All @@ -293,10 +323,10 @@ def apply(self,func,code):
code.code[i] = (LOAD_CONST,val)
else:
code.code[i] = (LOAD_CONST,val)
# Error out for unsupported load types, for now...
elif op in (LOAD_NAME,LOAD_DEREF,LOAD_FAST):
# Quick check that locals haven't been promised constant
elif op == LOAD_FAST:
if arg in self.names:
raise TypeError("sorry, only global constants are currently supported [%s]" % (arg,))
raise BrokenPromiseError("local names can't be constant: '%s'" % (arg,))
# Quick check that constant names arent munged
elif op in (STORE_NAME,STORE_GLOBAL,STORE_FAST,STORE_DEREF):
if arg in self.names:
Expand All @@ -306,9 +336,21 @@ def apply(self,func,code):
if arg in self.names:
msg = "name '%s' was promised constant, but deleted"
raise BrokenPromiseError(msg % (arg,))
# Track any existing constants for use in the next step
elif op == LOAD_CONST:
if arg not in old_constants:
old_constants.add(arg)
# Recursively apply promise to any inner functions.
# TODO: how can we do deferred promises on inner functions?
if i+1 < len(code.code):
(nextop,nextarg) = code.code[i+1]
if nextop in (MAKE_FUNCTION,MAKE_CLOSURE):
exclude = arg.to_code().co_varnames
p = self.__class__(names=self.names,exclude=exclude)
try:
p.apply(func,arg)
except NameError:
pass
# If any constants define a '_promise_fold_constant' method,
# let them have a crack at the bytecode as well.
for const in new_constants.itervalues():
Expand Down
43 changes: 43 additions & 0 deletions promise/tests/__init__.py
Expand Up @@ -48,6 +48,49 @@ def test(self):
locals()["test_"+nm] = _make_test(nm)


def test_inlining():
"""Test that function inlining works under a variety of circumstances."""
@promise.pure()
def calc(a,b=7):
return 2*a + 3*b
items = [(1,7),(3,7),(5,7)]
# Version using list comprehension produces a for-loop
def aggregate0(items):
return sum([calc(a,b) for (a,b) in items])
assert (LOAD_DEREF,"calc") in Code.from_code(aggregate0.func_code).code
assert (BINARY_ADD,None) not in Code.from_code(aggregate0.func_code).code
# Version using generator comprehension produces a closure
def aggregate1(items):
return sum(calc(a,b) for (a,b) in items)
assert aggregate0(items) == aggregate1(items)
genexp = aggregate1.func_code.co_consts[1]
assert (LOAD_DEREF,"calc") in Code.from_code(genexp).code
assert (BINARY_ADD,None) not in Code.from_code(genexp).code
# Pure function can be folded into the for-loop
@promise.constant(["calc"])
def aggregate2(items):
return sum([calc(a,b) for (a,b) in items])
assert aggregate0(items) == aggregate2(items)
assert (LOAD_DEREF,"calc") not in Code.from_code(aggregate2.func_code).code
assert (BINARY_ADD,None) in Code.from_code(aggregate2.func_code).code
# Pure function can be pushed inside closure
@promise.constant(["calc"])
def aggregate3(items):
return sum(calc(a,b) for (a,b) in items)
assert aggregate0(items) == aggregate3(items)
genexp = aggregate3.func_code.co_consts[1]
assert (LOAD_DEREF,"calc") not in Code.from_code(genexp).code
assert (BINARY_ADD,None) in Code.from_code(genexp).code
# Default arguments are respected
@promise.constant(["calc"])
def aggregate4(items):
return sum(calc(a) for (a,_) in items)
assert aggregate0(items) == aggregate4(items)
genexp = aggregate4.func_code.co_consts[1]
assert (LOAD_DEREF,"calc") not in Code.from_code(genexp).code
assert (BINARY_ADD,None) in Code.from_code(genexp).code



def test_README():
"""Ensure that the README is in sync with the docstring.
Expand Down

0 comments on commit 1335293

Please sign in to comment.