Skip to content

gh-81536: For nonblocking sockets, add SSLSocket.eager_recv to call SSL_read in a loop #31492

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions Doc/library/ssl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1315,6 +1315,20 @@ SSL sockets also have the following additional methods and attributes:

.. versionadded:: 3.2

.. attribute:: SSLSocket.eager_recv

If set to ``True``, a call to :meth:`~socket.socket.recv()` or
:meth:`~socket.socket.recv_into()` on a
:ref:`non-blocking <ssl-nonblocking>` TLS socket
will drop the GIL once to read the entire buffer instead of reading at most
one TLS record (16 KB).

.. note::
Reading the entire buffer can include the TLS EOF segment, which will
close the TLS layer without raising :exc:`SSLEOFError`.

.. versionadded:: 3.12

.. attribute:: SSLSocket.server_side

A boolean which is ``True`` for server-side sockets and ``False`` for
Expand Down
20 changes: 20 additions & 0 deletions Lib/ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,15 @@ def session_reused(self):
"""Was the client session reused during handshake"""
return self._sslobj.session_reused

@property
def eager_recv(self):
"""If data is read from the socket eagerly, ignoring possible TLS EOF packets."""
return self._sslobj.eager_recv

@eager_recv.setter
def eager_recv(self, eager_recv):
self._sslobj.eager_recv = eager_recv

@property
def server_side(self):
"""Whether this is a server-side socket."""
Expand Down Expand Up @@ -1044,6 +1053,17 @@ def session_reused(self):
if self._sslobj is not None:
return self._sslobj.session_reused

@property
@_sslcopydoc
def eager_recv(self):
if self._sslobj is not None:
return self._sslobj.eager_recv

@eager_recv.setter
def eager_recv(self, eager_recv):
if self._sslobj is not None:
self._sslobj.eager_recv = eager_recv

def dup(self):
raise NotImplementedError("Can't dup() %s instances" %
self.__class__.__name__)
Expand Down
58 changes: 52 additions & 6 deletions Lib/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2118,6 +2118,49 @@ def test_bio_read_write_data(self):
self.assertEqual(buf, b'foo\n')
self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)

def test_bulk_nonblocking_read(self):
# 65536 bytes divide up into 4 TLS records (16 KB each)
# In nonblocking mode, we should be able to read all four in a single
# drop of the GIL.
size = 65536

client_context, server_context, hostname = testing_context()
server = ThreadedEchoServer(context=server_context, chatty=False,
buffer_size=size)
with server:
sock = socket.create_connection((HOST, server.port))
sock.settimeout(0.0)
s = client_context.wrap_socket(sock, server_hostname=hostname,
do_handshake_on_connect=False)
s.eager_recv = True
with s:
while True:
try:
s.do_handshake()
break
except ssl.SSLWantReadError:
select.select([s], [], [])
except ssl.SSLWantWriteError:
select.select([], [s], [])

s.send(b'\x00' * size)

select.select([s], [], [])

while size > 0:
try:
count = len(s.recv(size))
except ssl.SSLWantReadError:
select.select([s], [], [])
# Give the sender some more time to complete sending.
time.sleep(0.01)
else:
if count > 16384:
return
size -= count

self.fail("All TLS reads were smaller than 16KB")


