From b26c9aee4cf1cb280fa771855ab7be95f1ae1d22 Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Fri, 27 Mar 2026 02:43:36 +0100 Subject: [PATCH 1/2] Raise ConnectionResetError if transport is None (#11761) --- CHANGES/11761.bugfix.rst | 4 ++ aiohttp/web_fileresponse.py | 3 +- aiohttp/web_ws.py | 3 +- tests/test_web_sendfile_functional.py | 63 ++++++++++++++++++++++++++ tests/test_web_websocket_functional.py | 57 ++++++++++++++++++++++- 5 files changed, 127 insertions(+), 3 deletions(-) create mode 100644 CHANGES/11761.bugfix.rst diff --git a/CHANGES/11761.bugfix.rst b/CHANGES/11761.bugfix.rst new file mode 100644 index 00000000000..d4661c6d4a1 --- /dev/null +++ b/CHANGES/11761.bugfix.rst @@ -0,0 +1,4 @@ +Fixed ``AssertionError`` when the transport is ``None`` during WebSocket +preparation or file response sending (e.g. when a client disconnects +immediately after connecting). A ``ConnectionResetError`` is now raised +instead -- by :user:`agners`. diff --git a/aiohttp/web_fileresponse.py b/aiohttp/web_fileresponse.py index eeaa2010f98..f339bec9662 100644 --- a/aiohttp/web_fileresponse.py +++ b/aiohttp/web_fileresponse.py @@ -128,7 +128,8 @@ async def _sendfile( loop = request._loop transport = request.transport - assert transport is not None + if transport is None: + raise ConnectionResetError("Connection lost") try: await loop.sendfile(transport, fobj, offset, count) diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index ca129bb0f30..2aeeb6dec1f 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -361,7 +361,8 @@ def _pre_start(self, request: BaseRequest) -> tuple[str | None, WebSocketWriter] self.force_close() self._compress = compress transport = request._protocol.transport - assert transport is not None + if transport is None: + raise ConnectionResetError("Connection lost") writer = WebSocketWriter( request._protocol, transport, diff --git a/tests/test_web_sendfile_functional.py b/tests/test_web_sendfile_functional.py index 87be2db182b..1d695d332c4 100644 --- a/tests/test_web_sendfile_functional.py +++ b/tests/test_web_sendfile_functional.py @@ -1,5 +1,6 @@ import asyncio import bz2 +import contextlib import gzip import pathlib import socket @@ -15,6 +16,7 @@ from aiohttp.compression_utils import ZLibBackend from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer from aiohttp.typedefs import PathLike +from aiohttp.web_fileresponse import NOSENDFILE try: import brotlicffi as brotli @@ -1156,3 +1158,64 @@ async def handler(request: web.Request) -> web.FileResponse: resp.release() await client.close() + + +@pytest.mark.skipif(NOSENDFILE, reason="OS sendfile not available") +async def test_sendfile_after_client_disconnect( + aiohttp_client: AiohttpClient, tmp_path: pathlib.Path +) -> None: + """Test ConnectionResetError when client disconnects before sendfile. + + Reproduces the race condition where: + - Client sends a GET request for a file + - Handler does async work (e.g. auth check) before returning a FileResponse + - Client disconnects while the handler is busy + - Server then calls sendfile() → ConnectionResetError (not AssertionError) + + _send_headers_immediately is set to False so that super().prepare() + only buffers the headers without writing to the transport. Otherwise + _write() raises ClientConnectionResetError first and _sendfile()'s own + transport check is never reached. + """ + filepath = tmp_path / "test.txt" + filepath.write_bytes(b"x" * 1024) + + handler_started = asyncio.Event() + prepare_done = asyncio.Event() + captured_protocol = None + + async def handler(request: web.Request) -> web.Response: + nonlocal captured_protocol + resp = web.FileResponse(filepath) + resp._send_headers_immediately = False + captured_protocol = request._protocol + handler_started.set() + # Simulate async work (e.g., auth check) during which client disconnects. + await asyncio.sleep(0) + with pytest.raises(ConnectionResetError, match="Connection lost"): + await resp.prepare(request) + prepare_done.set() + return web.Response(status=503) + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + request_task = asyncio.create_task(client.get("/")) + + # Wait until the handler is running but has not yet returned the response. + await handler_started.wait() + assert captured_protocol is not None + + # Simulate the client disconnecting by setting transport to None directly. + # We cannot use force_close() because closing the TCP transport triggers + # connection_lost() which cancels the handler task before it can call + # prepare() and hit the ConnectionResetError in _sendfile(). + captured_protocol.transport = None + + # Wait for the handler to resume, call prepare(), and hit ConnectionResetError. + await asyncio.wait_for(prepare_done.wait(), timeout=1) + + request_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await request_task diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index 1e202649c6a..7257c47ba73 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -11,7 +11,7 @@ import pytest import aiohttp -from aiohttp import WSServerHandshakeError, web +from aiohttp import WSServerHandshakeError, hdrs, web from aiohttp.http import WSCloseCode, WSMsgType from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer @@ -1661,3 +1661,58 @@ async def websocket_handler( assert msg.type is aiohttp.WSMsgType.TEXT assert msg.data == "success" await ws.close() + + +async def test_prepare_after_client_disconnect(aiohttp_client: AiohttpClient) -> None: + """Test ConnectionResetError when client disconnects before ws.prepare(). + + Reproduces the race condition where: + - Client connects and sends a WebSocket upgrade request + - Handler starts async work (e.g. authentication) before calling ws.prepare() + - Client disconnects while the handler is busy + - Handler then calls ws.prepare() → ConnectionResetError (not AssertionError) + """ + handler_started = asyncio.Event() + captured_protocol = None + + async def handler(request: web.Request) -> web.Response: + nonlocal captured_protocol + ws = web.WebSocketResponse() + captured_protocol = request._protocol + handler_started.set() + # Simulate async work (e.g., auth check) during which client disconnects. + await asyncio.sleep(0) + with pytest.raises(ConnectionResetError, match="Connection lost"): + await ws.prepare(request) + return web.Response(status=503) + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + + request_task = asyncio.create_task( + client.session.get( + client.make_url("/"), + headers={ + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "Upgrade", + hdrs.SEC_WEBSOCKET_KEY: "dGhlIHNhbXBsZSBub25jZQ==", + hdrs.SEC_WEBSOCKET_VERSION: "13", + }, + ) + ) + + # Wait until the handler is running but has not yet called ws.prepare(). + await handler_started.wait() + assert captured_protocol is not None + + # Simulate the client disconnecting abruptly. + captured_protocol.force_close() + + # Yield so the handler can resume and hit the ConnectionResetError. + await asyncio.sleep(0) + + with contextlib.suppress( + aiohttp.ServerDisconnectedError, aiohttp.ClientConnectionResetError + ): + await request_task From 8b10afd473a6805cd2c2b6bf918543942941a869 Mon Sep 17 00:00:00 2001 From: Rodrigo Nogueira Date: Thu, 26 Mar 2026 23:58:34 -0300 Subject: [PATCH 2/2] Fix credential leak on same-host redirects with different ports (#12275) --- CHANGES/5783.feature | 1 - aiohttp/client.py | 11 +------- docs/client_advanced.rst | 11 -------- tests/test_client_functional.py | 48 +++++++++++++-------------------- 4 files changed, 20 insertions(+), 51 deletions(-) delete mode 100644 CHANGES/5783.feature diff --git a/CHANGES/5783.feature b/CHANGES/5783.feature deleted file mode 100644 index 6b5c534f66f..00000000000 --- a/CHANGES/5783.feature +++ /dev/null @@ -1 +0,0 @@ -Started keeping the ``Authorization`` header during HTTP -> HTTPS redirects when the host remains the same. diff --git a/aiohttp/client.py b/aiohttp/client.py index f8bfe3e5bdc..c3e874e650d 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -869,12 +869,6 @@ async def _connect_and_send_request( elif not scheme: parsed_redirect_url = url.join(parsed_redirect_url) - is_same_host_https_redirect = ( - url.host == parsed_redirect_url.host - and parsed_redirect_url.scheme == "https" - and url.scheme == "http" - ) - try: redirect_origin = parsed_redirect_url.origin() except ValueError as origin_val_err: @@ -886,10 +880,7 @@ async def _connect_and_send_request( "Invalid redirect URL origin", ) from origin_val_err - if ( - not is_same_host_https_redirect - and url.origin() != redirect_origin - ): + if url.origin() != redirect_origin: auth = None headers.pop(hdrs.AUTHORIZATION, None) headers.pop(hdrs.COOKIE, None) diff --git a/docs/client_advanced.rst b/docs/client_advanced.rst index ebec0eef5a8..dc0dfb3660e 100644 --- a/docs/client_advanced.rst +++ b/docs/client_advanced.rst @@ -114,17 +114,6 @@ In cases where the authentication header value expires periodically, an :mod:`asyncio` task may be used to update the session's default headers in the background. -.. note:: - - The ``Authorization`` header will be removed if you get redirected - to a different host or protocol, except the case when HTTP → HTTPS - redirect is performed on the same host. - -.. versionchanged:: 4.0 - - Started keeping the ``Authorization`` header during HTTP → HTTPS - redirects when the host remains the same. - .. _aiohttp-client-middleware: Client Middleware diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 65d9cd28078..8ee45330bb5 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -3390,61 +3390,51 @@ def create(url: URL, srv: Handler) -> Awaitable[TestServer]: @pytest.mark.parametrize( - ["url_from_s", "url_to_s", "is_drop_header_expected"], - [ - [ - "http://host1.com/path1", - "http://host2.com/path2", - True, - ], - ["http://host1.com/path1", "https://host1.com/path1", False], - ["https://host1.com/path1", "http://host1.com/path2", True], - ], + ("url_from_s", "url_to_s"), + ( + ("http://host1.com/path1", "http://host2.com/path2"), + ("http://host1.com/path1", "https://host1.com/path1"), + ("https://host1.com/path1", "http://host1.com/path2"), + ("http://host1.com/path1", "https://host1.com:9443/path1"), + ), ids=( "entirely different hosts", "http -> https", "https -> http", + "http -> https different port", ), ) async def test_drop_auth_on_redirect_to_other_host( create_server_for_url_and_handler: Callable[[URL, Handler], Awaitable[TestServer]], url_from_s: str, url_to_s: str, - is_drop_header_expected: bool, ) -> None: url_from, url_to = URL(url_from_s), URL(url_to_s) async def srv_from(request: web.Request) -> NoReturn: - assert request.host == url_from.host + assert request.host.split(":")[0] == url_from.host assert request.headers["Authorization"] == "Basic dXNlcjpwYXNz" raise web.HTTPFound(url_to) async def srv_to(request: web.Request) -> web.Response: - assert request.host == url_to.host - if is_drop_header_expected: - assert "Authorization" not in request.headers, "Header wasn't dropped" - assert "Proxy-Authorization" not in request.headers - assert "Cookie" not in request.headers - else: - assert "Authorization" in request.headers, "Header was dropped" - assert "Proxy-Authorization" in request.headers - assert "Cookie" in request.headers + assert request.host.split(":")[0] == url_to.host + assert "Authorization" not in request.headers, "Header wasn't dropped" + assert "Proxy-Authorization" not in request.headers + assert "Cookie" not in request.headers return web.Response() server_from = await create_server_for_url_and_handler(url_from, srv_from) server_to = await create_server_for_url_and_handler(url_to, srv_to) assert ( - url_from.host != url_to.host or server_from.scheme != server_to.scheme - ), "Invalid test case, host or scheme must differ" + url_from.host != url_to.host + or server_from.scheme != server_to.scheme + or url_from.port != url_to.port + ), "Invalid test case, host, scheme, or port must differ" - protocol_port_map = { - "http": 80, - "https": 443, - } etc_hosts = { - (url_from.host, protocol_port_map[server_from.scheme]): server_from, - (url_to.host, protocol_port_map[server_to.scheme]): server_to, + (url_from.host, url_from.port): server_from, + (url_to.host, url_to.port): server_to, } class FakeResolver(AbstractResolver):