Skip to content

Commit

Permalink
Merge pull request #9 from zopefoundation/python3_port_print
Browse files Browse the repository at this point in the history
Python3 port print
  • Loading branch information
stephan-hof committed Nov 29, 2016
2 parents cdb2871 + a383d8f commit a88fd0c
Show file tree
Hide file tree
Showing 6 changed files with 787 additions and 137 deletions.
29 changes: 8 additions & 21 deletions src/RestrictedPython/PrintCollector.py
Expand Up @@ -12,37 +12,24 @@
##############################################################################
from __future__ import print_function

import sys


version = sys.version_info


class PrintCollector(object):
"""Collect written text, and return it when called."""

def __init__(self):
def __init__(self, _getattr_=None):
self.txt = []
self._getattr_ = _getattr_

def write(self, text):
self.txt.append(text)

def __call__(self):
return ''.join(self.txt)

def _call_print(self, *objects, **kwargs):
if kwargs.get('file', None) is None:
kwargs['file'] = self
else:
self._getattr_(kwargs['file'], 'write')

printed = PrintCollector()


def safe_print(sep=' ', end='\n', file=printed, flush=False, *objects):
"""
"""
# TODO: Reorder method args so that *objects is first
# This could first be done if we drop Python 2 support
if file is None or file is sys.stdout or file is sys.stderr:
file = printed
if version >= (3, 3):
print(self, objects, sep=sep, end=end, file=file, flush=flush)
else:
print(self, objects, sep=sep, end=end, file=file)
print(*objects, **kwargs)
155 changes: 139 additions & 16 deletions src/RestrictedPython/transformer.py
Expand Up @@ -23,6 +23,7 @@


import ast
import contextlib
import sys


Expand Down Expand Up @@ -161,7 +162,8 @@
ast.Starred,
ast.arg,
ast.Try,
ast.ExceptHandler
ast.ExceptHandler,
ast.NameConstant
])

if version >= (3, 4):
Expand Down Expand Up @@ -197,6 +199,26 @@ def copy_locations(new_node, old_node):
ast.fix_missing_locations(new_node)


class PrintInfo(object):
def __init__(self):
self.print_used = False
self.printed_used = False

@contextlib.contextmanager
def new_print_scope(self):
old_print_used = self.print_used
old_printed_used = self.printed_used

self.print_used = False
self.printed_used = False

try:
yield
finally:
self.print_used = old_print_used
self.printed_used = old_printed_used


class RestrictingNodeTransformer(ast.NodeTransformer):

def __init__(self, errors=[], warnings=[], used_names=[]):
Expand All @@ -208,6 +230,8 @@ def __init__(self, errors=[], warnings=[], used_names=[]):
# Global counter to construct temporary variable names.
self._tmp_idx = 0

self.print_info = PrintInfo()

def gen_tmp_name(self):
# 'check_name' ensures that no variable is prefixed with '_'.
# => Its safe to use '_tmp..' as a temporary variable.
Expand Down Expand Up @@ -424,6 +448,10 @@ def check_name(self, node, name):
elif name == "printed":
self.error(node, '"printed" is a reserved name.')

elif name == 'print':
# Assignments to 'print' would lead to funny results.
self.error(node, '"print" is a reserved name.')

def check_function_argument_names(self, node):
# In python3 arguments are always identifiers.
# In python2 the 'Python.asdl' specifies expressions, but
Expand Down Expand Up @@ -474,6 +502,47 @@ def check_import_names(self, node):

return self.generic_visit(node)

def inject_print_collector(self, node, position=0):
print_used = self.print_info.print_used
printed_used = self.print_info.printed_used

if print_used or printed_used:
# Add '_print = _print_(_getattr_)' add the top of a function/module.
_print = ast.Assign(
targets=[ast.Name('_print', ast.Store())],
value=ast.Call(
func=ast.Name("_print_", ast.Load()),
args=[ast.Name("_getattr_", ast.Load())],
keywords=[]))

