diff --git a/h11/__init__.py b/h11/__init__.py index ae39e01..989e92c 100644 --- a/h11/__init__.py +++ b/h11/__init__.py @@ -6,16 +6,57 @@ # semantics to check that what you're asking to write to the wire is sensible, # but at least it gets you out of dealing with the wire itself. -from ._connection import * -from ._events import * -from ._state import * -from ._util import LocalProtocolError, ProtocolError, RemoteProtocolError -from ._version import __version__ +from h11._connection import Connection, NEED_DATA, PAUSED +from h11._events import ( + ConnectionClosed, + Data, + EndOfMessage, + Event, + InformationalResponse, + Request, + Response, +) +from h11._state import ( + CLIENT, + CLOSED, + DONE, + ERROR, + IDLE, + MIGHT_SWITCH_PROTOCOL, + MUST_CLOSE, + SEND_BODY, + SEND_RESPONSE, + SERVER, + SWITCHED_PROTOCOL, +) +from h11._util import LocalProtocolError, ProtocolError, RemoteProtocolError +from h11._version import __version__ PRODUCT_ID = "python-h11/" + __version__ -__all__ = ["ProtocolError", "LocalProtocolError", "RemoteProtocolError"] -__all__ += _events.__all__ -__all__ += _connection.__all__ -__all__ += _state.__all__ +__all__ = ( + "Connection", + "NEED_DATA", + "PAUSED", + "ConnectionClosed", + "Data", + "EndOfMessage", + "Event", + "InformationalResponse", + "Request", + "Response", + "CLIENT", + "CLOSED", + "DONE", + "ERROR", + "IDLE", + "MUST_CLOSE", + "SEND_BODY", + "SEND_RESPONSE", + "SERVER", + "SWITCHED_PROTOCOL", + "ProtocolError", + "LocalProtocolError", + "RemoteProtocolError", +) diff --git a/h11/_connection.py b/h11/_connection.py index bcd3089..a967291 100644 --- a/h11/_connection.py +++ b/h11/_connection.py @@ -1,18 +1,38 @@ # This contains the main Connection class. Everything in h11 revolves around # this. - -from ._events import * # Import all event types +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union + +from ._events import ( + ConnectionClosed, + Data, + EndOfMessage, + Event, + InformationalResponse, + Request, + Response, +) from ._headers import get_comma_header, has_expect_100_continue, set_comma_header -from ._readers import READERS +from ._readers import READERS, ReadersType from ._receivebuffer import ReceiveBuffer -from ._state import * # Import all state sentinels -from ._state import _SWITCH_CONNECT, _SWITCH_UPGRADE, ConnectionState +from ._state import ( + _SWITCH_CONNECT, + _SWITCH_UPGRADE, + CLIENT, + ConnectionState, + DONE, + ERROR, + MIGHT_SWITCH_PROTOCOL, + SEND_BODY, + SERVER, + SWITCHED_PROTOCOL, +) from ._util import ( # Import the internal things we need LocalProtocolError, make_sentinel, RemoteProtocolError, + Sentinel, ) -from ._writers import WRITERS +from ._writers import WRITERS, WritersType # Everything in __all__ gets re-exported as part of the h11 public API. __all__ = ["Connection", "NEED_DATA", "PAUSED"] @@ -44,7 +64,7 @@ # our rule is: # - If someone says Connection: close, we will close # - If someone uses HTTP/1.0, we will close. -def _keep_alive(event): +def _keep_alive(event: Union[Request, Response]) -> bool: connection = get_comma_header(event.headers, b"connection") if b"close" in connection: return False @@ -53,7 +73,9 @@ def _keep_alive(event): return True -def _body_framing(request_method, event): +def _body_framing( + request_method: bytes, event: Union[Request, Response] +) -> Tuple[str, Union[Tuple[()], Tuple[int]]]: # Called when we enter SEND_BODY to figure out framing information for # this body. # @@ -126,8 +148,10 @@ class Connection: """ def __init__( - self, our_role, max_incomplete_event_size=DEFAULT_MAX_INCOMPLETE_EVENT_SIZE - ): + self, + our_role: Sentinel, + max_incomplete_event_size: int = DEFAULT_MAX_INCOMPLETE_EVENT_SIZE, + ) -> None: self._max_incomplete_event_size = max_incomplete_event_size # State and role tracking if our_role not in (CLIENT, SERVER): @@ -155,14 +179,14 @@ def __init__( # These two are only used to interpret framing headers for figuring # out how to read/write response bodies. their_http_version is also # made available as a convenient public API. - self.their_http_version = None - self._request_method = None + self.their_http_version: Optional[bytes] = None + self._request_method: Optional[bytes] = None # This is pure flow-control and doesn't at all affect the set of legal # transitions, so no need to bother ConnectionState with it: self.client_is_waiting_for_100_continue = False @property - def states(self): + def states(self) -> Dict[Sentinel, Sentinel]: """A dictionary like:: {CLIENT: , SERVER: } @@ -173,24 +197,24 @@ def states(self): return dict(self._cstate.states) @property - def our_state(self): + def our_state(self) -> Sentinel: """The current state of whichever role we are playing. See :ref:`state-machine` for details. """ return self._cstate.states[self.our_role] @property - def their_state(self): + def their_state(self) -> Sentinel: """The current state of whichever role we are NOT playing. See :ref:`state-machine` for details. """ return self._cstate.states[self.their_role] @property - def they_are_waiting_for_100_continue(self): + def they_are_waiting_for_100_continue(self) -> bool: return self.their_role is CLIENT and self.client_is_waiting_for_100_continue - def start_next_cycle(self): + def start_next_cycle(self) -> None: """Attempt to reset our connection state for a new request/response cycle. @@ -210,12 +234,12 @@ def start_next_cycle(self): assert not self.client_is_waiting_for_100_continue self._respond_to_state_changes(old_states) - def _process_error(self, role): + def _process_error(self, role: Sentinel) -> None: old_states = dict(self._cstate.states) self._cstate.process_error(role) self._respond_to_state_changes(old_states) - def _server_switch_event(self, event): + def _server_switch_event(self, event: Event) -> Optional[Sentinel]: if type(event) is InformationalResponse and event.status_code == 101: return _SWITCH_UPGRADE if type(event) is Response: @@ -227,7 +251,7 @@ def _server_switch_event(self, event): return None # All events go through here - def _process_event(self, role, event): + def _process_event(self, role: Sentinel, event: Event) -> None: # First, pass the event through the state machine to make sure it # succeeds. old_states = dict(self._cstate.states) @@ -243,15 +267,16 @@ def _process_event(self, role, event): # Then perform the updates triggered by it. - # self._request_method - if type(event) is Request: + if isinstance(event, Request): self._request_method = event.method - # self.their_http_version - if role is self.their_role and type(event) in ( - Request, - Response, - InformationalResponse, + if role is self.their_role and isinstance( + event, + ( + Request, + Response, + InformationalResponse, + ), ): self.their_http_version = event.http_version @@ -261,35 +286,44 @@ def _process_event(self, role, event): # shows up on a 1xx InformationalResponse. I think the idea is that # this is not supposed to happen. In any case, if it does happen, we # ignore it. - if type(event) in (Request, Response) and not _keep_alive(event): + if isinstance(event, (Request, Response)) and not _keep_alive(event): self._cstate.process_keep_alive_disabled() # 100-continue - if type(event) is Request and has_expect_100_continue(event): + if isinstance(event, Request) and has_expect_100_continue(event): self.client_is_waiting_for_100_continue = True - if type(event) in (InformationalResponse, Response): + if isinstance(event, (InformationalResponse, Response)): self.client_is_waiting_for_100_continue = False - if role is CLIENT and type(event) in (Data, EndOfMessage): + if role is CLIENT and isinstance(event, (Data, EndOfMessage)): self.client_is_waiting_for_100_continue = False self._respond_to_state_changes(old_states, event) - def _get_io_object(self, role, event, io_dict): + def _get_io_object( + self, + role: Sentinel, + event: Optional[Event], + io_dict: Union[ReadersType, WritersType], + ) -> Optional[Callable[..., Any]]: # event may be None; it's only used when entering SEND_BODY state = self._cstate.states[role] if state is SEND_BODY: # Special case: the io_dict has a dict of reader/writer factories # that depend on the request/response framing. - framing_type, args = _body_framing(self._request_method, event) - return io_dict[SEND_BODY][framing_type](*args) + framing_type, args = _body_framing( + cast(bytes, self._request_method), cast(Union[Request, Response], event) + ) + return io_dict[SEND_BODY][framing_type](*args) # type: ignore[index] else: # General case: the io_dict just has the appropriate reader/writer # for this state - return io_dict.get((role, state)) + return io_dict.get((role, state)) # type: ignore[return-value] # This must be called after any action that might have caused # self._cstate.states to change. - def _respond_to_state_changes(self, old_states, event=None): + def _respond_to_state_changes( + self, old_states: Dict[Sentinel, Sentinel], event: Optional[Event] = None + ) -> None: # Update reader/writer if self.our_state != old_states[self.our_role]: self._writer = self._get_io_object(self.our_role, event, WRITERS) @@ -297,7 +331,7 @@ def _respond_to_state_changes(self, old_states, event=None): self._reader = self._get_io_object(self.their_role, event, READERS) @property - def trailing_data(self): + def trailing_data(self) -> Tuple[bytes, bool]: """Data that has been received, but not yet processed, represented as a tuple with two elements, where the first is a byte-string containing the unprocessed data itself, and the second is a bool that is True if @@ -307,7 +341,7 @@ def trailing_data(self): """ return (bytes(self._receive_buffer), self._receive_buffer_closed) - def receive_data(self, data): + def receive_data(self, data: bytes) -> None: """Add data to our internal receive buffer. This does not actually do any processing on the data, just stores @@ -353,7 +387,7 @@ def receive_data(self, data): else: self._receive_buffer_closed = True - def _extract_next_receive_event(self): + def _extract_next_receive_event(self) -> Union[Event, Sentinel]: state = self.their_state # We don't pause immediately when they enter DONE, because even in # DONE state we can still process a ConnectionClosed() event. But @@ -372,14 +406,14 @@ def _extract_next_receive_event(self): # return that event, and then the state will change and we'll # get called again to generate the actual ConnectionClosed(). if hasattr(self._reader, "read_eof"): - event = self._reader.read_eof() + event = self._reader.read_eof() # type: ignore[attr-defined] else: event = ConnectionClosed() if event is None: event = NEED_DATA - return event + return event # type: ignore[no-any-return] - def next_event(self): + def next_event(self) -> Union[Event, Sentinel]: """Parse the next event out of our receive buffer, update our internal state, and return it. @@ -424,7 +458,7 @@ def next_event(self): try: event = self._extract_next_receive_event() if event not in [NEED_DATA, PAUSED]: - self._process_event(self.their_role, event) + self._process_event(self.their_role, cast(Event, event)) if event is NEED_DATA: if len(self._receive_buffer) > self._max_incomplete_event_size: # 431 is "Request header fields too large" which is pretty @@ -444,7 +478,7 @@ def next_event(self): else: raise - def send(self, event): + def send(self, event: Event) -> Optional[bytes]: """Convert a high-level event into bytes that can be sent to the peer, while updating our internal state machine. @@ -471,7 +505,7 @@ def send(self, event): else: return b"".join(data_list) - def send_with_data_passthrough(self, event): + def send_with_data_passthrough(self, event: Event) -> Optional[List[bytes]]: """Identical to :meth:`send`, except that in situations where :meth:`send` returns a single :term:`bytes-like object`, this instead returns a list of them -- and when sending a :class:`Data` event, this @@ -482,7 +516,7 @@ def send_with_data_passthrough(self, event): if self.our_state is ERROR: raise LocalProtocolError("Can't send data when our state is ERROR") try: - if type(event) is Response: + if isinstance(event, Response): event = self._clean_up_response_headers_for_sending(event) # We want to call _process_event before calling the writer, # because if someone tries to do something invalid then this will @@ -491,20 +525,20 @@ def send_with_data_passthrough(self, event): # change self._writer. So we have to do a little dance: writer = self._writer self._process_event(self.our_role, event) - if type(event) is ConnectionClosed: + if isinstance(event, ConnectionClosed): return None else: # In any situation where writer is None, process_event should # have raised ProtocolError assert writer is not None - data_list = [] + data_list: List[bytes] = [] writer(event, data_list.append) return data_list except: self._process_error(self.our_role) raise - def send_failed(self): + def send_failed(self) -> None: """Notify the state machine that we failed to send the data it gave us. @@ -529,7 +563,7 @@ def send_failed(self): # This function's *only* responsibility is making sure headers are set up # right -- everything downstream just looks at the headers. There are no # side channels. - def _clean_up_response_headers_for_sending(self, response): + def _clean_up_response_headers_for_sending(self, response: Response) -> Response: assert type(response) is Response headers = response.headers @@ -542,7 +576,7 @@ def _clean_up_response_headers_for_sending(self, response): # we're allowed to leave out the framing headers -- see # https://tools.ietf.org/html/rfc7231#section-4.3.2 . But it's just as # easy to get them right.) - method_for_choosing_headers = self._request_method + method_for_choosing_headers = cast(bytes, self._request_method) if method_for_choosing_headers == b"HEAD": method_for_choosing_headers = b"GET" framing_type, _ = _body_framing(method_for_choosing_headers, response) @@ -572,7 +606,7 @@ def _clean_up_response_headers_for_sending(self, response): if self._request_method != b"HEAD": need_close = True else: - headers = set_comma_header(headers, b"transfer-encoding", ["chunked"]) + headers = set_comma_header(headers, b"transfer-encoding", [b"chunked"]) if not self._cstate.keep_alive or need_close: # Make sure Connection: close is set diff --git a/h11/_headers.py b/h11/_headers.py index e8f24d6..4470937 100644 --- a/h11/_headers.py +++ b/h11/_headers.py @@ -1,9 +1,12 @@ import re -from collections.abc import Sequence +from typing import cast, List, Sequence, Tuple, TYPE_CHECKING, Union from ._abnf import field_name, field_value from ._util import bytesify, LocalProtocolError, validate +if TYPE_CHECKING: + from ._events import Request + # Facts # ----- # @@ -63,7 +66,7 @@ _field_value_re = re.compile(field_value.encode("ascii")) -class Headers(Sequence): +class Headers(Sequence[Tuple[bytes, bytes]]): """ A list-like interface that allows iterating over headers as byte-pairs of (lowercased-name, value). @@ -90,30 +93,39 @@ class Headers(Sequence): __slots__ = "_full_items" - def __init__(self, full_items): + def __init__(self, full_items: List[Tuple[bytes, bytes, bytes]]) -> None: self._full_items = full_items - def __bool__(self): + def __bool__(self) -> bool: return bool(self._full_items) - def __eq__(self, other): - return list(self) == list(other) + def __eq__(self, other: object) -> bool: + return list(self) == list(other) # type: ignore - def __len__(self): + def __len__(self) -> int: return len(self._full_items) - def __repr__(self): + def __repr__(self) -> str: return "" % repr(list(self)) - def __getitem__(self, idx): + def __getitem__(self, idx: int) -> Tuple[bytes, bytes]: # type: ignore[override] _, name, value = self._full_items[idx] return (name, value) - def raw_items(self): + def raw_items(self) -> List[Tuple[bytes, bytes]]: return [(raw_name, value) for raw_name, _, value in self._full_items] -def normalize_and_validate(headers, _parsed=False): +HeaderTypes = Union[ + List[Tuple[bytes, bytes]], + List[Tuple[bytes, str]], + List[Tuple[str, bytes]], + List[Tuple[str, str]], + Headers, +] + + +def normalize_and_validate(headers: HeaderTypes, _parsed: bool = False) -> Headers: new_headers = [] seen_content_length = None saw_transfer_encoding = False @@ -126,6 +138,9 @@ def normalize_and_validate(headers, _parsed=False): value = bytesify(value) validate(_field_name_re, name, "Illegal header name {!r}", name) validate(_field_value_re, value, "Illegal header value {!r}", value) + assert isinstance(name, bytes) + assert isinstance(value, bytes) + raw_name = name name = name.lower() if name == b"content-length": @@ -163,7 +178,7 @@ def normalize_and_validate(headers, _parsed=False): return Headers(new_headers) -def get_comma_header(headers, name): +def get_comma_header(headers: Headers, name: bytes) -> List[bytes]: # Should only be used for headers whose value is a list of # comma-separated, case-insensitive values. # @@ -199,7 +214,7 @@ def get_comma_header(headers, name): # Expect: the only legal value is the literal string # "100-continue". Splitting on commas is harmless. Case insensitive. # - out = [] + out: List[bytes] = [] for _, found_name, found_raw_value in headers._full_items: if found_name == name: found_raw_value = found_raw_value.lower() @@ -210,7 +225,7 @@ def get_comma_header(headers, name): return out -def set_comma_header(headers, name, new_values): +def set_comma_header(headers: Headers, name: bytes, new_values: List[bytes]) -> Headers: # The header name `name` is expected to be lower-case bytes. # # Note that when we store the header we use title casing for the header @@ -220,7 +235,7 @@ def set_comma_header(headers, name, new_values): # here given the cases where we're using `set_comma_header`... # # Connection, Content-Length, Transfer-Encoding. - new_headers = [] + new_headers: List[Tuple[bytes, bytes]] = [] for found_raw_name, found_name, found_raw_value in headers._full_items: if found_name != name: new_headers.append((found_raw_name, found_raw_value)) @@ -229,7 +244,7 @@ def set_comma_header(headers, name, new_values): return normalize_and_validate(new_headers) -def has_expect_100_continue(request): +def has_expect_100_continue(request: "Request") -> bool: # https://tools.ietf.org/html/rfc7231#section-5.1.1 # "A server that receives a 100-continue expectation in an HTTP/1.0 request # MUST ignore that expectation." diff --git a/h11/_readers.py b/h11/_readers.py index 0ead0be..a036d79 100644 --- a/h11/_readers.py +++ b/h11/_readers.py @@ -17,11 +17,22 @@ # - or, for body readers, a dict of per-framing reader factories import re +from typing import Any, Callable, Dict, Iterable, NoReturn, Optional, Tuple, Type, Union from ._abnf import chunk_header, header_field, request_line, status_line -from ._events import * -from ._state import * -from ._util import LocalProtocolError, RemoteProtocolError, validate +from ._events import Data, EndOfMessage, InformationalResponse, Request, Response +from ._receivebuffer import ReceiveBuffer +from ._state import ( + CLIENT, + CLOSED, + DONE, + IDLE, + MUST_CLOSE, + SEND_BODY, + SEND_RESPONSE, + SERVER, +) +from ._util import LocalProtocolError, RemoteProtocolError, Sentinel, validate __all__ = ["READERS"] @@ -32,9 +43,9 @@ obs_fold_re = re.compile(br"[ \t]+") -def _obsolete_line_fold(lines): +def _obsolete_line_fold(lines: Iterable[bytes]) -> Iterable[bytes]: it = iter(lines) - last = None + last: Optional[bytes] = None for line in it: match = obs_fold_re.match(line) if match: @@ -52,7 +63,9 @@ def _obsolete_line_fold(lines): yield last -def _decode_header_lines(lines): +def _decode_header_lines( + lines: Iterable[bytes], +) -> Iterable[Tuple[bytes, bytes]]: for line in _obsolete_line_fold(lines): matches = validate(header_field_re, line, "illegal header line: {!r}", line) yield (matches["field_name"], matches["field_value"]) @@ -61,7 +74,7 @@ def _decode_header_lines(lines): request_line_re = re.compile(request_line.encode("ascii")) -def maybe_read_from_IDLE_client(buf): +def maybe_read_from_IDLE_client(buf: ReceiveBuffer) -> Optional[Request]: lines = buf.maybe_extract_lines() if lines is None: if buf.is_next_line_obviously_invalid_request_line(): @@ -80,7 +93,9 @@ def maybe_read_from_IDLE_client(buf): status_line_re = re.compile(status_line.encode("ascii")) -def maybe_read_from_SEND_RESPONSE_server(buf): +def maybe_read_from_SEND_RESPONSE_server( + buf: ReceiveBuffer, +) -> Union[InformationalResponse, Response, None]: lines = buf.maybe_extract_lines() if lines is None: if buf.is_next_line_obviously_invalid_request_line(): @@ -89,22 +104,29 @@ def maybe_read_from_SEND_RESPONSE_server(buf): if not lines: raise LocalProtocolError("no response line received") matches = validate(status_line_re, lines[0], "illegal status line: {!r}", lines[0]) - # Tolerate missing reason phrases - if matches["reason"] is None: - matches["reason"] = b"" - status_code = matches["status_code"] = int(matches["status_code"]) - class_ = InformationalResponse if status_code < 200 else Response + http_version = ( + b"1.1" if matches["http_version"] is None else matches["http_version"] + ) + reason = b"" if matches["reason"] is None else matches["reason"] + status_code = int(matches["status_code"]) + class_: Union[Type[InformationalResponse], Type[Response]] = ( + InformationalResponse if status_code < 200 else Response + ) return class_( - headers=list(_decode_header_lines(lines[1:])), _parsed=True, **matches + headers=list(_decode_header_lines(lines[1:])), + _parsed=True, + status_code=status_code, + reason=reason, + http_version=http_version, ) class ContentLengthReader: - def __init__(self, length): + def __init__(self, length: int) -> None: self._length = length self._remaining = length - def __call__(self, buf): + def __call__(self, buf: ReceiveBuffer) -> Union[Data, EndOfMessage, None]: if self._remaining == 0: return EndOfMessage() data = buf.maybe_extract_at_most(self._remaining) @@ -113,7 +135,7 @@ def __call__(self, buf): self._remaining -= len(data) return Data(data=data) - def read_eof(self): + def read_eof(self) -> NoReturn: raise RemoteProtocolError( "peer closed connection without sending complete message body " "(received {} bytes, expected {})".format( @@ -126,7 +148,7 @@ def read_eof(self): class ChunkedReader: - def __init__(self): + def __init__(self) -> None: self._bytes_in_chunk = 0 # After reading a chunk, we have to throw away the trailing \r\n; if # this is >0 then we discard that many bytes before resuming regular @@ -134,7 +156,7 @@ def __init__(self): self._bytes_to_discard = 0 self._reading_trailer = False - def __call__(self, buf): + def __call__(self, buf: ReceiveBuffer) -> Union[Data, EndOfMessage, None]: if self._reading_trailer: lines = buf.maybe_extract_lines() if lines is None: @@ -180,7 +202,7 @@ def __call__(self, buf): chunk_end = False return Data(data=data, chunk_start=chunk_start, chunk_end=chunk_end) - def read_eof(self): + def read_eof(self) -> NoReturn: raise RemoteProtocolError( "peer closed connection without sending complete message body " "(incomplete chunked read)" @@ -188,23 +210,28 @@ def read_eof(self): class Http10Reader: - def __call__(self, buf): + def __call__(self, buf: ReceiveBuffer) -> Optional[Data]: data = buf.maybe_extract_at_most(999999999) if data is None: return None return Data(data=data) - def read_eof(self): + def read_eof(self) -> EndOfMessage: return EndOfMessage() -def expect_nothing(buf): +def expect_nothing(buf: ReceiveBuffer) -> None: if buf: raise LocalProtocolError("Got data when expecting EOF") return None -READERS = { +ReadersType = Dict[ + Union[Sentinel, Tuple[Sentinel, Sentinel]], + Union[Callable[..., Any], Dict[str, Callable[..., Any]]], +] + +READERS: ReadersType = { (CLIENT, IDLE): maybe_read_from_IDLE_client, (SERVER, IDLE): maybe_read_from_SEND_RESPONSE_server, (SERVER, SEND_RESPONSE): maybe_read_from_SEND_RESPONSE_server, diff --git a/h11/_receivebuffer.py b/h11/_receivebuffer.py index a3737f3..e5c4e08 100644 --- a/h11/_receivebuffer.py +++ b/h11/_receivebuffer.py @@ -1,5 +1,6 @@ import re import sys +from typing import List, Optional, Union __all__ = ["ReceiveBuffer"] @@ -44,26 +45,26 @@ class ReceiveBuffer: - def __init__(self): + def __init__(self) -> None: self._data = bytearray() self._next_line_search = 0 self._multiple_lines_search = 0 - def __iadd__(self, byteslike): + def __iadd__(self, byteslike: Union[bytes, bytearray]) -> "ReceiveBuffer": self._data += byteslike return self - def __bool__(self): + def __bool__(self) -> bool: return bool(len(self)) - def __len__(self): + def __len__(self) -> int: return len(self._data) # for @property unprocessed_data - def __bytes__(self): + def __bytes__(self) -> bytes: return bytes(self._data) - def _extract(self, count): + def _extract(self, count: int) -> bytearray: # extracting an initial slice of the data buffer and return it out = self._data[:count] del self._data[:count] @@ -73,7 +74,7 @@ def _extract(self, count): return out - def maybe_extract_at_most(self, count): + def maybe_extract_at_most(self, count: int) -> Optional[bytearray]: """ Extract a fixed number of bytes from the buffer. """ @@ -83,7 +84,7 @@ def maybe_extract_at_most(self, count): return self._extract(count) - def maybe_extract_next_line(self): + def maybe_extract_next_line(self) -> Optional[bytearray]: """ Extract the first line, if it is completed in the buffer. """ @@ -100,7 +101,7 @@ def maybe_extract_next_line(self): return self._extract(idx) - def maybe_extract_lines(self): + def maybe_extract_lines(self) -> Optional[List[bytearray]]: """ Extract everything up to the first blank line, and return a list of lines. """ @@ -143,7 +144,7 @@ def maybe_extract_lines(self): # This is especially interesting when peer is messing up with HTTPS and # sent us a TLS stream where we were expecting plain HTTP given all # versions of TLS so far start handshake with a 0x16 message type code. - def is_next_line_obviously_invalid_request_line(self): + def is_next_line_obviously_invalid_request_line(self) -> bool: try: # HTTP header line must not contain non-printable characters # and should not start with a space diff --git a/h11/_state.py b/h11/_state.py index 0f08a09..aed1a33 100644 --- a/h11/_state.py +++ b/h11/_state.py @@ -110,9 +110,10 @@ # tables. But it can't automatically read the transitions that are written # directly in Python code. So if you touch those, you need to also update the # script to keep it in sync! +from typing import cast, Dict, Optional, Set, Tuple, Type, Union from ._events import * -from ._util import LocalProtocolError, make_sentinel +from ._util import LocalProtocolError, make_sentinel, Sentinel # Everything in __all__ gets re-exported as part of the h11 public API. __all__ = [ @@ -148,7 +149,12 @@ _SWITCH_UPGRADE = make_sentinel("_SWITCH_UPGRADE") _SWITCH_CONNECT = make_sentinel("_SWITCH_CONNECT") -EVENT_TRIGGERED_TRANSITIONS = { +EventTransitionType = Dict[ + Sentinel, + Dict[Sentinel, Dict[Union[Type[Event], Tuple[Type[Event], Sentinel]], Sentinel]], +] + +EVENT_TRIGGERED_TRANSITIONS: EventTransitionType = { CLIENT: { IDLE: {Request: SEND_BODY, ConnectionClosed: CLOSED}, SEND_BODY: {Data: SEND_BODY, EndOfMessage: DONE}, @@ -198,7 +204,7 @@ class ConnectionState: - def __init__(self): + def __init__(self) -> None: # Extra bits of state that don't quite fit into the state model. # If this is False then it enables the automatic DONE -> MUST_CLOSE @@ -207,23 +213,29 @@ def __init__(self): # This is a subset of {UPGRADE, CONNECT}, containing the proposals # made by the client for switching protocols. - self.pending_switch_proposals = set() + self.pending_switch_proposals: Set[Sentinel] = set() self.states = {CLIENT: IDLE, SERVER: IDLE} - def process_error(self, role): + def process_error(self, role: Sentinel) -> None: self.states[role] = ERROR self._fire_state_triggered_transitions() - def process_keep_alive_disabled(self): + def process_keep_alive_disabled(self) -> None: self.keep_alive = False self._fire_state_triggered_transitions() - def process_client_switch_proposal(self, switch_event): + def process_client_switch_proposal(self, switch_event: Sentinel) -> None: self.pending_switch_proposals.add(switch_event) self._fire_state_triggered_transitions() - def process_event(self, role, event_type, server_switch_event=None): + def process_event( + self, + role: Sentinel, + event_type: Type[Event], + server_switch_event: Optional[Sentinel] = None, + ) -> None: + _event_type: Union[Type[Event], Tuple[Type[Event], Sentinel]] = event_type if server_switch_event is not None: assert role is SERVER if server_switch_event not in self.pending_switch_proposals: @@ -232,22 +244,27 @@ def process_event(self, role, event_type, server_switch_event=None): server_switch_event ) ) - event_type = (event_type, server_switch_event) - if server_switch_event is None and event_type is Response: + _event_type = (event_type, server_switch_event) + if server_switch_event is None and _event_type is Response: self.pending_switch_proposals = set() - self._fire_event_triggered_transitions(role, event_type) + self._fire_event_triggered_transitions(role, _event_type) # Special case: the server state does get to see Request # events. - if event_type is Request: + if _event_type is Request: assert role is CLIENT self._fire_event_triggered_transitions(SERVER, (Request, CLIENT)) self._fire_state_triggered_transitions() - def _fire_event_triggered_transitions(self, role, event_type): + def _fire_event_triggered_transitions( + self, + role: Sentinel, + event_type: Union[Type[Event], Tuple[Type[Event], Sentinel]], + ) -> None: state = self.states[role] try: new_state = EVENT_TRIGGERED_TRANSITIONS[role][state][event_type] except KeyError: + event_type = cast(Type[Event], event_type) raise LocalProtocolError( "can't handle event type {} when role={} and state={}".format( event_type.__name__, role, self.states[role] @@ -255,7 +272,7 @@ def _fire_event_triggered_transitions(self, role, event_type): ) self.states[role] = new_state - def _fire_state_triggered_transitions(self): + def _fire_state_triggered_transitions(self) -> None: # We apply these rules repeatedly until converging on a fixed point while True: start_states = dict(self.states) @@ -295,7 +312,7 @@ def _fire_state_triggered_transitions(self): # Fixed point reached return - def start_next_cycle(self): + def start_next_cycle(self) -> None: if self.states != {CLIENT: DONE, SERVER: DONE}: raise LocalProtocolError( "not in a reusable state. self.states={}".format(self.states) diff --git a/h11/_util.py b/h11/_util.py index eb1a5cd..b444f58 100644 --- a/h11/_util.py +++ b/h11/_util.py @@ -1,3 +1,5 @@ +from typing import Any, Dict, NoReturn, Pattern, Union + __all__ = [ "ProtocolError", "LocalProtocolError", @@ -37,7 +39,7 @@ class ProtocolError(Exception): """ - def __init__(self, msg, error_status_hint=400): + def __init__(self, msg: str, error_status_hint: int = 400) -> None: if type(self) is ProtocolError: raise TypeError("tried to directly instantiate ProtocolError") Exception.__init__(self, msg) @@ -56,14 +58,14 @@ def __init__(self, msg, error_status_hint=400): # LocalProtocolError is for local errors and RemoteProtocolError is for # remote errors. class LocalProtocolError(ProtocolError): - def _reraise_as_remote_protocol_error(self): + def _reraise_as_remote_protocol_error(self) -> NoReturn: # After catching a LocalProtocolError, use this method to re-raise it # as a RemoteProtocolError. This method must be called from inside an # except: block. # # An easy way to get an equivalent RemoteProtocolError is just to # modify 'self' in place. - self.__class__ = RemoteProtocolError + self.__class__ = RemoteProtocolError # type: ignore # But the re-raising is somewhat non-trivial -- you might think that # now that we've modified the in-flight exception object, that just # doing 'raise' to re-raise it would be enough. But it turns out that @@ -80,7 +82,9 @@ class RemoteProtocolError(ProtocolError): pass -def validate(regex, data, msg="malformed data", *format_args): +def validate( + regex: Pattern[bytes], data: bytes, msg: str = "malformed data", *format_args: Any +) -> Dict[str, bytes]: match = regex.fullmatch(data) if not match: if format_args: @@ -97,23 +101,23 @@ def validate(regex, data, msg="malformed data", *format_args): # # The bonus property is useful if you want to take the return value from # next_event() and do some sort of dispatch based on type(event). -class _SentinelBase(type): - def __repr__(self): +class Sentinel(type): + def __repr__(self) -> str: return self.__name__ -def make_sentinel(name): - cls = _SentinelBase(name, (_SentinelBase,), {}) - cls.__class__ = cls +def make_sentinel(name: str) -> Sentinel: + cls = Sentinel(name, (Sentinel,), {}) + cls.__class__ = cls # type: ignore return cls # Used for methods, request targets, HTTP versions, header names, and header # values. Accepts ascii-strings, or bytes/bytearray/memoryview/..., and always # returns bytes. -def bytesify(s): +def bytesify(s: Union[bytes, bytearray, memoryview, int, str]) -> bytes: # Fast-path: - if type(s) is bytes: + if isinstance(s, bytes): return s if isinstance(s, str): s = s.encode("ascii") diff --git a/h11/_writers.py b/h11/_writers.py index cb5e8a8..a0e76df 100644 --- a/h11/_writers.py +++ b/h11/_writers.py @@ -7,14 +7,19 @@ # - a writer # - or, for body writers, a dict of framin-dependent writer factories -from ._events import Data, EndOfMessage +from typing import Any, Callable, Dict, List, Tuple, Type, Union + +from ._events import Data, EndOfMessage, Event, InformationalResponse, Request, Response +from ._headers import Headers from ._state import CLIENT, IDLE, SEND_BODY, SEND_RESPONSE, SERVER -from ._util import LocalProtocolError +from ._util import LocalProtocolError, Sentinel __all__ = ["WRITERS"] +Writer = Callable[[bytes], Any] + -def write_headers(headers, write): +def write_headers(headers: Headers, write: Writer) -> None: # "Since the Host field-value is critical information for handling a # request, a user agent SHOULD generate Host as the first header field # following the request-line." - RFC 7230 @@ -28,7 +33,7 @@ def write_headers(headers, write): write(b"\r\n") -def write_request(request, write): +def write_request(request: Request, write: Writer) -> None: if request.http_version != b"1.1": raise LocalProtocolError("I only send HTTP/1.1") write(b"%s %s HTTP/1.1\r\n" % (request.method, request.target)) @@ -36,7 +41,9 @@ def write_request(request, write): # Shared between InformationalResponse and Response -def write_any_response(response, write): +def write_any_response( + response: Union[InformationalResponse, Response], write: Writer +) -> None: if response.http_version != b"1.1": raise LocalProtocolError("I only send HTTP/1.1") status_bytes = str(response.status_code).encode("ascii") @@ -53,14 +60,20 @@ def write_any_response(response, write): class BodyWriter: - def __call__(self, event, write): - if type(event) is Data: + def __call__(self, event: Event, write: Writer) -> None: + if isinstance(event, Data): self.send_data(event.data, write) - elif type(event) is EndOfMessage: + elif isinstance(event, EndOfMessage): self.send_eom(event.headers, write) else: # pragma: no cover assert False + def send_data(self, data: bytes, write: Writer) -> None: + pass + + def send_eom(self, headers: Headers, write: Writer) -> None: + pass + # # These are all careful not to do anything to 'data' except call len(data) and @@ -69,16 +82,16 @@ def __call__(self, event, write): # sendfile(2). # class ContentLengthWriter(BodyWriter): - def __init__(self, length): + def __init__(self, length: int) -> None: self._length = length - def send_data(self, data, write): + def send_data(self, data: bytes, write: Writer) -> None: self._length -= len(data) if self._length < 0: raise LocalProtocolError("Too much data for declared Content-Length") write(data) - def send_eom(self, headers, write): + def send_eom(self, headers: Headers, write: Writer) -> None: if self._length != 0: raise LocalProtocolError("Too little data for declared Content-Length") if headers: @@ -86,7 +99,7 @@ def send_eom(self, headers, write): class ChunkedWriter(BodyWriter): - def send_data(self, data, write): + def send_data(self, data: bytes, write: Writer) -> None: # if we encoded 0-length data in the naive way, it would look like an # end-of-message. if not data: @@ -95,23 +108,32 @@ def send_data(self, data, write): write(data) write(b"\r\n") - def send_eom(self, headers, write): + def send_eom(self, headers: Headers, write: Writer) -> None: write(b"0\r\n") write_headers(headers, write) class Http10Writer(BodyWriter): - def send_data(self, data, write): + def send_data(self, data: bytes, write: Writer) -> None: write(data) - def send_eom(self, headers, write): + def send_eom(self, headers: Headers, write: Writer) -> None: if headers: raise LocalProtocolError("can't send trailers to HTTP/1.0 client") # no need to close the socket ourselves, that will be taken care of by # Connection: close machinery -WRITERS = { +WritersType = Dict[ + Union[Tuple[Sentinel, Sentinel], Sentinel], + Union[ + Dict[str, Type[BodyWriter]], + Callable[[Union[InformationalResponse, Response], Writer], None], + Callable[[Request, Writer], None], + ], +] + +WRITERS: WritersType = { (CLIENT, IDLE): write_request, (SERVER, IDLE): write_any_response, (SERVER, SEND_RESPONSE): write_any_response, diff --git a/h11/tests/helpers.py b/h11/tests/helpers.py index 5f53457..c737520 100644 --- a/h11/tests/helpers.py +++ b/h11/tests/helpers.py @@ -1,31 +1,43 @@ -from .._connection import * -from .._events import * -from .._state import * +from typing import cast, List, Union, ValuesView +from .._connection import Connection, NEED_DATA, PAUSED +from .._events import ( + ConnectionClosed, + Data, + EndOfMessage, + Event, + InformationalResponse, + Request, + Response, +) +from .._state import CLIENT, CLOSED, DONE, MUST_CLOSE, SERVER +from .._util import Sentinel -def get_all_events(conn): + +def get_all_events(conn: Connection) -> List[Event]: got_events = [] while True: event = conn.next_event() if event in (NEED_DATA, PAUSED): break + event = cast(Event, event) got_events.append(event) if type(event) is ConnectionClosed: break return got_events -def receive_and_get(conn, data): +def receive_and_get(conn: Connection, data: bytes) -> List[Event]: conn.receive_data(data) return get_all_events(conn) # Merges adjacent Data events, converts payloads to bytestrings, and removes # chunk boundaries. -def normalize_data_events(in_events): - out_events = [] +def normalize_data_events(in_events: List[Event]) -> List[Event]: + out_events: List[Event] = [] for event in in_events: - if type(event) is Data: + if isinstance(event, Data): event = Data(data=bytes(event.data), chunk_start=False, chunk_end=False) if out_events and type(out_events[-1]) is type(event) is Data: out_events[-1] = Data( @@ -43,16 +55,21 @@ def normalize_data_events(in_events): # of pushing them through two Connections with a fake network link in # between. class ConnectionPair: - def __init__(self): + def __init__(self) -> None: self.conn = {CLIENT: Connection(CLIENT), SERVER: Connection(SERVER)} self.other = {CLIENT: SERVER, SERVER: CLIENT} @property - def conns(self): + def conns(self) -> ValuesView[Connection]: return self.conn.values() # expect="match" if expect=send_events; expect=[...] to say what expected - def send(self, role, send_events, expect="match"): + def send( + self, + role: Sentinel, + send_events: Union[List[Event], Event], + expect: Union[List[Event], Event, str] = "match", + ) -> bytes: if not isinstance(send_events, list): send_events = [send_events] data = b"" @@ -74,6 +91,6 @@ def send(self, role, send_events, expect="match"): if expect == "match": expect = send_events if not isinstance(expect, list): - expect = [expect] + expect = [expect] # type: ignore assert got_events == expect return data diff --git a/h11/tests/test_against_stdlib_http.py b/h11/tests/test_against_stdlib_http.py index e6c5db4..d2ee131 100644 --- a/h11/tests/test_against_stdlib_http.py +++ b/h11/tests/test_against_stdlib_http.py @@ -5,13 +5,16 @@ import threading from contextlib import closing, contextmanager from http.server import SimpleHTTPRequestHandler +from typing import Callable, Generator from urllib.request import urlopen import h11 @contextmanager -def socket_server(handler): +def socket_server( + handler: Callable[..., socketserver.BaseRequestHandler] +) -> Generator[socketserver.TCPServer, None, None]: httpd = socketserver.TCPServer(("127.0.0.1", 0), handler) thread = threading.Thread( target=httpd.serve_forever, kwargs={"poll_interval": 0.01} @@ -30,23 +33,23 @@ def socket_server(handler): class SingleMindedRequestHandler(SimpleHTTPRequestHandler): - def translate_path(self, path): + def translate_path(self, path: str) -> str: return test_file_path -def test_h11_as_client(): +def test_h11_as_client() -> None: with socket_server(SingleMindedRequestHandler) as httpd: with closing(socket.create_connection(httpd.server_address)) as s: c = h11.Connection(h11.CLIENT) s.sendall( - c.send( + c.send( # type: ignore[arg-type] h11.Request( method="GET", target="/foo", headers=[("Host", "localhost")] ) ) ) - s.sendall(c.send(h11.EndOfMessage())) + s.sendall(c.send(h11.EndOfMessage())) # type: ignore[arg-type] data = bytearray() while True: @@ -67,7 +70,7 @@ def test_h11_as_client(): class H11RequestHandler(socketserver.BaseRequestHandler): - def handle(self): + def handle(self) -> None: with closing(self.request) as s: c = h11.Connection(h11.SERVER) request = None @@ -82,6 +85,7 @@ def handle(self): request = event if type(event) is h11.EndOfMessage: break + assert request is not None info = json.dumps( { "method": request.method.decode("ascii"), @@ -92,12 +96,12 @@ def handle(self): }, } ) - s.sendall(c.send(h11.Response(status_code=200, headers=[]))) + s.sendall(c.send(h11.Response(status_code=200, headers=[]))) # type: ignore[arg-type] s.sendall(c.send(h11.Data(data=info.encode("ascii")))) s.sendall(c.send(h11.EndOfMessage())) -def test_h11_as_server(): +def test_h11_as_server() -> None: with socket_server(H11RequestHandler) as httpd: host, port = httpd.server_address url = "http://{}:{}/some-path".format(host, port) diff --git a/h11/tests/test_connection.py b/h11/tests/test_connection.py index de175de..af84883 100644 --- a/h11/tests/test_connection.py +++ b/h11/tests/test_connection.py @@ -1,13 +1,35 @@ +from typing import Any, cast, Dict, List, Optional, Tuple + import pytest from .._connection import _body_framing, _keep_alive, Connection, NEED_DATA, PAUSED -from .._events import * -from .._state import * -from .._util import LocalProtocolError, RemoteProtocolError +from .._events import ( + ConnectionClosed, + Data, + EndOfMessage, + Event, + InformationalResponse, + Request, + Response, +) +from .._state import ( + CLIENT, + CLOSED, + DONE, + ERROR, + IDLE, + MIGHT_SWITCH_PROTOCOL, + MUST_CLOSE, + SEND_BODY, + SEND_RESPONSE, + SERVER, + SWITCHED_PROTOCOL, +) +from .._util import LocalProtocolError, RemoteProtocolError, Sentinel from .helpers import ConnectionPair, get_all_events, receive_and_get -def test__keep_alive(): +def test__keep_alive() -> None: assert _keep_alive( Request(method="GET", target="/", headers=[("Host", "Example.com")]) ) @@ -26,19 +48,19 @@ def test__keep_alive(): ) ) assert not _keep_alive( - Request(method="GET", target="/", headers=[], http_version="1.0") + Request(method="GET", target="/", headers=[], http_version="1.0") # type: ignore[arg-type] ) - assert _keep_alive(Response(status_code=200, headers=[])) + assert _keep_alive(Response(status_code=200, headers=[])) # type: ignore[arg-type] assert not _keep_alive(Response(status_code=200, headers=[("Connection", "close")])) assert not _keep_alive( Response(status_code=200, headers=[("Connection", "a, b, cLOse, foo")]) ) - assert not _keep_alive(Response(status_code=200, headers=[], http_version="1.0")) + assert not _keep_alive(Response(status_code=200, headers=[], http_version="1.0")) # type: ignore[arg-type] -def test__body_framing(): - def headers(cl, te): +def test__body_framing() -> None: + def headers(cl: Optional[int], te: bool) -> List[Tuple[str, str]]: headers = [] if cl is not None: headers.append(("Content-Length", str(cl))) @@ -46,16 +68,19 @@ def headers(cl, te): headers.append(("Transfer-Encoding", "chunked")) return headers - def resp(status_code=200, cl=None, te=False): + def resp( + status_code: int = 200, cl: Optional[int] = None, te: bool = False + ) -> Response: return Response(status_code=status_code, headers=headers(cl, te)) - def req(cl=None, te=False): + def req(cl: Optional[int] = None, te: bool = False) -> Request: h = headers(cl, te) h += [("Host", "example.com")] return Request(method="GET", target="/", headers=h) # Special cases where the headers are ignored: for kwargs in [{}, {"cl": 100}, {"te": True}, {"cl": 100, "te": True}]: + kwargs = cast(Dict[str, Any], kwargs) for meth, r in [ (b"HEAD", resp(**kwargs)), (b"GET", resp(status_code=204, **kwargs)), @@ -65,21 +90,22 @@ def req(cl=None, te=False): # Transfer-encoding for kwargs in [{"te": True}, {"cl": 100, "te": True}]: - for meth, r in [(None, req(**kwargs)), (b"GET", resp(**kwargs))]: + kwargs = cast(Dict[str, Any], kwargs) + for meth, r in [(None, req(**kwargs)), (b"GET", resp(**kwargs))]: # type: ignore assert _body_framing(meth, r) == ("chunked", ()) # Content-Length - for meth, r in [(None, req(cl=100)), (b"GET", resp(cl=100))]: + for meth, r in [(None, req(cl=100)), (b"GET", resp(cl=100))]: # type: ignore assert _body_framing(meth, r) == ("content-length", (100,)) # No headers - assert _body_framing(None, req()) == ("content-length", (0,)) + assert _body_framing(None, req()) == ("content-length", (0,)) # type: ignore assert _body_framing(b"GET", resp()) == ("http/1.0", ()) -def test_Connection_basics_and_content_length(): +def test_Connection_basics_and_content_length() -> None: with pytest.raises(ValueError): - Connection("CLIENT") + Connection("CLIENT") # type: ignore p = ConnectionPair() assert p.conn[CLIENT].our_role is CLIENT @@ -109,7 +135,7 @@ def test_Connection_basics_and_content_length(): assert p.conn[CLIENT].their_http_version is None assert p.conn[SERVER].their_http_version == b"1.1" - data = p.send(SERVER, InformationalResponse(status_code=100, headers=[])) + data = p.send(SERVER, InformationalResponse(status_code=100, headers=[])) # type: ignore[arg-type] assert data == b"HTTP/1.1 100 \r\n\r\n" data = p.send(SERVER, Response(status_code=200, headers=[("Content-Length", "11")])) @@ -144,7 +170,7 @@ def test_Connection_basics_and_content_length(): assert conn.states == {CLIENT: DONE, SERVER: DONE} -def test_chunked(): +def test_chunked() -> None: p = ConnectionPair() p.send( @@ -175,7 +201,7 @@ def test_chunked(): assert conn.states == {CLIENT: DONE, SERVER: DONE} -def test_chunk_boundaries(): +def test_chunk_boundaries() -> None: conn = Connection(our_role=SERVER) request = ( @@ -214,14 +240,14 @@ def test_chunk_boundaries(): assert conn.next_event() == EndOfMessage() -def test_client_talking_to_http10_server(): +def test_client_talking_to_http10_server() -> None: c = Connection(CLIENT) c.send(Request(method="GET", target="/", headers=[("Host", "example.com")])) c.send(EndOfMessage()) assert c.our_state is DONE # No content-length, so Http10 framing for body assert receive_and_get(c, b"HTTP/1.0 200 OK\r\n\r\n") == [ - Response(status_code=200, headers=[], http_version="1.0", reason=b"OK") + Response(status_code=200, headers=[], http_version="1.0", reason=b"OK") # type: ignore[arg-type] ] assert c.our_state is MUST_CLOSE assert receive_and_get(c, b"12345") == [Data(data=b"12345")] @@ -230,19 +256,19 @@ def test_client_talking_to_http10_server(): assert c.their_state is CLOSED -def test_server_talking_to_http10_client(): +def test_server_talking_to_http10_client() -> None: c = Connection(SERVER) # No content-length, so no body # NB: no host header assert receive_and_get(c, b"GET / HTTP/1.0\r\n\r\n") == [ - Request(method="GET", target="/", headers=[], http_version="1.0"), + Request(method="GET", target="/", headers=[], http_version="1.0"), # type: ignore[arg-type] EndOfMessage(), ] assert c.their_state is MUST_CLOSE # We automatically Connection: close back at them assert ( - c.send(Response(status_code=200, headers=[])) + c.send(Response(status_code=200, headers=[])) # type: ignore[arg-type] == b"HTTP/1.1 200 \r\nConnection: close\r\n\r\n" ) @@ -267,7 +293,7 @@ def test_server_talking_to_http10_client(): assert receive_and_get(c, b"") == [ConnectionClosed()] -def test_automatic_transfer_encoding_in_response(): +def test_automatic_transfer_encoding_in_response() -> None: # Check that in responses, the user can specify either Transfer-Encoding: # chunked or no framing at all, and in both cases we automatically select # the right option depending on whether the peer speaks HTTP/1.0 or @@ -279,6 +305,7 @@ def test_automatic_transfer_encoding_in_response(): # because if both are set then Transfer-Encoding wins [("Transfer-Encoding", "chunked"), ("Content-Length", "100")], ]: + user_headers = cast(List[Tuple[str, str]], user_headers) p = ConnectionPair() p.send( CLIENT, @@ -308,7 +335,7 @@ def test_automatic_transfer_encoding_in_response(): assert c.send(Data(data=b"12345")) == b"12345" -def test_automagic_connection_close_handling(): +def test_automagic_connection_close_handling() -> None: p = ConnectionPair() # If the user explicitly sets Connection: close, then we notice and # respect it @@ -329,7 +356,7 @@ def test_automagic_connection_close_handling(): p.send( SERVER, # no header here... - [Response(status_code=204, headers=[]), EndOfMessage()], + [Response(status_code=204, headers=[]), EndOfMessage()], # type: ignore[arg-type] # ...but oh look, it arrived anyway expect=[ Response(status_code=204, headers=[("connection", "close")]), @@ -340,8 +367,8 @@ def test_automagic_connection_close_handling(): assert conn.states == {CLIENT: MUST_CLOSE, SERVER: MUST_CLOSE} -def test_100_continue(): - def setup(): +def test_100_continue() -> None: + def setup() -> ConnectionPair: p = ConnectionPair() p.send( CLIENT, @@ -363,7 +390,7 @@ def setup(): # Disabled by 100 Continue p = setup() - p.send(SERVER, InformationalResponse(status_code=100, headers=[])) + p.send(SERVER, InformationalResponse(status_code=100, headers=[])) # type: ignore[arg-type] for conn in p.conns: assert not conn.client_is_waiting_for_100_continue assert not conn.they_are_waiting_for_100_continue @@ -385,7 +412,7 @@ def setup(): assert not conn.they_are_waiting_for_100_continue -def test_max_incomplete_event_size_countermeasure(): +def test_max_incomplete_event_size_countermeasure() -> None: # Infinitely long headers are definitely not okay c = Connection(SERVER) c.receive_data(b"GET / HTTP/1.0\r\nEndless: ") @@ -444,7 +471,7 @@ def test_max_incomplete_event_size_countermeasure(): # Even more data comes in, still no problem c.receive_data(b"X" * 1000) # We can respond and reuse to get the second pipelined request - c.send(Response(status_code=200, headers=[])) + c.send(Response(status_code=200, headers=[])) # type: ignore[arg-type] c.send(EndOfMessage()) c.start_next_cycle() assert get_all_events(c) == [ @@ -454,14 +481,14 @@ def test_max_incomplete_event_size_countermeasure(): # But once we unpause and try to read the next message, and find that it's # incomplete and the buffer is *still* way too large, then *that's* a # problem: - c.send(Response(status_code=200, headers=[])) + c.send(Response(status_code=200, headers=[])) # type: ignore[arg-type] c.send(EndOfMessage()) c.start_next_cycle() with pytest.raises(RemoteProtocolError): c.next_event() -def test_reuse_simple(): +def test_reuse_simple() -> None: p = ConnectionPair() p.send( CLIENT, @@ -494,7 +521,7 @@ def test_reuse_simple(): ) -def test_pipelining(): +def test_pipelining() -> None: # Client doesn't support pipelining, so we have to do this by hand c = Connection(SERVER) assert c.next_event() is NEED_DATA @@ -520,7 +547,7 @@ def test_pipelining(): assert c.next_event() is PAUSED - c.send(Response(status_code=200, headers=[])) + c.send(Response(status_code=200, headers=[])) # type: ignore[arg-type] c.send(EndOfMessage()) assert c.their_state is DONE assert c.our_state is DONE @@ -537,7 +564,7 @@ def test_pipelining(): EndOfMessage(), ] assert c.next_event() is PAUSED - c.send(Response(status_code=200, headers=[])) + c.send(Response(status_code=200, headers=[])) # type: ignore[arg-type] c.send(EndOfMessage()) c.start_next_cycle() @@ -547,7 +574,7 @@ def test_pipelining(): ] # Doesn't pause this time, no trailing data assert c.next_event() is NEED_DATA - c.send(Response(status_code=200, headers=[])) + c.send(Response(status_code=200, headers=[])) # type: ignore[arg-type] c.send(EndOfMessage()) # Arrival of more data triggers pause @@ -566,7 +593,7 @@ def test_pipelining(): c.receive_data(b"FDSA") -def test_protocol_switch(): +def test_protocol_switch() -> None: for (req, deny, accept) in [ ( Request( @@ -608,7 +635,7 @@ def test_protocol_switch(): ), ]: - def setup(): + def setup() -> ConnectionPair: p = ConnectionPair() p.send(CLIENT, req) # No switch-related state change stuff yet; the client has to @@ -656,7 +683,7 @@ def setup(): sc.send(EndOfMessage()) sc.start_next_cycle() assert get_all_events(sc) == [ - Request(method="GET", target="/", headers=[], http_version="1.0"), + Request(method="GET", target="/", headers=[], http_version="1.0"), # type: ignore[arg-type] EndOfMessage(), ] @@ -673,7 +700,7 @@ def setup(): p = setup() sc = p.conn[SERVER] - sc.receive_data(b"") == [] + sc.receive_data(b"") assert sc.next_event() is PAUSED sc.send(deny) assert sc.next_event() == ConnectionClosed() @@ -691,12 +718,12 @@ def setup(): p.conn[SERVER].send(Data(data=b"123")) -def test_close_simple(): +def test_close_simple() -> None: # Just immediately closing a new connection without anything having # happened yet. for (who_shot_first, who_shot_second) in [(CLIENT, SERVER), (SERVER, CLIENT)]: - def setup(): + def setup() -> ConnectionPair: p = ConnectionPair() p.send(who_shot_first, ConnectionClosed()) for conn in p.conns: @@ -732,7 +759,7 @@ def setup(): p.conn[who_shot_first].next_event() -def test_close_different_states(): +def test_close_different_states() -> None: req = [ Request(method="GET", target="/foo", headers=[("Host", "a")]), EndOfMessage(), @@ -798,7 +825,7 @@ def test_close_different_states(): # Receive several requests and then client shuts down their side of the # connection; we can respond to each -def test_pipelined_close(): +def test_pipelined_close() -> None: c = Connection(SERVER) # 2 requests then a close c.receive_data( @@ -818,7 +845,7 @@ def test_pipelined_close(): EndOfMessage(), ] assert c.states[CLIENT] is DONE - c.send(Response(status_code=200, headers=[])) + c.send(Response(status_code=200, headers=[])) # type: ignore[arg-type] c.send(EndOfMessage()) assert c.states[SERVER] is DONE c.start_next_cycle() @@ -833,21 +860,23 @@ def test_pipelined_close(): ConnectionClosed(), ] assert c.states == {CLIENT: CLOSED, SERVER: SEND_RESPONSE} - c.send(Response(status_code=200, headers=[])) + c.send(Response(status_code=200, headers=[])) # type: ignore[arg-type] c.send(EndOfMessage()) assert c.states == {CLIENT: CLOSED, SERVER: MUST_CLOSE} c.send(ConnectionClosed()) assert c.states == {CLIENT: CLOSED, SERVER: CLOSED} -def test_sendfile(): +def test_sendfile() -> None: class SendfilePlaceholder: - def __len__(self): + def __len__(self) -> int: return 10 placeholder = SendfilePlaceholder() - def setup(header, http_version): + def setup( + header: Tuple[str, str], http_version: str + ) -> Tuple[Connection, Optional[List[bytes]]]: c = Connection(SERVER) receive_and_get( c, "GET / HTTP/{}\r\nHost: a\r\n\r\n".format(http_version).encode("ascii") @@ -856,25 +885,25 @@ def setup(header, http_version): if header: headers.append(header) c.send(Response(status_code=200, headers=headers)) - return c, c.send_with_data_passthrough(Data(data=placeholder)) + return c, c.send_with_data_passthrough(Data(data=placeholder)) # type: ignore c, data = setup(("Content-Length", "10"), "1.1") - assert data == [placeholder] + assert data == [placeholder] # type: ignore # Raises an error if the connection object doesn't think we've sent # exactly 10 bytes c.send(EndOfMessage()) _, data = setup(("Transfer-Encoding", "chunked"), "1.1") - assert placeholder in data - data[data.index(placeholder)] = b"x" * 10 - assert b"".join(data) == b"a\r\nxxxxxxxxxx\r\n" + assert placeholder in data # type: ignore + data[data.index(placeholder)] = b"x" * 10 # type: ignore + assert b"".join(data) == b"a\r\nxxxxxxxxxx\r\n" # type: ignore - c, data = setup(None, "1.0") - assert data == [placeholder] + c, data = setup(None, "1.0") # type: ignore + assert data == [placeholder] # type: ignore assert c.our_state is SEND_BODY -def test_errors(): +def test_errors() -> None: # After a receive error, you can't receive for role in [CLIENT, SERVER]: c = Connection(our_role=role) @@ -890,14 +919,14 @@ def test_errors(): # But we can still yell at the client for sending us gibberish if role is SERVER: assert ( - c.send(Response(status_code=400, headers=[])) + c.send(Response(status_code=400, headers=[])) # type: ignore[arg-type] == b"HTTP/1.1 400 \r\nConnection: close\r\n\r\n" ) # After an error sending, you can no longer send # (This is especially important for things like content-length errors, # where there's complex internal state being modified) - def conn(role): + def conn(role: Sentinel) -> Connection: c = Connection(our_role=role) if role is SERVER: # Put it into the state where it *could* send a response... @@ -917,8 +946,8 @@ def conn(role): http_version="1.0", ) elif role is SERVER: - good = Response(status_code=200, headers=[]) - bad = Response(status_code=200, headers=[], http_version="1.0") + good = Response(status_code=200, headers=[]) # type: ignore[arg-type,assignment] + bad = Response(status_code=200, headers=[], http_version="1.0") # type: ignore[arg-type,assignment] # Make sure 'good' actually is good c = conn(role) c.send(good) @@ -944,14 +973,14 @@ def conn(role): assert c.their_state is not ERROR -def test_idle_receive_nothing(): +def test_idle_receive_nothing() -> None: # At one point this incorrectly raised an error for role in [CLIENT, SERVER]: c = Connection(role) assert c.next_event() is NEED_DATA -def test_connection_drop(): +def test_connection_drop() -> None: c = Connection(SERVER) c.receive_data(b"GET /") assert c.next_event() is NEED_DATA @@ -960,7 +989,7 @@ def test_connection_drop(): c.next_event() -def test_408_request_timeout(): +def test_408_request_timeout() -> None: # Should be able to send this spontaneously as a server without seeing # anything from client p = ConnectionPair() @@ -968,7 +997,7 @@ def test_408_request_timeout(): # This used to raise IndexError -def test_empty_request(): +def test_empty_request() -> None: c = Connection(SERVER) c.receive_data(b"\r\n") with pytest.raises(RemoteProtocolError): @@ -976,7 +1005,7 @@ def test_empty_request(): # This used to raise IndexError -def test_empty_response(): +def test_empty_response() -> None: c = Connection(CLIENT) c.send(Request(method="GET", target="/", headers=[("Host", "a")])) c.receive_data(b"\r\n") @@ -992,7 +1021,7 @@ def test_empty_response(): b"\x16\x03\x01\x00\xa5", # Typical start of a TLS Client Hello ], ) -def test_early_detection_of_invalid_request(data): +def test_early_detection_of_invalid_request(data: bytes) -> None: c = Connection(SERVER) # Early detection should occur before even receiving a `\r\n` c.receive_data(data) @@ -1008,7 +1037,7 @@ def test_early_detection_of_invalid_request(data): b"\x16\x03\x03\x00\x31", # Typical start of a TLS Server Hello ], ) -def test_early_detection_of_invalid_response(data): +def test_early_detection_of_invalid_response(data: bytes) -> None: c = Connection(CLIENT) # Early detection should occur before even receiving a `\r\n` c.receive_data(data) @@ -1020,8 +1049,8 @@ def test_early_detection_of_invalid_response(data): # The correct way to handle HEAD is to put whatever headers we *would* have # put if it were a GET -- even though we know that for HEAD, those headers # will be ignored. -def test_HEAD_framing_headers(): - def setup(method, http_version): +def test_HEAD_framing_headers() -> None: + def setup(method: bytes, http_version: bytes) -> Connection: c = Connection(SERVER) c.receive_data( method + b" / HTTP/" + http_version + b"\r\n" + b"Host: example.com\r\n\r\n" @@ -1034,14 +1063,14 @@ def setup(method, http_version): # No Content-Length, HTTP/1.1 peer, should use chunked c = setup(method, b"1.1") assert ( - c.send(Response(status_code=200, headers=[])) == b"HTTP/1.1 200 \r\n" + c.send(Response(status_code=200, headers=[])) == b"HTTP/1.1 200 \r\n" # type: ignore[arg-type] b"Transfer-Encoding: chunked\r\n\r\n" ) # No Content-Length, HTTP/1.0 peer, frame with connection: close c = setup(method, b"1.0") assert ( - c.send(Response(status_code=200, headers=[])) == b"HTTP/1.1 200 \r\n" + c.send(Response(status_code=200, headers=[])) == b"HTTP/1.1 200 \r\n" # type: ignore[arg-type] b"Connection: close\r\n\r\n" ) @@ -1062,7 +1091,7 @@ def setup(method, http_version): ) -def test_special_exceptions_for_lost_connection_in_message_body(): +def test_special_exceptions_for_lost_connection_in_message_body() -> None: c = Connection(SERVER) c.receive_data( b"POST / HTTP/1.1\r\n" b"Host: example.com\r\n" b"Content-Length: 100\r\n\r\n" @@ -1086,7 +1115,7 @@ def test_special_exceptions_for_lost_connection_in_message_body(): assert type(c.next_event()) is Request assert c.next_event() is NEED_DATA c.receive_data(b"8\r\n012345") - assert c.next_event().data == b"012345" + assert c.next_event().data == b"012345" # type: ignore c.receive_data(b"") with pytest.raises(RemoteProtocolError) as excinfo: c.next_event() diff --git a/h11/tests/test_events.py b/h11/tests/test_events.py index 4748c4b..64b6808 100644 --- a/h11/tests/test_events.py +++ b/h11/tests/test_events.py @@ -3,11 +3,19 @@ import pytest from .. import _events -from .._events import * +from .._events import ( + ConnectionClosed, + Data, + EndOfMessage, + Event, + InformationalResponse, + Request, + Response, +) from .._util import LocalProtocolError -def test_events(): +def test_events() -> None: with pytest.raises(LocalProtocolError): # Missing Host: req = Request( @@ -68,9 +76,9 @@ def test_events(): ) # Request target is validated - for bad_char in b"\x00\x20\x7f\xee": + for bad_byte in b"\x00\x20\x7f\xee": target = bytearray(b"/") - target.append(bad_char) + target.append(bad_byte) with pytest.raises(LocalProtocolError): Request( method="GET", target=target, headers=[("Host", "a")], http_version="1.1" @@ -84,19 +92,19 @@ def test_events(): with pytest.raises(LocalProtocolError): InformationalResponse(status_code=200, headers=[("Host", "a")]) - resp = Response(status_code=204, headers=[], http_version="1.0") + resp = Response(status_code=204, headers=[], http_version="1.0") # type: ignore[arg-type] assert resp.status_code == 204 assert resp.headers == [] assert resp.http_version == b"1.0" with pytest.raises(LocalProtocolError): - resp = Response(status_code=100, headers=[], http_version="1.0") + resp = Response(status_code=100, headers=[], http_version="1.0") # type: ignore[arg-type] with pytest.raises(LocalProtocolError): - Response(status_code="100", headers=[], http_version="1.0") + Response(status_code="100", headers=[], http_version="1.0") # type: ignore[arg-type] with pytest.raises(LocalProtocolError): - InformationalResponse(status_code=b"100", headers=[], http_version="1.0") + InformationalResponse(status_code=b"100", headers=[], http_version="1.0") # type: ignore[arg-type] d = Data(data=b"asdf") assert d.data == b"asdf" @@ -108,16 +116,16 @@ def test_events(): assert repr(cc) == "ConnectionClosed()" -def test_intenum_status_code(): +def test_intenum_status_code() -> None: # https://github.com/python-hyper/h11/issues/72 - r = Response(status_code=HTTPStatus.OK, headers=[], http_version="1.0") + r = Response(status_code=HTTPStatus.OK, headers=[], http_version="1.0") # type: ignore[arg-type] assert r.status_code == HTTPStatus.OK assert type(r.status_code) is not type(HTTPStatus.OK) assert type(r.status_code) is int -def test_header_casing(): +def test_header_casing() -> None: r = Request( method="GET", target="/", diff --git a/h11/tests/test_headers.py b/h11/tests/test_headers.py index ff3dc8d..ba53d08 100644 --- a/h11/tests/test_headers.py +++ b/h11/tests/test_headers.py @@ -1,9 +1,17 @@ import pytest -from .._headers import * - - -def test_normalize_and_validate(): +from .._events import Request +from .._headers import ( + get_comma_header, + has_expect_100_continue, + Headers, + normalize_and_validate, + set_comma_header, +) +from .._util import LocalProtocolError + + +def test_normalize_and_validate() -> None: assert normalize_and_validate([("foo", "bar")]) == [(b"foo", b"bar")] assert normalize_and_validate([(b"foo", b"bar")]) == [(b"foo", b"bar")] @@ -84,7 +92,7 @@ def test_normalize_and_validate(): assert excinfo.value.error_status_hint == 501 # Not Implemented -def test_get_set_comma_header(): +def test_get_set_comma_header() -> None: headers = normalize_and_validate( [ ("Connection", "close"), @@ -95,10 +103,10 @@ def test_get_set_comma_header(): assert get_comma_header(headers, b"connection") == [b"close", b"foo", b"bar"] - headers = set_comma_header(headers, b"newthing", ["a", "b"]) + headers = set_comma_header(headers, b"newthing", ["a", "b"]) # type: ignore with pytest.raises(LocalProtocolError): - set_comma_header(headers, b"newthing", [" a", "b"]) + set_comma_header(headers, b"newthing", [" a", "b"]) # type: ignore assert headers == [ (b"connection", b"close"), @@ -108,7 +116,7 @@ def test_get_set_comma_header(): (b"newthing", b"b"), ] - headers = set_comma_header(headers, b"whatever", ["different thing"]) + headers = set_comma_header(headers, b"whatever", ["different thing"]) # type: ignore assert headers == [ (b"connection", b"close"), @@ -119,9 +127,7 @@ def test_get_set_comma_header(): ] -def test_has_100_continue(): - from .._events import Request - +def test_has_100_continue() -> None: assert has_expect_100_continue( Request( method="GET", diff --git a/h11/tests/test_helpers.py b/h11/tests/test_helpers.py index 1477947..c329c76 100644 --- a/h11/tests/test_helpers.py +++ b/h11/tests/test_helpers.py @@ -1,12 +1,21 @@ -from .helpers import * +from .._events import ( + ConnectionClosed, + Data, + EndOfMessage, + Event, + InformationalResponse, + Request, + Response, +) +from .helpers import normalize_data_events -def test_normalize_data_events(): +def test_normalize_data_events() -> None: assert normalize_data_events( [ Data(data=bytearray(b"1")), Data(data=b"2"), - Response(status_code=200, headers=[]), + Response(status_code=200, headers=[]), # type: ignore[arg-type] Data(data=b"3"), Data(data=b"4"), EndOfMessage(), @@ -16,7 +25,7 @@ def test_normalize_data_events(): ] ) == [ Data(data=b"12"), - Response(status_code=200, headers=[]), + Response(status_code=200, headers=[]), # type: ignore[arg-type] Data(data=b"34"), EndOfMessage(), Data(data=b"567"), diff --git a/h11/tests/test_io.py b/h11/tests/test_io.py index 459a627..e9c01bd 100644 --- a/h11/tests/test_io.py +++ b/h11/tests/test_io.py @@ -1,6 +1,16 @@ +from typing import Any, Callable, Generator, List + import pytest -from .._events import * +from .._events import ( + ConnectionClosed, + Data, + EndOfMessage, + Event, + InformationalResponse, + Request, + Response, +) from .._headers import Headers, normalize_and_validate from .._readers import ( _obsolete_line_fold, @@ -10,7 +20,18 @@ READERS, ) from .._receivebuffer import ReceiveBuffer -from .._state import * +from .._state import ( + CLIENT, + CLOSED, + DONE, + IDLE, + MIGHT_SWITCH_PROTOCOL, + MUST_CLOSE, + SEND_BODY, + SEND_RESPONSE, + SERVER, + SWITCHED_PROTOCOL, +) from .._util import LocalProtocolError from .._writers import ( ChunkedWriter, @@ -40,7 +61,7 @@ ), ( (SERVER, SEND_RESPONSE), - Response(status_code=200, headers=[], reason=b"OK"), + Response(status_code=200, headers=[], reason=b"OK"), # type: ignore[arg-type] b"HTTP/1.1 200 OK\r\n\r\n", ), ( @@ -52,36 +73,35 @@ ), ( (SERVER, SEND_RESPONSE), - InformationalResponse(status_code=101, headers=[], reason=b"Upgrade"), + InformationalResponse(status_code=101, headers=[], reason=b"Upgrade"), # type: ignore[arg-type] b"HTTP/1.1 101 Upgrade\r\n\r\n", ), ] -def dowrite(writer, obj): - got_list = [] +def dowrite(writer: Callable[..., None], obj: Any) -> bytes: + got_list: List[bytes] = [] writer(obj, got_list.append) return b"".join(got_list) -def tw(writer, obj, expected): +def tw(writer: Any, obj: Any, expected: Any) -> None: got = dowrite(writer, obj) assert got == expected -def makebuf(data): +def makebuf(data: bytes) -> ReceiveBuffer: buf = ReceiveBuffer() buf += data return buf -def tr(reader, data, expected): - def check(got): +def tr(reader: Any, data: bytes, expected: Any) -> None: + def check(got: Any) -> None: assert got == expected # Headers should always be returned as bytes, not e.g. bytearray # https://github.com/python-hyper/wsproto/pull/54#issuecomment-377709478 for name, value in getattr(got, "headers", []): - print(name, value) assert type(name) is bytes assert type(value) is bytes @@ -104,17 +124,17 @@ def check(got): assert bytes(buf) == b"trailing" -def test_writers_simple(): +def test_writers_simple() -> None: for ((role, state), event, binary) in SIMPLE_CASES: tw(WRITERS[role, state], event, binary) -def test_readers_simple(): +def test_readers_simple() -> None: for ((role, state), event, binary) in SIMPLE_CASES: tr(READERS[role, state], binary, event) -def test_writers_unusual(): +def test_writers_unusual() -> None: # Simple test of the write_headers utility routine tw( write_headers, @@ -145,7 +165,7 @@ def test_writers_unusual(): ) -def test_readers_unusual(): +def test_readers_unusual() -> None: # Reading HTTP/1.0 tr( READERS[CLIENT, IDLE], @@ -162,7 +182,7 @@ def test_readers_unusual(): tr( READERS[CLIENT, IDLE], b"HEAD /foo HTTP/1.0\r\n\r\n", - Request(method="HEAD", target="/foo", headers=[], http_version="1.0"), + Request(method="HEAD", target="/foo", headers=[], http_version="1.0"), # type: ignore[arg-type] ) tr( @@ -305,7 +325,7 @@ def test_readers_unusual(): tr(READERS[CLIENT, IDLE], b"HEAD /foo HTTP/1.1\r\n" b": line\r\n\r\n", None) -def test__obsolete_line_fold_bytes(): +def test__obsolete_line_fold_bytes() -> None: # _obsolete_line_fold has a defensive cast to bytearray, which is # necessary to protect against O(n^2) behavior in case anyone ever passes # in regular bytestrings... but right now we never pass in regular @@ -318,7 +338,9 @@ def test__obsolete_line_fold_bytes(): ] -def _run_reader_iter(reader, buf, do_eof): +def _run_reader_iter( + reader: Any, buf: bytes, do_eof: bool +) -> Generator[Any, None, None]: while True: event = reader(buf) if event is None: @@ -333,12 +355,12 @@ def _run_reader_iter(reader, buf, do_eof): yield reader.read_eof() -def _run_reader(*args): +def _run_reader(*args: Any) -> List[Event]: events = list(_run_reader_iter(*args)) return normalize_data_events(events) -def t_body_reader(thunk, data, expected, do_eof=False): +def t_body_reader(thunk: Any, data: bytes, expected: Any, do_eof: bool = False) -> None: # Simple: consume whole thing print("Test 1") buf = makebuf(data) @@ -361,7 +383,7 @@ def t_body_reader(thunk, data, expected, do_eof=False): assert _run_reader(thunk(), buf, False) == expected -def test_ContentLengthReader(): +def test_ContentLengthReader() -> None: t_body_reader(lambda: ContentLengthReader(0), b"", [EndOfMessage()]) t_body_reader( @@ -371,7 +393,7 @@ def test_ContentLengthReader(): ) -def test_Http10Reader(): +def test_Http10Reader() -> None: t_body_reader(Http10Reader, b"", [EndOfMessage()], do_eof=True) t_body_reader(Http10Reader, b"asdf", [Data(data=b"asdf")], do_eof=False) t_body_reader( @@ -379,7 +401,7 @@ def test_Http10Reader(): ) -def test_ChunkedReader(): +def test_ChunkedReader() -> None: t_body_reader(ChunkedReader, b"0\r\n\r\n", [EndOfMessage()]) t_body_reader( @@ -434,7 +456,7 @@ def test_ChunkedReader(): ) -def test_ContentLengthWriter(): +def test_ContentLengthWriter() -> None: w = ContentLengthWriter(5) assert dowrite(w, Data(data=b"123")) == b"123" assert dowrite(w, Data(data=b"45")) == b"45" @@ -461,7 +483,7 @@ def test_ContentLengthWriter(): dowrite(w, EndOfMessage(headers=[("Etag", "asdf")])) -def test_ChunkedWriter(): +def test_ChunkedWriter() -> None: w = ChunkedWriter() assert dowrite(w, Data(data=b"aaa")) == b"3\r\naaa\r\n" assert dowrite(w, Data(data=b"a" * 20)) == b"14\r\n" + b"a" * 20 + b"\r\n" @@ -476,7 +498,7 @@ def test_ChunkedWriter(): ) -def test_Http10Writer(): +def test_Http10Writer() -> None: w = Http10Writer() assert dowrite(w, Data(data=b"1234")) == b"1234" assert dowrite(w, EndOfMessage()) == b"" @@ -485,12 +507,12 @@ def test_Http10Writer(): dowrite(w, EndOfMessage(headers=[("Etag", "asdf")])) -def test_reject_garbage_after_request_line(): +def test_reject_garbage_after_request_line() -> None: with pytest.raises(LocalProtocolError): tr(READERS[SERVER, SEND_RESPONSE], b"HTTP/1.0 200 OK\x00xxxx\r\n\r\n", None) -def test_reject_garbage_after_response_line(): +def test_reject_garbage_after_response_line() -> None: with pytest.raises(LocalProtocolError): tr( READERS[CLIENT, IDLE], @@ -499,7 +521,7 @@ def test_reject_garbage_after_response_line(): ) -def test_reject_garbage_in_header_line(): +def test_reject_garbage_in_header_line() -> None: with pytest.raises(LocalProtocolError): tr( READERS[CLIENT, IDLE], @@ -508,7 +530,7 @@ def test_reject_garbage_in_header_line(): ) -def test_reject_non_vchar_in_path(): +def test_reject_non_vchar_in_path() -> None: for bad_char in b"\x00\x20\x7f\xee": message = bytearray(b"HEAD /") message.append(bad_char) @@ -518,7 +540,7 @@ def test_reject_non_vchar_in_path(): # https://github.com/python-hyper/h11/issues/57 -def test_allow_some_garbage_in_cookies(): +def test_allow_some_garbage_in_cookies() -> None: tr( READERS[CLIENT, IDLE], b"HEAD /foo HTTP/1.1\r\n" @@ -536,7 +558,7 @@ def test_allow_some_garbage_in_cookies(): ) -def test_host_comes_first(): +def test_host_comes_first() -> None: tw( write_headers, normalize_and_validate([("foo", "bar"), ("Host", "example.com")]), diff --git a/h11/tests/test_receivebuffer.py b/h11/tests/test_receivebuffer.py index 3a61f9d..21a3870 100644 --- a/h11/tests/test_receivebuffer.py +++ b/h11/tests/test_receivebuffer.py @@ -1,11 +1,12 @@ import re +from typing import Tuple import pytest from .._receivebuffer import ReceiveBuffer -def test_receivebuffer(): +def test_receivebuffer() -> None: b = ReceiveBuffer() assert not b assert len(b) == 0 @@ -118,7 +119,7 @@ def test_receivebuffer(): ), ], ) -def test_receivebuffer_for_invalid_delimiter(data): +def test_receivebuffer_for_invalid_delimiter(data: Tuple[bytes]) -> None: b = ReceiveBuffer() for line in data: diff --git a/h11/tests/test_state.py b/h11/tests/test_state.py index efe83f0..bc974e6 100644 --- a/h11/tests/test_state.py +++ b/h11/tests/test_state.py @@ -1,12 +1,33 @@ import pytest -from .._events import * -from .._state import * -from .._state import _SWITCH_CONNECT, _SWITCH_UPGRADE, ConnectionState +from .._events import ( + ConnectionClosed, + Data, + EndOfMessage, + Event, + InformationalResponse, + Request, + Response, +) +from .._state import ( + _SWITCH_CONNECT, + _SWITCH_UPGRADE, + CLIENT, + CLOSED, + ConnectionState, + DONE, + IDLE, + MIGHT_SWITCH_PROTOCOL, + MUST_CLOSE, + SEND_BODY, + SEND_RESPONSE, + SERVER, + SWITCHED_PROTOCOL, +) from .._util import LocalProtocolError -def test_ConnectionState(): +def test_ConnectionState() -> None: cs = ConnectionState() # Basic event-triggered transitions @@ -38,7 +59,7 @@ def test_ConnectionState(): assert cs.states == {CLIENT: MUST_CLOSE, SERVER: CLOSED} -def test_ConnectionState_keep_alive(): +def test_ConnectionState_keep_alive() -> None: # keep_alive = False cs = ConnectionState() cs.process_event(CLIENT, Request) @@ -51,7 +72,7 @@ def test_ConnectionState_keep_alive(): assert cs.states == {CLIENT: MUST_CLOSE, SERVER: MUST_CLOSE} -def test_ConnectionState_keep_alive_in_DONE(): +def test_ConnectionState_keep_alive_in_DONE() -> None: # Check that if keep_alive is disabled when the CLIENT is already in DONE, # then this is sufficient to immediately trigger the DONE -> MUST_CLOSE # transition @@ -63,7 +84,7 @@ def test_ConnectionState_keep_alive_in_DONE(): assert cs.states[CLIENT] is MUST_CLOSE -def test_ConnectionState_switch_denied(): +def test_ConnectionState_switch_denied() -> None: for switch_type in (_SWITCH_CONNECT, _SWITCH_UPGRADE): for deny_early in (True, False): cs = ConnectionState() @@ -107,7 +128,7 @@ def test_ConnectionState_switch_denied(): } -def test_ConnectionState_protocol_switch_accepted(): +def test_ConnectionState_protocol_switch_accepted() -> None: for switch_event in [_SWITCH_UPGRADE, _SWITCH_CONNECT]: cs = ConnectionState() cs.process_client_switch_proposal(switch_event) @@ -125,7 +146,7 @@ def test_ConnectionState_protocol_switch_accepted(): assert cs.states == {CLIENT: SWITCHED_PROTOCOL, SERVER: SWITCHED_PROTOCOL} -def test_ConnectionState_double_protocol_switch(): +def test_ConnectionState_double_protocol_switch() -> None: # CONNECT + Upgrade is legal! Very silly, but legal. So we support # it. Because sometimes doing the silly thing is easier than not. for server_switch in [None, _SWITCH_UPGRADE, _SWITCH_CONNECT]: @@ -144,7 +165,7 @@ def test_ConnectionState_double_protocol_switch(): assert cs.states == {CLIENT: SWITCHED_PROTOCOL, SERVER: SWITCHED_PROTOCOL} -def test_ConnectionState_inconsistent_protocol_switch(): +def test_ConnectionState_inconsistent_protocol_switch() -> None: for client_switches, server_switch in [ ([], _SWITCH_CONNECT), ([], _SWITCH_UPGRADE), @@ -152,14 +173,14 @@ def test_ConnectionState_inconsistent_protocol_switch(): ([_SWITCH_CONNECT], _SWITCH_UPGRADE), ]: cs = ConnectionState() - for client_switch in client_switches: + for client_switch in client_switches: # type: ignore[attr-defined] cs.process_client_switch_proposal(client_switch) cs.process_event(CLIENT, Request) with pytest.raises(LocalProtocolError): cs.process_event(SERVER, Response, server_switch) -def test_ConnectionState_keepalive_protocol_switch_interaction(): +def test_ConnectionState_keepalive_protocol_switch_interaction() -> None: # keep_alive=False + pending_switch_proposals cs = ConnectionState() cs.process_client_switch_proposal(_SWITCH_UPGRADE) @@ -177,7 +198,7 @@ def test_ConnectionState_keepalive_protocol_switch_interaction(): assert cs.states == {CLIENT: MUST_CLOSE, SERVER: SEND_BODY} -def test_ConnectionState_reuse(): +def test_ConnectionState_reuse() -> None: cs = ConnectionState() with pytest.raises(LocalProtocolError): @@ -242,7 +263,7 @@ def test_ConnectionState_reuse(): assert cs.states == {CLIENT: IDLE, SERVER: IDLE} -def test_server_request_is_illegal(): +def test_server_request_is_illegal() -> None: # There used to be a bug in how we handled the Request special case that # made this allowed... cs = ConnectionState() diff --git a/h11/tests/test_util.py b/h11/tests/test_util.py index d851bdc..cc086a6 100644 --- a/h11/tests/test_util.py +++ b/h11/tests/test_util.py @@ -1,18 +1,26 @@ import re import sys import traceback +from typing import NoReturn import pytest -from .._util import * +from .._util import ( + bytesify, + LocalProtocolError, + make_sentinel, + ProtocolError, + RemoteProtocolError, + validate, +) -def test_ProtocolError(): +def test_ProtocolError() -> None: with pytest.raises(TypeError): ProtocolError("abstract base class") -def test_LocalProtocolError(): +def test_LocalProtocolError() -> None: try: raise LocalProtocolError("foo") except LocalProtocolError as e: @@ -25,7 +33,7 @@ def test_LocalProtocolError(): assert str(e) == "foo" assert e.error_status_hint == 418 - def thunk(): + def thunk() -> NoReturn: raise LocalProtocolError("a", error_status_hint=420) try: @@ -42,7 +50,7 @@ def thunk(): assert new_traceback.endswith(orig_traceback) -def test_validate(): +def test_validate() -> None: my_re = re.compile(br"(?P[0-9]+)\.(?P[0-9]+)") with pytest.raises(LocalProtocolError): validate(my_re, b"0.") @@ -57,7 +65,7 @@ def test_validate(): validate(my_re, b"0.1\n") -def test_validate_formatting(): +def test_validate_formatting() -> None: my_re = re.compile(br"foo") with pytest.raises(LocalProtocolError) as excinfo: @@ -73,7 +81,7 @@ def test_validate_formatting(): assert "oops 10 xx" in str(excinfo.value) -def test_make_sentinel(): +def test_make_sentinel() -> None: S = make_sentinel("S") assert repr(S) == "S" assert S == S @@ -87,7 +95,7 @@ def test_make_sentinel(): assert type(S) is not type(S2) -def test_bytesify(): +def test_bytesify() -> None: assert bytesify(b"123") == b"123" assert bytesify(bytearray(b"123")) == b"123" assert bytesify("123") == b"123" diff --git a/setup.cfg b/setup.cfg index 0bd1262..95cb0f4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,3 +8,9 @@ line_length=88 multi_line_output=3 no_lines_before=LOCALFOLDER order_by_type=False + +[mypy] +strict = true +warn_unused_configs = true +warn_unused_ignores = true +show_error_codes = true diff --git a/tox.ini b/tox.ini index d4a2272..f333194 100644 --- a/tox.ini +++ b/tox.ini @@ -1,11 +1,11 @@ [tox] -envlist = format, py36, py37, py38, py39, pypy3 +envlist = format, py36, py37, py38, py39, pypy3, mypy [gh-actions] python = 3.6: py36 3.7: py37 - 3.8: py38, format + 3.8: py38, format, mypy 3.9: py39 pypy3: pypy3 @@ -21,3 +21,11 @@ deps = commands = black --check --diff h11/ bench/ examples/ fuzz/ isort --check --diff h11 bench examples fuzz + +[testenv:mypy] +basepython = python3.8 +deps = + mypy + pytest +commands = + mypy h11/