From d6b942c2397d0d06569441273320f3e1be8805df Mon Sep 17 00:00:00 2001 From: pgjones Date: Sat, 28 Aug 2021 17:16:17 +0100 Subject: [PATCH] Add typing and enforce checking via tox/CI This uses the same mypy settings as wsproto. The Sentinel values are problematic, but I've found no good solution that also has the property that type(Sentinel) is Sentinel - so this should suffice for now. Whilst I've been lazy with the tests, I've mostly avoided type ignores in the main code. This should ensure that mypyc can be used if desired. --- h11/__init__.py | 59 +++++++-- h11/_connection.py | 118 +++++++++++------ h11/_headers.py | 71 +++++++--- h11/_readers.py | 75 +++++++---- h11/_receivebuffer.py | 21 +-- h11/_state.py | 47 ++++--- h11/_util.py | 24 ++-- h11/_writers.py | 50 +++++-- h11/tests/helpers.py | 42 ++++-- h11/tests/test_against_stdlib_http.py | 20 +-- h11/tests/test_connection.py | 181 +++++++++++++++----------- h11/tests/test_events.py | 30 +++-- h11/tests/test_headers.py | 28 ++-- h11/tests/test_helpers.py | 17 ++- h11/tests/test_io.py | 86 +++++++----- h11/tests/test_receivebuffer.py | 5 +- h11/tests/test_state.py | 49 +++++-- h11/tests/test_util.py | 24 ++-- setup.cfg | 6 + setup.py | 5 +- tox.ini | 12 +- 21 files changed, 651 insertions(+), 319 deletions(-) diff --git a/h11/__init__.py b/h11/__init__.py index ae39e01..989e92c 100644 --- a/h11/__init__.py +++ b/h11/__init__.py @@ -6,16 +6,57 @@ # semantics to check that what you're asking to write to the wire is sensible, # but at least it gets you out of dealing with the wire itself. -from ._connection import * -from ._events import * -from ._state import * -from ._util import LocalProtocolError, ProtocolError, RemoteProtocolError -from ._version import __version__ +from h11._connection import Connection, NEED_DATA, PAUSED +from h11._events import ( + ConnectionClosed, + Data, + EndOfMessage, + Event, + InformationalResponse, + Request, + Response, +) +from h11._state import ( + CLIENT, + CLOSED, + DONE, + ERROR, + IDLE, + MIGHT_SWITCH_PROTOCOL, + MUST_CLOSE, + SEND_BODY, + SEND_RESPONSE, + SERVER, + SWITCHED_PROTOCOL, +) +from h11._util import LocalProtocolError, ProtocolError, RemoteProtocolError +from h11._version import __version__ PRODUCT_ID = "python-h11/" + __version__ -__all__ = ["ProtocolError", "LocalProtocolError", "RemoteProtocolError"] -__all__ += _events.__all__ -__all__ += _connection.__all__ -__all__ += _state.__all__ +__all__ = ( + "Connection", + "NEED_DATA", + "PAUSED", + "ConnectionClosed", + "Data", + "EndOfMessage", + "Event", + "InformationalResponse", + "Request", + "Response", + "CLIENT", + "CLOSED", + "DONE", + "ERROR", + "IDLE", + "MUST_CLOSE", + "SEND_BODY", + "SEND_RESPONSE", + "SERVER", + "SWITCHED_PROTOCOL", + "ProtocolError", + "LocalProtocolError", + "RemoteProtocolError", +) diff --git a/h11/_connection.py b/h11/_connection.py index bcd3089..0f3ed90 100644 --- a/h11/_connection.py +++ b/h11/_connection.py @@ -1,18 +1,38 @@ # This contains the main Connection class. Everything in h11 revolves around # this. - -from ._events import * # Import all event types +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union + +from ._events import ( + ConnectionClosed, + Data, + EndOfMessage, + Event, + InformationalResponse, + Request, + Response, +) from ._headers import get_comma_header, has_expect_100_continue, set_comma_header -from ._readers import READERS +from ._readers import READERS, ReadersType from ._receivebuffer import ReceiveBuffer -from ._state import * # Import all state sentinels -from ._state import _SWITCH_CONNECT, _SWITCH_UPGRADE, ConnectionState +from ._state import ( + _SWITCH_CONNECT, + _SWITCH_UPGRADE, + CLIENT, + ConnectionState, + DONE, + ERROR, + MIGHT_SWITCH_PROTOCOL, + SEND_BODY, + SERVER, + SWITCHED_PROTOCOL, +) from ._util import ( # Import the internal things we need LocalProtocolError, make_sentinel, RemoteProtocolError, + Sentinel, ) -from ._writers import WRITERS +from ._writers import WRITERS, WritersType # Everything in __all__ gets re-exported as part of the h11 public API. __all__ = ["Connection", "NEED_DATA", "PAUSED"] @@ -44,7 +64,7 @@ # our rule is: # - If someone says Connection: close, we will close # - If someone uses HTTP/1.0, we will close. -def _keep_alive(event): +def _keep_alive(event: Union[Request, Response]) -> bool: connection = get_comma_header(event.headers, b"connection") if b"close" in connection: return False @@ -53,7 +73,9 @@ def _keep_alive(event): return True -def _body_framing(request_method, event): +def _body_framing( + request_method: bytes, event: Union[Request, Response] +) -> Tuple[str, Union[Tuple[()], Tuple[int]]]: # Called when we enter SEND_BODY to figure out framing information for # this body. # @@ -126,8 +148,10 @@ class Connection: """ def __init__( - self, our_role, max_incomplete_event_size=DEFAULT_MAX_INCOMPLETE_EVENT_SIZE - ): + self, + our_role: Sentinel, + max_incomplete_event_size: int = DEFAULT_MAX_INCOMPLETE_EVENT_SIZE, + ) -> None: self._max_incomplete_event_size = max_incomplete_event_size # State and role tracking if our_role not in (CLIENT, SERVER): @@ -155,14 +179,14 @@ def __init__( # These two are only used to interpret framing headers for figuring # out how to read/write response bodies. their_http_version is also # made available as a convenient public API. - self.their_http_version = None - self._request_method = None + self.their_http_version: Optional[bytes] = None + self._request_method: Optional[bytes] = None # This is pure flow-control and doesn't at all affect the set of legal # transitions, so no need to bother ConnectionState with it: self.client_is_waiting_for_100_continue = False @property - def states(self): + def states(self) -> Dict[Sentinel, Sentinel]: """A dictionary like:: {CLIENT: , SERVER: } @@ -173,24 +197,24 @@ def states(self): return dict(self._cstate.states) @property - def our_state(self): + def our_state(self) -> Sentinel: """The current state of whichever role we are playing. See :ref:`state-machine` for details. """ return self._cstate.states[self.our_role] @property - def their_state(self): + def their_state(self) -> Sentinel: """The current state of whichever role we are NOT playing. See :ref:`state-machine` for details. """ return self._cstate.states[self.their_role] @property - def they_are_waiting_for_100_continue(self): + def they_are_waiting_for_100_continue(self) -> bool: return self.their_role is CLIENT and self.client_is_waiting_for_100_continue - def start_next_cycle(self): + def start_next_cycle(self) -> None: """Attempt to reset our connection state for a new request/response cycle. @@ -210,12 +234,12 @@ def start_next_cycle(self): assert not self.client_is_waiting_for_100_continue self._respond_to_state_changes(old_states) - def _process_error(self, role): + def _process_error(self, role: Sentinel) -> None: old_states = dict(self._cstate.states) self._cstate.process_error(role) self._respond_to_state_changes(old_states) - def _server_switch_event(self, event): + def _server_switch_event(self, event: Event) -> Optional[Sentinel]: if type(event) is InformationalResponse and event.status_code == 101: return _SWITCH_UPGRADE if type(event) is Response: @@ -227,7 +251,7 @@ def _server_switch_event(self, event): return None # All events go through here - def _process_event(self, role, event): + def _process_event(self, role: Sentinel, event: Event) -> None: # First, pass the event through the state machine to make sure it # succeeds. old_states = dict(self._cstate.states) @@ -243,16 +267,15 @@ def _process_event(self, role, event): # Then perform the updates triggered by it. - # self._request_method if type(event) is Request: self._request_method = event.method - # self.their_http_version if role is self.their_role and type(event) in ( Request, Response, InformationalResponse, ): + event = cast(Union[Request, Response, InformationalResponse], event) self.their_http_version = event.http_version # Keep alive handling @@ -261,7 +284,9 @@ def _process_event(self, role, event): # shows up on a 1xx InformationalResponse. I think the idea is that # this is not supposed to happen. In any case, if it does happen, we # ignore it. - if type(event) in (Request, Response) and not _keep_alive(event): + if type(event) in (Request, Response) and not _keep_alive( + cast(Union[Request, Response], event) + ): self._cstate.process_keep_alive_disabled() # 100-continue @@ -274,22 +299,31 @@ def _process_event(self, role, event): self._respond_to_state_changes(old_states, event) - def _get_io_object(self, role, event, io_dict): + def _get_io_object( + self, + role: Sentinel, + event: Optional[Event], + io_dict: Union[ReadersType, WritersType], + ) -> Optional[Callable[..., Any]]: # event may be None; it's only used when entering SEND_BODY state = self._cstate.states[role] if state is SEND_BODY: # Special case: the io_dict has a dict of reader/writer factories # that depend on the request/response framing. - framing_type, args = _body_framing(self._request_method, event) - return io_dict[SEND_BODY][framing_type](*args) + framing_type, args = _body_framing( + cast(bytes, self._request_method), cast(Union[Request, Response], event) + ) + return io_dict[SEND_BODY][framing_type](*args) # type: ignore[index] else: # General case: the io_dict just has the appropriate reader/writer # for this state - return io_dict.get((role, state)) + return io_dict.get((role, state)) # type: ignore[return-value] # This must be called after any action that might have caused # self._cstate.states to change. - def _respond_to_state_changes(self, old_states, event=None): + def _respond_to_state_changes( + self, old_states: Dict[Sentinel, Sentinel], event: Optional[Event] = None + ) -> None: # Update reader/writer if self.our_state != old_states[self.our_role]: self._writer = self._get_io_object(self.our_role, event, WRITERS) @@ -297,7 +331,7 @@ def _respond_to_state_changes(self, old_states, event=None): self._reader = self._get_io_object(self.their_role, event, READERS) @property - def trailing_data(self): + def trailing_data(self) -> Tuple[bytes, bool]: """Data that has been received, but not yet processed, represented as a tuple with two elements, where the first is a byte-string containing the unprocessed data itself, and the second is a bool that is True if @@ -307,7 +341,7 @@ def trailing_data(self): """ return (bytes(self._receive_buffer), self._receive_buffer_closed) - def receive_data(self, data): + def receive_data(self, data: bytes) -> None: """Add data to our internal receive buffer. This does not actually do any processing on the data, just stores @@ -353,7 +387,7 @@ def receive_data(self, data): else: self._receive_buffer_closed = True - def _extract_next_receive_event(self): + def _extract_next_receive_event(self) -> Union[Event, Sentinel]: state = self.their_state # We don't pause immediately when they enter DONE, because even in # DONE state we can still process a ConnectionClosed() event. But @@ -372,14 +406,14 @@ def _extract_next_receive_event(self): # return that event, and then the state will change and we'll # get called again to generate the actual ConnectionClosed(). if hasattr(self._reader, "read_eof"): - event = self._reader.read_eof() + event = self._reader.read_eof() # type: ignore[attr-defined] else: event = ConnectionClosed() if event is None: event = NEED_DATA - return event + return event # type: ignore[no-any-return] - def next_event(self): + def next_event(self) -> Union[Event, Sentinel]: """Parse the next event out of our receive buffer, update our internal state, and return it. @@ -424,7 +458,7 @@ def next_event(self): try: event = self._extract_next_receive_event() if event not in [NEED_DATA, PAUSED]: - self._process_event(self.their_role, event) + self._process_event(self.their_role, cast(Event, event)) if event is NEED_DATA: if len(self._receive_buffer) > self._max_incomplete_event_size: # 431 is "Request header fields too large" which is pretty @@ -444,7 +478,7 @@ def next_event(self): else: raise - def send(self, event): + def send(self, event: Event) -> Optional[bytes]: """Convert a high-level event into bytes that can be sent to the peer, while updating our internal state machine. @@ -471,7 +505,7 @@ def send(self, event): else: return b"".join(data_list) - def send_with_data_passthrough(self, event): + def send_with_data_passthrough(self, event: Event) -> Optional[List[bytes]]: """Identical to :meth:`send`, except that in situations where :meth:`send` returns a single :term:`bytes-like object`, this instead returns a list of them -- and when sending a :class:`Data` event, this @@ -497,14 +531,14 @@ def send_with_data_passthrough(self, event): # In any situation where writer is None, process_event should # have raised ProtocolError assert writer is not None - data_list = [] + data_list: List[bytes] = [] writer(event, data_list.append) return data_list except: self._process_error(self.our_role) raise - def send_failed(self): + def send_failed(self) -> None: """Notify the state machine that we failed to send the data it gave us. @@ -529,7 +563,7 @@ def send_failed(self): # This function's *only* responsibility is making sure headers are set up # right -- everything downstream just looks at the headers. There are no # side channels. - def _clean_up_response_headers_for_sending(self, response): + def _clean_up_response_headers_for_sending(self, response: Response) -> Response: assert type(response) is Response headers = response.headers @@ -542,7 +576,7 @@ def _clean_up_response_headers_for_sending(self, response): # we're allowed to leave out the framing headers -- see # https://tools.ietf.org/html/rfc7231#section-4.3.2 . But it's just as # easy to get them right.) - method_for_choosing_headers = self._request_method + method_for_choosing_headers = cast(bytes, self._request_method) if method_for_choosing_headers == b"HEAD": method_for_choosing_headers = b"GET" framing_type, _ = _body_framing(method_for_choosing_headers, response) @@ -572,7 +606,7 @@ def _clean_up_response_headers_for_sending(self, response): if self._request_method != b"HEAD": need_close = True else: - headers = set_comma_header(headers, b"transfer-encoding", ["chunked"]) + headers = set_comma_header(headers, b"transfer-encoding", [b"chunked"]) if not self._cstate.keep_alive or need_close: # Make sure Connection: close is set diff --git a/h11/_headers.py b/h11/_headers.py index e8f24d6..acc4596 100644 --- a/h11/_headers.py +++ b/h11/_headers.py @@ -1,9 +1,18 @@ import re -from collections.abc import Sequence +from typing import AnyStr, cast, List, overload, Sequence, Tuple, TYPE_CHECKING, Union from ._abnf import field_name, field_value from ._util import bytesify, LocalProtocolError, validate +if TYPE_CHECKING: + from ._events import Request + +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal # type: ignore + + # Facts # ----- # @@ -63,7 +72,7 @@ _field_value_re = re.compile(field_value.encode("ascii")) -class Headers(Sequence): +class Headers(Sequence[Tuple[bytes, bytes]]): """ A list-like interface that allows iterating over headers as byte-pairs of (lowercased-name, value). @@ -90,30 +99,57 @@ class Headers(Sequence): __slots__ = "_full_items" - def __init__(self, full_items): + def __init__(self, full_items: List[Tuple[bytes, bytes, bytes]]) -> None: self._full_items = full_items - def __bool__(self): + def __bool__(self) -> bool: return bool(self._full_items) - def __eq__(self, other): - return list(self) == list(other) + def __eq__(self, other: object) -> bool: + return list(self) == list(other) # type: ignore - def __len__(self): + def __len__(self) -> int: return len(self._full_items) - def __repr__(self): + def __repr__(self) -> str: return "" % repr(list(self)) - def __getitem__(self, idx): + def __getitem__(self, idx: int) -> Tuple[bytes, bytes]: # type: ignore[override] _, name, value = self._full_items[idx] return (name, value) - def raw_items(self): + def raw_items(self) -> List[Tuple[bytes, bytes]]: return [(raw_name, value) for raw_name, _, value in self._full_items] -def normalize_and_validate(headers, _parsed=False): +HeaderTypes = Union[ + List[Tuple[bytes, bytes]], + List[Tuple[bytes, str]], + List[Tuple[str, bytes]], + List[Tuple[str, str]], +] + + +@overload +def normalize_and_validate(headers: Headers, _parsed: Literal[True]) -> Headers: + ... + + +@overload +def normalize_and_validate(headers: HeaderTypes, _parsed: Literal[False]) -> Headers: + ... + + +@overload +def normalize_and_validate( + headers: Union[Headers, HeaderTypes], _parsed: bool = False +) -> Headers: + ... + + +def normalize_and_validate( + headers: Union[Headers, HeaderTypes], _parsed: bool = False +) -> Headers: new_headers = [] seen_content_length = None saw_transfer_encoding = False @@ -126,6 +162,9 @@ def normalize_and_validate(headers, _parsed=False): value = bytesify(value) validate(_field_name_re, name, "Illegal header name {!r}", name) validate(_field_value_re, value, "Illegal header value {!r}", value) + assert isinstance(name, bytes) + assert isinstance(value, bytes) + raw_name = name name = name.lower() if name == b"content-length": @@ -163,7 +202,7 @@ def normalize_and_validate(headers, _parsed=False): return Headers(new_headers) -def get_comma_header(headers, name): +def get_comma_header(headers: Headers, name: bytes) -> List[bytes]: # Should only be used for headers whose value is a list of # comma-separated, case-insensitive values. # @@ -199,7 +238,7 @@ def get_comma_header(headers, name): # Expect: the only legal value is the literal string # "100-continue". Splitting on commas is harmless. Case insensitive. # - out = [] + out: List[bytes] = [] for _, found_name, found_raw_value in headers._full_items: if found_name == name: found_raw_value = found_raw_value.lower() @@ -210,7 +249,7 @@ def get_comma_header(headers, name): return out -def set_comma_header(headers, name, new_values): +def set_comma_header(headers: Headers, name: bytes, new_values: List[bytes]) -> Headers: # The header name `name` is expected to be lower-case bytes. # # Note that when we store the header we use title casing for the header @@ -220,7 +259,7 @@ def set_comma_header(headers, name, new_values): # here given the cases where we're using `set_comma_header`... # # Connection, Content-Length, Transfer-Encoding. - new_headers = [] + new_headers: List[Tuple[bytes, bytes]] = [] for found_raw_name, found_name, found_raw_value in headers._full_items: if found_name != name: new_headers.append((found_raw_name, found_raw_value)) @@ -229,7 +268,7 @@ def set_comma_header(headers, name, new_values): return normalize_and_validate(new_headers) -def has_expect_100_continue(request): +def has_expect_100_continue(request: "Request") -> bool: # https://tools.ietf.org/html/rfc7231#section-5.1.1 # "A server that receives a 100-continue expectation in an HTTP/1.0 request # MUST ignore that expectation." diff --git a/h11/_readers.py b/h11/_readers.py index 0ead0be..a036d79 100644 --- a/h11/_readers.py +++ b/h11/_readers.py @@ -17,11 +17,22 @@ # - or, for body readers, a dict of per-framing reader factories import re +from typing import Any, Callable, Dict, Iterable, NoReturn, Optional, Tuple, Type, Union from ._abnf import chunk_header, header_field, request_line, status_line -from ._events import * -from ._state import * -from ._util import LocalProtocolError, RemoteProtocolError, validate +from ._events import Data, EndOfMessage, InformationalResponse, Request, Response +from ._receivebuffer import ReceiveBuffer +from ._state import ( + CLIENT, + CLOSED, + DONE, + IDLE, + MUST_CLOSE, + SEND_BODY, + SEND_RESPONSE, + SERVER, +) +from ._util import LocalProtocolError, RemoteProtocolError, Sentinel, validate __all__ = ["READERS"] @@ -32,9 +43,9 @@ obs_fold_re = re.compile(br"[ \t]+") -def _obsolete_line_fold(lines): +def _obsolete_line_fold(lines: Iterable[bytes]) -> Iterable[bytes]: it = iter(lines) - last = None + last: Optional[bytes] = None for line in it: match = obs_fold_re.match(line) if match: @@ -52,7 +63,9 @@ def _obsolete_line_fold(lines): yield last -def _decode_header_lines(lines): +def _decode_header_lines( + lines: Iterable[bytes], +) -> Iterable[Tuple[bytes, bytes]]: for line in _obsolete_line_fold(lines): matches = validate(header_field_re, line, "illegal header line: {!r}", line) yield (matches["field_name"], matches["field_value"]) @@ -61,7 +74,7 @@ def _decode_header_lines(lines): request_line_re = re.compile(request_line.encode("ascii")) -def maybe_read_from_IDLE_client(buf): +def maybe_read_from_IDLE_client(buf: ReceiveBuffer) -> Optional[Request]: lines = buf.maybe_extract_lines() if lines is None: if buf.is_next_line_obviously_invalid_request_line(): @@ -80,7 +93,9 @@ def maybe_read_from_IDLE_client(buf): status_line_re = re.compile(status_line.encode("ascii")) -def maybe_read_from_SEND_RESPONSE_server(buf): +def maybe_read_from_SEND_RESPONSE_server( + buf: ReceiveBuffer, +) -> Union[InformationalResponse, Response, None]: lines = buf.maybe_extract_lines() if lines is None: if buf.is_next_line_obviously_invalid_request_line(): @@ -89,22 +104,29 @@ def maybe_read_from_SEND_RESPONSE_server(buf): if not lines: raise LocalProtocolError("no response line received") matches = validate(status_line_re, lines[0], "illegal status line: {!r}", lines[0]) - # Tolerate missing reason phrases - if matches["reason"] is None: - matches["reason"] = b"" - status_code = matches["status_code"] = int(matches["status_code"]) - class_ = InformationalResponse if status_code < 200 else Response + http_version = ( + b"1.1" if matches["http_version"] is None else matches["http_version"] + ) + reason = b"" if matches["reason"] is None else matches["reason"] + status_code = int(matches["status_code"]) + class_: Union[Type[InformationalResponse], Type[Response]] = ( + InformationalResponse if status_code < 200 else Response + ) return class_( - headers=list(_decode_header_lines(lines[1:])), _parsed=True, **matches + headers=list(_decode_header_lines(lines[1:])), + _parsed=True, + status_code=status_code, + reason=reason, + http_version=http_version, ) class ContentLengthReader: - def __init__(self, length): + def __init__(self, length: int) -> None: self._length = length self._remaining = length - def __call__(self, buf): + def __call__(self, buf: ReceiveBuffer) -> Union[Data, EndOfMessage, None]: if self._remaining == 0: return EndOfMessage() data = buf.maybe_extract_at_most(self._remaining) @@ -113,7 +135,7 @@ def __call__(self, buf): self._remaining -= len(data) return Data(data=data) - def read_eof(self): + def read_eof(self) -> NoReturn: raise RemoteProtocolError( "peer closed connection without sending complete message body " "(received {} bytes, expected {})".format( @@ -126,7 +148,7 @@ def read_eof(self): class ChunkedReader: - def __init__(self): + def __init__(self) -> None: self._bytes_in_chunk = 0 # After reading a chunk, we have to throw away the trailing \r\n; if # this is >0 then we discard that many bytes before resuming regular @@ -134,7 +156,7 @@ def __init__(self): self._bytes_to_discard = 0 self._reading_trailer = False - def __call__(self, buf): + def __call__(self, buf: ReceiveBuffer) -> Union[Data, EndOfMessage, None]: if self._reading_trailer: lines = buf.maybe_extract_lines() if lines is None: @@ -180,7 +202,7 @@ def __call__(self, buf): chunk_end = False return Data(data=data, chunk_start=chunk_start, chunk_end=chunk_end) - def read_eof(self): + def read_eof(self) -> NoReturn: raise RemoteProtocolError( "peer closed connection without sending complete message body " "(incomplete chunked read)" @@ -188,23 +210,28 @@ def read_eof(self): class Http10Reader: - def __call__(self, buf): + def __call__(self, buf: ReceiveBuffer) -> Optional[Data]: data = buf.maybe_extract_at_most(999999999) if data is None: return None return Data(data=data) - def read_eof(self): + def read_eof(self) -> EndOfMessage: return EndOfMessage() -def expect_nothing(buf): +def expect_nothing(buf: ReceiveBuffer) -> None: if buf: raise LocalProtocolError("Got data when expecting EOF") return None -READERS = { +ReadersType = Dict[ + Union[Sentinel, Tuple[Sentinel, Sentinel]], + Union[Callable[..., Any], Dict[str, Callable[..., Any]]], +] + +READERS: ReadersType = { (CLIENT, IDLE): maybe_read_from_IDLE_client, (SERVER, IDLE): maybe_read_from_SEND_RESPONSE_server, (SERVER, SEND_RESPONSE): maybe_read_from_SEND_RESPONSE_server, diff --git a/h11/_receivebuffer.py b/h11/_receivebuffer.py index a3737f3..e5c4e08 100644 --- a/h11/_receivebuffer.py +++ b/h11/_receivebuffer.py @@ -1,5 +1,6 @@ import re import sys +from typing import List, Optional, Union __all__ = ["ReceiveBuffer"] @@ -44,26 +45,26 @@ class ReceiveBuffer: - def __init__(self): + def __init__(self) -> None: self._data = bytearray() self._next_line_search = 0 self._multiple_lines_search = 0 - def __iadd__(self, byteslike): + def __iadd__(self, byteslike: Union[bytes, bytearray]) -> "ReceiveBuffer": self._data += byteslike return self - def __bool__(self): + def __bool__(self) -> bool: return bool(len(self)) - def __len__(self): + def __len__(self) -> int: return len(self._data) # for @property unprocessed_data - def __bytes__(self): + def __bytes__(self) -> bytes: return bytes(self._data) - def _extract(self, count): + def _extract(self, count: int) -> bytearray: # extracting an initial slice of the data buffer and return it out = self._data[:count] del self._data[:count] @@ -73,7 +74,7 @@ def _extract(self, count): return out - def maybe_extract_at_most(self, count): + def maybe_extract_at_most(self, count: int) -> Optional[bytearray]: """ Extract a fixed number of bytes from the buffer. """ @@ -83,7 +84,7 @@ def maybe_extract_at_most(self, count): return self._extract(count) - def maybe_extract_next_line(self): + def maybe_extract_next_line(self) -> Optional[bytearray]: """ Extract the first line, if it is completed in the buffer. """ @@ -100,7 +101,7 @@ def maybe_extract_next_line(self): return self._extract(idx) - def maybe_extract_lines(self): + def maybe_extract_lines(self) -> Optional[List[bytearray]]: """ Extract everything up to the first blank line, and return a list of lines. """ @@ -143,7 +144,7 @@ def maybe_extract_lines(self): # This is especially interesting when peer is messing up with HTTPS and # sent us a TLS stream where we were expecting plain HTTP given all # versions of TLS so far start handshake with a 0x16 message type code. - def is_next_line_obviously_invalid_request_line(self): + def is_next_line_obviously_invalid_request_line(self) -> bool: try: # HTTP header line must not contain non-printable characters # and should not start with a space diff --git a/h11/_state.py b/h11/_state.py index 0f08a09..aed1a33 100644 --- a/h11/_state.py +++ b/h11/_state.py @@ -110,9 +110,10 @@ # tables. But it can't automatically read the transitions that are written # directly in Python code. So if you touch those, you need to also update the # script to keep it in sync! +from typing import cast, Dict, Optional, Set, Tuple, Type, Union from ._events import * -from ._util import LocalProtocolError, make_sentinel +from ._util import LocalProtocolError, make_sentinel, Sentinel # Everything in __all__ gets re-exported as part of the h11 public API. __all__ = [ @@ -148,7 +149,12 @@ _SWITCH_UPGRADE = make_sentinel("_SWITCH_UPGRADE") _SWITCH_CONNECT = make_sentinel("_SWITCH_CONNECT") -EVENT_TRIGGERED_TRANSITIONS = { +EventTransitionType = Dict[ + Sentinel, + Dict[Sentinel, Dict[Union[Type[Event], Tuple[Type[Event], Sentinel]], Sentinel]], +] + +EVENT_TRIGGERED_TRANSITIONS: EventTransitionType = { CLIENT: { IDLE: {Request: SEND_BODY, ConnectionClosed: CLOSED}, SEND_BODY: {Data: SEND_BODY, EndOfMessage: DONE}, @@ -198,7 +204,7 @@ class ConnectionState: - def __init__(self): + def __init__(self) -> None: # Extra bits of state that don't quite fit into the state model. # If this is False then it enables the automatic DONE -> MUST_CLOSE @@ -207,23 +213,29 @@ def __init__(self): # This is a subset of {UPGRADE, CONNECT}, containing the proposals # made by the client for switching protocols. - self.pending_switch_proposals = set() + self.pending_switch_proposals: Set[Sentinel] = set() self.states = {CLIENT: IDLE, SERVER: IDLE} - def process_error(self, role): + def process_error(self, role: Sentinel) -> None: self.states[role] = ERROR self._fire_state_triggered_transitions() - def process_keep_alive_disabled(self): + def process_keep_alive_disabled(self) -> None: self.keep_alive = False self._fire_state_triggered_transitions() - def process_client_switch_proposal(self, switch_event): + def process_client_switch_proposal(self, switch_event: Sentinel) -> None: self.pending_switch_proposals.add(switch_event) self._fire_state_triggered_transitions() - def process_event(self, role, event_type, server_switch_event=None): + def process_event( + self, + role: Sentinel, + event_type: Type[Event], + server_switch_event: Optional[Sentinel] = None, + ) -> None: + _event_type: Union[Type[Event], Tuple[Type[Event], Sentinel]] = event_type if server_switch_event is not None: assert role is SERVER if server_switch_event not in self.pending_switch_proposals: @@ -232,22 +244,27 @@ def process_event(self, role, event_type, server_switch_event=None): server_switch_event ) ) - event_type = (event_type, server_switch_event) - if server_switch_event is None and event_type is Response: + _event_type = (event_type, server_switch_event) + if server_switch_event is None and _event_type is Response: self.pending_switch_proposals = set() - self._fire_event_triggered_transitions(role, event_type) + self._fire_event_triggered_transitions(role, _event_type) # Special case: the server state does get to see Request # events. - if event_type is Request: + if _event_type is Request: assert role is CLIENT self._fire_event_triggered_transitions(SERVER, (Request, CLIENT)) self._fire_state_triggered_transitions() - def _fire_event_triggered_transitions(self, role, event_type): + def _fire_event_triggered_transitions( + self, + role: Sentinel, + event_type: Union[Type[Event], Tuple[Type[Event], Sentinel]], + ) -> None: state = self.states[role] try: new_state = EVENT_TRIGGERED_TRANSITIONS[role][state][event_type] except KeyError: + event_type = cast(Type[Event], event_type) raise LocalProtocolError( "can't handle event type {} when role={} and state={}".format( event_type.__name__, role, self.states[role] @@ -255,7 +272,7 @@ def _fire_event_triggered_transitions(self, role, event_type): ) self.states[role] = new_state - def _fire_state_triggered_transitions(self): + def _fire_state_triggered_transitions(self) -> None: # We apply these rules repeatedly until converging on a fixed point while True: start_states = dict(self.states) @@ -295,7 +312,7 @@ def _fire_state_triggered_transitions(self): # Fixed point reached return - def start_next_cycle(self): + def start_next_cycle(self) -> None: if self.states != {CLIENT: DONE, SERVER: DONE}: raise LocalProtocolError( "not in a reusable state. self.states={}".format(self.states) diff --git a/h11/_util.py b/h11/_util.py index eb1a5cd..451dc7e 100644 --- a/h11/_util.py +++ b/h11/_util.py @@ -1,3 +1,5 @@ +from typing import Any, Dict, NoReturn, Pattern, Union + __all__ = [ "ProtocolError", "LocalProtocolError", @@ -37,7 +39,7 @@ class ProtocolError(Exception): """ - def __init__(self, msg, error_status_hint=400): + def __init__(self, msg: str, error_status_hint: int = 400) -> None: if type(self) is ProtocolError: raise TypeError("tried to directly instantiate ProtocolError") Exception.__init__(self, msg) @@ -56,14 +58,14 @@ def __init__(self, msg, error_status_hint=400): # LocalProtocolError is for local errors and RemoteProtocolError is for # remote errors. class LocalProtocolError(ProtocolError): - def _reraise_as_remote_protocol_error(self): + def _reraise_as_remote_protocol_error(self) -> NoReturn: # After catching a LocalProtocolError, use this method to re-raise it # as a RemoteProtocolError. This method must be called from inside an # except: block. # # An easy way to get an equivalent RemoteProtocolError is just to # modify 'self' in place. - self.__class__ = RemoteProtocolError + self.__class__ = RemoteProtocolError # type: ignore # But the re-raising is somewhat non-trivial -- you might think that # now that we've modified the in-flight exception object, that just # doing 'raise' to re-raise it would be enough. But it turns out that @@ -80,7 +82,9 @@ class RemoteProtocolError(ProtocolError): pass -def validate(regex, data, msg="malformed data", *format_args): +def validate( + regex: Pattern[bytes], data: bytes, msg: str = "malformed data", *format_args: Any +) -> Dict[str, bytes]: match = regex.fullmatch(data) if not match: if format_args: @@ -97,21 +101,21 @@ def validate(regex, data, msg="malformed data", *format_args): # # The bonus property is useful if you want to take the return value from # next_event() and do some sort of dispatch based on type(event). -class _SentinelBase(type): - def __repr__(self): +class Sentinel(type): + def __repr__(self) -> str: return self.__name__ -def make_sentinel(name): - cls = _SentinelBase(name, (_SentinelBase,), {}) - cls.__class__ = cls +def make_sentinel(name: str) -> Sentinel: + cls = Sentinel(name, (Sentinel,), {}) + cls.__class__ = cls # type: ignore return cls # Used for methods, request targets, HTTP versions, header names, and header # values. Accepts ascii-strings, or bytes/bytearray/memoryview/..., and always # returns bytes. -def bytesify(s): +def bytesify(s: Union[bytes, bytearray, memoryview, int, str]) -> bytes: # Fast-path: if type(s) is bytes: return s diff --git a/h11/_writers.py b/h11/_writers.py index cb5e8a8..90a8dc0 100644 --- a/h11/_writers.py +++ b/h11/_writers.py @@ -7,14 +7,19 @@ # - a writer # - or, for body writers, a dict of framin-dependent writer factories -from ._events import Data, EndOfMessage +from typing import Any, Callable, Dict, List, Tuple, Type, Union + +from ._events import Data, EndOfMessage, Event, InformationalResponse, Request, Response +from ._headers import Headers from ._state import CLIENT, IDLE, SEND_BODY, SEND_RESPONSE, SERVER -from ._util import LocalProtocolError +from ._util import LocalProtocolError, Sentinel __all__ = ["WRITERS"] +Writer = Callable[[bytes], Any] + -def write_headers(headers, write): +def write_headers(headers: Headers, write: Writer) -> None: # "Since the Host field-value is critical information for handling a # request, a user agent SHOULD generate Host as the first header field # following the request-line." - RFC 7230 @@ -28,7 +33,7 @@ def write_headers(headers, write): write(b"\r\n") -def write_request(request, write): +def write_request(request: Request, write: Writer) -> None: if request.http_version != b"1.1": raise LocalProtocolError("I only send HTTP/1.1") write(b"%s %s HTTP/1.1\r\n" % (request.method, request.target)) @@ -36,7 +41,9 @@ def write_request(request, write): # Shared between InformationalResponse and Response -def write_any_response(response, write): +def write_any_response( + response: Union[InformationalResponse, Response], write: Writer +) -> None: if response.http_version != b"1.1": raise LocalProtocolError("I only send HTTP/1.1") status_bytes = str(response.status_code).encode("ascii") @@ -53,7 +60,7 @@ def write_any_response(response, write): class BodyWriter: - def __call__(self, event, write): + def __call__(self, event: Event, write: Writer) -> None: if type(event) is Data: self.send_data(event.data, write) elif type(event) is EndOfMessage: @@ -61,6 +68,12 @@ def __call__(self, event, write): else: # pragma: no cover assert False + def send_data(self, data: bytes, write: Writer) -> None: + pass + + def send_eom(self, headers: Headers, write: Writer) -> None: + pass + # # These are all careful not to do anything to 'data' except call len(data) and @@ -69,16 +82,16 @@ def __call__(self, event, write): # sendfile(2). # class ContentLengthWriter(BodyWriter): - def __init__(self, length): + def __init__(self, length: int) -> None: self._length = length - def send_data(self, data, write): + def send_data(self, data: bytes, write: Writer) -> None: self._length -= len(data) if self._length < 0: raise LocalProtocolError("Too much data for declared Content-Length") write(data) - def send_eom(self, headers, write): + def send_eom(self, headers: Headers, write: Writer) -> None: if self._length != 0: raise LocalProtocolError("Too little data for declared Content-Length") if headers: @@ -86,7 +99,7 @@ def send_eom(self, headers, write): class ChunkedWriter(BodyWriter): - def send_data(self, data, write): + def send_data(self, data: bytes, write: Writer) -> None: # if we encoded 0-length data in the naive way, it would look like an # end-of-message. if not data: @@ -95,23 +108,32 @@ def send_data(self, data, write): write(data) write(b"\r\n") - def send_eom(self, headers, write): + def send_eom(self, headers: Headers, write: Writer) -> None: write(b"0\r\n") write_headers(headers, write) class Http10Writer(BodyWriter): - def send_data(self, data, write): + def send_data(self, data: bytes, write: Writer) -> None: write(data) - def send_eom(self, headers, write): + def send_eom(self, headers: Headers, write: Writer) -> None: if headers: raise LocalProtocolError("can't send trailers to HTTP/1.0 client") # no need to close the socket ourselves, that will be taken care of by # Connection: close machinery -WRITERS = { +WritersType = Dict[ + Union[Tuple[Sentinel, Sentinel], Sentinel], + Union[ + Dict[str, Type[BodyWriter]], + Callable[[Union[InformationalResponse, Response], Writer], None], + Callable[[Request, Writer], None], + ], +] + +WRITERS: WritersType = { (CLIENT, IDLE): write_request, (SERVER, IDLE): write_any_response, (SERVER, SEND_RESPONSE): write_any_response, diff --git a/h11/tests/helpers.py b/h11/tests/helpers.py index 5f53457..c9f9721 100644 --- a/h11/tests/helpers.py +++ b/h11/tests/helpers.py @@ -1,29 +1,46 @@ -from .._connection import * -from .._events import * -from .._state import * +from typing import cast, List, Union, ValuesView +from .._connection import Connection, NEED_DATA, PAUSED +from .._events import ( + ConnectionClosed, + Data, + EndOfMessage, + Event, + InformationalResponse, + Request, + Response, +) +from .._state import CLIENT, CLOSED, DONE, MUST_CLOSE, SERVER +from .._util import Sentinel -def get_all_events(conn): +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal # type: ignore + + +def get_all_events(conn: Connection) -> List[Event]: got_events = [] while True: event = conn.next_event() if event in (NEED_DATA, PAUSED): break + event = cast(Event, event) got_events.append(event) if type(event) is ConnectionClosed: break return got_events -def receive_and_get(conn, data): +def receive_and_get(conn: Connection, data: bytes) -> List[Event]: conn.receive_data(data) return get_all_events(conn) # Merges adjacent Data events, converts payloads to bytestrings, and removes # chunk boundaries. -def normalize_data_events(in_events): - out_events = [] +def normalize_data_events(in_events: List[Event]) -> List[Event]: + out_events: List[Event] = [] for event in in_events: if type(event) is Data: event = Data(data=bytes(event.data), chunk_start=False, chunk_end=False) @@ -43,16 +60,21 @@ def normalize_data_events(in_events): # of pushing them through two Connections with a fake network link in # between. class ConnectionPair: - def __init__(self): + def __init__(self) -> None: self.conn = {CLIENT: Connection(CLIENT), SERVER: Connection(SERVER)} self.other = {CLIENT: SERVER, SERVER: CLIENT} @property - def conns(self): + def conns(self) -> ValuesView[Connection]: return self.conn.values() # expect="match" if expect=send_events; expect=[...] to say what expected - def send(self, role, send_events, expect="match"): + def send( + self, + role: Sentinel, + send_events: Union[List[Event], Event], + expect: Union[List[Event], Event, Literal["match"]] = "match", + ) -> bytes: if not isinstance(send_events, list): send_events = [send_events] data = b"" diff --git a/h11/tests/test_against_stdlib_http.py b/h11/tests/test_against_stdlib_http.py index e6c5db4..d2ee131 100644 --- a/h11/tests/test_against_stdlib_http.py +++ b/h11/tests/test_against_stdlib_http.py @@ -5,13 +5,16 @@ import threading from contextlib import closing, contextmanager from http.server import SimpleHTTPRequestHandler +from typing import Callable, Generator from urllib.request import urlopen import h11 @contextmanager -def socket_server(handler): +def socket_server( + handler: Callable[..., socketserver.BaseRequestHandler] +) -> Generator[socketserver.TCPServer, None, None]: httpd = socketserver.TCPServer(("127.0.0.1", 0), handler) thread = threading.Thread( target=httpd.serve_forever, kwargs={"poll_interval": 0.01} @@ -30,23 +33,23 @@ def socket_server(handler): class SingleMindedRequestHandler(SimpleHTTPRequestHandler): - def translate_path(self, path): + def translate_path(self, path: str) -> str: return test_file_path -def test_h11_as_client(): +def test_h11_as_client() -> None: with socket_server(SingleMindedRequestHandler) as httpd: with closing(socket.create_connection(httpd.server_address)) as s: c = h11.Connection(h11.CLIENT) s.sendall( - c.send( + c.send( # type: ignore[arg-type] h11.Request( method="GET", target="/foo", headers=[("Host", "localhost")] ) ) ) - s.sendall(c.send(h11.EndOfMessage())) + s.sendall(c.send(h11.EndOfMessage())) # type: ignore[arg-type] data = bytearray() while True: @@ -67,7 +70,7 @@ def test_h11_as_client(): class H11RequestHandler(socketserver.BaseRequestHandler): - def handle(self): + def handle(self) -> None: with closing(self.request) as s: c = h11.Connection(h11.SERVER) request = None @@ -82,6 +85,7 @@ def handle(self): request = event if type(event) is h11.EndOfMessage: break + assert request is not None info = json.dumps( { "method": request.method.decode("ascii"), @@ -92,12 +96,12 @@ def handle(self): }, } ) - s.sendall(c.send(h11.Response(status_code=200, headers=[]))) + s.sendall(c.send(h11.Response(status_code=200, headers=[]))) # type: ignore[arg-type] s.sendall(c.send(h11.Data(data=info.encode("ascii")))) s.sendall(c.send(h11.EndOfMessage())) -def test_h11_as_server(): +def test_h11_as_server() -> None: with socket_server(H11RequestHandler) as httpd: host, port = httpd.server_address url = "http://{}:{}/some-path".format(host, port) diff --git a/h11/tests/test_connection.py b/h11/tests/test_connection.py index de175de..af84883 100644 --- a/h11/tests/test_connection.py +++ b/h11/tests/test_connection.py @@ -1,13 +1,35 @@ +from typing import Any, cast, Dict, List, Optional, Tuple + import pytest from .._connection import _body_framing, _keep_alive, Connection, NEED_DATA, PAUSED -from .._events import * -from .._state import * -from .._util import LocalProtocolError, RemoteProtocolError +from .._events import ( + ConnectionClosed, + Data, + EndOfMessage, + Event, + InformationalResponse, + Request, + Response, +) +from .._state import ( + CLIENT, + CLOSED, + DONE, + ERROR, + IDLE, + MIGHT_SWITCH_PROTOCOL, + MUST_CLOSE, + SEND_BODY, + SEND_RESPONSE, + SERVER, + SWITCHED_PROTOCOL, +) +from .._util import LocalProtocolError, RemoteProtocolError, Sentinel from .helpers import ConnectionPair, get_all_events, receive_and_get -def test__keep_alive(): +def test__keep_alive() -> None: assert _keep_alive( Request(method="GET", target="/", headers=[("Host", "Example.com")]) ) @@ -26,19 +48,19 @@ def test__keep_alive(): ) ) assert not _keep_alive( - Request(method="GET", target="/", headers=[], http_version="1.0") + Request(method="GET", target="/", headers=[], http_version="1.0") # type: ignore[arg-type] ) - assert _keep_alive(Response(status_code=200, headers=[])) + assert _keep_alive(Response(status_code=200, headers=[])) # type: ignore[arg-type] assert not _keep_alive(Response(status_code=200, headers=[("Connection", "close")])) assert not _keep_alive( Response(status_code=200, headers=[("Connection", "a, b, cLOse, foo")]) ) - assert not _keep_alive(Response(status_code=200, headers=[], http_version="1.0")) + assert not _keep_alive(Response(status_code=200, headers=[], http_version="1.0")) # type: ignore[arg-type] -def test__body_framing(): - def headers(cl, te): +def test__body_framing() -> None: + def headers(cl: Optional[int], te: bool) -> List[Tuple[str, str]]: headers = [] if cl is not None: headers.append(("Content-Length", str(cl))) @@ -46,16 +68,19 @@ def headers(cl, te): headers.append(("Transfer-Encoding", "chunked")) return headers - def resp(status_code=200, cl=None, te=False): + def resp( + status_code: int = 200, cl: Optional[int] = None, te: bool = False + ) -> Response: return Response(status_code=status_code, headers=headers(cl, te)) - def req(cl=None, te=False): + def req(cl: Optional[int] = None, te: bool = False) -> Request: h = headers(cl, te) h += [("Host", "example.com")] return Request(method="GET", target="/", headers=h) # Special cases where the headers are ignored: for kwargs in [{}, {"cl": 100}, {"te": True}, {"cl": 100, "te": True}]: + kwargs = cast(Dict[str, Any], kwargs) for meth, r in [ (b"HEAD", resp(**kwargs)), (b"GET", resp(status_code=204, **kwargs)), @@ -65,21 +90,22 @@ def req(cl=None, te=False): # Transfer-encoding for kwargs in [{"te": True}, {"cl": 100, "te": True}]: - for meth, r in [(None, req(**kwargs)), (b"GET", resp(**kwargs))]: + kwargs = cast(Dict[str, Any], kwargs) + for meth, r in [(None, req(**kwargs)), (b"GET", resp(**kwargs))]: # type: ignore assert _body_framing(meth, r) == ("chunked", ()) # Content-Length - for meth, r in [(None, req(cl=100)), (b"GET", resp(cl=100))]: + for meth, r in [(None, req(cl=100)), (b"GET", resp(cl=100))]: # type: ignore assert _body_framing(meth, r) == ("content-length", (100,)) # No headers - assert _body_framing(None, req()) == ("content-length", (0,)) + assert _body_framing(None, req()) == ("content-length", (0,)) # type: ignore assert _body_framing(b"GET", resp()) == ("http/1.0", ()) -def test_Connection_basics_and_content_length(): +def test_Connection_basics_and_content_length() -> None: with pytest.raises(ValueError): - Connection("CLIENT") + Connection("CLIENT") # type: ignore p = ConnectionPair() assert p.conn[CLIENT].our_role is CLIENT @@ -109,7 +135,7 @@ def test_Connection_basics_and_content_length(): assert p.conn[CLIENT].their_http_version is None assert p.conn[SERVER].their_http_version == b"1.1" - data = p.send(SERVER, InformationalResponse(status_code=100, headers=[])) + data = p.send(SERVER, InformationalResponse(status_code=100, headers=[])) # type: ignore[arg-type] assert data == b"HTTP/1.1 100 \r\n\r\n" data = p.send(SERVER, Response(status_code=200, headers=[("Content-Length", "11")])) @@ -144,7 +170,7 @@ def test_Connection_basics_and_content_length(): assert conn.states == {CLIENT: DONE, SERVER: DONE} -def test_chunked(): +def test_chunked() -> None: p = ConnectionPair() p.send( @@ -175,7 +201,7 @@ def test_chunked(): assert conn.states == {CLIENT: DONE, SERVER: DONE} -def test_chunk_boundaries(): +def test_chunk_boundaries() -> None: conn = Connection(our_role=SERVER) request = ( @@ -214,14 +240,14 @@ def test_chunk_boundaries(): assert conn.next_event() == EndOfMessage() -def test_client_talking_to_http10_server(): +def test_client_talking_to_http10_server() -> None: c = Connection(CLIENT) c.send(Request(method="GET", target="/", headers=[("Host", "example.com")])) c.send(EndOfMessage()) assert c.our_state is DONE # No content-length, so Http10 framing for body assert receive_and_get(c, b"HTTP/1.0 200 OK\r\n\r\n") == [ - Response(status_code=200, headers=[], http_version="1.0", reason=b"OK") + Response(status_code=200, headers=[], http_version="1.0", reason=b"OK") # type: ignore[arg-type] ] assert c.our_state is MUST_CLOSE assert receive_and_get(c, b"12345") == [Data(data=b"12345")] @@ -230,19 +256,19 @@ def test_client_talking_to_http10_server(): assert c.their_state is CLOSED -def test_server_talking_to_http10_client(): +def test_server_talking_to_http10_client() -> None: c = Connection(SERVER) # No content-length, so no body # NB: no host header assert receive_and_get(c, b"GET / HTTP/1.0\r\n\r\n") == [ - Request(method="GET", target="/", headers=[], http_version="1.0"), + Request(method="GET", target="/", headers=[], http_version="1.0"), # type: ignore[arg-type] EndOfMessage(), ] assert c.their_state is MUST_CLOSE # We automatically Connection: close back at them assert ( - c.send(Response(status_code=200, headers=[])) + c.send(Response(status_code=200, headers=[])) # type: ignore[arg-type] == b"HTTP/1.1 200 \r\nConnection: close\r\n\r\n" ) @@ -267,7 +293,7 @@ def test_server_talking_to_http10_client(): assert receive_and_get(c, b"") == [ConnectionClosed()] -def test_automatic_transfer_encoding_in_response(): +def test_automatic_transfer_encoding_in_response() -> None: # Check that in responses, the user can specify either Transfer-Encoding: # chunked or no framing at all, and in both cases we automatically select # the right option depending on whether the peer speaks HTTP/1.0 or @@ -279,6 +305,7 @@ def test_automatic_transfer_encoding_in_response(): # because if both are set then Transfer-Encoding wins [("Transfer-Encoding", "chunked"), ("Content-Length", "100")], ]: + user_headers = cast(List[Tuple[str, str]], user_headers) p = ConnectionPair() p.send( CLIENT, @@ -308,7 +335,7 @@ def test_automatic_transfer_encoding_in_response(): assert c.send(Data(data=b"12345")) == b"12345" -def test_automagic_connection_close_handling(): +def test_automagic_connection_close_handling() -> None: p = ConnectionPair() # If the user explicitly sets Connection: close, then we notice and # respect it @@ -329,7 +356,7 @@ def test_automagic_connection_close_handling(): p.send( SERVER, # no header here... - [Response(status_code=204, headers=[]), EndOfMessage()], + [Response(status_code=204, headers=[]), EndOfMessage()], # type: ignore[arg-type] # ...but oh look, it arrived anyway expect=[ Response(status_code=204, headers=[("connection", "close")]), @@ -340,8 +367,8 @@ def test_automagic_connection_close_handling(): assert conn.states == {CLIENT: MUST_CLOSE, SERVER: MUST_CLOSE} -def test_100_continue(): - def setup(): +def test_100_continue() -> None: + def setup() -> ConnectionPair: p = ConnectionPair() p.send( CLIENT, @@ -363,7 +390,7 @@ def setup(): # Disabled by 100 Continue p = setup() - p.send(SERVER, InformationalResponse(status_code=100, headers=[])) + p.send(SERVER, InformationalResponse(status_code=100, headers=[])) # type: ignore[arg-type] for conn in p.conns: assert not conn.client_is_waiting_for_100_continue assert not conn.they_are_waiting_for_100_continue @@ -385,7 +412,7 @@ def setup(): assert not conn.they_are_waiting_for_100_continue -def test_max_incomplete_event_size_countermeasure(): +def test_max_incomplete_event_size_countermeasure() -> None: # Infinitely long headers are definitely not okay c = Connection(SERVER) c.receive_data(b"GET / HTTP/1.0\r\nEndless: ") @@ -444,7 +471,7 @@ def test_max_incomplete_event_size_countermeasure(): # Even more data comes in, still no problem c.receive_data(b"X" * 1000) # We can respond and reuse to get the second pipelined request - c.send(Response(status_code=200, headers=[])) + c.send(Response(status_code=200, headers=[])) # type: ignore[arg-type] c.send(EndOfMessage()) c.start_next_cycle() assert get_all_events(c) == [ @@ -454,14 +481,14 @@ def test_max_incomplete_event_size_countermeasure(): # But once we unpause and try to read the next message, and find that it's # incomplete and the buffer is *still* way too large, then *that's* a # problem: - c.send(Response(status_code=200, headers=[])) + c.send(Response(status_code=200, headers=[])) # type: ignore[arg-type] c.send(EndOfMessage()) c.start_next_cycle() with pytest.raises(RemoteProtocolError): c.next_event() -def test_reuse_simple(): +def test_reuse_simple() -> None: p = ConnectionPair() p.send( CLIENT, @@ -494,7 +521,7 @@ def test_reuse_simple(): ) -def test_pipelining(): +def test_pipelining() -> None: # Client doesn't support pipelining, so we have to do this by hand c = Connection(SERVER) assert c.next_event() is NEED_DATA @@ -520,7 +547,7 @@ def test_pipelining(): assert c.next_event() is PAUSED - c.send(Response(status_code=200, headers=[])) + c.send(Response(status_code=200, headers=[])) # type: ignore[arg-type] c.send(EndOfMessage()) assert c.their_state is DONE assert c.our_state is DONE @@ -537,7 +564,7 @@ def test_pipelining(): EndOfMessage(), ] assert c.next_event() is PAUSED - c.send(Response(status_code=200, headers=[])) + c.send(Response(status_code=200, headers=[])) # type: ignore[arg-type] c.send(EndOfMessage()) c.start_next_cycle() @@ -547,7 +574,7 @@ def test_pipelining(): ] # Doesn't pause this time, no trailing data assert c.next_event() is NEED_DATA - c.send(Response(status_code=200, headers=[])) + c.send(Response(status_code=200, headers=[])) # type: ignore[arg-type] c.send(EndOfMessage()) # Arrival of more data triggers pause @@ -566,7 +593,7 @@ def test_pipelining(): c.receive_data(b"FDSA") -def test_protocol_switch(): +def test_protocol_switch() -> None: for (req, deny, accept) in [ ( Request( @@ -608,7 +635,7 @@ def test_protocol_switch(): ), ]: - def setup(): + def setup() -> ConnectionPair: p = ConnectionPair() p.send(CLIENT, req) # No switch-related state change stuff yet; the client has to @@ -656,7 +683,7 @@ def setup(): sc.send(EndOfMessage()) sc.start_next_cycle() assert get_all_events(sc) == [ - Request(method="GET", target="/", headers=[], http_version="1.0"), + Request(method="GET", target="/", headers=[], http_version="1.0"), # type: ignore[arg-type] EndOfMessage(), ] @@ -673,7 +700,7 @@ def setup(): p = setup() sc = p.conn[SERVER] - sc.receive_data(b"") == [] + sc.receive_data(b"") assert sc.next_event() is PAUSED sc.send(deny) assert sc.next_event() == ConnectionClosed() @@ -691,12 +718,12 @@ def setup(): p.conn[SERVER].send(Data(data=b"123")) -def test_close_simple(): +def test_close_simple() -> None: # Just immediately closing a new connection without anything having # happened yet. for (who_shot_first, who_shot_second) in [(CLIENT, SERVER), (SERVER, CLIENT)]: - def setup(): + def setup() -> ConnectionPair: p = ConnectionPair() p.send(who_shot_first, ConnectionClosed()) for conn in p.conns: @@ -732,7 +759,7 @@ def setup(): p.conn[who_shot_first].next_event() -def test_close_different_states(): +def test_close_different_states() -> None: req = [ Request(method="GET", target="/foo", headers=[("Host", "a")]), EndOfMessage(), @@ -798,7 +825,7 @@ def test_close_different_states(): # Receive several requests and then client shuts down their side of the # connection; we can respond to each -def test_pipelined_close(): +def test_pipelined_close() -> None: c = Connection(SERVER) # 2 requests then a close c.receive_data( @@ -818,7 +845,7 @@ def test_pipelined_close(): EndOfMessage(), ] assert c.states[CLIENT] is DONE - c.send(Response(status_code=200, headers=[])) + c.send(Response(status_code=200, headers=[])) # type: ignore[arg-type] c.send(EndOfMessage()) assert c.states[SERVER] is DONE c.start_next_cycle() @@ -833,21 +860,23 @@ def test_pipelined_close(): ConnectionClosed(), ] assert c.states == {CLIENT: CLOSED, SERVER: SEND_RESPONSE} - c.send(Response(status_code=200, headers=[])) + c.send(Response(status_code=200, headers=[])) # type: ignore[arg-type] c.send(EndOfMessage()) assert c.states == {CLIENT: CLOSED, SERVER: MUST_CLOSE} c.send(ConnectionClosed()) assert c.states == {CLIENT: CLOSED, SERVER: CLOSED} -def test_sendfile(): +def test_sendfile() -> None: class SendfilePlaceholder: - def __len__(self): + def __len__(self) -> int: return 10 placeholder = SendfilePlaceholder() - def setup(header, http_version): + def setup( + header: Tuple[str, str], http_version: str + ) -> Tuple[Connection, Optional[List[bytes]]]: c = Connection(SERVER) receive_and_get( c, "GET / HTTP/{}\r\nHost: a\r\n\r\n".format(http_version).encode("ascii") @@ -856,25 +885,25 @@ def setup(header, http_version): if header: headers.append(header) c.send(Response(status_code=200, headers=headers)) - return c, c.send_with_data_passthrough(Data(data=placeholder)) + return c, c.send_with_data_passthrough(Data(data=placeholder)) # type: ignore c, data = setup(("Content-Length", "10"), "1.1") - assert data == [placeholder] + assert data == [placeholder] # type: ignore # Raises an error if the connection object doesn't think we've sent # exactly 10 bytes c.send(EndOfMessage()) _, data = setup(("Transfer-Encoding", "chunked"), "1.1") - assert placeholder in data - data[data.index(placeholder)] = b"x" * 10 - assert b"".join(data) == b"a\r\nxxxxxxxxxx\r\n" + assert placeholder in data # type: ignore + data[data.index(placeholder)] = b"x" * 10 # type: ignore + assert b"".join(data) == b"a\r\nxxxxxxxxxx\r\n" # type: ignore - c, data = setup(None, "1.0") - assert data == [placeholder] + c, data = setup(None, "1.0") # type: ignore + assert data == [placeholder] # type: ignore assert c.our_state is SEND_BODY -def test_errors(): +def test_errors() -> None: # After a receive error, you can't receive for role in [CLIENT, SERVER]: c = Connection(our_role=role) @@ -890,14 +919,14 @@ def test_errors(): # But we can still yell at the client for sending us gibberish if role is SERVER: assert ( - c.send(Response(status_code=400, headers=[])) + c.send(Response(status_code=400, headers=[])) # type: ignore[arg-type] == b"HTTP/1.1 400 \r\nConnection: close\r\n\r\n" ) # After an error sending, you can no longer send # (This is especially important for things like content-length errors, # where there's complex internal state being modified) - def conn(role): + def conn(role: Sentinel) -> Connection: c = Connection(our_role=role) if role is SERVER: # Put it into the state where it *could* send a response... @@ -917,8 +946,8 @@ def conn(role): http_version="1.0", ) elif role is SERVER: - good = Response(status_code=200, headers=[]) - bad = Response(status_code=200, headers=[], http_version="1.0") + good = Response(status_code=200, headers=[]) # type: ignore[arg-type,assignment] + bad = Response(status_code=200, headers=[], http_version="1.0") # type: ignore[arg-type,assignment] # Make sure 'good' actually is good c = conn(role) c.send(good) @@ -944,14 +973,14 @@ def conn(role): assert c.their_state is not ERROR -def test_idle_receive_nothing(): +def test_idle_receive_nothing() -> None: # At one point this incorrectly raised an error for role in [CLIENT, SERVER]: c = Connection(role) assert c.next_event() is NEED_DATA -def test_connection_drop(): +def test_connection_drop() -> None: c = Connection(SERVER) c.receive_data(b"GET /") assert c.next_event() is NEED_DATA @@ -960,7 +989,7 @@ def test_connection_drop(): c.next_event() -def test_408_request_timeout(): +def test_408_request_timeout() -> None: # Should be able to send this spontaneously as a server without seeing # anything from client p = ConnectionPair() @@ -968,7 +997,7 @@ def test_408_request_timeout(): # This used to raise IndexError -def test_empty_request(): +def test_empty_request() -> None: c = Connection(SERVER) c.receive_data(b"\r\n") with pytest.raises(RemoteProtocolError): @@ -976,7 +1005,7 @@ def test_empty_request(): # This used to raise IndexError -def test_empty_response(): +def test_empty_response() -> None: c = Connection(CLIENT) c.send(Request(method="GET", target="/", headers=[("Host", "a")])) c.receive_data(b"\r\n") @@ -992,7 +1021,7 @@ def test_empty_response(): b"\x16\x03\x01\x00\xa5", # Typical start of a TLS Client Hello ], ) -def test_early_detection_of_invalid_request(data): +def test_early_detection_of_invalid_request(data: bytes) -> None: c = Connection(SERVER) # Early detection should occur before even receiving a `\r\n` c.receive_data(data) @@ -1008,7 +1037,7 @@ def test_early_detection_of_invalid_request(data): b"\x16\x03\x03\x00\x31", # Typical start of a TLS Server Hello ], ) -def test_early_detection_of_invalid_response(data): +def test_early_detection_of_invalid_response(data: bytes) -> None: c = Connection(CLIENT) # Early detection should occur before even receiving a `\r\n` c.receive_data(data) @@ -1020,8 +1049,8 @@ def test_early_detection_of_invalid_response(data): # The correct way to handle HEAD is to put whatever headers we *would* have # put if it were a GET -- even though we know that for HEAD, those headers # will be ignored. -def test_HEAD_framing_headers(): - def setup(method, http_version): +def test_HEAD_framing_headers() -> None: + def setup(method: bytes, http_version: bytes) -> Connection: c = Connection(SERVER) c.receive_data( method + b" / HTTP/" + http_version + b"\r\n" + b"Host: example.com\r\n\r\n" @@ -1034,14 +1063,14 @@ def setup(method, http_version): # No Content-Length, HTTP/1.1 peer, should use chunked c = setup(method, b"1.1") assert ( - c.send(Response(status_code=200, headers=[])) == b"HTTP/1.1 200 \r\n" + c.send(Response(status_code=200, headers=[])) == b"HTTP/1.1 200 \r\n" # type: ignore[arg-type] b"Transfer-Encoding: chunked\r\n\r\n" ) # No Content-Length, HTTP/1.0 peer, frame with connection: close c = setup(method, b"1.0") assert ( - c.send(Response(status_code=200, headers=[])) == b"HTTP/1.1 200 \r\n" + c.send(Response(status_code=200, headers=[])) == b"HTTP/1.1 200 \r\n" # type: ignore[arg-type] b"Connection: close\r\n\r\n" ) @@ -1062,7 +1091,7 @@ def setup(method, http_version): ) -def test_special_exceptions_for_lost_connection_in_message_body(): +def test_special_exceptions_for_lost_connection_in_message_body() -> None: c = Connection(SERVER) c.receive_data( b"POST / HTTP/1.1\r\n" b"Host: example.com\r\n" b"Content-Length: 100\r\n\r\n" @@ -1086,7 +1115,7 @@ def test_special_exceptions_for_lost_connection_in_message_body(): assert type(c.next_event()) is Request assert c.next_event() is NEED_DATA c.receive_data(b"8\r\n012345") - assert c.next_event().data == b"012345" + assert c.next_event().data == b"012345" # type: ignore c.receive_data(b"") with pytest.raises(RemoteProtocolError) as excinfo: c.next_event() diff --git a/h11/tests/test_events.py b/h11/tests/test_events.py index 4748c4b..64b6808 100644 --- a/h11/tests/test_events.py +++ b/h11/tests/test_events.py @@ -3,11 +3,19 @@ import pytest from .. import _events -from .._events import * +from .._events import ( + ConnectionClosed, + Data, + EndOfMessage, + Event, + InformationalResponse, + Request, + Response, +) from .._util import LocalProtocolError -def test_events(): +def test_events() -> None: with pytest.raises(LocalProtocolError): # Missing Host: req = Request( @@ -68,9 +76,9 @@ def test_events(): ) # Request target is validated - for bad_char in b"\x00\x20\x7f\xee": + for bad_byte in b"\x00\x20\x7f\xee": target = bytearray(b"/") - target.append(bad_char) + target.append(bad_byte) with pytest.raises(LocalProtocolError): Request( method="GET", target=target, headers=[("Host", "a")], http_version="1.1" @@ -84,19 +92,19 @@ def test_events(): with pytest.raises(LocalProtocolError): InformationalResponse(status_code=200, headers=[("Host", "a")]) - resp = Response(status_code=204, headers=[], http_version="1.0") + resp = Response(status_code=204, headers=[], http_version="1.0") # type: ignore[arg-type] assert resp.status_code == 204 assert resp.headers == [] assert resp.http_version == b"1.0" with pytest.raises(LocalProtocolError): - resp = Response(status_code=100, headers=[], http_version="1.0") + resp = Response(status_code=100, headers=[], http_version="1.0") # type: ignore[arg-type] with pytest.raises(LocalProtocolError): - Response(status_code="100", headers=[], http_version="1.0") + Response(status_code="100", headers=[], http_version="1.0") # type: ignore[arg-type] with pytest.raises(LocalProtocolError): - InformationalResponse(status_code=b"100", headers=[], http_version="1.0") + InformationalResponse(status_code=b"100", headers=[], http_version="1.0") # type: ignore[arg-type] d = Data(data=b"asdf") assert d.data == b"asdf" @@ -108,16 +116,16 @@ def test_events(): assert repr(cc) == "ConnectionClosed()" -def test_intenum_status_code(): +def test_intenum_status_code() -> None: # https://github.com/python-hyper/h11/issues/72 - r = Response(status_code=HTTPStatus.OK, headers=[], http_version="1.0") + r = Response(status_code=HTTPStatus.OK, headers=[], http_version="1.0") # type: ignore[arg-type] assert r.status_code == HTTPStatus.OK assert type(r.status_code) is not type(HTTPStatus.OK) assert type(r.status_code) is int -def test_header_casing(): +def test_header_casing() -> None: r = Request( method="GET", target="/", diff --git a/h11/tests/test_headers.py b/h11/tests/test_headers.py index ff3dc8d..ba53d08 100644 --- a/h11/tests/test_headers.py +++ b/h11/tests/test_headers.py @@ -1,9 +1,17 @@ import pytest -from .._headers import * - - -def test_normalize_and_validate(): +from .._events import Request +from .._headers import ( + get_comma_header, + has_expect_100_continue, + Headers, + normalize_and_validate, + set_comma_header, +) +from .._util import LocalProtocolError + + +def test_normalize_and_validate() -> None: assert normalize_and_validate([("foo", "bar")]) == [(b"foo", b"bar")] assert normalize_and_validate([(b"foo", b"bar")]) == [(b"foo", b"bar")] @@ -84,7 +92,7 @@ def test_normalize_and_validate(): assert excinfo.value.error_status_hint == 501 # Not Implemented -def test_get_set_comma_header(): +def test_get_set_comma_header() -> None: headers = normalize_and_validate( [ ("Connection", "close"), @@ -95,10 +103,10 @@ def test_get_set_comma_header(): assert get_comma_header(headers, b"connection") == [b"close", b"foo", b"bar"] - headers = set_comma_header(headers, b"newthing", ["a", "b"]) + headers = set_comma_header(headers, b"newthing", ["a", "b"]) # type: ignore with pytest.raises(LocalProtocolError): - set_comma_header(headers, b"newthing", [" a", "b"]) + set_comma_header(headers, b"newthing", [" a", "b"]) # type: ignore assert headers == [ (b"connection", b"close"), @@ -108,7 +116,7 @@ def test_get_set_comma_header(): (b"newthing", b"b"), ] - headers = set_comma_header(headers, b"whatever", ["different thing"]) + headers = set_comma_header(headers, b"whatever", ["different thing"]) # type: ignore assert headers == [ (b"connection", b"close"), @@ -119,9 +127,7 @@ def test_get_set_comma_header(): ] -def test_has_100_continue(): - from .._events import Request - +def test_has_100_continue() -> None: assert has_expect_100_continue( Request( method="GET", diff --git a/h11/tests/test_helpers.py b/h11/tests/test_helpers.py index 1477947..c329c76 100644 --- a/h11/tests/test_helpers.py +++ b/h11/tests/test_helpers.py @@ -1,12 +1,21 @@ -from .helpers import * +from .._events import ( + ConnectionClosed, + Data, + EndOfMessage, + Event, + InformationalResponse, + Request, + Response, +) +from .helpers import normalize_data_events -def test_normalize_data_events(): +def test_normalize_data_events() -> None: assert normalize_data_events( [ Data(data=bytearray(b"1")), Data(data=b"2"), - Response(status_code=200, headers=[]), + Response(status_code=200, headers=[]), # type: ignore[arg-type] Data(data=b"3"), Data(data=b"4"), EndOfMessage(), @@ -16,7 +25,7 @@ def test_normalize_data_events(): ] ) == [ Data(data=b"12"), - Response(status_code=200, headers=[]), + Response(status_code=200, headers=[]), # type: ignore[arg-type] Data(data=b"34"), EndOfMessage(), Data(data=b"567"), diff --git a/h11/tests/test_io.py b/h11/tests/test_io.py index 459a627..e9c01bd 100644 --- a/h11/tests/test_io.py +++ b/h11/tests/test_io.py @@ -1,6 +1,16 @@ +from typing import Any, Callable, Generator, List + import pytest -from .._events import * +from .._events import ( + ConnectionClosed, + Data, + EndOfMessage, + Event, + InformationalResponse, + Request, + Response, +) from .._headers import Headers, normalize_and_validate from .._readers import ( _obsolete_line_fold, @@ -10,7 +20,18 @@ READERS, ) from .._receivebuffer import ReceiveBuffer -from .._state import * +from .._state import ( + CLIENT, + CLOSED, + DONE, + IDLE, + MIGHT_SWITCH_PROTOCOL, + MUST_CLOSE, + SEND_BODY, + SEND_RESPONSE, + SERVER, + SWITCHED_PROTOCOL, +) from .._util import LocalProtocolError from .._writers import ( ChunkedWriter, @@ -40,7 +61,7 @@ ), ( (SERVER, SEND_RESPONSE), - Response(status_code=200, headers=[], reason=b"OK"), + Response(status_code=200, headers=[], reason=b"OK"), # type: ignore[arg-type] b"HTTP/1.1 200 OK\r\n\r\n", ), ( @@ -52,36 +73,35 @@ ), ( (SERVER, SEND_RESPONSE), - InformationalResponse(status_code=101, headers=[], reason=b"Upgrade"), + InformationalResponse(status_code=101, headers=[], reason=b"Upgrade"), # type: ignore[arg-type] b"HTTP/1.1 101 Upgrade\r\n\r\n", ), ] -def dowrite(writer, obj): - got_list = [] +def dowrite(writer: Callable[..., None], obj: Any) -> bytes: + got_list: List[bytes] = [] writer(obj, got_list.append) return b"".join(got_list) -def tw(writer, obj, expected): +def tw(writer: Any, obj: Any, expected: Any) -> None: got = dowrite(writer, obj) assert got == expected -def makebuf(data): +def makebuf(data: bytes) -> ReceiveBuffer: buf = ReceiveBuffer() buf += data return buf -def tr(reader, data, expected): - def check(got): +def tr(reader: Any, data: bytes, expected: Any) -> None: + def check(got: Any) -> None: assert got == expected # Headers should always be returned as bytes, not e.g. bytearray # https://github.com/python-hyper/wsproto/pull/54#issuecomment-377709478 for name, value in getattr(got, "headers", []): - print(name, value) assert type(name) is bytes assert type(value) is bytes @@ -104,17 +124,17 @@ def check(got): assert bytes(buf) == b"trailing" -def test_writers_simple(): +def test_writers_simple() -> None: for ((role, state), event, binary) in SIMPLE_CASES: tw(WRITERS[role, state], event, binary) -def test_readers_simple(): +def test_readers_simple() -> None: for ((role, state), event, binary) in SIMPLE_CASES: tr(READERS[role, state], binary, event) -def test_writers_unusual(): +def test_writers_unusual() -> None: # Simple test of the write_headers utility routine tw( write_headers, @@ -145,7 +165,7 @@ def test_writers_unusual(): ) -def test_readers_unusual(): +def test_readers_unusual() -> None: # Reading HTTP/1.0 tr( READERS[CLIENT, IDLE], @@ -162,7 +182,7 @@ def test_readers_unusual(): tr( READERS[CLIENT, IDLE], b"HEAD /foo HTTP/1.0\r\n\r\n", - Request(method="HEAD", target="/foo", headers=[], http_version="1.0"), + Request(method="HEAD", target="/foo", headers=[], http_version="1.0"), # type: ignore[arg-type] ) tr( @@ -305,7 +325,7 @@ def test_readers_unusual(): tr(READERS[CLIENT, IDLE], b"HEAD /foo HTTP/1.1\r\n" b": line\r\n\r\n", None) -def test__obsolete_line_fold_bytes(): +def test__obsolete_line_fold_bytes() -> None: # _obsolete_line_fold has a defensive cast to bytearray, which is # necessary to protect against O(n^2) behavior in case anyone ever passes # in regular bytestrings... but right now we never pass in regular @@ -318,7 +338,9 @@ def test__obsolete_line_fold_bytes(): ] -def _run_reader_iter(reader, buf, do_eof): +def _run_reader_iter( + reader: Any, buf: bytes, do_eof: bool +) -> Generator[Any, None, None]: while True: event = reader(buf) if event is None: @@ -333,12 +355,12 @@ def _run_reader_iter(reader, buf, do_eof): yield reader.read_eof() -def _run_reader(*args): +def _run_reader(*args: Any) -> List[Event]: events = list(_run_reader_iter(*args)) return normalize_data_events(events) -def t_body_reader(thunk, data, expected, do_eof=False): +def t_body_reader(thunk: Any, data: bytes, expected: Any, do_eof: bool = False) -> None: # Simple: consume whole thing print("Test 1") buf = makebuf(data) @@ -361,7 +383,7 @@ def t_body_reader(thunk, data, expected, do_eof=False): assert _run_reader(thunk(), buf, False) == expected -def test_ContentLengthReader(): +def test_ContentLengthReader() -> None: t_body_reader(lambda: ContentLengthReader(0), b"", [EndOfMessage()]) t_body_reader( @@ -371,7 +393,7 @@ def test_ContentLengthReader(): ) -def test_Http10Reader(): +def test_Http10Reader() -> None: t_body_reader(Http10Reader, b"", [EndOfMessage()], do_eof=True) t_body_reader(Http10Reader, b"asdf", [Data(data=b"asdf")], do_eof=False) t_body_reader( @@ -379,7 +401,7 @@ def test_Http10Reader(): ) -def test_ChunkedReader(): +def test_ChunkedReader() -> None: t_body_reader(ChunkedReader, b"0\r\n\r\n", [EndOfMessage()]) t_body_reader( @@ -434,7 +456,7 @@ def test_ChunkedReader(): ) -def test_ContentLengthWriter(): +def test_ContentLengthWriter() -> None: w = ContentLengthWriter(5) assert dowrite(w, Data(data=b"123")) == b"123" assert dowrite(w, Data(data=b"45")) == b"45" @@ -461,7 +483,7 @@ def test_ContentLengthWriter(): dowrite(w, EndOfMessage(headers=[("Etag", "asdf")])) -def test_ChunkedWriter(): +def test_ChunkedWriter() -> None: w = ChunkedWriter() assert dowrite(w, Data(data=b"aaa")) == b"3\r\naaa\r\n" assert dowrite(w, Data(data=b"a" * 20)) == b"14\r\n" + b"a" * 20 + b"\r\n" @@ -476,7 +498,7 @@ def test_ChunkedWriter(): ) -def test_Http10Writer(): +def test_Http10Writer() -> None: w = Http10Writer() assert dowrite(w, Data(data=b"1234")) == b"1234" assert dowrite(w, EndOfMessage()) == b"" @@ -485,12 +507,12 @@ def test_Http10Writer(): dowrite(w, EndOfMessage(headers=[("Etag", "asdf")])) -def test_reject_garbage_after_request_line(): +def test_reject_garbage_after_request_line() -> None: with pytest.raises(LocalProtocolError): tr(READERS[SERVER, SEND_RESPONSE], b"HTTP/1.0 200 OK\x00xxxx\r\n\r\n", None) -def test_reject_garbage_after_response_line(): +def test_reject_garbage_after_response_line() -> None: with pytest.raises(LocalProtocolError): tr( READERS[CLIENT, IDLE], @@ -499,7 +521,7 @@ def test_reject_garbage_after_response_line(): ) -def test_reject_garbage_in_header_line(): +def test_reject_garbage_in_header_line() -> None: with pytest.raises(LocalProtocolError): tr( READERS[CLIENT, IDLE], @@ -508,7 +530,7 @@ def test_reject_garbage_in_header_line(): ) -def test_reject_non_vchar_in_path(): +def test_reject_non_vchar_in_path() -> None: for bad_char in b"\x00\x20\x7f\xee": message = bytearray(b"HEAD /") message.append(bad_char) @@ -518,7 +540,7 @@ def test_reject_non_vchar_in_path(): # https://github.com/python-hyper/h11/issues/57 -def test_allow_some_garbage_in_cookies(): +def test_allow_some_garbage_in_cookies() -> None: tr( READERS[CLIENT, IDLE], b"HEAD /foo HTTP/1.1\r\n" @@ -536,7 +558,7 @@ def test_allow_some_garbage_in_cookies(): ) -def test_host_comes_first(): +def test_host_comes_first() -> None: tw( write_headers, normalize_and_validate([("foo", "bar"), ("Host", "example.com")]), diff --git a/h11/tests/test_receivebuffer.py b/h11/tests/test_receivebuffer.py index 3a61f9d..21a3870 100644 --- a/h11/tests/test_receivebuffer.py +++ b/h11/tests/test_receivebuffer.py @@ -1,11 +1,12 @@ import re +from typing import Tuple import pytest from .._receivebuffer import ReceiveBuffer -def test_receivebuffer(): +def test_receivebuffer() -> None: b = ReceiveBuffer() assert not b assert len(b) == 0 @@ -118,7 +119,7 @@ def test_receivebuffer(): ), ], ) -def test_receivebuffer_for_invalid_delimiter(data): +def test_receivebuffer_for_invalid_delimiter(data: Tuple[bytes]) -> None: b = ReceiveBuffer() for line in data: diff --git a/h11/tests/test_state.py b/h11/tests/test_state.py index efe83f0..bc974e6 100644 --- a/h11/tests/test_state.py +++ b/h11/tests/test_state.py @@ -1,12 +1,33 @@ import pytest -from .._events import * -from .._state import * -from .._state import _SWITCH_CONNECT, _SWITCH_UPGRADE, ConnectionState +from .._events import ( + ConnectionClosed, + Data, + EndOfMessage, + Event, + InformationalResponse, + Request, + Response, +) +from .._state import ( + _SWITCH_CONNECT, + _SWITCH_UPGRADE, + CLIENT, + CLOSED, + ConnectionState, + DONE, + IDLE, + MIGHT_SWITCH_PROTOCOL, + MUST_CLOSE, + SEND_BODY, + SEND_RESPONSE, + SERVER, + SWITCHED_PROTOCOL, +) from .._util import LocalProtocolError -def test_ConnectionState(): +def test_ConnectionState() -> None: cs = ConnectionState() # Basic event-triggered transitions @@ -38,7 +59,7 @@ def test_ConnectionState(): assert cs.states == {CLIENT: MUST_CLOSE, SERVER: CLOSED} -def test_ConnectionState_keep_alive(): +def test_ConnectionState_keep_alive() -> None: # keep_alive = False cs = ConnectionState() cs.process_event(CLIENT, Request) @@ -51,7 +72,7 @@ def test_ConnectionState_keep_alive(): assert cs.states == {CLIENT: MUST_CLOSE, SERVER: MUST_CLOSE} -def test_ConnectionState_keep_alive_in_DONE(): +def test_ConnectionState_keep_alive_in_DONE() -> None: # Check that if keep_alive is disabled when the CLIENT is already in DONE, # then this is sufficient to immediately trigger the DONE -> MUST_CLOSE # transition @@ -63,7 +84,7 @@ def test_ConnectionState_keep_alive_in_DONE(): assert cs.states[CLIENT] is MUST_CLOSE -def test_ConnectionState_switch_denied(): +def test_ConnectionState_switch_denied() -> None: for switch_type in (_SWITCH_CONNECT, _SWITCH_UPGRADE): for deny_early in (True, False): cs = ConnectionState() @@ -107,7 +128,7 @@ def test_ConnectionState_switch_denied(): } -def test_ConnectionState_protocol_switch_accepted(): +def test_ConnectionState_protocol_switch_accepted() -> None: for switch_event in [_SWITCH_UPGRADE, _SWITCH_CONNECT]: cs = ConnectionState() cs.process_client_switch_proposal(switch_event) @@ -125,7 +146,7 @@ def test_ConnectionState_protocol_switch_accepted(): assert cs.states == {CLIENT: SWITCHED_PROTOCOL, SERVER: SWITCHED_PROTOCOL} -def test_ConnectionState_double_protocol_switch(): +def test_ConnectionState_double_protocol_switch() -> None: # CONNECT + Upgrade is legal! Very silly, but legal. So we support # it. Because sometimes doing the silly thing is easier than not. for server_switch in [None, _SWITCH_UPGRADE, _SWITCH_CONNECT]: @@ -144,7 +165,7 @@ def test_ConnectionState_double_protocol_switch(): assert cs.states == {CLIENT: SWITCHED_PROTOCOL, SERVER: SWITCHED_PROTOCOL} -def test_ConnectionState_inconsistent_protocol_switch(): +def test_ConnectionState_inconsistent_protocol_switch() -> None: for client_switches, server_switch in [ ([], _SWITCH_CONNECT), ([], _SWITCH_UPGRADE), @@ -152,14 +173,14 @@ def test_ConnectionState_inconsistent_protocol_switch(): ([_SWITCH_CONNECT], _SWITCH_UPGRADE), ]: cs = ConnectionState() - for client_switch in client_switches: + for client_switch in client_switches: # type: ignore[attr-defined] cs.process_client_switch_proposal(client_switch) cs.process_event(CLIENT, Request) with pytest.raises(LocalProtocolError): cs.process_event(SERVER, Response, server_switch) -def test_ConnectionState_keepalive_protocol_switch_interaction(): +def test_ConnectionState_keepalive_protocol_switch_interaction() -> None: # keep_alive=False + pending_switch_proposals cs = ConnectionState() cs.process_client_switch_proposal(_SWITCH_UPGRADE) @@ -177,7 +198,7 @@ def test_ConnectionState_keepalive_protocol_switch_interaction(): assert cs.states == {CLIENT: MUST_CLOSE, SERVER: SEND_BODY} -def test_ConnectionState_reuse(): +def test_ConnectionState_reuse() -> None: cs = ConnectionState() with pytest.raises(LocalProtocolError): @@ -242,7 +263,7 @@ def test_ConnectionState_reuse(): assert cs.states == {CLIENT: IDLE, SERVER: IDLE} -def test_server_request_is_illegal(): +def test_server_request_is_illegal() -> None: # There used to be a bug in how we handled the Request special case that # made this allowed... cs = ConnectionState() diff --git a/h11/tests/test_util.py b/h11/tests/test_util.py index d851bdc..cc086a6 100644 --- a/h11/tests/test_util.py +++ b/h11/tests/test_util.py @@ -1,18 +1,26 @@ import re import sys import traceback +from typing import NoReturn import pytest -from .._util import * +from .._util import ( + bytesify, + LocalProtocolError, + make_sentinel, + ProtocolError, + RemoteProtocolError, + validate, +) -def test_ProtocolError(): +def test_ProtocolError() -> None: with pytest.raises(TypeError): ProtocolError("abstract base class") -def test_LocalProtocolError(): +def test_LocalProtocolError() -> None: try: raise LocalProtocolError("foo") except LocalProtocolError as e: @@ -25,7 +33,7 @@ def test_LocalProtocolError(): assert str(e) == "foo" assert e.error_status_hint == 418 - def thunk(): + def thunk() -> NoReturn: raise LocalProtocolError("a", error_status_hint=420) try: @@ -42,7 +50,7 @@ def thunk(): assert new_traceback.endswith(orig_traceback) -def test_validate(): +def test_validate() -> None: my_re = re.compile(br"(?P[0-9]+)\.(?P[0-9]+)") with pytest.raises(LocalProtocolError): validate(my_re, b"0.") @@ -57,7 +65,7 @@ def test_validate(): validate(my_re, b"0.1\n") -def test_validate_formatting(): +def test_validate_formatting() -> None: my_re = re.compile(br"foo") with pytest.raises(LocalProtocolError) as excinfo: @@ -73,7 +81,7 @@ def test_validate_formatting(): assert "oops 10 xx" in str(excinfo.value) -def test_make_sentinel(): +def test_make_sentinel() -> None: S = make_sentinel("S") assert repr(S) == "S" assert S == S @@ -87,7 +95,7 @@ def test_make_sentinel(): assert type(S) is not type(S2) -def test_bytesify(): +def test_bytesify() -> None: assert bytesify(b"123") == b"123" assert bytesify(bytearray(b"123")) == b"123" assert bytesify("123") == b"123" diff --git a/setup.cfg b/setup.cfg index 0bd1262..95cb0f4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,3 +8,9 @@ line_length=88 multi_line_output=3 no_lines_before=LOCALFOLDER order_by_type=False + +[mypy] +strict = true +warn_unused_configs = true +warn_unused_ignores = true +show_error_codes = true diff --git a/setup.py b/setup.py index eab298e..1f23e63 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,10 @@ # doesn't look like a source file, so long as it appears in MANIFEST.in: include_package_data=True, python_requires=">=3.6", - install_requires=["dataclasses; python_version < '3.7'"], + install_requires=[ + "dataclasses; python_version < '3.7'", + "typing_extensions; python_version < '3.8'", + ], classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", diff --git a/tox.ini b/tox.ini index d4a2272..f333194 100644 --- a/tox.ini +++ b/tox.ini @@ -1,11 +1,11 @@ [tox] -envlist = format, py36, py37, py38, py39, pypy3 +envlist = format, py36, py37, py38, py39, pypy3, mypy [gh-actions] python = 3.6: py36 3.7: py37 - 3.8: py38, format + 3.8: py38, format, mypy 3.9: py39 pypy3: pypy3 @@ -21,3 +21,11 @@ deps = commands = black --check --diff h11/ bench/ examples/ fuzz/ isort --check --diff h11 bench examples fuzz + +[testenv:mypy] +basepython = python3.8 +deps = + mypy + pytest +commands = + mypy h11/