Skip to content

Commit

Permalink
Merge pull request #1192 from twisted/9719-contextvars
Browse files Browse the repository at this point in the history
Author: hawkowl

Reviewer: glyph

Fixes: ticket:9719

Support contextvars in ensureDeferred and inlineCallbacks coroutines.
  • Loading branch information
glyph authored May 1, 2020
2 parents bc7fef3 + 10ee73c commit db811a3
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 26 deletions.
64 changes: 45 additions & 19 deletions src/twisted/internet/defer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,25 @@
from twisted.logger import Logger
from twisted.python.deprecate import warnAboutFunction, deprecated

try:
from contextvars import copy_context as _copy_context
_contextvarsSupport = True
except ImportError:
_contextvarsSupport = False

class _NoContext:
@staticmethod
def run(f, *args, **kwargs):
return f(*args, **kwargs)


def _copy_context():
return _NoContext

log = Logger()



class AlreadyCalledError(Exception):
pass

Expand Down Expand Up @@ -262,9 +278,9 @@ def __init__(self, canceller=None):
Initialize a L{Deferred}.
@param canceller: a callable used to stop the pending operation
scheduled by this L{Deferred} when L{Deferred.cancel} is
invoked. The canceller will be passed the deferred whose
cancelation is requested (i.e., self).
scheduled by this L{Deferred} when L{Deferred.cancel} is invoked.
The canceller will be passed the deferred whose cancelation is
requested (i.e., self).
If a canceller is not given, or does not invoke its argument's
C{callback} or C{errback} method, L{Deferred.cancel} will
Expand Down Expand Up @@ -650,6 +666,7 @@ def _runCallbacks(self):
current._runningCallbacks = True
try:
current.result = callback(current.result, *args, **kw)

