Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions Lib/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,6 +1533,59 @@ def dummycallback(sock, servername, ctx, cycle=ctx):
gc.collect()
self.assertIs(wr(), None)

@unittest.skipUnless(support.Py_GIL_DISABLED,
"test is only useful if the GIL is disabled")
@threading_helper.requires_working_threading()
def test_sni_callback_race(self):
# Replacing sni_callback while handshakes are in-flight must not
# crash (use-after-free on the callback in free-threaded builds).
client_ctx, server_ctx, hostname = testing_context()

server_ctx.sni_callback = lambda *a: None
done = threading.Event()

def do_handshakes():
while not done.is_set():
c_in = ssl.MemoryBIO()
c_out = ssl.MemoryBIO()
s_in = ssl.MemoryBIO()
s_out = ssl.MemoryBIO()
client = client_ctx.wrap_bio(
c_in, c_out, server_hostname=hostname)
server = server_ctx.wrap_bio(s_in, s_out, server_side=True)
for _ in range(50):
try:
client.do_handshake()
except ssl.SSLWantReadError:
pass
except ssl.SSLError:
break
if c_out.pending:
s_in.write(c_out.read())
try:
server.do_handshake()
except ssl.SSLWantReadError:
pass
except ssl.SSLError:
break
if s_out.pending:
c_in.write(s_out.read())

def toggle_callback():
while not done.is_set():
server_ctx.sni_callback = lambda *a: None
server_ctx.sni_callback = None

workers = max(4, (os.cpu_count() or 4) * 2)
threads = [threading.Thread(target=do_handshakes)
for _ in range(workers)]
threads.append(threading.Thread(target=toggle_callback))

with threading_helper.catch_threading_exception() as cm:
with threading_helper.start_threads(threads):
done.set()
self.assertIsNone(cm.exc_value)

def test_cert_store_stats(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self.assertEqual(ctx.cert_store_stats(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix race condition in :attr:`ssl.SSLContext.sni_callback`
36 changes: 20 additions & 16 deletions Modules/_ssl.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#define OPENSSL_NO_DEPRECATED 1

#include "Python.h"
#include "pycore_critical_section.h" // Py_BEGIN_CRITICAL_SECTION()
#include "pycore_fileutils.h" // _PyIsSelectable_fd()
#include "pycore_long.h" // _PyLong_UnsignedLongLong_Converter()
#include "pycore_pyerrors.h" // _PyErr_ChainExceptions1()
Expand Down Expand Up @@ -4669,12 +4670,15 @@ _servername_callback(SSL *s, int *al, void *args)
PyObject *result;
/* The high-level ssl.SSLSocket object */
PyObject *ssl_socket;
PyObject *sni_cb;
const char *servername = SSL_get_servername(s, TLSEXT_NAMETYPE_host_name);
PyGILState_STATE gstate = PyGILState_Ensure();

if (sslctx->set_sni_cb == NULL) {
/* remove race condition in this the call back while if removing the
* callback is in progress */
Py_BEGIN_CRITICAL_SECTION(sslctx);
sni_cb = Py_XNewRef(sslctx->set_sni_cb);
Py_END_CRITICAL_SECTION();

if (sni_cb == NULL) {
PyGILState_Release(gstate);
return SSL_TLSEXT_ERR_OK;
}
Expand All @@ -4701,7 +4705,7 @@ _servername_callback(SSL *s, int *al, void *args)
goto error;

if (servername == NULL) {
result = PyObject_CallFunctionObjArgs(sslctx->set_sni_cb, ssl_socket,
result = PyObject_CallFunctionObjArgs(sni_cb, ssl_socket,
Py_None, sslctx, NULL);
}
else {
Expand All @@ -4728,7 +4732,7 @@ _servername_callback(SSL *s, int *al, void *args)
}
Py_DECREF(servername_bytes);
result = PyObject_CallFunctionObjArgs(
sslctx->set_sni_cb, ssl_socket, servername_str,
sni_cb, ssl_socket, servername_str,
sslctx, NULL);
Py_DECREF(servername_str);
}
Expand All @@ -4738,7 +4742,7 @@ _servername_callback(SSL *s, int *al, void *args)
PyErr_FormatUnraisable("Exception ignored "
"in ssl servername callback "
"while calling set SNI callback %R",
sslctx->set_sni_cb);
sni_cb);
*al = SSL_AD_HANDSHAKE_FAILURE;
ret = SSL_TLSEXT_ERR_ALERT_FATAL;
}
Expand All @@ -4763,11 +4767,13 @@ _servername_callback(SSL *s, int *al, void *args)
Py_DECREF(result);
}

Py_DECREF(sni_cb);
PyGILState_Release(gstate);
return ret;

error:
Py_XDECREF(ssl_socket);
Py_XDECREF(sni_cb);
*al = SSL_AD_INTERNAL_ERROR;
ret = SSL_TLSEXT_ERR_ALERT_FATAL;
PyGILState_Release(gstate);
Expand Down Expand Up @@ -4813,20 +4819,18 @@ _ssl__SSLContext_sni_callback_set_impl(PySSLContext *self, PyObject *value)
"sni_callback cannot be set on TLS_CLIENT context");
return -1;
}
Py_CLEAR(self->set_sni_cb);
if (value == Py_None) {
if (!PyCallable_Check(value)) {
SSL_CTX_set_tlsext_servername_callback(self->ctx, NULL);
}
else {
if (!PyCallable_Check(value)) {
SSL_CTX_set_tlsext_servername_callback(self->ctx, NULL);
PyErr_SetString(PyExc_TypeError,
"not a callable object");
Py_CLEAR(self->set_sni_cb);
if (value != Py_None) {
PyErr_SetString(PyExc_TypeError, "not a callable object");
return -1;
}
self->set_sni_cb = Py_NewRef(value);
SSL_CTX_set_tlsext_servername_callback(self->ctx, _servername_callback);
}
else {
Py_XSETREF(self->set_sni_cb, Py_NewRef(value));
SSL_CTX_set_tlsext_servername_arg(self->ctx, self);
SSL_CTX_set_tlsext_servername_callback(self->ctx, _servername_callback);
}
return 0;
}
Expand Down
Loading