Skip to content

Commit

Permalink
fix DeferredLock/Semaphore.run
Browse files Browse the repository at this point in the history
  • Loading branch information
graingert committed Sep 12, 2023
1 parent 026ef42 commit 5516574
Showing 1 changed file with 53 additions and 36 deletions.
89 changes: 53 additions & 36 deletions src/twisted/internet/defer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

import attr
from incremental import Version
from typing_extensions import Concatenate, Literal, ParamSpec
from typing_extensions import Concatenate, Literal, ParamSpec, Self

from twisted.internet.interfaces import IDelayedCall, IReactorTime
from twisted.logger import Logger
Expand Down Expand Up @@ -2228,26 +2228,47 @@ def unwindGenerator(*args: _P.args, **kwargs: _P.kwargs) -> Deferred[_T]:
## DeferredLock/DeferredQueue


_ConcurrencyPrimitiveT = TypeVar(
"_ConcurrencyPrimitiveT", bound="_ConcurrencyPrimitive[Any]"
)


class _ConcurrencyPrimitive(ABC, Generic[_SelfResultT]):
def __init__(self: _ConcurrencyPrimitiveT) -> None:
self.waiting: List[Deferred[_ConcurrencyPrimitiveT]] = []
class _ConcurrencyPrimitive(ABC):
def __init__(self: Self) -> None:
self.waiting: List[Deferred[Self]] = []

def _releaseAndReturn(self, r: _T) -> _T:
self.release()
return r

@overload
def run(
self: _ConcurrencyPrimitiveT,
self: Self,
/,
f: Callable[..., _SelfResultT],
*args: object,
**kwargs: object,
) -> Deferred[_SelfResultT]:
f: Callable[_P, Deferred[_T]],
*args: _P.args,
**kwargs: _P.kwargs,
) -> Deferred[_T]:
...

@overload
def run(
self: Self,
/,
f: Callable[_P, Coroutine[Deferred[_T], object, _T]],
*args: _P.args,
**kwargs: _P.kwargs,
) -> Deferred[_T]:
...

@overload
def run(
self: Self, /, f: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> Deferred[_T]:
...

def run(
self: Self,
/,
f: Callable[_P, Union[Deferred[_T], Coroutine[Deferred[_T], object, _T], _T]],
*args: _P.args,
**kwargs: _P.kwargs,
) -> Deferred[_T]:
"""
Acquire, run, release.
Expand All @@ -2262,12 +2283,16 @@ def run(
@return: L{Deferred} of function result.
"""

def execute(ignoredResult: object) -> Deferred[_SelfResultT]:
return maybeDeferred(f, *args, **kwargs).addBoth(self._releaseAndReturn)
def execute(ignoredResult: object) -> Deferred[_T]:
# maybeDeferred arg type requires one of the possible union members
# and won't accept all possible union members
return maybeDeferred(f, *args, **kwargs).addBoth(
self._releaseAndReturn
) # type: ignore[return-value]

return self.acquire().addCallback(execute)

def __aenter__(self: _ConcurrencyPrimitiveT) -> Deferred[_ConcurrencyPrimitiveT]:
def __aenter__(self: Self) -> Deferred[Self]:
"""
We can be used as an asynchronous context manager.
"""
Expand All @@ -2285,18 +2310,15 @@ def __aexit__(
return succeed(False)

@abstractmethod
def acquire(self: _ConcurrencyPrimitiveT) -> Deferred[_ConcurrencyPrimitiveT]:
def acquire(self: Self) -> Deferred[Self]:
pass

@abstractmethod
def release(self) -> None:
pass


_DeferredLockT = TypeVar("_DeferredLockT", bound="DeferredLock")


class DeferredLock(_ConcurrencyPrimitive[Any]):
class DeferredLock(_ConcurrencyPrimitive):
"""
A lock for event driven systems.
Expand All @@ -2307,7 +2329,7 @@ class DeferredLock(_ConcurrencyPrimitive[Any]):

locked = False

def _cancelAcquire(self: _DeferredLockT, d: Deferred[_DeferredLockT]) -> None:
def _cancelAcquire(self: Self, d: Deferred[Self]) -> None:
"""
Remove a deferred d from our waiting list, as the deferred has been
canceled.
Expand All @@ -2321,7 +2343,7 @@ def _cancelAcquire(self: _DeferredLockT, d: Deferred[_DeferredLockT]) -> None:
"""
self.waiting.remove(d)

def acquire(self: _DeferredLockT) -> Deferred[_DeferredLockT]:
def acquire(self: Self) -> Deferred[Self]:
"""
Attempt to acquire the lock. Returns a L{Deferred} that fires on
lock acquisition with the L{DeferredLock} as the value. If the lock
Expand All @@ -2330,15 +2352,15 @@ def acquire(self: _DeferredLockT) -> Deferred[_DeferredLockT]:
@return: a L{Deferred} which fires on lock acquisition.
@rtype: a L{Deferred}
"""
d: Deferred[_DeferredLockT] = Deferred(canceller=self._cancelAcquire)
d: Deferred[Self] = Deferred(canceller=self._cancelAcquire)
if self.locked:
self.waiting.append(d)
else:
self.locked = True
d.callback(self)
return d

def release(self: _DeferredLockT) -> None:
def release(self: Self) -> None:
"""
Release the lock. If there is a waiting list, then the first
L{Deferred} in that waiting list will be called back.
Expand All @@ -2355,10 +2377,7 @@ def release(self: _DeferredLockT) -> None:
d.callback(self)


_DeferredSemaphoreT = TypeVar("_DeferredSemaphoreT", bound="DeferredSemaphore")


class DeferredSemaphore(_ConcurrencyPrimitive[Any]):
class DeferredSemaphore(_ConcurrencyPrimitive):
"""
A semaphore for event driven systems.
Expand All @@ -2382,9 +2401,7 @@ def __init__(self, tokens: int) -> None:
self.tokens = tokens
self.limit = tokens

def _cancelAcquire(
self: _DeferredSemaphoreT, d: Deferred[_DeferredSemaphoreT]
) -> None:
def _cancelAcquire(self: Self, d: Deferred[Self]) -> None:
"""
Remove a deferred d from our waiting list, as the deferred has been
canceled.
Expand All @@ -2398,7 +2415,7 @@ def _cancelAcquire(
"""
self.waiting.remove(d)

def acquire(self: _DeferredSemaphoreT) -> Deferred[_DeferredSemaphoreT]:
def acquire(self: Self) -> Deferred[Self]:
"""
Attempt to acquire the token.
Expand All @@ -2407,15 +2424,15 @@ def acquire(self: _DeferredSemaphoreT) -> Deferred[_DeferredSemaphoreT]:
assert (
self.tokens >= 0
), "Internal inconsistency?? tokens should never be negative"
d: Deferred[_DeferredSemaphoreT] = Deferred(canceller=self._cancelAcquire)
d: Deferred[Self] = Deferred(canceller=self._cancelAcquire)
if not self.tokens:
self.waiting.append(d)
else:
self.tokens = self.tokens - 1
d.callback(self)
return d

def release(self: _DeferredSemaphoreT) -> None:
def release(self: Self) -> None:
"""
Release the token.
Expand Down

0 comments on commit 5516574

Please sign in to comment.