Skip to content

Commit

Permalink
Merge pull request #13 from atait/feature/collapse-opts
Browse files Browse the repository at this point in the history
Feature/collapse options
  • Loading branch information
atait committed Oct 17, 2020
2 parents a032639 + 1ae9ec8 commit 19087c0
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 7 deletions.
9 changes: 8 additions & 1 deletion pragma/collapse_literals.py
@@ -1,17 +1,24 @@
import ast
import logging

from .core import TrackedContextTransformer, make_function_transformer, primitive_ast_types
from .core import TrackedContextTransformer, make_function_transformer, primitive_ast_types, iterable_ast_types

log = logging.getLogger(__name__)


# noinspection PyPep8Naming
class CollapseTransformer(TrackedContextTransformer):
collapse_iterables = False

def visit_Name(self, node):
res = self.resolve_literal(node)
if isinstance(res, primitive_ast_types):
return res
if isinstance(res, iterable_ast_types):
if self.collapse_iterables:
return res
else:
log.debug("Not collapsing iterable {}. Change this setting with collapse_literals(collapse_iterables=True)".format(res))
return node

def visit_BinOp(self, node):
Expand Down
2 changes: 2 additions & 0 deletions pragma/core/resolve/__init__.py
Expand Up @@ -89,11 +89,13 @@ def resolve_name_or_attribute(node, ctxt):
float_types = (float,)

primitive_types = tuple([str, bytes, bool, type(None)] + list(num_types) + list(float_types))
iterable_types = (list, tuple)

try:
primitive_ast_types = (ast.Num, ast.Str, ast.Bytes, ast.NameConstant, ast.Constant, ast.JoinedStr)
except AttributeError: # Python <3.6
primitive_ast_types = (ast.Num, ast.Str, ast.Bytes, ast.NameConstant)
iterable_ast_types = (ast.List, ast.Tuple)


def make_binop(op):
Expand Down
23 changes: 17 additions & 6 deletions pragma/core/transformer.py
Expand Up @@ -358,14 +358,18 @@ def visit_ExceptHandler(self, node):
def make_function_transformer(transformer_type, name, description, **transformer_kwargs):
@optional_argument_decorator
@magic_contract
def transform(return_source=False, save_source=True, function_globals=None, **kwargs):
def transform(return_source=False, save_source=True, function_globals=None, collapse_iterables=False, explicit_only=False, **kwargs):
"""
:param return_source: Returns the transformed function's source code instead of compiling it
:type return_source: bool
:param save_source: Saves the function source code to a tempfile to make it inspectable
:type save_source: bool
:param function_globals: Overridden global name assignments to use when processing the function
:type function_globals: dict|None
:param collapse_iterables: Collapse iterable types
:type collapse_iterables: bool
:param explicit_only: Whether to use global variables or just keyword and function_globals in the replacement context
:type explicit_only: bool
:param kwargs: Any other environmental variables to provide during unrolling
:type kwargs: dict
:return: The transformed function, or its source code if requested
Expand All @@ -375,16 +379,23 @@ def transform(return_source=False, save_source=True, function_globals=None, **kw
@magic_contract(f='Callable', returns='Callable|str')
def inner(f):
f_mod, f_body, f_file = function_ast(f)
# Grab function globals
glbls = f.__globals__.copy()
# Grab function closure variables
if isinstance(f.__closure__, tuple):
glbls.update({k: v.cell_contents for k, v in zip(f.__code__.co_freevars, f.__closure__)})
if not explicit_only:
# Grab function globals
glbls = f.__globals__.copy()
# Grab function closure variables
if isinstance(f.__closure__, tuple):
glbls.update({k: v.cell_contents for k, v in zip(f.__code__.co_freevars, f.__closure__)})
else:
# Initialize empty context
if function_globals is None and len(kwargs) == 0:
log.warning("No global context nor function context. No collapse will occur")
glbls = dict()
# Apply manual globals override
if function_globals is not None:
glbls.update(function_globals)
# print({k: v for k, v in glbls.items() if k not in globals()})
trans = transformer_type(DictStack(glbls, kwargs), **transformer_kwargs)
trans.collapse_iterables = collapse_iterables
f_mod.body[0].decorator_list = []
f_mod = trans.visit(f_mod)
# print(astor.dump_tree(f_mod))
Expand Down
4 changes: 4 additions & 0 deletions tests/pytest.ini
@@ -0,0 +1,4 @@
[pytest]
filterwarnings =
ignore::DeprecationWarning
addopts = -s --log-cli-level 30
38 changes: 38 additions & 0 deletions tests/test_collapse_literals.py
Expand Up @@ -381,6 +381,20 @@ def f():

self.assertSourceEqual(f, result)

def test_iterable_option(self):
a = [1, 2, 3, 4]

@pragma.collapse_literals(collapse_iterables=True)
def f():
x = a

result = '''
def f():
x = [1, 2, 3, 4]
'''

self.assertSourceEqual(f, result)

def test_indexable_operations(self):
dct = dict(a=1, b=2, c=3, d=4)

Expand Down Expand Up @@ -468,3 +482,27 @@ def f():

self.assertSourceEqual(f, result)
self.assertEqual(f(), 4)


def test_explicit_collapse(self):
a = 2
b = 3
@pragma.collapse_literals(explicit_only=True, b=b)
def f():
x = a
y = b
result = '''
def f():
x = a
y = 3
'''
self.assertSourceEqual(f, result)

@pragma.collapse_literals(explicit_only=True)
def f():
x = a
result = '''
def f():
x = a
'''
self.assertSourceEqual(f, result)

0 comments on commit 19087c0

Please sign in to comment.