Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGES/11761.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Fixed ``AssertionError`` when the transport is ``None`` during WebSocket
preparation or file response sending (e.g. when a client disconnects
immediately after connecting). A ``ConnectionResetError`` is now raised
instead -- by :user:`agners`.
1 change: 0 additions & 1 deletion CHANGES/5783.feature

This file was deleted.

11 changes: 1 addition & 10 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,12 +869,6 @@ async def _connect_and_send_request(
elif not scheme:
parsed_redirect_url = url.join(parsed_redirect_url)

is_same_host_https_redirect = (
url.host == parsed_redirect_url.host
and parsed_redirect_url.scheme == "https"
and url.scheme == "http"
)

try:
redirect_origin = parsed_redirect_url.origin()
except ValueError as origin_val_err:
Expand All @@ -886,10 +880,7 @@ async def _connect_and_send_request(
"Invalid redirect URL origin",
) from origin_val_err

if (
not is_same_host_https_redirect
and url.origin() != redirect_origin
):
if url.origin() != redirect_origin:
auth = None
headers.pop(hdrs.AUTHORIZATION, None)
headers.pop(hdrs.COOKIE, None)
Expand Down
3 changes: 2 additions & 1 deletion aiohttp/web_fileresponse.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ async def _sendfile(

loop = request._loop
transport = request.transport
assert transport is not None
if transport is None:
raise ConnectionResetError("Connection lost")

try:
await loop.sendfile(transport, fobj, offset, count)
Expand Down
3 changes: 2 additions & 1 deletion aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,8 @@ def _pre_start(self, request: BaseRequest) -> tuple[str | None, WebSocketWriter]
self.force_close()
self._compress = compress
transport = request._protocol.transport
assert transport is not None
if transport is None:
raise ConnectionResetError("Connection lost")
writer = WebSocketWriter(
request._protocol,
transport,
Expand Down
11 changes: 0 additions & 11 deletions docs/client_advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,6 @@ In cases where the authentication header value expires periodically, an
:mod:`asyncio` task may be used to update the session's default headers in the
background.

.. note::

The ``Authorization`` header will be removed if you get redirected
to a different host or protocol, except the case when HTTP → HTTPS
redirect is performed on the same host.

.. versionchanged:: 4.0

Started keeping the ``Authorization`` header during HTTP → HTTPS
redirects when the host remains the same.

.. _aiohttp-client-middleware:

Client Middleware
Expand Down
48 changes: 19 additions & 29 deletions tests/test_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3390,61 +3390,51 @@ def create(url: URL, srv: Handler) -> Awaitable[TestServer]:


@pytest.mark.parametrize(
["url_from_s", "url_to_s", "is_drop_header_expected"],
[
[
"http://host1.com/path1",
"http://host2.com/path2",
True,
],
["http://host1.com/path1", "https://host1.com/path1", False],
["https://host1.com/path1", "http://host1.com/path2", True],
],
("url_from_s", "url_to_s"),
(
("http://host1.com/path1", "http://host2.com/path2"),
("http://host1.com/path1", "https://host1.com/path1"),
("https://host1.com/path1", "http://host1.com/path2"),
("http://host1.com/path1", "https://host1.com:9443/path1"),
),
ids=(
"entirely different hosts",
"http -> https",
"https -> http",
"http -> https different port",
),
)
async def test_drop_auth_on_redirect_to_other_host(
create_server_for_url_and_handler: Callable[[URL, Handler], Awaitable[TestServer]],
url_from_s: str,
url_to_s: str,
is_drop_header_expected: bool,
) -> None:
url_from, url_to = URL(url_from_s), URL(url_to_s)

async def srv_from(request: web.Request) -> NoReturn:
assert request.host == url_from.host
assert request.host.split(":")[0] == url_from.host
assert request.headers["Authorization"] == "Basic dXNlcjpwYXNz"
raise web.HTTPFound(url_to)

async def srv_to(request: web.Request) -> web.Response:
assert request.host == url_to.host
if is_drop_header_expected:
assert "Authorization" not in request.headers, "Header wasn't dropped"
assert "Proxy-Authorization" not in request.headers
assert "Cookie" not in request.headers
else:
assert "Authorization" in request.headers, "Header was dropped"
assert "Proxy-Authorization" in request.headers
assert "Cookie" in request.headers
assert request.host.split(":")[0] == url_to.host
assert "Authorization" not in request.headers, "Header wasn't dropped"
assert "Proxy-Authorization" not in request.headers
assert "Cookie" not in request.headers
return web.Response()

server_from = await create_server_for_url_and_handler(url_from, srv_from)
server_to = await create_server_for_url_and_handler(url_to, srv_to)

assert (
url_from.host != url_to.host or server_from.scheme != server_to.scheme
), "Invalid test case, host or scheme must differ"
url_from.host != url_to.host
or server_from.scheme != server_to.scheme
or url_from.port != url_to.port
), "Invalid test case, host, scheme, or port must differ"

protocol_port_map = {
"http": 80,
"https": 443,
}
etc_hosts = {
(url_from.host, protocol_port_map[server_from.scheme]): server_from,
(url_to.host, protocol_port_map[server_to.scheme]): server_to,
(url_from.host, url_from.port): server_from,
(url_to.host, url_to.port): server_to,
}

class FakeResolver(AbstractResolver):
Expand Down
63 changes: 63 additions & 0 deletions tests/test_web_sendfile_functional.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import bz2
import contextlib
import gzip
import pathlib
import socket
Expand All @@ -15,6 +16,7 @@
from aiohttp.compression_utils import ZLibBackend
from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer
from aiohttp.typedefs import PathLike
from aiohttp.web_fileresponse import NOSENDFILE

try:
import brotlicffi as brotli
Expand Down Expand Up @@ -1156,3 +1158,64 @@ async def handler(request: web.Request) -> web.FileResponse:

resp.release()
await client.close()


@pytest.mark.skipif(NOSENDFILE, reason="OS sendfile not available")
async def test_sendfile_after_client_disconnect(
aiohttp_client: AiohttpClient, tmp_path: pathlib.Path
) -> None:
"""Test ConnectionResetError when client disconnects before sendfile.

Reproduces the race condition where:
- Client sends a GET request for a file
- Handler does async work (e.g. auth check) before returning a FileResponse
- Client disconnects while the handler is busy
- Server then calls sendfile() → ConnectionResetError (not AssertionError)

_send_headers_immediately is set to False so that super().prepare()
only buffers the headers without writing to the transport. Otherwise
_write() raises ClientConnectionResetError first and _sendfile()'s own
transport check is never reached.
"""
filepath = tmp_path / "test.txt"
filepath.write_bytes(b"x" * 1024)

handler_started = asyncio.Event()
prepare_done = asyncio.Event()
captured_protocol = None

async def handler(request: web.Request) -> web.Response:
nonlocal captured_protocol
resp = web.FileResponse(filepath)
resp._send_headers_immediately = False
captured_protocol = request._protocol
handler_started.set()
# Simulate async work (e.g., auth check) during which client disconnects.
await asyncio.sleep(0)
with pytest.raises(ConnectionResetError, match="Connection lost"):
await resp.prepare(request)
prepare_done.set()
return web.Response(status=503)

app = web.Application()
app.router.add_get("/", handler)
client = await aiohttp_client(app)

request_task = asyncio.create_task(client.get("/"))

# Wait until the handler is running but has not yet returned the response.
await handler_started.wait()
assert captured_protocol is not None

# Simulate the client disconnecting by setting transport to None directly.
# We cannot use force_close() because closing the TCP transport triggers
# connection_lost() which cancels the handler task before it can call
# prepare() and hit the ConnectionResetError in _sendfile().
captured_protocol.transport = None

# Wait for the handler to resume, call prepare(), and hit ConnectionResetError.
await asyncio.wait_for(prepare_done.wait(), timeout=1)

request_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await request_task
57 changes: 56 additions & 1 deletion tests/test_web_websocket_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pytest

import aiohttp
from aiohttp import WSServerHandshakeError, web
from aiohttp import WSServerHandshakeError, hdrs, web
from aiohttp.http import WSCloseCode, WSMsgType
from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer

Expand Down Expand Up @@ -1661,3 +1661,58 @@ async def websocket_handler(
assert msg.type is aiohttp.WSMsgType.TEXT
assert msg.data == "success"
await ws.close()


async def test_prepare_after_client_disconnect(aiohttp_client: AiohttpClient) -> None:
"""Test ConnectionResetError when client disconnects before ws.prepare().

Reproduces the race condition where:
- Client connects and sends a WebSocket upgrade request
- Handler starts async work (e.g. authentication) before calling ws.prepare()
- Client disconnects while the handler is busy
- Handler then calls ws.prepare() → ConnectionResetError (not AssertionError)
"""
handler_started = asyncio.Event()
captured_protocol = None

async def handler(request: web.Request) -> web.Response:
nonlocal captured_protocol
ws = web.WebSocketResponse()
captured_protocol = request._protocol
handler_started.set()
# Simulate async work (e.g., auth check) during which client disconnects.
await asyncio.sleep(0)
with pytest.raises(ConnectionResetError, match="Connection lost"):
await ws.prepare(request)
return web.Response(status=503)

app = web.Application()
app.router.add_route("GET", "/", handler)
client = await aiohttp_client(app)

request_task = asyncio.create_task(
client.session.get(
client.make_url("/"),
headers={
hdrs.UPGRADE: "websocket",
hdrs.CONNECTION: "Upgrade",
hdrs.SEC_WEBSOCKET_KEY: "dGhlIHNhbXBsZSBub25jZQ==",
hdrs.SEC_WEBSOCKET_VERSION: "13",
},
)
)

# Wait until the handler is running but has not yet called ws.prepare().
await handler_started.wait()
assert captured_protocol is not None

# Simulate the client disconnecting abruptly.
captured_protocol.force_close()

# Yield so the handler can resume and hit the ConnectionResetError.
await asyncio.sleep(0)

with contextlib.suppress(
aiohttp.ServerDisconnectedError, aiohttp.ClientConnectionResetError
):
await request_task
Loading