diff --git a/.github/workflows/ci-cd.yml b/.github/workflows/ci-cd.yml index 2da43d1bd41..38681d80718 100644 --- a/.github/workflows/ci-cd.yml +++ b/.github/workflows/ci-cd.yml @@ -237,7 +237,8 @@ jobs: PIP_USER: 1 run: >- PATH="${HOME}/Library/Python/3.11/bin:${HOME}/.local/bin:${PATH}" - pytest --junitxml=junit.xml -m 'not dev_mode and not autobahn' + pytest --junitxml=junit.xml --numprocesses=auto --cov=aiohttp/ --cov=tests/ + -m 'not dev_mode and not autobahn' shell: bash - name: Re-run the failing tests with maximum verbosity if: failure() @@ -245,7 +246,7 @@ jobs: COLOR: yes AIOHTTP_NO_EXTENSIONS: ${{ matrix.no-extensions }} run: >- # `exit 1` makes sure that the job remains red with flaky runs - pytest --no-cov --numprocesses=0 -vvvvv --lf && exit 1 + pytest --no-cov -vvvvv --lf && exit 1 shell: bash - name: Run dev_mode tests env: @@ -253,7 +254,7 @@ jobs: AIOHTTP_NO_EXTENSIONS: ${{ matrix.no-extensions }} PIP_USER: 1 PYTHONDEVMODE: 1 - run: pytest -m dev_mode --cov-append --numprocesses=0 + run: pytest -m dev_mode --cov=aiohttp/ --cov=tests/ --cov-append shell: bash - name: Turn coverage into xml env: @@ -345,7 +346,7 @@ jobs: PIP_USER: 1 run: >- PATH="${HOME}/Library/Python/3.11/bin:${HOME}/.local/bin:${PATH}" - pytest --junitxml=junit.xml --numprocesses=0 -m autobahn + pytest --junitxml=junit.xml --cov=aiohttp/ --cov=tests/ -m autobahn shell: bash - name: Turn coverage into xml env: @@ -413,7 +414,7 @@ jobs: uses: CodSpeedHQ/action@v4 with: mode: instrumentation - run: python -Im pytest --no-cov --numprocesses=0 -vvvvv --codspeed + run: python -Im pytest --no-cov -vvvvv --codspeed cython-coverage: @@ -462,7 +463,7 @@ jobs: PIP_USER: 1 run: >- pytest tests/test_client_functional.py tests/test_http_parser.py tests/test_http_writer.py tests/test_web_functional.py tests/test_web_response.py tests/test_websocket_parser.py - --cov-config=.coveragerc-cython.toml + --cov-config=.coveragerc-cython.toml --cov=aiohttp/ --cov=tests/ --numprocesses=auto -m 'not dev_mode and not autobahn' shell: bash - name: Turn coverage into xml diff --git a/CHANGES/10600.bugfix.rst b/CHANGES/10600.bugfix.rst new file mode 100644 index 00000000000..eba47bf56e6 --- /dev/null +++ b/CHANGES/10600.bugfix.rst @@ -0,0 +1,2 @@ +Fixed http parser not rejecting HTTP/1.1 requests that do not have valid Host header. +-- by :user:`Cycloctane`. diff --git a/CHANGES/12364.contrib.rst b/CHANGES/12364.contrib.rst new file mode 100644 index 00000000000..21b9eb1b271 --- /dev/null +++ b/CHANGES/12364.contrib.rst @@ -0,0 +1 @@ +Disabled ``coverage`` and ``xdist`` by default to ease local development -- by :user:`Dreamsorcerer`. diff --git a/aiohttp/_http_parser.pyx b/aiohttp/_http_parser.pyx index 719387493f5..ac942ec1076 100644 --- a/aiohttp/_http_parser.pyx +++ b/aiohttp/_http_parser.pyx @@ -457,6 +457,7 @@ cdef class HttpParser: cdef _on_headers_complete(self): self._process_header() + http_version = self.http_version() should_close = not cparser.llhttp_should_keep_alive(self._cparser) upgrade = self._cparser.upgrade chunked = self._cparser.flags & cparser.F_CHUNKED @@ -465,6 +466,8 @@ cdef class HttpParser: headers = CIMultiDictProxy(CIMultiDict(self._headers)) if self._cparser.type == cparser.HTTP_REQUEST: + if http_version == HttpVersion11 and hdrs.HOST not in headers: + raise BadHttpMessage("Missing 'Host' header in request.") h_upg = headers.get("upgrade", "") allowed = upgrade and h_upg.isascii() and h_upg.lower() in ALLOWED_UPGRADES if allowed or self._cparser.method == cparser.HTTP_CONNECT: @@ -488,11 +491,11 @@ cdef class HttpParser: method = http_method_str(self._cparser.method) msg = _new_request_message( method, self._path, - self.http_version(), headers, raw_headers, + http_version, headers, raw_headers, should_close, encoding, upgrade, chunked, self._url) else: msg = _new_response_message( - self.http_version(), self._cparser.status_code, self._reason, + http_version, self._cparser.status_code, self._reason, headers, raw_headers, should_close, encoding, upgrade, chunked) diff --git a/aiohttp/_websocket/writer.py b/aiohttp/_websocket/writer.py index df89aabbd5b..53ff9b246b3 100644 --- a/aiohttp/_websocket/writer.py +++ b/aiohttp/_websocket/writer.py @@ -9,6 +9,7 @@ from ..base_protocol import BaseProtocol from ..client_exceptions import ClientConnectionResetError from ..compression_utils import ZLibBackend, ZLibCompressor +from ..helpers import DEFAULT_CHUNK_SIZE from .helpers import ( MASK_LEN, MSG_SIZE, @@ -21,8 +22,6 @@ ) from .models import WS_DEFLATE_TRAILING, WSMsgType -DEFAULT_LIMIT: Final[int] = 2**18 - # WebSocket opcode boundary: opcodes 0-7 are data frames, 8-15 are control frames # Control frames (ping, pong, close) are never compressed WS_CONTROL_FRAME_OPCODE: Final[int] = 8 @@ -52,7 +51,7 @@ def __init__( transport: asyncio.Transport, *, use_mask: bool = False, - limit: int = DEFAULT_LIMIT, + limit: int = DEFAULT_CHUNK_SIZE, random: random.Random = random.Random(), compress: int = 0, notakeover: bool = False, diff --git a/aiohttp/client.py b/aiohttp/client.py index 9bd9af10bf2..f11a994be92 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -92,6 +92,7 @@ from .cookiejar import CookieJar from .helpers import ( _SENTINEL, + DEFAULT_CHUNK_SIZE, EMPTY_BODY_METHODS, BasicAuth, TimeoutHandle, @@ -331,7 +332,7 @@ def __init__( trust_env: bool = False, requote_redirect_url: bool = True, trace_configs: list[TraceConfig[object]] | None = None, - read_bufsize: int = 2**18, + read_bufsize: int = DEFAULT_CHUNK_SIZE, max_line_size: int = 8190, max_field_size: int = 8190, max_headers: int = 128, @@ -1226,7 +1227,7 @@ async def _ws_connect( transport = conn.transport assert transport is not None - reader = WebSocketDataQueue(conn_proto, 2**18, loop=self._loop) + reader = WebSocketDataQueue(conn_proto, DEFAULT_CHUNK_SIZE, loop=self._loop) writer = WebSocketWriter( conn_proto, transport, diff --git a/aiohttp/client_proto.py b/aiohttp/client_proto.py index 19bd8564ca6..fe604275959 100644 --- a/aiohttp/client_proto.py +++ b/aiohttp/client_proto.py @@ -13,6 +13,7 @@ ) from .helpers import ( _EXC_SENTINEL, + DEFAULT_CHUNK_SIZE, EMPTY_BODY_STATUS_CODES, BaseTimerContext, ErrorableProtocol, @@ -231,7 +232,7 @@ def set_response_params( read_until_eof: bool = False, auto_decompress: bool = True, read_timeout: float | None = None, - read_bufsize: int = 2**18, + read_bufsize: int = DEFAULT_CHUNK_SIZE, timeout_ceil_threshold: float = 5, max_line_size: int = 8190, max_field_size: int = 8190, diff --git a/aiohttp/compression_utils.py b/aiohttp/compression_utils.py index 5e12337113c..8c6dd5877b7 100644 --- a/aiohttp/compression_utils.py +++ b/aiohttp/compression_utils.py @@ -34,9 +34,6 @@ MAX_SYNC_CHUNK_SIZE = 4096 -# Matches the max size we receive from sockets: -# https://github.com/python/cpython/blob/1857a40807daeae3a1bf5efb682de9c9ae6df845/Lib/asyncio/selector_events.py#L766 -DEFAULT_MAX_DECOMPRESS_SIZE = 256 * 1024 # Unlimited decompression constants - different libraries use different conventions ZLIB_MAX_LENGTH_UNLIMITED = 0 # zlib uses 0 to mean unlimited diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index f50496e9321..9f27e3860a8 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -66,6 +66,10 @@ __all__ = ("BasicAuth", "ChainMapProxy", "ETag", "frozen_dataclass_decorator", "reify") +# This is the default size/limit for several operations. +# Matches the max size we receive from sockets: +# https://github.com/python/cpython/blob/1857a40807daeae3a1bf5efb682de9c9ae6df845/Lib/asyncio/selector_events.py#L766 +DEFAULT_CHUNK_SIZE = 2**18 # 256 KiB COOKIE_MAX_LENGTH = 4096 _T = TypeVar("_T") diff --git a/aiohttp/http_parser.py b/aiohttp/http_parser.py index 4601f201122..43485354bb5 100644 --- a/aiohttp/http_parser.py +++ b/aiohttp/http_parser.py @@ -22,7 +22,6 @@ from . import hdrs from .base_protocol import BaseProtocol from .compression_utils import ( - DEFAULT_MAX_DECOMPRESS_SIZE, HAS_BROTLI, HAS_ZSTD, BrotliDecompressor, @@ -32,6 +31,7 @@ from .helpers import ( _EXC_SENTINEL, DEBUG, + DEFAULT_CHUNK_SIZE, EMPTY_BODY_METHODS, EMPTY_BODY_STATUS_CODES, NO_EXTENSIONS, @@ -49,7 +49,7 @@ LineTooLong, TransferEncodingError, ) -from .http_writer import HttpVersion, HttpVersion10 +from .http_writer import HttpVersion, HttpVersion10, HttpVersion11 from .streams import EMPTY_PAYLOAD, StreamReader from .typedefs import RawHeaders @@ -672,6 +672,9 @@ def parse_message(self, lines: list[bytes]) -> RawRequestMessage: chunked, ) = self.parse_headers(lines[1:]) + if version_o == HttpVersion11 and hdrs.HOST not in headers: + raise BadHttpMessage("Missing 'Host' header in request.") + if close is None: # then the headers weren't set in the request if version_o <= HttpVersion10: # HTTP 1.0 must asks to not close close = True @@ -810,7 +813,7 @@ def __init__( max_line_size: int = 8190, max_field_size: int = 8190, max_trailers: int = 128, - limit: int = DEFAULT_MAX_DECOMPRESS_SIZE, + limit: int = DEFAULT_CHUNK_SIZE, ) -> None: self._length = 0 self._paused = False @@ -1061,7 +1064,7 @@ def __init__( self, out: StreamReader, encoding: str | None, - max_decompress_size: int = DEFAULT_MAX_DECOMPRESS_SIZE, + max_decompress_size: int = DEFAULT_CHUNK_SIZE, ) -> None: self.out = out self.size = 0 diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index 5bfce9e4074..9d5e5d27b84 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -14,11 +14,7 @@ from multidict import CIMultiDict, CIMultiDictProxy from .abc import AbstractStreamWriter -from .compression_utils import ( - DEFAULT_MAX_DECOMPRESS_SIZE, - ZLibCompressor, - ZLibDecompressor, -) +from .compression_utils import ZLibCompressor, ZLibDecompressor from .hdrs import ( CONTENT_DISPOSITION, CONTENT_ENCODING, @@ -26,7 +22,7 @@ CONTENT_TRANSFER_ENCODING, CONTENT_TYPE, ) -from .helpers import CHAR, TOKEN, parse_mimetype, reify +from .helpers import CHAR, DEFAULT_CHUNK_SIZE, TOKEN, parse_mimetype, reify from .http import HeadersParser from .http_exceptions import BadHttpMessage from .log import internal_logger @@ -267,7 +263,7 @@ def __init__( *, subtype: str = "mixed", default_charset: str | None = None, - max_decompress_size: int = DEFAULT_MAX_DECOMPRESS_SIZE, + max_decompress_size: int = DEFAULT_CHUNK_SIZE, client_max_size: int = sys.maxsize, max_size_error_cls: type[Exception] = ValueError, ) -> None: @@ -641,7 +637,7 @@ async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> byt async def write(self, writer: AbstractStreamWriter) -> None: field = self._value - while chunk := await field.read_chunk(size=2**18): + while chunk := await field.read_chunk(size=DEFAULT_CHUNK_SIZE): async for d in field.decode_iter(chunk): await writer.write(d) diff --git a/aiohttp/payload.py b/aiohttp/payload.py index 71c015499a6..b435a7ff146 100644 --- a/aiohttp/payload.py +++ b/aiohttp/payload.py @@ -17,6 +17,7 @@ from .abc import AbstractStreamWriter from .helpers import ( _SENTINEL, + DEFAULT_CHUNK_SIZE, content_disposition_header, guess_filename, parse_mimetype, @@ -43,7 +44,6 @@ ) TOO_LARGE_BYTES_BODY: Final[int] = 2**20 # 1 MB -READ_SIZE: Final[int] = 2**18 # 256 KiB _CLOSE_FUTURES: set[asyncio.Future[None]] = set() @@ -489,7 +489,7 @@ def _read_and_available_len( Args: remaining_content_len: Optional limit on how many bytes to read in this operation. - If None, READ_SIZE will be used as the default chunk size. + If None, DEFAULT_CHUNK_SIZE will be used as the default chunk size. Returns: A tuple containing: @@ -504,7 +504,11 @@ def _read_and_available_len( self._set_or_restore_start_position() size = self.size # Call size only once since it does I/O return size, self._value.read( - min(READ_SIZE, size or READ_SIZE, remaining_content_len or READ_SIZE) + min( + DEFAULT_CHUNK_SIZE, + size or DEFAULT_CHUNK_SIZE, + remaining_content_len or DEFAULT_CHUNK_SIZE, + ) ) def _read(self, remaining_content_len: int | None) -> bytes: @@ -513,7 +517,7 @@ def _read(self, remaining_content_len: int | None) -> bytes: Args: remaining_content_len: Optional maximum number of bytes to read. - If None, READ_SIZE will be used as the default chunk size. + If None, DEFAULT_CHUNK_SIZE will be used as the default chunk size. Returns: A chunk of bytes read from the file object, respecting the @@ -523,7 +527,7 @@ def _read(self, remaining_content_len: int | None) -> bytes: the initial _read_and_available_len call has been made. """ - return self._value.read(remaining_content_len or READ_SIZE) # type: ignore[no-any-return] + return self._value.read(remaining_content_len or DEFAULT_CHUNK_SIZE) # type: ignore[no-any-return] @property def size(self) -> int | None: @@ -626,9 +630,9 @@ async def write_with_length( None, self._read, ( - min(READ_SIZE, remaining_content_len) + min(DEFAULT_CHUNK_SIZE, remaining_content_len) if remaining_content_len is not None - else READ_SIZE + else DEFAULT_CHUNK_SIZE ), ) @@ -753,7 +757,7 @@ def _read_and_available_len( Args: remaining_content_len: Optional limit on how many bytes to read in this operation. - If None, READ_SIZE will be used as the default chunk size. + If None, DEFAULT_CHUNK_SIZE will be used as the default chunk size. Returns: A tuple containing: @@ -772,7 +776,11 @@ def _read_and_available_len( self._set_or_restore_start_position() size = self.size chunk = self._value.read( - min(READ_SIZE, size or READ_SIZE, remaining_content_len or READ_SIZE) + min( + DEFAULT_CHUNK_SIZE, + size or DEFAULT_CHUNK_SIZE, + remaining_content_len or DEFAULT_CHUNK_SIZE, + ) ) return size, chunk.encode(self._encoding) if self._encoding else chunk.encode() @@ -782,7 +790,7 @@ def _read(self, remaining_content_len: int | None) -> bytes: Args: remaining_content_len: Optional maximum number of bytes to read. - If None, READ_SIZE will be used as the default chunk size. + If None, DEFAULT_CHUNK_SIZE will be used as the default chunk size. Returns: A chunk of bytes read from the file object and encoded using the payload's @@ -794,7 +802,7 @@ def _read(self, remaining_content_len: int | None) -> bytes: the specified encoding (or UTF-8 if none was provided). """ - chunk = self._value.read(remaining_content_len or READ_SIZE) + chunk = self._value.read(remaining_content_len or DEFAULT_CHUNK_SIZE) return chunk.encode(self._encoding) if self._encoding else chunk.encode() def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: @@ -878,7 +886,7 @@ async def write_with_length( self._set_or_restore_start_position() loop_count = 0 remaining_bytes = content_length - while chunk := self._value.read(READ_SIZE): + while chunk := self._value.read(DEFAULT_CHUNK_SIZE): if loop_count > 0: # Avoid blocking the event loop # if they pass a large BytesIO object diff --git a/aiohttp/streams.py b/aiohttp/streams.py index b9367066291..219d6b7c535 100644 --- a/aiohttp/streams.py +++ b/aiohttp/streams.py @@ -7,6 +7,7 @@ from .base_protocol import BaseProtocol from .helpers import ( _EXC_SENTINEL, + DEFAULT_CHUNK_SIZE, BaseTimerContext, TimerNoop, set_exception, @@ -165,7 +166,7 @@ def __repr__(self) -> str: info.append("%d bytes" % self._size) if self._eof: info.append("eof") - if self._low_water != 2**18: # default limit + if self._low_water != DEFAULT_CHUNK_SIZE: info.append("low=%d high=%d" % (self._low_water, self._high_water)) if self._waiter: info.append("w=%r" % self._waiter) diff --git a/aiohttp/web_fileresponse.py b/aiohttp/web_fileresponse.py index f339bec9662..1b4495ea4e1 100644 --- a/aiohttp/web_fileresponse.py +++ b/aiohttp/web_fileresponse.py @@ -13,7 +13,7 @@ from . import hdrs from .abc import AbstractStreamWriter -from .helpers import ETAG_ANY, ETag, must_be_empty_body +from .helpers import DEFAULT_CHUNK_SIZE, ETAG_ANY, ETag, must_be_empty_body from .typedefs import LooseHeaders, PathLike from .web_exceptions import ( HTTPForbidden, @@ -82,7 +82,7 @@ class FileResponse(StreamResponse): def __init__( self, path: PathLike, - chunk_size: int = 256 * 1024, + chunk_size: int = DEFAULT_CHUNK_SIZE, status: int = 200, reason: str | None = None, headers: LooseHeaders | None = None, diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index 9785c13fa4f..0245d7b776c 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -15,7 +15,7 @@ from .abc import AbstractAccessLogger, AbstractAsyncAccessLogger, AbstractStreamWriter from .base_protocol import BaseProtocol -from .helpers import ceil_timeout, frozen_dataclass_decorator +from .helpers import DEFAULT_CHUNK_SIZE, ceil_timeout, frozen_dataclass_decorator from .http import ( HttpProcessingError, HttpRequestParser, @@ -202,7 +202,7 @@ def __init__( max_headers: int = 128, max_field_size: int = 8190, lingering_time: float = 10.0, - read_bufsize: int = 2**18, + read_bufsize: int = DEFAULT_CHUNK_SIZE, auto_decompress: bool = True, timeout_ceil_threshold: float = 5, ): diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index b8feae19cec..5871390529c 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -21,6 +21,7 @@ from .abc import AbstractStreamWriter from .helpers import ( _SENTINEL, + DEFAULT_CHUNK_SIZE, ETAG_ANY, LIST_QUOTED_ETAG_RE, ChainMapProxy, @@ -719,7 +720,7 @@ async def post(self) -> "MultiDictProxy[str | bytes | FileField]": tmp = await self._loop.run_in_executor( None, tempfile.TemporaryFile ) - while chunk := await field.read_chunk(size=2**18): + while chunk := await field.read_chunk(size=DEFAULT_CHUNK_SIZE): async for decoded_chunk in field.decode_iter(chunk): await self._loop.run_in_executor( None, tmp.write, decoded_chunk diff --git a/aiohttp/web_urldispatcher.py b/aiohttp/web_urldispatcher.py index b1d060bd0f3..19913f6806e 100644 --- a/aiohttp/web_urldispatcher.py +++ b/aiohttp/web_urldispatcher.py @@ -29,7 +29,7 @@ from . import hdrs from .abc import AbstractMatchInfo, AbstractRouter, AbstractView -from .helpers import DEBUG +from .helpers import DEBUG, DEFAULT_CHUNK_SIZE from .http import HttpVersion11 from .typedefs import Handler, PathLike from .web_exceptions import ( @@ -507,7 +507,7 @@ def __init__( *, name: str | None = None, expect_handler: _ExpectHandler | None = None, - chunk_size: int = 256 * 1024, + chunk_size: int = DEFAULT_CHUNK_SIZE, show_index: bool = False, follow_symlinks: bool = False, append_version: bool = False, @@ -1133,7 +1133,7 @@ def add_static( *, name: str | None = None, expect_handler: _ExpectHandler | None = None, - chunk_size: int = 256 * 1024, + chunk_size: int = DEFAULT_CHUNK_SIZE, show_index: bool = False, follow_symlinks: bool = False, append_version: bool = False, diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 1a7622b8421..9ec478e46f5 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -11,10 +11,10 @@ from . import hdrs from ._websocket.reader import WebSocketDataQueue -from ._websocket.writer import DEFAULT_LIMIT from .abc import AbstractStreamWriter from .client_exceptions import WSMessageTypeError from .helpers import ( + DEFAULT_CHUNK_SIZE, calculate_timeout_when, frozen_dataclass_decorator, set_exception, @@ -107,7 +107,7 @@ def __init__( protocols: Iterable[str] = (), compress: bool = True, max_msg_size: int = 4 * 1024 * 1024, - writer_limit: int = DEFAULT_LIMIT, + writer_limit: int = DEFAULT_CHUNK_SIZE, decode_text: bool = True, ) -> None: super().__init__(status=101) @@ -383,7 +383,9 @@ def _post_start( loop = self._loop assert loop is not None - self._reader = WebSocketDataQueue(request._protocol, 2**18, loop=loop) + self._reader = WebSocketDataQueue( + request._protocol, DEFAULT_CHUNK_SIZE, loop=loop + ) parser = WebSocketReader( self._reader, self._max_msg_size, diff --git a/docs/client_quickstart.rst b/docs/client_quickstart.rst index 25d0c7cc6a6..d80f0393d8b 100644 --- a/docs/client_quickstart.rst +++ b/docs/client_quickstart.rst @@ -122,8 +122,9 @@ that case you can specify multiple values for each key:: expect = 'http://httpbin.org/get?key=value2&key=value1' assert str(r.url) == expect -You can also pass :class:`str` content as param, but beware -- content -is not encoded by library. Note that ``+`` is not encoded:: +You can also pass :class:`str` content as param. The value is used as a +query string, but passing ``params`` does not disable URL +canonicalization. Note that ``+`` is not encoded:: async with session.get('http://httpbin.org/get', params='key=value+1') as r: @@ -149,7 +150,9 @@ is not encoded by library. Note that ``+`` is not encoded:: .. warning:: - Passing *params* overrides ``encoded=True``, never use both options. + Passing *params* overrides ``encoded=True``. Never use both options + if you need to preserve exact query-string bytes. + Build the full URL (including query) instead. Response Content and Status Code ================================ diff --git a/requirements/constraints.txt b/requirements/constraints.txt index 48487702bba..adbe2a66f7f 100644 --- a/requirements/constraints.txt +++ b/requirements/constraints.txt @@ -175,7 +175,7 @@ pyproject-hooks==1.2.0 # via # build # pip-tools -pytest==9.0.2 +pytest==9.0.3 # via # -r requirements/lint.in # -r requirements/test-common.in diff --git a/requirements/dev.txt b/requirements/dev.txt index 459b617147b..aada534b624 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -170,7 +170,7 @@ pyproject-hooks==1.2.0 # via # build # pip-tools -pytest==9.0.2 +pytest==9.0.3 # via # -r requirements/lint.in # -r requirements/test-common.in diff --git a/requirements/lint.txt b/requirements/lint.txt index 8cfa2d2c750..3db11d0a1cc 100644 --- a/requirements/lint.txt +++ b/requirements/lint.txt @@ -83,7 +83,7 @@ pygments==2.20.0 # via # pytest # rich -pytest==9.0.2 +pytest==9.0.3 # via # -r requirements/lint.in # pytest-codspeed diff --git a/requirements/test-common.txt b/requirements/test-common.txt index c0092cff209..e74c0000f40 100644 --- a/requirements/test-common.txt +++ b/requirements/test-common.txt @@ -66,7 +66,7 @@ pygments==2.20.0 # via # pytest # rich -pytest==9.0.2 +pytest==9.0.3 # via # -r requirements/test-common.in # pytest-codspeed diff --git a/requirements/test-ft.txt b/requirements/test-ft.txt index ba5272d77e1..1c0a4482145 100644 --- a/requirements/test-ft.txt +++ b/requirements/test-ft.txt @@ -99,7 +99,7 @@ pygments==2.20.0 # via # pytest # rich -pytest==9.0.2 +pytest==9.0.3 # via # -r requirements/test-common.in # pytest-codspeed diff --git a/requirements/test.txt b/requirements/test.txt index 232e50ac2a9..9d9c38cdbb2 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -99,7 +99,7 @@ pygments==2.20.0 # via # pytest # rich -pytest==9.0.2 +pytest==9.0.3 # via # -r requirements/test-common.in # pytest-codspeed diff --git a/setup.cfg b/setup.cfg index 203c01c3754..e84f57107b3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -41,9 +41,6 @@ exclude_lines = [tool:pytest] addopts = - # `pytest-xdist`: - --numprocesses=auto - # show 10 slowest invocations: --durations=10 @@ -56,11 +53,6 @@ addopts = # show values of the local vars in errors: --showlocals - # `pytest-cov`: - -p pytest_cov - --cov=aiohttp - --cov=tests/ - -m "not dev_mode and not autobahn and not internal" filterwarnings = error diff --git a/tests/test_benchmarks_http_websocket.py b/tests/test_benchmarks_http_websocket.py index 10115c1a2bd..61b23125460 100644 --- a/tests/test_benchmarks_http_websocket.py +++ b/tests/test_benchmarks_http_websocket.py @@ -8,6 +8,7 @@ from aiohttp._websocket.helpers import MSG_SIZE, PACK_LEN3 from aiohttp._websocket.reader import WebSocketDataQueue from aiohttp.base_protocol import BaseProtocol +from aiohttp.helpers import DEFAULT_CHUNK_SIZE from aiohttp.http_websocket import WebSocketReader, WebSocketWriter, WSMsgType @@ -15,8 +16,8 @@ def test_read_large_binary_websocket_messages( loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture ) -> None: """Read one hundred large binary websocket messages.""" - queue = WebSocketDataQueue(BaseProtocol(loop), 2**18, loop=loop) - reader = WebSocketReader(queue, max_msg_size=2**18) + queue = WebSocketDataQueue(BaseProtocol(loop), DEFAULT_CHUNK_SIZE, loop=loop) + reader = WebSocketReader(queue, max_msg_size=DEFAULT_CHUNK_SIZE) # PACK3 has a minimum message length of 2**16 bytes. message = b"x" * ((2**16) + 1) @@ -36,8 +37,8 @@ def test_read_one_hundred_websocket_text_messages( loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture ) -> None: """Benchmark reading 100 WebSocket text messages.""" - queue = WebSocketDataQueue(BaseProtocol(loop), 2**18, loop=loop) - reader = WebSocketReader(queue, max_msg_size=2**18) + queue = WebSocketDataQueue(BaseProtocol(loop), DEFAULT_CHUNK_SIZE, loop=loop) + reader = WebSocketReader(queue, max_msg_size=DEFAULT_CHUNK_SIZE) raw_message = ( b'\x81~\x01!{"id":1,"src":"shellyplugus-c049ef8c30e4","dst":"aios-1453812500' b'8","result":{"name":null,"id":"shellyplugus-c049ef8c30e4","mac":"C049EF8C30E' diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 80e95c29512..51dfc6d44c9 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -52,7 +52,7 @@ TooManyRedirects, ) from aiohttp.client_reqrep import ClientRequest -from aiohttp.compression_utils import DEFAULT_MAX_DECOMPRESS_SIZE +from aiohttp.helpers import DEFAULT_CHUNK_SIZE from aiohttp.payload import ( AsyncIterablePayload, BufferedReaderPayload, @@ -2410,7 +2410,7 @@ async def test_payload_decompress_size_limit(aiohttp_client: AiohttpClient) -> N payload_size = 64 * 2**20 original = b"A" * payload_size compressed = zlib.compress(original) - assert len(original) > DEFAULT_MAX_DECOMPRESS_SIZE + assert len(original) > DEFAULT_CHUNK_SIZE async def handler(request: web.Request) -> web.Response: # Send compressed data with Content-Encoding header @@ -2442,7 +2442,7 @@ async def test_payload_decompress_size_limit_brotli( payload_size = 64 * 2**20 original = b"A" * payload_size compressed = brotli.compress(original) - assert len(original) > DEFAULT_MAX_DECOMPRESS_SIZE + assert len(original) > DEFAULT_CHUNK_SIZE async def handler(request: web.Request) -> web.Response: resp = web.Response(body=compressed) @@ -2474,7 +2474,7 @@ async def test_payload_decompress_size_limit_zstd( original = b"A" * payload_size compressor = ZstdCompressor() compressed = compressor.compress(original) + compressor.flush() - assert len(original) > DEFAULT_MAX_DECOMPRESS_SIZE + assert len(original) > DEFAULT_CHUNK_SIZE async def handler(request: web.Request) -> web.Response: resp = web.Response(body=compressed) diff --git a/tests/test_http_parser.py b/tests/test_http_parser.py index 35fce70ef5f..0e66b18c625 100644 --- a/tests/test_http_parser.py +++ b/tests/test_http_parser.py @@ -19,7 +19,7 @@ from aiohttp import http_exceptions, streams from aiohttp.base_protocol import BaseProtocol from aiohttp.client_proto import ResponseHandler -from aiohttp.helpers import NO_EXTENSIONS +from aiohttp.helpers import DEFAULT_CHUNK_SIZE, NO_EXTENSIONS from aiohttp.http_parser import ( DeflateBuffer, HeadersParser, @@ -100,7 +100,7 @@ def parser( parser = request.param( protocol, loop, - 2**18, + DEFAULT_CHUNK_SIZE, max_line_size=8190, max_headers=128, max_field_size=8190, @@ -128,7 +128,7 @@ def response( parser = request.param( protocol, loop, - 2**18, + DEFAULT_CHUNK_SIZE, max_line_size=8190, max_headers=128, max_field_size=8190, @@ -154,6 +154,7 @@ def test_c_parser_loaded() -> None: def test_parse_headers(parser: HttpRequestParser) -> None: text = b"""GET /test HTTP/1.1\r +Host: a\r test: a line\r test2: data\r \r @@ -162,8 +163,16 @@ def test_parse_headers(parser: HttpRequestParser) -> None: assert len(messages) == 1 msg = messages[0][0] - assert list(msg.headers.items()) == [("test", "a line"), ("test2", "data")] - assert msg.raw_headers == ((b"test", b"a line"), (b"test2", b"data")) + assert list(msg.headers.items()) == [ + ("Host", "a"), + ("test", "a line"), + ("test2", "data"), + ] + assert msg.raw_headers == ( + (b"Host", b"a"), + (b"test", b"a line"), + (b"test2", b"data"), + ) assert not msg.should_close assert msg.compression is None assert not msg.upgrade @@ -171,6 +180,7 @@ def test_parse_headers(parser: HttpRequestParser) -> None: def test_reject_obsolete_line_folding(parser: HttpRequestParser) -> None: text = b"""GET /test HTTP/1.1\r +Host: a\r test: line\r Content-Length: 48\r test2: data\r @@ -243,7 +253,7 @@ def test_cve_2023_37276(parser: HttpRequestParser) -> None: def test_bad_header_name( parser: HttpRequestParser, rfc9110_5_6_2_token_delim: str ) -> None: - text = f"POST / HTTP/1.1\r\nhead{rfc9110_5_6_2_token_delim}er: val\r\n\r\n".encode() + text = f"POST / HTTP/1.1\r\nHost: a\r\nhead{rfc9110_5_6_2_token_delim}er: val\r\n\r\n".encode() if rfc9110_5_6_2_token_delim == ":": # Inserting colon into header just splits name/value earlier. parser.feed_data(text) @@ -272,7 +282,7 @@ def test_bad_header_name( ), ) def test_bad_headers(parser: HttpRequestParser, hdr: str) -> None: - text = f"POST / HTTP/1.1\r\n{hdr}\r\n\r\n".encode() + text = f"POST / HTTP/1.1\r\nHost: a\r\n{hdr}\r\n\r\n".encode() with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) @@ -297,7 +307,7 @@ def test_unpaired_surrogate_in_header_py( max_field_size=8190, ) protocol._parser = parser - text = b"POST / HTTP/1.1\r\n\xff\r\n\r\n" + text = b"POST / HTTP/1.1\r\nHost: a\r\n\xff\r\n\r\n" message = None try: parser.feed_data(text) @@ -400,6 +410,12 @@ def test_duplicate_host_header_rejected(parser: HttpRequestParser) -> None: parser.feed_data(text) +def test_missing_host_header_rejected(parser: HttpRequestParser) -> None: + text = b"GET /admin HTTP/1.1\r\n\r\n" + with pytest.raises(http_exceptions.BadHttpMessage, match="Missing 'Host' header"): + parser.feed_data(text) + + @pytest.mark.parametrize( ("hdr1", "hdr2"), ( @@ -450,7 +466,7 @@ def test_bad_chunked(parser: HttpRequestParser) -> None: def test_whitespace_before_header(parser: HttpRequestParser) -> None: - text = b"GET / HTTP/1.1\r\n\tContent-Length: 1\r\n\r\nX" + text = b"GET / HTTP/1.1\r\nHost: a\r\n\tContent-Length: 1\r\n\r\nX" with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) @@ -481,7 +497,7 @@ def test_parse_unusual_request_line(parser: HttpRequestParser) -> None: def test_parse(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) assert len(messages) == 1 msg, _ = messages[0] @@ -490,10 +506,11 @@ def test_parse(parser: HttpRequestParser) -> None: assert msg.method == "GET" assert msg.path == "/test" assert msg.version == (1, 1) + assert msg.headers["Host"] == "a" async def test_parse_body(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody" + text = b"GET /test HTTP/1.1\r\nHost: a\r\nContent-Length: 4\r\n\r\nbody" messages, upgrade, tail = parser.feed_data(text) assert len(messages) == 1 _, payload = messages[0] @@ -502,7 +519,7 @@ async def test_parse_body(parser: HttpRequestParser) -> None: async def test_parse_body_with_CRLF(parser: HttpRequestParser) -> None: - text = b"\r\nGET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody" + text = b"\r\nGET /test HTTP/1.1\r\nHost: a\r\nContent-Length: 4\r\n\r\nbody" messages, upgrade, tail = parser.feed_data(text) assert len(messages) == 1 _, payload = messages[0] @@ -511,7 +528,7 @@ async def test_parse_body_with_CRLF(parser: HttpRequestParser) -> None: def test_parse_delayed(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\n" messages, upgrade, tail = parser.feed_data(text) assert len(messages) == 0 assert not upgrade @@ -524,8 +541,9 @@ def test_parse_delayed(parser: HttpRequestParser) -> None: def test_headers_multi_feed(parser: HttpRequestParser) -> None: text1 = b"GET /test HTTP/1.1\r\n" - text2 = b"test: line" - text3 = b" continue\r\n\r\n" + text2 = b"Host: a\r\n" + text3 = b"test: line" + text4 = b" continue\r\n\r\n" messages, upgrade, tail = parser.feed_data(text1) assert len(messages) == 0 @@ -534,18 +552,21 @@ def test_headers_multi_feed(parser: HttpRequestParser) -> None: assert len(messages) == 0 messages, upgrade, tail = parser.feed_data(text3) + assert len(messages) == 0 + + messages, upgrade, tail = parser.feed_data(text4) assert len(messages) == 1 msg = messages[0][0] - assert list(msg.headers.items()) == [("test", "line continue")] - assert msg.raw_headers == ((b"test", b"line continue"),) + assert list(msg.headers.items()) == [("Host", "a"), ("test", "line continue")] + assert msg.raw_headers == ((b"Host", b"a"), (b"test", b"line continue")) assert not msg.should_close assert msg.compression is None assert not msg.upgrade def test_headers_split_field(parser: HttpRequestParser) -> None: - text1 = b"GET /test HTTP/1.1\r\n" + text1 = b"GET /test HTTP/1.1\r\nHost: a\r\n" text2 = b"t" text3 = b"es" text4 = b"t: value\r\n\r\n" @@ -558,8 +579,8 @@ def test_headers_split_field(parser: HttpRequestParser) -> None: assert len(messages) == 1 msg = messages[0][0] - assert list(msg.headers.items()) == [("test", "value")] - assert msg.raw_headers == ((b"test", b"value"),) + assert list(msg.headers.items()) == [("Host", "a"), ("test", "value")] + assert msg.raw_headers == ((b"Host", b"a"), (b"test", b"value")) assert not msg.should_close assert msg.compression is None assert not msg.upgrade @@ -567,7 +588,7 @@ def test_headers_split_field(parser: HttpRequestParser) -> None: def test_parse_headers_multi(parser: HttpRequestParser) -> None: text = ( - b"GET /test HTTP/1.1\r\n" + b"GET /test HTTP/1.1\r\nHost: a\r\n" b"Set-Cookie: c1=cookie1\r\n" b"Set-Cookie: c2=cookie2\r\n\r\n" ) @@ -577,10 +598,12 @@ def test_parse_headers_multi(parser: HttpRequestParser) -> None: msg = messages[0][0] assert list(msg.headers.items()) == [ + ("Host", "a"), ("Set-Cookie", "c1=cookie1"), ("Set-Cookie", "c2=cookie2"), ] assert msg.raw_headers == ( + (b"Host", b"a"), (b"Set-Cookie", b"c1=cookie1"), (b"Set-Cookie", b"c2=cookie2"), ) @@ -596,14 +619,14 @@ def test_conn_default_1_0(parser: HttpRequestParser) -> None: def test_conn_default_1_1(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert not msg.should_close def test_conn_close(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\nconnection: close\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\nconnection: close\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.should_close @@ -624,14 +647,14 @@ def test_conn_keep_alive_1_0(parser: HttpRequestParser) -> None: def test_conn_keep_alive_1_1(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\nconnection: keep-alive\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\nconnection: keep-alive\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert not msg.should_close def test_conn_close_comma_list(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\nconnection: close, keep-alive\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\nconnection: close, keep-alive\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.should_close @@ -640,6 +663,7 @@ def test_conn_close_comma_list(parser: HttpRequestParser) -> None: def test_conn_close_multiple_headers(parser: HttpRequestParser) -> None: text = ( b"GET /test HTTP/1.1\r\n" + b"Host: a\r\n" b"connection: keep-alive\r\n" b"connection: close\r\n\r\n" ) @@ -656,14 +680,14 @@ def test_conn_other_1_0(parser: HttpRequestParser) -> None: def test_conn_other_1_1(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\nconnection: test\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\nconnection: test\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert not msg.should_close def test_request_chunked(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\ntransfer-encoding: chunked\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\ntransfer-encoding: chunked\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg, payload = messages[0] assert msg.chunked @@ -673,14 +697,14 @@ def test_request_chunked(parser: HttpRequestParser) -> None: def test_te_header_non_ascii(parser: HttpRequestParser) -> None: # K = Kelvin sign, not valid ascii. - text = "GET /test HTTP/1.1\r\nTransfer-Encoding: chunKed\r\n\r\n" + text = "GET /test HTTP/1.1\r\nHost: a\r\nTransfer-Encoding: chunKed\r\n\r\n" with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text.encode()) def test_upgrade_header_non_ascii(parser: HttpRequestParser) -> None: # K = Kelvin sign, not valid ascii. - text = "GET /test HTTP/1.1\r\nUpgrade: websocKet\r\n\r\n" + text = "GET /test HTTP/1.1\r\nHost: a\r\nUpgrade: websocKet\r\n\r\n" messages, upgrade, tail = parser.feed_data(text.encode()) assert not upgrade @@ -688,6 +712,7 @@ def test_upgrade_header_non_ascii(parser: HttpRequestParser) -> None: def test_request_te_chunked_with_content_length(parser: HttpRequestParser) -> None: text = ( b"GET /test HTTP/1.1\r\n" + b"Host: a\r\n" b"content-length: 1234\r\n" b"transfer-encoding: chunked\r\n\r\n" ) @@ -699,7 +724,7 @@ def test_request_te_chunked_with_content_length(parser: HttpRequestParser) -> No def test_request_te_chunked123(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\ntransfer-encoding: chunked123\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\ntransfer-encoding: chunked123\r\n\r\n" with pytest.raises( http_exceptions.BadHttpMessage, match="Request has invalid `Transfer-Encoding`", @@ -708,14 +733,14 @@ def test_request_te_chunked123(parser: HttpRequestParser) -> None: async def test_request_te_last_chunked(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\nTransfer-Encoding: not, chunked\r\n\r\n1\r\nT\r\n3\r\nest\r\n0\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\nTransfer-Encoding: not, chunked\r\n\r\n1\r\nT\r\n3\r\nest\r\n0\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) # https://www.rfc-editor.org/rfc/rfc9112#section-6.3-2.4.3 assert await messages[0][1].read() == b"Test" def test_request_te_first_chunked(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\nTransfer-Encoding: chunked, not\r\n\r\n1\r\nT\r\n3\r\nest\r\n0\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\nTransfer-Encoding: chunked, not\r\n\r\n1\r\nT\r\n3\r\nest\r\n0\r\n\r\n" # https://www.rfc-editor.org/rfc/rfc9112#section-6.3-2.4.3 with pytest.raises( http_exceptions.BadHttpMessage, @@ -738,6 +763,7 @@ def test_request_te_duplicate_chunked(parser: HttpRequestParser) -> None: def test_conn_upgrade(parser: HttpRequestParser) -> None: text = ( b"GET /test HTTP/1.1\r\n" + b"Host: a\r\n" b"connection: upgrade\r\n" b"upgrade: websocket\r\n\r\n" ) @@ -751,6 +777,7 @@ def test_conn_upgrade(parser: HttpRequestParser) -> None: def test_conn_upgrade_comma_list(parser: HttpRequestParser) -> None: text = ( b"GET /test HTTP/1.1\r\n" + b"host: a\r\n" b"connection: keep-alive, upgrade\r\n" b"upgrade: websocket\r\n\r\n" ) @@ -764,6 +791,7 @@ def test_conn_upgrade_comma_list(parser: HttpRequestParser) -> None: def test_conn_upgrade_multiple_headers(parser: HttpRequestParser) -> None: text = ( b"GET /test HTTP/1.1\r\n" + b"host: a\r\n" b"connection: keep-alive\r\n" b"connection: upgrade\r\n" b"upgrade: websocket\r\n\r\n" @@ -777,7 +805,7 @@ def test_conn_upgrade_multiple_headers(parser: HttpRequestParser) -> None: def test_bad_upgrade(parser: HttpRequestParser) -> None: """Test not upgraded if missing Upgrade header.""" - text = b"GET /test HTTP/1.1\r\nconnection: upgrade\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\nconnection: upgrade\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert not msg.upgrade @@ -785,21 +813,21 @@ def test_bad_upgrade(parser: HttpRequestParser) -> None: def test_compression_empty(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\ncontent-encoding: \r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\ncontent-encoding: \r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.compression is None def test_compression_deflate(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\ncontent-encoding: deflate\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\ncontent-encoding: deflate\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.compression == "deflate" def test_compression_gzip(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\ncontent-encoding: gzip\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\ncontent-encoding: gzip\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.compression == "gzip" @@ -807,7 +835,7 @@ def test_compression_gzip(parser: HttpRequestParser) -> None: @pytest.mark.skipif(brotli is None, reason="brotli is not installed") def test_compression_brotli(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\ncontent-encoding: br\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\ncontent-encoding: br\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.compression == "br" @@ -815,7 +843,7 @@ def test_compression_brotli(parser: HttpRequestParser) -> None: @pytest.mark.skipif(zstandard is None, reason="zstandard is not installed") def test_compression_zstd(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\ncontent-encoding: zstd\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\ncontent-encoding: zstd\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.compression == "zstd" @@ -829,7 +857,7 @@ def test_compression_zstd(parser: HttpRequestParser) -> None: ), ) def test_compression_non_ascii(parser: HttpRequestParser, enc: bytes) -> None: - text = b"GET /test HTTP/1.1\r\ncontent-encoding: " + enc + b"\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\ncontent-encoding: " + enc + b"\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] # Non-ascii input should not evaluate to a valid encoding scheme. @@ -837,14 +865,14 @@ def test_compression_non_ascii(parser: HttpRequestParser, enc: bytes) -> None: def test_compression_unknown(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\ncontent-encoding: compress\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\ncontent-encoding: compress\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.compression is None def test_url_connect(parser: HttpRequestParser) -> None: - text = b"CONNECT www.google.com HTTP/1.1\r\ncontent-length: 0\r\n\r\n" + text = b"CONNECT www.google.com HTTP/1.1\r\nHost: a\r\ncontent-length: 0\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg, payload = messages[0] assert upgrade @@ -852,7 +880,7 @@ def test_url_connect(parser: HttpRequestParser) -> None: def test_headers_connect(parser: HttpRequestParser) -> None: - text = b"CONNECT www.google.com HTTP/1.1\r\ncontent-length: 0\r\n\r\n" + text = b"CONNECT www.google.com HTTP/1.1\r\nHost: a\r\ncontent-length: 0\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg, payload = messages[0] assert upgrade @@ -862,6 +890,7 @@ def test_headers_connect(parser: HttpRequestParser) -> None: def test_url_absolute(parser: HttpRequestParser) -> None: text = ( b"GET https://www.google.com/path/to.html HTTP/1.1\r\n" + b"Host: a\r\n" b"content-length: 0\r\n\r\n" ) messages, upgrade, tail = parser.feed_data(text) @@ -872,21 +901,21 @@ def test_url_absolute(parser: HttpRequestParser) -> None: def test_headers_old_websocket_key1(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\nSEC-WEBSOCKET-KEY1: line\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\nSEC-WEBSOCKET-KEY1: line\r\n\r\n" with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) def test_headers_content_length_err_1(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\ncontent-length: line\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\ncontent-length: line\r\n\r\n" with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) def test_headers_content_length_err_2(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\ncontent-length: -1\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\ncontent-length: -1\r\n\r\n" with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) @@ -911,7 +940,7 @@ def test_headers_content_length_err_2(parser: HttpRequestParser) -> None: def test_invalid_header_spacing( parser: HttpRequestParser, pad1: bytes, pad2: bytes, hdr: bytes ) -> None: - text = b"GET /test HTTP/1.1\r\n%s%s%s: value\r\n\r\n" % (pad1, hdr, pad2) + text = b"GET /test HTTP/1.1\r\nHost: a\r\n%s%s%s: value\r\n\r\n" % (pad1, hdr, pad2) if pad1 == pad2 == b"" and hdr != b"": # one entry in param matrix is correct: non-empty name, not padded parser.feed_data(text) @@ -922,19 +951,19 @@ def test_invalid_header_spacing( def test_empty_header_name(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\n:test\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\n:test\r\n\r\n" with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) def test_invalid_header(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\ntest line\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\ntest line\r\n\r\n" with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) def test_invalid_name(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\ntest[]: line\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\ntest[]: line\r\n\r\n" with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) @@ -953,15 +982,15 @@ def test_max_header_field_size(parser: HttpRequestParser, size: int) -> None: def test_max_header_size_under_limit(parser: HttpRequestParser) -> None: name = b"t" * 8185 - text = b"GET /test HTTP/1.1\r\n" + name + b":data\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\n" + name + b":data\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.method == "GET" assert msg.path == "/test" assert msg.version == (1, 1) - assert msg.headers == CIMultiDict({name.decode(): "data"}) - assert msg.raw_headers == ((name, b"data"),) + assert msg.headers == CIMultiDict([("Host", "a"), (name.decode(), "data")]) + assert msg.raw_headers == ((b"Host", b"a"), (name, b"data")) assert not msg.should_close assert msg.compression is None assert not msg.upgrade @@ -993,7 +1022,7 @@ def test_max_header_combined_size(parser: HttpRequestParser) -> None: async def test_max_trailer_size(parser: HttpRequestParser, size: int) -> None: value = b"t" * size text = ( - b"GET /test HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" + b"GET /test HTTP/1.1\r\nHost: a\r\nTransfer-Encoding: chunked\r\n\r\n" + hex(4000)[2:].encode() + b"\r\n" + b"b" * 4000 @@ -1019,7 +1048,7 @@ async def test_max_headers( parser: HttpRequestParser, headers: int, trailers: int ) -> None: text = ( - b"GET /test HTTP/1.1\r\nTransfer-Encoding: chunked" + b"GET /test HTTP/1.1\r\nHost: a\r\nTransfer-Encoding: chunked" + b"".join(b"\r\nHeader-%d: Value" % i for i in range(headers)) + b"\r\n\r\n4\r\ntest\r\n0" + b"".join(b"\r\nTrailer-%d: Value" % i for i in range(trailers)) @@ -1035,15 +1064,15 @@ async def test_max_headers( def test_max_header_value_size_under_limit(parser: HttpRequestParser) -> None: value = b"A" * 8185 - text = b"GET /test HTTP/1.1\r\ndata:" + value + b"\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\ndata:" + value + b"\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.method == "GET" assert msg.path == "/test" assert msg.version == (1, 1) - assert msg.headers == CIMultiDict({"data": value.decode()}) - assert msg.raw_headers == ((b"data", value),) + assert msg.headers == CIMultiDict([("Host", "a"), ("data", value.decode())]) + assert msg.raw_headers == ((b"Host", b"a"), (b"data", value)) assert not msg.should_close assert msg.compression is None assert not msg.upgrade @@ -1053,7 +1082,7 @@ def test_max_header_value_size_under_limit(parser: HttpRequestParser) -> None: async def test_chunk_splits_after_pause(parser: HttpRequestParser) -> None: text = ( - b"GET /test HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" + b"GET /test HTTP/1.1\r\nHost: a\r\nTransfer-Encoding: chunked\r\n\r\n" + b"1\r\nb\r\n" * 50000 + b"0\r\n\r\n" ) @@ -1281,15 +1310,15 @@ def test_max_header_value_size_continuation_under_limit( def test_http_request_parser(parser: HttpRequestParser) -> None: - text = b"GET /path HTTP/1.1\r\n\r\n" + text = b"GET /path HTTP/1.1\r\nHost: a\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.method == "GET" assert msg.path == "/path" assert msg.version == (1, 1) - assert msg.headers == CIMultiDict() - assert msg.raw_headers == () + assert msg.headers == CIMultiDict({"Host": "a"}) + assert msg.raw_headers == ((b"Host", b"a"),) assert not msg.should_close assert msg.compression is None assert not msg.upgrade @@ -1320,7 +1349,7 @@ def test_http_request_bad_status_line(parser: HttpRequestParser) -> None: def test_http_request_bad_status_line_number( parser: HttpRequestParser, nonascii_digit: bytes ) -> None: - text = b"GET /digit HTTP/1." + nonascii_digit + b"\r\n\r\n" + text = b"GET /digit HTTP/1." + nonascii_digit + b"\r\nHost: a\r\n\r\n" with pytest.raises(http_exceptions.BadStatusLine): parser.feed_data(text) @@ -1328,19 +1357,19 @@ def test_http_request_bad_status_line_number( def test_http_request_bad_status_line_separator(parser: HttpRequestParser) -> None: # single code point, old, multibyte NFKC, multibyte NFKD utf8sep = "\N{ARABIC LIGATURE SALLALLAHOU ALAYHE WASALLAM}".encode() - text = b"GET /ligature HTTP/1" + utf8sep + b"1\r\n\r\n" + text = b"GET /ligature HTTP/1" + utf8sep + b"1\r\nHost: a\r\n\r\n" with pytest.raises(http_exceptions.BadStatusLine): parser.feed_data(text) def test_http_request_bad_status_line_whitespace(parser: HttpRequestParser) -> None: - text = b"GET\n/path\fHTTP/1.1\r\n\r\n" + text = b"GET\n/path\fHTTP/1.1\r\nHost: a\r\n\r\n" with pytest.raises(http_exceptions.BadStatusLine): parser.feed_data(text) def test_http_request_message_after_close(parser: HttpRequestParser) -> None: - text = b"GET / HTTP/1.1\r\nConnection: close\r\n\r\nInvalid\r\n\r\n" + text = b"GET / HTTP/1.1\r\nHost: a\r\nConnection: close\r\n\r\nInvalid\r\n\r\n" with pytest.raises( http_exceptions.BadHttpMessage, match="Data after `Connection: close`" ): @@ -1348,7 +1377,7 @@ def test_http_request_message_after_close(parser: HttpRequestParser) -> None: def test_http_request_message_after_close_comma_list(parser: HttpRequestParser) -> None: - text = b"GET / HTTP/1.1\r\nConnection: close, keep-alive\r\n\r\nInvalid\r\n\r\n" + text = b"GET / HTTP/1.1\r\nHost: a\r\nConnection: close, keep-alive\r\n\r\nInvalid\r\n\r\n" with pytest.raises( http_exceptions.BadHttpMessage, match="Data after `Connection: close`" ): @@ -1358,6 +1387,7 @@ def test_http_request_message_after_close_comma_list(parser: HttpRequestParser) def test_http_request_upgrade(parser: HttpRequestParser) -> None: text = ( b"GET /test HTTP/1.1\r\n" + b"Host: a\r\n" b"connection: upgrade\r\n" b"upgrade: websocket\r\n\r\n" b"some raw data" @@ -1373,6 +1403,7 @@ def test_http_request_upgrade(parser: HttpRequestParser) -> None: async def test_http_request_upgrade_unknown(parser: HttpRequestParser) -> None: text = ( b"POST / HTTP/1.1\r\n" + b"Host: a\r\n" b"Connection: Upgrade\r\n" b"Content-Length: 2\r\n" b"Upgrade: unknown\r\n" @@ -1406,7 +1437,7 @@ def xfail_c_parser_url(request: pytest.FixtureRequest) -> None: def test_http_request_parser_utf8_request_line(parser: HttpRequestParser) -> None: messages, upgrade, tail = parser.feed_data( # note the truncated unicode sequence - b"GET /P\xc3\xbcnktchen\xa0\xef\xb7 HTTP/1.1\r\n" + + b"GET /P\xc3\xbcnktchen\xa0\xef\xb7 HTTP/1.1\r\nHost: a\r\n" + # for easier grep: ASCII 0xA0 more commonly known as non-breaking space # note the leading and trailing spaces "sTeP: \N{LATIN SMALL LETTER SHARP S}nek\t\N{NO-BREAK SPACE} " @@ -1417,8 +1448,8 @@ def test_http_request_parser_utf8_request_line(parser: HttpRequestParser) -> Non assert msg.method == "GET" assert msg.path == "/Pünktchen\udca0\udcef\udcb7" assert msg.version == (1, 1) - assert msg.headers == CIMultiDict([("STEP", "ßnek\t\xa0")]) - assert msg.raw_headers == ((b"sTeP", "ßnek\t\xa0".encode()),) + assert msg.headers == CIMultiDict([("Host", "a"), ("STEP", "ßnek\t\xa0")]) + assert msg.raw_headers == ((b"Host", b"a"), (b"sTeP", "ßnek\t\xa0".encode())) assert not msg.should_close assert msg.compression is None assert not msg.upgrade @@ -1429,15 +1460,15 @@ def test_http_request_parser_utf8_request_line(parser: HttpRequestParser) -> Non def test_http_request_parser_utf8(parser: HttpRequestParser) -> None: - text = "GET /path HTTP/1.1\r\nx-test:тест\r\n\r\n".encode() + text = "GET /path HTTP/1.1\r\nHost: a\r\nx-test:тест\r\n\r\n".encode() messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.method == "GET" assert msg.path == "/path" assert msg.version == (1, 1) - assert msg.headers == CIMultiDict([("X-TEST", "тест")]) - assert msg.raw_headers == ((b"x-test", "тест".encode()),) + assert msg.headers == CIMultiDict([("Host", "a"), ("X-TEST", "тест")]) + assert msg.raw_headers == ((b"Host", b"a"), (b"x-test", "тест".encode())) assert not msg.should_close assert msg.compression is None assert not msg.upgrade @@ -1446,16 +1477,19 @@ def test_http_request_parser_utf8(parser: HttpRequestParser) -> None: def test_http_request_parser_non_utf8(parser: HttpRequestParser) -> None: - text = "GET /path HTTP/1.1\r\nx-test:тест\r\n\r\n".encode("cp1251") + text = "GET /path HTTP/1.1\r\nHost: a\r\nx-test:тест\r\n\r\n".encode("cp1251") msg = parser.feed_data(text)[0][0][0] assert msg.method == "GET" assert msg.path == "/path" assert msg.version == (1, 1) assert msg.headers == CIMultiDict( - [("X-TEST", "тест".encode("cp1251").decode("utf8", "surrogateescape"))] + [ + ("Host", "a"), + ("X-TEST", "тест".encode("cp1251").decode("utf8", "surrogateescape")), + ] ) - assert msg.raw_headers == ((b"x-test", "тест".encode("cp1251")),) + assert msg.raw_headers == ((b"Host", b"a"), (b"x-test", "тест".encode("cp1251"))) assert not msg.should_close assert msg.compression is None assert not msg.upgrade @@ -1464,7 +1498,7 @@ def test_http_request_parser_non_utf8(parser: HttpRequestParser) -> None: def test_http_request_parser_two_slashes(parser: HttpRequestParser) -> None: - text = b"GET //path HTTP/1.1\r\n\r\n" + text = b"GET //path HTTP/1.1\r\nHost: a\r\n\r\n" msg = parser.feed_data(text)[0][0][0] assert msg.method == "GET" @@ -1485,17 +1519,19 @@ def test_http_request_parser_bad_method( parser: HttpRequestParser, rfc9110_5_6_2_token_delim: bytes ) -> None: with pytest.raises(http_exceptions.BadHttpMethod): - parser.feed_data(rfc9110_5_6_2_token_delim + b'ET" /get HTTP/1.1\r\n\r\n') + parser.feed_data( + rfc9110_5_6_2_token_delim + b'ET" /get HTTP/1.1\r\nHost: a\r\n\r\n' + ) def test_http_request_parser_bad_version(parser: HttpRequestParser) -> None: with pytest.raises(http_exceptions.BadHttpMessage): - parser.feed_data(b"GET //get HT/11\r\n\r\n") + parser.feed_data(b"GET //get HT/11\r\nHost: a\r\n\r\n") def test_http_request_parser_bad_version_number(parser: HttpRequestParser) -> None: with pytest.raises(http_exceptions.BadHttpMessage): - parser.feed_data(b"GET /test HTTP/1.32\r\n\r\n") + parser.feed_data(b"GET /test HTTP/1.32\r\nHost: a\r\n\r\n") def test_http_request_parser_bad_ascii_uri(parser: HttpRequestParser) -> None: @@ -1519,15 +1555,15 @@ def test_http_request_max_status_line(parser: HttpRequestParser, size: int) -> N def test_http_request_max_status_line_under_limit(parser: HttpRequestParser) -> None: path = b"t" * 8172 messages, upgraded, tail = parser.feed_data( - b"GET /path" + path + b" HTTP/1.1\r\n\r\n" + b"GET /path" + path + b" HTTP/1.1\r\nHost: a\r\n\r\n" ) msg = messages[0][0] assert msg.method == "GET" assert msg.path == "/path" + path.decode() assert msg.version == (1, 1) - assert msg.headers == CIMultiDict() - assert msg.raw_headers == () + assert msg.headers == CIMultiDict({"Host": "a"}) + assert msg.raw_headers == ((b"Host", b"a"),) assert not msg.should_close assert msg.compression is None assert not msg.upgrade @@ -1689,7 +1725,7 @@ async def test_http_response_parser_bad_chunked_strict_py( response = HttpResponseParserPy( protocol, loop, - 2**18, + DEFAULT_CHUNK_SIZE, max_line_size=8190, max_field_size=8190, ) @@ -1776,7 +1812,7 @@ def test_http_response_parser_code_not_ascii( def test_http_request_chunked_payload(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\ntransfer-encoding: chunked\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\ntransfer-encoding: chunked\r\n\r\n" msg, payload = parser.feed_data(text)[0][0] assert msg.chunked @@ -1794,12 +1830,13 @@ def test_http_request_chunked_payload(parser: HttpRequestParser) -> None: def test_http_request_chunked_payload_and_next_message( parser: HttpRequestParser, ) -> None: - text = b"GET /test HTTP/1.1\r\ntransfer-encoding: chunked\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\ntransfer-encoding: chunked\r\n\r\n" msg, payload = parser.feed_data(text)[0][0] messages, upgraded, tail = parser.feed_data( b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n" b"POST /test2 HTTP/1.1\r\n" + b"Host: a\r\n" b"transfer-encoding: chunked\r\n\r\n" ) @@ -1817,7 +1854,7 @@ def test_http_request_chunked_payload_and_next_message( def test_http_request_chunked_payload_chunks(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\ntransfer-encoding: chunked\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\ntransfer-encoding: chunked\r\n\r\n" msg, payload = parser.feed_data(text)[0][0] parser.feed_data(b"4\r\ndata\r") @@ -1840,7 +1877,7 @@ def test_http_request_chunked_payload_chunks(parser: HttpRequestParser) -> None: def test_parse_chunked_payload_chunk_extension(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\ntransfer-encoding: chunked\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\ntransfer-encoding: chunked\r\n\r\n" msg, payload = parser.feed_data(text)[0][0] parser.feed_data(b"4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest: test\r\n\r\n") @@ -1852,7 +1889,7 @@ def test_parse_chunked_payload_chunk_extension(parser: HttpRequestParser) -> Non async def test_request_chunked_with_trailer(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n4\r\ntest\r\n0\r\ntest: trailer\r\nsecond: test trailer\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\nTransfer-Encoding: chunked\r\n\r\n4\r\ntest\r\n0\r\ntest: trailer\r\nsecond: test trailer\r\n\r\n" messages, upgraded, tail = parser.feed_data(text) assert not tail msg, payload = messages[0] @@ -1862,7 +1899,7 @@ async def test_request_chunked_with_trailer(parser: HttpRequestParser) -> None: async def test_request_chunked_reject_bad_trailer(parser: HttpRequestParser) -> None: - text = b"GET /test HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n0\r\nbad\ntrailer\r\n\r\n" + text = b"GET /test HTTP/1.1\r\nHost: a\r\nTransfer-Encoding: chunked\r\n\r\n0\r\nbad\ntrailer\r\n\r\n" with pytest.raises(http_exceptions.BadHttpMessage, match=r"b'bad\\ntrailer'"): parser.feed_data(text) @@ -1873,9 +1910,9 @@ def test_parse_no_length_or_te_on_post( request_cls: type[HttpRequestParser], ) -> None: protocol = RequestHandler(server, loop=loop) - parser = request_cls(protocol, loop, limit=2**18) + parser = request_cls(protocol, loop, limit=DEFAULT_CHUNK_SIZE) protocol._parser = parser - text = b"POST /test HTTP/1.1\r\n\r\n" + text = b"POST /test HTTP/1.1\r\nHost: a\r\n\r\n" msg, payload = parser.feed_data(text)[0][0] assert payload.is_eof() @@ -1908,7 +1945,7 @@ def test_parse_length_payload(response: HttpResponseParser) -> None: def test_parse_no_length_payload(parser: HttpRequestParser) -> None: - text = b"PUT / HTTP/1.1\r\n\r\n" + text = b"PUT / HTTP/1.1\r\nHost: a\r\n\r\n" msg, payload = parser.feed_data(text)[0][0] assert payload.is_eof() @@ -2080,7 +2117,7 @@ async def test_parse_chunked_payload_with_lf_in_extensions( def test_partial_url(parser: HttpRequestParser) -> None: messages, upgrade, tail = parser.feed_data(b"GET /te") assert len(messages) == 0 - messages, upgrade, tail = parser.feed_data(b"st HTTP/1.1\r\n\r\n") + messages, upgrade, tail = parser.feed_data(b"st HTTP/1.1\r\nHost: a\r\n\r\n") assert len(messages) == 1 msg, payload = messages[0] @@ -2105,7 +2142,7 @@ def test_partial_url(parser: HttpRequestParser) -> None: def test_parse_uri_percent_encoded( parser: HttpRequestParser, uri: str, path: str, query: dict[str, str], fragment: str ) -> None: - text = (f"GET {uri} HTTP/1.1\r\n\r\n").encode() + text = (f"GET {uri} HTTP/1.1\r\nHost: a\r\n\r\n").encode() messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] @@ -2119,7 +2156,7 @@ def test_parse_uri_percent_encoded( def test_parse_uri_utf8(parser: HttpRequestParser) -> None: if not isinstance(parser, HttpRequestParserPy): pytest.xfail("Not valid HTTP. Maybe update py-parser to reject later.") - text = ("GET /путь?ключ=знач#фраг HTTP/1.1\r\n\r\n").encode() + text = ("GET /путь?ключ=знач#фраг HTTP/1.1\r\nHost: a\r\n\r\n").encode() messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] @@ -2131,7 +2168,8 @@ def test_parse_uri_utf8(parser: HttpRequestParser) -> None: def test_parse_uri_utf8_percent_encoded(parser: HttpRequestParser) -> None: text = ( - "GET %s HTTP/1.1\r\n\r\n" % quote("/путь?ключ=знач#фраг", safe="/?=#") + "GET %s HTTP/1.1\r\nHost: a\r\n\r\n" + % quote("/путь?ключ=знач#фраг", safe="/?=#") ).encode() messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] @@ -2156,7 +2194,7 @@ def test_parse_bad_method_for_c_parser_raises( parser = HttpRequestParserC( protocol, loop, - 2**18, + DEFAULT_CHUNK_SIZE, max_line_size=8190, max_headers=128, max_field_size=8190, @@ -2178,7 +2216,9 @@ async def test_parse_eof_payload(self, protocol: BaseProtocol) -> None: assert [bytearray(b"data")] == list(out._buffer) async def test_parse_length_payload_eof(self, protocol: BaseProtocol) -> None: - out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader( + protocol, DEFAULT_CHUNK_SIZE, loop=asyncio.get_running_loop() + ) p = HttpPayloadParser(out, length=4, headers_parser=HeadersParser()) p.feed_data(b"da") @@ -2202,7 +2242,9 @@ async def test_parse_chunked_payload_size_data_mismatch( Regression test for #10596. """ - out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader( + protocol, DEFAULT_CHUNK_SIZE, loop=asyncio.get_running_loop() + ) p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser()) # Declared chunk-size is 4 but actual data is "Hello" (5 bytes). # After consuming 4 bytes, remaining starts with "o" not "\r\n". @@ -2217,7 +2259,9 @@ async def test_parse_chunked_payload_size_data_mismatch_too_short( Regression test for #10596. """ - out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader( + protocol, DEFAULT_CHUNK_SIZE, loop=asyncio.get_running_loop() + ) p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser()) # Declared chunk-size is 6 but actual data before CRLF is "Hello" (5 bytes). # Parser reads 6 bytes: "Hello\r", then expects \r\n but sees "\n0\r\n..." @@ -2239,7 +2283,9 @@ async def test_parse_chunked_payload_split_end( async def test_parse_chunked_payload_split_end2( self, protocol: BaseProtocol ) -> None: - out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader( + protocol, DEFAULT_CHUNK_SIZE, loop=asyncio.get_running_loop() + ) p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser()) p.feed_data(b"4\r\nasdf\r\n0\r\n\r") p.feed_data(b"\n") @@ -2250,7 +2296,9 @@ async def test_parse_chunked_payload_split_end2( async def test_parse_chunked_payload_split_end_trailers( self, protocol: BaseProtocol ) -> None: - out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader( + protocol, DEFAULT_CHUNK_SIZE, loop=asyncio.get_running_loop() + ) p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser()) p.feed_data(b"4\r\nasdf\r\n0\r\n") p.feed_data(b"Content-MD5: 912ec803b2ce49e4a541068d495ab570\r\n") @@ -2262,7 +2310,9 @@ async def test_parse_chunked_payload_split_end_trailers( async def test_parse_chunked_payload_split_end_trailers2( self, protocol: BaseProtocol ) -> None: - out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader( + protocol, DEFAULT_CHUNK_SIZE, loop=asyncio.get_running_loop() + ) p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser()) p.feed_data(b"4\r\nasdf\r\n0\r\n") p.feed_data(b"Content-MD5: 912ec803b2ce49e4a541068d495ab570\r\n\r") @@ -2294,7 +2344,9 @@ async def test_parse_chunked_payload_split_end_trailers4( assert b"asdf" == b"".join(out._buffer) async def test_http_payload_parser_length(self, protocol: BaseProtocol) -> None: - out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader( + protocol, DEFAULT_CHUNK_SIZE, loop=asyncio.get_running_loop() + ) p = HttpPayloadParser(out, length=2, headers_parser=HeadersParser()) state, tail = p.feed_data(b"1245") assert state is PayloadState.PAYLOAD_COMPLETE @@ -2307,7 +2359,9 @@ async def test_http_payload_parser_deflate(self, protocol: BaseProtocol) -> None COMPRESSED = b"x\x9cKI,I\x04\x00\x04\x00\x01\x9b" length = len(COMPRESSED) - out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader( + protocol, DEFAULT_CHUNK_SIZE, loop=asyncio.get_running_loop() + ) p = HttpPayloadParser( out, length=length, compression="deflate", headers_parser=HeadersParser() ) @@ -2378,7 +2432,9 @@ async def test_http_payload_parser_deflate_split_err( async def test_http_payload_parser_length_zero( self, protocol: BaseProtocol ) -> None: - out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader( + protocol, DEFAULT_CHUNK_SIZE, loop=asyncio.get_running_loop() + ) p = HttpPayloadParser(out, length=0, headers_parser=HeadersParser()) assert p.done assert out.is_eof() @@ -2386,7 +2442,9 @@ async def test_http_payload_parser_length_zero( @pytest.mark.skipif(brotli is None, reason="brotli is not installed") async def test_http_payload_brotli(self, protocol: BaseProtocol) -> None: compressed = brotli.compress(b"brotli data") - out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader( + protocol, DEFAULT_CHUNK_SIZE, loop=asyncio.get_running_loop() + ) p = HttpPayloadParser( out, length=len(compressed), @@ -2400,7 +2458,9 @@ async def test_http_payload_brotli(self, protocol: BaseProtocol) -> None: @pytest.mark.skipif(zstandard is None, reason="zstandard is not installed") async def test_http_payload_zstandard(self, protocol: BaseProtocol) -> None: compressed = zstandard.compress(b"zstd data") - out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader( + protocol, DEFAULT_CHUNK_SIZE, loop=asyncio.get_running_loop() + ) p = HttpPayloadParser( out, length=len(compressed), @@ -2418,7 +2478,9 @@ async def test_http_payload_zstandard_multi_frame( frame1 = zstandard.compress(b"first") frame2 = zstandard.compress(b"second") payload = frame1 + frame2 - out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader( + protocol, DEFAULT_CHUNK_SIZE, loop=asyncio.get_running_loop() + ) p = HttpPayloadParser( out, length=len(payload), @@ -2435,7 +2497,9 @@ async def test_http_payload_zstandard_multi_frame_chunked( ) -> None: frame1 = zstandard.compress(b"chunk1") frame2 = zstandard.compress(b"chunk2") - out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader( + protocol, DEFAULT_CHUNK_SIZE, loop=asyncio.get_running_loop() + ) p = HttpPayloadParser( out, length=len(frame1) + len(frame2), @@ -2455,7 +2519,9 @@ async def test_http_payload_zstandard_frame_split_mid_chunk( frame2 = zstandard.compress(b"BBBB") combined = frame1 + frame2 split_point = len(frame1) + 3 # 3 bytes into frame2 - out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader( + protocol, DEFAULT_CHUNK_SIZE, loop=asyncio.get_running_loop() + ) p = HttpPayloadParser( out, length=len(combined), @@ -2473,7 +2539,9 @@ async def test_http_payload_zstandard_many_small_frames( ) -> None: parts = [f"part{i}".encode() for i in range(10)] payload = b"".join(zstandard.compress(p) for p in parts) - out = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) + out = aiohttp.StreamReader( + protocol, DEFAULT_CHUNK_SIZE, loop=asyncio.get_running_loop() + ) p = HttpPayloadParser( out, length=len(payload), @@ -2487,7 +2555,9 @@ async def test_http_payload_zstandard_many_small_frames( class TestDeflateBuffer: async def test_feed_data(self, protocol: BaseProtocol) -> None: - buf = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) + buf = aiohttp.StreamReader( + protocol, DEFAULT_CHUNK_SIZE, loop=asyncio.get_running_loop() + ) dbuf = DeflateBuffer(buf, "deflate") dbuf.decompressor = mock.Mock() @@ -2572,7 +2642,9 @@ async def test_feed_eof_no_err_zstandard(self, protocol: BaseProtocol) -> None: assert buf._eof async def test_empty_body(self, protocol: BaseProtocol) -> None: - buf = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) + buf = aiohttp.StreamReader( + protocol, DEFAULT_CHUNK_SIZE, loop=asyncio.get_running_loop() + ) dbuf = DeflateBuffer(buf, "deflate") dbuf.feed_eof() @@ -2597,7 +2669,9 @@ async def test_streaming_decompress_large_payload( original = b"A" * (3 * 2**20) compressed = zlib.compress(original) - buf = aiohttp.StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) + buf = aiohttp.StreamReader( + protocol, DEFAULT_CHUNK_SIZE, loop=asyncio.get_running_loop() + ) dbuf = DeflateBuffer(buf, "deflate") # Feed compressed data in chunks (simulating network streaming) diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index 546ea60cd8b..091731afe5a 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -12,6 +12,7 @@ from aiohttp import ClientConnectionResetError, hdrs, http from aiohttp.base_protocol import BaseProtocol from aiohttp.compression_utils import ZLibBackend +from aiohttp.helpers import DEFAULT_CHUNK_SIZE from aiohttp.http_writer import _serialize_headers @@ -1488,7 +1489,7 @@ async def test_write_drain_condition_with_large_buffer( protocol._drain_helper.reset_mock() # type: ignore[attr-defined] # Write large amount of data with drain=True - large_data = b"x" * (2**18 + 1) # Just over LIMIT + large_data = b"x" * (DEFAULT_CHUNK_SIZE + 1) # Just over LIMIT await msg.write(large_data, drain=True) # Drain should be called because drain=True AND buffer_size > LIMIT @@ -1517,12 +1518,12 @@ async def test_write_no_drain_with_large_buffer( protocol._drain_helper.reset_mock() # type: ignore[attr-defined] # Write large amount of data with drain=False - large_data = b"x" * (2**18 + 1) # Just over LIMIT + large_data = b"x" * (DEFAULT_CHUNK_SIZE + 1) # Just over LIMIT await msg.write(large_data, drain=False) # Drain should NOT be called because drain=False assert not protocol._drain_helper.called # type: ignore[attr-defined] - assert msg.buffer_size == (2**18 + 1) # Buffer not reset + assert msg.buffer_size == (DEFAULT_CHUNK_SIZE + 1) # Buffer not reset assert large_data in buf diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 30659405599..83046ccc034 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -20,7 +20,7 @@ CONTENT_TRANSFER_ENCODING, CONTENT_TYPE, ) -from aiohttp.helpers import parse_mimetype +from aiohttp.helpers import DEFAULT_CHUNK_SIZE, parse_mimetype from aiohttp.multipart import ( BodyPartReader, BodyPartReaderPayload, @@ -674,9 +674,11 @@ async def test_filename(self) -> None: assert "foo.html" == part.filename async def test_reading_long_part(self) -> None: - size = 2 * 2**18 + size = 2 * DEFAULT_CHUNK_SIZE protocol = mock.Mock(_reading_paused=False) - stream = StreamReader(protocol, 2**18, loop=asyncio.get_event_loop()) + stream = StreamReader( + protocol, DEFAULT_CHUNK_SIZE, loop=asyncio.get_event_loop() + ) stream.feed_data(b"0" * size + b"\r\n--:--") stream.feed_eof() d = CIMultiDictProxy[str](CIMultiDict()) diff --git a/tests/test_payload.py b/tests/test_payload.py index e38335f546f..97464c4e474 100644 --- a/tests/test_payload.py +++ b/tests/test_payload.py @@ -13,7 +13,7 @@ from aiohttp import payload from aiohttp.abc import AbstractStreamWriter -from aiohttp.payload import READ_SIZE +from aiohttp.helpers import DEFAULT_CHUNK_SIZE class BufferWriter(AbstractStreamWriter): @@ -328,14 +328,13 @@ def mock_read(size: int | None = None) -> bytes: async def test_bytesio_payload_large_data_multiple_chunks() -> None: """Test BytesIOPayload with large data requiring multiple read chunks.""" - chunk_size = 2**18 # 256KiB (READ_SIZE) - data = b"x" * (chunk_size + 1000) # Slightly larger than READ_SIZE + data = b"x" * (DEFAULT_CHUNK_SIZE + 1000) payload_bytesio = payload.BytesIOPayload(io.BytesIO(data)) writer = MockStreamWriter() await payload_bytesio.write_with_length(writer, None) assert writer.get_written_bytes() == data - assert len(writer.get_written_bytes()) == chunk_size + 1000 + assert len(writer.get_written_bytes()) == DEFAULT_CHUNK_SIZE + 1000 async def test_bytesio_payload_remaining_bytes_exhausted() -> None: @@ -352,21 +351,20 @@ async def test_bytesio_payload_remaining_bytes_exhausted() -> None: async def test_iobase_payload_exact_chunk_size_limit() -> None: """Test IOBasePayload with content length matching exactly one read chunk.""" - chunk_size = 2**18 # 256KiB (READ_SIZE) - data = b"x" * chunk_size + b"extra" # Slightly larger than one read chunk + data = b"x" * DEFAULT_CHUNK_SIZE + b"extra" # Slightly larger than one read chunk p = payload.IOBasePayload(io.BytesIO(data)) writer = MockStreamWriter() - await p.write_with_length(writer, chunk_size) + await p.write_with_length(writer, DEFAULT_CHUNK_SIZE) written = writer.get_written_bytes() - assert len(written) == chunk_size - assert written == data[:chunk_size] + assert len(written) == DEFAULT_CHUNK_SIZE + assert written == data[:DEFAULT_CHUNK_SIZE] async def test_iobase_payload_reads_in_chunks() -> None: - """Test IOBasePayload reads data in chunks of READ_SIZE, not all at once.""" - # Create a large file that's multiple times larger than READ_SIZE - large_data = b"x" * (READ_SIZE * 3 + 1000) # ~192KB + 1000 bytes + """Test IOBasePayload reads data in chunks of default size, not all at once.""" + # Create a large file that's multiple times larger than DEFAULT_CHUNK_SIZE + large_data = b"x" * (DEFAULT_CHUNK_SIZE * 3 + 1000) # ~192KB + 1000 bytes # Mock the file-like object to track read calls mock_file = unittest.mock.Mock(spec=io.BytesIO) @@ -383,11 +381,11 @@ def mock_read(size: int) -> bytes: if call_count == 1: return large_data[:size] elif call_count == 2: - return large_data[READ_SIZE : READ_SIZE + size] + return large_data[DEFAULT_CHUNK_SIZE : DEFAULT_CHUNK_SIZE + size] elif call_count == 3: - return large_data[READ_SIZE * 2 : READ_SIZE * 2 + size] + return large_data[DEFAULT_CHUNK_SIZE * 2 : DEFAULT_CHUNK_SIZE * 2 + size] else: - return large_data[READ_SIZE * 3 :] + return large_data[DEFAULT_CHUNK_SIZE * 3 :] mock_file.read.side_effect = mock_read @@ -397,17 +395,17 @@ def mock_read(size: int) -> bytes: # Write with a large content_length await payload_obj.write_with_length(writer, len(large_data)) - # Verify that reads were limited to READ_SIZE + # Verify that reads were limited to DEFAULT_CHUNK_SIZE assert len(read_sizes) > 1 # Should have multiple reads for read_size in read_sizes: assert ( - read_size <= READ_SIZE - ), f"Read size {read_size} exceeds READ_SIZE {READ_SIZE}" + read_size <= DEFAULT_CHUNK_SIZE + ), f"Read size {read_size} exceeds DEFAULT_CHUNK_SIZE {DEFAULT_CHUNK_SIZE}" async def test_iobase_payload_large_content_length() -> None: """Test IOBasePayload with very large content_length doesn't read all at once.""" - data = b"x" * (READ_SIZE + 1000) + data = b"x" * (DEFAULT_CHUNK_SIZE + 1000) # Create a custom file-like object that tracks read sizes class TrackingBytesIO(io.BytesIO): @@ -427,20 +425,20 @@ def read(self, size: int | None = -1) -> bytes: large_content_length = 10 * 1024 * 1024 # 10MB await payload_obj.write_with_length(writer, large_content_length) - # Verify no single read exceeded READ_SIZE + # Verify no single read exceeded DEFAULT_CHUNK_SIZE for read_size in tracking_file.read_sizes: assert ( - read_size <= READ_SIZE - ), f"Read size {read_size} exceeds READ_SIZE {READ_SIZE}" + read_size <= DEFAULT_CHUNK_SIZE + ), f"Read size {read_size} exceeds DEFAULT_CHUNK_SIZE {DEFAULT_CHUNK_SIZE}" # Verify the correct amount of data was written assert writer.get_written_bytes() == data async def test_textio_payload_reads_in_chunks() -> None: - """Test TextIOPayload reads data in chunks of READ_SIZE, not all at once.""" - # Create a large text file that's multiple times larger than READ_SIZE - large_text = "x" * (READ_SIZE * 3 + 1000) # ~192KB + 1000 chars + """Test TextIOPayload reads data in chunks of default size, not all at once.""" + # Create a large text file that's multiple times larger than DEFAULT_CHUNK_SIZE + large_text = "x" * (DEFAULT_CHUNK_SIZE * 3 + 1000) # ~192KB + 1000 chars # Mock the file-like object to track read calls mock_file = unittest.mock.Mock(spec=io.StringIO) @@ -458,11 +456,11 @@ def mock_read(size: int) -> str: if call_count == 1: return large_text[:size] elif call_count == 2: - return large_text[READ_SIZE : READ_SIZE + size] + return large_text[DEFAULT_CHUNK_SIZE : DEFAULT_CHUNK_SIZE + size] elif call_count == 3: - return large_text[READ_SIZE * 2 : READ_SIZE * 2 + size] + return large_text[DEFAULT_CHUNK_SIZE * 2 : DEFAULT_CHUNK_SIZE * 2 + size] else: - return large_text[READ_SIZE * 3 :] + return large_text[DEFAULT_CHUNK_SIZE * 3 :] mock_file.read.side_effect = mock_read @@ -472,17 +470,17 @@ def mock_read(size: int) -> str: # Write with a large content_length await payload_obj.write_with_length(writer, len(large_text.encode("utf-8"))) - # Verify that reads were limited to READ_SIZE + # Verify that reads were limited to DEFAULT_CHUNK_SIZE assert len(read_sizes) > 1 # Should have multiple reads for read_size in read_sizes: assert ( - read_size <= READ_SIZE - ), f"Read size {read_size} exceeds READ_SIZE {READ_SIZE}" + read_size <= DEFAULT_CHUNK_SIZE + ), f"Read size {read_size} exceeds DEFAULT_CHUNK_SIZE {DEFAULT_CHUNK_SIZE}" async def test_textio_payload_large_content_length() -> None: """Test TextIOPayload with very large content_length doesn't read all at once.""" - text_data = "x" * (READ_SIZE + 1000) + text_data = "x" * (DEFAULT_CHUNK_SIZE + 1000) # Create a custom file-like object that tracks read sizes class TrackingStringIO(io.StringIO): @@ -502,11 +500,11 @@ def read(self, size: int | None = -1) -> str: large_content_length = 10 * 1024 * 1024 # 10MB await payload_obj.write_with_length(writer, large_content_length) - # Verify no single read exceeded READ_SIZE + # Verify no single read exceeded DEFAULT_CHUNK_SIZE for read_size in tracking_file.read_sizes: assert ( - read_size <= READ_SIZE - ), f"Read size {read_size} exceeds READ_SIZE {READ_SIZE}" + read_size <= DEFAULT_CHUNK_SIZE + ), f"Read size {read_size} exceeds DEFAULT_CHUNK_SIZE {DEFAULT_CHUNK_SIZE}" # Verify the correct amount of data was written assert writer.get_written_bytes() == text_data.encode("utf-8") diff --git a/tests/test_streams.py b/tests/test_streams.py index 6560b4698fb..2680b226266 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -14,6 +14,7 @@ from aiohttp import streams from aiohttp.base_protocol import BaseProtocol +from aiohttp.helpers import DEFAULT_CHUNK_SIZE from aiohttp.http_exceptions import LineTooLong DATA: bytes = b"line1\nline2\nline3\n" @@ -29,7 +30,7 @@ def chunkify(seq: Sequence[_T], n: int) -> Iterator[Sequence[_T]]: async def create_stream() -> streams.StreamReader: loop = asyncio.get_event_loop() protocol = mock.Mock(_reading_paused=False) - stream = streams.StreamReader(protocol, 2**18, loop=loop) + stream = streams.StreamReader(protocol, DEFAULT_CHUNK_SIZE, loop=loop) stream.feed_data(DATA) stream.feed_eof() return stream @@ -75,7 +76,7 @@ def get_memory_usage(obj: object) -> int: class TestStreamReader: DATA: bytes = b"line1\nline2\nline3\n" - def _make_one(self, limit: int = 2**18) -> streams.StreamReader: + def _make_one(self, limit: int = DEFAULT_CHUNK_SIZE) -> streams.StreamReader: loop = asyncio.get_event_loop() return streams.StreamReader(mock.Mock(_reading_paused=False), limit, loop=loop) @@ -1276,7 +1277,7 @@ async def set_err() -> None: async def test_feed_data_waiters(protocol: BaseProtocol) -> None: loop = asyncio.get_event_loop() - reader = streams.StreamReader(protocol, 2**18, loop=loop) + reader = streams.StreamReader(protocol, DEFAULT_CHUNK_SIZE, loop=loop) waiter = reader._waiter = loop.create_future() eof_waiter = reader._eof_waiter = loop.create_future() @@ -1304,7 +1305,7 @@ async def test_feed_data_completed_waiters(protocol: BaseProtocol) -> None: async def test_feed_eof_waiters(protocol: BaseProtocol) -> None: loop = asyncio.get_event_loop() - reader = streams.StreamReader(protocol, 2**18, loop=loop) + reader = streams.StreamReader(protocol, DEFAULT_CHUNK_SIZE, loop=loop) waiter = reader._waiter = loop.create_future() eof_waiter = reader._eof_waiter = loop.create_future() @@ -1336,7 +1337,7 @@ async def test_feed_eof_cancelled(protocol: BaseProtocol) -> None: async def test_on_eof(protocol: BaseProtocol) -> None: loop = asyncio.get_event_loop() - reader = streams.StreamReader(protocol, 2**18, loop=loop) + reader = streams.StreamReader(protocol, DEFAULT_CHUNK_SIZE, loop=loop) on_eof = mock.Mock() reader.on_eof(on_eof) @@ -1357,7 +1358,7 @@ async def test_on_eof_empty_reader() -> None: async def test_on_eof_exc_in_callback(protocol: BaseProtocol) -> None: loop = asyncio.get_event_loop() - reader = streams.StreamReader(protocol, 2**18, loop=loop) + reader = streams.StreamReader(protocol, DEFAULT_CHUNK_SIZE, loop=loop) on_eof = mock.Mock() on_eof.side_effect = ValueError @@ -1392,7 +1393,7 @@ async def test_on_eof_eof_is_set(protocol: BaseProtocol) -> None: async def test_on_eof_eof_is_set_exception(protocol: BaseProtocol) -> None: loop = asyncio.get_event_loop() - reader = streams.StreamReader(protocol, 2**18, loop=loop) + reader = streams.StreamReader(protocol, DEFAULT_CHUNK_SIZE, loop=loop) reader.feed_eof() on_eof = mock.Mock() @@ -1438,7 +1439,7 @@ async def test_set_exception_cancelled(protocol: BaseProtocol) -> None: async def test_set_exception_eof_callbacks(protocol: BaseProtocol) -> None: loop = asyncio.get_event_loop() - reader = streams.StreamReader(protocol, 2**18, loop=loop) + reader = streams.StreamReader(protocol, DEFAULT_CHUNK_SIZE, loop=loop) on_eof = mock.Mock() reader.on_eof(on_eof) diff --git a/tests/test_web_request.py b/tests/test_web_request.py index efeb6b766b0..a204cc1fd48 100644 --- a/tests/test_web_request.py +++ b/tests/test_web_request.py @@ -16,6 +16,7 @@ from aiohttp import ETag, HttpVersion, web from aiohttp.base_protocol import BaseProtocol +from aiohttp.helpers import DEFAULT_CHUNK_SIZE from aiohttp.http_exceptions import BadHttpMessage, LineTooLong from aiohttp.http_parser import RawRequestMessage from aiohttp.pytest_plugin import AiohttpClient @@ -837,7 +838,7 @@ def test_clone_headers_dict() -> None: async def test_cannot_clone_after_read(protocol: BaseProtocol) -> None: - payload = StreamReader(protocol, 2**18, loop=asyncio.get_event_loop()) + payload = StreamReader(protocol, DEFAULT_CHUNK_SIZE, loop=asyncio.get_event_loop()) payload.feed_data(b"data") payload.feed_eof() req = make_mocked_request("GET", "/path", payload=payload) @@ -860,7 +861,7 @@ async def test_make_too_big_request(protocol: BaseProtocol) -> None: async def test_request_with_wrong_content_type_encoding(protocol: BaseProtocol) -> None: - payload = StreamReader(protocol, 2**18, loop=asyncio.get_event_loop()) + payload = StreamReader(protocol, DEFAULT_CHUNK_SIZE, loop=asyncio.get_event_loop()) payload.feed_data(b"{}") payload.feed_eof() headers = {"Content-Type": "text/html; charset=test"} @@ -920,7 +921,7 @@ async def test_multipart_formdata(protocol: BaseProtocol) -> None: async def test_multipart_formdata_field_missing_name(protocol: BaseProtocol) -> None: # Ensure ValueError is raised when Content-Disposition has no name - payload = StreamReader(protocol, 2**18, loop=asyncio.get_event_loop()) + payload = StreamReader(protocol, DEFAULT_CHUNK_SIZE, loop=asyncio.get_event_loop()) payload.feed_data( b"-----------------------------326931944431359\r\n" b"Content-Disposition: form-data\r\n" # Missing name! @@ -972,7 +973,9 @@ async def test_multipart_formdata_headers_too_many(protocol: BaseProtocol) -> No b"--b--\r\n" ) content_type = "multipart/form-data; boundary=b" - payload = StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) + payload = StreamReader( + protocol, DEFAULT_CHUNK_SIZE, loop=asyncio.get_running_loop() + ) payload.feed_data(body) payload.feed_eof() req = make_mocked_request( @@ -999,7 +1002,9 @@ async def test_multipart_formdata_header_too_long(protocol: BaseProtocol) -> Non b"--b--\r\n" ) content_type = "multipart/form-data; boundary=b" - payload = StreamReader(protocol, 2**18, loop=asyncio.get_running_loop()) + payload = StreamReader( + protocol, DEFAULT_CHUNK_SIZE, loop=asyncio.get_running_loop() + ) payload.feed_data(body) payload.feed_eof() req = make_mocked_request( diff --git a/tests/test_web_urldispatcher.py b/tests/test_web_urldispatcher.py index 144bd9cd03e..da6a0c38b37 100644 --- a/tests/test_web_urldispatcher.py +++ b/tests/test_web_urldispatcher.py @@ -248,7 +248,7 @@ async def test_follow_symlink_directory_traversal( # We need to use a raw socket to test this, as the client will normalize # the path before sending it to the server. reader, writer = await asyncio.open_connection(client.host, client.port) - writer.write(b"GET /../private_file HTTP/1.1\r\n\r\n") + writer.write(b"GET /../private_file HTTP/1.1\r\nHost: a\r\n\r\n") response = await reader.readuntil(b"\r\n\r\n") assert b"404 Not Found" in response writer.close() @@ -300,14 +300,14 @@ async def test_follow_symlink_directory_traversal_after_normalization( # We need to use a raw socket to test this, as the client will normalize # the path before sending it to the server. reader, writer = await asyncio.open_connection(client.host, client.port) - writer.write(b"GET /my_symlink/../private_file HTTP/1.1\r\n\r\n") + writer.write(b"GET /my_symlink/../private_file HTTP/1.1\r\nHost: a\r\n\r\n") response = await reader.readuntil(b"\r\n\r\n") assert b"404 Not Found" in response writer.close() await writer.wait_closed() reader, writer = await asyncio.open_connection(client.host, client.port) - writer.write(b"GET /my_symlink/symlink_target_file HTTP/1.1\r\n\r\n") + writer.write(b"GET /my_symlink/symlink_target_file HTTP/1.1\r\nHost: a\r\n\r\n") response = await reader.readuntil(b"\r\n\r\n") assert b"200 OK" in response response = await reader.readuntil(b"readable") diff --git a/tests/test_websocket_writer.py b/tests/test_websocket_writer.py index 3b6bc98b54f..e9d930bc54e 100644 --- a/tests/test_websocket_writer.py +++ b/tests/test_websocket_writer.py @@ -11,6 +11,7 @@ from aiohttp._websocket.reader import WebSocketDataQueue from aiohttp.base_protocol import BaseProtocol from aiohttp.compression_utils import ZLibBackend +from aiohttp.helpers import DEFAULT_CHUNK_SIZE from aiohttp.http import WebSocketReader, WebSocketWriter @@ -158,7 +159,9 @@ async def test_send_compress_cancelled( monkeypatch.setattr("aiohttp._websocket.writer.WEBSOCKET_MAX_SYNC_CHUNK_SIZE", 1024) writer = WebSocketWriter(protocol, transport, compress=15) loop = asyncio.get_running_loop() - queue = WebSocketDataQueue(mock.Mock(_reading_paused=False), 2**18, loop=loop) + queue = WebSocketDataQueue( + mock.Mock(_reading_paused=False), DEFAULT_CHUNK_SIZE, loop=loop + ) reader = WebSocketReader(queue, 50000) # Replace executor with slow one to make race condition reproducible @@ -305,7 +308,9 @@ async def test_concurrent_messages( ): writer = WebSocketWriter(protocol, transport, compress=15) loop = asyncio.get_running_loop() - queue = WebSocketDataQueue(mock.Mock(_reading_paused=False), 2**18, loop=loop) + queue = WebSocketDataQueue( + mock.Mock(_reading_paused=False), DEFAULT_CHUNK_SIZE, loop=loop + ) reader = WebSocketReader(queue, 50000) writers = [] payloads = []