Skip to content

Commit

Permalink
refactored compiler and improved identifier handling for for-loops
Browse files Browse the repository at this point in the history
--HG--
branch : trunk
  • Loading branch information
mitsuhiko committed May 23, 2008
1 parent 903d168 commit 105f0dc
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 43 deletions.
94 changes: 55 additions & 39 deletions jinja2/compiler.py
Expand Up @@ -533,6 +533,11 @@ def collect_shadowed(self, frame):
self.writeline('%s = l_%s' % (ident, name))
return aliases

def restore_shadowed(self, aliases):
"""Restore all aliases."""
for name, alias in aliases.iteritems():
self.writeline('l_%s = %s' % (name, alias))

def function_scoping(self, node, frame, children=None,
find_special=True):
"""In Jinja a few statements require the help of anonymous
Expand Down Expand Up @@ -870,19 +875,28 @@ def visit_FromImport(self, node, frame):
def visit_For(self, node, frame):
# when calculating the nodes for the inner frame we have to exclude
# the iterator contents from it
children = list(node.iter_child_nodes(exclude=('iter',)))

children = node.iter_child_nodes(exclude=('iter',))
if node.recursive:
loop_frame = self.function_scoping(node, frame, children,
find_special=False)
else:
loop_frame = frame.inner()
loop_frame.inspect(children)

undeclared = find_undeclared(children, ('loop',))
extended_loop = node.recursive or node.else_ or 'loop' in undeclared
if extended_loop:
loop_frame.identifiers.add_special('loop')
# try to figure out if we have an extended loop. An extended loop
# is necessary if the loop is in recursive mode if the special loop
# variable is accessed in the body.
extended_loop = node.recursive or 'loop' in \
find_undeclared(node.iter_child_nodes(
only=('body',)), ('loop',))

# make sure the loop variable is a special one and raise a template
# assertion error if a loop tries to write to loop
loop_frame.identifiers.add_special('loop')
for name in node.find_all(nodes.Name):
if name.ctx == 'store' and name.name == 'loop':
self.fail('Can\'t assign to special loop variable '
'in for-loop target', name.lineno)

# if we don't have an recursive loop we have to find the shadowed
# variables at that point
Expand All @@ -898,22 +912,24 @@ def visit_For(self, node, frame):

self.pull_locals(loop_frame)
if node.else_:
self.writeline('l_loop = None')

self.newline(node)
self.writeline('for ')
iteration_indicator = self.temporary_identifier()
self.writeline('%s = 1' % iteration_indicator)

# Create a fake parent loop if the else or test section of a
# loop is accessing the special loop variable and no parent loop
# exists.
if 'loop' not in aliases and 'loop' in find_undeclared(
node.iter_child_nodes(only=('else_', 'test')), ('loop',)):
self.writeline("l_loop = environment.undefined(%r, name='loop')" %
"'loop' is undefined. the filter section of a loop as well " \
"as the else block doesn't have access to the special 'loop' "
"variable of the current loop. Because there is no parent "
"loop it's undefined.")

self.writeline('for ', node)
self.visit(node.target, loop_frame)
self.write(extended_loop and ', l_loop in LoopContext(' or ' in ')

# the expression pointing to the parent loop. We make the
# undefined a bit more debug friendly at the same time.
parent_loop = 'loop' in aliases and aliases['loop'] \
or "environment.undefined(%r, name='loop')" % "'loop' " \
'is undefined. "the filter section of a loop as well ' \
'as the else block doesn\'t have access to the ' \
"special 'loop' variable of the current loop. " \
"Because there is no parent loop it's undefined."

# if we have an extened loop and a node test, we filter in the
# "outer frame".
if extended_loop and node.test is not None:
Expand All @@ -928,7 +944,6 @@ def visit_For(self, node, frame):
self.visit(node.iter, loop_frame)
self.write(' if (')
test_frame = loop_frame.copy()
self.writeline('l_loop = ' + parent_loop)
self.visit(node.test, test_frame)
self.write('))')

Expand All @@ -954,18 +969,18 @@ def visit_For(self, node, frame):

self.indent()
self.blockvisit(node.body, loop_frame, force_generator=True)
if node.else_:
self.writeline('%s = 0' % iteration_indicator)
self.outdent()

if node.else_:
self.writeline('if l_loop is None:')
self.writeline('if %s:' % iteration_indicator)
self.indent()
self.writeline('l_loop = ' + parent_loop)
self.blockvisit(node.else_, loop_frame, force_generator=False)
self.outdent()

# reset the aliases if there are any.
for name, alias in aliases.iteritems():
self.writeline('l_%s = %s' % (name, alias))
self.restore_shadowed(aliases)

# if the node was recursive we have to return the buffer contents
# and start the iteration code
Expand Down Expand Up @@ -1008,25 +1023,20 @@ def visit_CallBlock(self, node, frame):
self.writeline('caller = ')
self.macro_def(node, call_frame)
self.start_write(frame, node)
self.visit_Call(node.call, call_frame,
extra_kwargs={'caller': 'caller'})
self.visit_Call(node.call, call_frame, forward_caller=True)
self.end_write(frame)

def visit_FilterBlock(self, node, frame):
filter_frame = frame.inner()
filter_frame.inspect(node.iter_child_nodes())

aliases = self.collect_shadowed(filter_frame)
self.pull_locals(filter_frame)
self.buffer(filter_frame)

for child in node.body:
self.visit(child, filter_frame)

self.blockvisit(node.body, filter_frame, force_generator=False)
self.start_write(frame, node)
self.visit_Filter(node.filter, filter_frame, 'concat(%s)'
% filter_frame.buffer)
self.visit_Filter(node.filter, filter_frame)
self.end_write(frame)
self.restore_shadowed(aliases)

def visit_ExprStmt(self, node, frame):
self.newline(node)
Expand Down Expand Up @@ -1283,7 +1293,7 @@ def visit_Slice(self, node, frame):
self.write(':')
self.visit(node.step, frame)

def visit_Filter(self, node, frame, initial=None):
def visit_Filter(self, node, frame):
self.write(self.filters[node.name] + '(')
func = self.environment.filters.get(node.name)
if func is None:
Expand All @@ -1292,10 +1302,15 @@ def visit_Filter(self, node, frame, initial=None):
self.write('context, ')
elif getattr(func, 'environmentfilter', False):
self.write('environment, ')
if isinstance(node.node, nodes.Filter):
self.visit_Filter(node.node, frame, initial)
elif node.node is None:
self.write(initial)

# if the filter node is None we are inside a filter block
# and want to write to the current buffer
if node.node is None:
if self.environment.autoescape:
tmpl = 'Markup(concat(%s))'
else:
tmpl = 'concat(%s)'
self.write(tmpl % frame.buffer)
else:
self.visit(node.node, frame)
self.signature(node, frame)
Expand Down Expand Up @@ -1327,11 +1342,12 @@ def visit_CondExpr(self, node, frame):
self.visit(node.expr2, frame)
self.write(')')

def visit_Call(self, node, frame, extra_kwargs=None):
def visit_Call(self, node, frame, forward_caller=False):
if self.environment.sandboxed:
self.write('environment.call(')
self.visit(node.node, frame)
self.write(self.environment.sandboxed and ', ' or '(')
extra_kwargs = forward_caller and {'caller': 'caller'} or None
self.signature(node, frame, False, extra_kwargs)
self.write(')')

Expand Down
13 changes: 9 additions & 4 deletions jinja2/nodes.py
Expand Up @@ -116,24 +116,26 @@ def __init__(self, *fields, **attributes):
raise TypeError('unknown attribute %r' %
iter(attributes).next())

def iter_fields(self, exclude=()):
def iter_fields(self, exclude=None, only=None):
"""This method iterates over all fields that are defined and yields
``(key, value)`` tuples. Optionally a parameter of ignored fields
can be provided.
"""
for name in self.fields:
if name not in exclude:
if (exclude is only is None) or \
(exclude is not None and name not in exclude) or \
(only is not None and name in only):
try:
yield name, getattr(self, name)
except AttributeError:
pass

def iter_child_nodes(self, exclude=()):
def iter_child_nodes(self, exclude=None, only=None):
"""Iterates over all direct child nodes of the node. This iterates
over all fields and yields the values of they are nodes. If the value
of a field is a list all the nodes in that list are returned.
"""
for field, item in self.iter_fields(exclude):
for field, item in self.iter_fields(exclude, only):
if isinstance(item, list):
for n in item:
if isinstance(n, Node):
Expand Down Expand Up @@ -529,6 +531,9 @@ def as_const(self):
class Filter(Expr):
"""This node applies a filter on an expression. `name` is the name of
the filter, the rest of the fields are the same as for :class:`Call`.
If the `node` of a filter is `None` the contents of the last buffer are
filtered. Buffers are created by macros and filter blocks.
"""
fields = ('node', 'name', 'args', 'kwargs', 'dyn_args', 'dyn_kwargs')

Expand Down
12 changes: 12 additions & 0 deletions tests/test_forloop.py
Expand Up @@ -7,6 +7,7 @@
:license: BSD, see LICENSE for more details.
"""
from py.test import raises
from jinja2.exceptions import UndefinedError


SIMPLE = '''{% for item in seq %}{{ item }}{% endfor %}'''
Expand All @@ -30,6 +31,10 @@
[{{ rowloop.index }}|{{ loop.index }}]
{%- endfor %}
{%- endfor %}'''
LOOPERROR1 = '''\
{% for item in [1] if loop.index == 0 %}...{% endfor %}'''
LOOPERROR2 = '''\
{% for item in [] %}...{% else %}{{ loop }}{% endfor %}'''


def test_simple(env):
Expand Down Expand Up @@ -102,3 +107,10 @@ def test_recursive(env):
def test_looploop(env):
tmpl = env.from_string(LOOPLOOP)
assert tmpl.render(table=['ab', 'cd']) == '[1|1][1|2][2|1][2|2]'


def test_loop_errors(env):
tmpl = env.from_string(LOOPERROR1)
raises(UndefinedError, tmpl.render)
tmpl = env.from_string(LOOPERROR2)
assert tmpl.render() == ''

0 comments on commit 105f0dc

Please sign in to comment.