Skip to content

Commit

Permalink
Merge c571ffd into fe64e14
Browse files Browse the repository at this point in the history
  • Loading branch information
scnerd committed Oct 17, 2020
2 parents fe64e14 + c571ffd commit 744965d
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 8 deletions.
3 changes: 3 additions & 0 deletions .travis.yml
Expand Up @@ -2,6 +2,9 @@ language: python
python:
- "3.5"
- "3.6"
- "3.7"
- "3.8"
dist: xenial

install:
- pip install .
Expand Down
2 changes: 2 additions & 0 deletions pragma/core/__init__.py
Expand Up @@ -20,6 +20,8 @@ def _is_indexable(x):
safe_new_contract('indexable', _is_indexable)
safe_new_contract('literal', 'int|float|str|bool|tuple|list|None')
for name, tp in inspect.getmembers(ast, inspect.isclass):
if name[0] == '_': # python 3.8 added ast._AST which pycontracts does not like
continue
safe_new_contract(name, tp)

# Astor tries to get fancy by failing nicely, but in doing so they fail when traversing non-AST type node properties.
Expand Down
15 changes: 15 additions & 0 deletions pragma/core/transformer.py
Expand Up @@ -30,6 +30,17 @@ def function_ast(f):
except (KeyError, AttributeError): # pragma: nocover
f_file = ''

try:
found = inspect.findsource(f)
except IndexError as err:
raise IOError((
'Discrepancy in number of decorator @magics expected by '
'inspect vs. __code__.co_firstlineno\n'
'{} in {}.\n'
'Try using the decorators after declaring the function'
'instead of @-magic').format(f, f_file)
) from err

root = ast.parse(textwrap.dedent(inspect.getsource(f)), f_file)
return root, root.body[0].body, f_file

Expand Down Expand Up @@ -388,6 +399,10 @@ def inner(f):
func = glbls[f_mod.body[0].name]
if save_source:
func.__tempfile__ = temp
# When there are other decorators, the co_firstlineno of *some* python distributions gets confused
# and thinks they will be there even when they are not written to the file, causing readline overflow
# So we put some empty lines to make them align
temp.write('\n' * func.__code__.co_firstlineno)
temp.write(source)
temp.flush()
return func
Expand Down
14 changes: 6 additions & 8 deletions pragma/lift.py
Expand Up @@ -143,19 +143,17 @@ def __call__(self, f):
new_kws = [ast.arg(arg=k, annotation=self._annotate(k, v)) for k, v in free_vars]
new_kw_defaults = [self._get_default(k, v) for k, v in free_vars]

# python 3.8 introduced a new signature for ast.arguments.__init__, so use whatever they use
ast_arguments_dict = func_def.args.__dict__
ast_arguments_dict['kwonlyargs'] += new_kws
ast_arguments_dict['kw_defaults'] += new_kw_defaults

new_func_def = ast.FunctionDef(
name=func_def.name,
body=f_body,
decorator_list=[], # func_def.decorator_list,
returns=func_def.returns,
args=ast.arguments(
args=func_def.args.args,
vararg=func_def.args.vararg,
kwarg=func_def.args.kwarg,
defaults=func_def.args.defaults,
kwonlyargs=func_def.args.kwonlyargs + new_kws,
kw_defaults=func_def.args.kw_defaults + new_kw_defaults
)
args=ast.arguments(**ast_arguments_dict)
)

f_mod.body[0] = new_func_def
Expand Down
45 changes: 45 additions & 0 deletions tests/test_collapse_literals.py
Expand Up @@ -40,6 +40,27 @@ def f():
''')
self.assertEqual(f.strip(), result.strip())

def test_repeated_decoration(self):
@pragma.collapse_literals
@pragma.collapse_literals
@pragma.collapse_literals
@pragma.collapse_literals
@pragma.collapse_literals
@pragma.collapse_literals
@pragma.collapse_literals
@pragma.collapse_literals
def f():
return 2
f = pragma.collapse_literals(f)

result = '''
def f():
return 2
'''

self.assertSourceEqual(f, result)
self.assertEqual(f(), 2)

def test_vars(self):
@pragma.collapse_literals(return_source=True)
def f():
Expand Down Expand Up @@ -288,17 +309,41 @@ def test_simple_functions(self):
def f():
print(len(a))
print(sum(a))
print(-a[0])
print(a[0] + a[1])
print(a)

result = '''
def f():
print(4)
print(10)
print(-1)
print(3)
print(a)
'''

self.assertSourceEqual(f, result)

def test_indexable_operations(self):
dct = dict(a=1, b=2, c=3, d=4)

@pragma.collapse_literals
def f():
print(len(dct))
print(-dct['a'])
print(dct['a'] + dct['b'])
print(dct)

result = '''
def f():
print(4)
print(-1)
print(3)
print(dct)
'''

self.assertSourceEqual(f, result)

def test_reduction(self):
a = [1, 2, 3]

Expand Down

0 comments on commit 744965d

Please sign in to comment.