@support.requires_resource('network')
class NetworkedTests(unittest.TestCase):
Expand Down Expand Up @@ -2177,7 +2220,7 @@ class ConnectionHandler(threading.Thread):
with and without the SSL wrapper around the socket connection, so
that we can test the STARTTLS functionality."""

def __init__(self, server, connsock, addr):
def __init__(self, server, connsock, addr, buffer_size):
self.server = server
self.running = False
self.sock = connsock
Expand All @@ -2186,6 +2229,7 @@ def __init__(self, server, connsock, addr):
self.sslconn = None
threading.Thread.__init__(self)
self.daemon = True
self.buffer_size = buffer_size

def wrap_conn(self):
try:
Expand Down Expand Up @@ -2251,9 +2295,9 @@ def wrap_conn(self):

def read(self):
if self.sslconn:
return self.sslconn.read()
return self.sslconn.read(self.buffer_size)
else:
return self.sock.recv(1024)
return self.sock.recv(self.buffer_size)

def write(self, bytes):
if self.sslconn:
Expand Down Expand Up @@ -2371,8 +2415,8 @@ def run(self):
def __init__(self, certificate=None, ssl_version=None,
certreqs=None, cacerts=None,
chatty=True, connectionchatty=False, starttls_server=False,
alpn_protocols=None,
ciphers=None, context=None):
alpn_protocols=None, ciphers=None, context=None,
buffer_size=1024):
if context:
self.context = context
else:
Expand Down Expand Up @@ -2401,6 +2445,7 @@ def __init__(self, certificate=None, ssl_version=None,
self.conn_errors = []
threading.Thread.__init__(self)
self.daemon = True
self.buffer_size = buffer_size

def __enter__(self):
self.start(threading.Event())
Expand Down Expand Up @@ -2428,7 +2473,8 @@ def run(self):
if support.verbose and self.chatty:
sys.stdout.write(' server: new connection from '
+ repr(connaddr) + '\n')
handler = self.ConnectionHandler(self, newconn, connaddr)
handler = self.ConnectionHandler(self, newconn, connaddr,
self.buffer_size)
handler.start()
handler.join()
except TimeoutError as e:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Added :attr:`ssl.SSLSocket.eager_recv`, if enabled a :ref:`non-blocking <ssl-nonblocking>`
TLS socket will drop the GIL once to read up to the entire buffer instead of reading at
most TLS record (16 KB). Patch by Josh Snyder and Safihre.
38 changes: 35 additions & 3 deletions Modules/_ssl.c
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ typedef struct {
PyObject *Socket; /* weakref to socket on which we're layered */
SSL *ssl;
PySSLContext *ctx; /* weakref to SSL context */
int eager_recv;
char shutdown_seen_zero;
enum py_ssl_server_or_client socket_type;
PyObject *owner; /* Python level "owner" passed to servername callback */
Expand Down Expand Up @@ -799,6 +800,7 @@ newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock,
self->ssl = NULL;
self->Socket = NULL;
self->ctx = (PySSLContext*)Py_NewRef(sslctx);
self->eager_recv = 0;
self->shutdown_seen_zero = 0;
self->owner = NULL;
self->server_hostname = NULL;
Expand Down Expand Up @@ -2118,6 +2120,22 @@ static int PySSL_set_context(PySSLSocket *self, PyObject *value,
return 0;
}

static PyObject *
PySSL_get_eager_recv(PySSLSocket *self, void *c)
{
return PyBool_FromLong(self->eager_recv);
}

static int
PySSL_set_eager_recv(PySSLSocket *self, PyObject *arg, void *c)
{
int eager_recv;
if (!PyArg_Parse(arg, "p", &eager_recv))
return -1;
self->eager_recv = eager_recv;
return 0;
}

PyDoc_STRVAR(PySSL_set_context_doc,
"_setter_context(ctx)\n\
\
Expand Down Expand Up @@ -2430,10 +2448,11 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len,
PyObject *dest = NULL;
char *mem;
size_t count = 0;
size_t readbytes = 0;
int retval;
int sockstate;
_PySSLError err;
int nonblocking;
int nonblocking = 0;
PySocketSockObject *sock = GET_SOCKET(self);
_PyTime_t timeout, deadline = 0;
int has_timeout;
Expand Down Expand Up @@ -2493,11 +2512,22 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len,

do {
PySSL_BEGIN_ALLOW_THREADS
retval = SSL_read_ex(self->ssl, mem, (size_t)len, &count);
do {
retval = SSL_read_ex(self->ssl, mem + count, len, &readbytes);
if (retval <= 0) {
break;
}
count += readbytes;
len -= readbytes;
} while (nonblocking && self->eager_recv && len > 0);
err = _PySSL_errno(retval == 0, self->ssl, retval);
PySSL_END_ALLOW_THREADS
self->err = err;

if (count > 0) {
break;
}

if (PyErr_CheckSignals())
goto error;

Expand Down Expand Up @@ -2528,7 +2558,7 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len,
} while (err.ssl == SSL_ERROR_WANT_READ ||
err.ssl == SSL_ERROR_WANT_WRITE);

if (retval == 0) {
if (count == 0) {
PySSL_SetError(self, retval, __FILE__, __LINE__);
goto error;
}
Expand Down Expand Up @@ -2877,6 +2907,8 @@ PyDoc_STRVAR(PySSL_get_session_reused_doc,
static PyGetSetDef ssl_getsetlist[] = {
{"context", (getter) PySSL_get_context,
(setter) PySSL_set_context, PySSL_set_context_doc},
{"eager_recv", (getter) PySSL_get_eager_recv,
(setter) PySSL_set_eager_recv, NULL},
{"server_side", (getter) PySSL_get_server_side, NULL,
PySSL_get_server_side_doc},
{"server_hostname", (getter) PySSL_get_server_hostname, NULL,
Expand Down