Skip to content

Migrate work for finished CurrentThreadExecutor to previous executor #494

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 3, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 40 additions & 23 deletions asgiref/current_thread_executor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import queue
import sys
import threading
from collections import deque
from concurrent.futures import Executor, Future
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, TypeVar

if sys.version_info >= (3, 10):
from typing import ParamSpec
@@ -53,10 +53,12 @@ class CurrentThreadExecutor(Executor):
the thread they came from.
"""

def __init__(self) -> None:
def __init__(self, old_executor: "CurrentThreadExecutor | None") -> None:
self._work_thread = threading.current_thread()
self._work_queue: queue.Queue[Union[_WorkItem, "Future[Any]"]] = queue.Queue()
self._broken = False
self._work_ready = threading.Condition(threading.Lock())
self._work_items = deque[_WorkItem]() # synchronized by _work_ready
self._broken = False # synchronized by _work_ready
self._old_executor = old_executor

def run_until_future(self, future: "Future[Any]") -> None:
"""
@@ -68,20 +70,25 @@ def run_until_future(self, future: "Future[Any]") -> None:
raise RuntimeError(
"You cannot run CurrentThreadExecutor from a different thread"
)
future.add_done_callback(self._work_queue.put)
# Keep getting and running work items until we get the future we're waiting for
# back via the future's done callback.
try:
while True:

def done(future: "Future[Any]") -> None:
with self._work_ready:
self._broken = True
self._work_ready.notify()

future.add_done_callback(done)
# Keep getting and running work items until the future we're waiting for
# is done and the queue is empty.
while True:
with self._work_ready:
while not self._work_items and not self._broken:
self._work_ready.wait()
if not self._work_items:
break
# Get a work item and run it
work_item = self._work_queue.get()
if work_item is future:
return
assert isinstance(work_item, _WorkItem)
work_item.run()
del work_item
finally:
self._broken = True
work_item = self._work_items.popleft()
work_item.run()
del work_item

def _submit(
self,
@@ -94,13 +101,23 @@ def _submit(
raise RuntimeError(
"You cannot submit onto CurrentThreadExecutor from its own thread"
)
# Check they're not too late or the executor errored
if self._broken:
raise RuntimeError("CurrentThreadExecutor already quit or is broken")
# Add to work queue
f: "Future[_R]" = Future()
work_item = _WorkItem(f, fn, *args, **kwargs)
self._work_queue.put(work_item)

# Walk up the CurrentThreadExecutor stack to find the closest one still
# running
executor = self
while True:
with executor._work_ready:
if not executor._broken:
# Add to work queue
executor._work_items.append(work_item)
executor._work_ready.notify()
break
if executor._old_executor is None:
raise RuntimeError("CurrentThreadExecutor already quit or is broken")
executor = executor._old_executor

# Return the future
return f

2 changes: 1 addition & 1 deletion asgiref/sync.py
Original file line number Diff line number Diff line change
@@ -196,7 +196,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
# need one for every sync frame, even if there's one above us in the
# same thread.
old_executor = getattr(self.executors, "current", None)
current_executor = CurrentThreadExecutor()
current_executor = CurrentThreadExecutor(old_executor)
self.executors.current = current_executor

# Wrapping context in list so it can be reassigned from within
73 changes: 73 additions & 0 deletions tests/test_sync.py
Original file line number Diff line number Diff line change
@@ -1208,3 +1208,76 @@ def test_function(**kwargs: Any) -> None:

# SyncToAsync.__call__.loop.run_in_executor has a param named `task_context`.
await test_function(task_context=1)


def test_nested_task() -> None:
async def inner() -> asyncio.Task[None]:
return asyncio.create_task(sync_to_async(print)("inner"))

async def main() -> None:
task = await sync_to_async(async_to_sync(inner))()
await task

async_to_sync(main)()


def test_nested_task_later() -> None:
def later(fut: asyncio.Future[asyncio.Task[None]]) -> None:
task = asyncio.create_task(sync_to_async(print)("later"))
fut.set_result(task)

async def inner() -> asyncio.Future[asyncio.Task[None]]:
loop = asyncio.get_running_loop()
fut = loop.create_future()
loop.call_later(0.1, later, fut)
return fut

async def main() -> None:
fut = await sync_to_async(async_to_sync(inner))()
task = await fut
await task

async_to_sync(main)()


def test_double_nested_task() -> None:
async def inner() -> asyncio.Task[None]:
return asyncio.create_task(sync_to_async(print)("inner"))

async def outer() -> asyncio.Task[asyncio.Task[None]]:
return asyncio.create_task(sync_to_async(async_to_sync(inner))())

async def main() -> None:
outer_task = await sync_to_async(async_to_sync(outer))()
inner_task = await outer_task
await inner_task

async_to_sync(main)()


# asyncio.Barrier is new in Python 3.11. Nest definition (rather than using
# skipIf) to avoid mypy error.
if sys.version_info >= (3, 11):

def test_two_nested_tasks_with_asyncio_run() -> None:
barrier = asyncio.Barrier(3)
event = threading.Event()

async def inner() -> None:
task = asyncio.create_task(sync_to_async(event.wait)())
await barrier.wait()
await task

async def outer() -> tuple[asyncio.Task[None], asyncio.Task[None]]:
task0 = asyncio.create_task(inner())
task1 = asyncio.create_task(inner())
await barrier.wait()
event.set()
return task0, task1

async def main() -> None:
task0, task1 = await sync_to_async(async_to_sync(outer))()
await task0
await task1

asyncio.run(main())