if isinstance(node, ast.Module):
_print.lineno = position
_print.col_offset = position
ast.fix_missing_locations(_print)
else:
copy_locations(_print, node)

node.body.insert(position, _print)

if not printed_used:
self.warn(node, "Prints, but never reads 'printed' variable.")

elif not print_used:
self.warn(node, "Doesn't print, but reads 'printed' variable.")

def gen_attr_check(self, node, attr_name):
"""Check if 'attr_name' is allowed on the object in node.
It generates (_getattr_(node, attr_name) and node).
"""

call_getattr = ast.Call(
func=ast.Name('_getattr_', ast.Load()),
args=[node, ast.Str(attr_name)],
keywords=[])

return ast.BoolOp(op=ast.And(), values=[call_getattr, node])

# Special Functions for an ast.NodeTransformer

def generic_visit(self, node):
Expand Down Expand Up @@ -556,11 +625,36 @@ def visit_NameConstant(self, node):
# ast for Variables

def visit_Name(self, node):
"""
"""Prevents access to protected names.
Converts use of the name 'printed' to this expression: '_print()'
"""

node = self.generic_visit(node)

if isinstance(node.ctx, ast.Load):
if node.id == 'printed':
self.print_info.printed_used = True
new_node = ast.Call(
func=ast.Name("_print", ast.Load()),
args=[],
keywords=[])

copy_locations(new_node, node)
return new_node

elif node.id == 'print':
self.print_info.print_used = True
new_node = ast.Attribute(
value=ast.Name('_print', ast.Load()),
attr="_call_print",
ctx=ast.Load())

copy_locations(new_node, node)
return new_node

self.check_name(node, node.id)
return self.generic_visit(node)
return node

def visit_Load(self, node):
"""
Expand Down Expand Up @@ -1075,16 +1169,32 @@ def visit_AugAssign(self, node):
return node

def visit_Print(self, node):
"""Checks and mutates a print statement.
Adds a target to all print statements. 'print foo' becomes
'print >> _print, foo', where _print is the default print
target defined for this scope.
Alternatively, if the untrusted code provides its own target,
we have to check the 'write' method of the target.
'print >> ob, foo' becomes
'print >> (_getattr_(ob, 'write') and ob), foo'.
Otherwise, it would be possible to call the write method of
templates and scripts; 'write' happens to be the name of the
method that changes them.
"""
Fields:
* dest (optional)
* value --> List of Nodes
* nl --> newline (True or False)
"""
if node.dest is not None:
self.error(
node,
'print statements with destination / chevron are not allowed.')

self.print_info.print_used = True

node = self.generic_visit(node)
if node.dest is None:
node.dest = ast.Name('_print', ast.Load())
else:
# Pre-validate access to the 'write' attribute.
node.dest = self.gen_attr_check(node.dest, 'write')

copy_locations(node.dest, node)
return node

def visit_Raise(self, node):
"""
Expand Down Expand Up @@ -1251,7 +1361,9 @@ def visit_FunctionDef(self, node):
self.check_name(node, node.name)
self.check_function_argument_names(node)

node = self.generic_visit(node)
with self.print_info.new_print_scope():
node = self.generic_visit(node)
self.inject_print_collector(node)

if version.major == 3:
return node
Expand Down Expand Up @@ -1388,10 +1500,21 @@ def visit_ClassDef(self, node):
return self.generic_visit(node)

def visit_Module(self, node):
"""
"""Adds the print_collector (only if print is used) at the top."""

"""
return self.generic_visit(node)
node = self.generic_visit(node)

# Inject the print collector after 'from __future__ import ....'
position = 0
for position, child in enumerate(node.body):
if not isinstance(child, ast.ImportFrom):
break

if not child.module == '__future__':
break

self.inject_print_collector(node, position)
return node

# Async und await

Expand Down
34 changes: 0 additions & 34 deletions tests/test_base_example.py

This file was deleted.

66 changes: 0 additions & 66 deletions tests/test_print.py

This file was deleted.

0 comments on commit a88fd0c

Please sign in to comment.