diff --git a/src/RestrictedPython/transformer.py b/src/RestrictedPython/transformer.py index 5756eae..cbcd456 100644 --- a/src/RestrictedPython/transformer.py +++ b/src/RestrictedPython/transformer.py @@ -23,6 +23,7 @@ import ast +import contextlib import sys @@ -197,6 +198,26 @@ def copy_locations(new_node, old_node): ast.fix_missing_locations(new_node) +class PrintInfo(object): + def __init__(self): + self.print_used = False + self.printed_used = False + + @contextlib.contextmanager + def new_print_scope(self): + old_print_used = self.print_used + old_printed_used = self.printed_used + + self.print_used = False + self.printed_used = False + + try: + yield + finally: + self.print_used = old_print_used + self.printed_used = old_printed_used + + class RestrictingNodeTransformer(ast.NodeTransformer): def __init__(self, errors=[], warnings=[], used_names=[]): @@ -208,6 +229,8 @@ def __init__(self, errors=[], warnings=[], used_names=[]): # Global counter to construct temporary variable names. self._tmp_idx = 0 + self.print_info = PrintInfo() + def gen_tmp_name(self): # 'check_name' ensures that no variable is prefixed with '_'. # => Its safe to use '_tmp..' as a temporary variable. @@ -474,6 +497,47 @@ def check_import_names(self, node): return self.generic_visit(node) + def inject_print_collector(self, node): + print_used = self.print_info.print_used + printed_used = self.print_info.printed_used + + if print_used or printed_used: + # Add '_print = _print_()' add the top of a function/module. + _print = ast.Assign( + targets=[ast.Name('_print', ast.Store())], + value=ast.Call( + func=ast.Name("_print_", ast.Load()), + args=[], + keywords=[])) + + if isinstance(node, ast.Module): + _print.lineno = 0 + _print.col_offset = 0 + ast.fix_missing_locations(_print) + else: + copy_locations(_print, node) + + node.body.insert(0, _print) + + if not printed_used: + self.warn(node, "Prints, but never reads 'printed' variable.") + + elif not print_used: + self.warn(node, "Doesn't print, but reads 'printed' variable.") + + def gen_attr_check(self, node, attr_name): + """Check if 'attr_name' is allowed on the object in node. + + It generates (_getattr_(node, attr_name) and node). + """ + + call_getattr = ast.Call( + func=ast.Name('_getattr_', ast.Load()), + args=[node, ast.Str(attr_name)], + keywords=[]) + + return ast.BoolOp(op=ast.And(), values=[call_getattr, node]) + # Special Functions for an ast.NodeTransformer def generic_visit(self, node): @@ -556,11 +620,25 @@ def visit_NameConstant(self, node): # ast for Variables def visit_Name(self, node): - """ + """Prevents access to protected names. + Converts use of the name 'printed' to this expression: '_print()' """ + + node = self.generic_visit(node) + + if node.id == 'printed' and isinstance(node.ctx, ast.Load): + self.print_info.printed_used = True + new_node = ast.Call( + func=ast.Name("_print", ast.Load()), + args=[], + keywords=[]) + + copy_locations(new_node, node) + return new_node + self.check_name(node, node.id) - return self.generic_visit(node) + return node def visit_Load(self, node): """ @@ -1075,16 +1153,32 @@ def visit_AugAssign(self, node): return node def visit_Print(self, node): + """Checks and mutates a print statement. + + Adds a target to all print statements. 'print foo' becomes + 'print >> _print, foo', where _print is the default print + target defined for this scope. + + Alternatively, if the untrusted code provides its own target, + we have to check the 'write' method of the target. + 'print >> ob, foo' becomes + 'print >> (_getattr_(ob, 'write') and ob), foo'. + Otherwise, it would be possible to call the write method of + templates and scripts; 'write' happens to be the name of the + method that changes them. """ - Fields: - * dest (optional) - * value --> List of Nodes - * nl --> newline (True or False) - """ - if node.dest is not None: - self.error( - node, - 'print statements with destination / chevron are not allowed.') + + self.print_info.print_used = True + + node = self.generic_visit(node) + if node.dest is None: + node.dest = ast.Name('_print', ast.Load()) + else: + # Pre-validate access to the 'write' attribute. + node.dest = self.gen_attr_check(node.dest, 'write') + + copy_locations(node.dest, node) + return node def visit_Raise(self, node): """ @@ -1251,7 +1345,9 @@ def visit_FunctionDef(self, node): self.check_name(node, node.name) self.check_function_argument_names(node) - node = self.generic_visit(node) + with self.print_info.new_print_scope(): + node = self.generic_visit(node) + self.inject_print_collector(node) if version.major == 3: return node @@ -1388,10 +1484,11 @@ def visit_ClassDef(self, node): return self.generic_visit(node) def visit_Module(self, node): - """ + """Adds the print_collector (only if print is used) at the top.""" - """ - return self.generic_visit(node) + node = self.generic_visit(node) + self.inject_print_collector(node) + return node # Async und await