From b2ed3c3c26ad8163bc1f56824f4b1db4fc75af75 Mon Sep 17 00:00:00 2001 From: Maksym Kasimov Date: Sun, 30 Nov 2025 19:42:23 +0200 Subject: [PATCH] asyncio.streams: transfer buffered data to SSL layer in start_tls() --- Lib/asyncio/base_events.py | 26 +++ Lib/test/test_asyncio/test_streams.py | 159 ++++++++++++++++++ ...2-06-16-14-18.gh-issue-142352.pW5HLX88.rst | 4 + 3 files changed, 189 insertions(+) create mode 100644 Misc/NEWS.d/next/Library/2025-12-06-16-14-18.gh-issue-142352.pW5HLX88.rst diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 8cbb71f708537f..7a94b4fe3c2d7b 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -1311,6 +1311,30 @@ async def _sendfile_fallback(self, transp, file, offset, count): file.seek(offset + total_sent) await proto.restore() + def _transfer_buffered_data_to_ssl(self, protocol, ssl_protocol): + """Transfer buffered data from StreamReader to SSL incoming BIO. + + When using start_tls() mid-connection (e.g., after reading a + PROXY protocol header), any data already buffered in the + StreamReader would be lost. This transfers that data to the + SSL layer so the handshake can proceed. + + Note: This only works with StreamReaderProtocol (used by the + streams API). Custom Protocol implementations that buffer data + must handle this manually before calling start_tls(). + """ + if not hasattr(protocol, '_stream_reader'): + return + + stream_reader = protocol._stream_reader + if stream_reader is None: + return + + buffer = stream_reader._buffer + if buffer: + ssl_protocol._incoming.write(buffer) + buffer.clear() + async def start_tls(self, transport, protocol, sslcontext, *, server_side=False, server_hostname=None, @@ -1341,6 +1365,8 @@ async def start_tls(self, transport, protocol, sslcontext, *, ssl_shutdown_timeout=ssl_shutdown_timeout, call_connection_made=False) + self._transfer_buffered_data_to_ssl(protocol, ssl_protocol) + # Pause early so that "ssl_protocol.data_received()" doesn't # have a chance to get called before "ssl_protocol.connection_made()". transport.pause_reading() diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index f93ee54abc6469..f704ef706bc0a3 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -819,6 +819,165 @@ async def client(addr): self.assertEqual(msg1, b"hello world 1!\n") self.assertEqual(msg2, b"hello world 2!\n") + def _run_test_start_tls_behind_proxy(self, send_combined): + """Test start_tls() when TLS ClientHello arrives with PROXY header. + + This simulates HAProxy with send-proxy, where the PROXY protocol + header and TLS handshake data may arrive in the same TCP segment. + Without the fix, buffered TLS data would be lost after start_tls(). + """ + + def reverse_message(data): + return data.strip()[::-1] + b'\n' + + test_message = b"hello world\n" + expected_response = reverse_message(test_message) + + class TCPProxyServer: + """A simple TCP proxy server that adds a PROXY protocol header + before forwarding data to the target server.""" + + PROXY_LINE = b"PROXY TCP4 127.0.0.1 127.0.0.1 54321 443\r\n" + + def __init__(self, loop, target_host, target_port): + self.loop = loop + self.target_host = target_host + self.target_port = target_port + self.server = None + + async def _pipe(self, reader, writer): + try: + while True: + data = await reader.read(4096) + if not data: + break + writer.write(data) + await writer.drain() + finally: + writer.close() + await writer.wait_closed() + + async def handle_client(self, client_reader, client_writer): + # Connecting to the target server + remote_reader, remote_writer = await asyncio.open_connection( + self.target_host, self.target_port) + + # Reading data from the client (TLS ClientHello) + tls_data = await client_reader.read(4096) + + if send_combined: + # send everything together: PROXY + TLS data + remote_writer.write(self.PROXY_LINE + tls_data) + await remote_writer.drain() + else: + # send TLS data after the PROXY line + remote_writer.write(self.PROXY_LINE) + await remote_writer.drain() + await asyncio.sleep(0.01) + remote_writer.write(tls_data) + await remote_writer.drain() + + await asyncio.gather( + self._pipe(client_reader, remote_writer), + self._pipe(remote_reader, client_writer), + ) + + def start(self): + sock = socket.create_server(('127.0.0.1', 0)) + self.server = self.loop.run_until_complete( + asyncio.start_server(self.handle_client, sock=sock)) + return sock.getsockname() + + def stop(self): + if self.server: + self.server.close() + self.loop.run_until_complete(self.server.wait_closed()) + self.server = None + + class ServerWithSendProxySupport: + """A server that supports the PROXY protocol and starts TLS + after receiving the PROXY header.""" + + def __init__(self, test_case, loop): + self.test = test_case + self.server = None + self.loop = loop + + async def handle_client(self, client_reader, client_writer): + proxy_line = await client_reader.readline() + self.test.assertEqual(proxy_line, TCPProxyServer.PROXY_LINE) + + # Now we can start TLS + self.test.assertIsNone( + client_writer.get_extra_info('sslcontext')) + await client_writer.start_tls( + test_utils.simple_server_sslcontext() + ) + self.test.assertIsNotNone( + client_writer.get_extra_info('sslcontext')) + + data = await client_reader.readline() + client_writer.write(reverse_message(data)) + await client_writer.drain() + client_writer.close() + await client_writer.wait_closed() + + def start(self): + sock = socket.create_server(('127.0.0.1', 0)) + self.server = self.loop.run_until_complete( + asyncio.start_server(self.handle_client, + sock=sock)) + return sock.getsockname() + + def stop(self): + if self.server is not None: + self.server.close() + self.loop.run_until_complete(self.server.wait_closed()) + self.server = None + + async def client(addr, test_case): + reader, writer = await asyncio.open_connection(*addr) + + test_case.assertIsNone(writer.get_extra_info('sslcontext')) + await writer.start_tls(test_utils.simple_client_sslcontext()) + test_case.assertIsNotNone(writer.get_extra_info('sslcontext')) + + writer.write(test_message) + await writer.drain() + msgback = await reader.readline() + writer.close() + await writer.wait_closed() + return msgback + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + + server = ServerWithSendProxySupport(self, self.loop) + server_addr = server.start() + + proxy = TCPProxyServer(self.loop, *server_addr) + proxy_addr = proxy.start() + + msg = self.loop.run_until_complete( + asyncio.wait_for(client(proxy_addr, self), timeout=5.0) + ) + + proxy.stop() + server.stop() + + self.assertEqual(messages, []) + self.assertEqual(msg, expected_response) + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_start_tls_behind_proxy_send_combined(self): + # Test with sending PROXY header and TLS data in one packet + self._run_test_start_tls_behind_proxy(send_combined=True) + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_start_tls_behind_proxy_send_separate(self): + # Test with sending PROXY header and TLS data in separate packets + self._run_test_start_tls_behind_proxy(send_combined=False) + def test_streamreader_constructor_without_loop(self): with self.assertRaisesRegex(RuntimeError, 'no current event loop'): asyncio.StreamReader() diff --git a/Misc/NEWS.d/next/Library/2025-12-06-16-14-18.gh-issue-142352.pW5HLX88.rst b/Misc/NEWS.d/next/Library/2025-12-06-16-14-18.gh-issue-142352.pW5HLX88.rst new file mode 100644 index 00000000000000..13e38b118175b4 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2025-12-06-16-14-18.gh-issue-142352.pW5HLX88.rst @@ -0,0 +1,4 @@ +Fix :meth:`asyncio.StreamWriter.start_tls` to transfer buffered data from +:class:`~asyncio.StreamReader` to the SSL layer, preventing data loss when +upgrading a connection to TLS mid-stream (e.g., when implementing PROXY +protocol support).