Skip to content

Commit

Permalink
pythonGH-110829: Ensure Thread.join() joins the OS thread
Browse files Browse the repository at this point in the history
  • Loading branch information
pitrou committed Oct 27, 2023
1 parent aa73245 commit 48482e9
Show file tree
Hide file tree
Showing 7 changed files with 378 additions and 23 deletions.
8 changes: 8 additions & 0 deletions Include/pythread.h
Expand Up @@ -20,6 +20,14 @@ PyAPI_FUNC(unsigned long) PyThread_start_new_thread(void (*)(void *), void *);
PyAPI_FUNC(void) _Py_NO_RETURN PyThread_exit_thread(void);
PyAPI_FUNC(unsigned long) PyThread_get_thread_ident(void);

#if !defined(Py_LIMITED_API)
PyAPI_FUNC(unsigned long) PyThread_start_joinable_thread(void (*func)(void *),
void *arg,
Py_uintptr_t* handle);
PyAPI_FUNC(int) PyThread_join_thread(Py_uintptr_t);
PyAPI_FUNC(int) PyThread_detach_thread(Py_uintptr_t);
#endif

#if (defined(__APPLE__) || defined(__linux__) || defined(_WIN32) \
|| defined(__FreeBSD__) || defined(__OpenBSD__) || defined(__NetBSD__) \
|| defined(__DragonFly__) || defined(_AIX))
Expand Down
5 changes: 4 additions & 1 deletion Lib/test/_test_multiprocessing.py
Expand Up @@ -2578,7 +2578,7 @@ def test_async(self):
self.assertTimingAlmostEqual(get.elapsed, TIMEOUT1)

def test_async_timeout(self):
res = self.pool.apply_async(sqr, (6, TIMEOUT2 + support.SHORT_TIMEOUT))
res = self.pool.apply_async(sqr, (6, 5 * TIMEOUT2))
get = TimingWrapper(res.get)
self.assertRaises(multiprocessing.TimeoutError, get, timeout=TIMEOUT2)
self.assertTimingAlmostEqual(get.elapsed, TIMEOUT2)
Expand Down Expand Up @@ -2682,6 +2682,9 @@ def test_make_pool(self):
p.join()

def test_terminate(self):
if self.TYPE == 'threads':
self.skipTest("Threads cannot be terminated")

# Simulate slow tasks which take "forever" to complete
args = [support.LONG_TIMEOUT for i in range(10_000)]
result = self.pool.map_async(time.sleep, args, chunksize=1)
Expand Down
99 changes: 99 additions & 0 deletions Lib/test/test_thread.py
Expand Up @@ -160,6 +160,105 @@ def task():
"Exception ignored in thread started by")
self.assertIsNotNone(cm.unraisable.exc_traceback)

def test_join_thread(self):
finished = []

def task():
time.sleep(0.05)
finished.append(None)

with threading_helper.wait_threads_exit():
joinable = True
ident = thread.start_new_thread(task, (), {}, joinable)
thread.join_thread(ident)
self.assertEqual(len(finished), 1)

def test_join_thread_already_exited(self):
def task():
pass

with threading_helper.wait_threads_exit():
joinable = True
ident = thread.start_new_thread(task, (), {}, joinable)
time.sleep(0.05)
thread.join_thread(ident)

def test_join_non_joinable(self):
def task():
pass

with threading_helper.wait_threads_exit():
ident = thread.start_new_thread(task, ())
with self.assertRaisesRegex(ValueError, "not joinable"):
thread.join_thread(ident)

def test_join_several_times(self):
def task():
pass

with threading_helper.wait_threads_exit():
joinable = True
ident = thread.start_new_thread(task, (), {}, joinable)
thread.join_thread(ident)
with self.assertRaisesRegex(ValueError, "not joinable"):
thread.join_thread(ident)

def test_join_from_self(self):
errors = []
lock = thread.allocate_lock()
lock.acquire()

def task():
ident = thread.get_ident()
# Wait for start_new_thread() to return so that the joinable threads
# are populated with the ident, otherwise ValueError would be raised
# instead.
lock.acquire()
try:
thread.join_thread(ident)
except Exception as e:
errors.append(e)

with threading_helper.wait_threads_exit():
joinable = True
ident = thread.start_new_thread(task, (), {}, joinable)
lock.release()
time.sleep(0.05)
# Can still join after join_thread() failed in other thread
thread.join_thread(ident)

