Skip to content

Commit

Permalink
Prevent leak of StackContexts in repeated gen.engine functions.
Browse files Browse the repository at this point in the history
Internally, StackContexts now return a deactivation callback,
which can be used to prevent that StackContext from propagating
further.  This is used in gen.engine because the decorator doesn't know
which arguments are callbacks that need to be wrapped outside of its
ExceptionStackContext.  This is deliberately undocumented for now.

Closes #507.
  • Loading branch information
bdarnell committed May 21, 2012
1 parent 1be0cc4 commit 57a3f83
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 12 deletions.
9 changes: 6 additions & 3 deletions tornado/gen.py
Expand Up @@ -113,13 +113,14 @@ def handle_exception(typ, value, tb):
if runner is not None: if runner is not None:
return runner.handle_exception(typ, value, tb) return runner.handle_exception(typ, value, tb)
return False return False
with ExceptionStackContext(handle_exception): with ExceptionStackContext(handle_exception) as deactivate:
gen = func(*args, **kwargs) gen = func(*args, **kwargs)
if isinstance(gen, types.GeneratorType): if isinstance(gen, types.GeneratorType):
runner = Runner(gen) runner = Runner(gen, deactivate)
runner.run() runner.run()
return return
assert gen is None, gen assert gen is None, gen
deactivate()
# no yield, so we're done # no yield, so we're done
return wrapper return wrapper


Expand Down Expand Up @@ -285,8 +286,9 @@ class Runner(object):
Maintains information about pending callbacks and their results. Maintains information about pending callbacks and their results.
""" """
def __init__(self, gen): def __init__(self, gen, deactivate_stack_context):
self.gen = gen self.gen = gen
self.deactivate_stack_context = deactivate_stack_context
self.yield_point = _NullYieldPoint() self.yield_point = _NullYieldPoint()
self.pending_callbacks = set() self.pending_callbacks = set()
self.results = {} self.results = {}
Expand Down Expand Up @@ -351,6 +353,7 @@ def run(self):
raise LeakedCallbackError( raise LeakedCallbackError(
"finished without waiting for callbacks %r" % "finished without waiting for callbacks %r" %
self.pending_callbacks) self.pending_callbacks)
self.deactivate_stack_context()
return return
except Exception: except Exception:
self.finished = True self.finished = True
Expand Down
29 changes: 20 additions & 9 deletions tornado/stack_context.py
Expand Up @@ -71,6 +71,7 @@ def die_on_error():
import contextlib import contextlib
import functools import functools
import itertools import itertools
import operator
import sys import sys
import threading import threading


Expand All @@ -95,23 +96,25 @@ class StackContext(object):
with StackContext(my_context): with StackContext(my_context):
''' '''
def __init__(self, context_factory): def __init__(self, context_factory, _active_cell=None):
self.context_factory = context_factory self.context_factory = context_factory
self.active_cell = _active_cell or [True]


# Note that some of this code is duplicated in ExceptionStackContext # Note that some of this code is duplicated in ExceptionStackContext
# below. ExceptionStackContext is more common and doesn't need # below. ExceptionStackContext is more common and doesn't need
# the full generality of this class. # the full generality of this class.
def __enter__(self): def __enter__(self):
self.old_contexts = _state.contexts self.old_contexts = _state.contexts
# _state.contexts is a tuple of (class, arg) pairs # _state.contexts is a tuple of (class, arg, active_cell) tuples
_state.contexts = (self.old_contexts + _state.contexts = (self.old_contexts +
((StackContext, self.context_factory),)) ((StackContext, self.context_factory, self.active_cell),))
try: try:
self.context = self.context_factory() self.context = self.context_factory()
self.context.__enter__() self.context.__enter__()
except Exception: except Exception:
_state.contexts = self.old_contexts _state.contexts = self.old_contexts
raise raise
return lambda: operator.setitem(self.active_cell, 0, False)


def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
try: try:
Expand All @@ -133,13 +136,16 @@ class ExceptionStackContext(object):
If the exception handler returns true, the exception will be If the exception handler returns true, the exception will be
consumed and will not be propagated to other exception handlers. consumed and will not be propagated to other exception handlers.
''' '''
def __init__(self, exception_handler): def __init__(self, exception_handler, _active_cell=None):
self.exception_handler = exception_handler self.exception_handler = exception_handler
self.active_cell = _active_cell or [True]


