Skip to content

Commit

Permalink
fix: add attr futures in anyio taskgroup
Browse files Browse the repository at this point in the history
  • Loading branch information
phi-friday committed Jul 29, 2023
1 parent 569cb31 commit 3d290d8
Showing 1 changed file with 29 additions and 20 deletions.
49 changes: 29 additions & 20 deletions src/async_wrapper/task_group/_anyio.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from asyncio import create_task, wait
from functools import partial, wraps
from asyncio import create_task, ensure_future, wait
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -12,7 +11,7 @@
TypeVar,
final,
)
from weakref import WeakSet
from weakref import WeakSet, WeakValueDictionary

from typing_extensions import ParamSpec, Self, override

Expand All @@ -29,7 +28,7 @@
from typing import Any as _TaskGroup

if TYPE_CHECKING:
from asyncio import Task
from asyncio import Future, Task
from types import TracebackType

from anyio.abc import Semaphore as AnyioSemaphore # type: ignore
Expand All @@ -47,7 +46,11 @@
class TaskGroup(BaseTaskGroup):
def __init__(self) -> None:
self._task_group: _TaskGroup = _get_task_group()
self._tasks: WeakSet[Task[Any]] = WeakSet()
self._futures: WeakSet[Future[Any]] = WeakSet()
self._task_futures: WeakValueDictionary[
Task[Any],
Future[Any],
] = WeakValueDictionary()

@override
def start_soon(
Expand All @@ -57,24 +60,30 @@ def start_soon(
**kwargs: ParamT.kwargs,
) -> SoonValue[ValueT_co]:
value = SoonValue()
wrapped = self._wrap_as_value(func, value)
self._task_group.start_soon(partial(wrapped, **kwargs), *args)
future = self._as_future(func, *args, **kwargs)
value._set_task_or_future(future) # noqa: SLF001
self._task_group.start_soon(self._as_task, future)
return value

def _wrap_as_value(
def _as_future(
self,
func: Callable[ParamT, Coroutine[Any, Any, ValueT_co]],
value: SoonValue[ValueT_co],
) -> Callable[ParamT, Coroutine[None, None, None]]:
@wraps(func)
async def inner(*args: ParamT.args, **kwargs: ParamT.kwargs) -> None:
coro = func(*args, **kwargs)
task = create_task(coro)
value._set_task_or_future(task) # noqa: SLF001
self.tasks.add(task)
await task

return inner
*args: ParamT.args,
**kwargs: ParamT.kwargs,
) -> Future[ValueT_co]:
coro = func(*args, **kwargs)
future = ensure_future(coro)
self._futures.add(future)
return future

async def _as_coro(self, future: Future[ValueT_co]) -> ValueT_co:
return await future

async def _as_task(self, future: Future[ValueT_co]) -> ValueT_co:
coro = self._as_coro(future)
task = create_task(coro)
self._task_futures[task] = future
return await task

@property
@override
Expand All @@ -84,7 +93,7 @@ def is_active(self) -> bool:
@property
@override
def tasks(self) -> WeakSet[Task[Any]]:
return self._tasks
return WeakSet(self._task_futures.keys())

@override
async def __aenter__(self) -> Self:
Expand Down

0 comments on commit 3d290d8

Please sign in to comment.