Skip to content

Commit 89d2064

Browse files
committed
split out trio/asyncio compat
1 parent 23b57e1 commit 89d2064

File tree

4 files changed

+256
-207
lines changed

4 files changed

+256
-207
lines changed

asgiref/_asyncio.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
__all__ = [
2+
"get_running_loop",
3+
"create_task_threadsafe",
4+
"wrap_task_context",
5+
"run_in_executor",
6+
]
7+
8+
import asyncio
9+
import contextvars
10+
import functools
11+
import sys
12+
from asyncio import get_running_loop
13+
from collections.abc import Callable
14+
from typing import Any, TypeVar
15+
16+
from ._context import restore_context as _restore_context
17+
18+
_R = TypeVar("_R")
19+
20+
21+
def create_task_threadsafe(loop, awaitable) -> None:
22+
loop.call_soon_threadsafe(loop.create_task, awaitable)
23+
24+
25+
async def wrap_task_context(loop, task_context, awaitable):
26+
if task_context is None:
27+
return await awaitable
28+
29+
current_task = asyncio.current_task(loop)
30+
if current_task is None:
31+
return await awaitable
32+
33+
task_context.append(current_task)
34+
try:
35+
return await awaitable
36+
finally:
37+
task_context.remove(current_task)
38+
39+
40+
async def run_in_executor(
41+
*, loop, executor, thread_handler, child: Callable[[], _R]
42+
) -> _R:
43+
context = contextvars.copy_context()
44+
func = context.run
45+
task_context: list[asyncio.Task[Any]] = []
46+
47+
# Run the code in the right thread
48+
exec_coro = loop.run_in_executor(
49+
executor,
50+
functools.partial(
51+
thread_handler,
52+
loop,
53+
sys.exc_info(),
54+
task_context,
55+
func,
56+
child,
57+
),
58+
)
59+
ret: _R
60+
try:
61+
ret = await asyncio.shield(exec_coro)
62+
except asyncio.CancelledError:
63+
cancel_parent = True
64+
try:
65+
task = task_context[0]
66+
task.cancel()
67+
try:
68+
await task
69+
cancel_parent = False
70+
except asyncio.CancelledError:
71+
pass
72+
except IndexError:
73+
pass
74+
if exec_coro.done():
75+
raise
76+
if cancel_parent:
77+
exec_coro.cancel()
78+
ret = await exec_coro
79+
finally:
80+
_restore_context(context)
81+
82+
return ret

asgiref/_context.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import contextvars
2+
3+
4+
def restore_context(context: contextvars.Context) -> None:
5+
# Check for changes in contextvars, and set them to the current
6+
# context for downstream consumers
7+
for cvar in context:
8+
cvalue = context.get(cvar)
9+
try:
10+
if cvar.get() != cvalue:
11+
cvar.set(cvalue)
12+
except LookupError:
13+
cvar.set(cvalue)

asgiref/_trio.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import asyncio
2+
import contextvars
3+
import functools
4+
import sys
5+
from typing import Any
6+
7+
import sniffio
8+
import trio.lowlevel
9+
import trio.to_thread
10+
11+
from . import _asyncio
12+
from ._context import restore_context as _restore_context
13+
14+
15+
class TrioThreadCancelled(BaseException):
16+
pass
17+
18+
19+
def get_running_loop():
20+
try:
21+
asynclib = sniffio.current_async_library()
22+
except sniffio.AsyncLibraryNotFoundError:
23+
return asyncio.get_running_loop()
24+
25+
if asynclib == "asyncio":
26+
return asyncio.get_running_loop()
27+
if asynclib == "trio":
28+
return trio.lowlevel.current_token()
29+
raise RuntimeError(f"unsupported library {asynclib}")
30+
31+
32+
@trio.lowlevel.disable_ki_protection
33+
async def wrap_awaitable(awaitable):
34+
return await awaitable
35+
36+
37+
def create_task_threadsafe(loop, awaitable):
38+
if isinstance(loop, trio.lowlevel.TrioToken):
39+
try:
40+
loop.run_sync_soon(
41+
trio.lowlevel.spawn_system_task,
42+
wrap_awaitable,
43+
awaitable,
44+
)
45+
except trio.RunFinishedError:
46+
raise RuntimeError("trio loop no-longer running")
47+
48+
return _asyncio.create_task_threadsafe(loop, awaitable)
49+
50+
51+
async def run_in_executor(*, loop, executor, thread_handler, child):
52+
if isinstance(loop, trio.lowlevel.TrioToken):
53+
context = contextvars.copy_context()
54+
func = context.run
55+
task_context: list[asyncio.Task[Any]] = []
56+
57+
# Run the code in the right thread
58+
full_func = functools.partial(
59+
thread_handler,
60+
loop,
61+
sys.exc_info(),
62+
task_context,
63+
func,
64+
child,
65+
)
66+
try:
67+
if executor is None:
68+
69+
async def handle_cancel():
70+
try:
71+
await trio.sleep_forever()
72+
except trio.Cancelled:
73+
if task_context:
74+
task_context[0].cancel()
75+
raise
76+
77+
async with trio.open_nursery() as nursery:
78+
nursery.start_soon(handle_cancel)
79+
try:
80+
return await trio.to_thread.run_sync(
81+
thread_handler, func, abandon_on_cancel=False
82+
)
83+
except TrioThreadCancelled:
84+
pass
85+
finally:
86+
nursery.cancel_scope.cancel()
87+
else:
88+
event = trio.Event()
89+
90+
def callback(fut):
91+
loop.run_sync_soon(event.set)
92+
93+
fut = executor.submit(full_func)
94+
fut.add_done_callback(callback)
95+
96+
async def handle_cancel_fut():
97+
try:
98+
await trio.sleep_forever()
99+
except trio.Cancelled:
100+
fut.cancel()
101+
if task_context:
102+
task_context[0].cancel()
103+
raise
104+
105+
async with trio.open_nursery() as nursery:
106+
nursery.start_soon(handle_cancel_fut)
107+
with trio.CancelScope(shield=True):
108+
await event.wait()
109+
nursery.cancel_scope.cancel()
110+
try:
111+
return fut.result()
112+
except TrioThreadCancelled:
113+
pass
114+
finally:
115+
_restore_context(context)
116+
117+
return await _asyncio.run_in_executor(
118+
loop=loop, executor=executor, thread_handler=thread_handler, func=func
119+
)
120+
121+
122+
async def wrap_task_context(loop, task_context, awaitable):
123+
if task_context is None:
124+
return await awaitable
125+
126+
if isinstance(loop, trio.lowlevel.TrioToken):
127+
with trio.CancelScope() as scope:
128+
task_context.append(scope)
129+
try:
130+
return await awaitable
131+
finally:
132+
task_context.remove(scope)
133+
if scope.cancelled_caught:
134+
raise TrioThreadCancelled
135+
136+
return await _asyncio.wrap_task_context(loop, task_context, awaitable)

0 commit comments

Comments
 (0)