Skip to content

Commit

Permalink
Optimized StackContext implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
mrjoes committed Apr 13, 2013
1 parent 8b72824 commit 681e76a
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 127 deletions.
9 changes: 2 additions & 7 deletions tornado/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def handle_exception(typ, value, tb):
if runner is not None:
return runner.handle_exception(typ, value, tb)
return False
with ExceptionStackContext(handle_exception) as deactivate:
with ExceptionStackContext(handle_exception):
try:
result = func(*args, **kwargs)
except (Return, StopIteration) as e:
Expand All @@ -149,15 +149,13 @@ def final_callback(value):
"@gen.engine functions cannot return values: "
"%r" % (value,))
assert value is None
deactivate()
runner = Runner(result, final_callback)
runner.run()
return
if result is not None:
raise ReturnValueIgnoredError(
"@gen.engine functions cannot return values: %r" %
(result,))
deactivate()
# no yield, so we're done
return wrapper

Expand Down Expand Up @@ -210,24 +208,21 @@ def handle_exception(typ, value, tb):
typ, value, tb = sys.exc_info()
future.set_exc_info((typ, value, tb))
return True
with ExceptionStackContext(handle_exception) as deactivate:
with ExceptionStackContext(handle_exception):
try:
result = func(*args, **kwargs)
except (Return, StopIteration) as e:
result = getattr(e, 'value', None)
except Exception:
deactivate()
future.set_exc_info(sys.exc_info())
return future
else:
if isinstance(result, types.GeneratorType):
def final_callback(value):
deactivate()
future.set_result(value)
runner = Runner(result, final_callback)
runner.run()
return future
deactivate()
future.set_result(result)
return future
return wrapper
Expand Down
196 changes: 108 additions & 88 deletions tornado/stack_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@ def die_on_error():

from __future__ import absolute_import, division, print_function, with_statement

import contextlib
import functools
import operator
import sys
import threading

Expand All @@ -84,7 +81,7 @@ class StackContextInconsistentError(Exception):

class _State(threading.local):
def __init__(self):
self.contexts = ()
self.contexts = (tuple(), None)
_state = _State()


Expand All @@ -108,45 +105,51 @@ class StackContext(object):
context that are currently pending). This is an advanced feature
and not necessary in most applications.
"""
def __init__(self, context_factory, _active_cell=None):
def __init__(self, context_factory):
self.context_factory = context_factory
self.active_cell = _active_cell or [True]
self.contexts = []

# StackContext protocol
def enter(self):
context = self.context_factory()
self.contexts.append(context)
context.__enter__()

def exit(self, type, value, traceback):
context = self.contexts.pop()
context.__exit__(type, value, traceback)

# Note that some of this code is duplicated in ExceptionStackContext
# below. ExceptionStackContext is more common and doesn't need
# the full generality of this class.
def __enter__(self):
self.old_contexts = _state.contexts
# _state.contexts is a tuple of (class, arg, active_cell) tuples
self.new_contexts = (self.old_contexts +
((StackContext, self.context_factory,
self.active_cell),))
self.new_contexts = (self.old_contexts[0] + (self,), self)
_state.contexts = self.new_contexts

try:
self.context = self.context_factory()
self.context.__enter__()
except Exception:
self.enter()
except:
_state.contexts = self.old_contexts
raise
return lambda: operator.setitem(self.active_cell, 0, False)

def __exit__(self, type, value, traceback):
try:
return self.context.__exit__(type, value, traceback)
self.exit(type, value, traceback)
finally:
final_contexts = _state.contexts
_state.contexts = self.old_contexts

# Generator coroutines and with-statements with non-local
# effects interact badly. Check here for signs of
# the stack getting out of sync.
# Note that this check comes after restoring _state.context
# so that if it fails things are left in a (relatively)
# consistent state.
if final_contexts is not self.new_contexts:
if final_contexts != self.new_contexts:
raise StackContextInconsistentError(
'stack_context inconsistency (may be caused by yield '
'within a "with StackContext" block)')
self.old_contexts = self.new_contexts = None


class ExceptionStackContext(object):
Expand All @@ -162,17 +165,17 @@ class ExceptionStackContext(object):
If the exception handler returns true, the exception will be
consumed and will not be propagated to other exception handlers.
"""
def __init__(self, exception_handler, _active_cell=None):
def __init__(self, exception_handler):
self.exception_handler = exception_handler
self.active_cell = _active_cell or [True]

def exit(self, type, value, traceback):
if type is not None:
return self.exception_handler(type, value, traceback)

def __enter__(self):
self.old_contexts = _state.contexts
self.new_contexts = (self.old_contexts +
((ExceptionStackContext, self.exception_handler,
self.active_cell),))
self.new_contexts = (self.old_contexts[0], self)
_state.contexts = self.new_contexts
return lambda: operator.setitem(self.active_cell, 0, False)

def __exit__(self, type, value, traceback):
try:
Expand All @@ -181,11 +184,11 @@ def __exit__(self, type, value, traceback):
finally:
final_contexts = _state.contexts
_state.contexts = self.old_contexts
if final_contexts is not self.new_contexts:

if final_contexts != self.new_contexts:
raise StackContextInconsistentError(
'stack_context inconsistency (may be caused by yield '
'within a "with StackContext" block)')
self.old_contexts = self.new_contexts = None


