Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions Lib/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1606,19 +1606,24 @@ def dummycallback(sock, servername, ctx, cycle=ctx):
gc.collect()
self.assertIs(wr(), None)

@unittest.skipUnless(support.Py_GIL_DISABLED,
"test is only useful if the GIL is disabled")
@threading_helper.requires_working_threading()
def test_sni_callback_race(self):
# Replacing sni_callback while handshakes are in-flight must not
# Replacing sni_callback while a handshake is in-flight must not
# crash (use-after-free on the callback in free-threaded builds).
#
# Use a single handshake thread: OpenSSL has internal data races
# on shared SSL_CTX state when multiple handshakes run
# concurrently against the same context (gh-150191). Concurrency
# on the *setter* is what exercises the fix from gh-149816, so
# multiple toggler threads race against each other and against
# the one handshake worker.
client_ctx, server_ctx, hostname = testing_context()

server_ctx.sni_callback = lambda *a: None
done = threading.Event()
deadline = time.monotonic() + 0.1

def do_handshakes():
while not done.is_set():
while time.monotonic() < deadline:
c_in = ssl.MemoryBIO()
c_out = ssl.MemoryBIO()
s_in = ssl.MemoryBIO()
Expand All @@ -1645,19 +1650,11 @@ def do_handshakes():
c_in.write(s_out.read())

def toggle_callback():
while not done.is_set():
while time.monotonic() < deadline:
server_ctx.sni_callback = lambda *a: None
server_ctx.sni_callback = None

workers = max(4, (os.cpu_count() or 4) * 2)
threads = [threading.Thread(target=do_handshakes)
for _ in range(workers)]
threads.append(threading.Thread(target=toggle_callback))

with threading_helper.catch_threading_exception() as cm:
with threading_helper.start_threads(threads):
done.set()
self.assertIsNone(cm.exc_value)
threading_helper.run_concurrently([do_handshakes] + 4 * [toggle_callback])

def test_cert_store_stats(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
Expand Down
Loading