-
Notifications
You must be signed in to change notification settings - Fork 150
/
local.py
134 lines (102 loc) · 3.91 KB
/
local.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Waiting for https://www.python.org/dev/peps/pep-0550/
import asyncio
from collections.abc import Coroutine
_original_task_factories = {}
def get_local():
return getattr(asyncio.Task.current_task(), 'task_local', None)
class TaskWrapper(asyncio.Future):
def __init__(self, coro, loop):
self._passthru = False
# Create task with wrapper which prevents passthru on late join
this = self
class CoroWrapper(Coroutine):
def __await__(self):
this._passthru = False
yield from coro.__await__()
def send(self, value):
this._passthru = False
return coro.send(value)
def throw(self, typ, val=None, tb=None):
this._passthru = False
return coro.throw(typ, val, tb)
def __getattr__(self, item):
return getattr(coro, item)
self._task = loop.create_task(CoroWrapper())
self._task.task_local = {}
if getattr(self._task, '_source_traceback', None):
del getattr(self._task, '_source_traceback')[-3:]
# Initialize Future object
super().__init__(loop=loop)
# Link cancel
def check_cancel(fut):
if fut.cancelled():
self._task.cancel()
self.add_done_callback(check_cancel)
# Further `add_done_callback` may trigger local pass through
self._passthru = True
# Link done callback
def task_done(task):
if self.cancelled():
return
assert not self.done()
if task.cancelled():
self.cancel()
else:
exception = task.exception()
if exception is not None:
self.set_exception(exception)
else:
self.set_result(task.result())
self._task.add_done_callback(task_done)
def add_done_callback(self, *args, **kwargs):
if self._passthru:
self._passthru = False
local = get_local()
if local is not None:
self._task.task_local = local
return super().add_done_callback(*args, **kwargs)
def with_local_reset(self):
self._passthru = False
return self
def __getattr__(self, item):
return getattr(self._task, item)
def reset_local(coro_or_future, *, loop=None):
"""Reset local to empty string within given routine.
This works for:
- newly created tasks
- coroutines
- awaitables
For coroutines and awaitables, new tasks will be created with the give loop
or current loop if not given.
This doesn't work if:
- task local is disabled for current/given loop
- given task is already running
- given futures are not instances of the internal `TaskWrapper` or subclass
It is only the "reset local" part that is not working, the given future is
returned untouched.
"""
if isinstance(coro_or_future, TaskWrapper):
return coro_or_future.with_local_reset()
elif not asyncio.isfuture(coro_or_future):
return reset_local(asyncio.ensure_future(coro_or_future, loop=loop))
else:
# we don't know how to reset local for random Future objects
return coro_or_future
def task_factory(loop, coro):
loop.set_task_factory(_original_task_factories.get(loop))
try:
return TaskWrapper(coro, loop)
finally:
loop.set_task_factory(task_factory)
def enable_task_local(loop=None):
if loop is None:
loop = asyncio.get_event_loop()
if loop in _original_task_factories:
return
_original_task_factories[loop] = loop.get_task_factory()
loop.set_task_factory(task_factory)
def disable_task_local(loop=None):
if loop is None:
loop = asyncio.get_event_loop()
if loop in _original_task_factories:
loop.set_task_factory(_original_task_factories.pop(loop))