diff --git a/src/RestrictedPython/PrintCollector.py b/src/RestrictedPython/PrintCollector.py index 56463e4..44ac972 100644 --- a/src/RestrictedPython/PrintCollector.py +++ b/src/RestrictedPython/PrintCollector.py @@ -12,17 +12,13 @@ ############################################################################## from __future__ import print_function -import sys - - -version = sys.version_info - class PrintCollector(object): """Collect written text, and return it when called.""" - def __init__(self): + def __init__(self, _getattr_=None): self.txt = [] + self._getattr_ = _getattr_ def write(self, text): self.txt.append(text) @@ -30,19 +26,10 @@ def write(self, text): def __call__(self): return ''.join(self.txt) + def _call_print(self, *objects, **kwargs): + if kwargs.get('file', None) is None: + kwargs['file'] = self + else: + self._getattr_(kwargs['file'], 'write') -printed = PrintCollector() - - -def safe_print(sep=' ', end='\n', file=printed, flush=False, *objects): - """ - - """ - # TODO: Reorder method args so that *objects is first - # This could first be done if we drop Python 2 support - if file is None or file is sys.stdout or file is sys.stderr: - file = printed - if version >= (3, 3): - print(self, objects, sep=sep, end=end, file=file, flush=flush) - else: - print(self, objects, sep=sep, end=end, file=file) + print(*objects, **kwargs) diff --git a/src/RestrictedPython/transformer.py b/src/RestrictedPython/transformer.py index 5756eae..570ca81 100644 --- a/src/RestrictedPython/transformer.py +++ b/src/RestrictedPython/transformer.py @@ -23,6 +23,7 @@ import ast +import contextlib import sys @@ -161,7 +162,8 @@ ast.Starred, ast.arg, ast.Try, - ast.ExceptHandler + ast.ExceptHandler, + ast.NameConstant ]) if version >= (3, 4): @@ -197,6 +199,26 @@ def copy_locations(new_node, old_node): ast.fix_missing_locations(new_node) +class PrintInfo(object): + def __init__(self): + self.print_used = False + self.printed_used = False + + @contextlib.contextmanager + def new_print_scope(self): + old_print_used = self.print_used + old_printed_used = self.printed_used + + self.print_used = False + self.printed_used = False + + try: + yield + finally: + self.print_used = old_print_used + self.printed_used = old_printed_used + + class RestrictingNodeTransformer(ast.NodeTransformer): def __init__(self, errors=[], warnings=[], used_names=[]): @@ -208,6 +230,8 @@ def __init__(self, errors=[], warnings=[], used_names=[]): # Global counter to construct temporary variable names. self._tmp_idx = 0 + self.print_info = PrintInfo() + def gen_tmp_name(self): # 'check_name' ensures that no variable is prefixed with '_'. # => Its safe to use '_tmp..' as a temporary variable. @@ -424,6 +448,10 @@ def check_name(self, node, name): elif name == "printed": self.error(node, '"printed" is a reserved name.') + elif name == 'print': + # Assignments to 'print' would lead to funny results. + self.error(node, '"print" is a reserved name.') + def check_function_argument_names(self, node): # In python3 arguments are always identifiers. # In python2 the 'Python.asdl' specifies expressions, but @@ -474,6 +502,47 @@ def check_import_names(self, node): return self.generic_visit(node) + def inject_print_collector(self, node, position=0): + print_used = self.print_info.print_used + printed_used = self.print_info.printed_used + + if print_used or printed_used: + # Add '_print = _print_(_getattr_)' add the top of a function/module. + _print = ast.Assign( + targets=[ast.Name('_print', ast.Store())], + value=ast.Call( + func=ast.Name("_print_", ast.Load()), + args=[ast.Name("_getattr_", ast.Load())], + keywords=[])) + + if isinstance(node, ast.Module): + _print.lineno = position + _print.col_offset = position + ast.fix_missing_locations(_print) + else: + copy_locations(_print, node) + + node.body.insert(position, _print) + + if not printed_used: + self.warn(node, "Prints, but never reads 'printed' variable.") + + elif not print_used: + self.warn(node, "Doesn't print, but reads 'printed' variable.") + + def gen_attr_check(self, node, attr_name): + """Check if 'attr_name' is allowed on the object in node. + + It generates (_getattr_(node, attr_name) and node). + """ + + call_getattr = ast.Call( + func=ast.Name('_getattr_', ast.Load()), + args=[node, ast.Str(attr_name)], + keywords=[]) + + return ast.BoolOp(op=ast.And(), values=[call_getattr, node]) + # Special Functions for an ast.NodeTransformer def generic_visit(self, node): @@ -556,11 +625,36 @@ def visit_NameConstant(self, node): # ast for Variables def visit_Name(self, node): - """ + """Prevents access to protected names. + Converts use of the name 'printed' to this expression: '_print()' """ + + node = self.generic_visit(node) + + if isinstance(node.ctx, ast.Load): + if node.id == 'printed': + self.print_info.printed_used = True + new_node = ast.Call( + func=ast.Name("_print", ast.Load()), + args=[], + keywords=[]) + + copy_locations(new_node, node) + return new_node + + elif node.id == 'print': + self.print_info.print_used = True + new_node = ast.Attribute( + value=ast.Name('_print', ast.Load()), + attr="_call_print", + ctx=ast.Load()) + + copy_locations(new_node, node) + return new_node + self.check_name(node, node.id) - return self.generic_visit(node) + return node def visit_Load(self, node): """ @@ -1075,16 +1169,32 @@ def visit_AugAssign(self, node): return node def visit_Print(self, node): + """Checks and mutates a print statement. + + Adds a target to all print statements. 'print foo' becomes + 'print >> _print, foo', where _print is the default print + target defined for this scope. + + Alternatively, if the untrusted code provides its own target, + we have to check the 'write' method of the target. + 'print >> ob, foo' becomes + 'print >> (_getattr_(ob, 'write') and ob), foo'. + Otherwise, it would be possible to call the write method of + templates and scripts; 'write' happens to be the name of the + method that changes them. """ - Fields: - * dest (optional) - * value --> List of Nodes - * nl --> newline (True or False) - """ - if node.dest is not None: - self.error( - node, - 'print statements with destination / chevron are not allowed.') + + self.print_info.print_used = True + + node = self.generic_visit(node) + if node.dest is None: + node.dest = ast.Name('_print', ast.Load()) + else: + # Pre-validate access to the 'write' attribute. + node.dest = self.gen_attr_check(node.dest, 'write') + + copy_locations(node.dest, node) + return node def visit_Raise(self, node): """ @@ -1251,7 +1361,9 @@ def visit_FunctionDef(self, node): self.check_name(node, node.name) self.check_function_argument_names(node) - node = self.generic_visit(node) + with self.print_info.new_print_scope(): + node = self.generic_visit(node) + self.inject_print_collector(node) if version.major == 3: return node @@ -1388,10 +1500,21 @@ def visit_ClassDef(self, node): return self.generic_visit(node) def visit_Module(self, node): - """ + """Adds the print_collector (only if print is used) at the top.""" - """ - return self.generic_visit(node) + node = self.generic_visit(node) + + # Inject the print collector after 'from __future__ import ....' + position = 0 + for position, child in enumerate(node.body): + if not isinstance(child, ast.ImportFrom): + break + + if not child.module == '__future__': + break + + self.inject_print_collector(node, position) + return node # Async und await diff --git a/tests/test_base_example.py b/tests/test_base_example.py deleted file mode 100644 index 648bdfb..0000000 --- a/tests/test_base_example.py +++ /dev/null @@ -1,34 +0,0 @@ -from RestrictedPython import compile_restricted - - -SRC = """\ -def hello_world(): - return "Hello World!" -""" - - -def test_base_example_unrestricted_compile(): - code = compile(SRC, '', 'exec') - locals = {} - exec(code, globals(), locals) - result = locals['hello_world']() - assert result == 'Hello World!' - - -def test_base_example_restricted_compile(): - code = compile_restricted(SRC, '', 'exec') - locals = {} - exec(code, globals(), locals) - assert locals['hello_world']() == 'Hello World!' - - -PRINT_STATEMENT = """\ -print("Hello World!") -""" - - -def test_base_example_catched_stdout(): - from RestrictedPython.PrintCollector import PrintCollector - locals = {'_print_': PrintCollector} - code = compile_restricted(PRINT_STATEMENT, '', 'exec') - exec(code, globals(), locals) diff --git a/tests/test_print.py b/tests/test_print.py deleted file mode 100644 index 838d4c7..0000000 --- a/tests/test_print.py +++ /dev/null @@ -1,66 +0,0 @@ -from RestrictedPython import compile_restricted -from RestrictedPython import compile_restricted_eval -from RestrictedPython import compile_restricted_exec -from RestrictedPython import compile_restricted_function -from RestrictedPython.PrintCollector import PrintCollector -from RestrictedPython.PrintCollector import printed -from RestrictedPython.PrintCollector import safe_print - -import pytest -import sys - - -ALLOWED_PRINT_STATEMENT = """\ -print 'Hello World!' -""" - -ALLOWED_PRINT_STATEMENT_WITH_NL = """\ -print 'Hello World!', -""" - -ALLOWED_MUKTI_PRINT_STATEMENT = """\ -print 'Hello World!', 'Hello Earth!' -""" - -DISSALOWED_PRINT_STATEMENT_WITH_CHEVRON = """\ -print >> stream, 'Hello World!' -""" - -DISSALOWED_PRINT_STATEMENT_WITH_CHEVRON_AND_NL = """\ -print >> stream, 'Hello World!', -""" - -ALLOWED_PRINT_FUNCTION = """\ -print('Hello World!') -""" - -ALLOWED_MULTI_PRINT_FUNCTION = """\ -print('Hello World!', 'Hello Earth!') -""" - -ALLOWED_FUTURE_PRINT_FUNCTION = """\ -from __future import print_function - -print('Hello World!') -""" - -ALLOWED_FUTURE_MULTI_PRINT_FUNCTION = """\ -from __future import print_function - -print('Hello World!', 'Hello Earth!') -""" - -ALLOWED_PRINT_FUNCTION = """\ -print('Hello World!', end='') -""" - -DISALLOWED_PRINT_FUNCTION_WITH_FILE = """\ -print('Hello World!', file=sys.stderr) -""" - - -@pytest.mark.skipif(sys.version_info >= (3, 0), - reason="print statement no longer exists in Python 3") -def test_print__simple_print_statement(): - code, err, warn, use = compile_restricted_exec(ALLOWED_PRINT_STATEMENT, '') - exec(code) diff --git a/tests/test_print_function.py b/tests/test_print_function.py new file mode 100644 index 0000000..2d542e7 --- /dev/null +++ b/tests/test_print_function.py @@ -0,0 +1,360 @@ +from RestrictedPython.PrintCollector import PrintCollector + +import RestrictedPython +import six + + +# The old 'RCompile' has no clue about the print function. +compiler = RestrictedPython.compile.compile_restricted_exec + + +ALLOWED_PRINT_FUNCTION = """ +from __future__ import print_function +print ('Hello World!') +""" + +ALLOWED_PRINT_FUNCTION_WITH_END = """ +from __future__ import print_function +print ('Hello World!', end='') +""" + +ALLOWED_PRINT_FUNCTION_MULTI_ARGS = """ +from __future__ import print_function +print ('Hello World!', 'Hello Earth!') +""" + +ALLOWED_PRINT_FUNCTION_WITH_SEPARATOR = """ +from __future__ import print_function +print ('a', 'b', 'c', sep='|', end='!') +""" + +PRINT_FUNCTION_WITH_NONE_SEPARATOR = """ +from __future__ import print_function +print ('a', 'b', sep=None) +""" + + +PRINT_FUNCTION_WITH_NONE_END = """ +from __future__ import print_function +print ('a', 'b', end=None) +""" + + +PRINT_FUNCTION_WITH_NONE_FILE = """ +from __future__ import print_function +print ('a', 'b', file=None) +""" + + +def test_print_function__simple_prints(): + glb = {'_print_': PrintCollector, '_getattr_': None} + + code, errors = compiler(ALLOWED_PRINT_FUNCTION)[:2] + assert code is not None + assert errors == () + six.exec_(code, glb) + assert glb['_print']() == 'Hello World!\n' + + code, errors = compiler(ALLOWED_PRINT_FUNCTION_WITH_END)[:2] + assert code is not None + assert errors == () + six.exec_(code, glb) + assert glb['_print']() == 'Hello World!' + + code, errors = compiler(ALLOWED_PRINT_FUNCTION_MULTI_ARGS)[:2] + assert code is not None + assert errors == () + six.exec_(code, glb) + assert glb['_print']() == 'Hello World! Hello Earth!\n' + + code, errors = compiler(ALLOWED_PRINT_FUNCTION_WITH_SEPARATOR)[:2] + assert code is not None + assert errors == () + six.exec_(code, glb) + assert glb['_print']() == "a|b|c!" + + code, errors = compiler(PRINT_FUNCTION_WITH_NONE_SEPARATOR)[:2] + assert code is not None + assert errors == () + six.exec_(code, glb) + assert glb['_print']() == "a b\n" + + code, errors = compiler(PRINT_FUNCTION_WITH_NONE_END)[:2] + assert code is not None + assert errors == () + six.exec_(code, glb) + assert glb['_print']() == "a b\n" + + code, errors = compiler(PRINT_FUNCTION_WITH_NONE_FILE)[:2] + assert code is not None + assert errors == () + six.exec_(code, glb) + assert glb['_print']() == "a b\n" + + +ALLOWED_PRINT_FUNCTION_WITH_STAR_ARGS = """ +from __future__ import print_function +to_print = (1, 2, 3) +print(*to_print) +""" + + +def test_print_function_with_star_args(mocker): + _apply_ = mocker.stub() + _apply_.side_effect = lambda func, *args, **kwargs: func(*args, **kwargs) + + glb = { + '_print_': PrintCollector, + '_getattr_': None, + "_apply_": _apply_ + } + + code, errors = compiler(ALLOWED_PRINT_FUNCTION_WITH_STAR_ARGS)[:2] + assert code is not None + assert errors == () + six.exec_(code, glb) + assert glb['_print']() == "1 2 3\n" + _apply_.assert_called_once_with(glb['_print']._call_print, 1, 2, 3) + + +ALLOWED_PRINT_FUNCTION_WITH_KWARGS = """ +from __future__ import print_function +to_print = (1, 2, 3) +kwargs = {'sep': '-', 'end': '!', 'file': None} +print(*to_print, **kwargs) +""" + + +def test_print_function_with_kw_args(mocker): + _apply_ = mocker.stub() + _apply_.side_effect = lambda func, *args, **kwargs: func(*args, **kwargs) + + glb = { + '_print_': PrintCollector, + '_getattr_': None, + "_apply_": _apply_ + } + + code, errors = compiler(ALLOWED_PRINT_FUNCTION_WITH_KWARGS)[:2] + assert code is not None + assert errors == () + six.exec_(code, glb) + assert glb['_print']() == "1-2-3!" + _apply_.assert_called_once_with( + glb['_print']._call_print, + 1, + 2, + 3, + end='!', + file=None, + sep='-') + + +PROTECT_WRITE_ON_FILE = """ +from __future__ import print_function +print ('a', 'b', file=stream) +""" + + +def test_print_function__protect_file(mocker): + _getattr_ = mocker.stub() + _getattr_.side_effect = getattr + stream = mocker.stub() + stream.write = mocker.stub() + + glb = { + '_print_': PrintCollector, + '_getattr_': _getattr_, + 'stream': stream + } + + code, errors = compiler(PROTECT_WRITE_ON_FILE)[:2] + assert code is not None + assert errors == () + + six.exec_(code, glb) + + _getattr_.assert_called_once_with(stream, 'write') + stream.write.assert_has_calls([ + mocker.call('a'), + mocker.call(' '), + mocker.call('b'), + mocker.call('\n') + ]) + + +# 'printed' is scope aware. +# => on a new function scope a new printed is generated. +INJECT_PRINT_COLLECTOR_NESTED = """ +from __future__ import print_function +def f2(): + return 'f2' + +def f1(): + print ('f1') + + def inner(): + print ('inner') + return printed + + return inner() + printed + f2() + +def main(): + print ('main') + return f1() + printed +""" + + +def test_print_function__nested_print_collector(): + code, errors = compiler(INJECT_PRINT_COLLECTOR_NESTED)[:2] + + glb = {"_print_": PrintCollector, '_getattr_': None} + six.exec_(code, glb) + + ret = glb['main']() + assert ret == 'inner\nf1\nf2main\n' + + +WARN_PRINTED_NO_PRINT = """ +def foo(): + return printed +""" + + +def test_print_function__with_printed_no_print(): + code, errors, warnings = compiler(WARN_PRINTED_NO_PRINT)[:3] + + assert code is not None + assert errors == () + assert warnings == ["Line 2: Doesn't print, but reads 'printed' variable."] + + +WARN_PRINTED_NO_PRINT_NESTED = """ +from __future__ import print_function +print ('a') +def foo(): + return printed +printed +""" + + +def test_print_function__with_printed_no_print_nested(): + code, errors, warnings = compiler(WARN_PRINTED_NO_PRINT_NESTED)[:3] + + assert code is not None + assert errors == () + assert warnings == ["Line 4: Doesn't print, but reads 'printed' variable."] + + +WARN_PRINT_NO_PRINTED = """ +from __future__ import print_function +def foo(): + print (1) +""" + + +def test_print_function__with_print_no_printed(): + code, errors, warnings = compiler(WARN_PRINT_NO_PRINTED)[:3] + + assert code is not None + assert errors == () + assert warnings == ["Line 3: Prints, but never reads 'printed' variable."] + + +WARN_PRINT_NO_PRINTED_NESTED = """ +from __future__ import print_function +print ('a') +def foo(): + print ('x') +printed +""" + + +def test_print_function__with_print_no_printed_nested(): + code, errors, warnings = compiler(WARN_PRINT_NO_PRINTED_NESTED)[:3] + + assert code is not None + assert errors == () + assert warnings == ["Line 4: Prints, but never reads 'printed' variable."] + + +# python generates a new frame/scope for: +# modules, functions, class, lambda, all the comprehensions +# For class, lambda and comprehensions *no* new print collector scope should be +# generated. + +NO_PRINT_SCOPES = """ +from __future__ import print_function +def class_scope(): + class A: + print ('a') + return printed + +def lambda_scope(): + func = lambda x: print(x) + func(1) + func(2) + return printed + +def comprehension_scope(): + [print(1) for _ in range(2)] + return printed +""" + + +def test_print_function_no_new_scope(): + code, errors = compiler(NO_PRINT_SCOPES)[:2] + glb = { + '_print_': PrintCollector, + '_getattr_': None, + '_getiter_': lambda ob: ob + } + six.exec_(code, glb) + + ret = glb['class_scope']() + assert ret == 'a\n' + + ret = glb['lambda_scope']() + assert ret == '1\n2\n' + + ret = glb['comprehension_scope']() + assert ret == '1\n1\n' + + +PASS_PRINT_FUNCTION = """ +from __future__ import print_function +def main(): + def do_stuff(func): + func(1) + func(2) + + do_stuff(print) + return printed +""" + + +def test_print_function_pass_print_function(): + code, errors = compiler(PASS_PRINT_FUNCTION)[:2] + glb = {'_print_': PrintCollector, '_getattr_': None} + six.exec_(code, glb) + + ret = glb['main']() + assert ret == '1\n2\n' + + +CONDITIONAL_PRINT = """ +from __future__ import print_function +def func(cond): + if cond: + print(1) + return printed +""" + + +def test_print_function_conditional_print(): + code, errors = compiler(CONDITIONAL_PRINT)[:2] + glb = {'_print_': PrintCollector, '_getattr_': None} + six.exec_(code, glb) + + assert glb['func'](True) == '1\n' + assert glb['func'](False) == '' diff --git a/tests/test_print_stmt.py b/tests/test_print_stmt.py new file mode 100644 index 0000000..4d1a51f --- /dev/null +++ b/tests/test_print_stmt.py @@ -0,0 +1,280 @@ +from RestrictedPython.PrintCollector import PrintCollector + +import pytest +import RestrictedPython +import six +import sys + + +pytestmark = pytest.mark.skipif( + sys.version_info.major == 3, + reason="print statement no longer exists in Python 3") + + +compilers = ('compiler', [RestrictedPython.compile.compile_restricted_exec]) + +if sys.version_info.major == 2: + from RestrictedPython import RCompile + compilers[1].append(RCompile.compile_restricted_exec) + + +ALLOWED_PRINT_STATEMENT = """ +print 'Hello World!' +""" + +ALLOWED_PRINT_STATEMENT_WITH_NO_NL = """ +print 'Hello World!', +""" + +ALLOWED_MULTI_PRINT_STATEMENT = """ +print 'Hello World!', 'Hello Earth!' +""" + +# It looks like a function, but is still a statement in python2.X +ALLOWED_PRINT_TUPLE = """ +print('Hello World!') +""" + + +ALLOWED_PRINT_MULTI_TUPLE = """ +print('Hello World!', 'Hello Earth!') +""" + + +@pytest.mark.parametrize(*compilers) +def test_print_stmt__simple_prints(compiler): + glb = {'_print_': PrintCollector, '_getattr_': None} + + code, errors = compiler(ALLOWED_PRINT_STATEMENT)[:2] + assert code is not None + assert errors == () + six.exec_(code, glb) + assert glb['_print']() == 'Hello World!\n' + + code, errors = compiler(ALLOWED_PRINT_STATEMENT_WITH_NO_NL)[:2] + assert code is not None + assert errors == () + six.exec_(code, glb) + assert glb['_print']() == 'Hello World!' + + code, errors = compiler(ALLOWED_MULTI_PRINT_STATEMENT)[:2] + assert code is not None + assert errors == () + six.exec_(code, glb) + assert glb['_print']() == 'Hello World! Hello Earth!\n' + + code, errors = compiler(ALLOWED_PRINT_TUPLE)[:2] + assert code is not None + assert errors == () + six.exec_(code, glb) + assert glb['_print']() == "Hello World!\n" + + code, errors = compiler(ALLOWED_PRINT_MULTI_TUPLE)[:2] + assert code is not None + assert errors == () + six.exec_(code, glb) + assert glb['_print']() == "('Hello World!', 'Hello Earth!')\n" + + +@pytest.mark.parametrize(*compilers) +def test_print_stmt__fail_with_none_target(compiler, mocker): + code, errors = compiler('print >> None, "test"')[:2] + + assert code is not None + assert errors == () + + glb = {'_getattr_': getattr, '_print_': PrintCollector} + + with pytest.raises(AttributeError) as excinfo: + six.exec_(code, glb) + + assert "'NoneType' object has no attribute 'write'" in str(excinfo.value) + + +PROTECT_PRINT_STATEMENT_WITH_CHEVRON = """ +def print_into_stream(stream): + print >> stream, 'Hello World!' +""" + + +@pytest.mark.parametrize(*compilers) +def test_print_stmt__protect_chevron_print(compiler, mocker): + code, errors = compiler(PROTECT_PRINT_STATEMENT_WITH_CHEVRON)[:2] + + _getattr_ = mocker.stub() + _getattr_.side_effect = getattr + glb = {'_getattr_': _getattr_, '_print_': PrintCollector} + + six.exec_(code, glb) + + stream = mocker.stub() + stream.write = mocker.stub() + glb['print_into_stream'](stream) + + stream.write.assert_has_calls([ + mocker.call('Hello World!'), + mocker.call('\n') + ]) + + _getattr_.assert_called_once_with(stream, 'write') + + +# 'printed' is scope aware. +# => on a new function scope a new printed is generated. +INJECT_PRINT_COLLECTOR_NESTED = """ +def f2(): + return 'f2' + +def f1(): + print 'f1' + + def inner(): + print 'inner' + return printed + + return inner() + printed + f2() + +def main(): + print 'main' + return f1() + printed +""" + + +@pytest.mark.parametrize(*compilers) +def test_print_stmt__nested_print_collector(compiler, mocker): + code, errors = compiler(INJECT_PRINT_COLLECTOR_NESTED)[:2] + + glb = {"_print_": PrintCollector, '_getattr_': None} + six.exec_(code, glb) + + ret = glb['main']() + assert ret == 'inner\nf1\nf2main\n' + + +WARN_PRINTED_NO_PRINT = """ +def foo(): + return printed +""" + + +@pytest.mark.parametrize(*compilers) +def test_print_stmt__with_printed_no_print(compiler): + code, errors, warnings = compiler(WARN_PRINTED_NO_PRINT)[:3] + + assert code is not None + assert errors == () + + if compiler is RestrictedPython.compile.compile_restricted_exec: + assert warnings == [ + "Line 2: Doesn't print, but reads 'printed' variable."] + + if compiler is RestrictedPython.RCompile.compile_restricted_exec: + assert warnings == ["Doesn't print, but reads 'printed' variable."] + + +WARN_PRINTED_NO_PRINT_NESTED = """ +print 'a' +def foo(): + return printed +printed +""" + + +@pytest.mark.parametrize(*compilers) +def test_print_stmt__with_printed_no_print_nested(compiler): + code, errors, warnings = compiler(WARN_PRINTED_NO_PRINT_NESTED)[:3] + + assert code is not None + assert errors == () + + if compiler is RestrictedPython.compile.compile_restricted_exec: + assert warnings == [ + "Line 3: Doesn't print, but reads 'printed' variable."] + + if compiler is RestrictedPython.RCompile.compile_restricted_exec: + assert warnings == ["Doesn't print, but reads 'printed' variable."] + + +WARN_PRINT_NO_PRINTED = """ +def foo(): + print 1 +""" + + +@pytest.mark.parametrize(*compilers) +def test_print_stmt__with_print_no_printed(compiler): + code, errors, warnings = compiler(WARN_PRINT_NO_PRINTED)[:3] + + assert code is not None + assert errors == () + + if compiler is RestrictedPython.compile.compile_restricted_exec: + assert warnings == [ + "Line 2: Prints, but never reads 'printed' variable."] + + if compiler is RestrictedPython.RCompile.compile_restricted_exec: + assert warnings == ["Prints, but never reads 'printed' variable."] + + +WARN_PRINT_NO_PRINTED_NESTED = """ +print 'a' +def foo(): + print 'x' +printed +""" + + +@pytest.mark.parametrize(*compilers) +def test_print_stmt__with_print_no_printed_nested(compiler): + code, errors, warnings = compiler(WARN_PRINT_NO_PRINTED_NESTED)[:3] + + assert code is not None + assert errors == () + + if compiler is RestrictedPython.compile.compile_restricted_exec: + assert warnings == [ + "Line 3: Prints, but never reads 'printed' variable."] + + if compiler is RestrictedPython.RCompile.compile_restricted_exec: + assert warnings == ["Prints, but never reads 'printed' variable."] + + +# python2 generates a new frame/scope for: +# modules, functions, class, lambda +# Since print statement cannot be used in lambda only ensure that no new scope +# for classes is generated. + +NO_PRINT_SCOPES = """ +def class_scope(): + class A: + print 'a' + return printed +""" + + +@pytest.mark.parametrize(*compilers) +def test_print_stmt_no_new_scope(compiler): + code, errors = compiler(NO_PRINT_SCOPES)[:2] + glb = {'_print_': PrintCollector, '_getattr_': None} + six.exec_(code, glb) + + ret = glb['class_scope']() + assert ret == 'a\n' + + +CONDITIONAL_PRINT = """ +def func(cond): + if cond: + print 1 + return printed +""" + + +@pytest.mark.parametrize(*compilers) +def test_print_stmt_conditional_print(compiler): + code, errors = compiler(CONDITIONAL_PRINT)[:2] + glb = {'_print_': PrintCollector, '_getattr_': None} + six.exec_(code, glb) + + assert glb['func'](True) == '1\n' + assert glb['func'](False) == ''