Skip to content

Commit

Permalink
Merge 15863ca into a72e783
Browse files Browse the repository at this point in the history
  • Loading branch information
atait committed Oct 13, 2020
2 parents a72e783 + 15863ca commit 68b1560
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/todo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ TODO List
.. todo:: Implement decorator to eliminate unused lines of code (assignments to unused values)
.. todo:: Technically, ``x += y`` doesn't have to be the same thing as ``x = x + y``. Handle it as its own operation of the form ``x += y; return x``
.. todo:: Support efficiently inlining simple functions, i.e. where there is no return or only one return as the last line of the function, using pure name substitution without loops, try/except, or anything else fancy
.. todo:: Catch replacement of loop variables that conflict with globals, or throw a more descriptive error when detected. See ``test_iteration_variable``

.. todolist::
2 changes: 1 addition & 1 deletion pragma/core/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def transform(return_source=False, save_source=True, function_globals=None, **kw
def inner(f):
f_mod, f_body, f_file = function_ast(f)
# Grab function globals
glbls = f.__globals__
glbls = f.__globals__.copy()
# Grab function closure variables
if isinstance(f.__closure__, tuple):
glbls.update({k: v.cell_contents for k, v in zip(f.__code__.co_freevars, f.__closure__)})
Expand Down
41 changes: 41 additions & 0 deletions tests/test_collapse_literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,47 @@ def f():

self.assertIsInstance(w[-1].category(), UserWarning)

def test_side_effects_0cause(self):
# This will never fail, but it causes other tests to fail
# if it incorrectly moves 'a' from the closure to the module globals
a = 0
@pragma.collapse_literals
def f():
x = a

def test_side_effects_1effect(self):
@pragma.collapse_literals
def f2():
for a in range(3): # failure occurs when this is interpreted as "for 0 in range(3)"
x = a

def test_iteration_variable(self):
# global glbvar # TODO: Uncommenting should lead to a descriptive error
glbvar = 0

# glbvar in <locals> is recognized as in the __closure__ of f1
@pragma.collapse_literals
def f1():
x = glbvar
result = '''
def f1():
x = 0
'''
self.assertSourceEqual(f1, result)

# glbvar in <locals> is recognized as NOT in the __closure__ of f2
# but, if glbvar is in __globals__, it fails (and maybe should)
@pragma.collapse_literals
def f2():
for glbvar in range(3):
x = glbvar
result = '''
def f2():
for glbvar in range(3):
x = glbvar
'''
self.assertSourceEqual(f2, result)

def test_conditional_erasure(self):
@pragma.collapse_literals
def f(y):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_lambda_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def f():
import pragma
import sys
return sys.version_info
''')
''', skip_pytest_imports=True)
self.assertSourceEqual(pragma.lift(imports=['sys'])(f), '''
def f():
import sys
Expand All @@ -172,7 +172,7 @@ def g():
import pragma
import sys as pseudo_sys
return pseudo_sys.version_info
''')
''', skip_pytest_imports=True)

def test_docstring(self):
@pragma.lift(imports=True)
Expand All @@ -187,6 +187,6 @@ def f(x):
return x + 1
'''

self.assertSourceEqual(f, result)
self.assertSourceEqual(f, result, skip_pytest_imports=True)


12 changes: 11 additions & 1 deletion tests/test_pragma.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,19 @@ def setUp(self):
# import contracts
# contracts.enable_all()

def assertSourceEqual(self, a, b):
def assertSourceEqual(self, a, b, skip_pytest_imports=False):
if callable(a):
a = dedent(getsource(a))
if skip_pytest_imports:
pytest_imports = [
'import builtins as @py_builtins',
'import _pytest.assertion.rewrite as @pytest_ar'
]
a_builder = []
for line in a.split('\n'):
if line.strip() not in pytest_imports:
a_builder.append(line)
a = '\n'.join(a_builder)
return self.assertEqual(a.strip(), dedent(b).strip())

def assertSourceIn(self, a, *b):
Expand Down

0 comments on commit 68b1560

Please sign in to comment.