assert len(errors) == 1
with self.assertRaisesRegex(RuntimeError, "Cannot join current thread"):
raise errors[0]

def test_detach_then_join(self):
lock = thread.allocate_lock()
lock.acquire()

def task():
lock.acquire()

with threading_helper.wait_threads_exit():
joinable = True
ident = thread.start_new_thread(task, (), {}, joinable)
# detach_thread() returns even though the thread is blocked on lock
thread.detach_thread(ident)
# join_thread() then cannot be called anymore
with self.assertRaisesRegex(ValueError, "not joinable"):
thread.join_thread(ident)
lock.release()

def test_join_then_detach(self):
def task():
pass

with threading_helper.wait_threads_exit():
joinable = True
ident = thread.start_new_thread(task, (), {}, joinable)
thread.join_thread(ident)
with self.assertRaisesRegex(ValueError, "not joinable"):
thread.detach_thread(ident)


class Barrier:
def __init__(self, num_threads):
Expand Down
35 changes: 33 additions & 2 deletions Lib/threading.py
Expand Up @@ -5,6 +5,7 @@
import _thread
import functools
import warnings
import _weakref

from time import monotonic as _time
from _weakrefset import WeakSet
Expand Down Expand Up @@ -34,6 +35,8 @@

# Rename some stuff so "from threading import *" is safe
_start_new_thread = _thread.start_new_thread
_join_thread = _thread.join_thread
_detach_thread = _thread.detach_thread
_daemon_threads_allowed = _thread.daemon_threads_allowed
_allocate_lock = _thread.allocate_lock
_set_sentinel = _thread._set_sentinel
Expand Down Expand Up @@ -924,6 +927,7 @@ class is implemented.
if _HAVE_THREAD_NATIVE_ID:
self._native_id = None
self._tstate_lock = None
self._join_lock = None
self._started = Event()
self._is_stopped = False
self._initialized = True
Expand All @@ -944,11 +948,14 @@ def _reset_internal_locks(self, is_alive):
if self._tstate_lock is not None:
self._tstate_lock._at_fork_reinit()
self._tstate_lock.acquire()
if self._join_lock is not None:
self._join_lock._at_fork_reinit()
else:
# The thread isn't alive after fork: it doesn't have a tstate
# anymore.
self._is_stopped = True
self._tstate_lock = None
self._join_lock = None

def __repr__(self):
assert self._initialized, "Thread.__init__() was not called"
Expand Down Expand Up @@ -980,15 +987,24 @@ def start(self):
if self._started.is_set():
raise RuntimeError("threads can only be started once")

self._join_lock = _allocate_lock()

with _active_limbo_lock:
_limbo[self] = self
try:
_start_new_thread(self._bootstrap, ())
# Start joinable thread
_start_new_thread(self._bootstrap, (), {}, True)
except Exception:
with _active_limbo_lock:
del _limbo[self]
raise
self._started.wait()
self._started.wait() # Will set ident and native_id

# We need to make sure the OS thread is either explicitly joined or
# detached at some point, otherwise system resources can be leaked.
def _finalizer(wr, _detach_thread=_detach_thread, ident=self._ident):
_detach_thread(ident)
self._non_joined_finalizer = _weakref.ref(self, _finalizer)

def run(self):
"""Method representing the thread's activity.
Expand Down Expand Up @@ -1144,6 +1160,19 @@ def join(self, timeout=None):
# historically .join(timeout=x) for x<0 has acted as if timeout=0
self._wait_for_tstate_lock(timeout=max(timeout, 0))

if self._is_stopped:
self._join_os_thread()

def _join_os_thread(self):
join_lock = self._join_lock
if join_lock is not None:
# Calling join() multiple times simultaneously would result in early
# return for one of the callers.
with join_lock:
_join_thread(self._ident)
self._join_lock = None
self._non_joined_finalizer = None

def _wait_for_tstate_lock(self, block=True, timeout=-1):
# Issue #18808: wait for the thread state to be gone.
# At the end of the thread's life, after all knowledge of the thread
Expand Down Expand Up @@ -1223,6 +1252,8 @@ def is_alive(self):
if self._is_stopped or not self._started.is_set():
return False
self._wait_for_tstate_lock(False)
if self._is_stopped:
self._join_os_thread()
return not self._is_stopped

@property
Expand Down

0 comments on commit 48482e9

Please sign in to comment.