def __enter__(self): def __enter__(self):
self.old_contexts = _state.contexts self.old_contexts = _state.contexts
_state.contexts = (self.old_contexts + _state.contexts = (self.old_contexts +
((ExceptionStackContext, self.exception_handler),)) ((ExceptionStackContext, self.exception_handler,
self.active_cell),))
return lambda: operator.setitem(self.active_cell, 0, False)


def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
try: try:
Expand Down Expand Up @@ -186,7 +192,9 @@ def wrapped(callback, contexts, *args, **kwargs):
callback(*args, **kwargs) callback(*args, **kwargs)
return return
if not _state.contexts: if not _state.contexts:
new_contexts = [cls(arg) for (cls, arg) in contexts] new_contexts = [cls(arg, active_cell)
for (cls, arg, active_cell) in contexts
if active_cell[0]]
# If we're moving down the stack, _state.contexts is a prefix # If we're moving down the stack, _state.contexts is a prefix
# of contexts. For each element of contexts not in that prefix, # of contexts. For each element of contexts not in that prefix,
# create a new StackContext object. # create a new StackContext object.
Expand All @@ -198,10 +206,13 @@ def wrapped(callback, contexts, *args, **kwargs):
for a, b in itertools.izip(_state.contexts, contexts))): for a, b in itertools.izip(_state.contexts, contexts))):
# contexts have been removed or changed, so start over # contexts have been removed or changed, so start over
new_contexts = ([NullContext()] + new_contexts = ([NullContext()] +
[cls(arg) for (cls, arg) in contexts]) [cls(arg, active_cell)
for (cls, arg, active_cell) in contexts
if active_cell[0]])
else: else:
new_contexts = [cls(arg) new_contexts = [cls(arg, active_cell)
for (cls, arg) in contexts[len(_state.contexts):]] for (cls, arg, active_cell) in contexts[len(_state.contexts):]
if active_cell[0]]
if len(new_contexts) > 1: if len(new_contexts) > 1:
with _nested(*new_contexts): with _nested(*new_contexts):
callback(*args, **kwargs) callback(*args, **kwargs)
Expand Down
18 changes: 18 additions & 0 deletions tornado/test/gen_test.py
Expand Up @@ -249,6 +249,24 @@ def task_func(callback):
self.stop() self.stop()
self.run_gen(f) self.run_gen(f)


def test_stack_context_leak(self):
# regression test: repeated invocations of a gen-based
# function should not result in accumulated stack_contexts
from tornado import stack_context
@gen.engine
def inner(callback):
yield gen.Task(self.io_loop.add_callback)
callback()
@gen.engine
def outer():
for i in xrange(10):
yield gen.Task(inner)
stack_increase = len(stack_context._state.contexts) - initial_stack_depth
self.assertTrue(stack_increase <= 2)
self.stop()
initial_stack_depth = len(stack_context._state.contexts)
self.run_gen(outer)



class GenSequenceHandler(RequestHandler): class GenSequenceHandler(RequestHandler):
@asynchronous @asynchronous
Expand Down
27 changes: 27 additions & 0 deletions tornado/test/stack_context_test.py
Expand Up @@ -93,5 +93,32 @@ def final_callback():
library_function(final_callback) library_function(final_callback)
self.wait() self.wait()


def test_deactivate(self):
deactivate_callbacks = []
def f1():
with StackContext(functools.partial(self.context, 'c1')) as c1:
deactivate_callbacks.append(c1)
self.io_loop.add_callback(f2)
def f2():
with StackContext(functools.partial(self.context, 'c2')) as c2:
deactivate_callbacks.append(c2)
self.io_loop.add_callback(f3)
def f3():
with StackContext(functools.partial(self.context, 'c3')) as c3:
deactivate_callbacks.append(c3)
self.io_loop.add_callback(f4)
def f4():
self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
deactivate_callbacks[1]()
# deactivating a context doesn't remove it immediately,
# but it will be missing from the next iteration
self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
self.io_loop.add_callback(f5)
def f5():
self.assertEqual(self.active_contexts, ['c1', 'c3'])
self.stop()
self.io_loop.add_callback(f1)
self.wait()

if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

0 comments on commit 57a3f83

Please sign in to comment.