Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions Lib/test/support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1722,6 +1722,43 @@ def decorator(*args):
threading_cleanup(*key)
return decorator


@contextlib.contextmanager
def wait_threads_exit(timeout=60.0):
"""
bpo-31234: Context manager to wait until all threads created in the with
statement exit.

Use thread.count() to check if threads exited. Indirectly, wait until
threads exit the internal t_bootstrap() C function of the thread module.

threading_setup() and threading_cleanup() are designed to emit a warning
if a test leaves running threads in the background. This context manager
is designed to cleanup threads started by the thread.start_new_thread()
which doesn't allow to wait for thread exit, whereas thread.Thread has a
join() method.
"""
old_count = thread._count()
try:
yield
finally:
start_time = time.time()
deadline = start_time + timeout
while True:
count = thread._count()
if count <= old_count:
break
if time.time() > deadline:
dt = time.time() - start_time
msg = ("wait_threads() failed to cleanup %s "
"threads after %.1f seconds "
"(count: %s, old count: %s)"
% (count - old_count, dt, count, old_count))
raise AssertionError(msg)
time.sleep(0.010)
gc_collect()


def reap_children():
"""Use this function at the end of test_main() whenever sub-processes
are started. This will help ensure that no extra children (zombies)
Expand Down
27 changes: 14 additions & 13 deletions Lib/test/test_asyncore.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,19 +727,20 @@ def test_quick_connect(self):
server = TCPServer()
t = threading.Thread(target=lambda: asyncore.loop(timeout=0.1, count=500))
t.start()
self.addCleanup(t.join)

for x in xrange(20):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(.2)
s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER,
struct.pack('ii', 1, 0))
try:
s.connect(server.address)
except socket.error:
pass
finally:
s.close()
try:
for x in xrange(20):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(.2)
s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER,
struct.pack('ii', 1, 0))
try:
s.connect(server.address)
except socket.error:
pass
finally:
s.close()
finally:
t.join()


class TestAPI_UseSelect(BaseTestAPI):
Expand Down
20 changes: 10 additions & 10 deletions Lib/test/test_hashlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,25 +371,25 @@ def test_threaded_hashing(self):
data = smallest_data*200000
expected_hash = hashlib.sha1(data*num_threads).hexdigest()

def hash_in_chunks(chunk_size, event):
def hash_in_chunks(chunk_size):
index = 0
while index < len(data):
hasher.update(data[index:index+chunk_size])
index += chunk_size
event.set()

events = []
threads = []
for threadnum in xrange(num_threads):
chunk_size = len(data) // (10**threadnum)
assert chunk_size > 0
assert chunk_size % len(smallest_data) == 0
event = threading.Event()
events.append(event)
threading.Thread(target=hash_in_chunks,
args=(chunk_size, event)).start()

for event in events:
event.wait()
thread = threading.Thread(target=hash_in_chunks,
args=(chunk_size,))
threads.append(thread)

for thread in threads:
thread.start()
for thread in threads:
thread.join()

self.assertEqual(expected_hash, hasher.hexdigest())

Expand Down
1 change: 1 addition & 0 deletions Lib/test/test_httpservers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def run(self):

def stop(self):
self.server.shutdown()
self.join()


class BaseTestCase(unittest.TestCase):
Expand Down
4 changes: 3 additions & 1 deletion Lib/test/test_smtplib.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,14 @@ def setUp(self):
self.sock.settimeout(15)
self.port = test_support.bind_port(self.sock)
servargs = (self.evt, self.respdata, self.sock)
threading.Thread(target=server, args=servargs).start()
self.thread = threading.Thread(target=server, args=servargs)
self.thread.start()
self.evt.wait()
self.evt.clear()

def tearDown(self):
self.evt.wait()
self.thread.join()
sys.stdout = self.old_stdout

def testLineTooLong(self):
Expand Down
106 changes: 57 additions & 49 deletions Lib/test/test_thread.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
import unittest
import random
from test import test_support
thread = test_support.import_module('thread')
from test import support
thread = support.import_module('thread')
import time
import sys
import weakref
Expand All @@ -17,7 +17,7 @@

def verbose_print(arg):
"""Helper function for printing out debugging output."""
if test_support.verbose:
if support.verbose:
with _print_mutex:
print arg

Expand All @@ -34,8 +34,8 @@ def setUp(self):
self.running = 0
self.next_ident = 0

key = test_support.threading_setup()
self.addCleanup(test_support.threading_cleanup, *key)
key = support.threading_setup()
self.addCleanup(support.threading_cleanup, *key)


class ThreadRunningTests(BasicThreadTest):
Expand All @@ -60,12 +60,13 @@ def task(self, ident):
self.done_mutex.release()

def test_starting_threads(self):
# Basic test for thread creation.
for i in range(NUMTASKS):
self.newtask()
verbose_print("waiting for tasks to complete...")
self.done_mutex.acquire()
verbose_print("all tasks done")
with support.wait_threads_exit():
# Basic test for thread creation.
for i in range(NUMTASKS):
self.newtask()
verbose_print("waiting for tasks to complete...")
self.done_mutex.acquire()
verbose_print("all tasks done")

def test_stack_size(self):
# Various stack size tests.
Expand Down Expand Up @@ -95,12 +96,13 @@ def test_nt_and_posix_stack_size(self):
verbose_print("trying stack_size = (%d)" % tss)
self.next_ident = 0
self.created = 0
for i in range(NUMTASKS):
self.newtask()
with support.wait_threads_exit():
for i in range(NUMTASKS):
self.newtask()

verbose_print("waiting for all tasks to complete")
self.done_mutex.acquire()
verbose_print("all tasks done")
verbose_print("waiting for all tasks to complete")
self.done_mutex.acquire()
verbose_print("all tasks done")

thread.stack_size(0)

Expand All @@ -110,25 +112,28 @@ def test__count(self):
mut = thread.allocate_lock()
mut.acquire()
started = []

def task():
started.append(None)
mut.acquire()
mut.release()
thread.start_new_thread(task, ())
while not started:
time.sleep(0.01)
self.assertEqual(thread._count(), orig + 1)
# Allow the task to finish.
mut.release()
# The only reliable way to be sure that the thread ended from the
# interpreter's point of view is to wait for the function object to be
# destroyed.
done = []
wr = weakref.ref(task, lambda _: done.append(None))
del task
while not done:
time.sleep(0.01)
self.assertEqual(thread._count(), orig)

with support.wait_threads_exit():
thread.start_new_thread(task, ())
while not started:
time.sleep(0.01)
self.assertEqual(thread._count(), orig + 1)
# Allow the task to finish.
mut.release()
# The only reliable way to be sure that the thread ended from the
# interpreter's point of view is to wait for the function object to be
# destroyed.
done = []
wr = weakref.ref(task, lambda _: done.append(None))
del task
while not done:
time.sleep(0.01)
self.assertEqual(thread._count(), orig)

def test_save_exception_state_on_error(self):
# See issue #14474
Expand All @@ -143,14 +148,13 @@ def mywrite(self, *args):
real_write(self, *args)
c = thread._count()
started = thread.allocate_lock()
with test_support.captured_output("stderr") as stderr:
with support.captured_output("stderr") as stderr:
real_write = stderr.write
stderr.write = mywrite
started.acquire()
thread.start_new_thread(task, ())
started.acquire()
while thread._count() > c:
time.sleep(0.01)
with support.wait_threads_exit():
thread.start_new_thread(task, ())
started.acquire()
self.assertIn("Traceback", stderr.getvalue())


Expand Down Expand Up @@ -182,13 +186,14 @@ def enter(self):
class BarrierTest(BasicThreadTest):

def test_barrier(self):
self.bar = Barrier(NUMTASKS)
self.running = NUMTASKS
for i in range(NUMTASKS):
thread.start_new_thread(self.task2, (i,))
verbose_print("waiting for tasks to end")
self.done_mutex.acquire()
verbose_print("tasks done")
with support.wait_threads_exit():
self.bar = Barrier(NUMTASKS)
self.running = NUMTASKS
for i in range(NUMTASKS):
thread.start_new_thread(self.task2, (i,))
verbose_print("waiting for tasks to end")
self.done_mutex.acquire()
verbose_print("tasks done")

def task2(self, ident):
for i in range(NUMTRIPS):
Expand Down Expand Up @@ -226,8 +231,9 @@ def setUp(self):

@unittest.skipIf(sys.platform.startswith('win'),
"This test is only appropriate for POSIX-like systems.")
@test_support.reap_threads
@support.reap_threads
def test_forkinthread(self):
non_local = {'status': None}
def thread1():
try:
pid = os.fork() # fork in a thread
Expand All @@ -246,11 +252,13 @@ def thread1():
else: # parent
os.close(self.write_fd)
pid, status = os.waitpid(pid, 0)
self.assertEqual(status, 0)
non_local['status'] = status

thread.start_new_thread(thread1, ())
self.assertEqual(os.read(self.read_fd, 2), "OK",
"Unable to fork() in thread")
with support.wait_threads_exit():
thread.start_new_thread(thread1, ())
self.assertEqual(os.read(self.read_fd, 2), "OK",
"Unable to fork() in thread")
self.assertEqual(non_local['status'], 0)

def tearDown(self):
try:
Expand All @@ -265,7 +273,7 @@ def tearDown(self):


def test_main():
test_support.run_unittest(ThreadRunningTests, BarrierTest, LockTests,
support.run_unittest(ThreadRunningTests, BarrierTest, LockTests,
TestForkInThread)

if __name__ == "__main__":
Expand Down