Skip to content

Commit

Permalink
fix annotation of maybeDeferred
Browse files Browse the repository at this point in the history
  • Loading branch information
glyph committed Nov 29, 2022
1 parent 6656db0 commit 3b8c7cf
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
16 changes: 14 additions & 2 deletions src/twisted/internet/defer.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,15 @@ def maybeDeferred(
...


@overload
def maybeDeferred(
f: Callable[_P, Coroutine[Deferred[_T], object, _T]],
*args: _P.args,
**kwargs: _P.kwargs,
) -> "Deferred[_T]":
...


@overload
def maybeDeferred(
f: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
Expand All @@ -190,7 +199,9 @@ def maybeDeferred(


def maybeDeferred(
f: Callable[_P, Union[_T, Deferred[_T]]], *args: _P.args, **kwargs: _P.kwargs
f: Callable[_P, Union[Deferred[_T], Coroutine[Deferred[_T], object, _T], _T]],
*args: _P.args,
**kwargs: _P.kwargs,
) -> "Deferred[_T]":
"""
Invoke a function that may or may not return a L{Deferred} or coroutine.
Expand Down Expand Up @@ -246,7 +257,8 @@ def maybeDeferred(
# case. Such values always have exactly one type: CoroutineType.
return Deferred.fromCoroutine(result)
else:
return succeed(result)
returned: _T = result # type: ignore
return succeed(returned)


@deprecated(
Expand Down
16 changes: 11 additions & 5 deletions src/twisted/test/test_defer.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,17 +879,23 @@ def test_maybeDeferredCoroutineSuccess(self) -> None:
L{defer.Deferred} which has the same result as the coroutine returned
by the function.
"""
result = object()

async def f() -> object:
return result
async def f() -> int:
return 7

# Demonstrate that the function itself does not need to be a coroutine
# function to trigger the coroutine-handling behavior.
def g() -> Coroutine:
def g() -> Coroutine[Deferred[int], Any, Any]:
return f()

self.assertIs((self.successResultOf(defer.maybeDeferred(g))), result)
# Provide a simple callback mainly to ensure the type-checking on
# maybeDeferred is correct.
def typedCallback(result: int) -> int:
return result + 1

coroutineDeferred = defer.maybeDeferred(g)
modifiedDeferred = coroutineDeferred.addCallback(typedCallback)
self.assertEqual(self.successResultOf(modifiedDeferred), 8)

def test_maybeDeferredCoroutineFailure(self) -> None:
"""
Expand Down

0 comments on commit 3b8c7cf

Please sign in to comment.