Skip to content

Commit e41f43d

Browse files
committed
Add CPython 3.10's test/support/threading_helper
1 parent d20dbf9 commit e41f43d

File tree

1 file changed

+209
-0
lines changed

1 file changed

+209
-0
lines changed
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
import _thread
2+
import contextlib
3+
import functools
4+
import sys
5+
import threading
6+
import time
7+
8+
from test import support
9+
10+
11+
#=======================================================================
12+
# Threading support to prevent reporting refleaks when running regrtest.py -R
13+
14+
# NOTE: we use thread._count() rather than threading.enumerate() (or the
15+
# moral equivalent thereof) because a threading.Thread object is still alive
16+
# until its __bootstrap() method has returned, even after it has been
17+
# unregistered from the threading module.
18+
# thread._count(), on the other hand, only gets decremented *after* the
19+
# __bootstrap() method has returned, which gives us reliable reference counts
20+
# at the end of a test run.
21+
22+
23+
def threading_setup():
24+
return _thread._count(), threading._dangling.copy()
25+
26+
27+
def threading_cleanup(*original_values):
28+
_MAX_COUNT = 100
29+
30+
for count in range(_MAX_COUNT):
31+
values = _thread._count(), threading._dangling
32+
if values == original_values:
33+
break
34+
35+
if not count:
36+
# Display a warning at the first iteration
37+
support.environment_altered = True
38+
dangling_threads = values[1]
39+
support.print_warning(f"threading_cleanup() failed to cleanup "
40+
f"{values[0] - original_values[0]} threads "
41+
f"(count: {values[0]}, "
42+
f"dangling: {len(dangling_threads)})")
43+
for thread in dangling_threads:
44+
support.print_warning(f"Dangling thread: {thread!r}")
45+
46+
# Don't hold references to threads
47+
dangling_threads = None
48+
values = None
49+
50+
time.sleep(0.01)
51+
support.gc_collect()
52+
53+
54+
def reap_threads(func):
55+
"""Use this function when threads are being used. This will
56+
ensure that the threads are cleaned up even when the test fails.
57+
"""
58+
@functools.wraps(func)
59+
def decorator(*args):
60+
key = threading_setup()
61+
try:
62+
return func(*args)
63+
finally:
64+
threading_cleanup(*key)
65+
return decorator
66+
67+
68+
@contextlib.contextmanager
69+
def wait_threads_exit(timeout=None):
70+
"""
71+
bpo-31234: Context manager to wait until all threads created in the with
72+
statement exit.
73+
74+
Use _thread.count() to check if threads exited. Indirectly, wait until
75+
threads exit the internal t_bootstrap() C function of the _thread module.
76+
77+
threading_setup() and threading_cleanup() are designed to emit a warning
78+
if a test leaves running threads in the background. This context manager
79+
is designed to cleanup threads started by the _thread.start_new_thread()
80+
which doesn't allow to wait for thread exit, whereas thread.Thread has a
81+
join() method.
82+
"""
83+
if timeout is None:
84+
timeout = support.SHORT_TIMEOUT
85+
old_count = _thread._count()
86+
try:
87+
yield
88+
finally:
89+
start_time = time.monotonic()
90+
deadline = start_time + timeout
91+
while True:
92+
count = _thread._count()
93+
if count <= old_count:
94+
break
95+
if time.monotonic() > deadline:
96+
dt = time.monotonic() - start_time
97+
msg = (f"wait_threads() failed to cleanup {count - old_count} "
98+
f"threads after {dt:.1f} seconds "
99+
f"(count: {count}, old count: {old_count})")
100+
raise AssertionError(msg)
101+
time.sleep(0.010)
102+
support.gc_collect()
103+
104+
105+
def join_thread(thread, timeout=None):
106+
"""Join a thread. Raise an AssertionError if the thread is still alive
107+
after timeout seconds.
108+
"""
109+
if timeout is None:
110+
timeout = support.SHORT_TIMEOUT
111+
thread.join(timeout)
112+
if thread.is_alive():
113+
msg = f"failed to join the thread in {timeout:.1f} seconds"
114+
raise AssertionError(msg)
115+
116+
117+
@contextlib.contextmanager
118+
def start_threads(threads, unlock=None):
119+
import faulthandler
120+
threads = list(threads)
121+
started = []
122+
try:
123+
try:
124+
for t in threads:
125+
t.start()
126+
started.append(t)
127+
except:
128+
if support.verbose:
129+
print("Can't start %d threads, only %d threads started" %
130+
(len(threads), len(started)))
131+
raise
132+
yield
133+
finally:
134+
try:
135+
if unlock:
136+
unlock()
137+
endtime = time.monotonic()
138+
for timeout in range(1, 16):
139+
endtime += 60
140+
for t in started:
141+
t.join(max(endtime - time.monotonic(), 0.01))
142+
started = [t for t in started if t.is_alive()]
143+
if not started:
144+
break
145+
if support.verbose:
146+
print('Unable to join %d threads during a period of '
147+
'%d minutes' % (len(started), timeout))
148+
finally:
149+
started = [t for t in started if t.is_alive()]
150+
if started:
151+
faulthandler.dump_traceback(sys.stdout)
152+
raise AssertionError('Unable to join %d threads' % len(started))
153+
154+
155+
class catch_threading_exception:
156+
"""
157+
Context manager catching threading.Thread exception using
158+
threading.excepthook.
159+
160+
Attributes set when an exception is caught:
161+
162+
* exc_type
163+
* exc_value
164+
* exc_traceback
165+
* thread
166+
167+
See threading.excepthook() documentation for these attributes.
168+
169+
These attributes are deleted at the context manager exit.
170+
171+
Usage:
172+
173+
with threading_helper.catch_threading_exception() as cm:
174+
# code spawning a thread which raises an exception
175+
...
176+
177+
# check the thread exception, use cm attributes:
178+
# exc_type, exc_value, exc_traceback, thread
179+
...
180+
181+
# exc_type, exc_value, exc_traceback, thread attributes of cm no longer
182+
# exists at this point
183+
# (to avoid reference cycles)
184+
"""
185+
186+
def __init__(self):
187+
self.exc_type = None
188+
self.exc_value = None
189+
self.exc_traceback = None
190+
self.thread = None
191+
self._old_hook = None
192+
193+
def _hook(self, args):
194+
self.exc_type = args.exc_type
195+
self.exc_value = args.exc_value
196+
self.exc_traceback = args.exc_traceback
197+
self.thread = args.thread
198+
199+
def __enter__(self):
200+
self._old_hook = threading.excepthook
201+
threading.excepthook = self._hook
202+
return self
203+
204+
def __exit__(self, *exc_info):
205+
threading.excepthook = self._old_hook
206+
del self.exc_type
207+
del self.exc_value
208+
del self.exc_traceback
209+
del self.thread

0 commit comments

Comments
 (0)