class NullContext(object):
Expand All @@ -197,16 +200,12 @@ class NullContext(object):
"""
def __enter__(self):
self.old_contexts = _state.contexts
_state.contexts = ()
_state.contexts = (tuple(), None)

def __exit__(self, type, value, traceback):
_state.contexts = self.old_contexts


class _StackContextWrapper(functools.partial):
pass


def wrap(fn):
"""Returns a callable object that will restore the current `StackContext`
when executed.
Expand All @@ -215,64 +214,85 @@ def wrap(fn):
different execution context (either in a different thread or
asynchronously in the same thread).
"""
if fn is None or fn.__class__ is _StackContextWrapper:
# Check if function is already wrapped
if fn is None or hasattr(fn, '_wrapped'):
return fn
# functools.wraps doesn't appear to work on functools.partial objects
#@functools.wraps(fn)

# Capture current stack head
contexts = _state.contexts

#@functools.wraps
def wrapped(*args, **kwargs):
callback, contexts, args = args[0], args[1], args[2:]

if _state.contexts:
new_contexts = [NullContext()]
else:
new_contexts = []
if contexts:
new_contexts.extend(cls(arg, active_cell)
for (cls, arg, active_cell) in contexts
if active_cell[0])
if len(new_contexts) > 1:
with _nested(*new_contexts):
callback(*args, **kwargs)
elif new_contexts:
with new_contexts[0]:
callback(*args, **kwargs)
else:
callback(*args, **kwargs)
return _StackContextWrapper(wrapped, fn, _state.contexts)


@contextlib.contextmanager
def _nested(*managers):
"""Support multiple context managers in a single with-statement.
Copied from the python 2.6 standard library. It's no longer present
in python 3 because the with statement natively supports multiple
context managers, but that doesn't help if the list of context
managers is not known until runtime.
"""
exits = []
vars = []
exc = (None, None, None)
try:
for mgr in managers:
exit = mgr.__exit__
enter = mgr.__enter__
vars.append(enter())
exits.append(exit)
yield vars
except:
exc = sys.exc_info()
finally:
while exits:
exit = exits.pop()
try:
if exit(*exc):
exc = (None, None, None)
except:
exc = sys.exc_info()
if exc != (None, None, None):
# Don't rely on sys.exc_info() still containing
# the right information. Another exception may
# have been raised and caught by an exit method
raise_exc_info(exc)
try:
# Force local state - switch to new stack chain
current_state = _state.contexts
_state.contexts = contexts

# Current exception
exc = (None, None, None)
top = None

# Apply stack contexts
last_ctx = 0
stack = contexts[0]

# Apply state
for n in stack:
try:
n.enter()
last_ctx += 1
except:
# Exception happened. Record exception info and store top-most handler
exc = sys.exc_info()
top = n.old_contexts[1]

# Execute callback if no exception happened while restoring state
if top is None:
try:
fn(*args, **kwargs)
except:
exc = sys.exc_info()
top = contexts[1]

# If there was exception, try to handle it by going through the exception chain
if top is not None:
exc = _handle_exception(top, exc)
else:
# Otherwise take shorter path and run stack contexts in reverse order
for n in xrange(last_ctx - 1, -1, -1):
c = stack[n]

try:
c.exit(*exc)
except:
exc = sys.exc_info()
top = c.old_contexts[1]
break
else:
top = None

# If if exception happened while unrolling, take longer exception handler path
if top is not None:
exc = _handle_exception(top, exc)

# If exception was not handled, raise it
if exc != (None, None, None):
raise_exc_info(exc)
finally:
_state.contexts = current_state

wrapped._wrapped = True
return wrapped


def _handle_exception(tail, exc):
while tail is not None:
try:
if tail.exit(*exc):
exc = (None, None, None)
except:
exc = sys.exc_info()

tail = tail.old_contexts[1]

return exc
3 changes: 3 additions & 0 deletions tornado/test/gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,3 +838,6 @@ def test_coroutine_exception_handler(self):
def test_yield_exception_handler(self):
response = self.fetch('/yield_exception')
self.assertEqual(response.body, b'ok')

if __name__ == '__main__':
unittest.main()
32 changes: 0 additions & 32 deletions tornado/test/stack_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,38 +95,6 @@ def final_callback():
library_function(final_callback)
self.wait()

def test_deactivate(self):
deactivate_callbacks = []

def f1():
with StackContext(functools.partial(self.context, 'c1')) as c1:
deactivate_callbacks.append(c1)
self.io_loop.add_callback(f2)

def f2():
with StackContext(functools.partial(self.context, 'c2')) as c2:
deactivate_callbacks.append(c2)
self.io_loop.add_callback(f3)

def f3():
with StackContext(functools.partial(self.context, 'c3')) as c3:
deactivate_callbacks.append(c3)
self.io_loop.add_callback(f4)

def f4():
self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
deactivate_callbacks[1]()
# deactivating a context doesn't remove it immediately,
# but it will be missing from the next iteration
self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
self.io_loop.add_callback(f5)

def f5():
self.assertEqual(self.active_contexts, ['c1', 'c3'])
self.stop()
self.io_loop.add_callback(f1)
self.wait()

def test_isolation_nonempty(self):
# f2 and f3 are a chain of operations started in context c1.
# f2 is incidentally run under context c2, but that context should
Expand Down

0 comments on commit 681e76a

Please sign in to comment.