if current.result is current:
warnAboutFunction(
callback,
Expand Down Expand Up @@ -1400,17 +1417,23 @@ def _inlineCallbacks(result, g, status):
# loop and the waiting variable solve that by manually unfolding the
# recursion.

waiting = [True, # waiting for result?
None] # result
waiting = [True, # waiting for result?
None] # result

# Get the current contextvars Context object.
current_context = _copy_context()

while 1:
try:
# Send the last result back as the result of the yield expression.
isFailure = isinstance(result, failure.Failure)

if isFailure:
result = result.throwExceptionIntoGenerator(g)
result = current_context.run(
result.throwExceptionIntoGenerator, g
)
else:
result = g.send(result)
result = current_context.run(g.send, result)
except StopIteration as e:
# fell off the end, or "return" statement
status.deferred.callback(getattr(e, "value", None))
Expand All @@ -1426,6 +1449,12 @@ def _inlineCallbacks(result, g, status):
# _inlineCallbacks); the next one down should be the application
# code.
appCodeTrace = exc_info()[2].tb_next

# If contextvars support is not present, we also have added a frame
# in the no-op shim, remove that
if not _contextvarsSupport:
appCodeTrace = appCodeTrace.tb_next

if isFailure:
# If we invoked this generator frame by throwing an exception
# into it, then throwExceptionIntoGenerator will consume an
Expand Down Expand Up @@ -1466,8 +1495,7 @@ def gotResult(r):
waiting[0] = False
waiting[1] = r
else:
# We are not waiting for deferred result any more
_inlineCallbacks(r, g, status)
current_context.run(_inlineCallbacks, r, g, status)

result.addBoth(gotResult)
if waiting[0]:
Expand Down Expand Up @@ -2002,13 +2030,11 @@ def _tryLock():



__all__ = ["Deferred", "DeferredList", "succeed", "fail", "FAILURE", "SUCCESS",
"AlreadyCalledError", "TimeoutError", "gatherResults",
"maybeDeferred", "ensureDeferred",
"waitForDeferred", "deferredGenerator", "inlineCallbacks",
"returnValue",
"DeferredLock", "DeferredSemaphore", "DeferredQueue",
"DeferredFilesystemLock", "AlreadyTryingToLockError",
"CancelledError",
]

__all__ = [
"Deferred", "DeferredList", "succeed", "fail", "FAILURE", "SUCCESS",
"AlreadyCalledError", "TimeoutError", "gatherResults",
"maybeDeferred", "ensureDeferred",
"waitForDeferred", "deferredGenerator", "inlineCallbacks", "returnValue",
"DeferredLock", "DeferredSemaphore", "DeferredQueue",
"DeferredFilesystemLock", "AlreadyTryingToLockError", "CancelledError",
]
1 change: 1 addition & 0 deletions src/twisted/newsfragments/9719.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
twisted.internet.defer.inlineCallbacks and ensureDeferred will now associate a contextvars.Context with the coroutines they run, meaning that ContextVar objects will maintain their value within the same coroutine, similarly to asyncio Tasks. This functionality requires Python 3.7+, or the contextvars PyPI backport to be installed for Python 3.5-3.6.
180 changes: 173 additions & 7 deletions src/twisted/test/test_defer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,29 @@

from asyncio import new_event_loop, Future, CancelledError

from twisted.python.reflect import requireModule
from twisted.python import failure, log
from twisted.trial import unittest
from twisted.internet import defer, reactor
from twisted.internet.task import Clock


contextvars = requireModule('contextvars')
if contextvars:
contextvarsSkip = None
else:
contextvarsSkip = "contextvars is not available"



def ensuringDeferred(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
result = f(*args, **kwargs)
return defer.ensureDeferred(result)
return wrapper



class GenericError(Exception):
pass
Expand Down Expand Up @@ -3090,16 +3107,88 @@ def test_fromFutureDeferredCancelled(self):



def ensuringDeferred(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
result = f(*args, **kwargs)
return defer.ensureDeferred(result)
return wrapper
class CoroutineContextVarsTests(unittest.TestCase):

skip = contextvarsSkip

def test_withInlineCallbacks(self):
"""
When an inlineCallbacks function is called, the context is taken from
when it was first called. When it resumes, the same context is applied.
"""
clock = Clock()

var = contextvars.ContextVar("testvar")
var.set(1)

# This Deferred will set its own context to 3 when it is called
mutatingDeferred = defer.Deferred()
mutatingDeferred.addCallback(lambda _: var.set(3))

mutatingDeferredThatFails = defer.Deferred()
mutatingDeferredThatFails.addCallback(lambda _: var.set(4))
mutatingDeferredThatFails.addCallback(lambda _: 1 / 0)

@defer.inlineCallbacks
def yieldingDeferred():
d = defer.Deferred()
clock.callLater(1, d.callback, True)
yield d
var.set(3)

# context is 1 when the function is defined
@defer.inlineCallbacks
def testFunction():

# Expected to be 2
self.assertEqual(var.get(), 2)

# Does not mutate the context
yield defer.succeed(1)

# Expected to be 2
self.assertEqual(var.get(), 2)

# mutatingDeferred mutates it to 3, but only in its Deferred chain
clock.callLater(1, mutatingDeferred.callback, True)
yield mutatingDeferred

# When it resumes, it should still be 2
self.assertEqual(var.get(), 2)

class DeferredTestsAsync(unittest.TestCase):
# mutatingDeferredThatFails mutates it to 3, but only in its
# Deferred chain
clock.callLater(1, mutatingDeferredThatFails.callback, True)
try:
yield mutatingDeferredThatFails
except Exception:
self.assertEqual(var.get(), 2)
else:
raise Exception("???? should have failed")

# IMPLEMENTATION DETAIL: Because inlineCallbacks must be at every
# level, an inlineCallbacks function yielding another
# inlineCallbacks function will NOT mutate the outer one's context,
# as it is copied when the inner one is ran and mutated there.
yield yieldingDeferred()
self.assertEqual(var.get(), 2)

defer.returnValue(True)

# The inlineCallbacks context is 2 when it's called
var.set(2)
d = testFunction()

# Advance the clock so mutatingDeferred triggers
clock.advance(1)

# Advance the clock so that mutatingDeferredThatFails triggers
clock.advance(1)

# Advance the clock so that yieldingDeferred triggers
clock.advance(1)

self.assertEqual(self.successResultOf(d), True)


@ensuringDeferred
Expand Down Expand Up @@ -3157,3 +3246,80 @@ async def test_asyncWithLockException(self):
self.assertTrue(lock.locked)
raise Exception('some specific exception')
self.assertFalse(lock.locked)


def test_contextvarsWithAsyncAwait(self):
"""
When a coroutine is called, the context is taken from when it was first
called. When it resumes, the same context is applied.
"""
clock = Clock()

var = contextvars.ContextVar("testvar")
var.set(1)

# This Deferred will set its own context to 3 when it is called
mutatingDeferred = defer.Deferred()
mutatingDeferred.addCallback(lambda _: var.set(3))

mutatingDeferredThatFails = defer.Deferred()
mutatingDeferredThatFails.addCallback(lambda _: var.set(4))
mutatingDeferredThatFails.addCallback(lambda _: 1 / 0)

async def asyncFuncAwaitingDeferred():
d = defer.Deferred()
clock.callLater(1, d.callback, True)
await d
var.set(3)

# context is 1 when the function is defined
async def testFunction():

# Expected to be 2
self.assertEqual(var.get(), 2)

# Does not mutate the context
await defer.succeed(1)

# Expected to be 2
self.assertEqual(var.get(), 2)

# mutatingDeferred mutates it to 3, but only in its Deferred chain
clock.callLater(0, mutatingDeferred.callback, True)
await mutatingDeferred

# When it resumes, it should still be 2
self.assertEqual(var.get(), 2)

# mutatingDeferredThatFails mutates it to 3, but only in its
# Deferred chain
clock.callLater(1, mutatingDeferredThatFails.callback, True)
try:
await mutatingDeferredThatFails
except Exception:
self.assertEqual(var.get(), 2)
else:
raise Exception("???? should have failed")

# If we await another async def-defined function, it will be able
# to mutate the outer function's context, it is *not* frozen and
# restored inside the function call.
await asyncFuncAwaitingDeferred()
self.assertEqual(var.get(), 3)

return True

# The inlineCallbacks context is 2 when it's called
var.set(2)
d = defer.ensureDeferred(testFunction())

# Advance the clock so mutatingDeferred triggers
clock.advance(1)

# Advance the clock so that mutatingDeferredThatFails triggers
clock.advance(1)

# Advance the clock so that asyncFuncAwaitingDeferred triggers
clock.advance(1)

self.assertEqual(self.successResultOf(d), True)

0 comments on commit db811a3

Please sign in to comment.